Harden streaming, pagination, and config loading
This commit is contained in:
parent
02c9c76667
commit
6cda36312a
16 changed files with 1079 additions and 244 deletions
13
.env.example
13
.env.example
|
|
@ -12,5 +12,16 @@ ANTHROPIC_FORWARD_URL=https://api.anthropic.com
|
||||||
ANTHROPIC_VERSION=2023-06-01
|
ANTHROPIC_VERSION=2023-06-01
|
||||||
ANTHROPIC_MAX_RETRIES=3
|
ANTHROPIC_MAX_RETRIES=3
|
||||||
|
|
||||||
|
# OpenAI Configuration (for subagent routing)
|
||||||
|
# OPENAI_API_KEY=your-openai-api-key
|
||||||
|
# OPENAI_BASE_URL=https://api.openai.com
|
||||||
|
# OPENAI_ALLOW_CLIENT_API_KEY=false
|
||||||
|
# OPENAI_CLIENT_API_KEY_HEADER=x-openai-api-key
|
||||||
|
|
||||||
# Storage Configuration
|
# Storage Configuration
|
||||||
DATABASE_PATH=requests.db
|
DB_PATH=requests.db
|
||||||
|
|
||||||
|
# CORS Configuration (comma-separated values)
|
||||||
|
# CORS_ALLOWED_ORIGINS=*
|
||||||
|
# CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||||
|
# CORS_ALLOWED_HEADERS=*
|
||||||
|
|
@ -24,20 +24,50 @@ providers:
|
||||||
anthropic:
|
anthropic:
|
||||||
# Base URL for Anthropic API (can be changed for custom endpoints)
|
# Base URL for Anthropic API (can be changed for custom endpoints)
|
||||||
base_url: "https://api.anthropic.com"
|
base_url: "https://api.anthropic.com"
|
||||||
|
|
||||||
# Maximum number of retries for failed requests
|
# Maximum number of retries for failed requests
|
||||||
max_retries: 3
|
max_retries: 3
|
||||||
|
|
||||||
# OpenAI configuration
|
# OpenAI configuration
|
||||||
openai:
|
openai:
|
||||||
# API key for OpenAI
|
# API key for OpenAI
|
||||||
# Can also be set via OPENAI_API_KEY environment variable
|
# Can also be set via OPENAI_API_KEY environment variable
|
||||||
# api_key: "..."
|
# api_key: "..."
|
||||||
|
|
||||||
# Base URL for OpenAI API (can be changed for custom endpoints)
|
# Base URL for OpenAI API (can be changed for custom endpoints)
|
||||||
# Can also be set via OPENAI_BASE_URL environment variable
|
# Can also be set via OPENAI_BASE_URL environment variable
|
||||||
# base_url: "https://api.openai.com"
|
# base_url: "https://api.openai.com"
|
||||||
|
|
||||||
|
# Allow clients to provide their own API key via header
|
||||||
|
# Can also be set via OPENAI_ALLOW_CLIENT_API_KEY environment variable
|
||||||
|
allow_client_api_key: false
|
||||||
|
|
||||||
|
# Header name for client-provided API key (default: x-openai-api-key)
|
||||||
|
# Can also be set via OPENAI_CLIENT_API_KEY_HEADER environment variable
|
||||||
|
client_api_key_header: "x-openai-api-key"
|
||||||
|
|
||||||
|
# CORS Configuration
|
||||||
|
# Controls Cross-Origin Resource Sharing for the web UI
|
||||||
|
cors:
|
||||||
|
# Allowed origins (use ["*"] for all origins - not recommended for production)
|
||||||
|
# Can also be set via CORS_ALLOWED_ORIGINS environment variable (comma-separated)
|
||||||
|
allowed_origins:
|
||||||
|
- "*"
|
||||||
|
|
||||||
|
# Allowed HTTP methods
|
||||||
|
# Can also be set via CORS_ALLOWED_METHODS environment variable (comma-separated)
|
||||||
|
allowed_methods:
|
||||||
|
- "GET"
|
||||||
|
- "POST"
|
||||||
|
- "PUT"
|
||||||
|
- "DELETE"
|
||||||
|
- "OPTIONS"
|
||||||
|
|
||||||
|
# Allowed headers (use ["*"] for all headers)
|
||||||
|
# Can also be set via CORS_ALLOWED_HEADERS environment variable (comma-separated)
|
||||||
|
allowed_headers:
|
||||||
|
- "*"
|
||||||
|
|
||||||
# Storage configuration
|
# Storage configuration
|
||||||
storage:
|
storage:
|
||||||
# SQLite database path for storing request history
|
# SQLite database path for storing request history
|
||||||
|
|
@ -69,23 +99,30 @@ subagents:
|
||||||
# The following environment variables will override the YAML configuration:
|
# The following environment variables will override the YAML configuration:
|
||||||
#
|
#
|
||||||
# Server:
|
# Server:
|
||||||
# PORT - Server port
|
# PORT - Server port
|
||||||
# READ_TIMEOUT - Read timeout duration
|
# READ_TIMEOUT - Read timeout duration
|
||||||
# WRITE_TIMEOUT - Write timeout duration
|
# WRITE_TIMEOUT - Write timeout duration
|
||||||
# IDLE_TIMEOUT - Idle timeout duration
|
# IDLE_TIMEOUT - Idle timeout duration
|
||||||
#
|
#
|
||||||
# Anthropic:
|
# Anthropic:
|
||||||
# ANTHROPIC_FORWARD_URL - Anthropic base URL
|
# ANTHROPIC_FORWARD_URL - Anthropic base URL
|
||||||
# ANTHROPIC_VERSION - Anthropic API version
|
# ANTHROPIC_VERSION - Anthropic API version
|
||||||
# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests
|
# ANTHROPIC_MAX_RETRIES - Maximum retries for Anthropic requests
|
||||||
#
|
#
|
||||||
# OpenAI:
|
# OpenAI:
|
||||||
# OPENAI_API_KEY - OpenAI API key
|
# OPENAI_API_KEY - OpenAI API key
|
||||||
# OPENAI_BASE_URL - OpenAI base URL
|
# OPENAI_BASE_URL - OpenAI base URL
|
||||||
|
# OPENAI_ALLOW_CLIENT_API_KEY - Allow client-provided API keys (true/false)
|
||||||
|
# OPENAI_CLIENT_API_KEY_HEADER - Header name for client API key
|
||||||
#
|
#
|
||||||
# Storage:
|
# Storage:
|
||||||
# DB_PATH - Database file path
|
# DB_PATH - Database file path
|
||||||
|
#
|
||||||
|
# CORS:
|
||||||
|
# CORS_ALLOWED_ORIGINS - Comma-separated allowed origins
|
||||||
|
# CORS_ALLOWED_METHODS - Comma-separated allowed methods
|
||||||
|
# CORS_ALLOWED_HEADERS - Comma-separated allowed headers
|
||||||
#
|
#
|
||||||
# Subagents:
|
# Subagents:
|
||||||
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
|
# SUBAGENT_MAPPINGS - Comma-separated subagent:model pairs
|
||||||
# Example: "code-reviewer:claude-3-5-sonnet"
|
# Example: "code-reviewer:claude-3-5-sonnet"
|
||||||
|
|
@ -50,9 +50,9 @@ func main() {
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
|
|
||||||
corsHandler := handlers.CORS(
|
corsHandler := handlers.CORS(
|
||||||
handlers.AllowedOrigins([]string{"*"}),
|
handlers.AllowedOrigins(cfg.CORS.AllowedOrigins),
|
||||||
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}),
|
handlers.AllowedMethods(cfg.CORS.AllowedMethods),
|
||||||
handlers.AllowedHeaders([]string{"*"}),
|
handlers.AllowedHeaders(cfg.CORS.AllowedHeaders),
|
||||||
)
|
)
|
||||||
|
|
||||||
r.Use(middleware.Logging)
|
r.Use(middleware.Logging)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
|
@ -15,9 +17,16 @@ type Config struct {
|
||||||
Providers ProvidersConfig `yaml:"providers"`
|
Providers ProvidersConfig `yaml:"providers"`
|
||||||
Storage StorageConfig `yaml:"storage"`
|
Storage StorageConfig `yaml:"storage"`
|
||||||
Subagents SubagentsConfig `yaml:"subagents"`
|
Subagents SubagentsConfig `yaml:"subagents"`
|
||||||
|
CORS CORSConfig `yaml:"cors"`
|
||||||
Anthropic AnthropicConfig
|
Anthropic AnthropicConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CORSConfig struct {
|
||||||
|
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||||
|
AllowedMethods []string `yaml:"allowed_methods"`
|
||||||
|
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||||
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port string `yaml:"port"`
|
Port string `yaml:"port"`
|
||||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
|
|
@ -45,8 +54,10 @@ type AnthropicProviderConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIProviderConfig struct {
|
type OpenAIProviderConfig struct {
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
APIKey string `yaml:"api_key"`
|
APIKey string `yaml:"api_key"`
|
||||||
|
AllowClientAPIKey bool `yaml:"allow_client_api_key"` // Allow clients to provide their own API key
|
||||||
|
ClientAPIKeyHeader string `yaml:"client_api_key_header"` // Header name for client API key (default: x-openai-api-key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnthropicConfig struct {
|
type AnthropicConfig struct {
|
||||||
|
|
@ -92,8 +103,10 @@ func Load() (*Config, error) {
|
||||||
MaxRetries: 3,
|
MaxRetries: 3,
|
||||||
},
|
},
|
||||||
OpenAI: OpenAIProviderConfig{
|
OpenAI: OpenAIProviderConfig{
|
||||||
BaseURL: "https://api.openai.com",
|
BaseURL: "https://api.openai.com",
|
||||||
APIKey: "",
|
APIKey: "",
|
||||||
|
AllowClientAPIKey: false,
|
||||||
|
ClientAPIKeyHeader: "x-openai-api-key",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Storage: StorageConfig{
|
Storage: StorageConfig{
|
||||||
|
|
@ -103,25 +116,17 @@ func Load() (*Config, error) {
|
||||||
Enable: false,
|
Enable: false,
|
||||||
Mappings: make(map[string]string),
|
Mappings: make(map[string]string),
|
||||||
},
|
},
|
||||||
|
CORS: CORSConfig{
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
|
AllowedHeaders: []string{"*"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to load config.yaml from the project root
|
if err := loadFirstAvailableConfig(cfg, candidateConfigPaths()); err != nil {
|
||||||
// The proxy binary is in proxy/ directory, config.yaml is in the parent
|
return nil, err
|
||||||
configPath := filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml")
|
|
||||||
|
|
||||||
// If that doesn't work, try relative to current directory
|
|
||||||
if _, err := os.Stat(configPath); err != nil {
|
|
||||||
// Try common locations relative to where the binary might be run
|
|
||||||
for _, tryPath := range []string{"config.yaml", "../config.yaml", "../../config.yaml"} {
|
|
||||||
if _, err := os.Stat(tryPath); err == nil {
|
|
||||||
configPath = tryPath
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.loadFromFile(configPath)
|
|
||||||
|
|
||||||
// Apply environment variable overrides AFTER loading from file
|
// Apply environment variable overrides AFTER loading from file
|
||||||
if envPort := os.Getenv("PORT"); envPort != "" {
|
if envPort := os.Getenv("PORT"); envPort != "" {
|
||||||
cfg.Server.Port = envPort
|
cfg.Server.Port = envPort
|
||||||
|
|
@ -154,12 +159,29 @@ func Load() (*Config, error) {
|
||||||
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
if envKey := os.Getenv("OPENAI_API_KEY"); envKey != "" {
|
||||||
cfg.Providers.OpenAI.APIKey = envKey
|
cfg.Providers.OpenAI.APIKey = envKey
|
||||||
}
|
}
|
||||||
|
if envAllow := os.Getenv("OPENAI_ALLOW_CLIENT_API_KEY"); envAllow != "" {
|
||||||
|
cfg.Providers.OpenAI.AllowClientAPIKey = envAllow == "true" || envAllow == "1"
|
||||||
|
}
|
||||||
|
if envHeader := os.Getenv("OPENAI_CLIENT_API_KEY_HEADER"); envHeader != "" {
|
||||||
|
cfg.Providers.OpenAI.ClientAPIKeyHeader = envHeader
|
||||||
|
}
|
||||||
|
|
||||||
// Override storage settings
|
// Override storage settings
|
||||||
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
if envPath := os.Getenv("DB_PATH"); envPath != "" {
|
||||||
cfg.Storage.DBPath = envPath
|
cfg.Storage.DBPath = envPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Override CORS settings (comma-separated values)
|
||||||
|
if envOrigins := os.Getenv("CORS_ALLOWED_ORIGINS"); envOrigins != "" {
|
||||||
|
cfg.CORS.AllowedOrigins = splitAndTrim(envOrigins)
|
||||||
|
}
|
||||||
|
if envMethods := os.Getenv("CORS_ALLOWED_METHODS"); envMethods != "" {
|
||||||
|
cfg.CORS.AllowedMethods = splitAndTrim(envMethods)
|
||||||
|
}
|
||||||
|
if envHeaders := os.Getenv("CORS_ALLOWED_HEADERS"); envHeaders != "" {
|
||||||
|
cfg.CORS.AllowedHeaders = splitAndTrim(envHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
// Sync legacy Anthropic config
|
// Sync legacy Anthropic config
|
||||||
cfg.Anthropic = AnthropicConfig{
|
cfg.Anthropic = AnthropicConfig{
|
||||||
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
||||||
|
|
@ -203,6 +225,45 @@ func (c *Config) loadFromFile(path string) error {
|
||||||
return yaml.Unmarshal(data, c)
|
return yaml.Unmarshal(data, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func candidateConfigPaths() []string {
|
||||||
|
paths := []string{
|
||||||
|
filepath.Join(filepath.Dir(os.Args[0]), "..", "config.yaml"),
|
||||||
|
"config.yaml",
|
||||||
|
"../config.yaml",
|
||||||
|
"../../config.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(paths))
|
||||||
|
unique := make([]string, 0, len(paths))
|
||||||
|
for _, path := range paths {
|
||||||
|
if _, ok := seen[path]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[path] = struct{}{}
|
||||||
|
unique = append(unique, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadFirstAvailableConfig(cfg *Config, paths []string) error {
|
||||||
|
for _, path := range paths {
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat config file %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.loadFromFile(path); err != nil {
|
||||||
|
return fmt.Errorf("failed to load config file %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getEnv(key, defaultValue string) string {
|
func getEnv(key, defaultValue string) string {
|
||||||
if value := os.Getenv(key); value != "" {
|
if value := os.Getenv(key); value != "" {
|
||||||
return value
|
return value
|
||||||
|
|
@ -237,3 +298,15 @@ func getInt(key string, defaultValue int) int {
|
||||||
|
|
||||||
return intValue
|
return intValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitAndTrim(s string) []string {
|
||||||
|
parts := strings.Split(s, ",")
|
||||||
|
result := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
trimmed := strings.TrimSpace(part)
|
||||||
|
if trimmed != "" {
|
||||||
|
result = append(result, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
|
||||||
30
proxy/internal/config/config_test.go
Normal file
30
proxy/internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadFirstAvailableConfigReturnsParseError(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tempDir, "config.yaml")
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, []byte("server: ["), 0o600); err != nil {
|
||||||
|
t.Fatalf("WriteFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &Config{}
|
||||||
|
if err := loadFirstAvailableConfig(cfg, []string{configPath}); err == nil {
|
||||||
|
t.Fatal("expected parse error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFirstAvailableConfigSkipsMissingFiles(t *testing.T) {
|
||||||
|
cfg := &Config{}
|
||||||
|
if err := loadFirstAvailableConfig(cfg, []string{
|
||||||
|
filepath.Join(t.TempDir(), "missing.yaml"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("expected nil error for missing config, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
|
@ -20,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/service"
|
"github.com/seifghazi/claude-code-monitor/internal/service"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/sse"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
|
|
@ -179,37 +179,13 @@ func (h *Handler) GetRequests(w http.ResponseWriter, r *http.Request) {
|
||||||
modelFilter = "all"
|
modelFilter = "all"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all requests with model filter applied at storage level
|
requests, total, err := h.storageService.GetRequests(page, limit, modelFilter)
|
||||||
allRequests, err := h.storageService.GetAllRequests(modelFilter)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error getting requests: %v", err)
|
log.Printf("Error getting requests: %v", err)
|
||||||
http.Error(w, "Failed to get requests", http.StatusInternalServerError)
|
http.Error(w, "Failed to get requests", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert pointers to values for consistency
|
|
||||||
requests := make([]model.RequestLog, len(allRequests))
|
|
||||||
for i, req := range allRequests {
|
|
||||||
if req != nil {
|
|
||||||
requests[i] = *req
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate total before pagination
|
|
||||||
total := len(requests)
|
|
||||||
|
|
||||||
// Apply pagination
|
|
||||||
start := (page - 1) * limit
|
|
||||||
end := start + limit
|
|
||||||
if start >= len(requests) {
|
|
||||||
requests = []model.RequestLog{}
|
|
||||||
} else {
|
|
||||||
if end > len(requests) {
|
|
||||||
end = len(requests)
|
|
||||||
}
|
|
||||||
requests = requests[start:end]
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(struct {
|
json.NewEncoder(w).Encode(struct {
|
||||||
Requests []model.RequestLog `json:"requests"`
|
Requests []model.RequestLog `json:"requests"`
|
||||||
|
|
@ -242,6 +218,8 @@ func (h *Handler) NotFound(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||||
|
// Forward important upstream headers (rate limits, request IDs, etc.)
|
||||||
|
ForwardResponseHeaders(w, resp)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
|
@ -278,16 +256,17 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
var messageID string
|
var messageID string
|
||||||
var modelName string
|
var modelName string
|
||||||
var stopReason string
|
var stopReason string
|
||||||
|
var sawMessageStop bool
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
streamErr := sse.ForEachLine(resp.Body, func(line string) error {
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
streamingChunks = append(streamingChunks, line)
|
streamingChunks = append(streamingChunks, line)
|
||||||
fmt.Fprintf(w, "%s\n\n", line)
|
if _, err := fmt.Fprintf(w, "%s\n\n", line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if f, ok := w.(http.Flusher); ok {
|
if f, ok := w.(http.Flusher); ok {
|
||||||
f.Flush()
|
f.Flush()
|
||||||
}
|
}
|
||||||
|
|
@ -298,7 +277,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
var genericEvent map[string]interface{}
|
var genericEvent map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(jsonData), &genericEvent); err != nil {
|
if err := json.Unmarshal([]byte(jsonData), &genericEvent); err != nil {
|
||||||
log.Printf("⚠️ Error unmarshalling streaming event: %v", err)
|
log.Printf("⚠️ Error unmarshalling streaming event: %v", err)
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture metadata from message_start event
|
// Capture metadata from message_start event
|
||||||
|
|
@ -347,7 +326,7 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
var event model.StreamingEvent
|
var event model.StreamingEvent
|
||||||
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
||||||
// Skip if structured parsing fails, but we already got the usage data above
|
// Skip if structured parsing fails, but we already got the usage data above
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
|
|
@ -366,8 +345,13 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
toolCalls = append(toolCalls, *event.ContentBlock)
|
toolCalls = append(toolCalls, *event.ContentBlock)
|
||||||
}
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// End of stream - scanner will exit on its own
|
sawMessageStop = true
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if streamErr == nil && !sawMessageStop {
|
||||||
|
streamErr = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
responseLog := &model.ResponseLog{
|
responseLog := &model.ResponseLog{
|
||||||
|
|
@ -378,6 +362,9 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
IsStreaming: true,
|
IsStreaming: true,
|
||||||
CompletedAt: time.Now().Format(time.RFC3339),
|
CompletedAt: time.Now().Format(time.RFC3339),
|
||||||
}
|
}
|
||||||
|
if streamErr != nil {
|
||||||
|
responseLog.StreamError = streamErr.Error()
|
||||||
|
}
|
||||||
|
|
||||||
// Create a structured response body that matches Anthropic's format
|
// Create a structured response body that matches Anthropic's format
|
||||||
var contentBlocks []model.AnthropicContentBlock
|
var contentBlocks []model.AnthropicContentBlock
|
||||||
|
|
@ -417,14 +404,31 @@ func (h *Handler) handleStreamingResponse(w http.ResponseWriter, resp *http.Resp
|
||||||
log.Printf("❌ Error updating request with streaming response: %v", err)
|
log.Printf("❌ Error updating request with streaming response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if streamErr != nil {
|
||||||
log.Printf("❌ Streaming error: %v", err)
|
log.Printf("❌ Streaming error: %v", streamErr)
|
||||||
|
// Send error event to client in Anthropic streaming format
|
||||||
|
errorEvent := map[string]interface{}{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]interface{}{
|
||||||
|
"type": "stream_error",
|
||||||
|
"message": fmt.Sprintf("Stream interrupted: %v", streamErr),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if errorJSON, jsonErr := json.Marshal(errorEvent); jsonErr == nil {
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", errorJSON)
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Println("✅ Streaming response completed")
|
log.Println("✅ Streaming response completed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, requestLog *model.RequestLog, startTime time.Time) {
|
||||||
|
// Forward important upstream headers (rate limits, request IDs, etc.)
|
||||||
|
ForwardResponseHeaders(w, resp)
|
||||||
|
|
||||||
responseBytes, err := io.ReadAll(resp.Body)
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("❌ Error reading Anthropic response: %v", err)
|
log.Printf("❌ Error reading Anthropic response: %v", err)
|
||||||
|
|
@ -464,6 +468,7 @@ func (h *Handler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.R
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Printf("❌ Anthropic API error: %d %s", resp.StatusCode, string(responseBytes))
|
log.Printf("❌ Anthropic API error: %d %s", resp.StatusCode, string(responseBytes))
|
||||||
|
// Headers already forwarded at start of function
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
w.Write(responseBytes)
|
w.Write(responseBytes)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,68 @@ import (
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Headers that should be forwarded from upstream responses to clients
|
||||||
|
var forwardableResponseHeaders = []string{
|
||||||
|
// Anthropic rate limit headers
|
||||||
|
"anthropic-ratelimit-requests-limit",
|
||||||
|
"anthropic-ratelimit-requests-remaining",
|
||||||
|
"anthropic-ratelimit-requests-reset",
|
||||||
|
"anthropic-ratelimit-tokens-limit",
|
||||||
|
"anthropic-ratelimit-tokens-remaining",
|
||||||
|
"anthropic-ratelimit-tokens-reset",
|
||||||
|
// Standard rate limit headers
|
||||||
|
"x-ratelimit-limit",
|
||||||
|
"x-ratelimit-remaining",
|
||||||
|
"x-ratelimit-reset",
|
||||||
|
"retry-after",
|
||||||
|
// Request tracking
|
||||||
|
"x-request-id",
|
||||||
|
"request-id",
|
||||||
|
// Anthropic specific
|
||||||
|
"anthropic-organization-id",
|
||||||
|
// OpenAI specific
|
||||||
|
"openai-organization",
|
||||||
|
"openai-processing-ms",
|
||||||
|
"openai-version",
|
||||||
|
"x-request-id",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardResponseHeaders copies important headers from upstream response to client response
|
||||||
|
func ForwardResponseHeaders(w http.ResponseWriter, resp *http.Response) {
|
||||||
|
for _, header := range forwardableResponseHeaders {
|
||||||
|
if values := resp.Header.Values(header); len(values) > 0 {
|
||||||
|
for _, value := range values {
|
||||||
|
w.Header().Add(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyAllResponseHeaders copies all non-hop-by-hop headers from upstream to client
|
||||||
|
func CopyAllResponseHeaders(w http.ResponseWriter, resp *http.Response) {
|
||||||
|
hopByHopHeaders := map[string]bool{
|
||||||
|
"connection": true,
|
||||||
|
"keep-alive": true,
|
||||||
|
"proxy-authenticate": true,
|
||||||
|
"proxy-authorization": true,
|
||||||
|
"te": true,
|
||||||
|
"trailers": true,
|
||||||
|
"transfer-encoding": true,
|
||||||
|
"upgrade": true,
|
||||||
|
"content-encoding": true, // We handle decompression ourselves
|
||||||
|
"content-length": true, // May change after decompression
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
if hopByHopHeaders[strings.ToLower(key)] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, value := range values {
|
||||||
|
w.Header().Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeHeaders removes sensitive headers before logging/storage
|
// SanitizeHeaders removes sensitive headers before logging/storage
|
||||||
func SanitizeHeaders(headers http.Header) http.Header {
|
func SanitizeHeaders(headers http.Header) http.Header {
|
||||||
sanitized := make(http.Header)
|
sanitized := make(http.Header)
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ type ResponseLog struct {
|
||||||
Headers map[string][]string `json:"headers"`
|
Headers map[string][]string `json:"headers"`
|
||||||
Body json.RawMessage `json:"body,omitempty"`
|
Body json.RawMessage `json:"body,omitempty"`
|
||||||
BodyText string `json:"bodyText,omitempty"`
|
BodyText string `json:"bodyText,omitempty"`
|
||||||
|
StreamError string `json:"streamError,omitempty"`
|
||||||
ResponseTime int64 `json:"responseTime"`
|
ResponseTime int64 `json:"responseTime"`
|
||||||
StreamingChunks []string `json:"streamingChunks,omitempty"`
|
StreamingChunks []string `json:"streamingChunks,omitempty"`
|
||||||
IsStreaming bool `json:"isStreaming"`
|
IsStreaming bool `json:"isStreaming"`
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,14 @@ func (p *AnthropicProvider) ForwardRequest(ctx context.Context, originalReq *htt
|
||||||
proxyReq.Header.Set("anthropic-version", p.config.Version)
|
proxyReq.Header.Set("anthropic-version", p.config.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Support gzip encoding
|
// Handle Accept-Encoding: We accept gzip from upstream for efficiency,
|
||||||
proxyReq.Header.Set("Accept-Encoding", "gzip")
|
// but we decompress before forwarding to the client. This is transparent
|
||||||
|
// to the client - they receive uncompressed data regardless of what they requested.
|
||||||
|
// We preserve gzip if client already requested it, otherwise add it.
|
||||||
|
clientEncoding := proxyReq.Header.Get("Accept-Encoding")
|
||||||
|
if clientEncoding == "" || !strings.Contains(clientEncoding, "gzip") {
|
||||||
|
proxyReq.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
}
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
resp, err := p.client.Do(proxyReq)
|
resp, err := p.client.Do(proxyReq)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -15,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/sse"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIProvider struct {
|
type OpenAIProvider struct {
|
||||||
|
|
@ -78,13 +78,29 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
||||||
// Remove Anthropic-specific headers
|
// Remove Anthropic-specific headers
|
||||||
proxyReq.Header.Del("anthropic-version")
|
proxyReq.Header.Del("anthropic-version")
|
||||||
proxyReq.Header.Del("x-api-key")
|
proxyReq.Header.Del("x-api-key")
|
||||||
|
proxyReq.Header.Del("Authorization")
|
||||||
|
|
||||||
|
// Determine which API key to use
|
||||||
|
apiKey := p.config.APIKey
|
||||||
|
|
||||||
|
// Check for client-provided API key if allowed
|
||||||
|
if p.config.AllowClientAPIKey && p.config.ClientAPIKeyHeader != "" {
|
||||||
|
if clientKey := originalReq.Header.Get(p.config.ClientAPIKeyHeader); clientKey != "" {
|
||||||
|
apiKey = clientKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add OpenAI headers
|
// Add OpenAI headers
|
||||||
if p.config.APIKey != "" {
|
if apiKey != "" {
|
||||||
proxyReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
proxyReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
}
|
}
|
||||||
proxyReq.Header.Set("Content-Type", "application/json")
|
proxyReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Remove the client API key header from the proxied request
|
||||||
|
if p.config.ClientAPIKeyHeader != "" {
|
||||||
|
proxyReq.Header.Del(p.config.ClientAPIKeyHeader)
|
||||||
|
}
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
resp, err := p.client.Do(proxyReq)
|
resp, err := p.client.Do(proxyReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -139,9 +155,12 @@ func (p *OpenAIProvider) ForwardRequest(ctx context.Context, originalReq *http.R
|
||||||
|
|
||||||
// Start a goroutine to transform the stream
|
// Start a goroutine to transform the stream
|
||||||
go func() {
|
go func() {
|
||||||
defer pw.Close()
|
|
||||||
defer bodyReader.Close()
|
defer bodyReader.Close()
|
||||||
transformOpenAIStreamToAnthropic(bodyReader, pw)
|
if err := transformOpenAIStreamToAnthropic(bodyReader, pw); err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = pw.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Replace the response body with our transformed stream
|
// Replace the response body with our transformed stream
|
||||||
|
|
@ -332,10 +351,10 @@ func convertAnthropicToOpenAI(req *model.AnthropicRequest) map[string]interface{
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check if max_tokens exceeds the model's limit and cap it if necessary
|
// Get model-specific max token limit
|
||||||
maxTokensLimit := 16384 // Assuming this is the limit for the model
|
// Let the API handle validation for unknown models rather than using arbitrary caps
|
||||||
if req.MaxTokens > maxTokensLimit {
|
maxTokensLimit := getModelMaxTokens(req.Model)
|
||||||
// Capping max_tokens to model limit
|
if maxTokensLimit > 0 && req.MaxTokens > maxTokensLimit {
|
||||||
req.MaxTokens = maxTokensLimit
|
req.MaxTokens = maxTokensLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -473,6 +492,52 @@ func min(a, b int) int {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getModelMaxTokens returns the max output tokens for known models
|
||||||
|
// Returns 0 for unknown models, letting the API handle validation
|
||||||
|
func getModelMaxTokens(model string) int {
|
||||||
|
// Model-specific max completion token limits
|
||||||
|
modelLimits := map[string]int{
|
||||||
|
// GPT-4 Turbo and GPT-4o models
|
||||||
|
"gpt-4-turbo": 4096,
|
||||||
|
"gpt-4-turbo-preview": 4096,
|
||||||
|
"gpt-4o": 16384,
|
||||||
|
"gpt-4o-mini": 16384,
|
||||||
|
"gpt-4o-2024-05-13": 16384,
|
||||||
|
"gpt-4o-2024-08-06": 16384,
|
||||||
|
// GPT-4 models
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4-32k": 8192,
|
||||||
|
"gpt-4-0613": 8192,
|
||||||
|
// GPT-3.5 models
|
||||||
|
"gpt-3.5-turbo": 4096,
|
||||||
|
"gpt-3.5-turbo-16k": 4096,
|
||||||
|
"gpt-3.5-turbo-0125": 4096,
|
||||||
|
"gpt-3.5-turbo-1106": 4096,
|
||||||
|
// o1 reasoning models
|
||||||
|
"o1": 100000,
|
||||||
|
"o1-preview": 32768,
|
||||||
|
"o1-mini": 65536,
|
||||||
|
// o3 reasoning models (estimated based on o1 patterns)
|
||||||
|
"o3": 100000,
|
||||||
|
"o3-mini": 65536,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for exact match first
|
||||||
|
if limit, ok := modelLimits[model]; ok {
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for prefix matches for versioned models
|
||||||
|
for prefix, limit := range modelLimits {
|
||||||
|
if strings.HasPrefix(model, prefix) {
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return 0 for unknown models - let the API validate
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
// This is a simplified transformation
|
// This is a simplified transformation
|
||||||
// In production, you'd want to handle all fields properly
|
// In production, you'd want to handle all fields properly
|
||||||
|
|
@ -579,19 +644,16 @@ func transformOpenAIResponseToAnthropic(respBody []byte) []byte {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStream io.Writer) {
|
func transformOpenAIStreamToAnthropic(openAIStream io.Reader, anthropicStream io.Writer) error {
|
||||||
defer openAIStream.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(openAIStream)
|
|
||||||
var messageStarted bool
|
var messageStarted bool
|
||||||
var contentStarted bool
|
var contentStarted bool
|
||||||
|
var sawDone bool
|
||||||
|
|
||||||
for scanner.Scan() {
|
err := sse.ForEachLine(openAIStream, func(line string) error {
|
||||||
line := scanner.Text()
|
|
||||||
|
|
||||||
// Skip empty lines
|
// Skip empty lines
|
||||||
if line == "" {
|
if line == "" {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle SSE data lines
|
// Handle SSE data lines
|
||||||
|
|
@ -600,21 +662,28 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
|
|
||||||
// Handle end of stream
|
// Handle end of stream
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
|
sawDone = true
|
||||||
// Send Anthropic-style completion
|
// Send Anthropic-style completion
|
||||||
if contentStarted {
|
if contentStarted {
|
||||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if messageStarted {
|
if messageStarted {
|
||||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n")
|
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}}\n\n"); err != nil {
|
||||||
fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n")
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(anthropicStream, "data: {\"type\":\"message_stop\"}\n\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse OpenAI response
|
// Parse OpenAI response
|
||||||
var openAIChunk map[string]interface{}
|
var openAIChunk map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
if err := json.Unmarshal([]byte(data), &openAIChunk); err != nil {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for usage data BEFORE processing choices
|
// Check for usage data BEFORE processing choices
|
||||||
|
|
@ -644,7 +713,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
"usage": anthropicUsage,
|
"usage": anthropicUsage,
|
||||||
}
|
}
|
||||||
usageJSON, _ := json.Marshal(usageDelta)
|
usageJSON, _ := json.Marshal(usageDelta)
|
||||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON)
|
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", usageJSON); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -652,17 +723,17 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
choices, ok := openAIChunk["choices"].([]interface{})
|
choices, ok := openAIChunk["choices"].([]interface{})
|
||||||
if !ok || len(choices) == 0 {
|
if !ok || len(choices) == 0 {
|
||||||
// Skip further processing if no choices, but we already handled usage above
|
// Skip further processing if no choices, but we already handled usage above
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
choice, ok := choices[0].(map[string]interface{})
|
choice, ok := choices[0].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
delta, ok := choice["delta"].(map[string]interface{})
|
delta, ok := choice["delta"].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle first chunk - send message_start
|
// Handle first chunk - send message_start
|
||||||
|
|
@ -684,7 +755,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
startJSON, _ := json.Marshal(messageStart)
|
startJSON, _ := json.Marshal(messageStart)
|
||||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON)
|
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", startJSON); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle content
|
// Handle content
|
||||||
|
|
@ -701,7 +774,9 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
blockStartJSON, _ := json.Marshal(blockStart)
|
blockStartJSON, _ := json.Marshal(blockStart)
|
||||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON)
|
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", blockStartJSON); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send content_block_delta
|
// Send content_block_delta
|
||||||
|
|
@ -714,9 +789,22 @@ func transformOpenAIStreamToAnthropic(openAIStream io.ReadCloser, anthropicStrea
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
deltaJSON, _ := json.Marshal(contentDelta)
|
deltaJSON, _ := json.Marshal(contentDelta)
|
||||||
fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON)
|
if _, err := fmt.Fprintf(anthropicStream, "data: %s\n\n", deltaJSON); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !sawDone {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
63
proxy/internal/provider/openai_test.go
Normal file
63
proxy/internal/provider/openai_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAIProviderForwardRequestClearsInboundAuthorization(t *testing.T) {
|
||||||
|
var gotAuthorization string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotAuthorization = r.Header.Get("Authorization")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"resp_1","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewOpenAIProvider(&config.OpenAIProviderConfig{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
}).(*OpenAIProvider)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "http://proxy.local/v1/messages", strings.NewReader(`{"model":"gpt-4o","messages":[],"max_tokens":16}`))
|
||||||
|
req.Header.Set("Authorization", "Bearer should-not-leak")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := provider.ForwardRequest(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ForwardRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if gotAuthorization != "" {
|
||||||
|
t.Fatalf("expected forwarded Authorization header to be empty, got %q", gotAuthorization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransformOpenAIStreamToAnthropicHandlesLargeEvents(t *testing.T) {
|
||||||
|
largeContent := strings.Repeat("x", 128*1024)
|
||||||
|
openAIStream := strings.NewReader("data: {\"id\":\"chatcmpl_1\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"" + largeContent + "\"}}]}\n\n" +
|
||||||
|
"data: [DONE]\n\n")
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
if err := transformOpenAIStreamToAnthropic(openAIStream, &output); err != nil {
|
||||||
|
t.Fatalf("transformOpenAIStreamToAnthropic() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := output.String()
|
||||||
|
if !strings.Contains(got, "\"message_start\"") {
|
||||||
|
t.Fatal("expected message_start event in output")
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, largeContent) {
|
||||||
|
t.Fatal("expected large content to be preserved in output")
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "\"message_stop\"") {
|
||||||
|
t.Fatal("expected message_stop event in output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,18 +1,33 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/config"
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
"github.com/seifghazi/claude-code-monitor/internal/model"
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StorageService interface {
|
type StorageService interface {
|
||||||
|
// Core CRUD operations
|
||||||
SaveRequest(request *model.RequestLog) (string, error)
|
SaveRequest(request *model.RequestLog) (string, error)
|
||||||
GetRequests(page, limit int) ([]model.RequestLog, int, error)
|
GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error)
|
||||||
|
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
|
||||||
|
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
|
||||||
ClearRequests() (int, error)
|
ClearRequests() (int, error)
|
||||||
|
|
||||||
|
// Update operations
|
||||||
UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error
|
UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error
|
||||||
UpdateRequestWithResponse(request *model.RequestLog) error
|
UpdateRequestWithResponse(request *model.RequestLog) error
|
||||||
EnsureDirectoryExists() error
|
|
||||||
GetRequestByShortID(shortID string) (*model.RequestLog, string, error)
|
// Maintenance operations
|
||||||
|
DeleteRequestsOlderThan(age time.Duration) (int, error)
|
||||||
|
GetDatabaseStats() (map[string]interface{}, error)
|
||||||
|
|
||||||
|
// Configuration
|
||||||
GetConfig() *config.StorageConfig
|
GetConfig() *config.StorageConfig
|
||||||
GetAllRequests(modelFilter string) ([]*model.RequestLog, error)
|
EnsureDirectoryExists() error
|
||||||
|
|
||||||
|
// Cleanup - implements io.Closer
|
||||||
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,9 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
|
@ -15,23 +17,63 @@ import (
|
||||||
type sqliteStorageService struct {
|
type sqliteStorageService struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
config *config.StorageConfig
|
config *config.StorageConfig
|
||||||
|
logger *log.Logger
|
||||||
|
|
||||||
|
// Prepared statements for frequently used queries
|
||||||
|
stmtInsertRequest *sql.Stmt
|
||||||
|
stmtUpdateResponse *sql.Stmt
|
||||||
|
stmtUpdateGrading *sql.Stmt
|
||||||
|
stmtGetRequestByID *sql.Stmt
|
||||||
|
stmtGetRequestsPage *sql.Stmt
|
||||||
|
stmtGetRequestsCount *sql.Stmt
|
||||||
|
stmtDeleteOldRequests *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) {
|
func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) {
|
||||||
db, err := sql.Open("sqlite3", cfg.DBPath)
|
return NewSQLiteStorageServiceWithLogger(cfg, log.Default())
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLiteStorageServiceWithLogger(cfg *config.StorageConfig, logger *log.Logger) (StorageService, error) {
|
||||||
|
// Enable WAL mode and other optimizations via connection string
|
||||||
|
// _journal_mode=WAL: Write-Ahead Logging for better concurrent read performance
|
||||||
|
// _synchronous=NORMAL: Good balance of safety and performance
|
||||||
|
// _busy_timeout=5000: Wait up to 5 seconds if database is locked
|
||||||
|
// _cache_size=-20000: Use 20MB of memory for cache (negative = KB)
|
||||||
|
connStr := cfg.DBPath + "?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000"
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", connStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure connection pool
|
||||||
|
// SQLite only supports one writer at a time, but can handle multiple readers
|
||||||
|
db.SetMaxOpenConns(1) // Serialize writes to avoid SQLITE_BUSY errors
|
||||||
|
db.SetMaxIdleConns(1)
|
||||||
|
db.SetConnMaxLifetime(time.Hour)
|
||||||
|
|
||||||
|
// Verify connection
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
service := &sqliteStorageService{
|
service := &sqliteStorageService{
|
||||||
db: db,
|
db: db,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := service.createTables(); err != nil {
|
if err := service.createTables(); err != nil {
|
||||||
|
db.Close()
|
||||||
return nil, fmt.Errorf("failed to create tables: %w", err)
|
return nil, fmt.Errorf("failed to create tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := service.prepareStatements(); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("failed to prepare statements: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -39,7 +81,7 @@ func (s *sqliteStorageService) createTables() error {
|
||||||
schema := `
|
schema := `
|
||||||
CREATE TABLE IF NOT EXISTS requests (
|
CREATE TABLE IF NOT EXISTS requests (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME NOT NULL,
|
||||||
method TEXT NOT NULL,
|
method TEXT NOT NULL,
|
||||||
endpoint TEXT NOT NULL,
|
endpoint TEXT NOT NULL,
|
||||||
headers TEXT NOT NULL,
|
headers TEXT NOT NULL,
|
||||||
|
|
@ -50,17 +92,100 @@ func (s *sqliteStorageService) createTables() error {
|
||||||
response TEXT,
|
response TEXT,
|
||||||
model TEXT,
|
model TEXT,
|
||||||
original_model TEXT,
|
original_model TEXT,
|
||||||
routed_model TEXT,
|
routed_model TEXT
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON requests(timestamp DESC);
|
-- Index for listing requests by time (most common query)
|
||||||
CREATE INDEX IF NOT EXISTS idx_endpoint ON requests(endpoint);
|
CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC);
|
||||||
CREATE INDEX IF NOT EXISTS idx_model ON requests(model);
|
|
||||||
|
-- Index for filtering by model
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_requests_model ON requests(model);
|
||||||
|
|
||||||
|
-- Index for filtering by endpoint
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_requests_endpoint ON requests(endpoint);
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := s.db.Exec(schema)
|
_, err := s.db.Exec(schema)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run migrations
|
||||||
|
s.migrateSchema()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sqliteStorageService) migrateSchema() {
|
||||||
|
// Ensure WAL mode is enabled (in case opened without connection string params)
|
||||||
|
_, err := s.db.Exec("PRAGMA journal_mode=WAL")
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to set WAL mode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop old redundant index if it exists (we renamed to idx_requests_timestamp)
|
||||||
|
s.db.Exec("DROP INDEX IF EXISTS idx_timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sqliteStorageService) prepareStatements() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
s.stmtInsertRequest, err = s.db.Prepare(`
|
||||||
|
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtUpdateResponse, err = s.db.Prepare(`
|
||||||
|
UPDATE requests SET response = ? WHERE id = ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare update response statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtUpdateGrading, err = s.db.Prepare(`
|
||||||
|
UPDATE requests SET prompt_grade = ? WHERE id = ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare update grading statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtGetRequestByID, err = s.db.Prepare(`
|
||||||
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
|
FROM requests
|
||||||
|
WHERE id = ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare get by ID statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtGetRequestsPage, err = s.db.Prepare(`
|
||||||
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
|
FROM requests
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare get requests page statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtGetRequestsCount, err = s.db.Prepare(`
|
||||||
|
SELECT COUNT(*) FROM requests
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare count statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stmtDeleteOldRequests, err = s.db.Prepare(`
|
||||||
|
DELETE FROM requests WHERE timestamp < ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare delete old requests statement: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) {
|
||||||
|
|
@ -74,12 +199,7 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
return "", fmt.Errorf("failed to marshal body: %w", err)
|
return "", fmt.Errorf("failed to marshal body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
_, err = s.stmtInsertRequest.Exec(
|
||||||
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model, original_model, routed_model)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err = s.db.Exec(query,
|
|
||||||
request.RequestID,
|
request.RequestID,
|
||||||
request.Timestamp,
|
request.Timestamp,
|
||||||
request.Method,
|
request.Method,
|
||||||
|
|
@ -100,10 +220,24 @@ func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, e
|
||||||
return request.RequestID, nil
|
return request.RequestID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, int, error) {
|
func (s *sqliteStorageService) GetRequests(page, limit int, modelFilter string) ([]model.RequestLog, int, error) {
|
||||||
|
whereClause := ""
|
||||||
|
countArgs := []interface{}{}
|
||||||
|
queryArgs := []interface{}{}
|
||||||
|
|
||||||
|
if modelFilter != "" && modelFilter != "all" {
|
||||||
|
// Escape LIKE special characters to prevent pattern injection
|
||||||
|
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
|
||||||
|
whereClause = " WHERE LOWER(model) LIKE ? ESCAPE '\\'"
|
||||||
|
filterValue := "%" + escapedFilter + "%"
|
||||||
|
countArgs = append(countArgs, filterValue)
|
||||||
|
queryArgs = append(queryArgs, filterValue)
|
||||||
|
}
|
||||||
|
|
||||||
// Get total count
|
// Get total count
|
||||||
var total int
|
var total int
|
||||||
err := s.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&total)
|
countQuery := "SELECT COUNT(*) FROM requests" + whereClause
|
||||||
|
err := s.db.QueryRow(countQuery, countArgs...).Scan(&total)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to get total count: %w", err)
|
return nil, 0, fmt.Errorf("failed to get total count: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -112,71 +246,21 @@ func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog,
|
||||||
offset := (page - 1) * limit
|
offset := (page - 1) * limit
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests` + whereClause + `
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
`
|
`
|
||||||
|
queryArgs = append(queryArgs, limit, offset)
|
||||||
|
|
||||||
rows, err := s.db.Query(query, limit, offset)
|
rows, err := s.db.Query(query, queryArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to query requests: %w", err)
|
return nil, 0, fmt.Errorf("failed to query requests: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var requests []model.RequestLog
|
requests, err := s.scanRequestRows(rows)
|
||||||
for rows.Next() {
|
if err != nil {
|
||||||
var req model.RequestLog
|
return nil, 0, err
|
||||||
var headersJSON, bodyJSON string
|
|
||||||
var promptGradeJSON, responseJSON sql.NullString
|
|
||||||
|
|
||||||
err := rows.Scan(
|
|
||||||
&req.RequestID,
|
|
||||||
&req.Timestamp,
|
|
||||||
&req.Method,
|
|
||||||
&req.Endpoint,
|
|
||||||
&headersJSON,
|
|
||||||
&bodyJSON,
|
|
||||||
&req.Model,
|
|
||||||
&req.UserAgent,
|
|
||||||
&req.ContentType,
|
|
||||||
&promptGradeJSON,
|
|
||||||
&responseJSON,
|
|
||||||
&req.OriginalModel,
|
|
||||||
&req.RoutedModel,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
// Error scanning row - skip
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
|
||||||
// Error unmarshaling headers
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var body interface{}
|
|
||||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
|
||||||
// Error unmarshaling body
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
req.Body = body
|
|
||||||
|
|
||||||
if promptGradeJSON.Valid {
|
|
||||||
var grade model.PromptGrade
|
|
||||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
|
||||||
req.PromptGrade = &grade
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseJSON.Valid {
|
|
||||||
var resp model.ResponseLog
|
|
||||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
|
||||||
req.Response = &resp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
requests = append(requests, req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return requests, total, nil
|
return requests, total, nil
|
||||||
|
|
@ -193,6 +277,12 @@ func (s *sqliteStorageService) ClearRequests() (int, error) {
|
||||||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reclaim space after clearing all data
|
||||||
|
_, err = s.db.Exec("VACUUM")
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to vacuum database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return int(rowsAffected), nil
|
return int(rowsAffected), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -202,12 +292,16 @@ func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade
|
||||||
return fmt.Errorf("failed to marshal grade: %w", err)
|
return fmt.Errorf("failed to marshal grade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := "UPDATE requests SET prompt_grade = ? WHERE id = ?"
|
result, err := s.stmtUpdateGrading.Exec(string(gradeJSON), requestID)
|
||||||
_, err = s.db.Exec(query, string(gradeJSON), requestID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update request with grading: %w", err)
|
return fmt.Errorf("failed to update request with grading: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rowsAffected, _ := result.RowsAffected()
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
return fmt.Errorf("request %s not found", requestID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -217,12 +311,72 @@ func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestL
|
||||||
return fmt.Errorf("failed to marshal response: %w", err)
|
return fmt.Errorf("failed to marshal response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := "UPDATE requests SET response = ? WHERE id = ?"
|
result, err := s.stmtUpdateResponse.Exec(string(responseJSON), request.RequestID)
|
||||||
_, err = s.db.Exec(query, string(responseJSON), request.RequestID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update request with response: %w", err)
|
return fmt.Errorf("failed to update request with response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rowsAffected, _ := result.RowsAffected()
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
return fmt.Errorf("request %s not found", request.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveRequestWithResponse saves a request and its response in a single transaction
|
||||||
|
func (s *sqliteStorageService) SaveRequestWithResponse(request *model.RequestLog) error {
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
headersJSON, err := json.Marshal(request.Headers)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal headers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyJSON, err := json.Marshal(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert request
|
||||||
|
_, err = tx.Stmt(s.stmtInsertRequest).Exec(
|
||||||
|
request.RequestID,
|
||||||
|
request.Timestamp,
|
||||||
|
request.Method,
|
||||||
|
request.Endpoint,
|
||||||
|
string(headersJSON),
|
||||||
|
string(bodyJSON),
|
||||||
|
request.UserAgent,
|
||||||
|
request.ContentType,
|
||||||
|
request.Model,
|
||||||
|
request.OriginalModel,
|
||||||
|
request.RoutedModel,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with response if present
|
||||||
|
if request.Response != nil {
|
||||||
|
responseJSON, err := json.Marshal(request.Response)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Stmt(s.stmtUpdateResponse).Exec(string(responseJSON), request.RequestID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update response: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -232,10 +386,13 @@ func (s *sqliteStorageService) EnsureDirectoryExists() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
|
||||||
|
// Escape LIKE special characters to prevent pattern injection
|
||||||
|
escapedID := escapeLikePattern(shortID)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
FROM requests
|
FROM requests
|
||||||
WHERE id LIKE ?
|
WHERE id LIKE ? ESCAPE '\'
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
@ -244,7 +401,7 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
||||||
var headersJSON, bodyJSON string
|
var headersJSON, bodyJSON string
|
||||||
var promptGradeJSON, responseJSON sql.NullString
|
var promptGradeJSON, responseJSON sql.NullString
|
||||||
|
|
||||||
err := s.db.QueryRow(query, "%"+shortID).Scan(
|
err := s.db.QueryRow(query, "%"+escapedID).Scan(
|
||||||
&req.RequestID,
|
&req.RequestID,
|
||||||
&req.Timestamp,
|
&req.Timestamp,
|
||||||
&req.Method,
|
&req.Method,
|
||||||
|
|
@ -267,29 +424,8 @@ func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.Reque
|
||||||
return nil, "", fmt.Errorf("failed to query request: %w", err)
|
return nil, "", fmt.Errorf("failed to query request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
return nil, "", err
|
||||||
return nil, "", fmt.Errorf("failed to unmarshal headers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var body interface{}
|
|
||||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
|
||||||
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
|
|
||||||
}
|
|
||||||
req.Body = body
|
|
||||||
|
|
||||||
if promptGradeJSON.Valid {
|
|
||||||
var grade model.PromptGrade
|
|
||||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
|
||||||
req.PromptGrade = &grade
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseJSON.Valid {
|
|
||||||
var resp model.ResponseLog
|
|
||||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
|
||||||
req.Response = &resp
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &req, req.RequestID, nil
|
return &req, req.RequestID, nil
|
||||||
|
|
@ -300,19 +436,36 @@ func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
|
||||||
query := `
|
return s.GetAllRequestsWithLimit(modelFilter, 0) // 0 means no limit
|
||||||
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
}
|
||||||
FROM requests
|
|
||||||
`
|
// GetAllRequestsWithLimit returns requests with an optional limit (0 = no limit)
|
||||||
|
func (s *sqliteStorageService) GetAllRequestsWithLimit(modelFilter string, limit int) ([]*model.RequestLog, error) {
|
||||||
|
var query string
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
||||||
if modelFilter != "" && modelFilter != "all" {
|
if modelFilter != "" && modelFilter != "all" {
|
||||||
query += " WHERE LOWER(model) LIKE ?"
|
// Escape LIKE special characters
|
||||||
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
|
escapedFilter := escapeLikePattern(strings.ToLower(modelFilter))
|
||||||
|
query = `
|
||||||
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
|
FROM requests
|
||||||
|
WHERE LOWER(model) LIKE ? ESCAPE '\'
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
`
|
||||||
|
args = append(args, "%"+escapedFilter+"%")
|
||||||
|
} else {
|
||||||
|
query = `
|
||||||
|
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response, original_model, routed_model
|
||||||
|
FROM requests
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
`
|
||||||
}
|
}
|
||||||
|
|
||||||
query += " ORDER BY timestamp DESC"
|
if limit > 0 {
|
||||||
|
query += " LIMIT ?"
|
||||||
|
args = append(args, limit)
|
||||||
|
}
|
||||||
|
|
||||||
rows, err := s.db.Query(query, args...)
|
rows, err := s.db.Query(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -321,6 +474,124 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var requests []*model.RequestLog
|
var requests []*model.RequestLog
|
||||||
|
for rows.Next() {
|
||||||
|
req, err := s.scanSingleRow(rows)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to scan request row: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
requests = append(requests, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return requests, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRequestsOlderThan removes requests older than the specified duration
|
||||||
|
func (s *sqliteStorageService) DeleteRequestsOlderThan(age time.Duration) (int, error) {
|
||||||
|
cutoff := time.Now().Add(-age)
|
||||||
|
|
||||||
|
result, err := s.stmtDeleteOldRequests.Exec(cutoff.Format(time.RFC3339))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to delete old requests: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(rowsAffected), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDatabaseStats returns statistics about the database
|
||||||
|
func (s *sqliteStorageService) GetDatabaseStats() (map[string]interface{}, error) {
|
||||||
|
stats := make(map[string]interface{})
|
||||||
|
|
||||||
|
// Get row count
|
||||||
|
var count int
|
||||||
|
err := s.stmtGetRequestsCount.QueryRow().Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get count: %w", err)
|
||||||
|
}
|
||||||
|
stats["total_requests"] = count
|
||||||
|
|
||||||
|
// Get database size
|
||||||
|
var pageCount, pageSize int
|
||||||
|
err = s.db.QueryRow("PRAGMA page_count").Scan(&pageCount)
|
||||||
|
if err == nil {
|
||||||
|
err = s.db.QueryRow("PRAGMA page_size").Scan(&pageSize)
|
||||||
|
if err == nil {
|
||||||
|
stats["database_size_bytes"] = pageCount * pageSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get oldest and newest timestamps
|
||||||
|
var oldest, newest sql.NullString
|
||||||
|
err = s.db.QueryRow("SELECT MIN(timestamp), MAX(timestamp) FROM requests").Scan(&oldest, &newest)
|
||||||
|
if err == nil {
|
||||||
|
if oldest.Valid {
|
||||||
|
stats["oldest_request"] = oldest.String
|
||||||
|
}
|
||||||
|
if newest.Valid {
|
||||||
|
stats["newest_request"] = newest.String
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sqliteStorageService) Close() error {
|
||||||
|
// Close prepared statements
|
||||||
|
if s.stmtInsertRequest != nil {
|
||||||
|
s.stmtInsertRequest.Close()
|
||||||
|
}
|
||||||
|
if s.stmtUpdateResponse != nil {
|
||||||
|
s.stmtUpdateResponse.Close()
|
||||||
|
}
|
||||||
|
if s.stmtUpdateGrading != nil {
|
||||||
|
s.stmtUpdateGrading.Close()
|
||||||
|
}
|
||||||
|
if s.stmtGetRequestByID != nil {
|
||||||
|
s.stmtGetRequestByID.Close()
|
||||||
|
}
|
||||||
|
if s.stmtGetRequestsPage != nil {
|
||||||
|
s.stmtGetRequestsPage.Close()
|
||||||
|
}
|
||||||
|
if s.stmtGetRequestsCount != nil {
|
||||||
|
s.stmtGetRequestsCount.Close()
|
||||||
|
}
|
||||||
|
if s.stmtDeleteOldRequests != nil {
|
||||||
|
s.stmtDeleteOldRequests.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checkpoint WAL before closing
|
||||||
|
_, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to checkpoint WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
// escapeLikePattern escapes special characters in LIKE patterns
|
||||||
|
func escapeLikePattern(s string) string {
|
||||||
|
// Escape \, %, and _ characters
|
||||||
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||||
|
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||||
|
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanRequestRows scans multiple rows into a slice of RequestLog
|
||||||
|
func (s *sqliteStorageService) scanRequestRows(rows *sql.Rows) ([]model.RequestLog, error) {
|
||||||
|
var requests []model.RequestLog
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var req model.RequestLog
|
var req model.RequestLog
|
||||||
var headersJSON, bodyJSON string
|
var headersJSON, bodyJSON string
|
||||||
|
|
@ -342,43 +613,86 @@ func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.Requ
|
||||||
&req.RoutedModel,
|
&req.RoutedModel,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Error scanning row - skip
|
s.logger.Printf("Warning: failed to scan row: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal JSON fields
|
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||||
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
s.logger.Printf("Warning: failed to unmarshal request fields: %v", err)
|
||||||
// Error unmarshaling headers
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var body interface{}
|
requests = append(requests, req)
|
||||||
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
}
|
||||||
// Error unmarshaling body
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
req.Body = body
|
|
||||||
|
|
||||||
if promptGradeJSON.Valid {
|
if err := rows.Err(); err != nil {
|
||||||
var grade model.PromptGrade
|
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||||
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
|
|
||||||
req.PromptGrade = &grade
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseJSON.Valid {
|
|
||||||
var resp model.ResponseLog
|
|
||||||
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
|
|
||||||
req.Response = &resp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
requests = append(requests, &req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return requests, nil
|
return requests, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqliteStorageService) Close() error {
|
// scanSingleRow scans a single row into a RequestLog pointer
|
||||||
return s.db.Close()
|
func (s *sqliteStorageService) scanSingleRow(rows *sql.Rows) (*model.RequestLog, error) {
|
||||||
|
var req model.RequestLog
|
||||||
|
var headersJSON, bodyJSON string
|
||||||
|
var promptGradeJSON, responseJSON sql.NullString
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&req.RequestID,
|
||||||
|
&req.Timestamp,
|
||||||
|
&req.Method,
|
||||||
|
&req.Endpoint,
|
||||||
|
&headersJSON,
|
||||||
|
&bodyJSON,
|
||||||
|
&req.Model,
|
||||||
|
&req.UserAgent,
|
||||||
|
&req.ContentType,
|
||||||
|
&promptGradeJSON,
|
||||||
|
&responseJSON,
|
||||||
|
&req.OriginalModel,
|
||||||
|
&req.RoutedModel,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.unmarshalRequestFields(&req, headersJSON, bodyJSON, promptGradeJSON, responseJSON); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalRequestFields unmarshals JSON fields into a RequestLog
|
||||||
|
func (s *sqliteStorageService) unmarshalRequestFields(req *model.RequestLog, headersJSON, bodyJSON string, promptGradeJSON, responseJSON sql.NullString) error {
|
||||||
|
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal headers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body interface{}
|
||||||
|
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal body: %w", err)
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
|
||||||
|
if promptGradeJSON.Valid {
|
||||||
|
var grade model.PromptGrade
|
||||||
|
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to unmarshal prompt grade: %v", err)
|
||||||
|
} else {
|
||||||
|
req.PromptGrade = &grade
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseJSON.Valid {
|
||||||
|
var resp model.ResponseLog
|
||||||
|
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err != nil {
|
||||||
|
s.logger.Printf("Warning: failed to unmarshal response: %v", err)
|
||||||
|
} else {
|
||||||
|
req.Response = &resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
69
proxy/internal/service/storage_sqlite_test.go
Normal file
69
proxy/internal/service/storage_sqlite_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||||||
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSQLiteStorageServiceGetRequestsUsesSQLPaginationAndFiltering(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "requests.db")
|
||||||
|
storage, err := NewSQLiteStorageService(&config.StorageConfig{DBPath: dbPath})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSQLiteStorageService() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqliteStorage, ok := storage.(*sqliteStorageService)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unexpected storage type %T", storage)
|
||||||
|
}
|
||||||
|
defer sqliteStorage.Close()
|
||||||
|
|
||||||
|
requests := []struct {
|
||||||
|
id string
|
||||||
|
model string
|
||||||
|
}{
|
||||||
|
{id: "1", model: "claude-3-5-sonnet"},
|
||||||
|
{id: "2", model: "gpt-4o"},
|
||||||
|
{id: "3", model: "claude-3-5-sonnet"},
|
||||||
|
{id: "4", model: "gpt-4o-mini"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, req := range requests {
|
||||||
|
_, err := storage.SaveRequest(&model.RequestLog{
|
||||||
|
RequestID: req.id,
|
||||||
|
Timestamp: time.Date(2026, 3, 19, 12, 0, i, 0, time.UTC).Format(time.RFC3339),
|
||||||
|
Method: "POST",
|
||||||
|
Endpoint: "/v1/messages",
|
||||||
|
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||||
|
Body: map[string]string{"request": fmt.Sprintf("body-%d", i)},
|
||||||
|
Model: req.model,
|
||||||
|
UserAgent: "test",
|
||||||
|
ContentType: "application/json",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got, total, err := storage.GetRequests(1, 1, "gpt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRequests() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if total != 2 {
|
||||||
|
t.Fatalf("expected filtered total 2, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1 paginated result, got %d", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got[0].RequestID != "4" {
|
||||||
|
t.Fatalf("expected newest filtered request ID 4, got %s", got[0].RequestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
30
proxy/internal/sse/sse.go
Normal file
30
proxy/internal/sse/sse.go
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
package sse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForEachLine reads line-oriented SSE content without bufio.Scanner's token limit.
|
||||||
|
func ForEachLine(r io.Reader, fn func(string) error) error {
|
||||||
|
reader := bufio.NewReader(r)
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(line) > 0 {
|
||||||
|
line = strings.TrimRight(line, "\r\n")
|
||||||
|
if callErr := fn(line); callErr != nil {
|
||||||
|
return callErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
31
proxy/internal/sse/sse_test.go
Normal file
31
proxy/internal/sse/sse_test.go
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
package sse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForEachLineHandlesLargeLines(t *testing.T) {
|
||||||
|
largePayload := strings.Repeat("x", 128*1024)
|
||||||
|
input := "data: " + largePayload + "\n\n"
|
||||||
|
|
||||||
|
var lines []string
|
||||||
|
if err := ForEachLine(strings.NewReader(input), func(line string) error {
|
||||||
|
lines = append(lines, line)
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("ForEachLine() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(lines) != 2 {
|
||||||
|
t.Fatalf("expected 2 lines, got %d", len(lines))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lines[0] != "data: "+largePayload {
|
||||||
|
t.Fatalf("unexpected first line length: got %d", len(lines[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lines[1] != "" {
|
||||||
|
t.Fatalf("expected blank separator line, got %q", lines[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue