From b9da198e1f75c35317f4b489a665b3ddbcd1896a Mon Sep 17 00:00:00 2001 From: sid Date: Thu, 19 Mar 2026 19:00:24 -0600 Subject: [PATCH] Harden proxy auth, storage, and conversation access --- .env.example | 18 +- README.md | 73 +++- config.yaml.example | 75 +++- proxy/cmd/proxy/main.go | 29 +- proxy/internal/config/config.go | 194 +++++++++-- proxy/internal/config/config_test.go | 107 ++++++ proxy/internal/middleware/auth.go | 83 +++++ proxy/internal/middleware/auth_test.go | 126 +++++++ proxy/internal/service/conversation.go | 157 ++++++++- proxy/internal/service/conversation_test.go | 107 ++++++ proxy/internal/service/storage_sqlite.go | 189 +++++++++- proxy/internal/service/storage_sqlite_test.go | 325 +++++++++++++++--- 12 files changed, 1362 insertions(+), 121 deletions(-) create mode 100644 proxy/internal/middleware/auth.go create mode 100644 proxy/internal/middleware/auth_test.go create mode 100644 proxy/internal/service/conversation_test.go diff --git a/.env.example b/.env.example index eddb829..85cba1b 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ # Claude Code Monitor Configuration # Server Configuration +SERVER_HOST=127.0.0.1 PORT=3001 READ_TIMEOUT=500 WRITE_TIMEOUT=500 @@ -18,10 +19,21 @@ ANTHROPIC_MAX_RETRIES=3 # OPENAI_ALLOW_CLIENT_API_KEY=false # OPENAI_CLIENT_API_KEY_HEADER=x-openai-api-key +# Auth Configuration +# AUTH_ENABLED=false +# AUTH_TOKEN=change-me +# AUTH_API_KEY_HEADER=x-api-key +# AUTH_ALLOW_LOCALHOST_BYPASS=true + # Storage Configuration DB_PATH=requests.db +STORAGE_CAPTURE_REQUEST_BODY=true +STORAGE_CAPTURE_RESPONSE_BODY=true +STORAGE_METADATA_ONLY=false +STORAGE_RETENTION_DAYS=0 +# STORAGE_REDACTED_FIELDS=api_key,authorization,token,password,secret,access_token,refresh_token,client_secret # CORS Configuration (comma-separated values) -# CORS_ALLOWED_ORIGINS=* -# CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS -# CORS_ALLOWED_HEADERS=* \ No newline at end of file +# CORS_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000,http://localhost:5173,http://127.0.0.1:5173 +# CORS_ALLOWED_METHODS=GET,POST,DELETE,OPTIONS +# CORS_ALLOWED_HEADERS=Accept,Authorization,Content-Type,Anthropic-Version,Anthropic-Beta,X-API-Key,X-Requested-With diff --git a/README.md b/README.md index 91c5be9..4fcdf22 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,13 @@ Claude Code Proxy serves three main purposes: - **Conversation Analysis**: View full conversation threads with tool usage - **Easy Setup**: One-command startup for both services +## Security Defaults + +- The proxy binds to `127.0.0.1` by default for local-only access. +- CORS defaults are restricted to localhost origins. +- If you want to expose the proxy on a public interface, you must set `AUTH_ENABLED=true` and provide `AUTH_TOKEN`. +- When auth is enabled, the proxy accepts either `Authorization: Bearer ` or `X-API-Key: `. + ## Quick Start ### Prerequisites @@ -75,16 +82,25 @@ Claude Code Proxy serves three main purposes: # Build the image docker build -t claude-code-proxy . - # Run with default settings - docker run -p 3001:3001 -p 5173:5173 claude-code-proxy + # Run locally without publishing ports + docker run claude-code-proxy + + # Run with published ports + docker run -p 3001:3001 -p 5173:5173 \ + -e SERVER_HOST=0.0.0.0 \ + -e AUTH_ENABLED=true \ + -e AUTH_TOKEN=change-me \ + claude-code-proxy ``` 4. **Run with persistent data and custom configuration** ```bash # Create a data directory for persistent SQLite database mkdir -p ./data - + # Option 1: Run with config file (recommended) + # If you expose the container with `-p`, set server.host to 0.0.0.0 + # and enable auth in the mounted config file. docker run -p 3001:3001 -p 5173:5173 \ -v ./data:/app/data \ -v ./config.yaml:/app/config.yaml:ro \ @@ -93,9 +109,11 @@ Claude Code Proxy serves three main purposes: # Option 2: Run with environment variables docker run -p 3001:3001 -p 5173:5173 \ -v ./data:/app/data \ + -e SERVER_HOST=0.0.0.0 \ -e ANTHROPIC_FORWARD_URL=https://api.anthropic.com \ + -e AUTH_ENABLED=true \ + -e AUTH_TOKEN=change-me \ -e PORT=3001 \ - -e WEB_PORT=5173 \ claude-code-proxy ``` @@ -113,9 +131,11 @@ Claude Code Proxy serves three main purposes: - ./data:/app/data - ./config.yaml:/app/config.yaml:ro # Mount config file environment: + - SERVER_HOST=0.0.0.0 - ANTHROPIC_FORWARD_URL=https://api.anthropic.com + - AUTH_ENABLED=true + - AUTH_TOKEN=change-me - PORT=3001 - - WEB_PORT=5173 - DB_PATH=/app/data/requests.db ``` @@ -169,6 +189,7 @@ make help # Show all commands Create a `config.yaml` file (or copy from `config.yaml.example`): ```yaml server: + host: 127.0.0.1 port: 3001 providers: @@ -180,6 +201,32 @@ providers: storage: db_path: "requests.db" + +auth: + enabled: false + token: "" +``` + +### Auth + +To expose the proxy beyond localhost, enable auth and provide a token: + +```yaml +auth: + enabled: true + token: "change-me" +``` + +Then send either: + +```bash +curl -H "Authorization: Bearer change-me" http://localhost:3001/v1/models +``` + +or: + +```bash +curl -H "X-API-Key: change-me" http://localhost:3001/v1/models ``` ### Subagent Configuration (Optional) @@ -241,6 +288,11 @@ Use case: Different specialists for different tasks, optimizing for speed/cost/q Override config via environment: - `PORT` - Server port +- `SERVER_HOST` - Server bind host +- `AUTH_ENABLED` - Enable auth for non-health endpoints +- `AUTH_TOKEN` - Shared auth secret +- `AUTH_API_KEY_HEADER` - Header name for API key auth +- `AUTH_ALLOW_LOCALHOST_BYPASS` - Allow localhost requests to bypass auth - `OPENAI_API_KEY` - OpenAI API key - `DB_PATH` - Database path - `SUBAGENT_MAPPINGS` - Comma-separated mappings (e.g., `"code-reviewer:gpt-4o,data-analyst:o3"`) @@ -251,22 +303,27 @@ All environment variables can be configured when running the Docker container: | Variable | Default | Description | |----------|---------|-------------| +| `SERVER_HOST` | `127.0.0.1` | Proxy bind host | | `PORT` | `3001` | Proxy server port | -| `WEB_PORT` | `5173` | Web dashboard port | | `READ_TIMEOUT` | `600` | Server read timeout (seconds) | | `WRITE_TIMEOUT` | `600` | Server write timeout (seconds) | | `IDLE_TIMEOUT` | `600` | Server idle timeout (seconds) | | `ANTHROPIC_FORWARD_URL` | `https://api.anthropic.com` | Target Anthropic API URL | | `ANTHROPIC_VERSION` | `2023-06-01` | Anthropic API version | | `ANTHROPIC_MAX_RETRIES` | `3` | Maximum retry attempts | +| `AUTH_ENABLED` | `false` | Enable auth for non-health endpoints | +| `AUTH_TOKEN` | `""` | Shared auth token | +| `AUTH_API_KEY_HEADER` | `x-api-key` | Header name for API-key style auth | +| `AUTH_ALLOW_LOCALHOST_BYPASS` | `true` | Allow loopback requests to bypass auth | | `DB_PATH` | `/app/data/requests.db` | SQLite database path | Example with custom configuration: ```bash docker run -p 3001:3001 -p 5173:5173 \ -v ./data:/app/data \ - -e PORT=8080 \ - -e WEB_PORT=3000 \ + -e SERVER_HOST=0.0.0.0 \ + -e AUTH_ENABLED=true \ + -e AUTH_TOKEN=change-me \ -e ANTHROPIC_FORWARD_URL=https://api.anthropic.com \ -e DB_PATH=/app/data/custom.db \ claude-code-proxy diff --git a/config.yaml.example b/config.yaml.example index 34338a8..bd820e6 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -4,6 +4,10 @@ # Server configuration server: + # Bind host for the proxy server. + # Defaults to 127.0.0.1 for local-only access. + host: 127.0.0.1 + # Port to listen on (default: 3001) port: 3001 @@ -49,29 +53,76 @@ providers: # CORS Configuration # Controls Cross-Origin Resource Sharing for the web UI cors: - # Allowed origins (use ["*"] for all origins - not recommended for production) + # Allowed origins. Defaults are localhost-only. # Can also be set via CORS_ALLOWED_ORIGINS environment variable (comma-separated) allowed_origins: - - "*" + - "http://localhost:3000" + - "http://127.0.0.1:3000" + - "http://localhost:5173" + - "http://127.0.0.1:5173" # Allowed HTTP methods # Can also be set via CORS_ALLOWED_METHODS environment variable (comma-separated) allowed_methods: - "GET" - "POST" - - "PUT" - "DELETE" - "OPTIONS" - # Allowed headers (use ["*"] for all headers) + # Allowed headers # Can also be set via CORS_ALLOWED_HEADERS environment variable (comma-separated) allowed_headers: - - "*" + - "Accept" + - "Authorization" + - "Content-Type" + - "Anthropic-Version" + - "Anthropic-Beta" + - "X-API-Key" + - "X-Requested-With" + +# Auth Configuration +# When enabled, all non-health endpoints require bearer token or X-API-Key auth. +auth: + # Enable auth for non-health endpoints + # Public/non-loopback binds must enable auth and set a token. + enabled: false + + # Shared secret used for Authorization: Bearer or X-API-Key: + token: "" + + # Header name used for API-key style auth + api_key_header: "x-api-key" + + # Allow requests from localhost to bypass auth when enabled + allow_localhost_bypass: true # Storage configuration storage: # SQLite database path for storing request history db_path: "requests.db" + + # Keep request bodies in storage. Disable for metadata-only tracking. + capture_request_body: true + + # Keep response bodies and streaming chunks in storage. + capture_response_body: true + + # Store only request/response metadata, not payload bodies. + metadata_only: false + + # Delete records older than this many days on write. 0 disables cleanup. + retention_days: 0 + + # JSON payload fields to redact before storage. + redacted_fields: + - api_key + - authorization + - token + - password + - secret + - access_token + - refresh_token + - client_secret # Directory for storing request files (if needed in future) # requests_dir: "./requests" @@ -99,6 +150,7 @@ subagents: # The following environment variables will override the YAML configuration: # # Server: +# SERVER_HOST - Bind host (default: 127.0.0.1) # PORT - Server port # READ_TIMEOUT - Read timeout duration # WRITE_TIMEOUT - Write timeout duration @@ -115,8 +167,19 @@ subagents: # OPENAI_ALLOW_CLIENT_API_KEY - Allow client-provided API keys (true/false) # OPENAI_CLIENT_API_KEY_HEADER - Header name for client API key # +# Auth: +# AUTH_ENABLED - Enable auth for non-health endpoints (true/false) +# AUTH_TOKEN - Shared secret for bearer / API-key auth +# AUTH_API_KEY_HEADER - Header name for API-key style auth +# AUTH_ALLOW_LOCALHOST_BYPASS - Allow loopback requests to bypass auth (true/false) +# # Storage: # DB_PATH - Database file path +# STORAGE_CAPTURE_REQUEST_BODY - Keep request bodies (true/false) +# STORAGE_CAPTURE_RESPONSE_BODY - Keep response bodies (true/false) +# STORAGE_METADATA_ONLY - Store metadata only (true/false) +# STORAGE_RETENTION_DAYS - Delete rows older than N days +# STORAGE_REDACTED_FIELDS - Comma-separated payload fields to redact # # CORS: # CORS_ALLOWED_ORIGINS - Comma-separated allowed origins @@ -125,4 +188,4 @@ subagents: # # Subagents: # SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs -# Example: "code-reviewer:claude-3-5-sonnet" \ No newline at end of file +# Example: "code-reviewer:claude-3-5-sonnet" diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index d5ef215..698b009 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -3,6 +3,7 @@ package main import ( "context" "log" + "net" "net/http" "os" "os/signal" @@ -56,6 +57,7 @@ func main() { ) r.Use(middleware.Logging) + r.Use(middleware.Auth(cfg.Auth)) r.HandleFunc("/v1/chat/completions", h.ChatCompletions).Methods("POST") r.HandleFunc("/v1/messages", h.Messages).Methods("POST") @@ -73,7 +75,7 @@ func main() { r.NotFoundHandler = http.HandlerFunc(h.NotFound) srv := &http.Server{ - Addr: ":" + cfg.Server.Port, + Addr: net.JoinHostPort(cfg.Server.Host, cfg.Server.Port), Handler: corsHandler(r), ReadTimeout: cfg.Server.ReadTimeout, WriteTimeout: cfg.Server.WriteTimeout, @@ -81,14 +83,19 @@ func main() { } go func() { - logger.Printf("🚀 Claude Code Monitor Server running on http://localhost:%s", cfg.Server.Port) + logger.Printf("🚀 Claude Code Monitor Server running on http://%s", srv.Addr) logger.Printf("📡 API endpoints available at:") - logger.Printf(" - POST http://localhost:%s/v1/messages (Anthropic format)", cfg.Server.Port) - logger.Printf(" - GET http://localhost:%s/v1/models", cfg.Server.Port) - logger.Printf(" - GET http://localhost:%s/health", cfg.Server.Port) + logger.Printf(" - POST http://%s/v1/messages (Anthropic format)", srv.Addr) + logger.Printf(" - GET http://%s/v1/models", srv.Addr) + logger.Printf(" - GET http://%s/health", srv.Addr) logger.Printf("🎨 Web UI available at:") - logger.Printf(" - GET http://localhost:%s/ (Request Visualizer)", cfg.Server.Port) - logger.Printf(" - GET http://localhost:%s/api/requests (Request API)", cfg.Server.Port) + logger.Printf(" - GET http://%s/ (Request Visualizer)", srv.Addr) + logger.Printf(" - GET http://%s/api/requests (Request API)", srv.Addr) + if cfg.Auth.Enabled { + logger.Printf("🔐 Auth enabled using bearer token or %s", cfg.Auth.APIKeyHeader) + } else { + logger.Printf("🔓 Auth disabled for local-only access") + } if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatalf("❌ Server failed to start: %v", err) @@ -105,7 +112,13 @@ func main() { defer cancel() if err := srv.Shutdown(ctx); err != nil { - logger.Fatalf("❌ Server forced to shutdown: %v", err) + logger.Printf("❌ Server forced to shutdown: %v", err) + } + + // Close storage service (checkpoints WAL, closes prepared statements) + logger.Println("🗄️ Closing database...") + if err := storageService.Close(); err != nil { + logger.Printf("❌ Error closing storage: %v", err) } logger.Println("✅ Server exited") diff --git a/proxy/internal/config/config.go b/proxy/internal/config/config.go index 1e5b110..45d5f79 100644 --- a/proxy/internal/config/config.go +++ b/proxy/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net" "os" "path/filepath" "strconv" @@ -17,6 +18,7 @@ type Config struct { Providers ProvidersConfig `yaml:"providers"` Storage StorageConfig `yaml:"storage"` Subagents SubagentsConfig `yaml:"subagents"` + Auth AuthConfig `yaml:"auth"` CORS CORSConfig `yaml:"cors"` Anthropic AnthropicConfig } @@ -28,6 +30,7 @@ type CORSConfig struct { } type ServerConfig struct { + Host string `yaml:"host"` Port string `yaml:"port"` Timeouts TimeoutsConfig `yaml:"timeouts"` // Legacy fields @@ -60,6 +63,13 @@ type OpenAIProviderConfig struct { ClientAPIKeyHeader string `yaml:"client_api_key_header"` // Header name for client API key (default: x-openai-api-key) } +type AuthConfig struct { + Enabled bool `yaml:"enabled"` + Token string `yaml:"token"` + APIKeyHeader string `yaml:"api_key_header"` + AllowLocalhostBypass bool `yaml:"allow_localhost_bypass"` +} + type AnthropicConfig struct { BaseURL string Version string @@ -67,8 +77,13 @@ type AnthropicConfig struct { } type StorageConfig struct { - RequestsDir string `yaml:"requests_dir"` - DBPath string `yaml:"db_path"` + RequestsDir string `yaml:"requests_dir"` + DBPath string `yaml:"db_path"` + CaptureRequestBody bool `yaml:"capture_request_body"` + CaptureResponseBody bool `yaml:"capture_response_body"` + MetadataOnly bool `yaml:"metadata_only"` + RetentionDays int `yaml:"retention_days"` + RedactedFields []string `yaml:"redacted_fields"` } type SubagentsConfig struct { @@ -88,40 +103,7 @@ func Load() (*Config, error) { } } - // Start with default configuration - cfg := &Config{ - Server: ServerConfig{ - Port: "3001", - ReadTimeout: 600 * time.Second, - WriteTimeout: 600 * time.Second, - IdleTimeout: 600 * time.Second, - }, - Providers: ProvidersConfig{ - Anthropic: AnthropicProviderConfig{ - BaseURL: "https://api.anthropic.com", - Version: "2023-06-01", - MaxRetries: 3, - }, - OpenAI: OpenAIProviderConfig{ - BaseURL: "https://api.openai.com", - APIKey: "", - AllowClientAPIKey: false, - ClientAPIKeyHeader: "x-openai-api-key", - }, - }, - Storage: StorageConfig{ - DBPath: "requests.db", - }, - Subagents: SubagentsConfig{ - Enable: false, - Mappings: make(map[string]string), - }, - CORS: CORSConfig{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"*"}, - }, - } + cfg := defaultConfig() if err := loadFirstAvailableConfig(cfg, candidateConfigPaths()); err != nil { return nil, err @@ -131,6 +113,9 @@ func Load() (*Config, error) { if envPort := os.Getenv("PORT"); envPort != "" { cfg.Server.Port = envPort } + if envHost := os.Getenv("SERVER_HOST"); envHost != "" { + cfg.Server.Host = envHost + } if envTimeout := os.Getenv("READ_TIMEOUT"); envTimeout != "" { cfg.Server.ReadTimeout = getDuration("READ_TIMEOUT", cfg.Server.ReadTimeout) } @@ -166,10 +151,44 @@ func Load() (*Config, error) { cfg.Providers.OpenAI.ClientAPIKeyHeader = envHeader } + // Override auth settings + if envAuthEnabled := os.Getenv("AUTH_ENABLED"); envAuthEnabled != "" { + cfg.Auth.Enabled = envAuthEnabled == "true" || envAuthEnabled == "1" + } + if envAuthToken := os.Getenv("AUTH_TOKEN"); envAuthToken != "" { + cfg.Auth.Token = envAuthToken + } + if envAPIKeyHeader := os.Getenv("AUTH_API_KEY_HEADER"); envAPIKeyHeader != "" { + cfg.Auth.APIKeyHeader = envAPIKeyHeader + } + if envLocalBypass := os.Getenv("AUTH_ALLOW_LOCALHOST_BYPASS"); envLocalBypass != "" { + cfg.Auth.AllowLocalhostBypass = envLocalBypass == "true" || envLocalBypass == "1" + } + // Override storage settings if envPath := os.Getenv("DB_PATH"); envPath != "" { cfg.Storage.DBPath = envPath } + if envCaptureReq := os.Getenv("STORAGE_CAPTURE_REQUEST_BODY"); envCaptureReq != "" { + cfg.Storage.CaptureRequestBody = envCaptureReq == "true" || envCaptureReq == "1" + } + if envCaptureResp := os.Getenv("STORAGE_CAPTURE_RESPONSE_BODY"); envCaptureResp != "" { + cfg.Storage.CaptureResponseBody = envCaptureResp == "true" || envCaptureResp == "1" + } + if envMetadataOnly := os.Getenv("STORAGE_METADATA_ONLY"); envMetadataOnly != "" { + cfg.Storage.MetadataOnly = envMetadataOnly == "true" || envMetadataOnly == "1" + } + if envRetentionDays := os.Getenv("STORAGE_RETENTION_DAYS"); envRetentionDays != "" { + cfg.Storage.RetentionDays = getInt("STORAGE_RETENTION_DAYS", cfg.Storage.RetentionDays) + } + if envRedacted := os.Getenv("STORAGE_REDACTED_FIELDS"); envRedacted != "" { + cfg.Storage.RedactedFields = splitAndTrim(envRedacted) + } + + if cfg.Storage.MetadataOnly { + cfg.Storage.CaptureRequestBody = false + cfg.Storage.CaptureResponseBody = false + } // Override CORS settings (comma-separated values) if envOrigins := os.Getenv("CORS_ALLOWED_ORIGINS"); envOrigins != "" { @@ -213,6 +232,10 @@ func Load() (*Config, error) { MaxRetries: cfg.Providers.Anthropic.MaxRetries, } + if err := validateSecurity(cfg); err != nil { + return nil, err + } + return cfg, nil } @@ -225,6 +248,76 @@ func (c *Config) loadFromFile(path string) error { return yaml.Unmarshal(data, c) } +func defaultConfig() *Config { + return &Config{ + Server: ServerConfig{ + Host: "127.0.0.1", + Port: "3001", + ReadTimeout: 600 * time.Second, + WriteTimeout: 600 * time.Second, + IdleTimeout: 600 * time.Second, + }, + Providers: ProvidersConfig{ + Anthropic: AnthropicProviderConfig{ + BaseURL: "https://api.anthropic.com", + Version: "2023-06-01", + MaxRetries: 3, + }, + OpenAI: OpenAIProviderConfig{ + BaseURL: "https://api.openai.com", + APIKey: "", + AllowClientAPIKey: false, + ClientAPIKeyHeader: "x-openai-api-key", + }, + }, + Storage: StorageConfig{ + DBPath: "requests.db", + CaptureRequestBody: true, + CaptureResponseBody: true, + MetadataOnly: false, + RetentionDays: 0, + RedactedFields: []string{ + "api_key", + "authorization", + "token", + "password", + "secret", + "access_token", + "refresh_token", + "client_secret", + }, + }, + Subagents: SubagentsConfig{ + Enable: false, + Mappings: make(map[string]string), + }, + Auth: AuthConfig{ + Enabled: false, + Token: "", + APIKeyHeader: "x-api-key", + AllowLocalhostBypass: true, + }, + CORS: CORSConfig{ + AllowedOrigins: []string{ + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:5173", + "http://127.0.0.1:5173", + }, + AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{ + "Accept", + "Authorization", + "Content-Type", + "Anthropic-Version", + "Anthropic-Beta", + "X-API-Key", + "X-Requested-With", + }, + }, + } +} + func candidateConfigPaths() []string { paths := []string{ filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml"), @@ -246,6 +339,35 @@ func candidateConfigPaths() []string { return unique } +func validateSecurity(cfg *Config) error { + if cfg.Server.Host == "" { + cfg.Server.Host = "127.0.0.1" + } + + if !isLoopbackHost(cfg.Server.Host) && !cfg.Auth.Enabled { + return fmt.Errorf("refusing to bind to %q without auth enabled; set AUTH_ENABLED=true and AUTH_TOKEN for public access", cfg.Server.Host) + } + + if cfg.Auth.Enabled && cfg.Auth.Token == "" && !isLoopbackHost(cfg.Server.Host) { + return fmt.Errorf("auth is enabled for public access but AUTH_TOKEN is empty") + } + + return nil +} + +func isLoopbackHost(host string) bool { + host = strings.TrimSpace(host) + if host == "localhost" { + return true + } + + if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil { + return ip.IsLoopback() + } + + return false +} + func loadFirstAvailableConfig(cfg *Config, paths []string) error { for _, path := range paths { if _, err := os.Stat(path); err != nil { diff --git a/proxy/internal/config/config_test.go b/proxy/internal/config/config_test.go index afd0a8f..84e4cfa 100644 --- a/proxy/internal/config/config_test.go +++ b/proxy/internal/config/config_test.go @@ -6,6 +6,70 @@ import ( "testing" ) +func TestDefaultConfigIncludesStorageControls(t *testing.T) { + cfg := defaultConfig() + + if !cfg.Storage.CaptureRequestBody { + t.Fatal("expected request body capture to be enabled by default") + } + if !cfg.Storage.CaptureResponseBody { + t.Fatal("expected response body capture to be enabled by default") + } + if cfg.Storage.MetadataOnly { + t.Fatal("expected metadata-only mode to be disabled by default") + } + if cfg.Storage.RetentionDays != 0 { + t.Fatalf("expected retention to be disabled by default, got %d", cfg.Storage.RetentionDays) + } + if len(cfg.Storage.RedactedFields) == 0 { + t.Fatal("expected default redacted field list to be populated") + } +} + +func TestLoadFromFileParsesStorageControls(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + yaml := ` +storage: + db_path: /tmp/claude.db + capture_request_body: false + capture_response_body: false + metadata_only: true + retention_days: 7 + redacted_fields: + - api_key + - secret +` + if err := os.WriteFile(configPath, []byte(yaml), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + cfg := defaultConfig() + if err := cfg.loadFromFile(configPath); err != nil { + t.Fatalf("loadFromFile() error = %v", err) + } + + if cfg.Storage.DBPath != "/tmp/claude.db" { + t.Fatalf("unexpected db path %q", cfg.Storage.DBPath) + } + if !cfg.Storage.MetadataOnly { + t.Fatal("expected metadata-only mode to load from file") + } + if cfg.Storage.CaptureRequestBody { + t.Fatal("expected request body capture to load as disabled") + } + if cfg.Storage.CaptureResponseBody { + t.Fatal("expected response body capture to load as disabled") + } + if cfg.Storage.RetentionDays != 7 { + t.Fatalf("unexpected retention days %d", cfg.Storage.RetentionDays) + } + if len(cfg.Storage.RedactedFields) != 2 { + t.Fatalf("unexpected redacted field count %d", len(cfg.Storage.RedactedFields)) + } +} + func TestLoadFirstAvailableConfigReturnsParseError(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "config.yaml") @@ -28,3 +92,46 @@ func TestLoadFirstAvailableConfigSkipsMissingFiles(t *testing.T) { t.Fatalf("expected nil error for missing config, got %v", err) } } + +func TestDefaultConfigUsesLoopbackAndLocalCors(t *testing.T) { + cfg := defaultConfig() + + if cfg.Server.Host != "127.0.0.1" { + t.Fatalf("expected loopback host, got %q", cfg.Server.Host) + } + + if cfg.Auth.Enabled { + t.Fatal("expected auth to be disabled by default for local development") + } + + if len(cfg.CORS.AllowedOrigins) == 0 { + t.Fatal("expected local CORS origins to be configured") + } + + for _, origin := range cfg.CORS.AllowedOrigins { + if origin == "*" { + t.Fatal("expected wildcard origin to be removed from defaults") + } + } +} + +func TestValidateSecurityRejectsPublicBindWithoutAuth(t *testing.T) { + cfg := defaultConfig() + cfg.Server.Host = "0.0.0.0" + cfg.Auth.Enabled = false + + if err := validateSecurity(cfg); err == nil { + t.Fatal("expected validation error for public bind without auth") + } +} + +func TestValidateSecurityAllowsPublicBindWithAuthToken(t *testing.T) { + cfg := defaultConfig() + cfg.Server.Host = "0.0.0.0" + cfg.Auth.Enabled = true + cfg.Auth.Token = "secret" + + if err := validateSecurity(cfg); err != nil { + t.Fatalf("expected public bind with auth token to be allowed, got %v", err) + } +} diff --git a/proxy/internal/middleware/auth.go b/proxy/internal/middleware/auth.go new file mode 100644 index 0000000..95ab9d2 --- /dev/null +++ b/proxy/internal/middleware/auth.go @@ -0,0 +1,83 @@ +package middleware + +import ( + "encoding/json" + "net" + "net/http" + "strings" + + "github.com/seifghazi/claude-code-monitor/internal/config" +) + +func Auth(cfg config.AuthConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions || r.URL.Path == "/health" { + next.ServeHTTP(w, r) + return + } + + if !cfg.Enabled { + next.ServeHTTP(w, r) + return + } + + if cfg.AllowLocalhostBypass && isLocalhostRequest(r.RemoteAddr) { + next.ServeHTTP(w, r) + return + } + + if token, ok := extractAuthToken(r, cfg); ok && token == cfg.Token { + next.ServeHTTP(w, r) + return + } + + w.Header().Set("WWW-Authenticate", `Bearer realm="claude-code-proxy"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + }) + }) + } +} + +func extractAuthToken(r *http.Request, cfg config.AuthConfig) (string, bool) { + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if authHeader != "" { + const bearerPrefix = "Bearer " + if len(authHeader) > len(bearerPrefix) && strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) { + return strings.TrimSpace(authHeader[len(bearerPrefix):]), true + } + } + + if cfg.APIKeyHeader != "" { + if headerValue := strings.TrimSpace(r.Header.Get(cfg.APIKeyHeader)); headerValue != "" { + return headerValue, true + } + } + + // Accept the common X-API-Key header even if callers customize the config. + if cfg.APIKeyHeader != "X-API-Key" && cfg.APIKeyHeader != "x-api-key" { + if headerValue := strings.TrimSpace(r.Header.Get("X-API-Key")); headerValue != "" { + return headerValue, true + } + } + + return "", false +} + +func isLocalhostRequest(remoteAddr string) bool { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + host = remoteAddr + } + + host = strings.TrimSpace(strings.Trim(host, "[]")) + if host == "localhost" { + return true + } + + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} diff --git a/proxy/internal/middleware/auth_test.go b/proxy/internal/middleware/auth_test.go new file mode 100644 index 0000000..fe25558 --- /dev/null +++ b/proxy/internal/middleware/auth_test.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/seifghazi/claude-code-monitor/internal/config" +) + +func TestAuthAllowsLocalhostBypass(t *testing.T) { + called := false + handler := Auth(config.AuthConfig{ + Enabled: true, + Token: "secret", + APIKeyHeader: "X-API-Key", + AllowLocalhostBypass: true, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "http://example.local/v1/messages", nil) + req.RemoteAddr = "127.0.0.1:45678" + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if !called { + t.Fatal("expected localhost request to bypass auth") + } +} + +func TestAuthRejectsMissingCredentials(t *testing.T) { + handler := Auth(config.AuthConfig{ + Enabled: true, + Token: "secret", + APIKeyHeader: "X-API-Key", + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "http://example.local/v1/messages", nil) + req.RemoteAddr = "10.1.2.3:45678" + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rr.Code) + } +} + +func TestAuthAcceptsBearerAndAPIKey(t *testing.T) { + testCases := []struct { + name string + setup func(*http.Request) + header string + }{ + { + name: "bearer", + setup: func(r *http.Request) { + r.Header.Set("Authorization", "Bearer secret") + }, + header: "Authorization", + }, + { + name: "api-key", + setup: func(r *http.Request) { + r.Header.Set("X-API-Key", "secret") + }, + header: "X-API-Key", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + called := false + handler := Auth(config.AuthConfig{ + Enabled: true, + Token: "secret", + APIKeyHeader: "X-API-Key", + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "http://example.local/v1/messages", nil) + req.RemoteAddr = "10.1.2.3:45678" + tc.setup(req) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if !called { + t.Fatalf("expected %s auth to pass", tc.header) + } + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + }) + } +} + +func TestAuthSkipsHealthAndOptions(t *testing.T) { + handler := Auth(config.AuthConfig{ + Enabled: true, + Token: "secret", + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.local/health", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected health request to bypass auth, got %d", rr.Code) + } + + req = httptest.NewRequest(http.MethodOptions, "http://example.local/v1/messages", nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS request to bypass auth, got %d", rr.Code) + } +} diff --git a/proxy/internal/service/conversation.go b/proxy/internal/service/conversation.go index 830f8a5..ddf36e9 100644 --- a/proxy/internal/service/conversation.go +++ b/proxy/internal/service/conversation.go @@ -59,8 +59,12 @@ type Conversation struct { func (cs *conversationService) GetConversations() (map[string][]*Conversation, error) { conversations := make(map[string][]*Conversation) var parseErrors []string + rootPath, err := cs.projectsRoot() + if err != nil { + return nil, fmt.Errorf("failed to resolve claude projects root: %w", err) + } - err := filepath.Walk(cs.claudeProjectsPath, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error { if err != nil { // Log but don't fail the entire walk parseErrors = append(parseErrors, fmt.Sprintf("Error accessing %s: %v", path, err)) @@ -71,16 +75,27 @@ func (cs *conversationService) GetConversations() (map[string][]*Conversation, e return nil } - // Get the project path relative to claudeProjectsPath - projectDir := filepath.Dir(path) - projectRelPath, _ := filepath.Rel(cs.claudeProjectsPath, projectDir) + // Reject symlinked files or paths that escape the projects root. + resolvedPath, err := cs.resolveExistingPathWithinProjectsRoot(path) + if err != nil { + parseErrors = append(parseErrors, fmt.Sprintf("Skipping %s: %v", path, err)) + return nil + } + + // Get the project path relative to the resolved root. + projectDir := filepath.Dir(resolvedPath) + projectRelPath, err := filepath.Rel(rootPath, projectDir) + if err != nil { + parseErrors = append(parseErrors, fmt.Sprintf("Skipping %s: %v", path, err)) + return nil + } // Skip files directly in the projects directory if projectRelPath == "." || projectRelPath == "" { return nil } - conv, err := cs.parseConversationFile(path, projectRelPath) + conv, err := cs.parseConversationFile(resolvedPath, projectRelPath) if err != nil { // Log parsing errors but continue processing other files parseErrors = append(parseErrors, fmt.Sprintf("Failed to parse %s: %v", path, err)) @@ -113,9 +128,12 @@ func (cs *conversationService) GetConversations() (map[string][]*Conversation, e // GetConversation returns a specific conversation by project and session ID func (cs *conversationService) GetConversation(projectPath, sessionID string) (*Conversation, error) { - filePath := filepath.Join(cs.claudeProjectsPath, projectPath, sessionID+".jsonl") + filePath, resolvedProjectPath, err := cs.resolveConversationFile(projectPath, sessionID) + if err != nil { + return nil, fmt.Errorf("failed to resolve conversation path: %w", err) + } - conv, err := cs.parseConversationFile(filePath, projectPath) + conv, err := cs.parseConversationFile(filePath, resolvedProjectPath) if err != nil { return nil, fmt.Errorf("failed to parse conversation: %w", err) } @@ -126,7 +144,10 @@ func (cs *conversationService) GetConversation(projectPath, sessionID string) (* // GetConversationsByProject returns all conversations for a specific project func (cs *conversationService) GetConversationsByProject(projectPath string) ([]*Conversation, error) { var conversations []*Conversation - projectDir := filepath.Join(cs.claudeProjectsPath, projectPath) + projectDir, resolvedProjectPath, err := cs.resolveProjectDir(projectPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve project path: %w", err) + } files, err := os.ReadDir(projectDir) if err != nil { @@ -139,7 +160,7 @@ func (cs *conversationService) GetConversationsByProject(projectPath string) ([] } filePath := filepath.Join(projectDir, file.Name()) - conv, err := cs.parseConversationFile(filePath, projectPath) + conv, err := cs.parseConversationFile(filePath, resolvedProjectPath) if err != nil { continue } @@ -157,6 +178,124 @@ func (cs *conversationService) GetConversationsByProject(projectPath string) ([] return conversations, nil } +func (cs *conversationService) projectsRoot() (string, error) { + root, err := filepath.Abs(cs.claudeProjectsPath) + if err != nil { + return "", fmt.Errorf("failed to make projects root absolute: %w", err) + } + + resolvedRoot, err := filepath.EvalSymlinks(root) + if err != nil { + if os.IsNotExist(err) { + return root, nil + } + return "", fmt.Errorf("failed to resolve projects root symlinks: %w", err) + } + + return resolvedRoot, nil +} + +func (cs *conversationService) resolveProjectDir(projectPath string) (string, string, error) { + cleanedProjectPath, err := cleanRelativeConversationPath(projectPath) + if err != nil { + return "", "", err + } + + rootPath, err := cs.projectsRoot() + if err != nil { + return "", "", err + } + + candidate := filepath.Join(rootPath, cleanedProjectPath) + resolvedCandidate, err := cs.resolveExistingPathWithinProjectsRoot(candidate) + if err != nil { + return "", "", err + } + + return resolvedCandidate, cleanedProjectPath, nil +} + +func (cs *conversationService) resolveConversationFile(projectPath, sessionID string) (string, string, error) { + if sessionID == "" { + return "", "", fmt.Errorf("session ID is required") + } + + if sessionID != filepath.Base(sessionID) || sessionID == "." || sessionID == ".." { + return "", "", fmt.Errorf("invalid session ID: %s", sessionID) + } + + projectDir, cleanedProjectPath, err := cs.resolveProjectDir(projectPath) + if err != nil { + return "", "", err + } + + candidate := filepath.Join(projectDir, sessionID+".jsonl") + resolvedCandidate, err := cs.resolveExistingPathWithinProjectsRoot(candidate) + if err != nil { + return "", "", err + } + + return resolvedCandidate, cleanedProjectPath, nil +} + +func (cs *conversationService) resolveExistingPathWithinProjectsRoot(path string) (string, error) { + rootPath, err := cs.projectsRoot() + if err != nil { + return "", err + } + + absolutePath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("failed to make path absolute: %w", err) + } + + normalizedPath := filepath.Clean(absolutePath) + if !pathWithinRoot(normalizedPath, rootPath) { + return "", fmt.Errorf("path escapes projects root: %s", path) + } + + resolvedPath, err := filepath.EvalSymlinks(normalizedPath) + if err != nil { + return "", fmt.Errorf("failed to resolve path symlinks: %w", err) + } + + if !pathWithinRoot(resolvedPath, rootPath) { + return "", fmt.Errorf("path escapes projects root after symlink resolution: %s", path) + } + + return resolvedPath, nil +} + +func cleanRelativeConversationPath(p string) (string, error) { + if p == "" { + return "", fmt.Errorf("path is required") + } + + if filepath.IsAbs(p) { + return "", fmt.Errorf("absolute paths are not allowed: %s", p) + } + + cleaned := filepath.Clean(p) + if cleaned == "." || cleaned == ".." || strings.HasPrefix(cleaned, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("path escapes projects root: %s", p) + } + + return cleaned, nil +} + +func pathWithinRoot(candidatePath, rootPath string) bool { + relPath, err := filepath.Rel(rootPath, candidatePath) + if err != nil { + return false + } + + if relPath == "." { + return true + } + + return relPath != ".." && !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) +} + // parseConversationFile reads and parses a JSONL conversation file func (cs *conversationService) parseConversationFile(filePath, projectPath string) (*Conversation, error) { // Get file info for modification time diff --git a/proxy/internal/service/conversation_test.go b/proxy/internal/service/conversation_test.go new file mode 100644 index 0000000..4c79cca --- /dev/null +++ b/proxy/internal/service/conversation_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "os" + "path/filepath" + "testing" +) + +func TestConversationServiceAllowsNestedProjectPaths(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "team", "app") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + sessionPath := filepath.Join(projectDir, "session.jsonl") + if err := os.WriteFile(sessionPath, []byte(`{"timestamp":"2026-03-19T12:00:00Z","type":"user","message":"hello"}`+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + svc := &conversationService{claudeProjectsPath: root} + + conversation, err := svc.GetConversation("team/app", "session") + if err != nil { + t.Fatalf("GetConversation() error = %v", err) + } + + if conversation.SessionID != "session" { + t.Fatalf("expected session ID %q, got %q", "session", conversation.SessionID) + } + + if conversation.ProjectPath != "team/app" { + t.Fatalf("expected project path %q, got %q", "team/app", conversation.ProjectPath) + } + + if len(conversation.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(conversation.Messages)) + } + + conversations, err := svc.GetConversationsByProject("team/app") + if err != nil { + t.Fatalf("GetConversationsByProject() error = %v", err) + } + + if len(conversations) != 1 { + t.Fatalf("expected 1 conversation, got %d", len(conversations)) + } +} + +func TestConversationServiceRejectsTraversalPaths(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "team", "app") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + sessionPath := filepath.Join(projectDir, "session.jsonl") + if err := os.WriteFile(sessionPath, []byte(`{"timestamp":"2026-03-19T12:00:00Z","type":"user","message":"hello"}`+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + svc := &conversationService{claudeProjectsPath: root} + + if _, err := svc.GetConversation("../outside", "session"); err == nil { + t.Fatal("expected traversal project path to be rejected") + } + + if _, err := svc.GetConversation("team/app", "../session"); err == nil { + t.Fatal("expected traversal session ID to be rejected") + } + + if _, err := svc.GetConversationsByProject("../../outside"); err == nil { + t.Fatal("expected traversal project listing to be rejected") + } +} + +func TestConversationServiceRejectsSymlinkEscapes(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "team") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + outsideDir := filepath.Join(t.TempDir(), "outside") + if err := os.MkdirAll(outsideDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + if err := os.WriteFile(filepath.Join(outsideDir, "session.jsonl"), []byte(`{"timestamp":"2026-03-19T12:00:00Z","type":"user","message":"hello"}`+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + linkPath := filepath.Join(projectDir, "app") + if err := os.Symlink(outsideDir, linkPath); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + svc := &conversationService{claudeProjectsPath: root} + + if _, err := svc.GetConversation("team/app", "session"); err == nil { + t.Fatal("expected symlink escape to be rejected") + } + + if _, err := svc.GetConversationsByProject("team/app"); err == nil { + t.Fatal("expected symlink project listing to be rejected") + } +} diff --git a/proxy/internal/service/storage_sqlite.go b/proxy/internal/service/storage_sqlite.go index e8aa062..a9e00db 100644 --- a/proxy/internal/service/storage_sqlite.go +++ b/proxy/internal/service/storage_sqlite.go @@ -74,6 +74,10 @@ func NewSQLiteStorageServiceWithLogger(cfg *config.StorageConfig, logger *log.Lo return nil, fmt.Errorf("failed to prepare statements: %w", err) } + if err := service.cleanupExpiredRequests(); err != nil { + logger.Printf("Warning: failed to apply retention policy during startup: %v", err) + } + return service, nil } @@ -194,7 +198,12 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e return "", fmt.Errorf("failed to marshal headers: %w", err) } - bodyJSON, err := json.Marshal(request.Body) + bodyForStorage, err := s.prepareRequestBodyForStorage(request.Body) + if err != nil { + return "", fmt.Errorf("failed to prepare body for storage: %w", err) + } + + bodyJSON, err := json.Marshal(bodyForStorage) if err != nil { return "", fmt.Errorf("failed to marshal body: %w", err) } @@ -217,6 +226,10 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e return "", fmt.Errorf("failed to insert request: %w", err) } + if err := s.cleanupExpiredRequests(); err != nil { + s.logger.Printf("Warning: failed to apply retention policy: %v", err) + } + return request.RequestID, nil } @@ -302,11 +315,20 @@ func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade return fmt.Errorf("request %s not found", requestID) } + if err := s.cleanupExpiredRequests(); err != nil { + s.logger.Printf("Warning: failed to apply retention policy: %v", err) + } + return nil } func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestLog) error { - responseJSON, err := json.Marshal(request.Response) + responseForStorage, err := s.prepareResponseForStorage(request.Response) + if err != nil { + return fmt.Errorf("failed to prepare response for storage: %w", err) + } + + responseJSON, err := json.Marshal(responseForStorage) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } @@ -337,7 +359,12 @@ func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog return fmt.Errorf("failed to marshal headers: %w", err) } - bodyJSON, err := json.Marshal(request.Body) + bodyForStorage, err := s.prepareRequestBodyForStorage(request.Body) + if err != nil { + return fmt.Errorf("failed to prepare body for storage: %w", err) + } + + bodyJSON, err := json.Marshal(bodyForStorage) if err != nil { return fmt.Errorf("failed to marshal body: %w", err) } @@ -362,7 +389,12 @@ func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog // Update with response if present if request.Response != nil { - responseJSON, err := json.Marshal(request.Response) + responseForStorage, err := s.prepareResponseForStorage(request.Response) + if err != nil { + return fmt.Errorf("failed to prepare response for storage: %w", err) + } + + responseJSON, err := json.Marshal(responseForStorage) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } @@ -377,6 +409,10 @@ func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog return fmt.Errorf("failed to commit transaction: %w", err) } + if err := s.cleanupExpiredRequests(); err != nil { + s.logger.Printf("Warning: failed to apply retention policy: %v", err) + } + return nil } @@ -579,6 +615,8 @@ func (s *sqliteStorageService) Close() error { // Helper functions +const redactionPlaceholder = "[REDACTED]" + // escapeLikePattern escapes special characters in LIKE patterns func escapeLikePattern(s string) string { // Escape \, %, and _ characters @@ -696,3 +734,146 @@ func (s *sqliteStorageService) unmarshalRequestFields(req *model.RequestLog, hea return nil } + +func (s *sqliteStorageService) cleanupExpiredRequests() error { + if s.config == nil || s.config.RetentionDays <= 0 { + return nil + } + + _, err := s.DeleteRequestsOlderThan(time.Duration(s.config.RetentionDays) * 24 * time.Hour) + return err +} + +func (s *sqliteStorageService) prepareRequestBodyForStorage(body interface{}) (interface{}, error) { + if s.shouldSuppressBodies() { + return storageBodyPlaceholder("metadata_only"), nil + } + if s.config != nil && !s.config.CaptureRequestBody { + return storageBodyPlaceholder("request_body_disabled"), nil + } + + normalized, err := normalizeJSONValue(body) + if err != nil { + return nil, err + } + + fields := []string{} + if s.config != nil { + fields = s.config.RedactedFields + } + + return redactJSONValue(normalized, redactedFieldSet(fields)), nil +} + +func (s *sqliteStorageService) prepareResponseForStorage(response *model.ResponseLog) (*model.ResponseLog, error) { + if response == nil { + return nil, nil + } + + clone := *response + if s.shouldSuppressBodies() || (s.config != nil && !s.config.CaptureResponseBody) { + clone.Body = nil + clone.BodyText = "" + clone.StreamingChunks = nil + return &clone, nil + } + + if len(clone.Body) > 0 { + fields := []string{} + if s.config != nil { + fields = s.config.RedactedFields + } + + sanitizedBody, err := sanitizeRawJSON(clone.Body, redactedFieldSet(fields)) + if err != nil { + // Preserve the original payload if it cannot be parsed as JSON. + s.logger.Printf("Warning: failed to redact response body: %v", err) + } else { + clone.Body = sanitizedBody + } + } + + return &clone, nil +} + +func (s *sqliteStorageService) shouldSuppressBodies() bool { + return s.config != nil && s.config.MetadataOnly +} + +func normalizeJSONValue(value interface{}) (interface{}, error) { + if value == nil { + return nil, nil + } + + data, err := json.Marshal(value) + if err != nil { + return nil, err + } + + var normalized interface{} + if err := json.Unmarshal(data, &normalized); err != nil { + return nil, err + } + + return normalized, nil +} + +func sanitizeRawJSON(raw json.RawMessage, redacted map[string]struct{}) (json.RawMessage, error) { + if len(raw) == 0 { + return raw, nil + } + + var value interface{} + if err := json.Unmarshal(raw, &value); err != nil { + return raw, err + } + + sanitized := redactJSONValue(value, redacted) + data, err := json.Marshal(sanitized) + if err != nil { + return raw, err + } + + return json.RawMessage(data), nil +} + +func redactJSONValue(value interface{}, redacted map[string]struct{}) interface{} { + switch typed := value.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(typed)) + for key, child := range typed { + if _, ok := redacted[strings.ToLower(key)]; ok { + result[key] = redactionPlaceholder + continue + } + result[key] = redactJSONValue(child, redacted) + } + return result + case []interface{}: + result := make([]interface{}, len(typed)) + for i, child := range typed { + result[i] = redactJSONValue(child, redacted) + } + return result + default: + return value + } +} + +func storageBodyPlaceholder(mode string) map[string]interface{} { + return map[string]interface{}{ + "_storage_mode": mode, + } +} + +func redactedFieldSet(fields []string) map[string]struct{} { + set := make(map[string]struct{}, len(fields)) + for _, field := range fields { + field = strings.TrimSpace(strings.ToLower(field)) + if field == "" { + continue + } + set[field] = struct{}{} + } + return set +} diff --git a/proxy/internal/service/storage_sqlite_test.go b/proxy/internal/service/storage_sqlite_test.go index 258cef2..982dbdf 100644 --- a/proxy/internal/service/storage_sqlite_test.go +++ b/proxy/internal/service/storage_sqlite_test.go @@ -1,7 +1,7 @@ package service import ( - "fmt" + "encoding/json" "path/filepath" "testing" "time" @@ -10,9 +10,276 @@ import ( "github.com/seifghazi/claude-code-monitor/internal/model" ) -func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "requests.db") - storage, err := NewSQLiteStorageService(&config.StorageConfig{DBPath: dbPath}) +func TestSQLiteStorageServiceRedactsRequestAndResponseBodies(t *testing.T) { + storage := newTestSQLiteStorage(t, config.StorageConfig{ + DBPath: filepath.Join(t.TempDir(), "requests.db"), + CaptureRequestBody: true, + CaptureResponseBody: true, + RedactedFields: []string{"api_key", "secret"}, + }) + + request := &model.RequestLog{ + RequestID: "redact-123", + Timestamp: time.Now().UTC().Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]interface{}{ + "api_key": "abc123", + "nested": map[string]interface{}{ + "secret": "top-secret", + "visible": "ok", + }, + }, + Model: "claude-3-5-sonnet", + UserAgent: "test", + ContentType: "application/json", + } + + if _, err := storage.SaveRequest(request); err != nil { + t.Fatalf("SaveRequest() error = %v", err) + } + + request.Response = &model.ResponseLog{ + StatusCode: httpStatusOK, + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: json.RawMessage(`{"secret":"response-secret","visible":"yes"}`), + ResponseTime: 12, + CompletedAt: time.Now().UTC().Format(time.RFC3339), + } + + if err := storage.UpdateRequestWithResponse(request); err != nil { + t.Fatalf("UpdateRequestWithResponse() error = %v", err) + } + + got, _, err := storage.GetRequestByShortID("123") + if err != nil { + t.Fatalf("GetRequestByShortID() error = %v", err) + } + + body, ok := got.Body.(map[string]interface{}) + if !ok { + t.Fatalf("expected request body to be a map, got %T", got.Body) + } + if body["api_key"] != redactionPlaceholder { + t.Fatalf("expected api_key to be redacted, got %#v", body["api_key"]) + } + + nested, ok := body["nested"].(map[string]interface{}) + if !ok { + t.Fatalf("expected nested body to be a map, got %T", body["nested"]) + } + if nested["secret"] != redactionPlaceholder { + t.Fatalf("expected nested secret to be redacted, got %#v", nested["secret"]) + } + if nested["visible"] != "ok" { + t.Fatalf("expected visible field to remain, got %#v", nested["visible"]) + } + + if got.Response == nil || len(got.Response.Body) == 0 { + t.Fatal("expected response body to be stored") + } + + var responseBody map[string]interface{} + if err := json.Unmarshal(got.Response.Body, &responseBody); err != nil { + t.Fatalf("response body unmarshal failed: %v", err) + } + if responseBody["secret"] != redactionPlaceholder { + t.Fatalf("expected response secret to be redacted, got %#v", responseBody["secret"]) + } + if responseBody["visible"] != "yes" { + t.Fatalf("expected response visible field to remain, got %#v", responseBody["visible"]) + } +} + +func TestSQLiteStorageServiceHonorsMetadataOnlyMode(t *testing.T) { + storage := newTestSQLiteStorage(t, config.StorageConfig{ + DBPath: filepath.Join(t.TempDir(), "requests.db"), + CaptureRequestBody: true, + CaptureResponseBody: true, + MetadataOnly: true, + }) + + request := &model.RequestLog{ + RequestID: "metadata-123", + Timestamp: time.Now().UTC().Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]interface{}{ + "message": "keep me out of storage", + }, + Model: "claude-3-5-sonnet", + UserAgent: "test", + ContentType: "application/json", + Response: &model.ResponseLog{ + StatusCode: httpStatusOK, + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: json.RawMessage(`{"answer":"secret"}`), + StreamingChunks: []string{ + "data: chunk-1", + }, + ResponseTime: 10, + CompletedAt: time.Now().UTC().Format(time.RFC3339), + }, + } + + if _, err := storage.SaveRequest(request); err != nil { + t.Fatalf("SaveRequest() error = %v", err) + } + if err := storage.UpdateRequestWithResponse(request); err != nil { + t.Fatalf("UpdateRequestWithResponse() error = %v", err) + } + + got, _, err := storage.GetRequestByShortID("123") + if err != nil { + t.Fatalf("GetRequestByShortID() error = %v", err) + } + + body, ok := got.Body.(map[string]interface{}) + if !ok { + t.Fatalf("expected metadata-only body placeholder map, got %T", got.Body) + } + if body["_storage_mode"] != "metadata_only" { + t.Fatalf("expected metadata-only placeholder, got %#v", body["_storage_mode"]) + } + + if got.Response == nil { + t.Fatal("expected response log to exist") + } + if len(got.Response.Body) != 0 { + t.Fatalf("expected response body to be removed, got %s", string(got.Response.Body)) + } + if got.Response.BodyText != "" { + t.Fatalf("expected response body text to be removed, got %q", got.Response.BodyText) + } + if len(got.Response.StreamingChunks) != 0 { + t.Fatalf("expected streaming chunks to be removed, got %d", len(got.Response.StreamingChunks)) + } + if got.Response.StatusCode != httpStatusOK { + t.Fatalf("expected response status to remain, got %d", got.Response.StatusCode) + } +} + +func TestSQLiteStorageServiceHonorsBodyCaptureToggles(t *testing.T) { + storage := newTestSQLiteStorage(t, config.StorageConfig{ + DBPath: filepath.Join(t.TempDir(), "requests.db"), + CaptureRequestBody: false, + CaptureResponseBody: false, + MetadataOnly: false, + }) + + request := &model.RequestLog{ + RequestID: "toggle-123", + Timestamp: time.Now().UTC().Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]interface{}{ + "message": "do not store me", + }, + Model: "claude-3-5-sonnet", + UserAgent: "test", + ContentType: "application/json", + Response: &model.ResponseLog{ + StatusCode: httpStatusOK, + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: json.RawMessage(`{"answer":"do not store me"}`), + BodyText: "sensitive text", + StreamingChunks: []string{"data: chunk-1"}, + ResponseTime: 10, + CompletedAt: time.Now().UTC().Format(time.RFC3339), + }, + } + + if _, err := storage.SaveRequest(request); err != nil { + t.Fatalf("SaveRequest() error = %v", err) + } + if err := storage.UpdateRequestWithResponse(request); err != nil { + t.Fatalf("UpdateRequestWithResponse() error = %v", err) + } + + got, _, err := storage.GetRequestByShortID("123") + if err != nil { + t.Fatalf("GetRequestByShortID() error = %v", err) + } + + body, ok := got.Body.(map[string]interface{}) + if !ok { + t.Fatalf("expected body placeholder map, got %T", got.Body) + } + if body["_storage_mode"] != "request_body_disabled" { + t.Fatalf("expected request body disabled placeholder, got %#v", body["_storage_mode"]) + } + + if got.Response == nil { + t.Fatal("expected response log to exist") + } + if len(got.Response.Body) != 0 { + t.Fatalf("expected response body to be omitted, got %s", string(got.Response.Body)) + } + if got.Response.BodyText != "" { + t.Fatalf("expected response body text to be omitted, got %q", got.Response.BodyText) + } + if len(got.Response.StreamingChunks) != 0 { + t.Fatalf("expected streaming chunks to be omitted, got %d", len(got.Response.StreamingChunks)) + } +} + +func TestSQLiteStorageServiceDeletesExpiredRequestsOnWrite(t *testing.T) { + storage := newTestSQLiteStorage(t, config.StorageConfig{ + DBPath: filepath.Join(t.TempDir(), "requests.db"), + RetentionDays: 1, + RedactedFields: []string{}, + }) + + oldRequest := &model.RequestLog{ + RequestID: "old-123", + Timestamp: time.Now().Add(-48 * time.Hour).UTC().Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]interface{}{"message": "old"}, + Model: "claude-3-5-sonnet", + UserAgent: "test", + ContentType: "application/json", + } + if _, err := storage.SaveRequest(oldRequest); err != nil { + t.Fatalf("SaveRequest(old) error = %v", err) + } + + recentRequest := &model.RequestLog{ + RequestID: "recent-123", + Timestamp: time.Now().UTC().Format(time.RFC3339), + Method: "POST", + Endpoint: "/v1/messages", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + Body: map[string]interface{}{"message": "recent"}, + Model: "claude-3-5-sonnet", + UserAgent: "test", + ContentType: "application/json", + } + if _, err := storage.SaveRequest(recentRequest); err != nil { + t.Fatalf("SaveRequest(recent) error = %v", err) + } + + got, err := storage.GetAllRequests("all") + if err != nil { + t.Fatalf("GetAllRequests() error = %v", err) + } + + if len(got) != 1 { + t.Fatalf("expected 1 request after retention cleanup, got %d", len(got)) + } + if got[0].RequestID != "recent-123" { + t.Fatalf("expected recent request to remain, got %s", got[0].RequestID) + } +} + +func newTestSQLiteStorage(t *testing.T, cfg config.StorageConfig) *sqliteStorageService { + t.Helper() + + storage, err := NewSQLiteStorageService(&cfg) if err != nil { t.Fatalf("NewSQLiteStorageService() error = %v", err) } @@ -21,49 +288,13 @@ func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing if !ok { t.Fatalf("unexpected storage type %T", storage) } - defer sqliteStorage.Close() - - requests := []struct { - id string - model string - }{ - {id: "1", model: "claude-3-5-sonnet"}, - {id: "2", model: "gpt-4o"}, - {id: "3", model: "claude-3-5-sonnet"}, - {id: "4", model: "gpt-4o-mini"}, - } - - for i, req := range requests { - _, err := storage.SaveRequest(&model.RequestLog{ - RequestID: req.id, - Timestamp: time.Date(2026, 3, 19, 12, 0, i, 0, time.UTC).Format(time.RFC3339), - Method: "POST", - Endpoint: "/v1/messages", - Headers: map[string][]string{"Content-Type": {"application/json"}}, - Body: map[string]string{"request": fmt.Sprintf("body-%d", i)}, - Model: req.model, - UserAgent: "test", - ContentType: "application/json", - }) - if err != nil { - t.Fatalf("SaveRequest() error = %v", err) + t.Cleanup(func() { + if err := sqliteStorage.Close(); err != nil { + t.Errorf("Close() error = %v", err) } - } + }) - got, total, err := storage.GetRequests(1, 1, "gpt") - if err != nil { - t.Fatalf("GetRequests() error = %v", err) - } - - if total != 2 { - t.Fatalf("expected filtered total 2, got %d", total) - } - - if len(got) != 1 { - t.Fatalf("expected 1 paginated result, got %d", len(got)) - } - - if got[0].RequestID != "4" { - t.Fatalf("expected newest filtered request ID 4, got %s", got[0].RequestID) - } + return sqliteStorage } + +const httpStatusOK = 200