From 6cda36312aa9062294b28256d2068a468b225b3d Mon Sep 17 00:00:00 2001 From: sid Date: Thu, 19 Mar 2026 18:52:09 -0600 Subject: [PATCH] Harden streaming, pagination, and config loading --- .env.example | 13 +- config.yaml.example | 67 +- proxy/cmd/proxy/main.go | 6 +- proxy/internal/config/config.go | 111 +++- proxy/internal/config/config_test.go | 30 + proxy/internal/handler/handlers.go | 77 +-- proxy/internal/handler/utils.go | 62 ++ proxy/internal/model/models.go | 1 + proxy/internal/provider/anthropic.go | 10 +- proxy/internal/provider/openai.go | 144 ++++- proxy/internal/provider/openai_test.go | 63 ++ proxy/internal/service/storage.go | 23 +- proxy/internal/service/storage_sqlite.go | 586 ++++++++++++++---- proxy/internal/service/storage_sqlite_test.go | 69 +++ proxy/internal/sse/sse.go | 30 + proxy/internal/sse/sse_test.go | 31 + 16 files changed, 1079 insertions(+), 244 deletions(-) create mode 100644 proxy/internal/config/config_test.go create mode 100644 proxy/internal/provider/openai_test.go create mode 100644 proxy/internal/service/storage_sqlite_test.go create mode 100644 proxy/internal/sse/sse.go create mode 100644 proxy/internal/sse/sse_test.go diff --git a/.env.example b/.env.example index 1fd02fb..eddb829 100644 --- a/.env.example +++ b/.env.example @@ -12,5 +12,16 @@ ANTHROPIC_FORWARD_URL=https://api.anthropic.com ANTHROPIC_VERSION=2023-06-01 ANTHROPIC_MAX_RETRIES=3 +# OpenAI Configuration (for subagent routing) +# OPENAI_API_KEY=your-openai-api-key +# OPENAI_BASE_URL=https://api.openai.com +# OPENAI_ALLOW_CLIENT_API_KEY=false +# OPENAI_CLIENT_API_KEY_HEADER=x-openai-api-key + # Storage Configuration -DATABASE_PATH=requests.db \ No newline at end of file +DB_PATH=requests.db + +# CORS Configuration (comma-separated values) +# CORS_ALLOWED_ORIGINS=* +# CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS +# CORS_ALLOWED_HEADERS=* \ No newline at end of file diff --git a/config.yaml.example b/config.yaml.example index aecc68a..34338a8 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -24,20 +24,50 @@ providers: anthropic: # Base URL for Anthropic API (can be changed for custom endpoints) base_url: "https://api.anthropic.com" - + # Maximum number of retries for failed requests max_retries: 3 - + # OpenAI configuration openai: # API key for OpenAI # Can also be set via OPENAI_API_KEY environment variable # api_key: "..." - + # Base URL for OpenAI API (can be changed for custom endpoints) # Can also be set via OPENAI_BASE_URL environment variable # base_url: "https://api.openai.com" + # Allow clients to provide their own API key via header + # Can also be set via OPENAI_ALLOW_CLIENT_API_KEY environment variable + allow_client_api_key: false + + # Header name for client-provided API key (default: x-openai-api-key) + # Can also be set via OPENAI_CLIENT_API_KEY_HEADER environment variable + client_api_key_header: "x-openai-api-key" + +# CORS Configuration +# Controls Cross-Origin Resource Sharing for the web UI +cors: + # Allowed origins (use ["*"] for all origins - not recommended for production) + # Can also be set via CORS_ALLOWED_ORIGINS environment variable (comma-separated) + allowed_origins: + - "*" + + # Allowed HTTP methods + # Can also be set via CORS_ALLOWED_METHODS environment variable (comma-separated) + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + - "OPTIONS" + + # Allowed headers (use ["*"] for all headers) + # Can also be set via CORS_ALLOWED_HEADERS environment variable (comma-separated) + allowed_headers: + - "*" + # Storage configuration storage: # SQLite database path for storing request history @@ -69,23 +99,30 @@ subagents: # The following environment variables will override the YAML configuration: # # Server: -# PORT - Server port -# READ_TIMEOUT - Read timeout duration -# WRITE_TIMEOUT - Write timeout duration -# IDLE_TIMEOUT - Idle timeout duration +# PORT - Server port +# READ_TIMEOUT - Read timeout duration +# WRITE_TIMEOUT - Write timeout duration +# IDLE_TIMEOUT - Idle timeout duration # # Anthropic: -# ANTHROPIC_FORWARD_URL - Anthropic base URL -# ANTHROPIC_VERSION - Anthropic API version -# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests +# ANTHROPIC_FORWARD_URL - Anthropic base URL +# ANTHROPIC_VERSION - Anthropic API version +# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests # # OpenAI: -# OPENAI_API_KEY - OpenAI API key -# OPENAI_BASE_URL - OpenAI base URL +# OPENAI_API_KEY - OpenAI API key +# OPENAI_BASE_URL - OpenAI base URL +# OPENAI_ALLOW_CLIENT_API_KEY - Allow client-provided API keys (true/false) +# OPENAI_CLIENT_API_KEY_HEADER - Header name for client API key # # Storage: -# DB_PATH - Database file path +# DB_PATH - Database file path +# +# CORS: +# CORS_ALLOWED_ORIGINS - Comma-separated allowed origins +# CORS_ALLOWED_METHODS - Comma-separated allowed methods +# CORS_ALLOWED_HEADERS - Comma-separated allowed headers # # Subagents: -# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs -# Example: "code-reviewer:claude-3-5-sonnet" \ No newline at end of file +# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs +# Example: "code-reviewer:claude-3-5-sonnet" \ No newline at end of file diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index 8623203..d5ef215 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -50,9 +50,9 @@ func main() { r := mux.NewRouter() corsHandler := handlers.CORS( - handlers.AllowedOrigins([]string{"*"}), - handlers.AllowedMethods([]string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}), - handlers.AllowedHeaders([]string{"*"}), + handlers.AllowedOrigins(cfg.CORS.AllowedOrigins), + handlers.AllowedMethods(cfg.CORS.AllowedMethods), + handlers.AllowedHeaders(cfg.CORS.AllowedHeaders), ) r.Use(middleware.Logging) diff --git a/proxy/internal/config/config.go b/proxy/internal/config/config.go index eb3e6bb..1e5b110 100644 --- a/proxy/internal/config/config.go +++ b/proxy/internal/config/config.go @@ -1,9 +1,11 @@ package config import ( + "fmt" "os" "path/filepath" "strconv" + "strings" "time" "github.com/joho/godotenv" @@ -15,9 +17,16 @@ type Config struct { Providers ProvidersConfig `yaml:"providers"` Storage StorageConfig `yaml:"storage"` Subagents SubagentsConfig `yaml:"subagents"` + CORS CORSConfig `yaml:"cors"` Anthropic AnthropicConfig } +type CORSConfig struct { + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` +} + type ServerConfig struct { Port string `yaml:"port"` Timeouts TimeoutsConfig `yaml:"timeouts"` @@ -45,8 +54,10 @@ type AnthropicProviderConfig struct { } type OpenAIProviderConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + AllowClientAPIKey bool `yaml:"allow_client_api_key"` // Allow clients to provide their own API key + ClientAPIKeyHeader string `yaml:"client_api_key_header"` // Header name for client API key (default: x-openai-api-key) } type AnthropicConfig struct { @@ -92,8 +103,10 @@ func Load() (*Config, error) { MaxRetries: 3, }, OpenAI: OpenAIProviderConfig{ - BaseURL: "https://api.openai.com", - APIKey: "", + BaseURL: "https://api.openai.com", + APIKey: "", + AllowClientAPIKey: false, + ClientAPIKeyHeader: "x-openai-api-key", }, }, Storage: StorageConfig{ @@ -103,25 +116,17 @@ func Load() (*Config, error) { Enable: false, Mappings: make(map[string]string), }, + CORS: CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"*"}, + }, } - // Try to load config.yaml from the project root - // The proxy binary is in proxy/ directory, config.yaml is in the parent - configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml") - - // If that doesn't work, try relative to current directory - if _, err := os.Stat(configPath); err != nil { - // Try common locations relative to where the binary might be run - for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} { - if _, err := os.Stat(tryPath); err == nil { - configPath = tryPath - break - } - } + if err := loadFirstAvailableConfig(cfg, candidateConfigPaths()); err != nil { + return nil, err } - cfg.loadFromFile(configPath) - // Apply environment variable overrides AFTER loading from file if envPort := os.Getenv("PORT"); envPort != "" { cfg.Server.Port = envPort @@ -154,12 +159,29 @@ func Load() (*Config, error) { if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" { cfg.Providers.OpenAI.APIKey = envKey } + if envAllow := os.Getenv("OPENAI_ALLOW_CLIENT_API_KEY"); envAllow != "" { + cfg.Providers.OpenAI.AllowClientAPIKey = envAllow == "true" || envAllow == "1" + } + if envHeader := os.Getenv("OPENAI_CLIENT_API_KEY_HEADER"); envHeader != "" { + cfg.Providers.OpenAI.ClientAPIKeyHeader = envHeader + } // Override storage settings if envPath := os.Getenv("DB_PATH"); envPath != "" { cfg.Storage.DBPath = envPath } + // Override CORS settings (comma-separated values) + if envOrigins := os.Getenv("CORS_ALLOWED_ORIGINS"); envOrigins != "" { + cfg.CORS.AllowedOrigins = splitAndTrim(envOrigins) + } + if envMethods := os.Getenv("CORS_ALLOWED_METHODS"); envMethods != "" { + cfg.CORS.AllowedMethods = splitAndTrim(envMethods) + } + if envHeaders := os.Getenv("CORS_ALLOWED_HEADERS"); envHeaders != "" { + cfg.CORS.AllowedHeaders = splitAndTrim(envHeaders) + } + // Sync legacy Anthropic config cfg.Anthropic = AnthropicConfig{ BaseURL: cfg.Providers.Anthropic.BaseURL, @@ -203,6 +225,45 @@ func (c *Config) loadFromFile(path string) error { return yaml.Unmarshal(data, c) } +func candidateConfigPaths() []string { + paths := []string{ + filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml"), + "config.yaml", + "../config.yaml", + "../../config.yaml", + } + + seen := make(map[string]struct{}, len(paths)) + unique := make([]string, 0, len(paths)) + for _, path := range paths { + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + unique = append(unique, path) + } + + return unique +} + +func loadFirstAvailableConfig(cfg *Config, paths []string) error { + for _, path := range paths { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + continue + } + return fmt.Errorf("failed to stat config file %q: %w", path, err) + } + + if err := cfg.loadFromFile(path); err != nil { + return fmt.Errorf("failed to load config file %q: %w", path, err) + } + return nil + } + + return nil +} + func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value @@ -237,3 +298,15 @@ func getInt(key string, defaultValue int) int { return intValue } + +func splitAndTrim(s string) []string { + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/proxy/internal/config/config_test.go b/proxy/internal/config/config_test.go new file mode 100644 index 0000000..afd0a8f --- /dev/null +++ b/proxy/internal/config/config_test.go @@ -0,0 +1,30 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadFirstAvailableConfigReturnsParseError(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := os.WriteFile(configPath, []byte("server: ["), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + cfg := &Config{} + if err := loadFirstAvailableConfig(cfg, []string{configPath}); err == nil { + t.Fatal("expected parse error, got nil") + } +} + +func TestLoadFirstAvailableConfigSkipsMissingFiles(t *testing.T) { + cfg := &Config{} + if err := loadFirstAvailableConfig(cfg, []string{ + filepath.Join(t.TempDir(), "missing.yaml"), + }); err != nil { + t.Fatalf("expected nil error for missing config, got %v", err) + } +} diff --git a/proxy/internal/handler/handlers.go b/proxy/internal/handler/handlers.go index 05d1360..3f3c6f6 100644 --- a/proxy/internal/handler/handlers.go +++ b/proxy/internal/handler/handlers.go @@ -1,7 +1,6 @@ package handler import ( - "bufio" "bytes" "crypto/rand" "encoding/hex" @@ -20,6 +19,7 @@ import ( "github.com/seifghazi/claude-code-monitor/internal/model" "github.com/seifghazi/claude-code-monitor/internal/service" + "github.com/seifghazi/claude-code-monitor/internal/sse" ) type Handler struct { @@ -179,37 +179,13 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) { modelFilter = "all" } - // Get all requests with model filter applied at storage level - allRequests, err := h.storageService.GetAllRequests(modelFilter) + requests, total, err := h.storageService.GetRequests(page, limit, modelFilter) if err != nil { log.Printf("Error getting requests: %v", err) http.Error(w, "Failed to get requests", http.StatusInternalServerError) return } - // Convert pointers to values for consistency - requests := make([]model.RequestLog, len(allRequests)) - for i, req := range allRequests { - if req != nil { - requests[i] = *req - } - } - - // Calculate total before pagination - total := len(requests) - - // Apply pagination - start := (page - 1) * limit - end := start + limit - if start >= len(requests) { - requests = []model.RequestLog{} - } else { - if end > len(requests) { - end = len(requests) - } - requests = requests[start:end] - } - w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(struct { Requests []model.RequestLog `json:"requests"` @@ -242,6 +218,8 @@ func (h *Handler) NotFound(w http.ResponseWriter, r *http.Request) { } func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) { + // Forward important upstream headers (rate limits, request IDs, etc.) + ForwardResponseHeaders(w, resp) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -278,16 +256,17 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp var messageID string var modelName string var stopReason string + var sawMessageStop bool - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() + streamErr := sse.ForEachLine(resp.Body, func(line string) error { if line == "" || !strings.HasPrefix(line, "data:") { - continue + return nil } streamingChunks = append(streamingChunks, line) - fmt.Fprintf(w, "%s\n\n", line) + if _, err := fmt.Fprintf(w, "%s\n\n", line); err != nil { + return err + } if f, ok := w.(http.Flusher); ok { f.Flush() } @@ -298,7 +277,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp var genericEvent map[string]interface{} if err := json.Unmarshal([]byte(jsonData), &genericEvent); err != nil { log.Printf("⚠️ Error unmarshalling streaming event: %v", err) - continue + return nil } // Capture metadata from message_start event @@ -347,7 +326,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp var event model.StreamingEvent if err := json.Unmarshal([]byte(jsonData), &event); err != nil { // Skip if structured parsing fails, but we already got the usage data above - continue + return nil } switch event.Type { @@ -366,8 +345,13 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp toolCalls = append(toolCalls, *event.ContentBlock) } case "message_stop": - // End of stream - scanner will exit on its own + sawMessageStop = true } + return nil + }) + + if streamErr == nil && !sawMessageStop { + streamErr = io.ErrUnexpectedEOF } responseLog := &model.ResponseLog{ @@ -378,6 +362,9 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp IsStreaming: true, CompletedAt: time.Now().Format(time.RFC3339), } + if streamErr != nil { + responseLog.StreamError = streamErr.Error() + } // Create a structured response body that matches Anthropic's format var contentBlocks []model.AnthropicContentBlock @@ -417,14 +404,31 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp log.Printf("❌ Error updating request with streaming response: %v", err) } - if err := scanner.Err(); err != nil { - log.Printf("❌ Streaming error: %v", err) + if streamErr != nil { + log.Printf("❌ Streaming error: %v", streamErr) + // Send error event to client in Anthropic streaming format + errorEvent := map[string]interface{}{ + "type": "error", + "error": map[string]interface{}{ + "type": "stream_error", + "message": fmt.Sprintf("Stream interrupted: %v", streamErr), + }, + } + if errorJSON, jsonErr := json.Marshal(errorEvent); jsonErr == nil { + fmt.Fprintf(w, "data: %s\n\n", errorJSON) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } } else { log.Println("✅ Streaming response completed") } } func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) { + // Forward important upstream headers (rate limits, request IDs, etc.) + ForwardResponseHeaders(w, resp) + responseBytes, err := io.ReadAll(resp.Body) if err != nil { log.Printf("❌ Error reading Anthropic response: %v", err) @@ -464,6 +468,7 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R if resp.StatusCode != http.StatusOK { log.Printf("❌ Anthropic API error: %d %s", resp.StatusCode, string(responseBytes)) + // Headers already forwarded at start of function w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) w.Write(responseBytes) diff --git a/proxy/internal/handler/utils.go b/proxy/internal/handler/utils.go index cf88be5..571f9f1 100644 --- a/proxy/internal/handler/utils.go +++ b/proxy/internal/handler/utils.go @@ -11,6 +11,68 @@ import ( "github.com/seifghazi/claude-code-monitor/internal/model" ) +// Headers that should be forwarded from upstream responses to clients +var forwardableResponseHeaders = []string{ + // Anthropic rate limit headers + "anthropic-ratelimit-requests-limit", + "anthropic-ratelimit-requests-remaining", + "anthropic-ratelimit-requests-reset", + "anthropic-ratelimit-tokens-limit", + "anthropic-ratelimit-tokens-remaining", + "anthropic-ratelimit-tokens-reset", + // Standard rate limit headers + "x-ratelimit-limit", + "x-ratelimit-remaining", + "x-ratelimit-reset", + "retry-after", + // Request tracking + "x-request-id", + "request-id", + // Anthropic specific + "anthropic-organization-id", + // OpenAI specific + "openai-organization", + "openai-processing-ms", + "openai-version", + "x-request-id", +} + +// ForwardResponseHeaders copies important headers from upstream response to client response +func ForwardResponseHeaders(w http.ResponseWriter, resp *http.Response) { + for _, header := range forwardableResponseHeaders { + if values := resp.Header.Values(header); len(values) > 0 { + for _, value := range values { + w.Header().Add(header, value) + } + } + } +} + +// CopyAllResponseHeaders copies all non-hop-by-hop headers from upstream to client +func CopyAllResponseHeaders(w http.ResponseWriter, resp *http.Response) { + hopByHopHeaders := map[string]bool{ + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "te": true, + "trailers": true, + "transfer-encoding": true, + "upgrade": true, + "content-encoding": true, // We handle decompression ourselves + "content-length": true, // May change after decompression + } + + for key, values := range resp.Header { + if hopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, value := range values { + w.Header().Add(key, value) + } + } +} + // SanitizeHeaders removes sensitive headers before logging/storage func SanitizeHeaders(headers http.Header) http.Header { sanitized := make(http.Header) diff --git a/proxy/internal/model/models.go b/proxy/internal/model/models.go index a5d03c0..1e5f5ef 100644 --- a/proxy/internal/model/models.go +++ b/proxy/internal/model/models.go @@ -45,6 +45,7 @@ type ResponseLog struct { Headers map[string][]string `json:"headers"` Body json.RawMessage `json:"body,omitempty"` BodyText string `json:"bodyText,omitempty"` + StreamError string `json:"streamError,omitempty"` ResponseTime int64 `json:"responseTime"` StreamingChunks []string `json:"streamingChunks,omitempty"` IsStreaming bool `json:"isStreaming"` diff --git a/proxy/internal/provider/anthropic.go b/proxy/internal/provider/anthropic.go index 64b3039..9c39389 100644 --- a/proxy/internal/provider/anthropic.go +++ b/proxy/internal/provider/anthropic.go @@ -66,8 +66,14 @@ func (p *AnthropicProvider) ForwardRequest(ctx context.Context, originalReq *htt proxyReq.Header.Set("anthropic-version", p.config.Version) } - // Support gzip encoding - proxyReq.Header.Set("Accept-Encoding", "gzip") + // Handle Accept-Encoding: We accept gzip from upstream for efficiency, + // but we decompress before forwarding to the client. This is transparent + // to the client - they receive uncompressed data regardless of what they requested. + // We preserve gzip if client already requested it, otherwise add it. + clientEncoding := proxyReq.Header.Get("Accept-Encoding") + if clientEncoding == "" || !strings.Contains(clientEncoding, "gzip") { + proxyReq.Header.Set("Accept-Encoding", "gzip") + } // Forward the request resp, err := p.client.Do(proxyReq) diff --git a/proxy/internal/provider/openai.go b/proxy/internal/provider/openai.go index 5973026..7afcc16 100644 --- a/proxy/internal/provider/openai.go +++ b/proxy/internal/provider/openai.go @@ -1,7 +1,6 @@ package provider import ( - "bufio" "bytes" "compress/gzip" "context" @@ -15,6 +14,7 @@ import ( "github.com/seifghazi/claude-code-monitor/internal/config" "github.com/seifghazi/claude-code-monitor/internal/model" + "github.com/seifghazi/claude-code-monitor/internal/sse" ) type OpenAIProvider struct { @@ -78,13 +78,29 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R // Remove Anthropic-specific headers proxyReq.Header.Del("anthropic-version") proxyReq.Header.Del("x-api-key") + proxyReq.Header.Del("Authorization") + + // Determine which API key to use + apiKey := p.config.APIKey + + // Check for client-provided API key if allowed + if p.config.AllowClientAPIKey && p.config.ClientAPIKeyHeader != "" { + if clientKey := originalReq.Header.Get(p.config.ClientAPIKeyHeader); clientKey != "" { + apiKey = clientKey + } + } // Add OpenAI headers - if p.config.APIKey != "" { - proxyReq.Header.Set("Authorization", "Bearer "+p.config.APIKey) + if apiKey != "" { + proxyReq.Header.Set("Authorization", "Bearer "+apiKey) } proxyReq.Header.Set("Content-Type", "application/json") + // Remove the client API key header from the proxied request + if p.config.ClientAPIKeyHeader != "" { + proxyReq.Header.Del(p.config.ClientAPIKeyHeader) + } + // Forward the request resp, err := p.client.Do(proxyReq) if err != nil { @@ -139,9 +155,12 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R // Start a goroutine to transform the stream go func() { - defer pw.Close() defer bodyReader.Close() - transformOpenAIStreamToAnthropic(bodyReader, pw) + if err := transformOpenAIStreamToAnthropic(bodyReader, pw); err != nil { + _ = pw.CloseWithError(err) + return + } + _ = pw.Close() }() // Replace the response body with our transformed stream @@ -332,10 +351,10 @@ func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{ }) } } - // Check if max_tokens exceeds the model's limit and cap it if necessary - maxTokensLimit := 16384 // Assuming this is the limit for the model - if req.MaxTokens > maxTokensLimit { - // Capping max_tokens to model limit + // Get model-specific max token limit + // Let the API handle validation for unknown models rather than using arbitrary caps + maxTokensLimit := getModelMaxTokens(req.Model) + if maxTokensLimit > 0 && req.MaxTokens > maxTokensLimit { req.MaxTokens = maxTokensLimit } @@ -473,6 +492,52 @@ func min(a, b int) int { return b } +// getModelMaxTokens returns the max output tokens for known models +// Returns 0 for unknown models, letting the API handle validation +func getModelMaxTokens(model string) int { + // Model-specific max completion token limits + modelLimits := map[string]int{ + // GPT-4 Turbo and GPT-4o models + "gpt-4-turbo": 4096, + "gpt-4-turbo-preview": 4096, + "gpt-4o": 16384, + "gpt-4o-mini": 16384, + "gpt-4o-2024-05-13": 16384, + "gpt-4o-2024-08-06": 16384, + // GPT-4 models + "gpt-4": 8192, + "gpt-4-32k": 8192, + "gpt-4-0613": 8192, + // GPT-3.5 models + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 4096, + "gpt-3.5-turbo-0125": 4096, + "gpt-3.5-turbo-1106": 4096, + // o1 reasoning models + "o1": 100000, + "o1-preview": 32768, + "o1-mini": 65536, + // o3 reasoning models (estimated based on o1 patterns) + "o3": 100000, + "o3-mini": 65536, + } + + // Check for exact match first + if limit, ok := modelLimits[model]; ok { + return limit + } + + // Check for prefix matches for versioned models + for prefix, limit := range modelLimits { + if strings.HasPrefix(model, prefix) { + return limit + } + } + + // Return 0 for unknown models - let the API validate + return 0 +} + func transformOpenAIResponseToAnthropic(respBody []byte) []byte { // This is a simplified transformation // In production, you'd want to handle all fields properly @@ -579,19 +644,16 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte { return result } -func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) { - defer openAIStream.Close() - - scanner := bufio.NewScanner(openAIStream) +func transformOpenAIStreamToAnthropic(openAIStream io.Reader, anthropicStream io.Writer) error { var messageStarted bool var contentStarted bool + var sawDone bool - for scanner.Scan() { - line := scanner.Text() + err := sse.ForEachLine(openAIStream, func(line string) error { // Skip empty lines if line == "" { - continue + return nil } // Handle SSE data lines @@ -600,21 +662,28 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea // Handle end of stream if data == "[DONE]" { + sawDone = true // Send Anthropic-style completion if contentStarted { - fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n") + if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n"); err != nil { + return err + } } if messageStarted { - fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n") - fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n") + if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n"); err != nil { + return err + } + if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n"); err != nil { + return err + } } - break + return nil } // Parse OpenAI response var openAIChunk map[string]interface{} if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil { - continue + return nil } // Check for usage data BEFORE processing choices @@ -644,7 +713,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea "usage": anthropicUsage, } usageJSON, _ := json.Marshal(usageDelta) - fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON) + if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON); err != nil { + return err + } } } @@ -652,17 +723,17 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea choices, ok := openAIChunk["choices"].([]interface{}) if !ok || len(choices) == 0 { // Skip further processing if no choices, but we already handled usage above - continue + return nil } choice, ok := choices[0].(map[string]interface{}) if !ok { - continue + return nil } delta, ok := choice["delta"].(map[string]interface{}) if !ok { - continue + return nil } // Handle first chunk - send message_start @@ -684,7 +755,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea }, } startJSON, _ := json.Marshal(messageStart) - fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON) + if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON); err != nil { + return err + } } // Handle content @@ -701,7 +774,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea }, } blockStartJSON, _ := json.Marshal(blockStart) - fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON) + if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON); err != nil { + return err + } } // Send content_block_delta @@ -714,9 +789,22 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea }, } deltaJSON, _ := json.Marshal(contentDelta) - fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON) + if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON); err != nil { + return err + } } } + return nil + }) + + if err != nil { + return err } + + if !sawDone { + return io.ErrUnexpectedEOF + } + + return nil } diff --git a/proxy/internal/provider/openai_test.go b/proxy/internal/provider/openai_test.go new file mode 100644 index 0000000..892ad2c --- /dev/null +++ b/proxy/internal/provider/openai_test.go @@ -0,0 +1,63 @@ +package provider + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/seifghazi/claude-code-monitor/internal/config" +) + +func TestOpenAIProviderForwardRequestClearsInboundAuthorization(t *testing.T) { + var gotAuthorization string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuthorization = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"resp_1","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`)) + })) + defer server.Close() + + provider := NewOpenAIProvider(&config.OpenAIProviderConfig{ + BaseURL: server.URL, + }).(*OpenAIProvider) + + req := httptest.NewRequest(http.MethodPost, "http://proxy.local/v1/messages", strings.NewReader(`{"model":"gpt-4o","messages":[],"max_tokens":16}`)) + req.Header.Set("Authorization", "Bearer should-not-leak") + req.Header.Set("Content-Type", "application/json") + + resp, err := provider.ForwardRequest(context.Background(), req) + if err != nil { + t.Fatalf("ForwardRequest() error = %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if gotAuthorization != "" { + t.Fatalf("expected forwarded Authorization header to be empty, got %q", gotAuthorization) + } +} + +func TestTransformOpenAIStreamToAnthropicHandlesLargeEvents(t *testing.T) { + largeContent := strings.Repeat("x", 128*1024) + openAIStream := strings.NewReader("data: {\"id\":\"chatcmpl_1\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"" + largeContent + "\"}}]}\n\n" + + "data: [DONE]\n\n") + + var output strings.Builder + if err := transformOpenAIStreamToAnthropic(openAIStream, &output); err != nil { + t.Fatalf("transformOpenAIStreamToAnthropic() error = %v", err) + } + + got := output.String() + if !strings.Contains(got, "\"message_start\"") { + t.Fatal("expected message_start event in output") + } + if !strings.Contains(got, largeContent) { + t.Fatal("expected large content to be preserved in output") + } + if !strings.Contains(got, "\"message_stop\"") { + t.Fatal("expected message_stop event in output") + } +} diff --git a/proxy/internal/service/storage.go b/proxy/internal/service/storage.go index 868c616..da0b05e 100644 --- a/proxy/internal/service/storage.go +++ b/proxy/internal/service/storage.go @@ -1,18 +1,33 @@ package service import ( + "io" + "time" + "github.com/seifghazi/claude-code-monitor/internal/config" "github.com/seifghazi/claude-code-monitor/internal/model" ) type StorageService interface { + // Core CRUD operations SaveRequest(request *model.RequestLog) (string, error) - GetRequests(page, limit int) ([]model.RequestLog, int, error) + GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error) + GetAllRequests(modelFilter string) ([]*model.RequestLog, error) + GetRequestByShortID(shortID string) (*model.RequestLog, string, error) ClearRequests() (int, error) + + // Update operations UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error UpdateRequestWithResponse(request *model.RequestLog) error - EnsureDirectoryExists() error - GetRequestByShortID(shortID string) (*model.RequestLog, string, error) + + // Maintenance operations + DeleteRequestsOlderThan(age time.Duration) (int, error) + GetDatabaseStats() (map[string]interface{}, error) + + // Configuration GetConfig() *config.StorageConfig - GetAllRequests(modelFilter string) ([]*model.RequestLog, error) + EnsureDirectoryExists() error + + // Cleanup - implements io.Closer + io.Closer } diff --git a/proxy/internal/service/storage_sqlite.go b/proxy/internal/service/storage_sqlite.go index 77a52b4..e8aa062 100644 --- a/proxy/internal/service/storage_sqlite.go +++ b/proxy/internal/service/storage_sqlite.go @@ -4,7 +4,9 @@ import ( "database/sql" "encoding/json" "fmt" + "log" "strings" + "time" _ "github.com/mattn/go-sqlite3" @@ -15,23 +17,63 @@ import ( type sqliteStorageService struct { db *sql.DB config *config.StorageConfig + logger *log.Logger + + // Prepared statements for frequently used queries + stmtInsertRequest *sql.Stmt + stmtUpdateResponse *sql.Stmt + stmtUpdateGrading *sql.Stmt + stmtGetRequestByID *sql.Stmt + stmtGetRequestsPage *sql.Stmt + stmtGetRequestsCount *sql.Stmt + stmtDeleteOldRequests *sql.Stmt } func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) { - db, err := sql.Open("sqlite3", cfg.DBPath) + return NewSQLiteStorageServiceWithLogger(cfg, log.Default()) +} + +func NewSQLiteStorageServiceWithLogger(cfg *config.StorageConfig, logger *log.Logger) (StorageService, error) { + // Enable WAL mode and other optimizations via connection string + // _journal_mode=WAL: Write-Ahead Logging for better concurrent read performance + // _synchronous=NORMAL: Good balance of safety and performance + // _busy_timeout=5000: Wait up to 5 seconds if database is locked + // _cache_size=-20000: Use 20MB of memory for cache (negative = KB) + connStr := cfg.DBPath + "?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000" + + db, err := sql.Open("sqlite3", connStr) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } + // Configure connection pool + // SQLite only supports one writer at a time, but can handle multiple readers + db.SetMaxOpenConns(1) // Serialize writes to avoid SQLITE_BUSY errors + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(time.Hour) + + // Verify connection + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to ping database: %w", err) + } + service := &sqliteStorageService{ db: db, config: cfg, + logger: logger, } if err := service.createTables(); err != nil { + db.Close() return nil, fmt.Errorf("failed to create tables: %w", err) } + if err := service.prepareStatements(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to prepare statements: %w", err) + } + return service, nil } @@ -39,7 +81,7 @@ func (s *sqliteStorageService) createTables() error { schema := ` CREATE TABLE IF NOT EXISTS requests ( id TEXT PRIMARY KEY, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + timestamp DATETIME NOT NULL, method TEXT NOT NULL, endpoint TEXT NOT NULL, headers TEXT NOT NULL, @@ -50,17 +92,100 @@ func (s *sqliteStorageService) createTables() error { response TEXT, model TEXT, original_model TEXT, - routed_model TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP + routed_model TEXT ); - CREATE INDEX IF NOT EXISTS idx_timestamp ON requests(timestamp DESC); - CREATE INDEX IF NOT EXISTS idx_endpoint ON requests(endpoint); - CREATE INDEX IF NOT EXISTS idx_model ON requests(model); + -- Index for listing requests by time (most common query) + CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC); + + -- Index for filtering by model + CREATE INDEX IF NOT EXISTS idx_requests_model ON requests(model); + + -- Index for filtering by endpoint + CREATE INDEX IF NOT EXISTS idx_requests_endpoint ON requests(endpoint); ` _, err := s.db.Exec(schema) - return err + if err != nil { + return err + } + + // Run migrations + s.migrateSchema() + + return nil +} + +func (s *sqliteStorageService) migrateSchema() { + // Ensure WAL mode is enabled (in case opened without connection string params) + _, err := s.db.Exec("PRAGMA journal_mode=WAL") + if err != nil { + s.logger.Printf("Warning: failed to set WAL mode: %v", err) + } + + // Drop old redundant index if it exists (we renamed to idx_requests_timestamp) + s.db.Exec("DROP INDEX IF EXISTS idx_timestamp") +} + +func (s *sqliteStorageService) prepareStatements() error { + var err error + + s.stmtInsertRequest, err = s.db.Prepare(` + INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + + s.stmtUpdateResponse, err = s.db.Prepare(` + UPDATE requests SET response = ? WHERE id = ? + `) + if err != nil { + return fmt.Errorf("failed to prepare update response statement: %w", err) + } + + s.stmtUpdateGrading, err = s.db.Prepare(` + UPDATE requests SET prompt_grade = ? WHERE id = ? + `) + if err != nil { + return fmt.Errorf("failed to prepare update grading statement: %w", err) + } + + s.stmtGetRequestByID, err = s.db.Prepare(` + SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model + FROM requests + WHERE id = ? + `) + if err != nil { + return fmt.Errorf("failed to prepare get by ID statement: %w", err) + } + + s.stmtGetRequestsPage, err = s.db.Prepare(` + SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model + FROM requests + ORDER BY timestamp DESC + LIMIT ? OFFSET ? + `) + if err != nil { + return fmt.Errorf("failed to prepare get requests page statement: %w", err) + } + + s.stmtGetRequestsCount, err = s.db.Prepare(` + SELECT COUNT(*) FROM requests + `) + if err != nil { + return fmt.Errorf("failed to prepare count statement: %w", err) + } + + s.stmtDeleteOldRequests, err = s.db.Prepare(` + DELETE FROM requests WHERE timestamp < ? + `) + if err != nil { + return fmt.Errorf("failed to prepare delete old requests statement: %w", err) + } + + return nil } func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) { @@ -74,12 +199,7 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e return "", fmt.Errorf("failed to marshal body: %w", err) } - query := ` - INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - - _, err = s.db.Exec(query, + _, err = s.stmtInsertRequest.Exec( request.RequestID, request.Timestamp, request.Method, @@ -100,10 +220,24 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e return request.RequestID, nil } -func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, int, error) { +func (s *sqliteStorageService) GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error) { + whereClause := "" + countArgs := []interface{}{} + queryArgs := []interface{}{} + + if modelFilter != "" && modelFilter != "all" { + // Escape LIKE special characters to prevent pattern injection + escapedFilter := escapeLikePattern(strings.ToLower(modelFilter)) + whereClause = " WHERE LOWER(model) LIKE ? ESCAPE '\\'" + filterValue := "%" + escapedFilter + "%" + countArgs = append(countArgs, filterValue) + queryArgs = append(queryArgs, filterValue) + } + // Get total count var total int - err := s.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&total) + countQuery := "SELECT COUNT(*) FROM requests" + whereClause + err := s.db.QueryRow(countQuery, countArgs...).Scan(&total) if err != nil { return nil, 0, fmt.Errorf("failed to get total count: %w", err) } @@ -112,71 +246,21 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, offset := (page - 1) * limit query := ` SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model - FROM requests + FROM requests` + whereClause + ` ORDER BY timestamp DESC LIMIT ? OFFSET ? ` + queryArgs = append(queryArgs, limit, offset) - rows, err := s.db.Query(query, limit, offset) + rows, err := s.db.Query(query, queryArgs...) if err != nil { return nil, 0, fmt.Errorf("failed to query requests: %w", err) } defer rows.Close() - var requests []model.RequestLog - for rows.Next() { - var req model.RequestLog - var headersJSON, bodyJSON string - var promptGradeJSON, responseJSON sql.NullString - - err := rows.Scan( - &req.RequestID, - &req.Timestamp, - &req.Method, - &req.Endpoint, - &headersJSON, - &bodyJSON, - &req.Model, - &req.UserAgent, - &req.ContentType, - &promptGradeJSON, - &responseJSON, - &req.OriginalModel, - &req.RoutedModel, - ) - if err != nil { - // Error scanning row - skip - continue - } - - // Unmarshal JSON fields - if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil { - // Error unmarshaling headers - continue - } - - var body interface{} - if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil { - // Error unmarshaling body - continue - } - req.Body = body - - if promptGradeJSON.Valid { - var grade model.PromptGrade - if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil { - req.PromptGrade = &grade - } - } - - if responseJSON.Valid { - var resp model.ResponseLog - if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil { - req.Response = &resp - } - } - - requests = append(requests, req) + requests, err := s.scanRequestRows(rows) + if err != nil { + return nil, 0, err } return requests, total, nil @@ -193,6 +277,12 @@ func (s *sqliteStorageService) ClearRequests() (int, error) { return 0, fmt.Errorf("failed to get rows affected: %w", err) } + // Reclaim space after clearing all data + _, err = s.db.Exec("VACUUM") + if err != nil { + s.logger.Printf("Warning: failed to vacuum database: %v", err) + } + return int(rowsAffected), nil } @@ -202,12 +292,16 @@ func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade return fmt.Errorf("failed to marshal grade: %w", err) } - query := "UPDATE requests SET prompt_grade = ? WHERE id = ?" - _, err = s.db.Exec(query, string(gradeJSON), requestID) + result, err := s.stmtUpdateGrading.Exec(string(gradeJSON), requestID) if err != nil { return fmt.Errorf("failed to update request with grading: %w", err) } + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + return fmt.Errorf("request %s not found", requestID) + } + return nil } @@ -217,12 +311,72 @@ func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestL return fmt.Errorf("failed to marshal response: %w", err) } - query := "UPDATE requests SET response = ? WHERE id = ?" - _, err = s.db.Exec(query, string(responseJSON), request.RequestID) + result, err := s.stmtUpdateResponse.Exec(string(responseJSON), request.RequestID) if err != nil { return fmt.Errorf("failed to update request with response: %w", err) } + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + return fmt.Errorf("request %s not found", request.RequestID) + } + + return nil +} + +// SaveRequestWithResponse saves a request and its response in a single transaction +func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog) error { + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + headersJSON, err := json.Marshal(request.Headers) + if err != nil { + return fmt.Errorf("failed to marshal headers: %w", err) + } + + bodyJSON, err := json.Marshal(request.Body) + if err != nil { + return fmt.Errorf("failed to marshal body: %w", err) + } + + // Insert request + _, err = tx.Stmt(s.stmtInsertRequest).Exec( + request.RequestID, + request.Timestamp, + request.Method, + request.Endpoint, + string(headersJSON), + string(bodyJSON), + request.UserAgent, + request.ContentType, + request.Model, + request.OriginalModel, + request.RoutedModel, + ) + if err != nil { + return fmt.Errorf("failed to insert request: %w", err) + } + + // Update with response if present + if request.Response != nil { + responseJSON, err := json.Marshal(request.Response) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + _, err = tx.Stmt(s.stmtUpdateResponse).Exec(string(responseJSON), request.RequestID) + if err != nil { + return fmt.Errorf("failed to update response: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil } @@ -232,10 +386,13 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error { } func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) { + // Escape LIKE special characters to prevent pattern injection + escapedID := escapeLikePattern(shortID) + query := ` SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model FROM requests - WHERE id LIKE ? + WHERE id LIKE ? ESCAPE '\' ORDER BY timestamp DESC LIMIT 1 ` @@ -244,7 +401,7 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque var headersJSON, bodyJSON string var promptGradeJSON, responseJSON sql.NullString - err := s.db.QueryRow(query, "%"+shortID).Scan( + err := s.db.QueryRow(query, "%"+escapedID).Scan( &req.RequestID, &req.Timestamp, &req.Method, @@ -267,29 +424,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque return nil, "", fmt.Errorf("failed to query request: %w", err) } - // Unmarshal JSON fields - if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil { - return nil, "", fmt.Errorf("failed to unmarshal headers: %w", err) - } - - var body interface{} - if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil { - return nil, "", fmt.Errorf("failed to unmarshal body: %w", err) - } - req.Body = body - - if promptGradeJSON.Valid { - var grade model.PromptGrade - if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil { - req.PromptGrade = &grade - } - } - - if responseJSON.Valid { - var resp model.ResponseLog - if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil { - req.Response = &resp - } + if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil { + return nil, "", err } return &req, req.RequestID, nil @@ -300,19 +436,36 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig { } func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) { - query := ` - SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model - FROM requests - ` + return s.GetAllRequestsWithLimit(modelFilter, 0) // 0 means no limit +} + +// GetAllRequestsWithLimit returns requests with an optional limit (0 = no limit) +func (s *sqliteStorageService) GetAllRequestsWithLimit(modelFilter string, limit int) ([]*model.RequestLog, error) { + var query string args := []interface{}{} if modelFilter != "" && modelFilter != "all" { - query += " WHERE LOWER(model) LIKE ?" - args = append(args, "%"+strings.ToLower(modelFilter)+"%") - + // Escape LIKE special characters + escapedFilter := escapeLikePattern(strings.ToLower(modelFilter)) + query = ` + SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model + FROM requests + WHERE LOWER(model) LIKE ? ESCAPE '\' + ORDER BY timestamp DESC + ` + args = append(args, "%"+escapedFilter+"%") + } else { + query = ` + SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model + FROM requests + ORDER BY timestamp DESC + ` } - query += " ORDER BY timestamp DESC" + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + } rows, err := s.db.Query(query, args...) if err != nil { @@ -321,6 +474,124 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ defer rows.Close() var requests []*model.RequestLog + for rows.Next() { + req, err := s.scanSingleRow(rows) + if err != nil { + s.logger.Printf("Warning: failed to scan request row: %v", err) + continue + } + requests = append(requests, req) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) + } + + return requests, nil +} + +// DeleteRequestsOlderThan removes requests older than the specified duration +func (s *sqliteStorageService) DeleteRequestsOlderThan(age time.Duration) (int, error) { + cutoff := time.Now().Add(-age) + + result, err := s.stmtDeleteOldRequests.Exec(cutoff.Format(time.RFC3339)) + if err != nil { + return 0, fmt.Errorf("failed to delete old requests: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("failed to get rows affected: %w", err) + } + + return int(rowsAffected), nil +} + +// GetDatabaseStats returns statistics about the database +func (s *sqliteStorageService) GetDatabaseStats() (map[string]interface{}, error) { + stats := make(map[string]interface{}) + + // Get row count + var count int + err := s.stmtGetRequestsCount.QueryRow().Scan(&count) + if err != nil { + return nil, fmt.Errorf("failed to get count: %w", err) + } + stats["total_requests"] = count + + // Get database size + var pageCount, pageSize int + err = s.db.QueryRow("PRAGMA page_count").Scan(&pageCount) + if err == nil { + err = s.db.QueryRow("PRAGMA page_size").Scan(&pageSize) + if err == nil { + stats["database_size_bytes"] = pageCount * pageSize + } + } + + // Get oldest and newest timestamps + var oldest, newest sql.NullString + err = s.db.QueryRow("SELECT MIN(timestamp), MAX(timestamp) FROM requests").Scan(&oldest, &newest) + if err == nil { + if oldest.Valid { + stats["oldest_request"] = oldest.String + } + if newest.Valid { + stats["newest_request"] = newest.String + } + } + + return stats, nil +} + +func (s *sqliteStorageService) Close() error { + // Close prepared statements + if s.stmtInsertRequest != nil { + s.stmtInsertRequest.Close() + } + if s.stmtUpdateResponse != nil { + s.stmtUpdateResponse.Close() + } + if s.stmtUpdateGrading != nil { + s.stmtUpdateGrading.Close() + } + if s.stmtGetRequestByID != nil { + s.stmtGetRequestByID.Close() + } + if s.stmtGetRequestsPage != nil { + s.stmtGetRequestsPage.Close() + } + if s.stmtGetRequestsCount != nil { + s.stmtGetRequestsCount.Close() + } + if s.stmtDeleteOldRequests != nil { + s.stmtDeleteOldRequests.Close() + } + + // Checkpoint WAL before closing + _, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)") + if err != nil { + s.logger.Printf("Warning: failed to checkpoint WAL: %v", err) + } + + return s.db.Close() +} + +// Helper functions + +// escapeLikePattern escapes special characters in LIKE patterns +func escapeLikePattern(s string) string { + // Escape \, %, and _ characters + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `%`, `\%`) + s = strings.ReplaceAll(s, `_`, `\_`) + return s +} + +// scanRequestRows scans multiple rows into a slice of RequestLog +func (s *sqliteStorageService) scanRequestRows(rows *sql.Rows) ([]model.RequestLog, error) { + var requests []model.RequestLog + for rows.Next() { var req model.RequestLog var headersJSON, bodyJSON string @@ -342,43 +613,86 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ &req.RoutedModel, ) if err != nil { - // Error scanning row - skip + s.logger.Printf("Warning: failed to scan row: %v", err) continue } - // Unmarshal JSON fields - if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil { - // Error unmarshaling headers + if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil { + s.logger.Printf("Warning: failed to unmarshal request fields: %v", err) continue } - var body interface{} - if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil { - // Error unmarshaling body - continue - } - req.Body = body + requests = append(requests, req) + } - if promptGradeJSON.Valid { - var grade model.PromptGrade - if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil { - req.PromptGrade = &grade - } - } - - if responseJSON.Valid { - var resp model.ResponseLog - if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil { - req.Response = &resp - } - } - - requests = append(requests, &req) + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) } return requests, nil } -func (s *sqliteStorageService) Close() error { - return s.db.Close() +// scanSingleRow scans a single row into a RequestLog pointer +func (s *sqliteStorageService) scanSingleRow(rows *sql.Rows) (*model.RequestLog, error) { + var req model.RequestLog + var headersJSON, bodyJSON string + var promptGradeJSON, responseJSON sql.NullString + + err := rows.Scan( + &req.RequestID, + &req.Timestamp, + &req.Method, + &req.Endpoint, + &headersJSON, + &bodyJSON, + &req.Model, + &req.UserAgent, + &req.ContentType, + &promptGradeJSON, + &responseJSON, + &req.OriginalModel, + &req.RoutedModel, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil { + return nil, err + } + + return &req, nil +} + +// unmarshalRequestFields unmarshals JSON fields into a RequestLog +func (s *sqliteStorageService) unmarshalRequestFields(req *model.RequestLog, headersJSON, bodyJSON string, promptGradeJSON, responseJSON sql.NullString) error { + if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil { + return fmt.Errorf("failed to unmarshal headers: %w", err) + } + + var body interface{} + if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil { + return fmt.Errorf("failed to unmarshal body: %w", err) + } + req.Body = body + + if promptGradeJSON.Valid { + var grade model.PromptGrade + if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err != nil { + s.logger.Printf("Warning: failed to unmarshal prompt grade: %v", err) + } else { + req.PromptGrade = &grade + } + } + + if responseJSON.Valid { + var resp model.ResponseLog + if err := json.Unmarshal([]byte(responseJSON.String), &resp); err != nil { + s.logger.Printf("Warning: failed to unmarshal response: %v", err) + } else { + req.Response = &resp + } + } + + return nil } diff --git a/proxy/internal/service/storage_sqlite_test.go b/proxy/internal/service/storage_sqlite_test.go new file mode 100644 index 0000000..258cef2 --- /dev/null +++ b/proxy/internal/service/storage_sqlite_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/seifghazi/claude-code-monitor/internal/config" + "github.com/seifghazi/claude-code-monitor/internal/model" +) + +func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "requests.db") + storage, err := NewSQLiteStorageService(&config.StorageConfig{DBPath: dbPath}) + if err != nil { + t.Fatalf("NewSQLiteStorageService() error = %v", err) + } + + sqliteStorage, ok := storage.(*sqliteStorageService) + if !ok { + t.Fatalf("unexpected storage type %T", storage) + } + defer sqliteStorage.Close() + + requests := []struct { + id string + model string + }{ + {id: "1", model: "claude-3-5-sonnet"}, + {id: "2", model: "gpt-4o"}, + {id: "3", model: "claude-3-5-sonnet"}, + {id: "4", model: "gpt-4o-mini"}, + } + + for i, req := range requests { + _, err := storage.SaveRequest(&model.RequestLog{ + RequestID: req.id, + Timestamp: time.Date(2026, 3, 19, 12, 0, i, 0, time.UTC).Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]string{"request": fmt.Sprintf("body-%d", i)}, + Model: req.model, + UserAgent: "test", + ContentType: "application/json", + }) + if err != nil { + t.Fatalf("SaveRequest() error = %v", err) + } + } + + got, total, err := storage.GetRequests(1, 1, "gpt") + if err != nil { + t.Fatalf("GetRequests() error = %v", err) + } + + if total != 2 { + t.Fatalf("expected filtered total 2, got %d", total) + } + + if len(got) != 1 { + t.Fatalf("expected 1 paginated result, got %d", len(got)) + } + + if got[0].RequestID != "4" { + t.Fatalf("expected newest filtered request ID 4, got %s", got[0].RequestID) + } +} diff --git a/proxy/internal/sse/sse.go b/proxy/internal/sse/sse.go new file mode 100644 index 0000000..ce1c1e2 --- /dev/null +++ b/proxy/internal/sse/sse.go @@ -0,0 +1,30 @@ +package sse + +import ( + "bufio" + "io" + "strings" +) + +// ForEachLine reads line-oriented SSE content without bufio.Scanner's token limit. +func ForEachLine(r io.Reader, fn func(string) error) error { + reader := bufio.NewReader(r) + + for { + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return err + } + + if len(line) > 0 { + line = strings.TrimRight(line, "\r\n") + if callErr := fn(line); callErr != nil { + return callErr + } + } + + if err == io.EOF { + return nil + } + } +} diff --git a/proxy/internal/sse/sse_test.go b/proxy/internal/sse/sse_test.go new file mode 100644 index 0000000..67319f4 --- /dev/null +++ b/proxy/internal/sse/sse_test.go @@ -0,0 +1,31 @@ +package sse + +import ( + "strings" + "testing" +) + +func TestForEachLineHandlesLargeLines(t *testing.T) { + largePayload := strings.Repeat("x", 128*1024) + input := "data: " + largePayload + "\n\n" + + var lines []string + if err := ForEachLine(strings.NewReader(input), func(line string) error { + lines = append(lines, line) + return nil + }); err != nil { + t.Fatalf("ForEachLine() error = %v", err) + } + + if len(lines) != 2 { + t.Fatalf("expected 2 lines, got %d", len(lines)) + } + + if lines[0] != "data: "+largePayload { + t.Fatalf("unexpected first line length: got %d", len(lines[0])) + } + + if lines[1] != "" { + t.Fatalf("expected blank separator line, got %q", lines[1]) + } +}