Working version with openai
This commit is contained in:
parent
80a25f7ba7
commit
1e0173c768
11 changed files with 904 additions and 179 deletions
|
|
@ -81,45 +81,95 @@ func Load() (*Config, error) {
|
||||||
// Start with default configuration
|
// Start with default configuration
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: getEnv("PORT", "3001"),
|
Port: "3001",
|
||||||
ReadTimeout: getDuration("READ_TIMEOUT", 600*time.Second),
|
ReadTimeout: 600 * time.Second,
|
||||||
WriteTimeout: getDuration("WRITE_TIMEOUT", 600*time.Second),
|
WriteTimeout: 600 * time.Second,
|
||||||
IdleTimeout: getDuration("IDLE_TIMEOUT", 600*time.Second),
|
IdleTimeout: 600 * time.Second,
|
||||||
},
|
},
|
||||||
Providers: ProvidersConfig{
|
Providers: ProvidersConfig{
|
||||||
Anthropic: AnthropicProviderConfig{
|
Anthropic: AnthropicProviderConfig{
|
||||||
BaseURL: getEnv("ANTHROPIC_FORWARD_URL", "https://api.anthropic.com"),
|
BaseURL: "https://api.anthropic.com",
|
||||||
Version: getEnv("ANTHROPIC_VERSION", "2023-06-01"),
|
Version: "2023-06-01",
|
||||||
MaxRetries: getInt("ANTHROPIC_MAX_RETRIES", 3),
|
MaxRetries: 3,
|
||||||
},
|
},
|
||||||
OpenAI: OpenAIProviderConfig{
|
OpenAI: OpenAIProviderConfig{
|
||||||
BaseURL: getEnv("OPENAI_BASE_URL", "https://api.openai.com"),
|
BaseURL: "https://api.openai.com",
|
||||||
APIKey: getEnv("OPENAI_API_KEY", ""),
|
APIKey: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Storage: StorageConfig{
|
Storage: StorageConfig{
|
||||||
DBPath: getEnv("DB_PATH", "requests.db"),
|
DBPath: "requests.db",
|
||||||
},
|
},
|
||||||
Subagents: SubagentsConfig{
|
Subagents: SubagentsConfig{
|
||||||
Mappings: make(map[string]string),
|
Mappings: make(map[string]string),
|
||||||
},
|
},
|
||||||
// Legacy field for backward compatibility
|
|
||||||
Anthropic: AnthropicConfig{
|
|
||||||
BaseURL: getEnv("ANTHROPIC_FORWARD_URL", "https://api.anthropic.com"),
|
|
||||||
Version: getEnv("ANTHROPIC_VERSION", "2023-06-01"),
|
|
||||||
MaxRetries: getInt("ANTHROPIC_MAX_RETRIES", 3),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to load from YAML config file if specified
|
// Try to load config.yaml from the project root
|
||||||
configPath := getEnv("CONFIG_PATH", "../config.yaml")
|
// The proxy binary is in proxy/ directory, config.yaml is in the parent
|
||||||
if configPath != "" {
|
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
|
||||||
if err := cfg.loadFromFile(configPath); err != nil {
|
|
||||||
// Log error but continue with defaults
|
// If that doesn't work, try relative to current directory
|
||||||
fmt.Printf("Warning: Failed to load config from %s: %v\n", configPath, err)
|
if _, err := os.Stat(configPath); err != nil {
|
||||||
|
// Try common locations relative to where the binary might be run
|
||||||
|
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
|
||||||
|
if _, err := os.Stat(tryPath); err == nil {
|
||||||
|
configPath = tryPath
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := cfg.loadFromFile(configPath); err == nil {
|
||||||
|
fmt.Printf("Loaded config from %s\n", configPath)
|
||||||
|
fmt.Printf("Subagent mappings: %+v\n", cfg.Subagents.Mappings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply environment variable overrides AFTER loading from file
|
||||||
|
if envPort := os.Getenv("PORT"); envPort != "" {
|
||||||
|
cfg.Server.Port = envPort
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("READ_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.ReadTimeout = getDuration("READ_TIMEOUT", cfg.Server.ReadTimeout)
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("WRITE_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.WriteTimeout = getDuration("WRITE_TIMEOUT", cfg.Server.WriteTimeout)
|
||||||
|
}
|
||||||
|
if envTimeout := os.Getenv("IDLE_TIMEOUT"); envTimeout != "" {
|
||||||
|
cfg.Server.IdleTimeout = getDuration("IDLE_TIMEOUT", cfg.Server.IdleTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Anthropic settings
|
||||||
|
if envURL := os.Getenv("ANTHROPIC_FORWARD_URL"); envURL != "" {
|
||||||
|
cfg.Providers.Anthropic.BaseURL = envURL
|
||||||
|
}
|
||||||
|
if envVersion := os.Getenv("ANTHROPIC_VERSION"); envVersion != "" {
|
||||||
|
cfg.Providers.Anthropic.Version = envVersion
|
||||||
|
}
|
||||||
|
if envRetries := os.Getenv("ANTHROPIC_MAX_RETRIES"); envRetries != "" {
|
||||||
|
cfg.Providers.Anthropic.MaxRetries = getInt("ANTHROPIC_MAX_RETRIES", cfg.Providers.Anthropic.MaxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override OpenAI settings
|
||||||
|
if envURL := os.Getenv("OPENAI_BASE_URL"); envURL != "" {
|
||||||
|
cfg.Providers.OpenAI.BaseURL = envURL
|
||||||
|
}
|
||||||
|
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
||||||
|
cfg.Providers.OpenAI.APIKey = envKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override storage settings
|
||||||
|
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
||||||
|
cfg.Storage.DBPath = envPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync legacy Anthropic config
|
||||||
|
cfg.Anthropic = AnthropicConfig{
|
||||||
|
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||||
|
Version: cfg.Providers.Anthropic.Version,
|
||||||
|
MaxRetries: cfg.Providers.Anthropic.MaxRetries,
|
||||||
|
}
|
||||||
|
|
||||||
// After loading from file, apply any timeout conversions if needed
|
// After loading from file, apply any timeout conversions if needed
|
||||||
if cfg.Server.Timeouts.Read != "" {
|
if cfg.Server.Timeouts.Read != "" {
|
||||||
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Read); err == nil {
|
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Read); err == nil {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ func TestLoad(t *testing.T) {
|
||||||
originalPort := os.Getenv("PORT")
|
originalPort := os.Getenv("PORT")
|
||||||
originalAnthropicURL := os.Getenv("ANTHROPIC_FORWARD_URL")
|
originalAnthropicURL := os.Getenv("ANTHROPIC_FORWARD_URL")
|
||||||
originalOpenAIKey := os.Getenv("OPENAI_API_KEY")
|
originalOpenAIKey := os.Getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
// Restore after test
|
// Restore after test
|
||||||
defer func() {
|
defer func() {
|
||||||
os.Setenv("CONFIG_PATH", originalConfigPath)
|
os.Setenv("CONFIG_PATH", originalConfigPath)
|
||||||
|
|
@ -82,7 +82,7 @@ subagents:
|
||||||
// Clear environment variables
|
// Clear environment variables
|
||||||
os.Unsetenv("CONFIG_PATH")
|
os.Unsetenv("CONFIG_PATH")
|
||||||
os.Unsetenv("PORT")
|
os.Unsetenv("PORT")
|
||||||
|
|
||||||
// Create empty config directory
|
// Create empty config directory
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
os.Setenv("CONFIG_PATH", filepath.Join(tempDir, "nonexistent.yaml"))
|
os.Setenv("CONFIG_PATH", filepath.Join(tempDir, "nonexistent.yaml"))
|
||||||
|
|
@ -188,7 +188,7 @@ func TestConfig_ParseTimeouts(t *testing.T) {
|
||||||
{"Valid minutes", "5m", 5, false},
|
{"Valid minutes", "5m", 5, false},
|
||||||
{"Valid seconds", "30s", 0, false}, // Will be 30 seconds, not minutes
|
{"Valid seconds", "30s", 0, false}, // Will be 30 seconds, not minutes
|
||||||
{"Valid hours", "2h", 120, false},
|
{"Valid hours", "2h", 120, false},
|
||||||
{"Empty string", "", 10, false}, // Should use default
|
{"Empty string", "", 10, false}, // Should use default
|
||||||
{"Invalid format", "invalid", 10, false}, // Should use default
|
{"Invalid format", "invalid", 10, false}, // Should use default
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -199,4 +199,4 @@ func TestConfig_ParseTimeouts(t *testing.T) {
|
||||||
// For now, we'll skip the implementation details
|
// For now, we'll skip the implementation details
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
@ -68,23 +70,6 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
requestID := generateRequestID()
|
requestID := generateRequestID()
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
// Create request log
|
|
||||||
requestLog := &model.RequestLog{
|
|
||||||
RequestID: requestID,
|
|
||||||
Timestamp: time.Now().Format(time.RFC3339),
|
|
||||||
Method: r.Method,
|
|
||||||
Endpoint: "/v1/messages",
|
|
||||||
Headers: SanitizeHeaders(r.Header),
|
|
||||||
Body: req,
|
|
||||||
Model: req.Model,
|
|
||||||
UserAgent: r.Header.Get("User-Agent"),
|
|
||||||
ContentType: r.Header.Get("Content-Type"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
|
||||||
log.Printf("❌ Error saving request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use model router to determine provider and route the request
|
// Use model router to determine provider and route the request
|
||||||
provider, originalModel, err := h.modelRouter.RouteRequest(&req)
|
provider, originalModel, err := h.modelRouter.RouteRequest(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -93,9 +78,44 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update request log with original model (for tracking)
|
// Create request log with routing information
|
||||||
requestLog.OriginalModel = originalModel
|
requestLog := &model.RequestLog{
|
||||||
requestLog.RoutedModel = req.Model
|
RequestID: requestID,
|
||||||
|
Timestamp: time.Now().Format(time.RFC3339),
|
||||||
|
Method: r.Method,
|
||||||
|
Endpoint: "/v1/messages",
|
||||||
|
Headers: SanitizeHeaders(r.Header),
|
||||||
|
Body: req,
|
||||||
|
Model: req.Model,
|
||||||
|
OriginalModel: originalModel,
|
||||||
|
RoutedModel: req.Model,
|
||||||
|
UserAgent: r.Header.Get("User-Agent"),
|
||||||
|
ContentType: r.Header.Get("Content-Type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
||||||
|
log.Printf("❌ Error saving request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the model was changed by routing, update the request body
|
||||||
|
if req.Model != originalModel {
|
||||||
|
// Re-marshal the request with the updated model
|
||||||
|
updatedBodyBytes, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("❌ Error marshaling updated request: %v", err)
|
||||||
|
writeErrorResponse(w, "Failed to process request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request with the updated body
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(updatedBodyBytes))
|
||||||
|
r.ContentLength = int64(len(updatedBodyBytes))
|
||||||
|
r.Header.Set("Content-Length", fmt.Sprintf("%d", len(updatedBodyBytes)))
|
||||||
|
|
||||||
|
// Update the context with new body bytes for logging
|
||||||
|
ctx := context.WithValue(r.Context(), model.BodyBytesKey, updatedBodyBytes)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// Forward the request to the selected provider
|
// Forward the request to the selected provider
|
||||||
resp, err := provider.ForwardRequest(r.Context(), r)
|
resp, err := provider.ForwardRequest(r.Context(), r)
|
||||||
|
|
@ -182,8 +202,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
modelFilter = "all"
|
modelFilter = "all"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 GetRequests called - page: %d, limit: %d, modelFilter: %s", page, limit, modelFilter)
|
|
||||||
|
|
||||||
// Get all requests with model filter applied at storage level
|
// Get all requests with model filter applied at storage level
|
||||||
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -192,8 +210,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 Got %d requests from storage (filter: %s)", len(allRequests), modelFilter)
|
|
||||||
|
|
||||||
// Convert pointers to values for consistency
|
// Convert pointers to values for consistency
|
||||||
requests := make([]model.RequestLog, len(allRequests))
|
requests := make([]model.RequestLog, len(allRequests))
|
||||||
for i, req := range allRequests {
|
for i, req := range allRequests {
|
||||||
|
|
@ -217,8 +233,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
requests = requests[start:end]
|
requests = requests[start:end]
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("📊 Returning %d requests after pagination", len(requests))
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(struct {
|
json.NewEncoder(w).Encode(struct {
|
||||||
Requests []model.RequestLog `json:"requests"`
|
Requests []model.RequestLog `json:"requests"`
|
||||||
|
|
@ -314,7 +328,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture usage data and metadata from message_start event
|
// Capture metadata from message_start event
|
||||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_start" {
|
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_start" {
|
||||||
if message, ok := genericEvent["message"].(map[string]interface{}); ok {
|
if message, ok := genericEvent["message"].(map[string]interface{}); ok {
|
||||||
// Capture message metadata
|
// Capture message metadata
|
||||||
|
|
@ -327,51 +341,42 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
if reason, ok := message["stop_reason"].(string); ok {
|
if reason, ok := message["stop_reason"].(string); ok {
|
||||||
stopReason = reason
|
stopReason = reason
|
||||||
}
|
}
|
||||||
|
// Don't capture usage from message_start - it will come in message_delta
|
||||||
// Capture initial usage data from message_start
|
|
||||||
if usage, ok := message["usage"].(map[string]interface{}); ok {
|
|
||||||
finalUsage = &model.AnthropicUsage{}
|
|
||||||
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
|
||||||
finalUsage.InputTokens = int(inputTokens)
|
|
||||||
}
|
|
||||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
|
||||||
finalUsage.OutputTokens = int(outputTokens)
|
|
||||||
}
|
|
||||||
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
|
||||||
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
|
||||||
}
|
|
||||||
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
|
||||||
finalUsage.CacheReadInputTokens = int(cacheRead)
|
|
||||||
}
|
|
||||||
if tier, ok := usage["service_tier"].(string); ok {
|
|
||||||
finalUsage.ServiceTier = tier
|
|
||||||
}
|
|
||||||
log.Printf("📊 Captured initial usage from message_start: %+v", finalUsage)
|
|
||||||
} else {
|
|
||||||
log.Printf("⚠️ No usage data found in message_start event")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update output tokens from message_delta event
|
// Capture usage data from message_delta event
|
||||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_delta" {
|
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_delta" {
|
||||||
// Usage is at top level for message_delta events
|
// Usage is at top level for message_delta events
|
||||||
if usage, ok := genericEvent["usage"].(map[string]interface{}); ok {
|
if usage, ok := genericEvent["usage"].(map[string]interface{}); ok {
|
||||||
if finalUsage != nil {
|
// Create finalUsage if it doesn't exist yet
|
||||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
if finalUsage == nil {
|
||||||
finalUsage.OutputTokens = int(outputTokens)
|
finalUsage = &model.AnthropicUsage{}
|
||||||
log.Printf("📊 Updated output tokens from message_delta: %d", int(outputTokens))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("⚠️ finalUsage is nil when trying to update from message_delta usage")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture all usage fields
|
||||||
|
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
||||||
|
finalUsage.InputTokens = int(inputTokens)
|
||||||
|
}
|
||||||
|
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
||||||
|
finalUsage.OutputTokens = int(outputTokens)
|
||||||
|
}
|
||||||
|
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
||||||
|
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
||||||
|
}
|
||||||
|
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
||||||
|
finalUsage.CacheReadInputTokens = int(cacheRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("📊 Captured usage from message_delta: %+v", finalUsage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse as structured event for content processing
|
// Parse as structured event for content processing
|
||||||
var event model.StreamingEvent
|
var event model.StreamingEvent
|
||||||
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
||||||
continue // Skip if structured parsing fails, but we already got the usage data above
|
// Skip if structured parsing fails, but we already got the usage data above
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ import (
|
||||||
|
|
||||||
// MockStorageService implements StorageService interface for testing
|
// MockStorageService implements StorageService interface for testing
|
||||||
type MockStorageService struct {
|
type MockStorageService struct {
|
||||||
SavedRequests []model.RequestLog
|
SavedRequests []model.RequestLog
|
||||||
ReturnError error
|
ReturnError error
|
||||||
RequestsToReturn []model.RequestLog
|
RequestsToReturn []model.RequestLog
|
||||||
TotalRequests int
|
TotalRequests int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
func (m *MockStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
||||||
|
|
@ -87,8 +87,8 @@ func (m *MockStorageService) GetAllRequests(modelFilter string) ([]*model.Reques
|
||||||
|
|
||||||
// MockAnthropicService implements AnthropicService interface for testing
|
// MockAnthropicService implements AnthropicService interface for testing
|
||||||
type MockAnthropicService struct {
|
type MockAnthropicService struct {
|
||||||
ReturnResponse *http.Response
|
ReturnResponse *http.Response
|
||||||
ReturnError error
|
ReturnError error
|
||||||
ReceivedRequest *http.Request
|
ReceivedRequest *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,8 +103,8 @@ func (m *MockAnthropicService) ForwardRequest(ctx context.Context, originalReq *
|
||||||
// Return a default successful response
|
// Return a default successful response
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
StatusCode: 200,
|
StatusCode: 200,
|
||||||
Body: io.NopCloser(bytes.NewBufferString(`{"id":"test","content":[{"text":"Hello"}]}`)),
|
Body: io.NopCloser(bytes.NewBufferString(`{"id":"test","content":[{"text":"Hello"}]}`)),
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -284,4 +284,4 @@ func TestDeleteRequestsEndpoint(t *testing.T) {
|
||||||
if response["deleted"] != float64(2) { // JSON unmarshals numbers as float64
|
if response["deleted"] != float64(2) { // JSON unmarshals numbers as float64
|
||||||
t.Errorf("Expected 2 deleted requests, got %v", response["deleted"])
|
t.Errorf("Expected 2 deleted requests, got %v", response["deleted"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@ import (
|
||||||
func Logging(next http.Handler) http.Handler {
|
func Logging(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
log.Printf("%s - %s %s", start.Format(time.RFC3339), r.Method, r.URL.Path)
|
|
||||||
log.Printf("Headers: %s", formatHeaders(r.Header))
|
|
||||||
|
|
||||||
var bodyBytes []byte
|
var bodyBytes []byte
|
||||||
if r.Body != nil {
|
if r.Body != nil {
|
||||||
|
|
@ -35,12 +33,6 @@ func Logging(next http.Handler) http.Handler {
|
||||||
ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes)
|
ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
log.Printf("Body length: %d bytes", len(bodyBytes))
|
|
||||||
if len(bodyBytes) > 0 {
|
|
||||||
logRequestBody(bodyBytes)
|
|
||||||
}
|
|
||||||
log.Println("---")
|
|
||||||
|
|
||||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||||
next.ServeHTTP(wrapped, r)
|
next.ServeHTTP(wrapped, r)
|
||||||
|
|
||||||
|
|
@ -77,17 +69,6 @@ func sanitizeHeaderValue(key string, values []string) []string {
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func logRequestBody(bodyBytes []byte) {
|
|
||||||
var bodyJSON interface{}
|
|
||||||
if err := json.Unmarshal(bodyBytes, &bodyJSON); err == nil {
|
|
||||||
bodyStr, _ := json.MarshalIndent(bodyJSON, "", " ")
|
|
||||||
log.Printf("Body: %s", string(bodyStr))
|
|
||||||
} else {
|
|
||||||
log.Printf("❌ Failed to parse body as JSON: %v", err)
|
|
||||||
log.Printf("Raw body: %s", string(bodyBytes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
statusCode int
|
statusCode int
|
||||||
|
|
|
||||||
|
|
@ -131,14 +131,9 @@ type Tool struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputSchema struct {
|
type InputSchema struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Properties map[string]Property `json:"properties"`
|
Properties map[string]interface{} `json:"properties"`
|
||||||
Required []string `json:"required,omitempty"`
|
Required []string `json:"required,omitempty"`
|
||||||
}
|
|
||||||
|
|
||||||
type Property struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnthropicRequest struct {
|
type AnthropicRequest struct {
|
||||||
|
|
@ -149,6 +144,7 @@ type AnthropicRequest struct {
|
||||||
System []AnthropicSystemMessage `json:"system,omitempty"`
|
System []AnthropicSystemMessage `json:"system,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelsResponse struct {
|
type ModelsResponse struct {
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,16 @@
|
||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
|
@ -88,6 +91,47 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
||||||
return nil, fmt.Errorf("failed to forward request: %w", err)
|
return nil, fmt.Errorf("failed to forward request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for error responses
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
// Read the error body for debugging
|
||||||
|
errorBody, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Log the error details
|
||||||
|
fmt.Printf("OpenAI API error: Status=%d, Body=%s\n", resp.StatusCode, string(errorBody))
|
||||||
|
|
||||||
|
// Create an error response in Anthropic format
|
||||||
|
errorResp := map[string]interface{}{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]interface{}{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": fmt.Sprintf("OpenAI API error: %s", string(errorBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
errorJSON, _ := json.Marshal(errorResp)
|
||||||
|
|
||||||
|
// Create a new response with the error
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(errorJSON))
|
||||||
|
resp.Header.Set("Content-Type", "application/json")
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.ContentLength = int64(len(errorJSON))
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle gzip-encoded responses
|
||||||
|
var bodyReader io.ReadCloser = resp.Body
|
||||||
|
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
gzReader, err := gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
bodyReader = gzReader
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
|
}
|
||||||
|
|
||||||
// For streaming responses, we need to convert back to Anthropic format
|
// For streaming responses, we need to convert back to Anthropic format
|
||||||
if anthropicReq.Stream {
|
if anthropicReq.Stream {
|
||||||
// Create a pipe to transform the response
|
// Create a pipe to transform the response
|
||||||
|
|
@ -96,15 +140,16 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
||||||
// Start a goroutine to transform the stream
|
// Start a goroutine to transform the stream
|
||||||
go func() {
|
go func() {
|
||||||
defer pw.Close()
|
defer pw.Close()
|
||||||
transformOpenAIStreamToAnthropic(resp.Body, pw)
|
defer bodyReader.Close()
|
||||||
|
transformOpenAIStreamToAnthropic(bodyReader, pw)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Replace the response body with our transformed stream
|
// Replace the response body with our transformed stream
|
||||||
resp.Body = pr
|
resp.Body = pr
|
||||||
} else {
|
} else {
|
||||||
// For non-streaming, read and convert the response
|
// For non-streaming, read and convert the response
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(bodyReader)
|
||||||
resp.Body.Close()
|
bodyReader.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -122,41 +167,315 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
||||||
func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{} {
|
func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{} {
|
||||||
messages := []map[string]interface{}{}
|
messages := []map[string]interface{}{}
|
||||||
|
|
||||||
// Add system messages
|
// Combine all system messages into a single system message for OpenAI
|
||||||
for _, sysMsg := range req.System {
|
if len(req.System) > 0 {
|
||||||
|
systemContent := ""
|
||||||
|
for i, sysMsg := range req.System {
|
||||||
|
if i > 0 {
|
||||||
|
systemContent += "\n\n"
|
||||||
|
}
|
||||||
|
systemContent += sysMsg.Text
|
||||||
|
}
|
||||||
messages = append(messages, map[string]interface{}{
|
messages = append(messages, map[string]interface{}{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": sysMsg.Text,
|
"content": systemContent,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add conversation messages
|
// Add conversation messages
|
||||||
for _, msg := range req.Messages {
|
for _, msg := range req.Messages {
|
||||||
// Get content blocks from the message
|
// Handle messages with raw content that may contain tool results
|
||||||
contentBlocks := msg.GetContentBlocks()
|
if contentArray, ok := msg.Content.([]interface{}); ok {
|
||||||
content := ""
|
// Check if this message contains tool results
|
||||||
if len(contentBlocks) > 0 {
|
hasToolResults := false
|
||||||
// Use the first text block
|
for _, item := range contentArray {
|
||||||
content = contentBlocks[0].Text
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
}
|
if blockType, hasType := block["type"].(string); hasType && blockType == "tool_result" {
|
||||||
|
hasToolResults = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
messages = append(messages, map[string]interface{}{
|
if hasToolResults {
|
||||||
"role": msg.Role,
|
textContent := ""
|
||||||
"content": content,
|
|
||||||
})
|
for _, item := range contentArray {
|
||||||
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
|
if blockType, hasType := block["type"].(string); hasType {
|
||||||
|
if blockType == "text" {
|
||||||
|
if text, hasText := block["text"].(string); hasText {
|
||||||
|
textContent += text + "\n"
|
||||||
|
}
|
||||||
|
} else if blockType == "tool_result" {
|
||||||
|
// Extract tool ID
|
||||||
|
toolID := ""
|
||||||
|
if id, hasID := block["tool_use_id"].(string); hasID {
|
||||||
|
toolID = id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different formats of tool result content
|
||||||
|
resultContent := ""
|
||||||
|
if content, hasContent := block["content"]; hasContent {
|
||||||
|
if contentStr, ok := content.(string); ok {
|
||||||
|
resultContent = contentStr
|
||||||
|
} else if contentList, ok := content.([]interface{}); ok {
|
||||||
|
// If content is a list of blocks, extract text from each
|
||||||
|
for _, c := range contentList {
|
||||||
|
if contentMap, ok := c.(map[string]interface{}); ok {
|
||||||
|
if contentMap["type"] == "text" {
|
||||||
|
if text, ok := contentMap["text"].(string); ok {
|
||||||
|
resultContent += text + "\n"
|
||||||
|
}
|
||||||
|
} else if text, hasText := contentMap["text"]; hasText {
|
||||||
|
// Handle any dict by trying to extract text
|
||||||
|
resultContent += fmt.Sprintf("%v\n", text)
|
||||||
|
} else {
|
||||||
|
// Try to JSON serialize
|
||||||
|
if jsonBytes, err := json.Marshal(contentMap); err == nil {
|
||||||
|
resultContent += string(jsonBytes) + "\n"
|
||||||
|
} else {
|
||||||
|
resultContent += fmt.Sprintf("%v\n", contentMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if contentDict, ok := content.(map[string]interface{}); ok {
|
||||||
|
// Handle dictionary content
|
||||||
|
if contentDict["type"] == "text" {
|
||||||
|
if text, ok := contentDict["text"].(string); ok {
|
||||||
|
resultContent = text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to JSON serialize
|
||||||
|
if jsonBytes, err := json.Marshal(contentDict); err == nil {
|
||||||
|
resultContent = string(jsonBytes)
|
||||||
|
} else {
|
||||||
|
resultContent = fmt.Sprintf("%v", contentDict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle any other type by converting to string
|
||||||
|
if jsonBytes, err := json.Marshal(content); err == nil {
|
||||||
|
resultContent = string(jsonBytes)
|
||||||
|
} else {
|
||||||
|
resultContent = fmt.Sprintf("%v", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In OpenAI format, tool results come from the user (matching Python behavior)
|
||||||
|
textContent += fmt.Sprintf("Tool result for %s:\n%s\n", toolID, resultContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add as a single user message with all the content
|
||||||
|
if textContent == "" {
|
||||||
|
textContent = "..."
|
||||||
|
}
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": strings.TrimSpace(textContent),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Handle regular messages with content blocks
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
for _, item := range contentArray {
|
||||||
|
if block, ok := item.(map[string]interface{}); ok {
|
||||||
|
if blockType, hasType := block["type"].(string); hasType && blockType == "text" {
|
||||||
|
if text, hasText := block["text"].(string); hasText {
|
||||||
|
if content != "" {
|
||||||
|
content += "\n"
|
||||||
|
}
|
||||||
|
content += text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure content is never empty
|
||||||
|
if content == "" {
|
||||||
|
content = "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle simple string content
|
||||||
|
contentBlocks := msg.GetContentBlocks()
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
// Concatenate all text blocks
|
||||||
|
for _, block := range contentBlocks {
|
||||||
|
if block.Type == "text" {
|
||||||
|
if content != "" {
|
||||||
|
content += "\n"
|
||||||
|
}
|
||||||
|
content += block.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure content is never empty
|
||||||
|
if content == "" {
|
||||||
|
content = "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if max_tokens exceeds the model's limit and cap it if necessary
|
||||||
|
maxTokensLimit := 16384 // Assuming this is the limit for the model
|
||||||
|
if req.MaxTokens > maxTokensLimit {
|
||||||
|
fmt.Printf("Warning: max_tokens is too large: %d. Capping to %d.\n", req.MaxTokens, maxTokensLimit)
|
||||||
|
req.MaxTokens = maxTokensLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All OpenAI models now use max_completion_tokens instead of deprecated max_tokens
|
||||||
openAIReq := map[string]interface{}{
|
openAIReq := map[string]interface{}{
|
||||||
"model": req.Model,
|
"model": req.Model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": req.Temperature,
|
"stream": req.Stream,
|
||||||
"max_tokens": req.MaxTokens,
|
"max_completion_tokens": req.MaxTokens,
|
||||||
"stream": req.Stream,
|
}
|
||||||
|
|
||||||
|
// If streaming is enabled, request usage data to be included in the final chunk
|
||||||
|
if req.Stream {
|
||||||
|
openAIReq["stream_options"] = map[string]interface{}{
|
||||||
|
"include_usage": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is an o-series model (they don't support temperature)
|
||||||
|
isOSeriesModel := strings.HasPrefix(req.Model, "o1") || strings.HasPrefix(req.Model, "o3")
|
||||||
|
|
||||||
|
// Only include temperature for non-o-series models
|
||||||
|
if !isOSeriesModel {
|
||||||
|
openAIReq["temperature"] = req.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Using max_completion_tokens=%d for model %s\n", req.MaxTokens, req.Model)
|
||||||
|
|
||||||
|
// Convert Anthropic tools to OpenAI format
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
tools := make([]map[string]interface{}, 0, len(req.Tools))
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
// Ensure tool has required fields
|
||||||
|
if tool.Name == "" {
|
||||||
|
fmt.Printf("Warning: Skipping tool with empty name\n")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build parameters with error checking
|
||||||
|
parameters := make(map[string]interface{})
|
||||||
|
parameters["type"] = tool.InputSchema.Type
|
||||||
|
if parameters["type"] == "" {
|
||||||
|
parameters["type"] = "object" // Default to object type
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle properties safely with array validation
|
||||||
|
if tool.InputSchema.Properties != nil {
|
||||||
|
// Fix array properties that are missing items field
|
||||||
|
fixedProperties := make(map[string]interface{})
|
||||||
|
for propName, propValue := range tool.InputSchema.Properties {
|
||||||
|
if prop, ok := propValue.(map[string]interface{}); ok {
|
||||||
|
// Check if this is an array type missing items
|
||||||
|
if propType, hasType := prop["type"]; hasType && propType == "array" {
|
||||||
|
if _, hasItems := prop["items"]; !hasItems {
|
||||||
|
// Add default items definition for arrays
|
||||||
|
fmt.Printf("Warning: Array property '%s' in tool '%s' missing items - adding default\n", propName, tool.Name)
|
||||||
|
prop["items"] = map[string]interface{}{"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fixedProperties[propName] = prop
|
||||||
|
} else {
|
||||||
|
// Keep non-map properties as-is
|
||||||
|
fixedProperties[propName] = propValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parameters["properties"] = fixedProperties
|
||||||
|
} else {
|
||||||
|
parameters["properties"] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle required fields
|
||||||
|
if len(tool.InputSchema.Required) > 0 {
|
||||||
|
parameters["required"] = tool.InputSchema.Required
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build function definition
|
||||||
|
functionDef := map[string]interface{}{
|
||||||
|
"name": tool.Name,
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add description if present
|
||||||
|
if tool.Description != "" {
|
||||||
|
functionDef["description"] = tool.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
openAITool := map[string]interface{}{
|
||||||
|
"type": "function",
|
||||||
|
"function": functionDef,
|
||||||
|
}
|
||||||
|
tools = append(tools, openAITool)
|
||||||
|
}
|
||||||
|
openAIReq["tools"] = tools
|
||||||
|
|
||||||
|
// Handle tool_choice if present
|
||||||
|
if req.ToolChoice != nil {
|
||||||
|
// Convert Anthropic tool_choice to OpenAI format
|
||||||
|
if toolChoiceMap, ok := req.ToolChoice.(map[string]interface{}); ok {
|
||||||
|
choiceType := toolChoiceMap["type"]
|
||||||
|
switch choiceType {
|
||||||
|
case "auto":
|
||||||
|
openAIReq["tool_choice"] = "auto"
|
||||||
|
case "any":
|
||||||
|
openAIReq["tool_choice"] = "required"
|
||||||
|
case "tool":
|
||||||
|
// Specific tool choice
|
||||||
|
if name, hasName := toolChoiceMap["name"].(string); hasName {
|
||||||
|
openAIReq["tool_choice"] = map[string]interface{}{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default to auto if we can't determine
|
||||||
|
openAIReq["tool_choice"] = "auto"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return openAIReq
|
return openAIReq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getMapKeys(m map[string]interface{}) []string {
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
// This is a simplified transformation
|
// This is a simplified transformation
|
||||||
// In production, you'd want to handle all fields properly
|
// In production, you'd want to handle all fields properly
|
||||||
|
|
@ -166,25 +485,97 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the assistant's message
|
// Extract the assistant's message
|
||||||
content := ""
|
var contentBlocks []map[string]interface{}
|
||||||
|
|
||||||
if choices, ok := openAIResp["choices"].([]interface{}); ok && len(choices) > 0 {
|
if choices, ok := openAIResp["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||||
if msg, ok := choice["message"].(map[string]interface{}); ok {
|
if msg, ok := choice["message"].(map[string]interface{}); ok {
|
||||||
if c, ok := msg["content"].(string); ok {
|
// Handle regular text content
|
||||||
content = c
|
if content, ok := msg["content"].(string); ok && content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok {
|
||||||
|
// Since this proxy forwards to Claude/Anthropic API, we should always
|
||||||
|
// use tool_use blocks so Claude can execute the tools properly
|
||||||
|
// (regardless of which model generated the response)
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
if toolCall, ok := tc.(map[string]interface{}); ok {
|
||||||
|
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||||
|
// Convert OpenAI tool call to Anthropic tool_use format
|
||||||
|
anthropicToolUse := map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolCall["id"],
|
||||||
|
"name": function["name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the arguments JSON string
|
||||||
|
if argsStr, ok := function["arguments"].(string); ok {
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||||
|
anthropicToolUse["input"] = args
|
||||||
|
} else {
|
||||||
|
// If parsing fails, wrap in a raw field like Python does
|
||||||
|
fmt.Printf("Warning: Failed to parse tool arguments as JSON: %v\n", err)
|
||||||
|
anthropicToolUse["input"] = map[string]interface{}{"raw": argsStr}
|
||||||
|
}
|
||||||
|
} else if args, ok := function["arguments"].(map[string]interface{}); ok {
|
||||||
|
// Already a map, use directly
|
||||||
|
anthropicToolUse["input"] = args
|
||||||
|
} else {
|
||||||
|
// Fallback for any other type
|
||||||
|
anthropicToolUse["input"] = map[string]interface{}{"raw": fmt.Sprintf("%v", function["arguments"])}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlocks = append(contentBlocks, anthropicToolUse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no content blocks were created, add a default empty text block
|
||||||
|
if len(contentBlocks) == 0 {
|
||||||
|
contentBlocks = []map[string]interface{}{
|
||||||
|
{"type": "text", "text": ""},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build Anthropic-style response
|
// Build Anthropic-style response
|
||||||
anthropicResp := map[string]interface{}{
|
anthropicResp := map[string]interface{}{
|
||||||
"id": openAIResp["id"],
|
"id": openAIResp["id"],
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": []map[string]string{{"type": "text", "text": content}},
|
"content": contentBlocks,
|
||||||
"model": openAIResp["model"],
|
"model": openAIResp["model"],
|
||||||
"usage": openAIResp["usage"],
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI usage format to Anthropic format
|
||||||
|
if usage, ok := openAIResp["usage"].(map[string]interface{}); ok {
|
||||||
|
anthropicUsage := map[string]interface{}{}
|
||||||
|
|
||||||
|
// Map prompt_tokens to input_tokens
|
||||||
|
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map completion_tokens to output_tokens
|
||||||
|
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include total_tokens if needed (though Anthropic format doesn't typically use it)
|
||||||
|
if totalTokens, ok := usage["total_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["total_tokens"] = int(totalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicResp["usage"] = anthropicUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := json.Marshal(anthropicResp)
|
result, _ := json.Marshal(anthropicResp)
|
||||||
|
|
@ -194,7 +585,154 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
||||||
defer openAIStream.Close()
|
defer openAIStream.Close()
|
||||||
|
|
||||||
// This is a placeholder - in production you'd parse SSE events
|
scanner := bufio.NewScanner(openAIStream)
|
||||||
// and transform them from OpenAI format to Anthropic format
|
var messageStarted bool
|
||||||
io.Copy(anthropicStream, openAIStream)
|
var contentStarted bool
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
// Skip empty lines
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SSE data lines
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
|
||||||
|
// Handle end of stream
|
||||||
|
if data == "[DONE]" {
|
||||||
|
// Send Anthropic-style completion
|
||||||
|
if contentStarted {
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||||
|
}
|
||||||
|
if messageStarted {
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
|
||||||
|
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse OpenAI response
|
||||||
|
var openAIChunk map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: Check if this is the final chunk
|
||||||
|
if choices, ok := openAIChunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||||
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||||
|
if finishReason, ok := choice["finish_reason"]; ok && finishReason != nil {
|
||||||
|
fmt.Printf("🏁 Final chunk detected with finish_reason: %v\n", finishReason)
|
||||||
|
fmt.Printf("🏁 Full final chunk: %+v\n", openAIChunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for usage data BEFORE processing choices
|
||||||
|
// According to OpenAI docs, usage is sent in the final chunk with empty choices array
|
||||||
|
if usage, hasUsage := openAIChunk["usage"].(map[string]interface{}); hasUsage {
|
||||||
|
fmt.Printf("🔍 Found usage data in OpenAI stream: %+v\n", usage)
|
||||||
|
fmt.Printf("🔍 Full OpenAI chunk with usage: %+v\n", openAIChunk)
|
||||||
|
|
||||||
|
// Convert OpenAI usage to Anthropic format
|
||||||
|
anthropicUsage := map[string]interface{}{}
|
||||||
|
|
||||||
|
// Handle both float64 and int types
|
||||||
|
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||||
|
} else if promptTokens, ok := usage["prompt_tokens"].(int); ok {
|
||||||
|
anthropicUsage["input_tokens"] = promptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||||
|
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||||
|
} else if completionTokens, ok := usage["completion_tokens"].(int); ok {
|
||||||
|
anthropicUsage["output_tokens"] = completionTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(anthropicUsage) > 0 {
|
||||||
|
// Send usage data in a message_delta event
|
||||||
|
usageDelta := map[string]interface{}{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]interface{}{},
|
||||||
|
"usage": anthropicUsage,
|
||||||
|
}
|
||||||
|
usageJSON, _ := json.Marshal(usageDelta)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract choices array
|
||||||
|
choices, ok := openAIChunk["choices"].([]interface{})
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
// Skip further processing if no choices, but we already handled usage above
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
choice, ok := choices[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
delta, ok := choice["delta"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle first chunk - send message_start
|
||||||
|
if !messageStarted {
|
||||||
|
messageStarted = true
|
||||||
|
messageStart := map[string]interface{}{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"id": openAIChunk["id"],
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": openAIChunk["model"],
|
||||||
|
"content": []interface{}{},
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
// Empty usage - will be updated in final chunk
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
startJSON, _ := json.Marshal(messageStart)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle content
|
||||||
|
if content, hasContent := delta["content"].(string); hasContent && content != "" {
|
||||||
|
if !contentStarted {
|
||||||
|
contentStarted = true
|
||||||
|
// Send content_block_start
|
||||||
|
blockStart := map[string]interface{}{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
blockStartJSON, _ := json.Marshal(blockStart)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send content_block_delta
|
||||||
|
contentDelta := map[string]interface{}{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deltaJSON, _ := json.Marshal(contentDelta)
|
||||||
|
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,8 @@ func (r *ModelRouter) extractStaticPrompt(systemPrompt string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ModelRouter) loadCustomAgents() {
|
func (r *ModelRouter) loadCustomAgents() {
|
||||||
|
r.logger.Printf("Loading custom agents from mappings: %+v", r.subagentMappings)
|
||||||
|
|
||||||
for agentName, targetModel := range r.subagentMappings {
|
for agentName, targetModel := range r.subagentMappings {
|
||||||
// Try loading from project level first, then user level
|
// Try loading from project level first, then user level
|
||||||
paths := []string{
|
paths := []string{
|
||||||
|
|
@ -67,20 +69,28 @@ func (r *ModelRouter) loadCustomAgents() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
|
r.logger.Printf("Trying to load agent from: %s", path)
|
||||||
content, err := os.ReadFile(path)
|
content, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
r.logger.Printf("Failed to read %s: %v", path, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.logger.Printf("Successfully read agent file: %s (size: %d bytes)", path, len(content))
|
||||||
|
|
||||||
// Parse agent file: metadata\n---\nsystem prompt
|
// Parse agent file: metadata\n---\nsystem prompt
|
||||||
parts := strings.Split(string(content), "\n---\n")
|
parts := strings.Split(string(content), "\n---\n")
|
||||||
|
r.logger.Printf("Agent file parts: %d", len(parts))
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
systemPrompt := strings.TrimSpace(parts[1])
|
systemPrompt := strings.TrimSpace(parts[1])
|
||||||
|
r.logger.Printf("System prompt (first 200 chars): %.200s", systemPrompt)
|
||||||
|
|
||||||
// Extract only the static part (before "Notes:" if it exists)
|
// Extract only the static part (before "Notes:" if it exists)
|
||||||
staticPrompt := r.extractStaticPrompt(systemPrompt)
|
staticPrompt := r.extractStaticPrompt(systemPrompt)
|
||||||
hash := r.hashString(staticPrompt)
|
hash := r.hashString(staticPrompt)
|
||||||
|
|
||||||
|
r.logger.Printf("Static prompt after extraction (first 200 chars): %.200s", staticPrompt)
|
||||||
|
|
||||||
// Determine provider for the target model
|
// Determine provider for the target model
|
||||||
providerName := r.getProviderNameForModel(targetModel)
|
providerName := r.getProviderNameForModel(targetModel)
|
||||||
|
|
||||||
|
|
@ -91,20 +101,35 @@ func (r *ModelRouter) loadCustomAgents() {
|
||||||
FullPrompt: staticPrompt,
|
FullPrompt: staticPrompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
r.logger.Printf("Loaded custom agent: %s (hash: %s) -> %s",
|
r.logger.Printf("Loaded custom agent: %s (hash: %s) -> %s (provider: %s)",
|
||||||
agentName, hash, targetModel)
|
agentName, hash, targetModel, providerName)
|
||||||
break
|
break
|
||||||
|
} else {
|
||||||
|
r.logger.Printf("Invalid agent file format for %s: expected at least 2 parts separated by ---", agentName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.logger.Printf("Total custom agents loaded: %d", len(r.customAgentPrompts))
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteRequest determines which provider and model to use for a request
|
// RouteRequest determines which provider and model to use for a request
|
||||||
func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provider, string, error) {
|
func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provider, string, error) {
|
||||||
originalModel := req.Model
|
originalModel := req.Model
|
||||||
|
|
||||||
|
r.logger.Printf("RouteRequest: Model=%s, System messages count=%d", originalModel, len(req.System))
|
||||||
|
|
||||||
|
// Debug: Print loaded custom agents
|
||||||
|
r.logger.Printf("Loaded custom agents: %d", len(r.customAgentPrompts))
|
||||||
|
for hash, def := range r.customAgentPrompts {
|
||||||
|
r.logger.Printf(" Agent: %s (hash: %s) -> %s", def.Name, hash, def.TargetModel)
|
||||||
|
}
|
||||||
|
|
||||||
// Claude Code pattern: Check if we have exactly 2 system messages
|
// Claude Code pattern: Check if we have exactly 2 system messages
|
||||||
if len(req.System) == 2 {
|
if len(req.System) == 2 {
|
||||||
|
r.logger.Printf("System[0]: %.100s...", req.System[0].Text)
|
||||||
|
r.logger.Printf("System[1]: %.100s...", req.System[1].Text)
|
||||||
|
|
||||||
// First should be "You are Claude Code..."
|
// First should be "You are Claude Code..."
|
||||||
if strings.Contains(req.System[0].Text, "You are Claude Code") {
|
if strings.Contains(req.System[0].Text, "You are Claude Code") {
|
||||||
// Second message could be either:
|
// Second message could be either:
|
||||||
|
|
@ -117,6 +142,9 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
||||||
staticPrompt := r.extractStaticPrompt(fullPrompt)
|
staticPrompt := r.extractStaticPrompt(fullPrompt)
|
||||||
promptHash := r.hashString(staticPrompt)
|
promptHash := r.hashString(staticPrompt)
|
||||||
|
|
||||||
|
r.logger.Printf("Static prompt hash: %s", promptHash)
|
||||||
|
r.logger.Printf("Static prompt (first 200 chars): %.200s", staticPrompt)
|
||||||
|
|
||||||
// Check if this matches a known custom agent
|
// Check if this matches a known custom agent
|
||||||
if definition, exists := r.customAgentPrompts[promptHash]; exists {
|
if definition, exists := r.customAgentPrompts[promptHash]; exists {
|
||||||
r.logger.Printf("Subagent '%s' detected -> routing to %s",
|
r.logger.Printf("Subagent '%s' detected -> routing to %s",
|
||||||
|
|
@ -133,7 +161,7 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a regular Claude Code request (not a known subagent)
|
// This is a regular Claude Code request (not a known subagent)
|
||||||
r.logger.Printf("Regular Claude Code request detected, using original model %s", originalModel)
|
r.logger.Printf("No matching subagent found for hash %s, using original model %s", promptHash, originalModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,14 +178,17 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
||||||
func (r *ModelRouter) hashString(s string) string {
|
func (r *ModelRouter) hashString(s string) string {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write([]byte(s))
|
h.Write([]byte(s))
|
||||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
fullHash := hex.EncodeToString(h.Sum(nil))
|
||||||
|
shortHash := fullHash[:16]
|
||||||
|
r.logger.Printf("Hashing string (length: %d) -> %s", len(s), shortHash)
|
||||||
|
return shortHash
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ModelRouter) getProviderNameForModel(model string) string {
|
func (r *ModelRouter) getProviderNameForModel(model string) string {
|
||||||
// Map models to providers
|
// Map models to providers
|
||||||
if strings.HasPrefix(model, "claude") {
|
if strings.HasPrefix(model, "claude") {
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
} else if strings.HasPrefix(model, "gpt") {
|
} else if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o") {
|
||||||
return "openai"
|
return "openai"
|
||||||
}
|
}
|
||||||
// Default to anthropic
|
// Default to anthropic
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,8 @@ func (s *sqliteStorageService) createTables() error {
|
||||||
prompt_grade TEXT,
|
prompt_grade TEXT,
|
||||||
response TEXT,
|
response TEXT,
|
||||||
model TEXT,
|
model TEXT,
|
||||||
|
original_model TEXT,
|
||||||
|
routed_model TEXT,
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -74,8 +76,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model)
|
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err = s.db.Exec(query,
|
_, err = s.db.Exec(query,
|
||||||
|
|
@ -88,6 +90,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
request.UserAgent,
|
request.UserAgent,
|
||||||
request.ContentType,
|
request.ContentType,
|
||||||
request.Model,
|
request.Model,
|
||||||
|
request.OriginalModel,
|
||||||
|
request.RoutedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -108,7 +112,7 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
||||||
// Get paginated results
|
// Get paginated results
|
||||||
offset := (page - 1) * limit
|
offset := (page - 1) * limit
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
|
|
@ -138,6 +142,8 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error scanning row: %v", err)
|
log.Printf("Error scanning row: %v", err)
|
||||||
|
|
@ -228,7 +234,7 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
WHERE id LIKE ?
|
WHERE id LIKE ?
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
|
|
@ -251,6 +257,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
|
@ -294,7 +302,7 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
`
|
`
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
@ -331,13 +339,15 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
||||||
&req.ContentType,
|
&req.ContentType,
|
||||||
&promptGradeJSON,
|
&promptGradeJSON,
|
||||||
&responseJSON,
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error scanning row: %v", err)
|
log.Printf("Error scanning row: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
|
// log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
// Unmarshal JSON fields
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ interface Request {
|
||||||
method: string;
|
method: string;
|
||||||
endpoint: string;
|
endpoint: string;
|
||||||
headers: Record<string, string[]>;
|
headers: Record<string, string[]>;
|
||||||
|
originalModel?: string;
|
||||||
|
routedModel?: string;
|
||||||
body?: {
|
body?: {
|
||||||
model?: string;
|
model?: string;
|
||||||
messages?: Array<{
|
messages?: Array<{
|
||||||
|
|
@ -150,7 +152,7 @@ export default function RequestDetailContent({ request, onGrade }: RequestDetail
|
||||||
<div className="flex items-center space-x-3">
|
<div className="flex items-center space-x-3">
|
||||||
<span className="text-gray-500 font-medium min-w-[80px]">Endpoint:</span>
|
<span className="text-gray-500 font-medium min-w-[80px]">Endpoint:</span>
|
||||||
<code className="text-blue-600 bg-blue-50 px-2 py-1 rounded font-mono text-xs border border-blue-200">
|
<code className="text-blue-600 bg-blue-50 px-2 py-1 rounded font-mono text-xs border border-blue-200">
|
||||||
{request.endpoint}
|
{request.routedModel && request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : request.endpoint}
|
||||||
</code>
|
</code>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -329,12 +331,49 @@ export default function RequestDetailContent({ request, onGrade }: RequestDetail
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{expandedSections.model && (
|
{expandedSections.model && (
|
||||||
<div className="p-6">
|
<div className="p-6 space-y-4">
|
||||||
<div className="grid grid-cols-2 gap-4">
|
{/* Model Routing Information */}
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
{request.routedModel && request.routedModel !== request.originalModel && (
|
||||||
<div className="text-xs text-gray-500 mb-1">Model</div>
|
<div className="bg-gradient-to-r from-purple-50 to-blue-50 border border-purple-200 rounded-xl p-4">
|
||||||
<div className="text-sm font-medium text-gray-900">{request.body.model || 'N/A'}</div>
|
<div className="flex items-center space-x-4">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<span className="text-sm font-semibold text-purple-700">Requested Model</span>
|
||||||
|
<code className="text-xs bg-white px-2 py-1 rounded font-mono border border-purple-200">
|
||||||
|
{request.originalModel || request.body.model}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center space-x-3">
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
<ArrowLeftRight className="w-4 h-4 text-purple-600" />
|
||||||
|
<span className="text-xs text-purple-600 font-medium">Routed to</span>
|
||||||
|
</div>
|
||||||
|
<code className="text-sm bg-white px-3 py-1.5 rounded font-mono font-semibold border border-blue-200 text-blue-700">
|
||||||
|
{request.routedModel}
|
||||||
|
</code>
|
||||||
|
<span className="text-xs bg-blue-100 text-blue-700 px-2 py-1 rounded-full border border-blue-200">
|
||||||
|
{request.routedModel.startsWith('gpt-') ? 'OpenAI' : 'Anthropic'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="text-right">
|
||||||
|
<div className="text-xs text-gray-500 mb-1">Target Endpoint</div>
|
||||||
|
<code className="text-xs bg-white px-2 py-1 rounded font-mono border border-gray-200">
|
||||||
|
{request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : '/v1/messages'}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Model Parameters */}
|
||||||
|
<div className="grid grid-cols-2 gap-4">
|
||||||
|
{!request.routedModel || request.routedModel === request.originalModel ? (
|
||||||
|
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
||||||
|
<div className="text-xs text-gray-500 mb-1">Model</div>
|
||||||
|
<div className="text-sm font-medium text-gray-900">{request.originalModel || request.body.model || 'N/A'}</div>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
<div className="bg-gray-50 border border-gray-200 rounded-lg p-3">
|
||||||
<div className="text-xs text-gray-500 mb-1">Max Tokens</div>
|
<div className="text-xs text-gray-500 mb-1">Max Tokens</div>
|
||||||
<div className="text-sm font-medium text-gray-900">
|
<div className="text-sm font-medium text-gray-900">
|
||||||
|
|
@ -619,6 +658,57 @@ function ResponseDetails({ response }: { response: NonNullable<Request['response
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Token Usage */}
|
||||||
|
{response.body?.usage && (
|
||||||
|
<div className="grid grid-cols-2 lg:grid-cols-4 gap-4">
|
||||||
|
<div className="bg-indigo-50 border border-indigo-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Brain className="w-4 h-4 text-indigo-600" />
|
||||||
|
<span className="text-xs font-medium text-indigo-700">Input Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-indigo-700">
|
||||||
|
{response.body.usage.input_tokens?.toLocaleString() || '0'}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-indigo-700 opacity-75">Prompt</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-emerald-50 border border-emerald-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<MessageCircle className="w-4 h-4 text-emerald-600" />
|
||||||
|
<span className="text-xs font-medium text-emerald-700">Output Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-emerald-700">
|
||||||
|
{response.body.usage.output_tokens?.toLocaleString() || '0'}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-emerald-700 opacity-75">Response</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-amber-50 border border-amber-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Cpu className="w-4 h-4 text-amber-600" />
|
||||||
|
<span className="text-xs font-medium text-amber-700">Total Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-amber-700">
|
||||||
|
{((response.body.usage.input_tokens || 0) + (response.body.usage.output_tokens || 0)).toLocaleString()}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-amber-700 opacity-75">Combined</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{response.body.usage.cache_read_input_tokens && (
|
||||||
|
<div className="bg-green-50 border border-green-200 rounded-lg p-4">
|
||||||
|
<div className="flex items-center space-x-2 mb-2">
|
||||||
|
<Bot className="w-4 h-4 text-green-600" />
|
||||||
|
<span className="text-xs font-medium text-green-700">Cached Tokens</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-lg font-bold text-green-700">
|
||||||
|
{response.body.usage.cache_read_input_tokens.toLocaleString()}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-green-700 opacity-75">From Cache</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Response Headers */}
|
{/* Response Headers */}
|
||||||
{response.headers && (
|
{response.headers && (
|
||||||
<div className="bg-gray-50 border border-gray-200 rounded-xl overflow-hidden">
|
<div className="bg-gray-50 border border-gray-200 rounded-xl overflow-hidden">
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ import {
|
||||||
Copy,
|
Copy,
|
||||||
Check,
|
Check,
|
||||||
Lightbulb,
|
Lightbulb,
|
||||||
Loader2
|
Loader2,
|
||||||
|
ArrowLeftRight
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
|
|
||||||
import RequestDetailContent from "../components/RequestDetailContent";
|
import RequestDetailContent from "../components/RequestDetailContent";
|
||||||
|
|
@ -50,6 +51,8 @@ interface Request {
|
||||||
method: string;
|
method: string;
|
||||||
endpoint: string;
|
endpoint: string;
|
||||||
headers: Record<string, string[]>;
|
headers: Record<string, string[]>;
|
||||||
|
originalModel?: string;
|
||||||
|
routedModel?: string;
|
||||||
body?: {
|
body?: {
|
||||||
model?: string;
|
model?: string;
|
||||||
messages?: Array<{
|
messages?: Array<{
|
||||||
|
|
@ -363,12 +366,21 @@ export default function Index() {
|
||||||
parts.push(`⏱️ ${seconds}s`);
|
parts.push(`⏱️ ${seconds}s`);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add model if available
|
// Add model if available (use routed model if different from original)
|
||||||
if (request.body?.model) {
|
const model = request.routedModel || request.body?.model;
|
||||||
const modelShort = request.body.model.includes('opus') ? 'Opus' :
|
if (model) {
|
||||||
request.body.model.includes('sonnet') ? 'Sonnet' :
|
const modelShort = model.includes('opus') ? 'Opus' :
|
||||||
request.body.model.includes('haiku') ? 'Haiku' : 'Model';
|
model.includes('sonnet') ? 'Sonnet' :
|
||||||
|
model.includes('haiku') ? 'Haiku' :
|
||||||
|
model.includes('gpt-4o') ? 'gpt-4o' :
|
||||||
|
model.includes('o3') ? 'o3' :
|
||||||
|
model.includes('o3-mini') ? 'o3-mini' : 'Model';
|
||||||
parts.push(`🤖 ${modelShort}`);
|
parts.push(`🤖 ${modelShort}`);
|
||||||
|
|
||||||
|
// Show routing info if model was routed
|
||||||
|
if (request.routedModel && request.originalModel && request.routedModel !== request.originalModel) {
|
||||||
|
parts.push(`→ routed`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return parts.length > 0 ? parts.join(' • ') : '📡 API request';
|
return parts.length > 0 ? parts.join(' • ') : '📡 API request';
|
||||||
|
|
@ -671,13 +683,25 @@ export default function Index() {
|
||||||
{/* Model and Status */}
|
{/* Model and Status */}
|
||||||
<div className="flex items-center space-x-3 mb-1">
|
<div className="flex items-center space-x-3 mb-1">
|
||||||
<h3 className="text-sm font-medium">
|
<h3 className="text-sm font-medium">
|
||||||
{request.body?.model ? (
|
{request.routedModel || request.body?.model ? (
|
||||||
request.body.model.includes('opus') ? <span className="text-purple-600 font-semibold">Opus</span> :
|
// Use routedModel if available, otherwise fall back to body.model
|
||||||
request.body.model.includes('sonnet') ? <span className="text-indigo-600 font-semibold">Sonnet</span> :
|
(() => {
|
||||||
request.body.model.includes('haiku') ? <span className="text-teal-600 font-semibold">Haiku</span> :
|
const model = request.routedModel || request.body?.model || '';
|
||||||
<span className="text-gray-900">{request.body.model.split('-')[0]}</span>
|
if (model.includes('opus')) return <span className="text-purple-600 font-semibold">Opus</span>;
|
||||||
|
if (model.includes('sonnet')) return <span className="text-indigo-600 font-semibold">Sonnet</span>;
|
||||||
|
if (model.includes('haiku')) return <span className="text-teal-600 font-semibold">Haiku</span>;
|
||||||
|
if (model.includes('gpt-4o')) return <span className="text-green-600 font-semibold">GPT-4o</span>;
|
||||||
|
if (model.includes('gpt')) return <span className="text-green-600 font-semibold">GPT</span>;
|
||||||
|
return <span className="text-gray-900">{model.split('-')[0]}</span>;
|
||||||
|
})()
|
||||||
) : <span className="text-gray-900">API</span>}
|
) : <span className="text-gray-900">API</span>}
|
||||||
</h3>
|
</h3>
|
||||||
|
{request.routedModel && request.routedModel !== request.originalModel && (
|
||||||
|
<span className="text-xs px-1.5 py-0.5 bg-blue-100 text-blue-700 rounded font-medium flex items-center space-x-1">
|
||||||
|
<ArrowLeftRight className="w-3 h-3" />
|
||||||
|
<span>routed</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
{request.response?.statusCode && (
|
{request.response?.statusCode && (
|
||||||
<span className={`text-xs font-medium px-1.5 py-0.5 rounded ${
|
<span className={`text-xs font-medium px-1.5 py-0.5 rounded ${
|
||||||
request.response.statusCode >= 200 && request.response.statusCode < 300
|
request.response.statusCode >= 200 && request.response.statusCode < 300
|
||||||
|
|
@ -698,7 +722,7 @@ export default function Index() {
|
||||||
|
|
||||||
{/* Endpoint */}
|
{/* Endpoint */}
|
||||||
<div className="text-xs text-gray-600 font-mono mb-1">
|
<div className="text-xs text-gray-600 font-mono mb-1">
|
||||||
{request.endpoint}
|
{request.routedModel && request.routedModel.startsWith('gpt-') ? '/v1/chat/completions' : request.endpoint}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Metrics Row */}
|
{/* Metrics Row */}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue