Harden streaming, pagination, and config loading

This commit is contained in:
sid 2026-03-19 18:52:09 -06:00
parent 02c9c76667
commit 6cda36312a
16 changed files with 1079 additions and 244 deletions

View file

@ -12,5 +12,16 @@ ANTHROPIC_FORWARD_URL=https://api.anthropic.com
ANTHROPIC_VERSION=2023-06-01
ANTHROPIC_MAX_RETRIES=3
# OpenAI Configuration (for subagent routing)
# OPENAI_API_KEY=your-openai-api-key
# OPENAI_BASE_URL=https://api.openai.com
# OPENAI_ALLOW_CLIENT_API_KEY=false
# OPENAI_CLIENT_API_KEY_HEADER=x-openai-api-key
# Storage Configuration
DATABASE_PATH=requests.db
DB_PATH=requests.db
# CORS Configuration (comma-separated values)
# CORS_ALLOWED_ORIGINS=*
# CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
# CORS_ALLOWED_HEADERS=*

View file

@ -38,6 +38,36 @@ providers:
# Can also be set via OPENAI_BASE_URL environment variable
# base_url: "https://api.openai.com"
# Allow clients to provide their own API key via header
# Can also be set via OPENAI_ALLOW_CLIENT_API_KEY environment variable
allow_client_api_key: false
# Header name for client-provided API key (default: x-openai-api-key)
# Can also be set via OPENAI_CLIENT_API_KEY_HEADER environment variable
client_api_key_header: "x-openai-api-key"
# CORS Configuration
# Controls Cross-Origin Resource Sharing for the web UI
cors:
# Allowed origins (use ["*"] for all origins - not recommended for production)
# Can also be set via CORS_ALLOWED_ORIGINS environment variable (comma-separated)
allowed_origins:
- "*"
# Allowed HTTP methods
# Can also be set via CORS_ALLOWED_METHODS environment variable (comma-separated)
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
# Allowed headers (use ["*"] for all headers)
# Can also be set via CORS_ALLOWED_HEADERS environment variable (comma-separated)
allowed_headers:
- "*"
# Storage configuration
storage:
# SQLite database path for storing request history
@ -69,23 +99,30 @@ subagents:
# The following environment variables will override the YAML configuration:
#
# Server:
# PORT - Server port
# READ_TIMEOUT - Read timeout duration
# WRITE_TIMEOUT - Write timeout duration
# IDLE_TIMEOUT - Idle timeout duration
# PORT - Server port
# READ_TIMEOUT - Read timeout duration
# WRITE_TIMEOUT - Write timeout duration
# IDLE_TIMEOUT - Idle timeout duration
#
# Anthropic:
# ANTHROPIC_FORWARD_URL - Anthropic base URL
# ANTHROPIC_VERSION - Anthropic API version
# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests
# ANTHROPIC_FORWARD_URL - Anthropic base URL
# ANTHROPIC_VERSION - Anthropic API version
# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests
#
# OpenAI:
# OPENAI_API_KEY - OpenAI API key
# OPENAI_BASE_URL - OpenAI base URL
# OPENAI_API_KEY - OpenAI API key
# OPENAI_BASE_URL - OpenAI base URL
# OPENAI_ALLOW_CLIENT_API_KEY - Allow client-provided API keys (true/false)
# OPENAI_CLIENT_API_KEY_HEADER - Header name for client API key
#
# Storage:
# DB_PATH - Database file path
# DB_PATH - Database file path
#
# CORS:
# CORS_ALLOWED_ORIGINS - Comma-separated allowed origins
# CORS_ALLOWED_METHODS - Comma-separated allowed methods
# CORS_ALLOWED_HEADERS - Comma-separated allowed headers
#
# Subagents:
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
# Example: "code-reviewer:claude-3-5-sonnet"
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
# Example: "code-reviewer:claude-3-5-sonnet"

View file

@ -50,9 +50,9 @@ func main() {
r := mux.NewRouter()
corsHandler := handlers.CORS(
handlers.AllowedOrigins([]string{"*"}),
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}),
handlers.AllowedHeaders([]string{"*"}),
handlers.AllowedOrigins(cfg.CORS.AllowedOrigins),
handlers.AllowedMethods(cfg.CORS.AllowedMethods),
handlers.AllowedHeaders(cfg.CORS.AllowedHeaders),
)
r.Use(middleware.Logging)

View file

@ -1,9 +1,11 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/joho/godotenv"
@ -15,9 +17,16 @@ type Config struct {
Providers ProvidersConfig `yaml:"providers"`
Storage StorageConfig `yaml:"storage"`
Subagents SubagentsConfig `yaml:"subagents"`
CORS CORSConfig `yaml:"cors"`
Anthropic AnthropicConfig
}
type CORSConfig struct {
AllowedOrigins []string `yaml:"allowed_origins"`
AllowedMethods []string `yaml:"allowed_methods"`
AllowedHeaders []string `yaml:"allowed_headers"`
}
type ServerConfig struct {
Port string `yaml:"port"`
Timeouts TimeoutsConfig `yaml:"timeouts"`
@ -45,8 +54,10 @@ type AnthropicProviderConfig struct {
}
type OpenAIProviderConfig struct {
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
AllowClientAPIKey bool `yaml:"allow_client_api_key"` // Allow clients to provide their own API key
ClientAPIKeyHeader string `yaml:"client_api_key_header"` // Header name for client API key (default: x-openai-api-key)
}
type AnthropicConfig struct {
@ -92,8 +103,10 @@ func Load() (*Config, error) {
MaxRetries: 3,
},
OpenAI: OpenAIProviderConfig{
BaseURL: "https://api.openai.com",
APIKey: "",
BaseURL: "https://api.openai.com",
APIKey: "",
AllowClientAPIKey: false,
ClientAPIKeyHeader: "x-openai-api-key",
},
},
Storage: StorageConfig{
@ -103,25 +116,17 @@ func Load() (*Config, error) {
Enable: false,
Mappings: make(map[string]string),
},
CORS: CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"*"},
},
}
// Try to load config.yaml from the project root
// The proxy binary is in proxy/ directory, config.yaml is in the parent
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
// If that doesn't work, try relative to current directory
if _, err := os.Stat(configPath); err != nil {
// Try common locations relative to where the binary might be run
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
if _, err := os.Stat(tryPath); err == nil {
configPath = tryPath
break
}
}
if err := loadFirstAvailableConfig(cfg, candidateConfigPaths()); err != nil {
return nil, err
}
cfg.loadFromFile(configPath)
// Apply environment variable overrides AFTER loading from file
if envPort := os.Getenv("PORT"); envPort != "" {
cfg.Server.Port = envPort
@ -154,12 +159,29 @@ func Load() (*Config, error) {
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
cfg.Providers.OpenAI.APIKey = envKey
}
if envAllow := os.Getenv("OPENAI_ALLOW_CLIENT_API_KEY"); envAllow != "" {
cfg.Providers.OpenAI.AllowClientAPIKey = envAllow == "true" || envAllow == "1"
}
if envHeader := os.Getenv("OPENAI_CLIENT_API_KEY_HEADER"); envHeader != "" {
cfg.Providers.OpenAI.ClientAPIKeyHeader = envHeader
}
// Override storage settings
if envPath := os.Getenv("DB_PATH"); envPath != "" {
cfg.Storage.DBPath = envPath
}
// Override CORS settings (comma-separated values)
if envOrigins := os.Getenv("CORS_ALLOWED_ORIGINS"); envOrigins != "" {
cfg.CORS.AllowedOrigins = splitAndTrim(envOrigins)
}
if envMethods := os.Getenv("CORS_ALLOWED_METHODS"); envMethods != "" {
cfg.CORS.AllowedMethods = splitAndTrim(envMethods)
}
if envHeaders := os.Getenv("CORS_ALLOWED_HEADERS"); envHeaders != "" {
cfg.CORS.AllowedHeaders = splitAndTrim(envHeaders)
}
// Sync legacy Anthropic config
cfg.Anthropic = AnthropicConfig{
BaseURL: cfg.Providers.Anthropic.BaseURL,
@ -203,6 +225,45 @@ func (c *Config) loadFromFile(path string) error {
return yaml.Unmarshal(data, c)
}
func candidateConfigPaths() []string {
paths := []string{
filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml"),
"config.yaml",
"../config.yaml",
"../../config.yaml",
}
seen := make(map[string]struct{}, len(paths))
unique := make([]string, 0, len(paths))
for _, path := range paths {
if _, ok := seen[path]; ok {
continue
}
seen[path] = struct{}{}
unique = append(unique, path)
}
return unique
}
func loadFirstAvailableConfig(cfg *Config, paths []string) error {
for _, path := range paths {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("failed to stat config file %q: %w", path, err)
}
if err := cfg.loadFromFile(path); err != nil {
return fmt.Errorf("failed to load config file %q: %w", path, err)
}
return nil
}
return nil
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
@ -237,3 +298,15 @@ func getInt(key string, defaultValue int) int {
return intValue
}
func splitAndTrim(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}

View file

@ -0,0 +1,30 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadFirstAvailableConfigReturnsParseError(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("server: ["), 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
cfg := &Config{}
if err := loadFirstAvailableConfig(cfg, []string{configPath}); err == nil {
t.Fatal("expected parse error, got nil")
}
}
func TestLoadFirstAvailableConfigSkipsMissingFiles(t *testing.T) {
cfg := &Config{}
if err := loadFirstAvailableConfig(cfg, []string{
filepath.Join(t.TempDir(), "missing.yaml"),
}); err != nil {
t.Fatalf("expected nil error for missing config, got %v", err)
}
}

View file

@ -1,7 +1,6 @@
package handler
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
@ -20,6 +19,7 @@ import (
"github.com/seifghazi/claude-code-monitor/internal/model"
"github.com/seifghazi/claude-code-monitor/internal/service"
"github.com/seifghazi/claude-code-monitor/internal/sse"
)
type Handler struct {
@ -179,37 +179,13 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
modelFilter = "all"
}
// Get all requests with model filter applied at storage level
allRequests, err := h.storageService.GetAllRequests(modelFilter)
requests, total, err := h.storageService.GetRequests(page, limit, modelFilter)
if err != nil {
log.Printf("Error getting requests: %v", err)
http.Error(w, "Failed to get requests", http.StatusInternalServerError)
return
}
// Convert pointers to values for consistency
requests := make([]model.RequestLog, len(allRequests))
for i, req := range allRequests {
if req != nil {
requests[i] = *req
}
}
// Calculate total before pagination
total := len(requests)
// Apply pagination
start := (page - 1) * limit
end := start + limit
if start >= len(requests) {
requests = []model.RequestLog{}
} else {
if end > len(requests) {
end = len(requests)
}
requests = requests[start:end]
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(struct {
Requests []model.RequestLog `json:"requests"`
@ -242,6 +218,8 @@ func (h *Handler) NotFound(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
// Forward important upstream headers (rate limits, request IDs, etc.)
ForwardResponseHeaders(w, resp)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
@ -278,16 +256,17 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
var messageID string
var modelName string
var stopReason string
var sawMessageStop bool
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
streamErr := sse.ForEachLine(resp.Body, func(line string) error {
if line == "" || !strings.HasPrefix(line, "data:") {
continue
return nil
}
streamingChunks = append(streamingChunks, line)
fmt.Fprintf(w, "%s\n\n", line)
if _, err := fmt.Fprintf(w, "%s\n\n", line); err != nil {
return err
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
@ -298,7 +277,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
var genericEvent map[string]interface{}
if err := json.Unmarshal([]byte(jsonData), &genericEvent); err != nil {
log.Printf("⚠️ Error unmarshalling streaming event: %v", err)
continue
return nil
}
// Capture metadata from message_start event
@ -347,7 +326,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
var event model.StreamingEvent
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
// Skip if structured parsing fails, but we already got the usage data above
continue
return nil
}
switch event.Type {
@ -366,8 +345,13 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
toolCalls = append(toolCalls, *event.ContentBlock)
}
case "message_stop":
// End of stream - scanner will exit on its own
sawMessageStop = true
}
return nil
})
if streamErr == nil && !sawMessageStop {
streamErr = io.ErrUnexpectedEOF
}
responseLog := &model.ResponseLog{
@ -378,6 +362,9 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
IsStreaming: true,
CompletedAt: time.Now().Format(time.RFC3339),
}
if streamErr != nil {
responseLog.StreamError = streamErr.Error()
}
// Create a structured response body that matches Anthropic's format
var contentBlocks []model.AnthropicContentBlock
@ -417,14 +404,31 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
log.Printf("❌ Error updating request with streaming response: %v", err)
}
if err := scanner.Err(); err != nil {
log.Printf("❌ Streaming error: %v", err)
if streamErr != nil {
log.Printf("❌ Streaming error: %v", streamErr)
// Send error event to client in Anthropic streaming format
errorEvent := map[string]interface{}{
"type": "error",
"error": map[string]interface{}{
"type": "stream_error",
"message": fmt.Sprintf("Stream interrupted: %v", streamErr),
},
}
if errorJSON, jsonErr := json.Marshal(errorEvent); jsonErr == nil {
fmt.Fprintf(w, "data: %s\n\n", errorJSON)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
} else {
log.Println("✅ Streaming response completed")
}
}
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
// Forward important upstream headers (rate limits, request IDs, etc.)
ForwardResponseHeaders(w, resp)
responseBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("❌ Error reading Anthropic response: %v", err)
@ -464,6 +468,7 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
if resp.StatusCode != http.StatusOK {
log.Printf("❌ Anthropic API error: %d %s", resp.StatusCode, string(responseBytes))
// Headers already forwarded at start of function
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
w.Write(responseBytes)

View file

@ -11,6 +11,68 @@ import (
"github.com/seifghazi/claude-code-monitor/internal/model"
)
// Headers that should be forwarded from upstream responses to clients
var forwardableResponseHeaders = []string{
// Anthropic rate limit headers
"anthropic-ratelimit-requests-limit",
"anthropic-ratelimit-requests-remaining",
"anthropic-ratelimit-requests-reset",
"anthropic-ratelimit-tokens-limit",
"anthropic-ratelimit-tokens-remaining",
"anthropic-ratelimit-tokens-reset",
// Standard rate limit headers
"x-ratelimit-limit",
"x-ratelimit-remaining",
"x-ratelimit-reset",
"retry-after",
// Request tracking
"x-request-id",
"request-id",
// Anthropic specific
"anthropic-organization-id",
// OpenAI specific
"openai-organization",
"openai-processing-ms",
"openai-version",
"x-request-id",
}
// ForwardResponseHeaders copies important headers from upstream response to client response
func ForwardResponseHeaders(w http.ResponseWriter, resp *http.Response) {
for _, header := range forwardableResponseHeaders {
if values := resp.Header.Values(header); len(values) > 0 {
for _, value := range values {
w.Header().Add(header, value)
}
}
}
}
// CopyAllResponseHeaders copies all non-hop-by-hop headers from upstream to client
func CopyAllResponseHeaders(w http.ResponseWriter, resp *http.Response) {
hopByHopHeaders := map[string]bool{
"connection": true,
"keep-alive": true,
"proxy-authenticate": true,
"proxy-authorization": true,
"te": true,
"trailers": true,
"transfer-encoding": true,
"upgrade": true,
"content-encoding": true, // We handle decompression ourselves
"content-length": true, // May change after decompression
}
for key, values := range resp.Header {
if hopByHopHeaders[strings.ToLower(key)] {
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
}
// SanitizeHeaders removes sensitive headers before logging/storage
func SanitizeHeaders(headers http.Header) http.Header {
sanitized := make(http.Header)

View file

@ -45,6 +45,7 @@ type ResponseLog struct {
Headers map[string][]string `json:"headers"`
Body json.RawMessage `json:"body,omitempty"`
BodyText string `json:"bodyText,omitempty"`
StreamError string `json:"streamError,omitempty"`
ResponseTime int64 `json:"responseTime"`
StreamingChunks []string `json:"streamingChunks,omitempty"`
IsStreaming bool `json:"isStreaming"`

View file

@ -66,8 +66,14 @@ func (p *AnthropicProvider) ForwardRequest(ctx context.Context, originalReq *htt
proxyReq.Header.Set("anthropic-version", p.config.Version)
}
// Support gzip encoding
proxyReq.Header.Set("Accept-Encoding", "gzip")
// Handle Accept-Encoding: We accept gzip from upstream for efficiency,
// but we decompress before forwarding to the client. This is transparent
// to the client - they receive uncompressed data regardless of what they requested.
// We preserve gzip if client already requested it, otherwise add it.
clientEncoding := proxyReq.Header.Get("Accept-Encoding")
if clientEncoding == "" || !strings.Contains(clientEncoding, "gzip") {
proxyReq.Header.Set("Accept-Encoding", "gzip")
}
// Forward the request
resp, err := p.client.Do(proxyReq)

View file

@ -1,7 +1,6 @@
package provider
import (
"bufio"
"bytes"
"compress/gzip"
"context"
@ -15,6 +14,7 @@ import (
"github.com/seifghazi/claude-code-monitor/internal/config"
"github.com/seifghazi/claude-code-monitor/internal/model"
"github.com/seifghazi/claude-code-monitor/internal/sse"
)
type OpenAIProvider struct {
@ -78,13 +78,29 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
// Remove Anthropic-specific headers
proxyReq.Header.Del("anthropic-version")
proxyReq.Header.Del("x-api-key")
proxyReq.Header.Del("Authorization")
// Determine which API key to use
apiKey := p.config.APIKey
// Check for client-provided API key if allowed
if p.config.AllowClientAPIKey && p.config.ClientAPIKeyHeader != "" {
if clientKey := originalReq.Header.Get(p.config.ClientAPIKeyHeader); clientKey != "" {
apiKey = clientKey
}
}
// Add OpenAI headers
if p.config.APIKey != "" {
proxyReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
if apiKey != "" {
proxyReq.Header.Set("Authorization", "Bearer "+apiKey)
}
proxyReq.Header.Set("Content-Type", "application/json")
// Remove the client API key header from the proxied request
if p.config.ClientAPIKeyHeader != "" {
proxyReq.Header.Del(p.config.ClientAPIKeyHeader)
}
// Forward the request
resp, err := p.client.Do(proxyReq)
if err != nil {
@ -139,9 +155,12 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
// Start a goroutine to transform the stream
go func() {
defer pw.Close()
defer bodyReader.Close()
transformOpenAIStreamToAnthropic(bodyReader, pw)
if err := transformOpenAIStreamToAnthropic(bodyReader, pw); err != nil {
_ = pw.CloseWithError(err)
return
}
_ = pw.Close()
}()
// Replace the response body with our transformed stream
@ -332,10 +351,10 @@ func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{
})
}
}
// Check if max_tokens exceeds the model's limit and cap it if necessary
maxTokensLimit := 16384 // Assuming this is the limit for the model
if req.MaxTokens > maxTokensLimit {
// Capping max_tokens to model limit
// Get model-specific max token limit
// Let the API handle validation for unknown models rather than using arbitrary caps
maxTokensLimit := getModelMaxTokens(req.Model)
if maxTokensLimit > 0 && req.MaxTokens > maxTokensLimit {
req.MaxTokens = maxTokensLimit
}
@ -473,6 +492,52 @@ func min(a, b int) int {
return b
}
// getModelMaxTokens returns the max output tokens for known models
// Returns 0 for unknown models, letting the API handle validation
func getModelMaxTokens(model string) int {
// Model-specific max completion token limits
modelLimits := map[string]int{
// GPT-4 Turbo and GPT-4o models
"gpt-4-turbo": 4096,
"gpt-4-turbo-preview": 4096,
"gpt-4o": 16384,
"gpt-4o-mini": 16384,
"gpt-4o-2024-05-13": 16384,
"gpt-4o-2024-08-06": 16384,
// GPT-4 models
"gpt-4": 8192,
"gpt-4-32k": 8192,
"gpt-4-0613": 8192,
// GPT-3.5 models
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 4096,
"gpt-3.5-turbo-0125": 4096,
"gpt-3.5-turbo-1106": 4096,
// o1 reasoning models
"o1": 100000,
"o1-preview": 32768,
"o1-mini": 65536,
// o3 reasoning models (estimated based on o1 patterns)
"o3": 100000,
"o3-mini": 65536,
}
// Check for exact match first
if limit, ok := modelLimits[model]; ok {
return limit
}
// Check for prefix matches for versioned models
for prefix, limit := range modelLimits {
if strings.HasPrefix(model, prefix) {
return limit
}
}
// Return 0 for unknown models - let the API validate
return 0
}
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
// This is a simplified transformation
// In production, you'd want to handle all fields properly
@ -579,19 +644,16 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
return result
}
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
defer openAIStream.Close()
scanner := bufio.NewScanner(openAIStream)
func transformOpenAIStreamToAnthropic(openAIStream io.Reader, anthropicStream io.Writer) error {
var messageStarted bool
var contentStarted bool
var sawDone bool
for scanner.Scan() {
line := scanner.Text()
err := sse.ForEachLine(openAIStream, func(line string) error {
// Skip empty lines
if line == "" {
continue
return nil
}
// Handle SSE data lines
@ -600,21 +662,28 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
// Handle end of stream
if data == "[DONE]" {
sawDone = true
// Send Anthropic-style completion
if contentStarted {
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n"); err != nil {
return err
}
}
if messageStarted {
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n"); err != nil {
return err
}
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n"); err != nil {
return err
}
}
break
return nil
}
// Parse OpenAI response
var openAIChunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
continue
return nil
}
// Check for usage data BEFORE processing choices
@ -644,7 +713,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
"usage": anthropicUsage,
}
usageJSON, _ := json.Marshal(usageDelta)
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON); err != nil {
return err
}
}
}
@ -652,17 +723,17 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
choices, ok := openAIChunk["choices"].([]interface{})
if !ok || len(choices) == 0 {
// Skip further processing if no choices, but we already handled usage above
continue
return nil
}
choice, ok := choices[0].(map[string]interface{})
if !ok {
continue
return nil
}
delta, ok := choice["delta"].(map[string]interface{})
if !ok {
continue
return nil
}
// Handle first chunk - send message_start
@ -684,7 +755,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
},
}
startJSON, _ := json.Marshal(messageStart)
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON); err != nil {
return err
}
}
// Handle content
@ -701,7 +774,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
},
}
blockStartJSON, _ := json.Marshal(blockStart)
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON); err != nil {
return err
}
}
// Send content_block_delta
@ -714,9 +789,22 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
},
}
deltaJSON, _ := json.Marshal(contentDelta)
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON); err != nil {
return err
}
}
}
return nil
})
if err != nil {
return err
}
if !sawDone {
return io.ErrUnexpectedEOF
}
return nil
}

View file

@ -0,0 +1,63 @@
package provider
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/seifghazi/claude-code-monitor/internal/config"
)
func TestOpenAIProviderForwardRequestClearsInboundAuthorization(t *testing.T) {
var gotAuthorization string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuthorization = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"resp_1","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`))
}))
defer server.Close()
provider := NewOpenAIProvider(&config.OpenAIProviderConfig{
BaseURL: server.URL,
}).(*OpenAIProvider)
req := httptest.NewRequest(http.MethodPost, "http://proxy.local/v1/messages", strings.NewReader(`{"model":"gpt-4o","messages":[],"max_tokens":16}`))
req.Header.Set("Authorization", "Bearer should-not-leak")
req.Header.Set("Content-Type", "application/json")
resp, err := provider.ForwardRequest(context.Background(), req)
if err != nil {
t.Fatalf("ForwardRequest() error = %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if gotAuthorization != "" {
t.Fatalf("expected forwarded Authorization header to be empty, got %q", gotAuthorization)
}
}
func TestTransformOpenAIStreamToAnthropicHandlesLargeEvents(t *testing.T) {
largeContent := strings.Repeat("x", 128*1024)
openAIStream := strings.NewReader("data: {\"id\":\"chatcmpl_1\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"" + largeContent + "\"}}]}\n\n" +
"data: [DONE]\n\n")
var output strings.Builder
if err := transformOpenAIStreamToAnthropic(openAIStream, &output); err != nil {
t.Fatalf("transformOpenAIStreamToAnthropic() error = %v", err)
}
got := output.String()
if !strings.Contains(got, "\"message_start\"") {
t.Fatal("expected message_start event in output")
}
if !strings.Contains(got, largeContent) {
t.Fatal("expected large content to be preserved in output")
}
if !strings.Contains(got, "\"message_stop\"") {
t.Fatal("expected message_stop event in output")
}
}

View file

@ -1,18 +1,33 @@
package service
import (
"io"
"time"
"github.com/seifghazi/claude-code-monitor/internal/config"
"github.com/seifghazi/claude-code-monitor/internal/model"
)
type StorageService interface {
// Core CRUD operations
SaveRequest(request *model.RequestLog) (string, error)
GetRequests(page, limit int) ([]model.RequestLog, int, error)
GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error)
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
ClearRequests() (int, error)
// Update operations
UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error
UpdateRequestWithResponse(request *model.RequestLog) error
EnsureDirectoryExists() error
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
// Maintenance operations
DeleteRequestsOlderThan(age time.Duration) (int, error)
GetDatabaseStats() (map[string]interface{}, error)
// Configuration
GetConfig() *config.StorageConfig
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
EnsureDirectoryExists() error
// Cleanup - implements io.Closer
io.Closer
}

View file

@ -4,7 +4,9 @@ import (
"database/sql"
"encoding/json"
"fmt"
"log"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
@ -15,23 +17,63 @@ import (
type sqliteStorageService struct {
db *sql.DB
config *config.StorageConfig
logger *log.Logger
// Prepared statements for frequently used queries
stmtInsertRequest *sql.Stmt
stmtUpdateResponse *sql.Stmt
stmtUpdateGrading *sql.Stmt
stmtGetRequestByID *sql.Stmt
stmtGetRequestsPage *sql.Stmt
stmtGetRequestsCount *sql.Stmt
stmtDeleteOldRequests *sql.Stmt
}
func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) {
db, err := sql.Open("sqlite3", cfg.DBPath)
return NewSQLiteStorageServiceWithLogger(cfg, log.Default())
}
func NewSQLiteStorageServiceWithLogger(cfg *config.StorageConfig, logger *log.Logger) (StorageService, error) {
// Enable WAL mode and other optimizations via connection string
// _journal_mode=WAL: Write-Ahead Logging for better concurrent read performance
// _synchronous=NORMAL: Good balance of safety and performance
// _busy_timeout=5000: Wait up to 5 seconds if database is locked
// _cache_size=-20000: Use 20MB of memory for cache (negative = KB)
connStr := cfg.DBPath + "?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000"
db, err := sql.Open("sqlite3", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Configure connection pool
// SQLite only supports one writer at a time, but can handle multiple readers
db.SetMaxOpenConns(1) // Serialize writes to avoid SQLITE_BUSY errors
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(time.Hour)
// Verify connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
service := &sqliteStorageService{
db: db,
config: cfg,
logger: logger,
}
if err := service.createTables(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to create tables: %w", err)
}
if err := service.prepareStatements(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to prepare statements: %w", err)
}
return service, nil
}
@ -39,7 +81,7 @@ func (s *sqliteStorageService) createTables() error {
schema := `
CREATE TABLE IF NOT EXISTS requests (
id TEXT PRIMARY KEY,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
timestamp DATETIME NOT NULL,
method TEXT NOT NULL,
endpoint TEXT NOT NULL,
headers TEXT NOT NULL,
@ -50,17 +92,100 @@ func (s *sqliteStorageService) createTables() error {
response TEXT,
model TEXT,
original_model TEXT,
routed_model TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
routed_model TEXT
);
CREATE INDEX IF NOT EXISTS idx_timestamp ON requests(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_endpoint ON requests(endpoint);
CREATE INDEX IF NOT EXISTS idx_model ON requests(model);
-- Index for listing requests by time (most common query)
CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC);
-- Index for filtering by model
CREATE INDEX IF NOT EXISTS idx_requests_model ON requests(model);
-- Index for filtering by endpoint
CREATE INDEX IF NOT EXISTS idx_requests_endpoint ON requests(endpoint);
`
_, err := s.db.Exec(schema)
return err
if err != nil {
return err
}
// Run migrations
s.migrateSchema()
return nil
}
func (s *sqliteStorageService) migrateSchema() {
// Ensure WAL mode is enabled (in case opened without connection string params)
_, err := s.db.Exec("PRAGMA journal_mode=WAL")
if err != nil {
s.logger.Printf("Warning: failed to set WAL mode: %v", err)
}
// Drop old redundant index if it exists (we renamed to idx_requests_timestamp)
s.db.Exec("DROP INDEX IF EXISTS idx_timestamp")
}
func (s *sqliteStorageService) prepareStatements() error {
var err error
s.stmtInsertRequest, err = s.db.Prepare(`
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
return fmt.Errorf("failed to prepare insert statement: %w", err)
}
s.stmtUpdateResponse, err = s.db.Prepare(`
UPDATE requests SET response = ? WHERE id = ?
`)
if err != nil {
return fmt.Errorf("failed to prepare update response statement: %w", err)
}
s.stmtUpdateGrading, err = s.db.Prepare(`
UPDATE requests SET prompt_grade = ? WHERE id = ?
`)
if err != nil {
return fmt.Errorf("failed to prepare update grading statement: %w", err)
}
s.stmtGetRequestByID, err = s.db.Prepare(`
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
WHERE id = ?
`)
if err != nil {
return fmt.Errorf("failed to prepare get by ID statement: %w", err)
}
s.stmtGetRequestsPage, err = s.db.Prepare(`
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
`)
if err != nil {
return fmt.Errorf("failed to prepare get requests page statement: %w", err)
}
s.stmtGetRequestsCount, err = s.db.Prepare(`
SELECT COUNT(*) FROM requests
`)
if err != nil {
return fmt.Errorf("failed to prepare count statement: %w", err)
}
s.stmtDeleteOldRequests, err = s.db.Prepare(`
DELETE FROM requests WHERE timestamp < ?
`)
if err != nil {
return fmt.Errorf("failed to prepare delete old requests statement: %w", err)
}
return nil
}
func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) {
@ -74,12 +199,7 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
return "", fmt.Errorf("failed to marshal body: %w", err)
}
query := `
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err = s.db.Exec(query,
_, err = s.stmtInsertRequest.Exec(
request.RequestID,
request.Timestamp,
request.Method,
@ -100,10 +220,24 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
return request.RequestID, nil
}
func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, int, error) {
func (s *sqliteStorageService) GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error) {
whereClause := ""
countArgs := []interface{}{}
queryArgs := []interface{}{}
if modelFilter != "" && modelFilter != "all" {
// Escape LIKE special characters to prevent pattern injection
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
whereClause = " WHERE LOWER(model) LIKE ? ESCAPE '\\'"
filterValue := "%" + escapedFilter + "%"
countArgs = append(countArgs, filterValue)
queryArgs = append(queryArgs, filterValue)
}
// Get total count
var total int
err := s.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&total)
countQuery := "SELECT COUNT(*) FROM requests" + whereClause
err := s.db.QueryRow(countQuery, countArgs...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("failed to get total count: %w", err)
}
@ -112,71 +246,21 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
offset := (page - 1) * limit
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
FROM requests` + whereClause + `
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
`
queryArgs = append(queryArgs, limit, offset)
rows, err := s.db.Query(query, limit, offset)
rows, err := s.db.Query(query, queryArgs...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query requests: %w", err)
}
defer rows.Close()
var requests []model.RequestLog
for rows.Next() {
var req model.RequestLog
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := rows.Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
&req.Endpoint,
&headersJSON,
&bodyJSON,
&req.Model,
&req.UserAgent,
&req.ContentType,
&promptGradeJSON,
&responseJSON,
&req.OriginalModel,
&req.RoutedModel,
)
if err != nil {
// Error scanning row - skip
continue
}
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
// Error unmarshaling headers
continue
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
// Error unmarshaling body
continue
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
}
requests = append(requests, req)
requests, err := s.scanRequestRows(rows)
if err != nil {
return nil, 0, err
}
return requests, total, nil
@ -193,6 +277,12 @@ func (s *sqliteStorageService) ClearRequests() (int, error) {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
// Reclaim space after clearing all data
_, err = s.db.Exec("VACUUM")
if err != nil {
s.logger.Printf("Warning: failed to vacuum database: %v", err)
}
return int(rowsAffected), nil
}
@ -202,12 +292,16 @@ func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade
return fmt.Errorf("failed to marshal grade: %w", err)
}
query := "UPDATE requests SET prompt_grade = ? WHERE id = ?"
_, err = s.db.Exec(query, string(gradeJSON), requestID)
result, err := s.stmtUpdateGrading.Exec(string(gradeJSON), requestID)
if err != nil {
return fmt.Errorf("failed to update request with grading: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
return fmt.Errorf("request %s not found", requestID)
}
return nil
}
@ -217,12 +311,72 @@ func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestL
return fmt.Errorf("failed to marshal response: %w", err)
}
query := "UPDATE requests SET response = ? WHERE id = ?"
_, err = s.db.Exec(query, string(responseJSON), request.RequestID)
result, err := s.stmtUpdateResponse.Exec(string(responseJSON), request.RequestID)
if err != nil {
return fmt.Errorf("failed to update request with response: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
return fmt.Errorf("request %s not found", request.RequestID)
}
return nil
}
// SaveRequestWithResponse saves a request and its response in a single transaction
func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
headersJSON, err := json.Marshal(request.Headers)
if err != nil {
return fmt.Errorf("failed to marshal headers: %w", err)
}
bodyJSON, err := json.Marshal(request.Body)
if err != nil {
return fmt.Errorf("failed to marshal body: %w", err)
}
// Insert request
_, err = tx.Stmt(s.stmtInsertRequest).Exec(
request.RequestID,
request.Timestamp,
request.Method,
request.Endpoint,
string(headersJSON),
string(bodyJSON),
request.UserAgent,
request.ContentType,
request.Model,
request.OriginalModel,
request.RoutedModel,
)
if err != nil {
return fmt.Errorf("failed to insert request: %w", err)
}
// Update with response if present
if request.Response != nil {
responseJSON, err := json.Marshal(request.Response)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
_, err = tx.Stmt(s.stmtUpdateResponse).Exec(string(responseJSON), request.RequestID)
if err != nil {
return fmt.Errorf("failed to update response: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
@ -232,10 +386,13 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
}
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
// Escape LIKE special characters to prevent pattern injection
escapedID := escapeLikePattern(shortID)
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
WHERE id LIKE ?
WHERE id LIKE ? ESCAPE '\'
ORDER BY timestamp DESC
LIMIT 1
`
@ -244,7 +401,7 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := s.db.QueryRow(query, "%"+shortID).Scan(
err := s.db.QueryRow(query, "%"+escapedID).Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
@ -267,29 +424,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
return nil, "", fmt.Errorf("failed to query request: %w", err)
}
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
return nil, "", fmt.Errorf("failed to unmarshal headers: %w", err)
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
return nil, "", err
}
return &req, req.RequestID, nil
@ -300,19 +436,36 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
}
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
`
return s.GetAllRequestsWithLimit(modelFilter, 0) // 0 means no limit
}
// GetAllRequestsWithLimit returns requests with an optional limit (0 = no limit)
func (s *sqliteStorageService) GetAllRequestsWithLimit(modelFilter string, limit int) ([]*model.RequestLog, error) {
var query string
args := []interface{}{}
if modelFilter != "" && modelFilter != "all" {
query += " WHERE LOWER(model) LIKE ?"
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
// Escape LIKE special characters
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
query = `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
WHERE LOWER(model) LIKE ? ESCAPE '\'
ORDER BY timestamp DESC
`
args = append(args, "%"+escapedFilter+"%")
} else {
query = `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
FROM requests
ORDER BY timestamp DESC
`
}
query += " ORDER BY timestamp DESC"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
}
rows, err := s.db.Query(query, args...)
if err != nil {
@ -321,6 +474,124 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
defer rows.Close()
var requests []*model.RequestLog
for rows.Next() {
req, err := s.scanSingleRow(rows)
if err != nil {
s.logger.Printf("Warning: failed to scan request row: %v", err)
continue
}
requests = append(requests, req)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating rows: %w", err)
}
return requests, nil
}
// DeleteRequestsOlderThan removes requests older than the specified duration
func (s *sqliteStorageService) DeleteRequestsOlderThan(age time.Duration) (int, error) {
cutoff := time.Now().Add(-age)
result, err := s.stmtDeleteOldRequests.Exec(cutoff.Format(time.RFC3339))
if err != nil {
return 0, fmt.Errorf("failed to delete old requests: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// GetDatabaseStats returns statistics about the database
func (s *sqliteStorageService) GetDatabaseStats() (map[string]interface{}, error) {
stats := make(map[string]interface{})
// Get row count
var count int
err := s.stmtGetRequestsCount.QueryRow().Scan(&count)
if err != nil {
return nil, fmt.Errorf("failed to get count: %w", err)
}
stats["total_requests"] = count
// Get database size
var pageCount, pageSize int
err = s.db.QueryRow("PRAGMA page_count").Scan(&pageCount)
if err == nil {
err = s.db.QueryRow("PRAGMA page_size").Scan(&pageSize)
if err == nil {
stats["database_size_bytes"] = pageCount * pageSize
}
}
// Get oldest and newest timestamps
var oldest, newest sql.NullString
err = s.db.QueryRow("SELECT MIN(timestamp), MAX(timestamp) FROM requests").Scan(&oldest, &newest)
if err == nil {
if oldest.Valid {
stats["oldest_request"] = oldest.String
}
if newest.Valid {
stats["newest_request"] = newest.String
}
}
return stats, nil
}
func (s *sqliteStorageService) Close() error {
// Close prepared statements
if s.stmtInsertRequest != nil {
s.stmtInsertRequest.Close()
}
if s.stmtUpdateResponse != nil {
s.stmtUpdateResponse.Close()
}
if s.stmtUpdateGrading != nil {
s.stmtUpdateGrading.Close()
}
if s.stmtGetRequestByID != nil {
s.stmtGetRequestByID.Close()
}
if s.stmtGetRequestsPage != nil {
s.stmtGetRequestsPage.Close()
}
if s.stmtGetRequestsCount != nil {
s.stmtGetRequestsCount.Close()
}
if s.stmtDeleteOldRequests != nil {
s.stmtDeleteOldRequests.Close()
}
// Checkpoint WAL before closing
_, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)")
if err != nil {
s.logger.Printf("Warning: failed to checkpoint WAL: %v", err)
}
return s.db.Close()
}
// Helper functions
// escapeLikePattern escapes special characters in LIKE patterns
func escapeLikePattern(s string) string {
// Escape \, %, and _ characters
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}
// scanRequestRows scans multiple rows into a slice of RequestLog
func (s *sqliteStorageService) scanRequestRows(rows *sql.Rows) ([]model.RequestLog, error) {
var requests []model.RequestLog
for rows.Next() {
var req model.RequestLog
var headersJSON, bodyJSON string
@ -342,43 +613,86 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
&req.RoutedModel,
)
if err != nil {
// Error scanning row - skip
s.logger.Printf("Warning: failed to scan row: %v", err)
continue
}
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
// Error unmarshaling headers
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
s.logger.Printf("Warning: failed to unmarshal request fields: %v", err)
continue
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
// Error unmarshaling body
continue
}
req.Body = body
requests = append(requests, req)
}
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
}
requests = append(requests, &req)
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating rows: %w", err)
}
return requests, nil
}
func (s *sqliteStorageService) Close() error {
return s.db.Close()
// scanSingleRow scans a single row into a RequestLog pointer
func (s *sqliteStorageService) scanSingleRow(rows *sql.Rows) (*model.RequestLog, error) {
var req model.RequestLog
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := rows.Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
&req.Endpoint,
&headersJSON,
&bodyJSON,
&req.Model,
&req.UserAgent,
&req.ContentType,
&promptGradeJSON,
&responseJSON,
&req.OriginalModel,
&req.RoutedModel,
)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
return nil, err
}
return &req, nil
}
// unmarshalRequestFields unmarshals JSON fields into a RequestLog
func (s *sqliteStorageService) unmarshalRequestFields(req *model.RequestLog, headersJSON, bodyJSON string, promptGradeJSON, responseJSON sql.NullString) error {
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
return fmt.Errorf("failed to unmarshal headers: %w", err)
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
return fmt.Errorf("failed to unmarshal body: %w", err)
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err != nil {
s.logger.Printf("Warning: failed to unmarshal prompt grade: %v", err)
} else {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err != nil {
s.logger.Printf("Warning: failed to unmarshal response: %v", err)
} else {
req.Response = &resp
}
}
return nil
}

View file

@ -0,0 +1,69 @@
package service
import (
"fmt"
"path/filepath"
"testing"
"time"
"github.com/seifghazi/claude-code-monitor/internal/config"
"github.com/seifghazi/claude-code-monitor/internal/model"
)
func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "requests.db")
storage, err := NewSQLiteStorageService(&config.StorageConfig{DBPath: dbPath})
if err != nil {
t.Fatalf("NewSQLiteStorageService() error = %v", err)
}
sqliteStorage, ok := storage.(*sqliteStorageService)
if !ok {
t.Fatalf("unexpected storage type %T", storage)
}
defer sqliteStorage.Close()
requests := []struct {
id string
model string
}{
{id: "1", model: "claude-3-5-sonnet"},
{id: "2", model: "gpt-4o"},
{id: "3", model: "claude-3-5-sonnet"},
{id: "4", model: "gpt-4o-mini"},
}
for i, req := range requests {
_, err := storage.SaveRequest(&model.RequestLog{
RequestID: req.id,
Timestamp: time.Date(2026, 3, 19, 12, 0, i, 0, time.UTC).Format(time.RFC3339),
Method: "POST",
Endpoint: "/v1/messages",
Headers: map[string][]string{"Content-Type": {"application/json"}},
Body: map[string]string{"request": fmt.Sprintf("body-%d", i)},
Model: req.model,
UserAgent: "test",
ContentType: "application/json",
})
if err != nil {
t.Fatalf("SaveRequest() error = %v", err)
}
}
got, total, err := storage.GetRequests(1, 1, "gpt")
if err != nil {
t.Fatalf("GetRequests() error = %v", err)
}
if total != 2 {
t.Fatalf("expected filtered total 2, got %d", total)
}
if len(got) != 1 {
t.Fatalf("expected 1 paginated result, got %d", len(got))
}
if got[0].RequestID != "4" {
t.Fatalf("expected newest filtered request ID 4, got %s", got[0].RequestID)
}
}

30
proxy/internal/sse/sse.go Normal file
View file

@ -0,0 +1,30 @@
package sse
import (
"bufio"
"io"
"strings"
)
// ForEachLine reads line-oriented SSE content without bufio.Scanner's token limit.
func ForEachLine(r io.Reader, fn func(string) error) error {
reader := bufio.NewReader(r)
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
return err
}
if len(line) > 0 {
line = strings.TrimRight(line, "\r\n")
if callErr := fn(line); callErr != nil {
return callErr
}
}
if err == io.EOF {
return nil
}
}
}

View file

@ -0,0 +1,31 @@
package sse
import (
"strings"
"testing"
)
func TestForEachLineHandlesLargeLines(t *testing.T) {
largePayload := strings.Repeat("x", 128*1024)
input := "data: " + largePayload + "\n\n"
var lines []string
if err := ForEachLine(strings.NewReader(input), func(line string) error {
lines = append(lines, line)
return nil
}); err != nil {
t.Fatalf("ForEachLine() error = %v", err)
}
if len(lines) != 2 {
t.Fatalf("expected 2 lines, got %d", len(lines))
}
if lines[0] != "data: "+largePayload {
t.Fatalf("unexpected first line length: got %d", len(lines[0]))
}
if lines[1] != "" {
t.Fatalf("expected blank separator line, got %q", lines[1])
}
}