Harden streaming, pagination, and config loading
This commit is contained in:
parent
02c9c76667
commit
6cda36312a
16 changed files with 1079 additions and 244 deletions
13
.env.example
13
.env.example
|
|
@ -12,5 +12,16 @@ ANTHROPIC_FORWARD_URL=https://api.anthropic.com
|
|||
ANTHROPIC_VERSION=2023-06-01
|
||||
ANTHROPIC_MAX_RETRIES=3
|
||||
|
||||
# OpenAI Configuration (for subagent routing)
|
||||
# OPENAI_API_KEY=your-openai-api-key
|
||||
# OPENAI_BASE_URL=https://api.openai.com
|
||||
# OPENAI_ALLOW_CLIENT_API_KEY=false
|
||||
# OPENAI_CLIENT_API_KEY_HEADER=x-openai-api-key
|
||||
|
||||
# Storage Configuration
|
||||
DATABASE_PATH=requests.db
|
||||
DB_PATH=requests.db
|
||||
|
||||
# CORS Configuration (comma-separated values)
|
||||
# CORS_ALLOWED_ORIGINS=*
|
||||
# CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||
# CORS_ALLOWED_HEADERS=*
|
||||
|
|
@ -38,6 +38,36 @@ providers:
|
|||
# Can also be set via OPENAI_BASE_URL environment variable
|
||||
# base_url: "https://api.openai.com"
|
||||
|
||||
# Allow clients to provide their own API key via header
|
||||
# Can also be set via OPENAI_ALLOW_CLIENT_API_KEY environment variable
|
||||
allow_client_api_key: false
|
||||
|
||||
# Header name for client-provided API key (default: x-openai-api-key)
|
||||
# Can also be set via OPENAI_CLIENT_API_KEY_HEADER environment variable
|
||||
client_api_key_header: "x-openai-api-key"
|
||||
|
||||
# CORS Configuration
|
||||
# Controls Cross-Origin Resource Sharing for the web UI
|
||||
cors:
|
||||
# Allowed origins (use ["*"] for all origins - not recommended for production)
|
||||
# Can also be set via CORS_ALLOWED_ORIGINS environment variable (comma-separated)
|
||||
allowed_origins:
|
||||
- "*"
|
||||
|
||||
# Allowed HTTP methods
|
||||
# Can also be set via CORS_ALLOWED_METHODS environment variable (comma-separated)
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
- "OPTIONS"
|
||||
|
||||
# Allowed headers (use ["*"] for all headers)
|
||||
# Can also be set via CORS_ALLOWED_HEADERS environment variable (comma-separated)
|
||||
allowed_headers:
|
||||
- "*"
|
||||
|
||||
# Storage configuration
|
||||
storage:
|
||||
# SQLite database path for storing request history
|
||||
|
|
@ -82,10 +112,17 @@ subagents:
|
|||
# OpenAI:
|
||||
# OPENAI_API_KEY - OpenAI API key
|
||||
# OPENAI_BASE_URL - OpenAI base URL
|
||||
# OPENAI_ALLOW_CLIENT_API_KEY - Allow client-provided API keys (true/false)
|
||||
# OPENAI_CLIENT_API_KEY_HEADER - Header name for client API key
|
||||
#
|
||||
# Storage:
|
||||
# DB_PATH - Database file path
|
||||
#
|
||||
# CORS:
|
||||
# CORS_ALLOWED_ORIGINS - Comma-separated allowed origins
|
||||
# CORS_ALLOWED_METHODS - Comma-separated allowed methods
|
||||
# CORS_ALLOWED_HEADERS - Comma-separated allowed headers
|
||||
#
|
||||
# Subagents:
|
||||
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
|
||||
# Example: "code-reviewer:claude-3-5-sonnet"
|
||||
|
|
@ -50,9 +50,9 @@ func main() {
|
|||
r := mux.NewRouter()
|
||||
|
||||
corsHandler := handlers.CORS(
|
||||
handlers.AllowedOrigins([]string{"*"}),
|
||||
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}),
|
||||
handlers.AllowedHeaders([]string{"*"}),
|
||||
handlers.AllowedOrigins(cfg.CORS.AllowedOrigins),
|
||||
handlers.AllowedMethods(cfg.CORS.AllowedMethods),
|
||||
handlers.AllowedHeaders(cfg.CORS.AllowedHeaders),
|
||||
)
|
||||
|
||||
r.Use(middleware.Logging)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
|
|
@ -15,9 +17,16 @@ type Config struct {
|
|||
Providers ProvidersConfig `yaml:"providers"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Subagents SubagentsConfig `yaml:"subagents"`
|
||||
CORS CORSConfig `yaml:"cors"`
|
||||
Anthropic AnthropicConfig
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
AllowedMethods []string `yaml:"allowed_methods"`
|
||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port string `yaml:"port"`
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
|
|
@ -47,6 +56,8 @@ type AnthropicProviderConfig struct {
|
|||
type OpenAIProviderConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
AllowClientAPIKey bool `yaml:"allow_client_api_key"` // Allow clients to provide their own API key
|
||||
ClientAPIKeyHeader string `yaml:"client_api_key_header"` // Header name for client API key (default: x-openai-api-key)
|
||||
}
|
||||
|
||||
type AnthropicConfig struct {
|
||||
|
|
@ -94,6 +105,8 @@ func Load() (*Config, error) {
|
|||
OpenAI: OpenAIProviderConfig{
|
||||
BaseURL: "https://api.openai.com",
|
||||
APIKey: "",
|
||||
AllowClientAPIKey: false,
|
||||
ClientAPIKeyHeader: "x-openai-api-key",
|
||||
},
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
|
|
@ -103,24 +116,16 @@ func Load() (*Config, error) {
|
|||
Enable: false,
|
||||
Mappings: make(map[string]string),
|
||||
},
|
||||
CORS: CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"*"},
|
||||
},
|
||||
}
|
||||
|
||||
// Try to load config.yaml from the project root
|
||||
// The proxy binary is in proxy/ directory, config.yaml is in the parent
|
||||
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
|
||||
|
||||
// If that doesn't work, try relative to current directory
|
||||
if _, err := os.Stat(configPath); err != nil {
|
||||
// Try common locations relative to where the binary might be run
|
||||
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
|
||||
if _, err := os.Stat(tryPath); err == nil {
|
||||
configPath = tryPath
|
||||
break
|
||||
if err := loadFirstAvailableConfig(cfg, candidateConfigPaths()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg.loadFromFile(configPath)
|
||||
|
||||
// Apply environment variable overrides AFTER loading from file
|
||||
if envPort := os.Getenv("PORT"); envPort != "" {
|
||||
|
|
@ -154,12 +159,29 @@ func Load() (*Config, error) {
|
|||
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
||||
cfg.Providers.OpenAI.APIKey = envKey
|
||||
}
|
||||
if envAllow := os.Getenv("OPENAI_ALLOW_CLIENT_API_KEY"); envAllow != "" {
|
||||
cfg.Providers.OpenAI.AllowClientAPIKey = envAllow == "true" || envAllow == "1"
|
||||
}
|
||||
if envHeader := os.Getenv("OPENAI_CLIENT_API_KEY_HEADER"); envHeader != "" {
|
||||
cfg.Providers.OpenAI.ClientAPIKeyHeader = envHeader
|
||||
}
|
||||
|
||||
// Override storage settings
|
||||
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
||||
cfg.Storage.DBPath = envPath
|
||||
}
|
||||
|
||||
// Override CORS settings (comma-separated values)
|
||||
if envOrigins := os.Getenv("CORS_ALLOWED_ORIGINS"); envOrigins != "" {
|
||||
cfg.CORS.AllowedOrigins = splitAndTrim(envOrigins)
|
||||
}
|
||||
if envMethods := os.Getenv("CORS_ALLOWED_METHODS"); envMethods != "" {
|
||||
cfg.CORS.AllowedMethods = splitAndTrim(envMethods)
|
||||
}
|
||||
if envHeaders := os.Getenv("CORS_ALLOWED_HEADERS"); envHeaders != "" {
|
||||
cfg.CORS.AllowedHeaders = splitAndTrim(envHeaders)
|
||||
}
|
||||
|
||||
// Sync legacy Anthropic config
|
||||
cfg.Anthropic = AnthropicConfig{
|
||||
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||
|
|
@ -203,6 +225,45 @@ func (c *Config) loadFromFile(path string) error {
|
|||
return yaml.Unmarshal(data, c)
|
||||
}
|
||||
|
||||
func candidateConfigPaths() []string {
|
||||
paths := []string{
|
||||
filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml"),
|
||||
"config.yaml",
|
||||
"../config.yaml",
|
||||
"../../config.yaml",
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(paths))
|
||||
unique := make([]string, 0, len(paths))
|
||||
for _, path := range paths {
|
||||
if _, ok := seen[path]; ok {
|
||||
continue
|
||||
}
|
||||
seen[path] = struct{}{}
|
||||
unique = append(unique, path)
|
||||
}
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
func loadFirstAvailableConfig(cfg *Config, paths []string) error {
|
||||
for _, path := range paths {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to stat config file %q: %w", path, err)
|
||||
}
|
||||
|
||||
if err := cfg.loadFromFile(path); err != nil {
|
||||
return fmt.Errorf("failed to load config file %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
|
|
@ -237,3 +298,15 @@ func getInt(key string, defaultValue int) int {
|
|||
|
||||
return intValue
|
||||
}
|
||||
|
||||
func splitAndTrim(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
30
proxy/internal/config/config_test.go
Normal file
30
proxy/internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadFirstAvailableConfigReturnsParseError(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.yaml")
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("server: ["), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{}
|
||||
if err := loadFirstAvailableConfig(cfg, []string{configPath}); err == nil {
|
||||
t.Fatal("expected parse error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFirstAvailableConfigSkipsMissingFiles(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
if err := loadFirstAvailableConfig(cfg, []string{
|
||||
filepath.Join(t.TempDir(), "missing.yaml"),
|
||||
}); err != nil {
|
||||
t.Fatalf("expected nil error for missing config, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
|
@ -20,6 +19,7 @@ import (
|
|||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/service"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/sse"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
|
|
@ -179,37 +179,13 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
|||
modelFilter = "all"
|
||||
}
|
||||
|
||||
// Get all requests with model filter applied at storage level
|
||||
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
||||
requests, total, err := h.storageService.GetRequests(page, limit, modelFilter)
|
||||
if err != nil {
|
||||
log.Printf("Error getting requests: %v", err)
|
||||
http.Error(w, "Failed to get requests", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert pointers to values for consistency
|
||||
requests := make([]model.RequestLog, len(allRequests))
|
||||
for i, req := range allRequests {
|
||||
if req != nil {
|
||||
requests[i] = *req
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate total before pagination
|
||||
total := len(requests)
|
||||
|
||||
// Apply pagination
|
||||
start := (page - 1) * limit
|
||||
end := start + limit
|
||||
if start >= len(requests) {
|
||||
requests = []model.RequestLog{}
|
||||
} else {
|
||||
if end > len(requests) {
|
||||
end = len(requests)
|
||||
}
|
||||
requests = requests[start:end]
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(struct {
|
||||
Requests []model.RequestLog `json:"requests"`
|
||||
|
|
@ -242,6 +218,8 @@ func (h *Handler) NotFound(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||
// Forward important upstream headers (rate limits, request IDs, etc.)
|
||||
ForwardResponseHeaders(w, resp)
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
|
|
@ -278,16 +256,17 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
var messageID string
|
||||
var modelName string
|
||||
var stopReason string
|
||||
var sawMessageStop bool
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
streamErr := sse.ForEachLine(resp.Body, func(line string) error {
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
streamingChunks = append(streamingChunks, line)
|
||||
fmt.Fprintf(w, "%s\n\n", line)
|
||||
if _, err := fmt.Fprintf(w, "%s\n\n", line); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
|
@ -298,7 +277,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
var genericEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonData), &genericEvent); err != nil {
|
||||
log.Printf("⚠️ Error unmarshalling streaming event: %v", err)
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Capture metadata from message_start event
|
||||
|
|
@ -347,7 +326,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
var event model.StreamingEvent
|
||||
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
||||
// Skip if structured parsing fails, but we already got the usage data above
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
|
|
@ -366,8 +345,13 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
toolCalls = append(toolCalls, *event.ContentBlock)
|
||||
}
|
||||
case "message_stop":
|
||||
// End of stream - scanner will exit on its own
|
||||
sawMessageStop = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if streamErr == nil && !sawMessageStop {
|
||||
streamErr = io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
responseLog := &model.ResponseLog{
|
||||
|
|
@ -378,6 +362,9 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
IsStreaming: true,
|
||||
CompletedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
if streamErr != nil {
|
||||
responseLog.StreamError = streamErr.Error()
|
||||
}
|
||||
|
||||
// Create a structured response body that matches Anthropic's format
|
||||
var contentBlocks []model.AnthropicContentBlock
|
||||
|
|
@ -417,14 +404,31 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
|||
log.Printf("❌ Error updating request with streaming response: %v", err)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("❌ Streaming error: %v", err)
|
||||
if streamErr != nil {
|
||||
log.Printf("❌ Streaming error: %v", streamErr)
|
||||
// Send error event to client in Anthropic streaming format
|
||||
errorEvent := map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": map[string]interface{}{
|
||||
"type": "stream_error",
|
||||
"message": fmt.Sprintf("Stream interrupted: %v", streamErr),
|
||||
},
|
||||
}
|
||||
if errorJSON, jsonErr := json.Marshal(errorEvent); jsonErr == nil {
|
||||
fmt.Fprintf(w, "data: %s\n\n", errorJSON)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Println("✅ Streaming response completed")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||
// Forward important upstream headers (rate limits, request IDs, etc.)
|
||||
ForwardResponseHeaders(w, resp)
|
||||
|
||||
responseBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error reading Anthropic response: %v", err)
|
||||
|
|
@ -464,6 +468,7 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("❌ Anthropic API error: %d %s", resp.StatusCode, string(responseBytes))
|
||||
// Headers already forwarded at start of function
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
w.Write(responseBytes)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,68 @@ import (
|
|||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||
)
|
||||
|
||||
// Headers that should be forwarded from upstream responses to clients
|
||||
var forwardableResponseHeaders = []string{
|
||||
// Anthropic rate limit headers
|
||||
"anthropic-ratelimit-requests-limit",
|
||||
"anthropic-ratelimit-requests-remaining",
|
||||
"anthropic-ratelimit-requests-reset",
|
||||
"anthropic-ratelimit-tokens-limit",
|
||||
"anthropic-ratelimit-tokens-remaining",
|
||||
"anthropic-ratelimit-tokens-reset",
|
||||
// Standard rate limit headers
|
||||
"x-ratelimit-limit",
|
||||
"x-ratelimit-remaining",
|
||||
"x-ratelimit-reset",
|
||||
"retry-after",
|
||||
// Request tracking
|
||||
"x-request-id",
|
||||
"request-id",
|
||||
// Anthropic specific
|
||||
"anthropic-organization-id",
|
||||
// OpenAI specific
|
||||
"openai-organization",
|
||||
"openai-processing-ms",
|
||||
"openai-version",
|
||||
"x-request-id",
|
||||
}
|
||||
|
||||
// ForwardResponseHeaders copies important headers from upstream response to client response
|
||||
func ForwardResponseHeaders(w http.ResponseWriter, resp *http.Response) {
|
||||
for _, header := range forwardableResponseHeaders {
|
||||
if values := resp.Header.Values(header); len(values) > 0 {
|
||||
for _, value := range values {
|
||||
w.Header().Add(header, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CopyAllResponseHeaders copies all non-hop-by-hop headers from upstream to client
|
||||
func CopyAllResponseHeaders(w http.ResponseWriter, resp *http.Response) {
|
||||
hopByHopHeaders := map[string]bool{
|
||||
"connection": true,
|
||||
"keep-alive": true,
|
||||
"proxy-authenticate": true,
|
||||
"proxy-authorization": true,
|
||||
"te": true,
|
||||
"trailers": true,
|
||||
"transfer-encoding": true,
|
||||
"upgrade": true,
|
||||
"content-encoding": true, // We handle decompression ourselves
|
||||
"content-length": true, // May change after decompression
|
||||
}
|
||||
|
||||
for key, values := range resp.Header {
|
||||
if hopByHopHeaders[strings.ToLower(key)] {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeHeaders removes sensitive headers before logging/storage
|
||||
func SanitizeHeaders(headers http.Header) http.Header {
|
||||
sanitized := make(http.Header)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ type ResponseLog struct {
|
|||
Headers map[string][]string `json:"headers"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
BodyText string `json:"bodyText,omitempty"`
|
||||
StreamError string `json:"streamError,omitempty"`
|
||||
ResponseTime int64 `json:"responseTime"`
|
||||
StreamingChunks []string `json:"streamingChunks,omitempty"`
|
||||
IsStreaming bool `json:"isStreaming"`
|
||||
|
|
|
|||
|
|
@ -66,8 +66,14 @@ func (p *AnthropicProvider) ForwardRequest(ctx context.Context, originalReq *htt
|
|||
proxyReq.Header.Set("anthropic-version", p.config.Version)
|
||||
}
|
||||
|
||||
// Support gzip encoding
|
||||
// Handle Accept-Encoding: We accept gzip from upstream for efficiency,
|
||||
// but we decompress before forwarding to the client. This is transparent
|
||||
// to the client - they receive uncompressed data regardless of what they requested.
|
||||
// We preserve gzip if client already requested it, otherwise add it.
|
||||
clientEncoding := proxyReq.Header.Get("Accept-Encoding")
|
||||
if clientEncoding == "" || !strings.Contains(clientEncoding, "gzip") {
|
||||
proxyReq.Header.Set("Accept-Encoding", "gzip")
|
||||
}
|
||||
|
||||
// Forward the request
|
||||
resp, err := p.client.Do(proxyReq)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
|
|
@ -15,6 +14,7 @@ import (
|
|||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/sse"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
|
|
@ -78,13 +78,29 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
|||
// Remove Anthropic-specific headers
|
||||
proxyReq.Header.Del("anthropic-version")
|
||||
proxyReq.Header.Del("x-api-key")
|
||||
proxyReq.Header.Del("Authorization")
|
||||
|
||||
// Determine which API key to use
|
||||
apiKey := p.config.APIKey
|
||||
|
||||
// Check for client-provided API key if allowed
|
||||
if p.config.AllowClientAPIKey && p.config.ClientAPIKeyHeader != "" {
|
||||
if clientKey := originalReq.Header.Get(p.config.ClientAPIKeyHeader); clientKey != "" {
|
||||
apiKey = clientKey
|
||||
}
|
||||
}
|
||||
|
||||
// Add OpenAI headers
|
||||
if p.config.APIKey != "" {
|
||||
proxyReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
if apiKey != "" {
|
||||
proxyReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
proxyReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Remove the client API key header from the proxied request
|
||||
if p.config.ClientAPIKeyHeader != "" {
|
||||
proxyReq.Header.Del(p.config.ClientAPIKeyHeader)
|
||||
}
|
||||
|
||||
// Forward the request
|
||||
resp, err := p.client.Do(proxyReq)
|
||||
if err != nil {
|
||||
|
|
@ -139,9 +155,12 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
|||
|
||||
// Start a goroutine to transform the stream
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
defer bodyReader.Close()
|
||||
transformOpenAIStreamToAnthropic(bodyReader, pw)
|
||||
if err := transformOpenAIStreamToAnthropic(bodyReader, pw); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
// Replace the response body with our transformed stream
|
||||
|
|
@ -332,10 +351,10 @@ func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{
|
|||
})
|
||||
}
|
||||
}
|
||||
// Check if max_tokens exceeds the model's limit and cap it if necessary
|
||||
maxTokensLimit := 16384 // Assuming this is the limit for the model
|
||||
if req.MaxTokens > maxTokensLimit {
|
||||
// Capping max_tokens to model limit
|
||||
// Get model-specific max token limit
|
||||
// Let the API handle validation for unknown models rather than using arbitrary caps
|
||||
maxTokensLimit := getModelMaxTokens(req.Model)
|
||||
if maxTokensLimit > 0 && req.MaxTokens > maxTokensLimit {
|
||||
req.MaxTokens = maxTokensLimit
|
||||
}
|
||||
|
||||
|
|
@ -473,6 +492,52 @@ func min(a, b int) int {
|
|||
return b
|
||||
}
|
||||
|
||||
// getModelMaxTokens returns the max output tokens for known models
|
||||
// Returns 0 for unknown models, letting the API handle validation
|
||||
func getModelMaxTokens(model string) int {
|
||||
// Model-specific max completion token limits
|
||||
modelLimits := map[string]int{
|
||||
// GPT-4 Turbo and GPT-4o models
|
||||
"gpt-4-turbo": 4096,
|
||||
"gpt-4-turbo-preview": 4096,
|
||||
"gpt-4o": 16384,
|
||||
"gpt-4o-mini": 16384,
|
||||
"gpt-4o-2024-05-13": 16384,
|
||||
"gpt-4o-2024-08-06": 16384,
|
||||
// GPT-4 models
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
// GPT-3.5 models
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 4096,
|
||||
"gpt-3.5-turbo-0125": 4096,
|
||||
"gpt-3.5-turbo-1106": 4096,
|
||||
// o1 reasoning models
|
||||
"o1": 100000,
|
||||
"o1-preview": 32768,
|
||||
"o1-mini": 65536,
|
||||
// o3 reasoning models (estimated based on o1 patterns)
|
||||
"o3": 100000,
|
||||
"o3-mini": 65536,
|
||||
}
|
||||
|
||||
// Check for exact match first
|
||||
if limit, ok := modelLimits[model]; ok {
|
||||
return limit
|
||||
}
|
||||
|
||||
// Check for prefix matches for versioned models
|
||||
for prefix, limit := range modelLimits {
|
||||
if strings.HasPrefix(model, prefix) {
|
||||
return limit
|
||||
}
|
||||
}
|
||||
|
||||
// Return 0 for unknown models - let the API validate
|
||||
return 0
|
||||
}
|
||||
|
||||
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||
// This is a simplified transformation
|
||||
// In production, you'd want to handle all fields properly
|
||||
|
|
@ -579,19 +644,16 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
|||
return result
|
||||
}
|
||||
|
||||
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
||||
defer openAIStream.Close()
|
||||
|
||||
scanner := bufio.NewScanner(openAIStream)
|
||||
func transformOpenAIStreamToAnthropic(openAIStream io.Reader, anthropicStream io.Writer) error {
|
||||
var messageStarted bool
|
||||
var contentStarted bool
|
||||
var sawDone bool
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
err := sse.ForEachLine(openAIStream, func(line string) error {
|
||||
|
||||
// Skip empty lines
|
||||
if line == "" {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle SSE data lines
|
||||
|
|
@ -600,21 +662,28 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
|
||||
// Handle end of stream
|
||||
if data == "[DONE]" {
|
||||
sawDone = true
|
||||
// Send Anthropic-style completion
|
||||
if contentStarted {
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if messageStarted {
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
|
||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
break
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse OpenAI response
|
||||
var openAIChunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for usage data BEFORE processing choices
|
||||
|
|
@ -644,7 +713,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
"usage": anthropicUsage,
|
||||
}
|
||||
usageJSON, _ := json.Marshal(usageDelta)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -652,17 +723,17 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
choices, ok := openAIChunk["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
// Skip further processing if no choices, but we already handled usage above
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
choice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
delta, ok := choice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle first chunk - send message_start
|
||||
|
|
@ -684,7 +755,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
},
|
||||
}
|
||||
startJSON, _ := json.Marshal(messageStart)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Handle content
|
||||
|
|
@ -701,7 +774,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
},
|
||||
}
|
||||
blockStartJSON, _ := json.Marshal(blockStart)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send content_block_delta
|
||||
|
|
@ -714,9 +789,22 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
|||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(contentDelta)
|
||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
|
||||
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !sawDone {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
63
proxy/internal/provider/openai_test.go
Normal file
63
proxy/internal/provider/openai_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||
)
|
||||
|
||||
func TestOpenAIProviderForwardRequestClearsInboundAuthorization(t *testing.T) {
|
||||
var gotAuthorization string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuthorization = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"resp_1","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewOpenAIProvider(&config.OpenAIProviderConfig{
|
||||
BaseURL: server.URL,
|
||||
}).(*OpenAIProvider)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://proxy.local/v1/messages", strings.NewReader(`{"model":"gpt-4o","messages":[],"max_tokens":16}`))
|
||||
req.Header.Set("Authorization", "Bearer should-not-leak")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := provider.ForwardRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("ForwardRequest() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if gotAuthorization != "" {
|
||||
t.Fatalf("expected forwarded Authorization header to be empty, got %q", gotAuthorization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformOpenAIStreamToAnthropicHandlesLargeEvents(t *testing.T) {
|
||||
largeContent := strings.Repeat("x", 128*1024)
|
||||
openAIStream := strings.NewReader("data: {\"id\":\"chatcmpl_1\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"" + largeContent + "\"}}]}\n\n" +
|
||||
"data: [DONE]\n\n")
|
||||
|
||||
var output strings.Builder
|
||||
if err := transformOpenAIStreamToAnthropic(openAIStream, &output); err != nil {
|
||||
t.Fatalf("transformOpenAIStreamToAnthropic() error = %v", err)
|
||||
}
|
||||
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "\"message_start\"") {
|
||||
t.Fatal("expected message_start event in output")
|
||||
}
|
||||
if !strings.Contains(got, largeContent) {
|
||||
t.Fatal("expected large content to be preserved in output")
|
||||
}
|
||||
if !strings.Contains(got, "\"message_stop\"") {
|
||||
t.Fatal("expected message_stop event in output")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,18 +1,33 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||
)
|
||||
|
||||
type StorageService interface {
|
||||
// Core CRUD operations
|
||||
SaveRequest(request *model.RequestLog) (string, error)
|
||||
GetRequests(page, limit int) ([]model.RequestLog, int, error)
|
||||
GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error)
|
||||
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
|
||||
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
|
||||
ClearRequests() (int, error)
|
||||
|
||||
// Update operations
|
||||
UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error
|
||||
UpdateRequestWithResponse(request *model.RequestLog) error
|
||||
EnsureDirectoryExists() error
|
||||
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
|
||||
|
||||
// Maintenance operations
|
||||
DeleteRequestsOlderThan(age time.Duration) (int, error)
|
||||
GetDatabaseStats() (map[string]interface{}, error)
|
||||
|
||||
// Configuration
|
||||
GetConfig() *config.StorageConfig
|
||||
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
|
||||
EnsureDirectoryExists() error
|
||||
|
||||
// Cleanup - implements io.Closer
|
||||
io.Closer
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
|
|
@ -15,23 +17,63 @@ import (
|
|||
type sqliteStorageService struct {
|
||||
db *sql.DB
|
||||
config *config.StorageConfig
|
||||
logger *log.Logger
|
||||
|
||||
// Prepared statements for frequently used queries
|
||||
stmtInsertRequest *sql.Stmt
|
||||
stmtUpdateResponse *sql.Stmt
|
||||
stmtUpdateGrading *sql.Stmt
|
||||
stmtGetRequestByID *sql.Stmt
|
||||
stmtGetRequestsPage *sql.Stmt
|
||||
stmtGetRequestsCount *sql.Stmt
|
||||
stmtDeleteOldRequests *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) {
|
||||
db, err := sql.Open("sqlite3", cfg.DBPath)
|
||||
return NewSQLiteStorageServiceWithLogger(cfg, log.Default())
|
||||
}
|
||||
|
||||
func NewSQLiteStorageServiceWithLogger(cfg *config.StorageConfig, logger *log.Logger) (StorageService, error) {
|
||||
// Enable WAL mode and other optimizations via connection string
|
||||
// _journal_mode=WAL: Write-Ahead Logging for better concurrent read performance
|
||||
// _synchronous=NORMAL: Good balance of safety and performance
|
||||
// _busy_timeout=5000: Wait up to 5 seconds if database is locked
|
||||
// _cache_size=-20000: Use 20MB of memory for cache (negative = KB)
|
||||
connStr := cfg.DBPath + "?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000"
|
||||
|
||||
db, err := sql.Open("sqlite3", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
// SQLite only supports one writer at a time, but can handle multiple readers
|
||||
db.SetMaxOpenConns(1) // Serialize writes to avoid SQLITE_BUSY errors
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
// Verify connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
service := &sqliteStorageService{
|
||||
db: db,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if err := service.createTables(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to create tables: %w", err)
|
||||
}
|
||||
|
||||
if err := service.prepareStatements(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to prepare statements: %w", err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
|
|
@ -39,7 +81,7 @@ func (s *sqliteStorageService) createTables() error {
|
|||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS requests (
|
||||
id TEXT PRIMARY KEY,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
timestamp DATETIME NOT NULL,
|
||||
method TEXT NOT NULL,
|
||||
endpoint TEXT NOT NULL,
|
||||
headers TEXT NOT NULL,
|
||||
|
|
@ -50,17 +92,100 @@ func (s *sqliteStorageService) createTables() error {
|
|||
response TEXT,
|
||||
model TEXT,
|
||||
original_model TEXT,
|
||||
routed_model TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
routed_model TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON requests(timestamp DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_endpoint ON requests(endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_model ON requests(model);
|
||||
-- Index for listing requests by time (most common query)
|
||||
CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC);
|
||||
|
||||
-- Index for filtering by model
|
||||
CREATE INDEX IF NOT EXISTS idx_requests_model ON requests(model);
|
||||
|
||||
-- Index for filtering by endpoint
|
||||
CREATE INDEX IF NOT EXISTS idx_requests_endpoint ON requests(endpoint);
|
||||
`
|
||||
|
||||
_, err := s.db.Exec(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
s.migrateSchema()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) migrateSchema() {
|
||||
// Ensure WAL mode is enabled (in case opened without connection string params)
|
||||
_, err := s.db.Exec("PRAGMA journal_mode=WAL")
|
||||
if err != nil {
|
||||
s.logger.Printf("Warning: failed to set WAL mode: %v", err)
|
||||
}
|
||||
|
||||
// Drop old redundant index if it exists (we renamed to idx_requests_timestamp)
|
||||
s.db.Exec("DROP INDEX IF EXISTS idx_timestamp")
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) prepareStatements() error {
|
||||
var err error
|
||||
|
||||
s.stmtInsertRequest, err = s.db.Prepare(`
|
||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtUpdateResponse, err = s.db.Prepare(`
|
||||
UPDATE requests SET response = ? WHERE id = ?
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare update response statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtUpdateGrading, err = s.db.Prepare(`
|
||||
UPDATE requests SET prompt_grade = ? WHERE id = ?
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare update grading statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtGetRequestByID, err = s.db.Prepare(`
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
WHERE id = ?
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare get by ID statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtGetRequestsPage, err = s.db.Prepare(`
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare get requests page statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtGetRequestsCount, err = s.db.Prepare(`
|
||||
SELECT COUNT(*) FROM requests
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare count statement: %w", err)
|
||||
}
|
||||
|
||||
s.stmtDeleteOldRequests, err = s.db.Prepare(`
|
||||
DELETE FROM requests WHERE timestamp < ?
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare delete old requests statement: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
||||
|
|
@ -74,12 +199,7 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
|||
return "", fmt.Errorf("failed to marshal body: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err = s.db.Exec(query,
|
||||
_, err = s.stmtInsertRequest.Exec(
|
||||
request.RequestID,
|
||||
request.Timestamp,
|
||||
request.Method,
|
||||
|
|
@ -100,10 +220,24 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
|||
return request.RequestID, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, int, error) {
|
||||
func (s *sqliteStorageService) GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error) {
|
||||
whereClause := ""
|
||||
countArgs := []interface{}{}
|
||||
queryArgs := []interface{}{}
|
||||
|
||||
if modelFilter != "" && modelFilter != "all" {
|
||||
// Escape LIKE special characters to prevent pattern injection
|
||||
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
|
||||
whereClause = " WHERE LOWER(model) LIKE ? ESCAPE '\\'"
|
||||
filterValue := "%" + escapedFilter + "%"
|
||||
countArgs = append(countArgs, filterValue)
|
||||
queryArgs = append(queryArgs, filterValue)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
var total int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&total)
|
||||
countQuery := "SELECT COUNT(*) FROM requests" + whereClause
|
||||
err := s.db.QueryRow(countQuery, countArgs...).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to get total count: %w", err)
|
||||
}
|
||||
|
|
@ -112,71 +246,21 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
|||
offset := (page - 1) * limit
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
FROM requests` + whereClause + `
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
queryArgs = append(queryArgs, limit, offset)
|
||||
|
||||
rows, err := s.db.Query(query, limit, offset)
|
||||
rows, err := s.db.Query(query, queryArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query requests: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var requests []model.RequestLog
|
||||
for rows.Next() {
|
||||
var req model.RequestLog
|
||||
var headersJSON, bodyJSON string
|
||||
var promptGradeJSON, responseJSON sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&req.RequestID,
|
||||
&req.Timestamp,
|
||||
&req.Method,
|
||||
&req.Endpoint,
|
||||
&headersJSON,
|
||||
&bodyJSON,
|
||||
&req.Model,
|
||||
&req.UserAgent,
|
||||
&req.ContentType,
|
||||
&promptGradeJSON,
|
||||
&responseJSON,
|
||||
&req.OriginalModel,
|
||||
&req.RoutedModel,
|
||||
)
|
||||
requests, err := s.scanRequestRows(rows)
|
||||
if err != nil {
|
||||
// Error scanning row - skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Unmarshal JSON fields
|
||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||
// Error unmarshaling headers
|
||||
continue
|
||||
}
|
||||
|
||||
var body interface{}
|
||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||
// Error unmarshaling body
|
||||
continue
|
||||
}
|
||||
req.Body = body
|
||||
|
||||
if promptGradeJSON.Valid {
|
||||
var grade model.PromptGrade
|
||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
||||
req.PromptGrade = &grade
|
||||
}
|
||||
}
|
||||
|
||||
if responseJSON.Valid {
|
||||
var resp model.ResponseLog
|
||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
||||
req.Response = &resp
|
||||
}
|
||||
}
|
||||
|
||||
requests = append(requests, req)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return requests, total, nil
|
||||
|
|
@ -193,6 +277,12 @@ func (s *sqliteStorageService) ClearRequests() (int, error) {
|
|||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
// Reclaim space after clearing all data
|
||||
_, err = s.db.Exec("VACUUM")
|
||||
if err != nil {
|
||||
s.logger.Printf("Warning: failed to vacuum database: %v", err)
|
||||
}
|
||||
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
|
|
@ -202,12 +292,16 @@ func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade
|
|||
return fmt.Errorf("failed to marshal grade: %w", err)
|
||||
}
|
||||
|
||||
query := "UPDATE requests SET prompt_grade = ? WHERE id = ?"
|
||||
_, err = s.db.Exec(query, string(gradeJSON), requestID)
|
||||
result, err := s.stmtUpdateGrading.Exec(string(gradeJSON), requestID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update request with grading: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("request %s not found", requestID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -217,12 +311,72 @@ func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestL
|
|||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
|
||||
query := "UPDATE requests SET response = ? WHERE id = ?"
|
||||
_, err = s.db.Exec(query, string(responseJSON), request.RequestID)
|
||||
result, err := s.stmtUpdateResponse.Exec(string(responseJSON), request.RequestID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update request with response: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("request %s not found", request.RequestID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveRequestWithResponse saves a request and its response in a single transaction
|
||||
func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
headersJSON, err := json.Marshal(request.Headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal headers: %w", err)
|
||||
}
|
||||
|
||||
bodyJSON, err := json.Marshal(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal body: %w", err)
|
||||
}
|
||||
|
||||
// Insert request
|
||||
_, err = tx.Stmt(s.stmtInsertRequest).Exec(
|
||||
request.RequestID,
|
||||
request.Timestamp,
|
||||
request.Method,
|
||||
request.Endpoint,
|
||||
string(headersJSON),
|
||||
string(bodyJSON),
|
||||
request.UserAgent,
|
||||
request.ContentType,
|
||||
request.Model,
|
||||
request.OriginalModel,
|
||||
request.RoutedModel,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert request: %w", err)
|
||||
}
|
||||
|
||||
// Update with response if present
|
||||
if request.Response != nil {
|
||||
responseJSON, err := json.Marshal(request.Response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Stmt(s.stmtUpdateResponse).Exec(string(responseJSON), request.RequestID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -232,10 +386,13 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
|
|||
}
|
||||
|
||||
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
||||
// Escape LIKE special characters to prevent pattern injection
|
||||
escapedID := escapeLikePattern(shortID)
|
||||
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
WHERE id LIKE ?
|
||||
WHERE id LIKE ? ESCAPE '\'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
|
@ -244,7 +401,7 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
|||
var headersJSON, bodyJSON string
|
||||
var promptGradeJSON, responseJSON sql.NullString
|
||||
|
||||
err := s.db.QueryRow(query, "%"+shortID).Scan(
|
||||
err := s.db.QueryRow(query, "%"+escapedID).Scan(
|
||||
&req.RequestID,
|
||||
&req.Timestamp,
|
||||
&req.Method,
|
||||
|
|
@ -267,29 +424,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
|||
return nil, "", fmt.Errorf("failed to query request: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal JSON fields
|
||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to unmarshal headers: %w", err)
|
||||
}
|
||||
|
||||
var body interface{}
|
||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
|
||||
}
|
||||
req.Body = body
|
||||
|
||||
if promptGradeJSON.Valid {
|
||||
var grade model.PromptGrade
|
||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
||||
req.PromptGrade = &grade
|
||||
}
|
||||
}
|
||||
|
||||
if responseJSON.Valid {
|
||||
var resp model.ResponseLog
|
||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
||||
req.Response = &resp
|
||||
}
|
||||
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &req, req.RequestID, nil
|
||||
|
|
@ -300,19 +436,36 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
|
|||
}
|
||||
|
||||
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
||||
query := `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
`
|
||||
return s.GetAllRequestsWithLimit(modelFilter, 0) // 0 means no limit
|
||||
}
|
||||
|
||||
// GetAllRequestsWithLimit returns requests with an optional limit (0 = no limit)
|
||||
func (s *sqliteStorageService) GetAllRequestsWithLimit(modelFilter string, limit int) ([]*model.RequestLog, error) {
|
||||
var query string
|
||||
args := []interface{}{}
|
||||
|
||||
if modelFilter != "" && modelFilter != "all" {
|
||||
query += " WHERE LOWER(model) LIKE ?"
|
||||
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
|
||||
|
||||
// Escape LIKE special characters
|
||||
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
|
||||
query = `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
WHERE LOWER(model) LIKE ? ESCAPE '\'
|
||||
ORDER BY timestamp DESC
|
||||
`
|
||||
args = append(args, "%"+escapedFilter+"%")
|
||||
} else {
|
||||
query = `
|
||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||
FROM requests
|
||||
ORDER BY timestamp DESC
|
||||
`
|
||||
}
|
||||
|
||||
query += " ORDER BY timestamp DESC"
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
if err != nil {
|
||||
|
|
@ -321,6 +474,124 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
|||
defer rows.Close()
|
||||
|
||||
var requests []*model.RequestLog
|
||||
for rows.Next() {
|
||||
req, err := s.scanSingleRow(rows)
|
||||
if err != nil {
|
||||
s.logger.Printf("Warning: failed to scan request row: %v", err)
|
||||
continue
|
||||
}
|
||||
requests = append(requests, req)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||
}
|
||||
|
||||
return requests, nil
|
||||
}
|
||||
|
||||
// DeleteRequestsOlderThan removes requests older than the specified duration
|
||||
func (s *sqliteStorageService) DeleteRequestsOlderThan(age time.Duration) (int, error) {
|
||||
cutoff := time.Now().Add(-age)
|
||||
|
||||
result, err := s.stmtDeleteOldRequests.Exec(cutoff.Format(time.RFC3339))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete old requests: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// GetDatabaseStats returns statistics about the database
|
||||
func (s *sqliteStorageService) GetDatabaseStats() (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
// Get row count
|
||||
var count int
|
||||
err := s.stmtGetRequestsCount.QueryRow().Scan(&count)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get count: %w", err)
|
||||
}
|
||||
stats["total_requests"] = count
|
||||
|
||||
// Get database size
|
||||
var pageCount, pageSize int
|
||||
err = s.db.QueryRow("PRAGMA page_count").Scan(&pageCount)
|
||||
if err == nil {
|
||||
err = s.db.QueryRow("PRAGMA page_size").Scan(&pageSize)
|
||||
if err == nil {
|
||||
stats["database_size_bytes"] = pageCount * pageSize
|
||||
}
|
||||
}
|
||||
|
||||
// Get oldest and newest timestamps
|
||||
var oldest, newest sql.NullString
|
||||
err = s.db.QueryRow("SELECT MIN(timestamp), MAX(timestamp) FROM requests").Scan(&oldest, &newest)
|
||||
if err == nil {
|
||||
if oldest.Valid {
|
||||
stats["oldest_request"] = oldest.String
|
||||
}
|
||||
if newest.Valid {
|
||||
stats["newest_request"] = newest.String
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) Close() error {
|
||||
// Close prepared statements
|
||||
if s.stmtInsertRequest != nil {
|
||||
s.stmtInsertRequest.Close()
|
||||
}
|
||||
if s.stmtUpdateResponse != nil {
|
||||
s.stmtUpdateResponse.Close()
|
||||
}
|
||||
if s.stmtUpdateGrading != nil {
|
||||
s.stmtUpdateGrading.Close()
|
||||
}
|
||||
if s.stmtGetRequestByID != nil {
|
||||
s.stmtGetRequestByID.Close()
|
||||
}
|
||||
if s.stmtGetRequestsPage != nil {
|
||||
s.stmtGetRequestsPage.Close()
|
||||
}
|
||||
if s.stmtGetRequestsCount != nil {
|
||||
s.stmtGetRequestsCount.Close()
|
||||
}
|
||||
if s.stmtDeleteOldRequests != nil {
|
||||
s.stmtDeleteOldRequests.Close()
|
||||
}
|
||||
|
||||
// Checkpoint WAL before closing
|
||||
_, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||
if err != nil {
|
||||
s.logger.Printf("Warning: failed to checkpoint WAL: %v", err)
|
||||
}
|
||||
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// escapeLikePattern escapes special characters in LIKE patterns
|
||||
func escapeLikePattern(s string) string {
|
||||
// Escape \, %, and _ characters
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return s
|
||||
}
|
||||
|
||||
// scanRequestRows scans multiple rows into a slice of RequestLog
|
||||
func (s *sqliteStorageService) scanRequestRows(rows *sql.Rows) ([]model.RequestLog, error) {
|
||||
var requests []model.RequestLog
|
||||
|
||||
for rows.Next() {
|
||||
var req model.RequestLog
|
||||
var headersJSON, bodyJSON string
|
||||
|
|
@ -342,43 +613,86 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
|||
&req.RoutedModel,
|
||||
)
|
||||
if err != nil {
|
||||
// Error scanning row - skip
|
||||
s.logger.Printf("Warning: failed to scan row: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Unmarshal JSON fields
|
||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||
// Error unmarshaling headers
|
||||
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||
s.logger.Printf("Warning: failed to unmarshal request fields: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
requests = append(requests, req)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||
}
|
||||
|
||||
return requests, nil
|
||||
}
|
||||
|
||||
// scanSingleRow scans a single row into a RequestLog pointer
|
||||
func (s *sqliteStorageService) scanSingleRow(rows *sql.Rows) (*model.RequestLog, error) {
|
||||
var req model.RequestLog
|
||||
var headersJSON, bodyJSON string
|
||||
var promptGradeJSON, responseJSON sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&req.RequestID,
|
||||
&req.Timestamp,
|
||||
&req.Method,
|
||||
&req.Endpoint,
|
||||
&headersJSON,
|
||||
&bodyJSON,
|
||||
&req.Model,
|
||||
&req.UserAgent,
|
||||
&req.ContentType,
|
||||
&promptGradeJSON,
|
||||
&responseJSON,
|
||||
&req.OriginalModel,
|
||||
&req.RoutedModel,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// unmarshalRequestFields unmarshals JSON fields into a RequestLog
|
||||
func (s *sqliteStorageService) unmarshalRequestFields(req *model.RequestLog, headersJSON, bodyJSON string, promptGradeJSON, responseJSON sql.NullString) error {
|
||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal headers: %w", err)
|
||||
}
|
||||
|
||||
var body interface{}
|
||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||
// Error unmarshaling body
|
||||
continue
|
||||
return fmt.Errorf("failed to unmarshal body: %w", err)
|
||||
}
|
||||
req.Body = body
|
||||
|
||||
if promptGradeJSON.Valid {
|
||||
var grade model.PromptGrade
|
||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err != nil {
|
||||
s.logger.Printf("Warning: failed to unmarshal prompt grade: %v", err)
|
||||
} else {
|
||||
req.PromptGrade = &grade
|
||||
}
|
||||
}
|
||||
|
||||
if responseJSON.Valid {
|
||||
var resp model.ResponseLog
|
||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err != nil {
|
||||
s.logger.Printf("Warning: failed to unmarshal response: %v", err)
|
||||
} else {
|
||||
req.Response = &resp
|
||||
}
|
||||
}
|
||||
|
||||
requests = append(requests, &req)
|
||||
}
|
||||
|
||||
return requests, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStorageService) Close() error {
|
||||
return s.db.Close()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
69
proxy/internal/service/storage_sqlite_test.go
Normal file
69
proxy/internal/service/storage_sqlite_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||
)
|
||||
|
||||
func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "requests.db")
|
||||
storage, err := NewSQLiteStorageService(&config.StorageConfig{DBPath: dbPath})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSQLiteStorageService() error = %v", err)
|
||||
}
|
||||
|
||||
sqliteStorage, ok := storage.(*sqliteStorageService)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected storage type %T", storage)
|
||||
}
|
||||
defer sqliteStorage.Close()
|
||||
|
||||
requests := []struct {
|
||||
id string
|
||||
model string
|
||||
}{
|
||||
{id: "1", model: "claude-3-5-sonnet"},
|
||||
{id: "2", model: "gpt-4o"},
|
||||
{id: "3", model: "claude-3-5-sonnet"},
|
||||
{id: "4", model: "gpt-4o-mini"},
|
||||
}
|
||||
|
||||
for i, req := range requests {
|
||||
_, err := storage.SaveRequest(&model.RequestLog{
|
||||
RequestID: req.id,
|
||||
Timestamp: time.Date(2026, 3, 19, 12, 0, i, 0, time.UTC).Format(time.RFC3339),
|
||||
Method: "POST",
|
||||
Endpoint: "/v1/messages",
|
||||
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||
Body: map[string]string{"request": fmt.Sprintf("body-%d", i)},
|
||||
Model: req.model,
|
||||
UserAgent: "test",
|
||||
ContentType: "application/json",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SaveRequest() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
got, total, err := storage.GetRequests(1, 1, "gpt")
|
||||
if err != nil {
|
||||
t.Fatalf("GetRequests() error = %v", err)
|
||||
}
|
||||
|
||||
if total != 2 {
|
||||
t.Fatalf("expected filtered total 2, got %d", total)
|
||||
}
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 paginated result, got %d", len(got))
|
||||
}
|
||||
|
||||
if got[0].RequestID != "4" {
|
||||
t.Fatalf("expected newest filtered request ID 4, got %s", got[0].RequestID)
|
||||
}
|
||||
}
|
||||
30
proxy/internal/sse/sse.go
Normal file
30
proxy/internal/sse/sse.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package sse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ForEachLine reads line-oriented SSE content without bufio.Scanner's token limit.
|
||||
func ForEachLine(r io.Reader, fn func(string) error) error {
|
||||
reader := bufio.NewReader(r)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(line) > 0 {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if callErr := fn(line); callErr != nil {
|
||||
return callErr
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
31
proxy/internal/sse/sse_test.go
Normal file
31
proxy/internal/sse/sse_test.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package sse
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestForEachLineHandlesLargeLines(t *testing.T) {
|
||||
largePayload := strings.Repeat("x", 128*1024)
|
||||
input := "data: " + largePayload + "\n\n"
|
||||
|
||||
var lines []string
|
||||
if err := ForEachLine(strings.NewReader(input), func(line string) error {
|
||||
lines = append(lines, line)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("ForEachLine() error = %v", err)
|
||||
}
|
||||
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("expected 2 lines, got %d", len(lines))
|
||||
}
|
||||
|
||||
if lines[0] != "data: "+largePayload {
|
||||
t.Fatalf("unexpected first line length: got %d", len(lines[0]))
|
||||
}
|
||||
|
||||
if lines[1] != "" {
|
||||
t.Fatalf("expected blank separator line, got %q", lines[1])
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue