Working version with openai
This commit is contained in:
parent
80a25f7ba7
commit
1e0173c768
11 changed files with 904 additions and 179 deletions
|
|
@ -81,45 +81,95 @@ func Load() (*Config, error) {
|
|||
// Start with default configuration
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Port: getEnv("PORT", "3001"),
|
||||
ReadTimeout: getDuration("READ_TIMEOUT", 600*time.Second),
|
||||
WriteTimeout: getDuration("WRITE_TIMEOUT", 600*time.Second),
|
||||
IdleTimeout: getDuration("IDLE_TIMEOUT", 600*time.Second),
|
||||
Port: "3001",
|
||||
ReadTimeout: 600 * time.Second,
|
||||
WriteTimeout: 600 * time.Second,
|
||||
IdleTimeout: 600 * time.Second,
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: AnthropicProviderConfig{
|
||||
BaseURL: getEnv("ANTHROPIC_FORWARD_URL", "https://api.anthropic.com"),
|
||||
Version: getEnv("ANTHROPIC_VERSION", "2023-06-01"),
|
||||
MaxRetries: getInt("ANTHROPIC_MAX_RETRIES", 3),
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
Version: "2023-06-01",
|
||||
MaxRetries: 3,
|
||||
},
|
||||
OpenAI: OpenAIProviderConfig{
|
||||
BaseURL: getEnv("OPENAI_BASE_URL", "https://api.openai.com"),
|
||||
APIKey: getEnv("OPENAI_API_KEY", ""),
|
||||
BaseURL: "https://api.openai.com",
|
||||
APIKey: "",
|
||||
},
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
DBPath: getEnv("DB_PATH", "requests.db"),
|
||||
DBPath: "requests.db",
|
||||
},
|
||||
Subagents: SubagentsConfig{
|
||||
Mappings: make(map[string]string),
|
||||
},
|
||||
// Legacy field for backward compatibility
|
||||
Anthropic: AnthropicConfig{
|
||||
BaseURL: getEnv("ANTHROPIC_FORWARD_URL", "https://api.anthropic.com"),
|
||||
Version: getEnv("ANTHROPIC_VERSION", "2023-06-01"),
|
||||
MaxRetries: getInt("ANTHROPIC_MAX_RETRIES", 3),
|
||||
},
|
||||
}
|
||||
|
||||
// Try to load from YAML config file if specified
|
||||
configPath := getEnv("CONFIG_PATH", "../config.yaml")
|
||||
if configPath != "" {
|
||||
if err := cfg.loadFromFile(configPath); err != nil {
|
||||
// Log error but continue with defaults
|
||||
fmt.Printf("Warning: Failed to load config from %s: %v\n", configPath, err)
|
||||
// Try to load config.yaml from the project root
|
||||
// The proxy binary is in proxy/ directory, config.yaml is in the parent
|
||||
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
|
||||
|
||||
// If that doesn't work, try relative to current directory
|
||||
if _, err := os.Stat(configPath); err != nil {
|
||||
// Try common locations relative to where the binary might be run
|
||||
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
|
||||
if _, err := os.Stat(tryPath); err == nil {
|
||||
configPath = tryPath
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := cfg.loadFromFile(configPath); err == nil {
|
||||
fmt.Printf("Loaded config from %s\n", configPath)
|
||||
fmt.Printf("Subagent mappings: %+v\n", cfg.Subagents.Mappings)
|
||||
}
|
||||
|
||||
// Apply environment variable overrides AFTER loading from file
|
||||
if envPort := os.Getenv("PORT"); envPort != "" {
|
||||
cfg.Server.Port = envPort
|
||||
}
|
||||
if envTimeout := os.Getenv("READ_TIMEOUT"); envTimeout != "" {
|
||||
cfg.Server.ReadTimeout = getDuration("READ_TIMEOUT", cfg.Server.ReadTimeout)
|
||||
}
|
||||
if envTimeout := os.Getenv("WRITE_TIMEOUT"); envTimeout != "" {
|
||||
cfg.Server.WriteTimeout = getDuration("WRITE_TIMEOUT", cfg.Server.WriteTimeout)
|
||||
}
|
||||
if envTimeout := os.Getenv("IDLE_TIMEOUT"); envTimeout != "" {
|
||||
cfg.Server.IdleTimeout = getDuration("IDLE_TIMEOUT", cfg.Server.IdleTimeout)
|
||||
}
|
||||
|
||||
// Override Anthropic settings
|
||||
if envURL := os.Getenv("ANTHROPIC_FORWARD_URL"); envURL != "" {
|
||||
cfg.Providers.Anthropic.BaseURL = envURL
|
||||
}
|
||||
if envVersion := os.Getenv("ANTHROPIC_VERSION"); envVersion != "" {
|
||||
cfg.Providers.Anthropic.Version = envVersion
|
||||
}
|
||||
if envRetries := os.Getenv("ANTHROPIC_MAX_RETRIES"); envRetries != "" {
|
||||
cfg.Providers.Anthropic.MaxRetries = getInt("ANTHROPIC_MAX_RETRIES", cfg.Providers.Anthropic.MaxRetries)
|
||||
}
|
||||
|
||||
// Override OpenAI settings
|
||||
if envURL := os.Getenv("OPENAI_BASE_URL"); envURL != "" {
|
||||
cfg.Providers.OpenAI.BaseURL = envURL
|
||||
}
|
||||
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
||||
cfg.Providers.OpenAI.APIKey = envKey
|
||||
}
|
||||
|
||||
// Override storage settings
|
||||
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
||||
cfg.Storage.DBPath = envPath
|
||||
}
|
||||
|
||||
// Sync legacy Anthropic config
|
||||
cfg.Anthropic = AnthropicConfig{
|
||||
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||
Version: cfg.Providers.Anthropic.Version,
|
||||
MaxRetries: cfg.Providers.Anthropic.MaxRetries,
|
||||
}
|
||||
|
||||
// After loading from file, apply any timeout conversions if needed
|
||||
if cfg.Server.Timeouts.Read != "" {
|
||||
if duration, err := time.ParseDuration(cfg.Server.Timeouts.Read); err == nil {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ func TestLoad(t *testing.T) {
|
|||
originalPort := os.Getenv("PORT")
|
||||
originalAnthropicURL := os.Getenv("ANTHROPIC_FORWARD_URL")
|
||||
originalOpenAIKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
// Restore after test
|
||||
defer func() {
|
||||
os.Setenv("CONFIG_PATH", originalConfigPath)
|
||||
|
|
@ -82,7 +82,7 @@ subagents:
|
|||
// Clear environment variables
|
||||
os.Unsetenv("CONFIG_PATH")
|
||||
os.Unsetenv("PORT")
|
||||
|
||||
|
||||
// Create empty config directory
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("CONFIG_PATH", filepath.Join(tempDir, "nonexistent.yaml"))
|
||||
|
|
@ -188,7 +188,7 @@ func TestConfig_ParseTimeouts(t *testing.T) {
|
|||
{"Valid minutes", "5m", 5, false},
|
||||
{"Valid seconds", "30s", 0, false}, // Will be 30 seconds, not minutes
|
||||
{"Valid hours", "2h", 120, false},
|
||||
{"Empty string", "", 10, false}, // Should use default
|
||||
{"Empty string", "", 10, false}, // Should use default
|
||||
{"Invalid format", "invalid", 10, false}, // Should use default
|
||||
}
|
||||
|
||||
|
|
@ -199,4 +199,4 @@ func TestConfig_ParseTimeouts(t *testing.T) {
|
|||
// For now, we'll skip the implementation details
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package handler
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
|
@ -68,23 +70,6 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
|||
requestID := generateRequestID()
|
||||
startTime := time.Now()
|
||||
|
||||
// Create request log
|
||||
requestLog := &model.RequestLog{
|
||||
RequestID: requestID,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Method: r.Method,
|
||||
Endpoint: "/v1/messages",
|
||||
Headers: SanitizeHeaders(r.Header),
|
||||
Body: req,
|
||||
Model: req.Model,
|
||||
UserAgent: r.Header.Get("User-Agent"),
|
||||
ContentType: r.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
||||
log.Printf("❌ Error saving request: %v", err)
|
||||
}
|
||||
|
||||
// Use model router to determine provider and route the request
|
||||
provider, originalModel, err := h.modelRouter.RouteRequest(&req)
|
||||
if err != nil {
|
||||
|
|
@ -93,9 +78,44 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Update request log with original model (for tracking)
|
||||
requestLog.OriginalModel = originalModel
|
||||
requestLog.RoutedModel = req.Model
|
||||
// Create request log with routing information
|
||||
requestLog := &model.RequestLog{
|
||||
RequestID: requestID,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Method: r.Method,
|
||||
Endpoint: "/v1/messages",
|
||||
Headers: SanitizeHeaders(r.Header),
|
||||
Body: req,
|
||||
Model: req.Model,
|
||||
OriginalModel: originalModel,
|
||||
RoutedModel: req.Model,
|
||||
UserAgent: r.Header.Get("User-Agent"),
|
||||
ContentType: r.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
if _, err := h.storageService.SaveRequest(requestLog); err != nil {
|
||||
log.Printf("❌ Error saving request: %v", err)
|
||||
}
|
||||
|
||||
// If the model was changed by routing, update the request body
|
||||
if req.Model != originalModel {
|
||||
// Re-marshal the request with the updated model
|
||||
updatedBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error marshaling updated request: %v", err)
|
||||
writeErrorResponse(w, "Failed to process request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the updated body
|
||||
r.Body = io.NopCloser(bytes.NewReader(updatedBodyBytes))
|
||||
r.ContentLength = int64(len(updatedBodyBytes))
|
||||
r.Header.Set("Content-Length", fmt.Sprintf("%d", len(updatedBodyBytes)))
|
||||
|
||||
// Update the context with new body bytes for logging
|
||||
ctx := context.WithValue(r.Context(), model.BodyBytesKey, updatedBodyBytes)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
// Forward the request to the selected provider
|
||||
resp, err := provider.ForwardRequest(r.Context(), r)
|
||||
|
|
@ -182,8 +202,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
|||
modelFilter = "all"
|
||||
}
|
||||
|
||||
log.Printf("📊 GetRequests called - page: %d, limit: %d, modelFilter: %s", page, limit, modelFilter)
|
||||
|
||||
// Get all requests with model filter applied at storage level
|
||||
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
||||
if err != nil {
|
||||
|
|
@ -192,8 +210,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
log.Printf("📊 Got %d requests from storage (filter: %s)", len(allRequests), modelFilter)
|
||||
|
||||
// Convert pointers to values for consistency
|
||||
requests := make([]model.RequestLog, len(allRequests))
|
||||
for i, req := range allRequests {
|
||||
|
|
@ -217,8 +233,6 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
|||
requests = requests[start:end]
|
||||
}
|
||||
|
||||
log.Printf("📊 Returning %d requests after pagination", len(requests))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(struct {
|
||||
Requests []model.RequestLog `json:"requests"`
|
||||
|
|
@ -314,7 +328,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
continue
|
||||
}
|
||||
|
||||
// Capture usage data and metadata from message_start event
|
||||
// Capture metadata from message_start event
|
||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_start" {
|
||||
if message, ok := genericEvent["message"].(map[string]interface{}); ok {
|
||||
// Capture message metadata
|
||||
|
|
@ -327,51 +341,42 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
if reason, ok := message["stop_reason"].(string); ok {
|
||||
stopReason = reason
|
||||
}
|
||||
|
||||
// Capture initial usage data from message_start
|
||||
if usage, ok := message["usage"].(map[string]interface{}); ok {
|
||||
finalUsage = &model.AnthropicUsage{}
|
||||
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
||||
finalUsage.InputTokens = int(inputTokens)
|
||||
}
|
||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
||||
finalUsage.OutputTokens = int(outputTokens)
|
||||
}
|
||||
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
||||
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
||||
}
|
||||
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
||||
finalUsage.CacheReadInputTokens = int(cacheRead)
|
||||
}
|
||||
if tier, ok := usage["service_tier"].(string); ok {
|
||||
finalUsage.ServiceTier = tier
|
||||
}
|
||||
log.Printf("📊 Captured initial usage from message_start: %+v", finalUsage)
|
||||
} else {
|
||||
log.Printf("⚠️ No usage data found in message_start event")
|
||||
}
|
||||
// Don't capture usage from message_start - it will come in message_delta
|
||||
}
|
||||
}
|
||||
|
||||
// Update output tokens from message_delta event
|
||||
// Capture usage data from message_delta event
|
||||
if eventType, ok := genericEvent["type"].(string); ok && eventType == "message_delta" {
|
||||
// Usage is at top level for message_delta events
|
||||
if usage, ok := genericEvent["usage"].(map[string]interface{}); ok {
|
||||
if finalUsage != nil {
|
||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
||||
finalUsage.OutputTokens = int(outputTokens)
|
||||
log.Printf("📊 Updated output tokens from message_delta: %d", int(outputTokens))
|
||||
}
|
||||
} else {
|
||||
log.Printf("⚠️ finalUsage is nil when trying to update from message_delta usage")
|
||||
// Create finalUsage if it doesn't exist yet
|
||||
if finalUsage == nil {
|
||||
finalUsage = &model.AnthropicUsage{}
|
||||
}
|
||||
|
||||
// Capture all usage fields
|
||||
if inputTokens, ok := usage["input_tokens"].(float64); ok {
|
||||
finalUsage.InputTokens = int(inputTokens)
|
||||
}
|
||||
if outputTokens, ok := usage["output_tokens"].(float64); ok {
|
||||
finalUsage.OutputTokens = int(outputTokens)
|
||||
}
|
||||
if cacheCreation, ok := usage["cache_creation_input_tokens"].(float64); ok {
|
||||
finalUsage.CacheCreationInputTokens = int(cacheCreation)
|
||||
}
|
||||
if cacheRead, ok := usage["cache_read_input_tokens"].(float64); ok {
|
||||
finalUsage.CacheReadInputTokens = int(cacheRead)
|
||||
}
|
||||
|
||||
log.Printf("📊 Captured usage from message_delta: %+v", finalUsage)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as structured event for content processing
|
||||
var event model.StreamingEvent
|
||||
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
||||
continue // Skip if structured parsing fails, but we already got the usage data above
|
||||
// Skip if structured parsing fails, but we already got the usage data above
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ import (
|
|||
|
||||
// MockStorageService implements StorageService interface for testing
|
||||
type MockStorageService struct {
|
||||
SavedRequests []model.RequestLog
|
||||
ReturnError error
|
||||
SavedRequests []model.RequestLog
|
||||
ReturnError error
|
||||
RequestsToReturn []model.RequestLog
|
||||
TotalRequests int
|
||||
TotalRequests int
|
||||
}
|
||||
|
||||
func (m *MockStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
||||
|
|
@ -87,8 +87,8 @@ func (m *MockStorageService) GetAllRequests(modelFilter string) ([]*model.Reques
|
|||
|
||||
// MockAnthropicService implements AnthropicService interface for testing
|
||||
type MockAnthropicService struct {
|
||||
ReturnResponse *http.Response
|
||||
ReturnError error
|
||||
ReturnResponse *http.Response
|
||||
ReturnError error
|
||||
ReceivedRequest *http.Request
|
||||
}
|
||||
|
||||
|
|
@ -103,8 +103,8 @@ func (m *MockAnthropicService) ForwardRequest(ctx context.Context, originalReq *
|
|||
// Return a default successful response
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"id":"test","content":[{"text":"Hello"}]}`)),
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"id":"test","content":[{"text":"Hello"}]}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -284,4 +284,4 @@ func TestDeleteRequestsEndpoint(t *testing.T) {
|
|||
if response["deleted"] != float64(2) { // JSON unmarshals numbers as float64
|
||||
t.Errorf("Expected 2 deleted requests, got %v", response["deleted"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ import (
|
|||
func Logging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
log.Printf("%s - %s %s", start.Format(time.RFC3339), r.Method, r.URL.Path)
|
||||
log.Printf("Headers: %s", formatHeaders(r.Header))
|
||||
|
||||
var bodyBytes []byte
|
||||
if r.Body != nil {
|
||||
|
|
@ -35,12 +33,6 @@ func Logging(next http.Handler) http.Handler {
|
|||
ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
log.Printf("Body length: %d bytes", len(bodyBytes))
|
||||
if len(bodyBytes) > 0 {
|
||||
logRequestBody(bodyBytes)
|
||||
}
|
||||
log.Println("---")
|
||||
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
|
|
@ -77,17 +69,6 @@ func sanitizeHeaderValue(key string, values []string) []string {
|
|||
return values
|
||||
}
|
||||
|
||||
func logRequestBody(bodyBytes []byte) {
|
||||
var bodyJSON interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &bodyJSON); err == nil {
|
||||
bodyStr, _ := json.MarshalIndent(bodyJSON, "", " ")
|
||||
log.Printf("Body: %s", string(bodyStr))
|
||||
} else {
|
||||
log.Printf("❌ Failed to parse body as JSON: %v", err)
|
||||
log.Printf("Raw body: %s", string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
|
|
|
|||
|
|
@ -131,14 +131,9 @@ type Tool struct {
|
|||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Properties map[string]interface{} `json:"properties"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type AnthropicRequest struct {
|
||||
|
|
@ -149,6 +144,7 @@ type AnthropicRequest struct {
|
|||
System []AnthropicSystemMessage `json:"system,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type ModelsResponse struct {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||
|
|
@ -88,6 +91,47 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
|||
return nil, fmt.Errorf("failed to forward request: %w", err)
|
||||
}
|
||||
|
||||
// Check for error responses
|
||||
if resp.StatusCode >= 400 {
|
||||
// Read the error body for debugging
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
// Log the error details
|
||||
fmt.Printf("OpenAI API error: Status=%d, Body=%s\n", resp.StatusCode, string(errorBody))
|
||||
|
||||
// Create an error response in Anthropic format
|
||||
errorResp := map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": map[string]interface{}{
|
||||
"type": "api_error",
|
||||
"message": fmt.Sprintf("OpenAI API error: %s", string(errorBody)),
|
||||
},
|
||||
}
|
||||
errorJSON, _ := json.Marshal(errorResp)
|
||||
|
||||
// Create a new response with the error
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorJSON))
|
||||
resp.Header.Set("Content-Type", "application/json")
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.ContentLength = int64(len(errorJSON))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Handle gzip-encoded responses
|
||||
var bodyReader io.ReadCloser = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
bodyReader = gzReader
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-Length")
|
||||
}
|
||||
|
||||
// For streaming responses, we need to convert back to Anthropic format
|
||||
if anthropicReq.Stream {
|
||||
// Create a pipe to transform the response
|
||||
|
|
@ -96,15 +140,16 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
|||
// Start a goroutine to transform the stream
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
transformOpenAIStreamToAnthropic(resp.Body, pw)
|
||||
defer bodyReader.Close()
|
||||
transformOpenAIStreamToAnthropic(bodyReader, pw)
|
||||
}()
|
||||
|
||||
// Replace the response body with our transformed stream
|
||||
resp.Body = pr
|
||||
} else {
|
||||
// For non-streaming, read and convert the response
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
respBody, err := io.ReadAll(bodyReader)
|
||||
bodyReader.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
|
@ -122,41 +167,315 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
|||
func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{} {
|
||||
messages := []map[string]interface{}{}
|
||||
|
||||
// Add system messages
|
||||
for _, sysMsg := range req.System {
|
||||
// Combine all system messages into a single system message for OpenAI
|
||||
if len(req.System) > 0 {
|
||||
systemContent := ""
|
||||
for i, sysMsg := range req.System {
|
||||
if i > 0 {
|
||||
systemContent += "\n\n"
|
||||
}
|
||||
systemContent += sysMsg.Text
|
||||
}
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": sysMsg.Text,
|
||||
"content": systemContent,
|
||||
})
|
||||
}
|
||||
|
||||
// Add conversation messages
|
||||
for _, msg := range req.Messages {
|
||||
// Get content blocks from the message
|
||||
contentBlocks := msg.GetContentBlocks()
|
||||
content := ""
|
||||
if len(contentBlocks) > 0 {
|
||||
// Use the first text block
|
||||
content = contentBlocks[0].Text
|
||||
}
|
||||
// Handle messages with raw content that may contain tool results
|
||||
if contentArray, ok := msg.Content.([]interface{}); ok {
|
||||
// Check if this message contains tool results
|
||||
hasToolResults := false
|
||||
for _, item := range contentArray {
|
||||
if block, ok := item.(map[string]interface{}); ok {
|
||||
if blockType, hasType := block["type"].(string); hasType && blockType == "tool_result" {
|
||||
hasToolResults = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": content,
|
||||
})
|
||||
if hasToolResults {
|
||||
textContent := ""
|
||||
|
||||
for _, item := range contentArray {
|
||||
if block, ok := item.(map[string]interface{}); ok {
|
||||
if blockType, hasType := block["type"].(string); hasType {
|
||||
if blockType == "text" {
|
||||
if text, hasText := block["text"].(string); hasText {
|
||||
textContent += text + "\n"
|
||||
}
|
||||
} else if blockType == "tool_result" {
|
||||
// Extract tool ID
|
||||
toolID := ""
|
||||
if id, hasID := block["tool_use_id"].(string); hasID {
|
||||
toolID = id
|
||||
}
|
||||
|
||||
// Handle different formats of tool result content
|
||||
resultContent := ""
|
||||
if content, hasContent := block["content"]; hasContent {
|
||||
if contentStr, ok := content.(string); ok {
|
||||
resultContent = contentStr
|
||||
} else if contentList, ok := content.([]interface{}); ok {
|
||||
// If content is a list of blocks, extract text from each
|
||||
for _, c := range contentList {
|
||||
if contentMap, ok := c.(map[string]interface{}); ok {
|
||||
if contentMap["type"] == "text" {
|
||||
if text, ok := contentMap["text"].(string); ok {
|
||||
resultContent += text + "\n"
|
||||
}
|
||||
} else if text, hasText := contentMap["text"]; hasText {
|
||||
// Handle any dict by trying to extract text
|
||||
resultContent += fmt.Sprintf("%v\n", text)
|
||||
} else {
|
||||
// Try to JSON serialize
|
||||
if jsonBytes, err := json.Marshal(contentMap); err == nil {
|
||||
resultContent += string(jsonBytes) + "\n"
|
||||
} else {
|
||||
resultContent += fmt.Sprintf("%v\n", contentMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if contentDict, ok := content.(map[string]interface{}); ok {
|
||||
// Handle dictionary content
|
||||
if contentDict["type"] == "text" {
|
||||
if text, ok := contentDict["text"].(string); ok {
|
||||
resultContent = text
|
||||
}
|
||||
} else {
|
||||
// Try to JSON serialize
|
||||
if jsonBytes, err := json.Marshal(contentDict); err == nil {
|
||||
resultContent = string(jsonBytes)
|
||||
} else {
|
||||
resultContent = fmt.Sprintf("%v", contentDict)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle any other type by converting to string
|
||||
if jsonBytes, err := json.Marshal(content); err == nil {
|
||||
resultContent = string(jsonBytes)
|
||||
} else {
|
||||
resultContent = fmt.Sprintf("%v", content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In OpenAI format, tool results come from the user (matching Python behavior)
|
||||
textContent += fmt.Sprintf("Tool result for %s:\n%s\n", toolID, resultContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add as a single user message with all the content
|
||||
if textContent == "" {
|
||||
textContent = "..."
|
||||
}
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": strings.TrimSpace(textContent),
|
||||
})
|
||||
} else {
|
||||
// Handle regular messages with content blocks
|
||||
content := ""
|
||||
|
||||
for _, item := range contentArray {
|
||||
if block, ok := item.(map[string]interface{}); ok {
|
||||
if blockType, hasType := block["type"].(string); hasType && blockType == "text" {
|
||||
if text, hasText := block["text"].(string); hasText {
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure content is never empty
|
||||
if content == "" {
|
||||
content = "..."
|
||||
}
|
||||
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Handle simple string content
|
||||
contentBlocks := msg.GetContentBlocks()
|
||||
content := ""
|
||||
|
||||
// Concatenate all text blocks
|
||||
for _, block := range contentBlocks {
|
||||
if block.Type == "text" {
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure content is never empty
|
||||
if content == "" {
|
||||
content = "..."
|
||||
}
|
||||
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Check if max_tokens exceeds the model's limit and cap it if necessary
|
||||
maxTokensLimit := 16384 // Assuming this is the limit for the model
|
||||
if req.MaxTokens > maxTokensLimit {
|
||||
fmt.Printf("Warning: max_tokens is too large: %d. Capping to %d.\n", req.MaxTokens, maxTokensLimit)
|
||||
req.MaxTokens = maxTokensLimit
|
||||
}
|
||||
|
||||
// All OpenAI models now use max_completion_tokens instead of deprecated max_tokens
|
||||
openAIReq := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"messages": messages,
|
||||
"temperature": req.Temperature,
|
||||
"max_tokens": req.MaxTokens,
|
||||
"stream": req.Stream,
|
||||
"model": req.Model,
|
||||
"messages": messages,
|
||||
"stream": req.Stream,
|
||||
"max_completion_tokens": req.MaxTokens,
|
||||
}
|
||||
|
||||
// If streaming is enabled, request usage data to be included in the final chunk
|
||||
if req.Stream {
|
||||
openAIReq["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is an o-series model (they don't support temperature)
|
||||
isOSeriesModel := strings.HasPrefix(req.Model, "o1") || strings.HasPrefix(req.Model, "o3")
|
||||
|
||||
// Only include temperature for non-o-series models
|
||||
if !isOSeriesModel {
|
||||
openAIReq["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
fmt.Printf("Using max_completion_tokens=%d for model %s\n", req.MaxTokens, req.Model)
|
||||
|
||||
// Convert Anthropic tools to OpenAI format
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]interface{}, 0, len(req.Tools))
|
||||
for _, tool := range req.Tools {
|
||||
// Ensure tool has required fields
|
||||
if tool.Name == "" {
|
||||
fmt.Printf("Warning: Skipping tool with empty name\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// Build parameters with error checking
|
||||
parameters := make(map[string]interface{})
|
||||
parameters["type"] = tool.InputSchema.Type
|
||||
if parameters["type"] == "" {
|
||||
parameters["type"] = "object" // Default to object type
|
||||
}
|
||||
|
||||
// Handle properties safely with array validation
|
||||
if tool.InputSchema.Properties != nil {
|
||||
// Fix array properties that are missing items field
|
||||
fixedProperties := make(map[string]interface{})
|
||||
for propName, propValue := range tool.InputSchema.Properties {
|
||||
if prop, ok := propValue.(map[string]interface{}); ok {
|
||||
// Check if this is an array type missing items
|
||||
if propType, hasType := prop["type"]; hasType && propType == "array" {
|
||||
if _, hasItems := prop["items"]; !hasItems {
|
||||
// Add default items definition for arrays
|
||||
fmt.Printf("Warning: Array property '%s' in tool '%s' missing items - adding default\n", propName, tool.Name)
|
||||
prop["items"] = map[string]interface{}{"type": "string"}
|
||||
}
|
||||
}
|
||||
fixedProperties[propName] = prop
|
||||
} else {
|
||||
// Keep non-map properties as-is
|
||||
fixedProperties[propName] = propValue
|
||||
}
|
||||
}
|
||||
parameters["properties"] = fixedProperties
|
||||
} else {
|
||||
parameters["properties"] = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Handle required fields
|
||||
if len(tool.InputSchema.Required) > 0 {
|
||||
parameters["required"] = tool.InputSchema.Required
|
||||
}
|
||||
|
||||
// Build function definition
|
||||
functionDef := map[string]interface{}{
|
||||
"name": tool.Name,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
// Add description if present
|
||||
if tool.Description != "" {
|
||||
functionDef["description"] = tool.Description
|
||||
}
|
||||
|
||||
openAITool := map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": functionDef,
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
openAIReq["tools"] = tools
|
||||
|
||||
// Handle tool_choice if present
|
||||
if req.ToolChoice != nil {
|
||||
// Convert Anthropic tool_choice to OpenAI format
|
||||
if toolChoiceMap, ok := req.ToolChoice.(map[string]interface{}); ok {
|
||||
choiceType := toolChoiceMap["type"]
|
||||
switch choiceType {
|
||||
case "auto":
|
||||
openAIReq["tool_choice"] = "auto"
|
||||
case "any":
|
||||
openAIReq["tool_choice"] = "required"
|
||||
case "tool":
|
||||
// Specific tool choice
|
||||
if name, hasName := toolChoiceMap["name"].(string); hasName {
|
||||
openAIReq["tool_choice"] = map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": name,
|
||||
},
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Default to auto if we can't determine
|
||||
openAIReq["tool_choice"] = "auto"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return openAIReq
|
||||
}
|
||||
|
||||
func getMapKeys(m map[string]interface{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||
// This is a simplified transformation
|
||||
// In production, you'd want to handle all fields properly
|
||||
|
|
@ -166,25 +485,97 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
|||
}
|
||||
|
||||
// Extract the assistant's message
|
||||
content := ""
|
||||
var contentBlocks []map[string]interface{}
|
||||
|
||||
if choices, ok := openAIResp["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if msg, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if c, ok := msg["content"].(string); ok {
|
||||
content = c
|
||||
// Handle regular text content
|
||||
if content, ok := msg["content"].(string); ok && content != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok {
|
||||
// Since this proxy forwards to Claude/Anthropic API, we should always
|
||||
// use tool_use blocks so Claude can execute the tools properly
|
||||
// (regardless of which model generated the response)
|
||||
for _, tc := range toolCalls {
|
||||
if toolCall, ok := tc.(map[string]interface{}); ok {
|
||||
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||
// Convert OpenAI tool call to Anthropic tool_use format
|
||||
anthropicToolUse := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCall["id"],
|
||||
"name": function["name"],
|
||||
}
|
||||
|
||||
// Parse the arguments JSON string
|
||||
if argsStr, ok := function["arguments"].(string); ok {
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||
anthropicToolUse["input"] = args
|
||||
} else {
|
||||
// If parsing fails, wrap in a raw field like Python does
|
||||
fmt.Printf("Warning: Failed to parse tool arguments as JSON: %v\n", err)
|
||||
anthropicToolUse["input"] = map[string]interface{}{"raw": argsStr}
|
||||
}
|
||||
} else if args, ok := function["arguments"].(map[string]interface{}); ok {
|
||||
// Already a map, use directly
|
||||
anthropicToolUse["input"] = args
|
||||
} else {
|
||||
// Fallback for any other type
|
||||
anthropicToolUse["input"] = map[string]interface{}{"raw": fmt.Sprintf("%v", function["arguments"])}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, anthropicToolUse)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no content blocks were created, add a default empty text block
|
||||
if len(contentBlocks) == 0 {
|
||||
contentBlocks = []map[string]interface{}{
|
||||
{"type": "text", "text": ""},
|
||||
}
|
||||
}
|
||||
|
||||
// Build Anthropic-style response
|
||||
anthropicResp := map[string]interface{}{
|
||||
"id": openAIResp["id"],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]string{{"type": "text", "text": content}},
|
||||
"content": contentBlocks,
|
||||
"model": openAIResp["model"],
|
||||
"usage": openAIResp["usage"],
|
||||
}
|
||||
|
||||
// Convert OpenAI usage format to Anthropic format
|
||||
if usage, ok := openAIResp["usage"].(map[string]interface{}); ok {
|
||||
anthropicUsage := map[string]interface{}{}
|
||||
|
||||
// Map prompt_tokens to input_tokens
|
||||
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||
}
|
||||
|
||||
// Map completion_tokens to output_tokens
|
||||
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||
}
|
||||
|
||||
// Include total_tokens if needed (though Anthropic format doesn't typically use it)
|
||||
if totalTokens, ok := usage["total_tokens"].(float64); ok {
|
||||
anthropicUsage["total_tokens"] = int(totalTokens)
|
||||
}
|
||||
|
||||
anthropicResp["usage"] = anthropicUsage
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(anthropicResp)
|
||||
|
|
@ -194,7 +585,154 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
|||
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
||||
defer openAIStream.Close()
|
||||
|
||||
// This is a placeholder - in production you'd parse SSE events
|
||||
// and transform them from OpenAI format to Anthropic format
|
||||
io.Copy(anthropicStream, openAIStream)
|
||||
scanner := bufio.NewScanner(openAIStream)
|
||||
var messageStarted bool
|
||||
var contentStarted bool
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Skip empty lines
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle SSE data lines
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// Handle end of stream
|
||||
if data == "[DONE]" {
|
||||
// Send Anthropic-style completion
|
||||
if contentStarted {
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||
}
|
||||
if messageStarted {
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Parse OpenAI response
|
||||
var openAIChunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Debug: Check if this is the final chunk
|
||||
if choices, ok := openAIChunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if finishReason, ok := choice["finish_reason"]; ok && finishReason != nil {
|
||||
fmt.Printf("🏁 Final chunk detected with finish_reason: %v\n", finishReason)
|
||||
fmt.Printf("🏁 Full final chunk: %+v\n", openAIChunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for usage data BEFORE processing choices
|
||||
// According to OpenAI docs, usage is sent in the final chunk with empty choices array
|
||||
if usage, hasUsage := openAIChunk["usage"].(map[string]interface{}); hasUsage {
|
||||
fmt.Printf("🔍 Found usage data in OpenAI stream: %+v\n", usage)
|
||||
fmt.Printf("🔍 Full OpenAI chunk with usage: %+v\n", openAIChunk)
|
||||
|
||||
// Convert OpenAI usage to Anthropic format
|
||||
anthropicUsage := map[string]interface{}{}
|
||||
|
||||
// Handle both float64 and int types
|
||||
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||
anthropicUsage["input_tokens"] = int(promptTokens)
|
||||
} else if promptTokens, ok := usage["prompt_tokens"].(int); ok {
|
||||
anthropicUsage["input_tokens"] = promptTokens
|
||||
}
|
||||
|
||||
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
|
||||
anthropicUsage["output_tokens"] = int(completionTokens)
|
||||
} else if completionTokens, ok := usage["completion_tokens"].(int); ok {
|
||||
anthropicUsage["output_tokens"] = completionTokens
|
||||
}
|
||||
|
||||
if len(anthropicUsage) > 0 {
|
||||
// Send usage data in a message_delta event
|
||||
usageDelta := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{},
|
||||
"usage": anthropicUsage,
|
||||
}
|
||||
usageJSON, _ := json.Marshal(usageDelta)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract choices array
|
||||
choices, ok := openAIChunk["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
// Skip further processing if no choices, but we already handled usage above
|
||||
continue
|
||||
}
|
||||
|
||||
choice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := choice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle first chunk - send message_start
|
||||
if !messageStarted {
|
||||
messageStarted = true
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": openAIChunk["id"],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": openAIChunk["model"],
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
// Empty usage - will be updated in final chunk
|
||||
},
|
||||
},
|
||||
}
|
||||
startJSON, _ := json.Marshal(messageStart)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
|
||||
}
|
||||
|
||||
// Handle content
|
||||
if content, hasContent := delta["content"].(string); hasContent && content != "" {
|
||||
if !contentStarted {
|
||||
contentStarted = true
|
||||
// Send content_block_start
|
||||
blockStart := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
blockStartJSON, _ := json.Marshal(blockStart)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
|
||||
}
|
||||
|
||||
// Send content_block_delta
|
||||
contentDelta := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": content,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(contentDelta)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ func (r *ModelRouter) extractStaticPrompt(systemPrompt string) string {
|
|||
}
|
||||
|
||||
func (r *ModelRouter) loadCustomAgents() {
|
||||
r.logger.Printf("Loading custom agents from mappings: %+v", r.subagentMappings)
|
||||
|
||||
for agentName, targetModel := range r.subagentMappings {
|
||||
// Try loading from project level first, then user level
|
||||
paths := []string{
|
||||
|
|
@ -67,20 +69,28 @@ func (r *ModelRouter) loadCustomAgents() {
|
|||
}
|
||||
|
||||
for _, path := range paths {
|
||||
r.logger.Printf("Trying to load agent from: %s", path)
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
r.logger.Printf("Failed to read %s: %v", path, err)
|
||||
continue
|
||||
}
|
||||
|
||||
r.logger.Printf("Successfully read agent file: %s (size: %d bytes)", path, len(content))
|
||||
|
||||
// Parse agent file: metadata\n---\nsystem prompt
|
||||
parts := strings.Split(string(content), "\n---\n")
|
||||
r.logger.Printf("Agent file parts: %d", len(parts))
|
||||
if len(parts) >= 2 {
|
||||
systemPrompt := strings.TrimSpace(parts[1])
|
||||
r.logger.Printf("System prompt (first 200 chars): %.200s", systemPrompt)
|
||||
|
||||
// Extract only the static part (before "Notes:" if it exists)
|
||||
staticPrompt := r.extractStaticPrompt(systemPrompt)
|
||||
hash := r.hashString(staticPrompt)
|
||||
|
||||
r.logger.Printf("Static prompt after extraction (first 200 chars): %.200s", staticPrompt)
|
||||
|
||||
// Determine provider for the target model
|
||||
providerName := r.getProviderNameForModel(targetModel)
|
||||
|
||||
|
|
@ -91,20 +101,35 @@ func (r *ModelRouter) loadCustomAgents() {
|
|||
FullPrompt: staticPrompt,
|
||||
}
|
||||
|
||||
r.logger.Printf("Loaded custom agent: %s (hash: %s) -> %s",
|
||||
agentName, hash, targetModel)
|
||||
r.logger.Printf("Loaded custom agent: %s (hash: %s) -> %s (provider: %s)",
|
||||
agentName, hash, targetModel, providerName)
|
||||
break
|
||||
} else {
|
||||
r.logger.Printf("Invalid agent file format for %s: expected at least 2 parts separated by ---", agentName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Printf("Total custom agents loaded: %d", len(r.customAgentPrompts))
|
||||
}
|
||||
|
||||
// RouteRequest determines which provider and model to use for a request
|
||||
func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provider, string, error) {
|
||||
originalModel := req.Model
|
||||
|
||||
r.logger.Printf("RouteRequest: Model=%s, System messages count=%d", originalModel, len(req.System))
|
||||
|
||||
// Debug: Print loaded custom agents
|
||||
r.logger.Printf("Loaded custom agents: %d", len(r.customAgentPrompts))
|
||||
for hash, def := range r.customAgentPrompts {
|
||||
r.logger.Printf(" Agent: %s (hash: %s) -> %s", def.Name, hash, def.TargetModel)
|
||||
}
|
||||
|
||||
// Claude Code pattern: Check if we have exactly 2 system messages
|
||||
if len(req.System) == 2 {
|
||||
r.logger.Printf("System[0]: %.100s...", req.System[0].Text)
|
||||
r.logger.Printf("System[1]: %.100s...", req.System[1].Text)
|
||||
|
||||
// First should be "You are Claude Code..."
|
||||
if strings.Contains(req.System[0].Text, "You are Claude Code") {
|
||||
// Second message could be either:
|
||||
|
|
@ -117,6 +142,9 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
|||
staticPrompt := r.extractStaticPrompt(fullPrompt)
|
||||
promptHash := r.hashString(staticPrompt)
|
||||
|
||||
r.logger.Printf("Static prompt hash: %s", promptHash)
|
||||
r.logger.Printf("Static prompt (first 200 chars): %.200s", staticPrompt)
|
||||
|
||||
// Check if this matches a known custom agent
|
||||
if definition, exists := r.customAgentPrompts[promptHash]; exists {
|
||||
r.logger.Printf("Subagent '%s' detected -> routing to %s",
|
||||
|
|
@ -133,7 +161,7 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
|||
}
|
||||
|
||||
// This is a regular Claude Code request (not a known subagent)
|
||||
r.logger.Printf("Regular Claude Code request detected, using original model %s", originalModel)
|
||||
r.logger.Printf("No matching subagent found for hash %s, using original model %s", promptHash, originalModel)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -150,14 +178,17 @@ func (r *ModelRouter) RouteRequest(req *model.AnthropicRequest) (provider.Provid
|
|||
func (r *ModelRouter) hashString(s string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
||||
fullHash := hex.EncodeToString(h.Sum(nil))
|
||||
shortHash := fullHash[:16]
|
||||
r.logger.Printf("Hashing string (length: %d) -> %s", len(s), shortHash)
|
||||
return shortHash
|
||||
}
|
||||
|
||||
func (r *ModelRouter) getProviderNameForModel(model string) string {
|
||||
// Map models to providers
|
||||
if strings.HasPrefix(model, "claude") {
|
||||
return "anthropic"
|
||||
} else if strings.HasPrefix(model, "gpt") {
|
||||
} else if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o") {
|
||||
return "openai"
|
||||
}
|
||||
// Default to anthropic
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ func (s *sqliteStorageService) createTables() error {
|
|||
prompt_grade TEXT,
|
||||
response TEXT,
|
||||
model TEXT,
|
||||
original_model TEXT,
|
||||
routed_model TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
|
|
@ -74,8 +76,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
|||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err = s.db.Exec(query,
|
||||
|
|
@ -88,6 +90,8 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
|||
request.UserAgent,
|
||||
request.ContentType,
|
||||
request.Model,
|
||||
request.OriginalModel,
|
||||
request.RoutedModel,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -108,7 +112,7 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
|||
// Get paginated results
|
||||
offset := (page - 1) * limit
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
|
|
@ -138,6 +142,8 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
|||
&req.ContentType,
|
||||
&promptGradeJSON,
|
||||
&responseJSON,
|
||||
&req.OriginalModel,
|
||||
&req.RoutedModel,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error scanning row: %v", err)
|
||||
|
|
@ -228,7 +234,7 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
|
|||
|
||||
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
WHERE id LIKE ?
|
||||
ORDER BY timestamp DESC
|
||||
|
|
@ -251,6 +257,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
|||
&req.ContentType,
|
||||
&promptGradeJSON,
|
||||
&responseJSON,
|
||||
&req.OriginalModel,
|
||||
&req.RoutedModel,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -294,7 +302,7 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
|
|||
|
||||
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
|
@ -331,13 +339,15 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
|||
&req.ContentType,
|
||||
&promptGradeJSON,
|
||||
&responseJSON,
|
||||
&req.OriginalModel,
|
||||
&req.RoutedModel,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error scanning row: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
|
||||
// log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
|
||||
|
||||
// Unmarshal JSON fields
|
||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue