Merge pull request #10 from seifghazi/sg/subagent-support
Route Subagents to OpenAI Models
This commit is contained in:
commit
7574829604
23 changed files with 1970 additions and 337 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -38,4 +38,8 @@ coverage/
|
||||||
|
|
||||||
# Temporary files
|
# Temporary files
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
|
|
||||||
|
|
||||||
|
# Config
|
||||||
|
config.yaml
|
||||||
2
Makefile
2
Makefile
|
|
@ -1,4 +1,4 @@
|
||||||
.PHONY: all build run clean install dev
|
.PHONY: all build run clean install dev
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
all: install build
|
all: install build
|
||||||
|
|
|
||||||
106
README.md
106
README.md
|
|
@ -2,18 +2,20 @@
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
A dual-purpose monitoring solution that serves as both a proxy for Claude Code requests and a visualization dashboard for your Claude API conversations.
|
A transparent proxy for capturing and visualizing in-flight Claude Code requests and conversations, with optional agent routing to different LLM providers.
|
||||||
|
|
||||||
## What It Does
|
## What It Does
|
||||||
|
|
||||||
Claude Code Proxy serves two main purposes:
|
Claude Code Proxy serves three main purposes:
|
||||||
|
|
||||||
1. **Claude Code Proxy**: Intercepts and monitors requests from Claude Code (claude.ai/code) to the Anthropic API, allowing you to see what Claude Code is doing in real-time
|
1. **Claude Code Proxy**: Intercepts and monitors requests from Claude Code (claude.ai/code) to the Anthropic API, allowing you to see what Claude Code is doing in real-time
|
||||||
2. **Conversation Viewer**: Displays and analyzes your Claude API conversations with a beautiful web interface
|
2. **Conversation Viewer**: Displays and analyzes your Claude API conversations with a beautiful web interface
|
||||||
|
3. **Agent Routing (Optional)**: Routes specific Claude Code agents to different LLM providers (e.g., route code-reviewer agent to GPT-4o)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Transparent Proxy**: Routes Claude Code requests through the monitor without disruption
|
- **Transparent Proxy**: Routes Claude Code requests through the monitor without disruption
|
||||||
|
- **Agent Routing (Optional)**: Map specific Claude Code agents to different LLM models
|
||||||
- **Request Monitoring**: SQLite-based logging of all API interactions
|
- **Request Monitoring**: SQLite-based logging of all API interactions
|
||||||
- **Live Dashboard**: Real-time visualization of requests and responses
|
- **Live Dashboard**: Real-time visualization of requests and responses
|
||||||
- **Conversation Analysis**: View full conversation threads with tool usage
|
- **Conversation Analysis**: View full conversation threads with tool usage
|
||||||
|
|
@ -36,9 +38,9 @@ Claude Code Proxy serves two main purposes:
|
||||||
cd claude-code-proxy
|
cd claude-code-proxy
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Set up your environment variables**
|
2. **Configure the proxy**
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp config.yaml.example config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Install and run** (first time)
|
3. **Install and run** (first time)
|
||||||
|
|
@ -46,11 +48,6 @@ Claude Code Proxy serves two main purposes:
|
||||||
make install # Install all dependencies
|
make install # Install all dependencies
|
||||||
make dev # Start both services
|
make dev # Start both services
|
||||||
```
|
```
|
||||||
|
|
||||||
Or use the script that does both:
|
|
||||||
```bash
|
|
||||||
./run.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **Subsequent runs** (after initial setup)
|
4. **Subsequent runs** (after initial setup)
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -154,15 +151,86 @@ make help # Show all commands
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
### Local Development
|
### Basic Setup
|
||||||
Create a `.env` file with:
|
|
||||||
```
|
Create a `config.yaml` file (or copy from `config.yaml.example`):
|
||||||
PORT=3001
|
```yaml
|
||||||
DB_PATH=requests.db
|
server:
|
||||||
ANTHROPIC_FORWARD_URL=https://api.anthropic.com
|
port: 3001
|
||||||
|
|
||||||
|
providers:
|
||||||
|
anthropic:
|
||||||
|
base_url: "https://api.anthropic.com"
|
||||||
|
|
||||||
|
openai: # if enabling subagent routing
|
||||||
|
api_key: "your-openai-key" # Or set OPENAI_API_KEY env var
|
||||||
|
|
||||||
|
storage:
|
||||||
|
db_path: "requests.db"
|
||||||
```
|
```
|
||||||
|
|
||||||
See `.env.example` for all available options.
|
### Subagent Configuration (Optional)
|
||||||
|
|
||||||
|
The proxy supports routing specific Claude Code agents to different LLM providers. This is an **optional** feature that's disabled by default.
|
||||||
|
|
||||||
|
#### Enabling Subagent Routing
|
||||||
|
|
||||||
|
1. **Enable the feature** in `config.yaml`:
|
||||||
|
```yaml
|
||||||
|
subagents:
|
||||||
|
enable: true # Set to true to enable subagent routing
|
||||||
|
mappings:
|
||||||
|
code-reviewer: "gpt-4o"
|
||||||
|
data-analyst: "o3"
|
||||||
|
doc-writer: "gpt-3.5-turbo"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Set up your Claude Code agents** following Anthropic's official documentation:
|
||||||
|
- 📖 **[Claude Code Subagents Documentation](https://docs.anthropic.com/en/docs/claude-code/sub-agents)**
|
||||||
|
|
||||||
|
3. **How it works**: When Claude Code uses a subagent that matches one of your mappings, the proxy will automatically route the request to the specified model instead of Claude.
|
||||||
|
|
||||||
|
### Practical Examples
|
||||||
|
|
||||||
|
**Example 1: Code Review Agent → GPT-4o**
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
subagents:
|
||||||
|
enable: true
|
||||||
|
mappings:
|
||||||
|
code-reviewer: "gpt-4o"
|
||||||
|
```
|
||||||
|
Use case: Route code review tasks to GPT-4o for faster responses while keeping complex coding tasks on Claude.
|
||||||
|
|
||||||
|
**Example 2: Reasoning Agent → O3**
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
subagents:
|
||||||
|
enable: true
|
||||||
|
mappings:
|
||||||
|
deep-reasoning: "o3"
|
||||||
|
```
|
||||||
|
Use case: Send complex reasoning tasks to O3 while using Claude for general coding.
|
||||||
|
|
||||||
|
**Example 3: Multiple Agents**
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
subagents:
|
||||||
|
enable: true
|
||||||
|
mappings:
|
||||||
|
streaming-systems-engineer: "o3"
|
||||||
|
frontend-developer: "gpt-4o-mini"
|
||||||
|
security-auditor: "gpt-4o"
|
||||||
|
```
|
||||||
|
Use case: Different specialists for different tasks, optimizing for speed/cost/quality.
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Override config via environment:
|
||||||
|
- `PORT` - Server port
|
||||||
|
- `OPENAI_API_KEY` - OpenAI API key
|
||||||
|
- `DB_PATH` - Database path
|
||||||
|
- `SUBAGENT_MAPPINGS` - Comma-separated mappings (e.g., `"code-reviewer:gpt-4o,data-analyst:o3"`)
|
||||||
|
|
||||||
### Docker Environment Variables
|
### Docker Environment Variables
|
||||||
|
|
||||||
|
|
@ -216,12 +284,6 @@ claude-code-proxy/
|
||||||
- Request/response body inspection
|
- Request/response body inspection
|
||||||
- Conversation threading
|
- Conversation threading
|
||||||
|
|
||||||
### Prompt Analysis
|
|
||||||
- Automatic prompt grading
|
|
||||||
- Best practices evaluation
|
|
||||||
- Complexity assessment
|
|
||||||
- Response quality metrics
|
|
||||||
|
|
||||||
### Web Dashboard
|
### Web Dashboard
|
||||||
- Real-time request streaming
|
- Real-time request streaming
|
||||||
- Interactive request explorer
|
- Interactive request explorer
|
||||||
|
|
|
||||||
91
config.yaml.example
Normal file
91
config.yaml.example
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
# LLM Proxy Configuration Example
|
||||||
|
# This file demonstrates all available configuration options
|
||||||
|
# Copy this file to config.yaml and customize as needed
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
server:
|
||||||
|
# Port to listen on (default: 3001)
|
||||||
|
port: 3001
|
||||||
|
|
||||||
|
# Timeout configurations
|
||||||
|
timeouts:
|
||||||
|
# Maximum duration for reading the entire request, including the body
|
||||||
|
read: 10m
|
||||||
|
|
||||||
|
# Maximum duration before timing out writes of the response
|
||||||
|
write: 10m
|
||||||
|
|
||||||
|
# Maximum amount of time to wait for the next request when keep-alives are enabled
|
||||||
|
idle: 10m
|
||||||
|
|
||||||
|
# Provider configurations
|
||||||
|
providers:
|
||||||
|
# Anthropic Claude configuration
|
||||||
|
anthropic:
|
||||||
|
# Base URL for Anthropic API (can be changed for custom endpoints)
|
||||||
|
base_url: "https://api.anthropic.com"
|
||||||
|
|
||||||
|
# Maximum number of retries for failed requests
|
||||||
|
max_retries: 3
|
||||||
|
|
||||||
|
# OpenAI configuration
|
||||||
|
openai:
|
||||||
|
# API key for OpenAI
|
||||||
|
# Can also be set via OPENAI_API_KEY environment variable
|
||||||
|
# api_key: "..."
|
||||||
|
|
||||||
|
# Base URL for OpenAI API (can be changed for custom endpoints)
|
||||||
|
# Can also be set via OPENAI_BASE_URL environment variable
|
||||||
|
# base_url: "https://api.openai.com"
|
||||||
|
|
||||||
|
# Storage configuration
|
||||||
|
storage:
|
||||||
|
# SQLite database path for storing request history
|
||||||
|
db_path: "requests.db"
|
||||||
|
|
||||||
|
# Directory for storing request files (if needed in future)
|
||||||
|
# requests_dir: "./requests"
|
||||||
|
|
||||||
|
# Subagent Configuration (Optional)
|
||||||
|
# Enable this feature if you want to route specific Claude Code agents to different LLM providers
|
||||||
|
# For subagent setup instructions, see: https://docs.anthropic.com/en/docs/claude-code/sub-agents
|
||||||
|
subagents:
|
||||||
|
# Enable subagent routing (default: false)
|
||||||
|
enable: false
|
||||||
|
|
||||||
|
# Maps subagent types to specific models
|
||||||
|
# Only used when enable: true
|
||||||
|
mappings:
|
||||||
|
# Code review specialist (example)
|
||||||
|
# code-reviewer: "gpt-4o"
|
||||||
|
|
||||||
|
# Data analysis expert (example)
|
||||||
|
# data-analyst: "o3"
|
||||||
|
|
||||||
|
# Documentation writer (example)
|
||||||
|
# doc-writer: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
# Environment variable overrides:
|
||||||
|
# The following environment variables will override the YAML configuration:
|
||||||
|
#
|
||||||
|
# Server:
|
||||||
|
# PORT - Server port
|
||||||
|
# READ_TIMEOUT - Read timeout duration
|
||||||
|
# WRITE_TIMEOUT - Write timeout duration
|
||||||
|
# IDLE_TIMEOUT - Idle timeout duration
|
||||||
|
#
|
||||||
|
# Anthropic:
|
||||||
|
# ANTHROPIC_FORWARD_URL - Anthropic base URL
|
||||||
|
# ANTHROPIC_VERSION - Anthropic API version
|
||||||
|
# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests
|
||||||
|
#
|
||||||
|
# OpenAI:
|
||||||
|
# OPENAI_API_KEY - OpenAI API key
|
||||||
|
# OPENAI_BASE_URL - OpenAI base URL
|
||||||
|
#
|
||||||
|
# Storage:
|
||||||
|
# DB_PATH - Database file path
|
||||||
|
#
|
||||||
|
# Subagents:
|
||||||
|
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
|
||||||
|
# Example: "code-reviewer:claude-3-5-sonnet"
|
||||||
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/handler"
|
"github.com/seifghazi/claude-code-monitor/internal/handler"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/middleware"
|
"github.com/seifghazi/claude-code-monitor/internal/middleware"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/provider"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/service"
|
"github.com/seifghazi/claude-code-monitor/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -26,6 +27,15 @@ func main() {
|
||||||
logger.Fatalf("❌ Failed to load configuration: %v", err)
|
logger.Fatalf("❌ Failed to load configuration: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize providers
|
||||||
|
providers := make(map[string]provider.Provider)
|
||||||
|
providers["anthropic"] = provider.NewAnthropicProvider(&cfg.Providers.Anthropic)
|
||||||
|
providers["openai"] = provider.NewOpenAIProvider(&cfg.Providers.OpenAI)
|
||||||
|
|
||||||
|
// Initialize model router
|
||||||
|
modelRouter := service.NewModelRouter(cfg, providers, logger)
|
||||||
|
|
||||||
|
// Use legacy anthropic service for backward compatibility
|
||||||
anthropicService := service.NewAnthropicService(&cfg.Anthropic)
|
anthropicService := service.NewAnthropicService(&cfg.Anthropic)
|
||||||
|
|
||||||
// Use SQLite storage
|
// Use SQLite storage
|
||||||
|
|
@ -35,7 +45,7 @@ func main() {
|
||||||
}
|
}
|
||||||
logger.Println("🗿 SQLite database ready")
|
logger.Println("🗿 SQLite database ready")
|
||||||
|
|
||||||
h := handler.New(anthropicService, storageService, logger)
|
h := handler.New(anthropicService, storageService, logger, modelRouter)
|
||||||
|
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
|
|
||||||
|
|
@ -73,16 +83,12 @@ func main() {
|
||||||
go func() {
|
go func() {
|
||||||
logger.Printf("🚀 Claude Code Monitor Server running on http://localhost:%s", cfg.Server.Port)
|
logger.Printf("🚀 Claude Code Monitor Server running on http://localhost:%s", cfg.Server.Port)
|
||||||
logger.Printf("📡 API endpoints available at:")
|
logger.Printf("📡 API endpoints available at:")
|
||||||
logger.Printf(" - POST http://localhost:%s/v1/chat/completions (OpenAI format)", cfg.Server.Port)
|
|
||||||
logger.Printf(" - POST http://localhost:%s/v1/messages (Anthropic format)", cfg.Server.Port)
|
logger.Printf(" - POST http://localhost:%s/v1/messages (Anthropic format)", cfg.Server.Port)
|
||||||
logger.Printf(" - GET http://localhost:%s/v1/models", cfg.Server.Port)
|
logger.Printf(" - GET http://localhost:%s/v1/models", cfg.Server.Port)
|
||||||
logger.Printf(" - GET http://localhost:%s/health", cfg.Server.Port)
|
logger.Printf(" - GET http://localhost:%s/health", cfg.Server.Port)
|
||||||
logger.Printf(" - POST http://localhost:%s/api/grade-prompt (Prompt grading)", cfg.Server.Port)
|
|
||||||
logger.Printf("🎨 Web UI available at:")
|
logger.Printf("🎨 Web UI available at:")
|
||||||
logger.Printf(" - GET http://localhost:%s/ (Request Visualizer)", cfg.Server.Port)
|
logger.Printf(" - GET http://localhost:%s/ (Request Visualizer)", cfg.Server.Port)
|
||||||
logger.Printf(" - GET http://localhost:%s/api/requests (Request API)", cfg.Server.Port)
|
logger.Printf(" - GET http://localhost:%s/api/requests (Request API)", cfg.Server.Port)
|
||||||
logger.Printf("🔍 All requests logged with comprehensive error handling")
|
|
||||||
logger.Printf("🎯 Auto prompt grading with Anthropic best practices")
|
|
||||||
|
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
logger.Fatalf("❌ Server failed to start: %v", err)
|
logger.Fatalf("❌ Server failed to start: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ require (
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/mattn/go-sqlite3 v1.14.28
|
github.com/mattn/go-sqlite3 v1.14.28
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require github.com/felixge/httpsnoop v1.0.3 // indirect
|
require github.com/felixge/httpsnoop v1.0.3 // indirect
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,7 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
|
||||||
|
|
@ -7,21 +7,48 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig
|
Server ServerConfig `yaml:"server"`
|
||||||
|
Providers ProvidersConfig `yaml:"providers"`
|
||||||
|
Storage StorageConfig `yaml:"storage"`
|
||||||
|
Subagents SubagentsConfig `yaml:"subagents"`
|
||||||
Anthropic AnthropicConfig
|
Anthropic AnthropicConfig
|
||||||
Storage StorageConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port string
|
Port string `yaml:"port"`
|
||||||
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
|
// Legacy fields
|
||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TimeoutsConfig struct {
|
||||||
|
Read string `yaml:"read"`
|
||||||
|
Write string `yaml:"write"`
|
||||||
|
Idle string `yaml:"idle"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProvidersConfig struct {
|
||||||
|
Anthropic AnthropicProviderConfig `yaml:"anthropic"`
|
||||||
|
OpenAI OpenAIProviderConfig `yaml:"openai"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnthropicProviderConfig struct {
|
||||||
|
BaseURL string `yaml:"base_url"`
|
||||||
|
Version string `yaml:"version"`
|
||||||
|
MaxRetries int `yaml:"max_retries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderConfig struct {
|
||||||
|
BaseURL string `yaml:"base_url"`
|
||||||
|
APIKey string `yaml:"api_key"`
|
||||||
|
}
|
||||||
|
|
||||||
type AnthropicConfig struct {
|
type AnthropicConfig struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Version string
|
Version string
|
||||||
|
|
@ -29,8 +56,13 @@ type AnthropicConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type StorageConfig struct {
|
type StorageConfig struct {
|
||||||
RequestsDir string
|
RequestsDir string `yaml:"requests_dir"`
|
||||||
DBPath string
|
DBPath string `yaml:"db_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubagentsConfig struct {
|
||||||
|
Enable bool `yaml:"enable"`
|
||||||
|
Mappings map[string]string `yaml:"mappings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Load() (*Config, error) {
|
func Load() (*Config, error) {
|
||||||
|
|
@ -45,26 +77,132 @@ func Load() (*Config, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start with default configuration
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: getEnv("PORT", "3001"),
|
Port: "3001",
|
||||||
ReadTimeout: getDuration("READ_TIMEOUT", 600*time.Second), // Increased to 10 minutes
|
ReadTimeout: 600 * time.Second,
|
||||||
WriteTimeout: getDuration("WRITE_TIMEOUT", 600*time.Second), // Increased to 10 minutes
|
WriteTimeout: 600 * time.Second,
|
||||||
IdleTimeout: getDuration("IDLE_TIMEOUT", 600*time.Second), // Increased to 10 minutes
|
IdleTimeout: 600 * time.Second,
|
||||||
},
|
},
|
||||||
Anthropic: AnthropicConfig{
|
Providers: ProvidersConfig{
|
||||||
BaseURL: getEnv("ANTHROPIC_FORWARD_URL", "https://api.anthropic.com"),
|
Anthropic: AnthropicProviderConfig{
|
||||||
Version: getEnv("ANTHROPIC_VERSION", "2023-06-01"),
|
BaseURL: "https://api.anthropic.com",
|
||||||
MaxRetries: getInt("ANTHROPIC_MAX_RETRIES", 3),
|
Version: "2023-06-01",
|
||||||
|
MaxRetries: 3,
|
||||||
|
},
|
||||||
|
OpenAI: OpenAIProviderConfig{
|
||||||
|
BaseURL: "https://api.openai.com",
|
||||||
|
APIKey: "",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Storage: StorageConfig{
|
Storage: StorageConfig{
|
||||||
DBPath: getEnv("DB_PATH", "requests.db"),
|
DBPath: "requests.db",
|
||||||
|
},
|
||||||
|
Subagents: SubagentsConfig{
|
||||||
|
Enable: false,
|
||||||
|
Mappings: make(map[string]string),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to load config.yaml from the project root
|
||||||
|
// The proxy binary is in proxy/ directory, config.yaml is in the parent
|
||||||
|
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
|
||||||
|
|
||||||
|
// If that doesn't work, try relative to current directory
|
||||||
|
if _, err := os.Stat(configPath); err != nil {
|
||||||
|
// Try common locations relative to where the binary might be run
|
||||||
|
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
|
||||||
|
if _, err := os.Stat(tryPath); err == nil {
|
||||||
|
configPath = tryPath
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.loadFromFile(configPath)
|
||||||
|
|
||||||
|
// Apply environment variable overrides AFTER loading from file
|
||||||
|
if envPort := os.Getenv("PORT"); envPort != "" {
|
||||||
|
cfg.Server.Port = envPort
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("READ_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.ReadTimeout = getDuration("READ_TIMEOUT", cfg.Server.ReadTimeout)
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("WRITE_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.WriteTimeout = getDuration("WRITE_TIMEOUT", cfg.Server.WriteTimeout)
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("IDLE_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.IdleTimeout = getDuration("IDLE_TIMEOUT", cfg.Server.IdleTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Anthropic settings
|
||||||
|
if envURL := os.Getenv("ANTHROPIC_FORWARD_URL"); envURL != "" {
|
||||||
|
cfg.Providers.Anthropic.BaseURL = envURL
|
||||||
|
}
|
||||||
|
if envVersion := os.Getenv("ANTHROPIC_VERSION"); envVersion != "" {
|
||||||
|
cfg.Providers.Anthropic.Version = envVersion
|
||||||
|
}
|
||||||
|
if envRetries := os.Getenv("ANTHROPIC_MAX_RETRIES"); envRetries != "" {
|
||||||
|
cfg.Providers.Anthropic.MaxRetries = getInt("ANTHROPIC_MAX_RETRIES", cfg.Providers.Anthropic.MaxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override OpenAI settings
|
||||||
|
if envURL := os.Getenv("OPENAI_BASE_URL"); envURL != "" {
|
||||||
|
cfg.Providers.OpenAI.BaseURL = envURL
|
||||||
|
}
|
||||||
|
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
||||||
|
cfg.Providers.OpenAI.APIKey = envKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override storage settings
|
||||||
|
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
||||||
|
cfg.Storage.DBPath = envPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync legacy Anthropic config
|
||||||
|
cfg.Anthropic = AnthropicConfig{
|
||||||
|
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||||
|
Version: cfg.Providers.Anthropic.Version,
|
||||||
|
MaxRetries: cfg.Providers.Anthropic.MaxRetries,
|
||||||
|
}
|
||||||
|
|
||||||
|
// After loading from file, apply any timeout conversions if needed
|
||||||
|
if cfg.Server.Timeouts.Read != "" {
|
||||||
|
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Read); err == nil {
|
||||||
|
cfg.Server.ReadTimeout = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.Server.Timeouts.Write != "" {
|
||||||
|
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Write); err == nil {
|
||||||
|
cfg.Server.WriteTimeout = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.Server.Timeouts.Idle != "" {
|
||||||
|
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Idle); err == nil {
|
||||||
|
cfg.Server.IdleTimeout = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync legacy Anthropic config with new structure
|
||||||
|
cfg.Anthropic = AnthropicConfig{
|
||||||
|
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||||
|
Version: cfg.Providers.Anthropic.Version,
|
||||||
|
MaxRetries: cfg.Providers.Anthropic.MaxRetries,
|
||||||
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) loadFromFile(path string) error {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return yaml.Unmarshal(data, c)
|
||||||
|
}
|
||||||
|
|
||||||
func getEnv(key, defaultValue string) string {
|
func getEnv(key, defaultValue string) string {
|
||||||
if value := os.Getenv(key); value != "" {
|
if value := os.Getenv(key); value != "" {
|
||||||
return value
|
return value
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
@ -25,35 +26,37 @@ type Handler struct {
|
||||||
anthropicService service.AnthropicService
|
anthropicService service.AnthropicService
|
||||||
storageService service.StorageService
|
storageService service.StorageService
|
||||||
conversationService service.ConversationService
|
conversationService service.ConversationService
|
||||||
|
modelRouter *service.ModelRouter
|
||||||
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(anthropicService service.AnthropicService, storageService service.StorageService, logger *log.Logger) *Handler {
|
func New(anthropicService service.AnthropicService, storageService service.StorageService, logger *log.Logger, modelRouter *service.ModelRouter) *Handler {
|
||||||
conversationService := service.NewConversationService()
|
conversationService := service.NewConversationService()
|
||||||
|
|
||||||
return &Handler{
|
return &Handler{
|
||||||
anthropicService: anthropicService,
|
anthropicService: anthropicService,
|
||||||
storageService: storageService,
|
storageService: storageService,
|
||||||
conversationService: conversationService,
|
conversationService: conversationService,
|
||||||
|
modelRouter: modelRouter,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("🤖 Chat completion request received (OpenAI format)")
|
|
||||||
|
|
||||||
// This endpoint is for compatibility but we're an Anthropic proxy
|
// This endpoint is for compatibility but we're an Anthropic proxy
|
||||||
// Return a helpful error message
|
// Return a helpful error message
|
||||||
writeErrorResponse(w, "This is an Anthropic proxy. Please use the /v1/messages endpoint instead of /v1/chat/completions", http.StatusBadRequest)
|
writeErrorResponse(w, "This is an Anthropic proxy. Please use the /v1/messages endpoint instead of /v1/chat/completions", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("🤖 Messages request received (Anthropic format)")
|
// Get body bytes from context (set by middleware)
|
||||||
|
|
||||||
bodyBytes := getBodyBytes(r)
|
bodyBytes := getBodyBytes(r)
|
||||||
if bodyBytes == nil {
|
if bodyBytes == nil {
|
||||||
http.Error(w, "Error reading request body", http.StatusBadRequest)
|
http.Error(w, "Error reading request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the request
|
||||||
var req model.AnthropicRequest
|
var req model.AnthropicRequest
|
||||||
if err := json.Unmarshal(bodyBytes, &req); err != nil {
|
if err := json.Unmarshal(bodyBytes, &req); err != nil {
|
||||||
log.Printf("❌ Error parsing JSON: %v", err)
|
log.Printf("❌ Error parsing JSON: %v", err)
|
||||||
|
|
@ -64,27 +67,55 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
requestID := generateRequestID()
|
requestID := generateRequestID()
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
// Create request log
|
// Use model router to determine provider and route the request
|
||||||
|
decision, err := h.modelRouter.DetermineRoute(&req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("❌ Error routing request: %v", err)
|
||||||
|
writeErrorResponse(w, "Failed to route request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create request log with routing information
|
||||||
requestLog := &model.RequestLog{
|
requestLog := &model.RequestLog{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Timestamp: time.Now().Format(time.RFC3339),
|
Timestamp: time.Now().Format(time.RFC3339),
|
||||||
Method: r.Method,
|
Method: r.Method,
|
||||||
Endpoint: "/v1/messages",
|
Endpoint: r.URL.Path,
|
||||||
Headers: SanitizeHeaders(r.Header),
|
Headers: SanitizeHeaders(r.Header),
|
||||||
Body: req,
|
Body: req,
|
||||||
Model: req.Model,
|
Model: decision.OriginalModel,
|
||||||
UserAgent: r.Header.Get("User-Agent"),
|
OriginalModel: decision.OriginalModel,
|
||||||
ContentType: r.Header.Get("Content-Type"),
|
RoutedModel: decision.TargetModel,
|
||||||
|
UserAgent: r.Header.Get("User-Agent"),
|
||||||
|
ContentType: r.Header.Get("Content-Type"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
||||||
log.Printf("❌ Error saving request: %v", err)
|
log.Printf("❌ Error saving request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward the request to Anthropic
|
// If the model was changed by routing, update the request body
|
||||||
resp, err := h.anthropicService.ForwardRequest(r.Context(), r)
|
if decision.TargetModel != decision.OriginalModel {
|
||||||
|
req.Model = decision.TargetModel
|
||||||
|
|
||||||
|
// Re-marshal the request with the updated model
|
||||||
|
updatedBodyBytes, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("❌ Error marshaling updated request: %v", err)
|
||||||
|
writeErrorResponse(w, "Failed to process request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the request body
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(updatedBodyBytes))
|
||||||
|
r.ContentLength = int64(len(updatedBodyBytes))
|
||||||
|
r.Header.Set("Content-Length", fmt.Sprintf("%d", len(updatedBodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward the request to the selected provider
|
||||||
|
resp, err := decision.Provider.ForwardRequest(r.Context(), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error forwarding to Anthropic API: %v", err)
|
log.Printf("❌ Error forwarding to %s API: %v", decision.Provider.Name(), err)
|
||||||
writeErrorResponse(w, "Failed to forward request", http.StatusInternalServerError)
|
writeErrorResponse(w, "Failed to forward request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -99,7 +130,6 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Models(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) Models(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("📋 Models list requested")
|
|
||||||
|
|
||||||
response := &model.ModelsResponse{
|
response := &model.ModelsResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
|
|
@ -140,7 +170,7 @@ func (h *Handler) Health(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *Handler) UI(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) UI(w http.ResponseWriter, r *http.Request) {
|
||||||
htmlContent, err := os.ReadFile("index.html")
|
htmlContent, err := os.ReadFile("index.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error reading index.html: %v", err)
|
// Error reading index.html
|
||||||
http.Error(w, "UI not available", http.StatusNotFound)
|
http.Error(w, "UI not available", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -166,8 +196,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
modelFilter = "all"
|
modelFilter = "all"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 GetRequests called - page: %d, limit: %d, modelFilter: %s", page, limit, modelFilter)
|
|
||||||
|
|
||||||
// Get all requests with model filter applied at storage level
|
// Get all requests with model filter applied at storage level
|
||||||
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -176,8 +204,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 Got %d requests from storage (filter: %s)", len(allRequests), modelFilter)
|
|
||||||
|
|
||||||
// Convert pointers to values for consistency
|
// Convert pointers to values for consistency
|
||||||
requests := make([]model.RequestLog, len(allRequests))
|
requests := make([]model.RequestLog, len(allRequests))
|
||||||
for i, req := range allRequests {
|
for i, req := range allRequests {
|
||||||
|
|
@ -201,8 +227,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
requests = requests[start:end]
|
requests = requests[start:end]
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 Returning %d requests after pagination", len(requests))
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(struct {
|
json.NewEncoder(w).Encode(struct {
|
||||||
Requests []model.RequestLog `json:"requests"`
|
Requests []model.RequestLog `json:"requests"`
|
||||||
|
|
@ -214,17 +238,14 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteRequests(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) DeleteRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("🗑️ Clearing request history")
|
|
||||||
|
|
||||||
clearedCount, err := h.storageService.ClearRequests()
|
clearedCount, err := h.storageService.ClearRequests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error clearing requests: %v", err)
|
log.Printf("Error clearing requests: %v", err)
|
||||||
writeErrorResponse(w, "Error clearing request history", http.StatusInternalServerError)
|
writeErrorResponse(w, "Error clearing request history", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("✅ Deleted %d request files", clearedCount)
|
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"message": "Request history cleared",
|
"message": "Request history cleared",
|
||||||
"deleted": clearedCount,
|
"deleted": clearedCount,
|
||||||
|
|
@ -238,7 +259,6 @@ func (h *Handler) NotFound(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||||
log.Println("🌊 Streaming response detected, forwarding stream...")
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
|
@ -298,7 +318,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture usage data and metadata from message_start event
|
// Capture metadata from message_start event
|
||||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_start" {
|
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_start" {
|
||||||
if message, ok := genericEvent["message"].(map[string]interface{}); ok {
|
if message, ok := genericEvent["message"].(map[string]interface{}); ok {
|
||||||
// Capture message metadata
|
// Capture message metadata
|
||||||
|
|
@ -311,51 +331,40 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
if reason, ok := message["stop_reason"].(string); ok {
|
if reason, ok := message["stop_reason"].(string); ok {
|
||||||
stopReason = reason
|
stopReason = reason
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture initial usage data from message_start
|
|
||||||
if usage, ok := message["usage"].(map[string]interface{}); ok {
|
|
||||||
finalUsage = &model.AnthropicUsage{}
|
|
||||||
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
|
||||||
finalUsage.InputTokens = int(inputTokens)
|
|
||||||
}
|
|
||||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
|
||||||
finalUsage.OutputTokens = int(outputTokens)
|
|
||||||
}
|
|
||||||
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
|
||||||
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
|
||||||
}
|
|
||||||
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
|
||||||
finalUsage.CacheReadInputTokens = int(cacheRead)
|
|
||||||
}
|
|
||||||
if tier, ok := usage["service_tier"].(string); ok {
|
|
||||||
finalUsage.ServiceTier = tier
|
|
||||||
}
|
|
||||||
log.Printf("📊 Captured initial usage from message_start: %+v", finalUsage)
|
|
||||||
} else {
|
|
||||||
log.Printf("⚠️ No usage data found in message_start event")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update output tokens from message_delta event
|
// Capture usage data from message_delta event
|
||||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_delta" {
|
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_delta" {
|
||||||
// Usage is at top level for message_delta events
|
// Usage is at top level for message_delta events
|
||||||
if usage, ok := genericEvent["usage"].(map[string]interface{}); ok {
|
if usage, ok := genericEvent["usage"].(map[string]interface{}); ok {
|
||||||
if finalUsage != nil {
|
// Create finalUsage if it doesn't exist yet
|
||||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
if finalUsage == nil {
|
||||||
finalUsage.OutputTokens = int(outputTokens)
|
finalUsage = &model.AnthropicUsage{}
|
||||||
log.Printf("📊 Updated output tokens from message_delta: %d", int(outputTokens))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("⚠️ finalUsage is nil when trying to update from message_delta usage")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture all usage fields
|
||||||
|
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
||||||
|
finalUsage.InputTokens = int(inputTokens)
|
||||||
|
}
|
||||||
|
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
||||||
|
finalUsage.OutputTokens = int(outputTokens)
|
||||||
|
}
|
||||||
|
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
||||||
|
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
||||||
|
}
|
||||||
|
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
||||||
|
finalUsage.CacheReadInputTokens = int(cacheRead)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse as structured event for content processing
|
// Parse as structured event for content processing
|
||||||
var event model.StreamingEvent
|
var event model.StreamingEvent
|
||||||
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
||||||
continue // Skip if structured parsing fails, but we already got the usage data above
|
// Skip if structured parsing fails, but we already got the usage data above
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
|
|
@ -409,9 +418,6 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
// Add usage data if we captured it
|
// Add usage data if we captured it
|
||||||
if finalUsage != nil {
|
if finalUsage != nil {
|
||||||
responseBody["usage"] = finalUsage
|
responseBody["usage"] = finalUsage
|
||||||
log.Printf("📊 Final usage data being stored: %+v", finalUsage)
|
|
||||||
} else {
|
|
||||||
log.Printf("⚠️ No usage data captured for streaming response - finalUsage is nil")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal to JSON for storage
|
// Marshal to JSON for storage
|
||||||
|
|
@ -436,10 +442,6 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||||
// Log response headers for debugging
|
|
||||||
log.Printf("📋 Response headers: Content-Encoding=%s, Content-Type=%s, Status=%d",
|
|
||||||
resp.Header.Get("Content-Encoding"), resp.Header.Get("Content-Type"), resp.StatusCode)
|
|
||||||
|
|
||||||
responseBytes, err := io.ReadAll(resp.Body)
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error reading Anthropic response: %v", err)
|
log.Printf("❌ Error reading Anthropic response: %v", err)
|
||||||
|
|
@ -447,11 +449,6 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log first few bytes to help debug compression issues
|
|
||||||
if len(responseBytes) > 0 {
|
|
||||||
log.Printf("📊 Response body starts with: %x (first 10 bytes)", responseBytes[:min(10, len(responseBytes))])
|
|
||||||
}
|
|
||||||
|
|
||||||
responseLog := &model.ResponseLog{
|
responseLog := &model.ResponseLog{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
Headers: SanitizeHeaders(resp.Header),
|
Headers: SanitizeHeaders(resp.Header),
|
||||||
|
|
@ -466,7 +463,6 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
|
||||||
if err := json.Unmarshal(responseBytes, &anthropicResp); err == nil {
|
if err := json.Unmarshal(responseBytes, &anthropicResp); err == nil {
|
||||||
// Successfully parsed - store the structured response
|
// Successfully parsed - store the structured response
|
||||||
responseLog.Body = json.RawMessage(responseBytes)
|
responseLog.Body = json.RawMessage(responseBytes)
|
||||||
log.Printf("✅ Successfully parsed Anthropic response")
|
|
||||||
} else {
|
} else {
|
||||||
// If parsing fails, store as text but log the error
|
// If parsing fails, store as text but log the error
|
||||||
log.Printf("⚠️ Failed to parse Anthropic response: %v", err)
|
log.Printf("⚠️ Failed to parse Anthropic response: %v", err)
|
||||||
|
|
@ -491,7 +487,6 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("✅ Successfully forwarded request to Anthropic API")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write(responseBytes)
|
w.Write(responseBytes)
|
||||||
}
|
}
|
||||||
|
|
@ -597,7 +592,6 @@ func extractTextFromMessage(message json.RawMessage) string {
|
||||||
// Conversation handlers
|
// Conversation handlers
|
||||||
|
|
||||||
func (h *Handler) GetConversations(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetConversations(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("📚 Getting conversations from Claude projects")
|
|
||||||
|
|
||||||
conversations, err := h.conversationService.GetConversations()
|
conversations, err := h.conversationService.GetConversations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -687,8 +681,6 @@ func (h *Handler) GetConversationByID(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📖 Getting conversation %s from project %s", sessionID, projectPath)
|
|
||||||
|
|
||||||
conversation, err := h.conversationService.GetConversation(projectPath, sessionID)
|
conversation, err := h.conversationService.GetConversation(projectPath, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error getting conversation: %v", err)
|
log.Printf("❌ Error getting conversation: %v", err)
|
||||||
|
|
@ -706,8 +698,6 @@ func (h *Handler) GetConversationsByProject(w http.ResponseWriter, r *http.Reque
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📁 Getting conversations for project %s", projectPath)
|
|
||||||
|
|
||||||
conversations, err := h.conversationService.GetConversationsByProject(projectPath)
|
conversations, err := h.conversationService.GetConversationsByProject(projectPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error getting project conversations: %v", err)
|
log.Printf("❌ Error getting project conversations: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,10 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
|
@ -16,11 +15,10 @@ import (
|
||||||
func Logging(next http.Handler) http.Handler {
|
func Logging(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
log.Printf("%s - %s %s", start.Format(time.RFC3339), r.Method, r.URL.Path)
|
|
||||||
log.Printf("Headers: %s", formatHeaders(r.Header))
|
|
||||||
|
|
||||||
|
// For POST requests with body, read and store the bytes
|
||||||
var bodyBytes []byte
|
var bodyBytes []byte
|
||||||
if r.Body != nil {
|
if r.Body != nil && (r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH") {
|
||||||
var err error
|
var err error
|
||||||
bodyBytes, err = io.ReadAll(r.Body)
|
bodyBytes, err = io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -30,64 +28,29 @@ func Logging(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes)
|
// Store raw bytes in context for handler to use
|
||||||
r = r.WithContext(ctx)
|
ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
log.Printf("Body length: %d bytes", len(bodyBytes))
|
|
||||||
if len(bodyBytes) > 0 {
|
|
||||||
logRequestBody(bodyBytes)
|
|
||||||
}
|
}
|
||||||
log.Println("---")
|
|
||||||
|
|
||||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||||
next.ServeHTTP(wrapped, r)
|
next.ServeHTTP(wrapped, r)
|
||||||
|
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
log.Printf("Response: %d %s (took %v)", wrapped.statusCode, http.StatusText(wrapped.statusCode), duration)
|
statusColor := getStatusColor(wrapped.statusCode)
|
||||||
|
|
||||||
|
log.Printf("%s %s %s%d%s %s (%s)",
|
||||||
|
r.Method,
|
||||||
|
r.URL.Path,
|
||||||
|
statusColor,
|
||||||
|
wrapped.statusCode,
|
||||||
|
colorReset,
|
||||||
|
http.StatusText(wrapped.statusCode),
|
||||||
|
formatDuration(duration))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatHeaders(headers http.Header) string {
|
|
||||||
headerMap := make(map[string][]string)
|
|
||||||
for k, v := range headers {
|
|
||||||
headerMap[k] = sanitizeHeaderValue(k, v)
|
|
||||||
}
|
|
||||||
headerBytes, _ := json.MarshalIndent(headerMap, "", " ")
|
|
||||||
return string(headerBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeHeaderValue(key string, values []string) []string {
|
|
||||||
lowerKey := strings.ToLower(key)
|
|
||||||
sensitiveHeaders := []string{
|
|
||||||
"x-api-key",
|
|
||||||
"api-key",
|
|
||||||
"authorization",
|
|
||||||
"anthropic-api-key",
|
|
||||||
"openai-api-key",
|
|
||||||
"bearer",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sensitive := range sensitiveHeaders {
|
|
||||||
if strings.Contains(lowerKey, sensitive) {
|
|
||||||
return []string{"[REDACTED]"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
func logRequestBody(bodyBytes []byte) {
|
|
||||||
var bodyJSON interface{}
|
|
||||||
if err := json.Unmarshal(bodyBytes, &bodyJSON); err == nil {
|
|
||||||
bodyStr, _ := json.MarshalIndent(bodyJSON, "", " ")
|
|
||||||
log.Printf("Body: %s", string(bodyStr))
|
|
||||||
} else {
|
|
||||||
log.Printf("❌ Failed to parse body as JSON: %v", err)
|
|
||||||
log.Printf("Raw body: %s", string(bodyBytes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
statusCode int
|
statusCode int
|
||||||
|
|
@ -97,3 +60,37 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||||
rw.statusCode = code
|
rw.statusCode = code
|
||||||
rw.ResponseWriter.WriteHeader(code)
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ANSI color codes
|
||||||
|
const (
|
||||||
|
colorReset = "\033[0m"
|
||||||
|
colorGreen = "\033[32m"
|
||||||
|
colorYellow = "\033[33m"
|
||||||
|
colorRed = "\033[31m"
|
||||||
|
colorBlue = "\033[34m"
|
||||||
|
colorCyan = "\033[36m"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getStatusColor(status int) string {
|
||||||
|
switch {
|
||||||
|
case status >= 200 && status < 300:
|
||||||
|
return colorGreen
|
||||||
|
case status >= 300 && status < 400:
|
||||||
|
return colorBlue
|
||||||
|
case status >= 400 && status < 500:
|
||||||
|
return colorYellow
|
||||||
|
case status >= 500:
|
||||||
|
return colorRed
|
||||||
|
default:
|
||||||
|
return colorReset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDuration(d time.Duration) string {
|
||||||
|
if d < time.Millisecond {
|
||||||
|
return fmt.Sprintf("%dµs", d.Microseconds())
|
||||||
|
} else if d < time.Second {
|
||||||
|
return fmt.Sprintf("%dms", d.Milliseconds())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.2fs", d.Seconds())
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,17 +25,19 @@ type CriteriaScore struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestLog struct {
|
type RequestLog struct {
|
||||||
RequestID string `json:"requestId"`
|
RequestID string `json:"requestId"`
|
||||||
Timestamp string `json:"timestamp"`
|
Timestamp string `json:"timestamp"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
Headers map[string][]string `json:"headers"`
|
Headers map[string][]string `json:"headers"`
|
||||||
Body interface{} `json:"body"`
|
Body interface{} `json:"body"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
UserAgent string `json:"userAgent"`
|
OriginalModel string `json:"originalModel,omitempty"`
|
||||||
ContentType string `json:"contentType"`
|
RoutedModel string `json:"routedModel,omitempty"`
|
||||||
PromptGrade *PromptGrade `json:"promptGrade,omitempty"`
|
UserAgent string `json:"userAgent"`
|
||||||
Response *ResponseLog `json:"response,omitempty"`
|
ContentType string `json:"contentType"`
|
||||||
|
PromptGrade *PromptGrade `json:"promptGrade,omitempty"`
|
||||||
|
Response *ResponseLog `json:"response,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseLog struct {
|
type ResponseLog struct {
|
||||||
|
|
@ -129,14 +131,9 @@ type Tool struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputSchema struct {
|
type InputSchema struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Properties map[string]Property `json:"properties"`
|
Properties map[string]interface{} `json:"properties"`
|
||||||
Required []string `json:"required,omitempty"`
|
Required []string `json:"required,omitempty"`
|
||||||
}
|
|
||||||
|
|
||||||
type Property struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnthropicRequest struct {
|
type AnthropicRequest struct {
|
||||||
|
|
@ -147,6 +144,7 @@ type AnthropicRequest struct {
|
||||||
System []AnthropicSystemMessage `json:"system,omitempty"`
|
System []AnthropicSystemMessage `json:"system,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelsResponse struct {
|
type ModelsResponse struct {
|
||||||
|
|
|
||||||
131
proxy/internal/provider/anthropic.go
Normal file
131
proxy/internal/provider/anthropic.go
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AnthropicProvider struct {
|
||||||
|
client *http.Client
|
||||||
|
config *config.AnthropicProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAnthropicProvider(cfg *config.AnthropicProviderConfig) Provider {
|
||||||
|
return &AnthropicProvider{
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 300 * time.Second, // 5 minutes timeout
|
||||||
|
},
|
||||||
|
config: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AnthropicProvider) Name() string {
|
||||||
|
return "anthropic"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AnthropicProvider) ForwardRequest(ctx context.Context, originalReq *http.Request) (*http.Response, error) {
|
||||||
|
// Clone the request to avoid modifying the original
|
||||||
|
proxyReq := originalReq.Clone(ctx)
|
||||||
|
|
||||||
|
// Parse the configured base URL
|
||||||
|
baseURL, err := url.Parse(p.config.BaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse base URL '%s': %w", p.config.BaseURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseURL.Scheme == "" || baseURL.Host == "" {
|
||||||
|
return nil, fmt.Errorf("invalid base URL, scheme and host are required: %s", p.config.BaseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the destination URL
|
||||||
|
proxyReq.URL.Scheme = baseURL.Scheme
|
||||||
|
proxyReq.URL.Host = baseURL.Host
|
||||||
|
proxyReq.URL.Path = path.Join(baseURL.Path, originalReq.URL.Path)
|
||||||
|
|
||||||
|
// Preserve query parameters
|
||||||
|
proxyReq.URL.RawQuery = originalReq.URL.RawQuery
|
||||||
|
|
||||||
|
// Update request headers
|
||||||
|
proxyReq.RequestURI = ""
|
||||||
|
proxyReq.Host = baseURL.Host
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers
|
||||||
|
removeHopByHopHeaders(proxyReq.Header)
|
||||||
|
|
||||||
|
// Add required headers if not present
|
||||||
|
if proxyReq.Header.Get("anthropic-version") == "" {
|
||||||
|
proxyReq.Header.Set("anthropic-version", p.config.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support gzip encoding
|
||||||
|
proxyReq.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
|
// Forward the request
|
||||||
|
resp, err := p.client.Do(proxyReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to forward request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle gzip-encoded responses
|
||||||
|
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
|
gzipReader, err := gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
resp.Body = &gzipResponseBody{
|
||||||
|
Reader: gzipReader,
|
||||||
|
closer: resp.Body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type gzipResponseBody struct {
|
||||||
|
io.Reader
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gzipResponseBody) Close() error {
|
||||||
|
if gzReader, ok := g.Reader.(*gzip.Reader); ok {
|
||||||
|
gzReader.Close()
|
||||||
|
}
|
||||||
|
return g.closer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeHopByHopHeaders(header http.Header) {
|
||||||
|
hopByHopHeaders := []string{
|
||||||
|
"Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"TE",
|
||||||
|
"Trailers",
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, h := range hopByHopHeaders {
|
||||||
|
header.Del(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any headers specified in the Connection header
|
||||||
|
if connection := header.Get("Connection"); connection != "" {
|
||||||
|
for _, h := range strings.Split(connection, ",") {
|
||||||
|
header.Del(strings.TrimSpace(h))
|
||||||
|
}
|
||||||
|
header.Del("Connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
722
proxy/internal/provider/openai.go
Normal file
722
proxy/internal/provider/openai.go
Normal file
|
|
@ -0,0 +1,722 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProvider struct {
|
||||||
|
client *http.Client
|
||||||
|
config *config.OpenAIProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAIProvider(cfg *config.OpenAIProviderConfig) Provider {
|
||||||
|
return &OpenAIProvider{
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 300 * time.Second, // 5 minutes timeout
|
||||||
|
},
|
||||||
|
config: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) Name() string {
|
||||||
|
return "openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.Request) (*http.Response, error) {
|
||||||
|
// First, we need to convert the Anthropic request to OpenAI format
|
||||||
|
bodyBytes, err := io.ReadAll(originalReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
}
|
||||||
|
originalReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
|
var anthropicReq model.AnthropicRequest
|
||||||
|
if err := json.Unmarshal(bodyBytes, &anthropicReq); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse anthropic request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to OpenAI format
|
||||||
|
openAIReq := convertAnthropicToOpenAI(&anthropicReq)
|
||||||
|
newBodyBytes, err := json.Marshal(openAIReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal openai request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone the request with new body
|
||||||
|
proxyReq := originalReq.Clone(ctx)
|
||||||
|
proxyReq.Body = io.NopCloser(bytes.NewReader(newBodyBytes))
|
||||||
|
proxyReq.ContentLength = int64(len(newBodyBytes))
|
||||||
|
|
||||||
|
// Parse the configured base URL
|
||||||
|
baseURL, err := url.Parse(p.config.BaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse base URL '%s': %w", p.config.BaseURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the destination URL for OpenAI
|
||||||
|
proxyReq.URL.Scheme = baseURL.Scheme
|
||||||
|
proxyReq.URL.Host = baseURL.Host
|
||||||
|
proxyReq.URL.Path = "/v1/chat/completions" // OpenAI endpoint
|
||||||
|
|
||||||
|
// Update request headers
|
||||||
|
proxyReq.RequestURI = ""
|
||||||
|
proxyReq.Host = baseURL.Host
|
||||||
|
|
||||||
|
// Remove Anthropic-specific headers
|
||||||
|
proxyReq.Header.Del("anthropic-version")
|
||||||
|
proxyReq.Header.Del("x-api-key")
|
||||||
|
|
||||||
|
// Add OpenAI headers
|
||||||
|
if p.config.APIKey != "" {
|
||||||
|
proxyReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||||
|
}
|
||||||
|
proxyReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Forward the request
|
||||||
|
resp, err := p.client.Do(proxyReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to forward request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error responses
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
// Read the error body for debugging
|
||||||
|
errorBody, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Log the error details
|
||||||
|
// OpenAI API error - will be returned to client
|
||||||
|
|
||||||
|
// Create an error response in Anthropic format
|
||||||
|
errorResp := map[string]interface{}{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]interface{}{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": fmt.Sprintf("OpenAI API error: %s", string(errorBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
errorJSON, _ := json.Marshal(errorResp)
|
||||||
|
|
||||||
|
// Create a new response with the error
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(errorJSON))
|
||||||
|
resp.Header.Set("Content-Type", "application/json")
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.ContentLength = int64(len(errorJSON))
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle gzip-encoded responses
|
||||||
|
var bodyReader io.ReadCloser = resp.Body
|
||||||
|
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
gzReader, err := gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
bodyReader = gzReader
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
|
}
|
||||||
|
|
||||||
|
// For streaming responses, we need to convert back to Anthropic format
|
||||||
|
if anthropicReq.Stream {
|
||||||
|
// Create a pipe to transform the response
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
|
||||||
|
// Start a goroutine to transform the stream
|
||||||
|
go func() {
|
||||||
|
defer pw.Close()
|
||||||
|
defer bodyReader.Close()
|
||||||
|
transformOpenAIStreamToAnthropic(bodyReader, pw)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Replace the response body with our transformed stream
|
||||||
|
resp.Body = pr
|
||||||
|
} else {
|
||||||
|
// For non-streaming, read and convert the response
|
||||||
|
respBody, err := io.ReadAll(bodyReader)
|
||||||
|
bodyReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI response back to Anthropic format
|
||||||
|
transformedBody := transformOpenAIResponseToAnthropic(respBody)
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(transformedBody))
|
||||||
|
resp.ContentLength = int64(len(transformedBody))
|
||||||
|
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(transformedBody)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{} {
|
||||||
|
messages := []map[string]interface{}{}
|
||||||
|
|
||||||
|
// Combine all system messages into a single system message for OpenAI
|
||||||
|
if len(req.System) > 0 {
|
||||||
|
systemContent := ""
|
||||||
|
for i, sysMsg := range req.System {
|
||||||
|
if i > 0 {
|
||||||
|
systemContent += "\n\n"
|
||||||
|
}
|
||||||
|
systemContent += sysMsg.Text
|
||||||
|
}
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": "system",
|
||||||
|
"content": systemContent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add conversation messages
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
// Handle messages with raw content that may contain tool results
|
||||||
|
if contentArray, ok := msg.Content.([]interface{}); ok {
|
||||||
|
// Check if this message contains tool results
|
||||||
|
hasToolResults := false
|
||||||
|
for _, item := range contentArray {
|
||||||
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
|
if blockType, hasType := block["type"].(string); hasType && blockType == "tool_result" {
|
||||||
|
hasToolResults = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasToolResults {
|
||||||
|
textContent := ""
|
||||||
|
|
||||||
|
for _, item := range contentArray {
|
||||||
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
|
if blockType, hasType := block["type"].(string); hasType {
|
||||||
|
if blockType == "text" {
|
||||||
|
if text, hasText := block["text"].(string); hasText {
|
||||||
|
textContent += text + "\n"
|
||||||
|
}
|
||||||
|
} else if blockType == "tool_result" {
|
||||||
|
// Extract tool ID
|
||||||
|
toolID := ""
|
||||||
|
if id, hasID := block["tool_use_id"].(string); hasID {
|
||||||
|
toolID = id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different formats of tool result content
|
||||||
|
resultContent := ""
|
||||||
|
if content, hasContent := block["content"]; hasContent {
|
||||||
|
if contentStr, ok := content.(string); ok {
|
||||||
|
resultContent = contentStr
|
||||||
|
} else if contentList, ok := content.([]interface{}); ok {
|
||||||
|
// If content is a list of blocks, extract text from each
|
||||||
|
for _, c := range contentList {
|
||||||
|
if contentMap, ok := c.(map[string]interface{}); ok {
|
||||||
|
if contentMap["type"] == "text" {
|
||||||
|
if text, ok := contentMap["text"].(string); ok {
|
||||||
|
resultContent += text + "\n"
|
||||||
|
}
|
||||||
|
} else if text, hasText := contentMap["text"]; hasText {
|
||||||
|
// Handle any dict by trying to extract text
|
||||||
|
resultContent += fmt.Sprintf("%v\n", text)
|
||||||
|
} else {
|
||||||
|
// Try to JSON serialize
|
||||||
|
if jsonBytes, err := json.Marshal(contentMap); err == nil {
|
||||||
|
resultContent += string(jsonBytes) + "\n"
|
||||||
|
} else {
|
||||||
|
resultContent += fmt.Sprintf("%v\n", contentMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if contentDict, ok := content.(map[string]interface{}); ok {
|
||||||
|
// Handle dictionary content
|
||||||
|
if contentDict["type"] == "text" {
|
||||||
|
if text, ok := contentDict["text"].(string); ok {
|
||||||
|
resultContent = text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to JSON serialize
|
||||||
|
if jsonBytes, err := json.Marshal(contentDict); err == nil {
|
||||||
|
resultContent = string(jsonBytes)
|
||||||
|
} else {
|
||||||
|
resultContent = fmt.Sprintf("%v", contentDict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle any other type by converting to string
|
||||||
|
if jsonBytes, err := json.Marshal(content); err == nil {
|
||||||
|
resultContent = string(jsonBytes)
|
||||||
|
} else {
|
||||||
|
resultContent = fmt.Sprintf("%v", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In OpenAI format, tool results come from the user (matching Python behavior)
|
||||||
|
textContent += fmt.Sprintf("Tool result for %s:\n%s\n", toolID, resultContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add as a single user message with all the content
|
||||||
|
if textContent == "" {
|
||||||
|
textContent = "..."
|
||||||
|
}
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": strings.TrimSpace(textContent),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Handle regular messages with content blocks
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
for _, item := range contentArray {
|
||||||
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
|
if blockType, hasType := block["type"].(string); hasType && blockType == "text" {
|
||||||
|
if text, hasText := block["text"].(string); hasText {
|
||||||
|
if content != "" {
|
||||||
|
content += "\n"
|
||||||
|
}
|
||||||
|
content += text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure content is never empty
|
||||||
|
if content == "" {
|
||||||
|
content = "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle simple string content
|
||||||
|
contentBlocks := msg.GetContentBlocks()
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
// Concatenate all text blocks
|
||||||
|
for _, block := range contentBlocks {
|
||||||
|
if block.Type == "text" {
|
||||||
|
if content != "" {
|
||||||
|
content += "\n"
|
||||||
|
}
|
||||||
|
content += block.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure content is never empty
|
||||||
|
if content == "" {
|
||||||
|
content = "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if max_tokens exceeds the model's limit and cap it if necessary
|
||||||
|
maxTokensLimit := 16384 // Assuming this is the limit for the model
|
||||||
|
if req.MaxTokens > maxTokensLimit {
|
||||||
|
// Capping max_tokens to model limit
|
||||||
|
req.MaxTokens = maxTokensLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
// All OpenAI models now use max_completion_tokens instead of deprecated max_tokens
|
||||||
|
openAIReq := map[string]interface{}{
|
||||||
|
"model": req.Model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": req.Stream,
|
||||||
|
"max_completion_tokens": req.MaxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If streaming is enabled, request usage data to be included in the final chunk
|
||||||
|
if req.Stream {
|
||||||
|
openAIReq["stream_options"] = map[string]interface{}{
|
||||||
|
"include_usage": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is an o-series model (they don't support temperature)
|
||||||
|
isOSeriesModel := strings.HasPrefix(req.Model, "o1") || strings.HasPrefix(req.Model, "o3")
|
||||||
|
|
||||||
|
// Only include temperature for non-o-series models
|
||||||
|
if !isOSeriesModel {
|
||||||
|
openAIReq["temperature"] = req.Temperature
|
||||||
|
}
|
||||||
|
// Convert Anthropic tools to OpenAI format
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
tools := make([]map[string]interface{}, 0, len(req.Tools))
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
// Ensure tool has required fields
|
||||||
|
if tool.Name == "" {
|
||||||
|
// Skip tools with empty names
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build parameters with error checking
|
||||||
|
parameters := make(map[string]interface{})
|
||||||
|
parameters["type"] = tool.InputSchema.Type
|
||||||
|
if parameters["type"] == "" {
|
||||||
|
parameters["type"] = "object" // Default to object type
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle properties safely with array validation
|
||||||
|
if tool.InputSchema.Properties != nil {
|
||||||
|
// Fix array properties that are missing items field
|
||||||
|
fixedProperties := make(map[string]interface{})
|
||||||
|
for propName, propValue := range tool.InputSchema.Properties {
|
||||||
|
if prop, ok := propValue.(map[string]interface{}); ok {
|
||||||
|
// Check if this is an array type missing items
|
||||||
|
if propType, hasType := prop["type"]; hasType && propType == "array" {
|
||||||
|
if _, hasItems := prop["items"]; !hasItems {
|
||||||
|
// Add default items definition for arrays
|
||||||
|
// Add default items for array properties missing them
|
||||||
|
prop["items"] = map[string]interface{}{"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fixedProperties[propName] = prop
|
||||||
|
} else {
|
||||||
|
// Keep non-map properties as-is
|
||||||
|
fixedProperties[propName] = propValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parameters["properties"] = fixedProperties
|
||||||
|
} else {
|
||||||
|
parameters["properties"] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle required fields
|
||||||
|
if len(tool.InputSchema.Required) > 0 {
|
||||||
|
parameters["required"] = tool.InputSchema.Required
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build function definition
|
||||||
|
functionDef := map[string]interface{}{
|
||||||
|
"name": tool.Name,
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add description if present
|
||||||
|
if tool.Description != "" {
|
||||||
|
functionDef["description"] = tool.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
openAITool := map[string]interface{}{
|
||||||
|
"type": "function",
|
||||||
|
"function": functionDef,
|
||||||
|
}
|
||||||
|
tools = append(tools, openAITool)
|
||||||
|
}
|
||||||
|
openAIReq["tools"] = tools
|
||||||
|
|
||||||
|
// Handle tool_choice if present
|
||||||
|
if req.ToolChoice != nil {
|
||||||
|
// Convert Anthropic tool_choice to OpenAI format
|
||||||
|
if toolChoiceMap, ok := req.ToolChoice.(map[string]interface{}); ok {
|
||||||
|
choiceType := toolChoiceMap["type"]
|
||||||
|
switch choiceType {
|
||||||
|
case "auto":
|
||||||
|
openAIReq["tool_choice"] = "auto"
|
||||||
|
case "any":
|
||||||
|
openAIReq["tool_choice"] = "required"
|
||||||
|
case "tool":
|
||||||
|
// Specific tool choice
|
||||||
|
if name, hasName := toolChoiceMap["name"].(string); hasName {
|
||||||
|
openAIReq["tool_choice"] = map[string]interface{}{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default to auto if we can't determine
|
||||||
|
openAIReq["tool_choice"] = "auto"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMapKeys(m map[string]interface{}) []string {
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
|
// This is a simplified transformation
|
||||||
|
// In production, you'd want to handle all fields properly
|
||||||
|
var openAIResp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(respBody, &openAIResp); err != nil {
|
||||||
|
return respBody // Return as-is if we can't parse
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the assistant's message
|
||||||
|
var contentBlocks []map[string]interface{}
|
||||||
|
|
||||||
|
if choices, ok := openAIResp["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||||
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||||
|
if msg, ok := choice["message"].(map[string]interface{}); ok {
|
||||||
|
// Handle regular text content
|
||||||
|
if content, ok := msg["content"].(string); ok && content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok {
|
||||||
|
// Since this proxy forwards to Claude/Anthropic API, we should always
|
||||||
|
// use tool_use blocks so Claude can execute the tools properly
|
||||||
|
// (regardless of which model generated the response)
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
if toolCall, ok := tc.(map[string]interface{}); ok {
|
||||||
|
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||||
|
// Convert OpenAI tool call to Anthropic tool_use format
|
||||||
|
anthropicToolUse := map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolCall["id"],
|
||||||
|
"name": function["name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the arguments JSON string
|
||||||
|
if argsStr, ok := function["arguments"].(string); ok {
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||||
|
anthropicToolUse["input"] = args
|
||||||
|
} else {
|
||||||
|
// If parsing fails, wrap in a raw field like Python does
|
||||||
|
// Failed to parse tool arguments - skip
|
||||||
|
anthropicToolUse["input"] = map[string]interface{}{"raw": argsStr}
|
||||||
|
}
|
||||||
|
} else if args, ok := function["arguments"].(map[string]interface{}); ok {
|
||||||
|
// Already a map, use directly
|
||||||
|
anthropicToolUse["input"] = args
|
||||||
|
} else {
|
||||||
|
// Fallback for any other type
|
||||||
|
anthropicToolUse["input"] = map[string]interface{}{"raw": fmt.Sprintf("%v", function["arguments"])}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlocks = append(contentBlocks, anthropicToolUse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no content blocks were created, add a default empty text block
|
||||||
|
if len(contentBlocks) == 0 {
|
||||||
|
contentBlocks = []map[string]interface{}{
|
||||||
|
{"type": "text", "text": ""},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Anthropic-style response
|
||||||
|
anthropicResp := map[string]interface{}{
|
||||||
|
"id": openAIResp["id"],
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": contentBlocks,
|
||||||
|
"model": openAIResp["model"],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI usage format to Anthropic format
|
||||||
|
if usage, ok := openAIResp["usage"].(map[string]interface{}); ok {
|
||||||
|
anthropicUsage := map[string]interface{}{}
|
||||||
|
|
||||||
|
// Map prompt_tokens to input_tokens
|
||||||
|
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map completion_tokens to output_tokens
|
||||||
|
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include total_tokens if needed (though Anthropic format doesn't typically use it)
|
||||||
|
if totalTokens, ok := usage["total_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["total_tokens"] = int(totalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicResp["usage"] = anthropicUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(anthropicResp)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
||||||
|
defer openAIStream.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(openAIStream)
|
||||||
|
var messageStarted bool
|
||||||
|
var contentStarted bool
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
// Skip empty lines
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SSE data lines
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
|
||||||
|
// Handle end of stream
|
||||||
|
if data == "[DONE]" {
|
||||||
|
// Send Anthropic-style completion
|
||||||
|
if contentStarted {
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||||
|
}
|
||||||
|
if messageStarted {
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse OpenAI response
|
||||||
|
var openAIChunk map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for usage data BEFORE processing choices
|
||||||
|
// According to OpenAI docs, usage is sent in the final chunk with empty choices array
|
||||||
|
if usage, hasUsage := openAIChunk["usage"].(map[string]interface{}); hasUsage {
|
||||||
|
// Convert OpenAI usage to Anthropic format
|
||||||
|
anthropicUsage := map[string]interface{}{}
|
||||||
|
|
||||||
|
// Handle both float64 and int types
|
||||||
|
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||||
|
} else if promptTokens, ok := usage["prompt_tokens"].(int); ok {
|
||||||
|
anthropicUsage["input_tokens"] = promptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||||
|
} else if completionTokens, ok := usage["completion_tokens"].(int); ok {
|
||||||
|
anthropicUsage["output_tokens"] = completionTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(anthropicUsage) > 0 {
|
||||||
|
// Send usage data in a message_delta event
|
||||||
|
usageDelta := map[string]interface{}{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]interface{}{},
|
||||||
|
"usage": anthropicUsage,
|
||||||
|
}
|
||||||
|
usageJSON, _ := json.Marshal(usageDelta)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract choices array
|
||||||
|
choices, ok := openAIChunk["choices"].([]interface{})
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
// Skip further processing if no choices, but we already handled usage above
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
choice, ok := choices[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
delta, ok := choice["delta"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle first chunk - send message_start
|
||||||
|
if !messageStarted {
|
||||||
|
messageStarted = true
|
||||||
|
messageStart := map[string]interface{}{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"id": openAIChunk["id"],
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": openAIChunk["model"],
|
||||||
|
"content": []interface{}{},
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
// Empty usage - will be updated in final chunk
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
startJSON, _ := json.Marshal(messageStart)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle content
|
||||||
|
if content, hasContent := delta["content"].(string); hasContent && content != "" {
|
||||||
|
if !contentStarted {
|
||||||
|
contentStarted = true
|
||||||
|
// Send content_block_start
|
||||||
|
blockStart := map[string]interface{}{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
blockStartJSON, _ := json.Marshal(blockStart)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send content_block_delta
|
||||||
|
contentDelta := map[string]interface{}{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deltaJSON, _ := json.Marshal(contentDelta)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
15
proxy/internal/provider/provider.go
Normal file
15
proxy/internal/provider/provider.go
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider is the interface that all LLM providers must implement
|
||||||
|
type Provider interface {
|
||||||
|
// Name returns the provider name (e.g., "anthropic", "openai")
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// ForwardRequest forwards a request to the provider's API
|
||||||
|
ForwardRequest(ctx context.Context, req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
@ -30,29 +30,29 @@ func NewConversationService() ConversationService {
|
||||||
|
|
||||||
// ConversationMessage represents a single message in a Claude conversation
|
// ConversationMessage represents a single message in a Claude conversation
|
||||||
type ConversationMessage struct {
|
type ConversationMessage struct {
|
||||||
ParentUUID *string `json:"parentUuid"`
|
ParentUUID *string `json:"parentUuid"`
|
||||||
IsSidechain bool `json:"isSidechain"`
|
IsSidechain bool `json:"isSidechain"`
|
||||||
UserType string `json:"userType"`
|
UserType string `json:"userType"`
|
||||||
CWD string `json:"cwd"`
|
CWD string `json:"cwd"`
|
||||||
SessionID string `json:"sessionId"`
|
SessionID string `json:"sessionId"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Message json.RawMessage `json:"message"`
|
Message json.RawMessage `json:"message"`
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
Timestamp string `json:"timestamp"`
|
Timestamp string `json:"timestamp"`
|
||||||
ParsedTime time.Time `json:"-"`
|
ParsedTime time.Time `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conversation represents a complete conversation session
|
// Conversation represents a complete conversation session
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
SessionID string `json:"sessionId"`
|
SessionID string `json:"sessionId"`
|
||||||
ProjectPath string `json:"projectPath"`
|
ProjectPath string `json:"projectPath"`
|
||||||
ProjectName string `json:"projectName"`
|
ProjectName string `json:"projectName"`
|
||||||
Messages []*ConversationMessage `json:"messages"`
|
Messages []*ConversationMessage `json:"messages"`
|
||||||
StartTime time.Time `json:"startTime"`
|
StartTime time.Time `json:"startTime"`
|
||||||
EndTime time.Time `json:"endTime"`
|
EndTime time.Time `json:"endTime"`
|
||||||
MessageCount int `json:"messageCount"`
|
MessageCount int `json:"messageCount"`
|
||||||
FileModTime time.Time `json:"-"` // Used for sorting, not exported
|
FileModTime time.Time `json:"-"` // Used for sorting, not exported
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversations returns all conversations organized by project
|
// GetConversations returns all conversations organized by project
|
||||||
|
|
@ -74,7 +74,7 @@ func (cs *conversationService) GetConversations() (map[string][]*Conversation, e
|
||||||
// Get the project path relative to claudeProjectsPath
|
// Get the project path relative to claudeProjectsPath
|
||||||
projectDir := filepath.Dir(path)
|
projectDir := filepath.Dir(path)
|
||||||
projectRelPath, _ := filepath.Rel(cs.claudeProjectsPath, projectDir)
|
projectRelPath, _ := filepath.Rel(cs.claudeProjectsPath, projectDir)
|
||||||
|
|
||||||
// Skip files directly in the projects directory
|
// Skip files directly in the projects directory
|
||||||
if projectRelPath == "." || projectRelPath == "" {
|
if projectRelPath == "." || projectRelPath == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -99,18 +99,7 @@ func (cs *conversationService) GetConversations() (map[string][]*Conversation, e
|
||||||
return nil, fmt.Errorf("failed to walk claude projects: %w", err)
|
return nil, fmt.Errorf("failed to walk claude projects: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log any parsing errors encountered
|
// Some parsing errors may have occurred but were handled
|
||||||
if len(parseErrors) > 0 {
|
|
||||||
fmt.Printf("Warning: Encountered %d parsing errors while loading conversations:\n", len(parseErrors))
|
|
||||||
for i, err := range parseErrors {
|
|
||||||
if i < 5 { // Only show first 5 errors to avoid spam
|
|
||||||
fmt.Printf(" - %s\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(parseErrors) > 5 {
|
|
||||||
fmt.Printf(" ... and %d more errors\n", len(parseErrors)-5)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort conversations within each project by file modification time (newest first)
|
// Sort conversations within each project by file modification time (newest first)
|
||||||
for project := range conversations {
|
for project := range conversations {
|
||||||
|
|
@ -125,7 +114,7 @@ func (cs *conversationService) GetConversations() (map[string][]*Conversation, e
|
||||||
// GetConversation returns a specific conversation by project and session ID
|
// GetConversation returns a specific conversation by project and session ID
|
||||||
func (cs *conversationService) GetConversation(projectPath, sessionID string) (*Conversation, error) {
|
func (cs *conversationService) GetConversation(projectPath, sessionID string) (*Conversation, error) {
|
||||||
filePath := filepath.Join(cs.claudeProjectsPath, projectPath, sessionID+".jsonl")
|
filePath := filepath.Join(cs.claudeProjectsPath, projectPath, sessionID+".jsonl")
|
||||||
|
|
||||||
conv, err := cs.parseConversationFile(filePath, projectPath)
|
conv, err := cs.parseConversationFile(filePath, projectPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse conversation: %w", err)
|
return nil, fmt.Errorf("failed to parse conversation: %w", err)
|
||||||
|
|
@ -175,7 +164,7 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to stat file: %w", err)
|
return nil, fmt.Errorf("failed to stat file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||||
|
|
@ -185,9 +174,9 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
var messages []*ConversationMessage
|
var messages []*ConversationMessage
|
||||||
var parseErrors int
|
var parseErrors int
|
||||||
lineNum := 0
|
lineNum := 0
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
|
|
||||||
// Increase buffer size for large messages
|
// Increase buffer size for large messages
|
||||||
const maxScanTokenSize = 10 * 1024 * 1024 // 10MB
|
const maxScanTokenSize = 10 * 1024 * 1024 // 10MB
|
||||||
buf := make([]byte, maxScanTokenSize)
|
buf := make([]byte, maxScanTokenSize)
|
||||||
|
|
@ -196,18 +185,18 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
lineNum++
|
lineNum++
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
|
|
||||||
// Skip empty lines
|
// Skip empty lines
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg ConversationMessage
|
var msg ConversationMessage
|
||||||
if err := json.Unmarshal(line, &msg); err != nil {
|
if err := json.Unmarshal(line, &msg); err != nil {
|
||||||
parseErrors++
|
parseErrors++
|
||||||
// Log only first few errors to avoid spam
|
// Log only first few errors to avoid spam
|
||||||
if parseErrors <= 3 {
|
if parseErrors <= 3 {
|
||||||
fmt.Printf("Warning: Failed to parse line %d in %s: %v\n", lineNum, filePath, err)
|
// Skip malformed line
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -219,7 +208,7 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
// Try alternative timestamp formats
|
// Try alternative timestamp formats
|
||||||
parsedTime, err = time.Parse(time.RFC3339Nano, msg.Timestamp)
|
parsedTime, err = time.Parse(time.RFC3339Nano, msg.Timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Warning: Failed to parse timestamp '%s' in %s\n", msg.Timestamp, filePath)
|
// Skip message with invalid timestamp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg.ParsedTime = parsedTime
|
msg.ParsedTime = parsedTime
|
||||||
|
|
@ -233,7 +222,7 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
}
|
}
|
||||||
|
|
||||||
if parseErrors > 3 {
|
if parseErrors > 3 {
|
||||||
fmt.Printf("Warning: Total of %d lines failed to parse in %s\n", parseErrors, filePath)
|
// Some lines failed to parse but were skipped
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return empty conversation if no messages (caller can decide what to do)
|
// Return empty conversation if no messages (caller can decide what to do)
|
||||||
|
|
@ -303,4 +292,4 @@ func (cs *conversationService) parseConversationFile(filePath, projectPath strin
|
||||||
MessageCount: len(messages),
|
MessageCount: len(messages),
|
||||||
FileModTime: fileInfo.ModTime(),
|
FileModTime: fileInfo.ModTime(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
269
proxy/internal/service/model_router.go
Normal file
269
proxy/internal/service/model_router.go
Normal file
|
|
@ -0,0 +1,269 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoutingDecision contains the result of routing analysis
|
||||||
|
type RoutingDecision struct {
|
||||||
|
Provider provider.Provider
|
||||||
|
OriginalModel string
|
||||||
|
TargetModel string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelRouter struct {
|
||||||
|
config *config.Config
|
||||||
|
providers map[string]provider.Provider
|
||||||
|
subagentMappings map[string]string // agentName -> targetModel
|
||||||
|
customAgentPrompts map[string]SubagentDefinition // promptHash -> definition
|
||||||
|
modelProviderMap map[string]string // model -> provider mapping
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubagentDefinition struct {
|
||||||
|
Name string
|
||||||
|
TargetModel string
|
||||||
|
TargetProvider string
|
||||||
|
FullPrompt string // Store for debugging
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModelRouter(cfg *config.Config, providers map[string]provider.Provider, logger *log.Logger) *ModelRouter {
|
||||||
|
router := &ModelRouter{
|
||||||
|
config: cfg,
|
||||||
|
providers: providers,
|
||||||
|
subagentMappings: cfg.Subagents.Mappings,
|
||||||
|
customAgentPrompts: make(map[string]SubagentDefinition),
|
||||||
|
modelProviderMap: initializeModelProviderMap(),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only load custom agents if subagents are enabled
|
||||||
|
if cfg.Subagents.Enable {
|
||||||
|
router.loadCustomAgents()
|
||||||
|
} else {
|
||||||
|
logger.Println("")
|
||||||
|
logger.Println("ℹ️ Subagent routing is disabled")
|
||||||
|
logger.Println(" Enable it in config.yaml to route Claude Code agents to different LLM providers")
|
||||||
|
logger.Println("")
|
||||||
|
}
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeModelProviderMap creates a mapping of model names to their providers
|
||||||
|
func initializeModelProviderMap() map[string]string {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
|
||||||
|
// OpenAI models
|
||||||
|
openaiModels := []string{
|
||||||
|
// GPT-4.1 family
|
||||||
|
"gpt-4.1", "gpt-4.1-2025-04-14",
|
||||||
|
"gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
|
||||||
|
"gpt-4.1-nano", "gpt-4.1-nano-2025-04-14",
|
||||||
|
|
||||||
|
// GPT-4.5
|
||||||
|
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||||
|
|
||||||
|
// GPT-4o variants
|
||||||
|
"gpt-4o", "gpt-4o-2024-08-06",
|
||||||
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
|
|
||||||
|
// GPT-3.5 variants
|
||||||
|
"gpt-3.5-turbo", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-instruct",
|
||||||
|
|
||||||
|
// O1 series
|
||||||
|
"o1", "o1-2024-12-17",
|
||||||
|
"o1-pro", "o1-pro-2025-03-19",
|
||||||
|
"o1-mini", "o1-mini-2024-09-12",
|
||||||
|
|
||||||
|
// O3 series
|
||||||
|
"o3-pro", "o3-pro-2025-06-10",
|
||||||
|
"o3", "o3-2025-04-16",
|
||||||
|
"o3-mini", "o3-mini-2025-01-31",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range openaiModels {
|
||||||
|
modelMap[model] = "openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anthropic models
|
||||||
|
anthropicModels := []string{
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"claude-sonnet-4-20250514",
|
||||||
|
"claude-3-7-sonnet-20250219",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range anthropicModels {
|
||||||
|
modelMap[model] = "anthropic"
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStaticPrompt extracts the portion before "Notes:" if it exists
|
||||||
|
func (r *ModelRouter) extractStaticPrompt(systemPrompt string) string {
|
||||||
|
// Find the "Notes:" section
|
||||||
|
notesIndex := strings.Index(systemPrompt, "\nNotes:")
|
||||||
|
if notesIndex == -1 {
|
||||||
|
notesIndex = strings.Index(systemPrompt, "\n\nNotes:")
|
||||||
|
}
|
||||||
|
|
||||||
|
if notesIndex != -1 {
|
||||||
|
// Return only the part before "Notes:"
|
||||||
|
return strings.TrimSpace(systemPrompt[:notesIndex])
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no "Notes:" section, return the whole prompt
|
||||||
|
return strings.TrimSpace(systemPrompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRouter) loadCustomAgents() {
|
||||||
|
for agentName, targetModel := range r.subagentMappings {
|
||||||
|
// Try loading from project level first, then user level
|
||||||
|
paths := []string{
|
||||||
|
fmt.Sprintf(".claude/agents/%s.md", agentName),
|
||||||
|
fmt.Sprintf("%s/.claude/agents/%s.md", os.Getenv("HOME"), agentName),
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, path := range paths {
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse agent file: metadata\n---\nsystem prompt
|
||||||
|
parts := strings.Split(string(content), "\n---\n")
|
||||||
|
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
systemPrompt := strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
// Extract only the static part (before "Notes:" if it exists)
|
||||||
|
staticPrompt := r.extractStaticPrompt(systemPrompt)
|
||||||
|
hash := r.hashString(staticPrompt)
|
||||||
|
|
||||||
|
// Determine provider for the target model
|
||||||
|
providerName := r.getProviderNameForModel(targetModel)
|
||||||
|
|
||||||
|
r.customAgentPrompts[hash] = SubagentDefinition{
|
||||||
|
Name: agentName,
|
||||||
|
TargetModel: targetModel,
|
||||||
|
TargetProvider: providerName,
|
||||||
|
FullPrompt: staticPrompt,
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log warning if subagent is mapped but definition not found
|
||||||
|
if !found {
|
||||||
|
r.logger.Printf("⚠️ Subagent '%s' is mapped to '%s' but definition file not found in:\n", agentName, targetModel)
|
||||||
|
for _, path := range paths {
|
||||||
|
r.logger.Printf(" - %s\n", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pretty print loaded subagents
|
||||||
|
if len(r.customAgentPrompts) > 0 {
|
||||||
|
r.logger.Println("")
|
||||||
|
r.logger.Println("🤖 Subagent Model Mappings:")
|
||||||
|
r.logger.Println("──────────────────────────────────────")
|
||||||
|
|
||||||
|
for _, def := range r.customAgentPrompts {
|
||||||
|
r.logger.Printf(" \033[36m%s\033[0m → \033[32m%s\033[0m",
|
||||||
|
def.Name, def.TargetModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Println("──────────────────────────────────────")
|
||||||
|
r.logger.Println("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetermineRoute analyzes the request and returns routing information without modifying the request
|
||||||
|
func (r *ModelRouter) DetermineRoute(req *model.AnthropicRequest) (*RoutingDecision, error) {
|
||||||
|
decision := &RoutingDecision{
|
||||||
|
OriginalModel: req.Model,
|
||||||
|
TargetModel: req.Model, // default to original
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if subagents are enabled
|
||||||
|
if !r.config.Subagents.Enable {
|
||||||
|
// Subagents disabled, use default provider
|
||||||
|
providerName := r.getProviderNameForModel(decision.TargetModel)
|
||||||
|
decision.Provider = r.providers[providerName]
|
||||||
|
if decision.Provider == nil {
|
||||||
|
return nil, fmt.Errorf("no provider found for model %s", decision.TargetModel)
|
||||||
|
}
|
||||||
|
return decision, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude Code pattern: Check if we have exactly 2 system messages
|
||||||
|
if len(req.System) == 2 {
|
||||||
|
|
||||||
|
// First should be "You are Claude Code..."
|
||||||
|
if strings.Contains(req.System[0].Text, "You are Claude Code") {
|
||||||
|
// Second message could be either:
|
||||||
|
// 1. A regular Claude Code prompt (no Notes: section)
|
||||||
|
// 2. A subagent prompt (may have Notes: section)
|
||||||
|
|
||||||
|
fullPrompt := req.System[1].Text
|
||||||
|
|
||||||
|
// Extract static portion (before "Notes:" if it exists)
|
||||||
|
staticPrompt := r.extractStaticPrompt(fullPrompt)
|
||||||
|
promptHash := r.hashString(staticPrompt)
|
||||||
|
|
||||||
|
// Check if this matches a known custom agent
|
||||||
|
if definition, exists := r.customAgentPrompts[promptHash]; exists {
|
||||||
|
r.logger.Printf("\033[36m%s\033[0m → \033[32m%s\033[0m",
|
||||||
|
req.Model, definition.TargetModel)
|
||||||
|
|
||||||
|
decision.TargetModel = definition.TargetModel
|
||||||
|
decision.Provider = r.providers[definition.TargetProvider]
|
||||||
|
if decision.Provider == nil {
|
||||||
|
return nil, fmt.Errorf("provider %s not found for model %s",
|
||||||
|
definition.TargetProvider, definition.TargetModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decision, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: use the original model and its provider
|
||||||
|
providerName := r.getProviderNameForModel(decision.TargetModel)
|
||||||
|
decision.Provider = r.providers[providerName]
|
||||||
|
if decision.Provider == nil {
|
||||||
|
return nil, fmt.Errorf("no provider found for model %s", decision.TargetModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decision, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRouter) hashString(s string) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
fullHash := hex.EncodeToString(h.Sum(nil))
|
||||||
|
shortHash := fullHash[:16]
|
||||||
|
return shortHash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRouter) getProviderNameForModel(model string) string {
|
||||||
|
if provider, exists := r.modelProviderMap[model]; exists {
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to anthropic
|
||||||
|
r.logger.Printf("⚠️ Model '%s' doesn't match any known patterns, defaulting to anthropic", model)
|
||||||
|
return "anthropic"
|
||||||
|
}
|
||||||
137
proxy/internal/service/model_router_test.go
Normal file
137
proxy/internal/service/model_router_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelRouter_EdgeCases(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
cfg := &config.Config{
|
||||||
|
Subagents: config.SubagentsConfig{
|
||||||
|
Mappings: map[string]string{
|
||||||
|
"streaming-systems-engineer": "gpt-4o",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := make(map[string]provider.Provider)
|
||||||
|
providers["anthropic"] = nil
|
||||||
|
providers["openai"] = nil
|
||||||
|
|
||||||
|
logger := log.New(os.Stdout, "test: ", log.LstdFlags)
|
||||||
|
router := NewModelRouter(cfg, providers, logger)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *model.AnthropicRequest
|
||||||
|
expectedRoute string
|
||||||
|
expectedModel string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Regular Claude Code request (no Notes section)",
|
||||||
|
request: &model.AnthropicRequest{
|
||||||
|
Model: "claude-3-opus-20240229",
|
||||||
|
System: []model.AnthropicSystemMessage{
|
||||||
|
{Text: "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
|
{Text: "You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedRoute: "anthropic",
|
||||||
|
expectedModel: "claude-3-opus-20240229",
|
||||||
|
description: "Regular Claude Code requests should use original model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-Claude Code request",
|
||||||
|
request: &model.AnthropicRequest{
|
||||||
|
Model: "claude-3-opus-20240229",
|
||||||
|
System: []model.AnthropicSystemMessage{
|
||||||
|
{Text: "You are a helpful assistant."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedRoute: "anthropic",
|
||||||
|
expectedModel: "claude-3-opus-20240229",
|
||||||
|
description: "Non-Claude Code requests should use original model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single system message",
|
||||||
|
request: &model.AnthropicRequest{
|
||||||
|
Model: "claude-3-opus-20240229",
|
||||||
|
System: []model.AnthropicSystemMessage{},
|
||||||
|
},
|
||||||
|
expectedRoute: "anthropic",
|
||||||
|
expectedModel: "claude-3-opus-20240229",
|
||||||
|
description: "Requests with no system messages should use original model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if len(tt.request.System) == 2 {
|
||||||
|
// Test extract static prompt for second message
|
||||||
|
fullPrompt := tt.request.System[1].Text
|
||||||
|
staticPrompt := router.extractStaticPrompt(fullPrompt)
|
||||||
|
|
||||||
|
// Verify no "Notes:" in static prompt
|
||||||
|
if contains(staticPrompt, "Notes:") {
|
||||||
|
t.Errorf("Static prompt should not contain 'Notes:' section")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log for manual verification
|
||||||
|
t.Logf("Test case: %s", tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRouter_ExtractStaticPrompt(t *testing.T) {
|
||||||
|
router := &ModelRouter{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Prompt with Notes section",
|
||||||
|
input: "You are an expert engineer.\n\nNotes:\n- Some dynamic content\n- More notes",
|
||||||
|
expected: "You are an expert engineer.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prompt without Notes section",
|
||||||
|
input: "You are an expert engineer.\nNo notes here.",
|
||||||
|
expected: "You are an expert engineer.\nNo notes here.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prompt with double newline before Notes",
|
||||||
|
input: "You are an expert.\n\nNotes:\nDynamic content",
|
||||||
|
expected: "You are an expert.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty prompt",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := router.extractStaticPrompt(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractStaticPrompt() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||||
|
(len(s) > 0 && len(substr) > 0 && s[0:len(substr)] == substr) ||
|
||||||
|
(len(s) > len(substr) && contains(s[1:], substr)))
|
||||||
|
}
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
@ -50,6 +49,8 @@ func (s *sqliteStorageService) createTables() error {
|
||||||
prompt_grade TEXT,
|
prompt_grade TEXT,
|
||||||
response TEXT,
|
response TEXT,
|
||||||
model TEXT,
|
model TEXT,
|
||||||
|
original_model TEXT,
|
||||||
|
routed_model TEXT,
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -74,8 +75,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model)
|
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err = s.db.Exec(query,
|
_, err = s.db.Exec(query,
|
||||||
|
|
@ -88,6 +89,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
request.UserAgent,
|
request.UserAgent,
|
||||||
request.ContentType,
|
request.ContentType,
|
||||||
request.Model,
|
request.Model,
|
||||||
|
request.OriginalModel,
|
||||||
|
request.RoutedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -108,7 +111,7 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
||||||
// Get paginated results
|
// Get paginated results
|
||||||
offset := (page - 1) * limit
|
offset := (page - 1) * limit
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
|
|
@ -138,21 +141,23 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error scanning row: %v", err)
|
// Error scanning row - skip
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
// Unmarshal JSON fields
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||||
log.Printf("Error unmarshaling headers: %v", err)
|
// Error unmarshaling headers
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var body interface{}
|
var body interface{}
|
||||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||||
log.Printf("Error unmarshaling body: %v", err)
|
// Error unmarshaling body
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
req.Body = body
|
req.Body = body
|
||||||
|
|
@ -228,7 +233,7 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
WHERE id LIKE ?
|
WHERE id LIKE ?
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
|
|
@ -251,6 +256,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
|
@ -294,7 +301,7 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
`
|
`
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
@ -302,7 +309,7 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
||||||
if modelFilter != "" && modelFilter != "all" {
|
if modelFilter != "" && modelFilter != "all" {
|
||||||
query += " WHERE LOWER(model) LIKE ?"
|
query += " WHERE LOWER(model) LIKE ?"
|
||||||
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
|
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
|
||||||
log.Printf("🔍 SQL Query with filter: %s, args: %v", query, args)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
query += " ORDER BY timestamp DESC"
|
query += " ORDER BY timestamp DESC"
|
||||||
|
|
@ -331,23 +338,23 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error scanning row: %v", err)
|
// Error scanning row - skip
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
|
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
// Unmarshal JSON fields
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||||
log.Printf("Error unmarshaling headers: %v", err)
|
// Error unmarshaling headers
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var body interface{}
|
var body interface{}
|
||||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||||
log.Printf("Error unmarshaling body: %v", err)
|
// Error unmarshaling body
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
req.Body = body
|
req.Body = body
|
||||||
|
|
|
||||||
BIN
proxy/proxy
Executable file
BIN
proxy/proxy
Executable file
Binary file not shown.
|
|
@ -108,18 +108,6 @@ export function ConversationThread({ conversation }: ConversationThreadProps) {
|
||||||
|
|
||||||
const messages = analyzeConversationFlow();
|
const messages = analyzeConversationFlow();
|
||||||
|
|
||||||
// Debug logging to identify assistant response issues
|
|
||||||
console.log('Conversation Debug:', {
|
|
||||||
messageCount: conversation.messageCount,
|
|
||||||
totalMessages: messages.length,
|
|
||||||
messages: messages.map(m => ({
|
|
||||||
role: m.role,
|
|
||||||
contentPreview: JSON.stringify(m.content)?.substring(0, 50),
|
|
||||||
turn: m.turnNumber,
|
|
||||||
ts: m.timestamp,
|
|
||||||
})),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (messages.length === 0) {
|
if (messages.length === 0) {
|
||||||
return (
|
return (
|
||||||
<div className="text-center py-12">
|
<div className="text-center py-12">
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ interface Request {
|
||||||
method: string;
|
method: string;
|
||||||
endpoint: string;
|
endpoint: string;
|
||||||
headers: Record<string, string[]>;
|
headers: Record<string, string[]>;
|
||||||
|
originalModel?: string;
|
||||||
|
routedModel?: string;
|
||||||
body?: {
|
body?: {
|
||||||
model?: string;
|
model?: string;
|
||||||
messages?: Array<{
|
messages?: Array<{
|
||||||
|
|
@ -80,7 +82,7 @@ interface RequestDetailContentProps {
|
||||||
export default function RequestDetailContent({ request, onGrade }: RequestDetailContentProps) {
|
export default function RequestDetailContent({ request, onGrade }: RequestDetailContentProps) {
|
||||||
const [expandedSections, setExpandedSections] = useState<Record<string, boolean>>({
|
const [expandedSections, setExpandedSections] = useState<Record<string, boolean>>({
|
||||||
overview: true,
|
overview: true,
|
||||||
conversation: true
|
// conversation: true
|
||||||
});
|
});
|
||||||
const [copied, setCopied] = useState<Record<string, boolean>>({});
|
const [copied, setCopied] = useState<Record<string, boolean>>({});
|
||||||
|
|
||||||
|
|
@ -150,7 +152,7 @@ export default function RequestDetailContent({ request, onGrade }: RequestDetail
|
||||||
<div className="flex items-center space-x-3">
|
<div className="flex items-center space-x-3">
|
||||||
<span className="text-gray-500 font-medium min-w-[80px]">Endpoint:</span>
|
<span className="text-gray-500 font-medium min-w-[80px]">Endpoint:</span>
|
||||||
<code className="text-blue-600 bg-blue-50 px-2 py-1 rounded font-mono text-xs border border-blue-200">
|
<code className="text-blue-600 bg-blue-50 px-2 py-1 rounded font-mono text-xs border border-blue-200">
|
||||||
{request.endpoint}
|
{request.routedModel && request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : request.endpoint}
|
||||||
</code>
|
</code>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -329,12 +331,49 @@ export default function RequestDetailContent({ request, onGrade }: RequestDetail
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{expandedSections.model && (
|
{expandedSections.model && (
|
||||||
<div className="p-6">
|
<div className="p-6 space-y-4">
|
||||||
<div className="grid grid-cols-2 gap-4">
|
{/* Model Routing Information */}
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
{request.routedModel && request.routedModel !== request.originalModel && (
|
||||||
<div className="text-xs text-gray-500 mb-1">Model</div>
|
<div className="bg-gradient-to-r from-purple-50 to-blue-50 border border-purple-200 rounded-xl p-4">
|
||||||
<div className="text-sm font-medium text-gray-900">{request.body.model || 'N/A'}</div>
|
<div className="flex items-center space-x-4">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<span className="text-sm font-semibold text-purple-700">Requested Model</span>
|
||||||
|
<code className="text-xs bg-white px-2 py-1 rounded font-mono border border-purple-200">
|
||||||
|
{request.originalModel || request.body.model}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center space-x-3">
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
<ArrowLeftRight className="w-4 h-4 text-purple-600" />
|
||||||
|
<span className="text-xs text-purple-600 font-medium">Routed to</span>
|
||||||
|
</div>
|
||||||
|
<code className="text-sm bg-white px-3 py-1.5 rounded font-mono font-semibold border border-blue-200 text-blue-700">
|
||||||
|
{request.routedModel}
|
||||||
|
</code>
|
||||||
|
<span className="text-xs bg-blue-100 text-blue-700 px-2 py-1 rounded-full border border-blue-200">
|
||||||
|
{request.routedModel.startsWith('gpt-') || request.routedModel.startsWith('o') ? 'OpenAI' : 'Anthropic'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="text-right">
|
||||||
|
<div className="text-xs text-gray-500 mb-1">Target Endpoint</div>
|
||||||
|
<code className="text-xs bg-white px-2 py-1 rounded font-mono border border-gray-200">
|
||||||
|
{request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : '/v1/messages'}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Model Parameters */}
|
||||||
|
<div className="grid grid-cols-2 gap-4">
|
||||||
|
{!request.routedModel || request.routedModel === request.originalModel ? (
|
||||||
|
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
||||||
|
<div className="text-xs text-gray-500 mb-1">Model</div>
|
||||||
|
<div className="text-sm font-medium text-gray-900">{request.originalModel || request.body.model || 'N/A'}</div>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
||||||
<div className="text-xs text-gray-500 mb-1">Max Tokens</div>
|
<div className="text-xs text-gray-500 mb-1">Max Tokens</div>
|
||||||
<div className="text-sm font-medium text-gray-900">
|
<div className="text-sm font-medium text-gray-900">
|
||||||
|
|
@ -619,6 +658,57 @@ function ResponseDetails({ response }: { response: NonNullable<Request['response
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Token Usage */}
|
||||||
|
{response.body?.usage && (
|
||||||
|
<div className="grid grid-cols-2 lg:grid-cols-4 gap-4">
|
||||||
|
<div className="bg-indigo-50 border border-indigo-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Brain className="w-4 h-4 text-indigo-600" />
|
||||||
|
<span className="text-xs font-medium text-indigo-700">Input Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-indigo-700">
|
||||||
|
{response.body.usage.input_tokens?.toLocaleString() || '0'}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-indigo-700 opacity-75">Prompt</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-emerald-50 border border-emerald-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<MessageCircle className="w-4 h-4 text-emerald-600" />
|
||||||
|
<span className="text-xs font-medium text-emerald-700">Output Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-emerald-700">
|
||||||
|
{response.body.usage.output_tokens?.toLocaleString() || '0'}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-emerald-700 opacity-75">Response</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-amber-50 border border-amber-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Cpu className="w-4 h-4 text-amber-600" />
|
||||||
|
<span className="text-xs font-medium text-amber-700">Total Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-amber-700">
|
||||||
|
{((response.body.usage.input_tokens || 0) + (response.body.usage.output_tokens || 0)).toLocaleString()}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-amber-700 opacity-75">Combined</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{response.body.usage.cache_read_input_tokens && (
|
||||||
|
<div className="bg-green-50 border border-green-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Bot className="w-4 h-4 text-green-600" />
|
||||||
|
<span className="text-xs font-medium text-green-700">Cached Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-green-700">
|
||||||
|
{response.body.usage.cache_read_input_tokens.toLocaleString()}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-green-700 opacity-75">From Cache</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Response Headers */}
|
{/* Response Headers */}
|
||||||
{response.headers && (
|
{response.headers && (
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-xl overflow-hidden">
|
<div className="bg-gray-50 border border-gray-200 rounded-xl overflow-hidden">
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,6 @@ interface TodoListProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function TodoList({ todos }: TodoListProps) {
|
export function TodoList({ todos }: TodoListProps) {
|
||||||
// Debug: Log the structure of the first todo
|
|
||||||
if (todos && todos.length > 0) {
|
|
||||||
console.log('Todo structure:', todos[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!todos || todos.length === 0) {
|
if (!todos || todos.length === 0) {
|
||||||
return (
|
return (
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-lg p-4 text-center">
|
<div className="bg-gray-50 border border-gray-200 rounded-lg p-4 text-center">
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ import {
|
||||||
Copy,
|
Copy,
|
||||||
Check,
|
Check,
|
||||||
Lightbulb,
|
Lightbulb,
|
||||||
Loader2
|
Loader2,
|
||||||
|
ArrowLeftRight
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
|
|
||||||
import RequestDetailContent from "../components/RequestDetailContent";
|
import RequestDetailContent from "../components/RequestDetailContent";
|
||||||
|
|
@ -50,6 +51,8 @@ interface Request {
|
||||||
method: string;
|
method: string;
|
||||||
endpoint: string;
|
endpoint: string;
|
||||||
headers: Record<string, string[]>;
|
headers: Record<string, string[]>;
|
||||||
|
originalModel?: string;
|
||||||
|
routedModel?: string;
|
||||||
body?: {
|
body?: {
|
||||||
model?: string;
|
model?: string;
|
||||||
messages?: Array<{
|
messages?: Array<{
|
||||||
|
|
@ -187,53 +190,8 @@ export default function Index() {
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load requests:', error);
|
console.error('Failed to load requests:', error);
|
||||||
|
|
||||||
// Fallback to example data for demo
|
|
||||||
const exampleRequest = {
|
|
||||||
timestamp: "2025-06-04T23:47:37-04:00",
|
|
||||||
method: "POST",
|
|
||||||
endpoint: "/v1/messages",
|
|
||||||
headers: {
|
|
||||||
"User-Agent": ["claude-cli/1.0.11 (external, cli)"],
|
|
||||||
"Content-Type": ["application/json"],
|
|
||||||
"Anthropic-Version": ["2023-06-01"]
|
|
||||||
},
|
|
||||||
body: {
|
|
||||||
model: "claude-sonnet-4-20250514",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: [{
|
|
||||||
type: "text",
|
|
||||||
text: "I need to extract the complete list of tools available to Claude Code from the request file..."
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens: 32000,
|
|
||||||
temperature: 1,
|
|
||||||
stream: true
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
startTransition(() => {
|
startTransition(() => {
|
||||||
// setRequests([
|
setRequests([]);
|
||||||
// { ...exampleRequest, id: 1 },
|
|
||||||
// {
|
|
||||||
// ...exampleRequest,
|
|
||||||
// id: 2,
|
|
||||||
// timestamp: "2025-06-04T23:45:12-04:00",
|
|
||||||
// endpoint: "/v1/chat/completions",
|
|
||||||
// body: { ...exampleRequest.body, model: "gpt-4-turbo" }
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// ...exampleRequest,
|
|
||||||
// id: 3,
|
|
||||||
// timestamp: "2025-06-04T23:42:33-04:00",
|
|
||||||
// method: "GET",
|
|
||||||
// endpoint: "/v1/models",
|
|
||||||
// body: undefined
|
|
||||||
// }
|
|
||||||
// ]);
|
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
setIsFetching(false);
|
setIsFetching(false);
|
||||||
|
|
@ -363,12 +321,21 @@ export default function Index() {
|
||||||
parts.push(`⏱️ ${seconds}s`);
|
parts.push(`⏱️ ${seconds}s`);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add model if available
|
// Add model if available (use routed model if different from original)
|
||||||
if (request.body?.model) {
|
const model = request.routedModel || request.body?.model;
|
||||||
const modelShort = request.body.model.includes('opus') ? 'Opus' :
|
if (model) {
|
||||||
request.body.model.includes('sonnet') ? 'Sonnet' :
|
const modelShort = model.includes('opus') ? 'Opus' :
|
||||||
request.body.model.includes('haiku') ? 'Haiku' : 'Model';
|
model.includes('sonnet') ? 'Sonnet' :
|
||||||
|
model.includes('haiku') ? 'Haiku' :
|
||||||
|
model.includes('gpt-4o') ? 'gpt-4o' :
|
||||||
|
model.includes('o3') ? 'o3' :
|
||||||
|
model.includes('o3-mini') ? 'o3-mini' : 'Model';
|
||||||
parts.push(`🤖 ${modelShort}`);
|
parts.push(`🤖 ${modelShort}`);
|
||||||
|
|
||||||
|
// Show routing info if model was routed
|
||||||
|
if (request.routedModel && request.originalModel && request.routedModel !== request.originalModel) {
|
||||||
|
parts.push(`→ routed`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return parts.length > 0 ? parts.join(' • ') : '📡 API request';
|
return parts.length > 0 ? parts.join(' • ') : '📡 API request';
|
||||||
|
|
@ -516,6 +483,26 @@ export default function Index() {
|
||||||
}
|
}
|
||||||
}, [viewMode, modelFilter]);
|
}, [viewMode, modelFilter]);
|
||||||
|
|
||||||
|
// Handle escape key to close modals
|
||||||
|
useEffect(() => {
|
||||||
|
const handleEscapeKey = (event: KeyboardEvent) => {
|
||||||
|
if (event.key === 'Escape') {
|
||||||
|
if (isModalOpen) {
|
||||||
|
closeModal();
|
||||||
|
} else if (isConversationModalOpen) {
|
||||||
|
setIsConversationModalOpen(false);
|
||||||
|
setSelectedConversation(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener('keydown', handleEscapeKey);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('keydown', handleEscapeKey);
|
||||||
|
};
|
||||||
|
}, [isModalOpen, isConversationModalOpen]);
|
||||||
|
|
||||||
const filteredRequests = filterRequests(filter);
|
const filteredRequests = filterRequests(filter);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -671,13 +658,25 @@ export default function Index() {
|
||||||
{/* Model and Status */}
|
{/* Model and Status */}
|
||||||
<div className="flex items-center space-x-3 mb-1">
|
<div className="flex items-center space-x-3 mb-1">
|
||||||
<h3 className="text-sm font-medium">
|
<h3 className="text-sm font-medium">
|
||||||
{request.body?.model ? (
|
{request.routedModel || request.body?.model ? (
|
||||||
request.body.model.includes('opus') ? <span className="text-purple-600 font-semibold">Opus</span> :
|
// Use routedModel if available, otherwise fall back to body.model
|
||||||
request.body.model.includes('sonnet') ? <span className="text-indigo-600 font-semibold">Sonnet</span> :
|
(() => {
|
||||||
request.body.model.includes('haiku') ? <span className="text-teal-600 font-semibold">Haiku</span> :
|
const model = request.routedModel || request.body?.model || '';
|
||||||
<span className="text-gray-900">{request.body.model.split('-')[0]}</span>
|
if (model.includes('opus')) return <span className="text-purple-600 font-semibold">Opus</span>;
|
||||||
|
if (model.includes('sonnet')) return <span className="text-indigo-600 font-semibold">Sonnet</span>;
|
||||||
|
if (model.includes('haiku')) return <span className="text-teal-600 font-semibold">Haiku</span>;
|
||||||
|
if (model.includes('gpt-4o')) return <span className="text-green-600 font-semibold">GPT-4o</span>;
|
||||||
|
if (model.includes('gpt')) return <span className="text-green-600 font-semibold">GPT</span>;
|
||||||
|
return <span className="text-gray-900">{model.split('-')[0]}</span>;
|
||||||
|
})()
|
||||||
) : <span className="text-gray-900">API</span>}
|
) : <span className="text-gray-900">API</span>}
|
||||||
</h3>
|
</h3>
|
||||||
|
{request.routedModel && request.routedModel !== request.originalModel && (
|
||||||
|
<span className="text-xs px-1.5 py-0.5 bg-blue-100 text-blue-700 rounded font-medium flex items-center space-x-1">
|
||||||
|
<ArrowLeftRight className="w-3 h-3" />
|
||||||
|
<span>routed</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
{request.response?.statusCode && (
|
{request.response?.statusCode && (
|
||||||
<span className={`text-xs font-medium px-1.5 py-0.5 rounded ${
|
<span className={`text-xs font-medium px-1.5 py-0.5 rounded ${
|
||||||
request.response.statusCode >= 200 && request.response.statusCode < 300
|
request.response.statusCode >= 200 && request.response.statusCode < 300
|
||||||
|
|
@ -698,7 +697,7 @@ export default function Index() {
|
||||||
|
|
||||||
{/* Endpoint */}
|
{/* Endpoint */}
|
||||||
<div className="text-xs text-gray-600 font-mono mb-1">
|
<div className="text-xs text-gray-600 font-mono mb-1">
|
||||||
{request.endpoint}
|
{request.routedModel && request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : request.endpoint}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Metrics Row */}
|
{/* Metrics Row */}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue