claude-code-proxy/proxy/internal/service/storage_sqlite.go
Seif Ghazi ae71ec4f72
Ready
2025-06-29 20:50:04 -04:00

386 lines
9.3 KiB
Go

package service
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/seifghazi/claude-code-monitor/internal/config"
"github.com/seifghazi/claude-code-monitor/internal/model"
)
type sqliteStorageService struct {
db *sql.DB
config *config.StorageConfig
}
func NewSQLiteStorageService(cfg *config.StorageConfig) (StorageService, error) {
db, err := sql.Open("sqlite3", cfg.DBPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
service := &sqliteStorageService{
db: db,
config: cfg,
}
if err := service.createTables(); err != nil {
return nil, fmt.Errorf("failed to create tables: %w", err)
}
return service, nil
}
func (s *sqliteStorageService) createTables() error {
schema := `
CREATE TABLE IF NOT EXISTS requests (
id TEXT PRIMARY KEY,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
method TEXT NOT NULL,
endpoint TEXT NOT NULL,
headers TEXT NOT NULL,
body TEXT NOT NULL,
user_agent TEXT,
content_type TEXT,
prompt_grade TEXT,
response TEXT,
model TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_timestamp ON requests(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_endpoint ON requests(endpoint);
CREATE INDEX IF NOT EXISTS idx_model ON requests(model);
`
_, err := s.db.Exec(schema)
return err
}
func (s *sqliteStorageService) SaveRequest(request *model.RequestLog) (string, error) {
headersJSON, err := json.Marshal(request.Headers)
if err != nil {
return "", fmt.Errorf("failed to marshal headers: %w", err)
}
bodyJSON, err := json.Marshal(request.Body)
if err != nil {
return "", fmt.Errorf("failed to marshal body: %w", err)
}
// Extract model from body if available
var modelName string
if body, ok := request.Body.(map[string]interface{}); ok {
if model, ok := body["model"].(string); ok {
modelName = model
request.Model = model // Also set it in the struct
}
}
query := `
INSERT INTO requests (id, timestamp, method, endpoint, headers, body, user_agent, content_type, model)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err = s.db.Exec(query,
request.RequestID,
request.Timestamp,
request.Method,
request.Endpoint,
string(headersJSON),
string(bodyJSON),
request.UserAgent,
request.ContentType,
modelName,
)
if err != nil {
return "", fmt.Errorf("failed to insert request: %w", err)
}
return request.RequestID, nil
}
func (s *sqliteStorageService) GetRequests(page, limit int) ([]model.RequestLog, int, error) {
// Get total count
var total int
err := s.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("failed to get total count: %w", err)
}
// Get paginated results
offset := (page - 1) * limit
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
FROM requests
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
`
rows, err := s.db.Query(query, limit, offset)
if err != nil {
return nil, 0, fmt.Errorf("failed to query requests: %w", err)
}
defer rows.Close()
var requests []model.RequestLog
for rows.Next() {
var req model.RequestLog
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := rows.Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
&req.Endpoint,
&headersJSON,
&bodyJSON,
&req.Model,
&req.UserAgent,
&req.ContentType,
&promptGradeJSON,
&responseJSON,
)
if err != nil {
log.Printf("Error scanning row: %v", err)
continue
}
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
log.Printf("Error unmarshaling headers: %v", err)
continue
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
log.Printf("Error unmarshaling body: %v", err)
continue
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
}
requests = append(requests, req)
}
return requests, total, nil
}
func (s *sqliteStorageService) ClearRequests() (int, error) {
result, err := s.db.Exec("DELETE FROM requests")
if err != nil {
return 0, fmt.Errorf("failed to clear requests: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
func (s *sqliteStorageService) UpdateRequestWithGrading(requestID string, grade *model.PromptGrade) error {
gradeJSON, err := json.Marshal(grade)
if err != nil {
return fmt.Errorf("failed to marshal grade: %w", err)
}
query := "UPDATE requests SET prompt_grade = ? WHERE id = ?"
_, err = s.db.Exec(query, string(gradeJSON), requestID)
if err != nil {
return fmt.Errorf("failed to update request with grading: %w", err)
}
return nil
}
func (s *sqliteStorageService) UpdateRequestWithResponse(request *model.RequestLog) error {
responseJSON, err := json.Marshal(request.Response)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
query := "UPDATE requests SET response = ? WHERE id = ?"
_, err = s.db.Exec(query, string(responseJSON), request.RequestID)
if err != nil {
return fmt.Errorf("failed to update request with response: %w", err)
}
return nil
}
func (s *sqliteStorageService) EnsureDirectoryExists() error {
// No directory needed for SQLite
return nil
}
func (s *sqliteStorageService) GetRequestByShortID(shortID string) (*model.RequestLog, string, error) {
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
FROM requests
WHERE id LIKE ?
ORDER BY timestamp DESC
LIMIT 1
`
var req model.RequestLog
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := s.db.QueryRow(query, "%"+shortID).Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
&req.Endpoint,
&headersJSON,
&bodyJSON,
&req.Model,
&req.UserAgent,
&req.ContentType,
&promptGradeJSON,
&responseJSON,
)
if err == sql.ErrNoRows {
return nil, "", fmt.Errorf("request with ID %s not found", shortID)
}
if err != nil {
return nil, "", fmt.Errorf("failed to query request: %w", err)
}
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
return nil, "", fmt.Errorf("failed to unmarshal headers: %w", err)
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
}
return &req, req.RequestID, nil
}
func (s *sqliteStorageService) GetConfig() *config.StorageConfig {
return s.config
}
func (s *sqliteStorageService) GetAllRequests(modelFilter string) ([]*model.RequestLog, error) {
query := `
SELECT id, timestamp, method, endpoint, headers, body, model, user_agent, content_type, prompt_grade, response
FROM requests
`
args := []interface{}{}
if modelFilter != "" && modelFilter != "all" {
query += " WHERE LOWER(model) LIKE ?"
args = append(args, "%"+strings.ToLower(modelFilter)+"%")
log.Printf("🔍 SQL Query with filter: %s, args: %v", query, args)
}
query += " ORDER BY timestamp DESC"
rows, err := s.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query requests: %w", err)
}
defer rows.Close()
var requests []*model.RequestLog
for rows.Next() {
var req model.RequestLog
var headersJSON, bodyJSON string
var promptGradeJSON, responseJSON sql.NullString
err := rows.Scan(
&req.RequestID,
&req.Timestamp,
&req.Method,
&req.Endpoint,
&headersJSON,
&bodyJSON,
&req.Model,
&req.UserAgent,
&req.ContentType,
&promptGradeJSON,
&responseJSON,
)
if err != nil {
log.Printf("Error scanning row: %v", err)
continue
}
log.Printf("🔍 Scanned request - ID: %s, Model: %s", req.RequestID, req.Model)
// Unmarshal JSON fields
if err := json.Unmarshal([]byte(headersJSON), &req.Headers); err != nil {
log.Printf("Error unmarshaling headers: %v", err)
continue
}
var body interface{}
if err := json.Unmarshal([]byte(bodyJSON), &body); err != nil {
log.Printf("Error unmarshaling body: %v", err)
continue
}
req.Body = body
if promptGradeJSON.Valid {
var grade model.PromptGrade
if err := json.Unmarshal([]byte(promptGradeJSON.String), &grade); err == nil {
req.PromptGrade = &grade
}
}
if responseJSON.Valid {
var resp model.ResponseLog
if err := json.Unmarshal([]byte(responseJSON.String), &resp); err == nil {
req.Response = &resp
}
}
requests = append(requests, &req)
}
return requests, nil
}
func (s *sqliteStorageService) Close() error {
return s.db.Close()
}