package handler import ( "bytes" "context" "io" "log" "net/http" "net/http/httptest" "testing" "time" "github.com/seifghazi/claude-code-monitor/internal/config" "github.com/seifghazi/claude-code-monitor/internal/model" "github.com/seifghazi/claude-code-monitor/internal/provider" ) type passthroughStorageStub struct { saveRequestFn func(*model.RequestLog) (string, error) updateRequestFn func(*model.RequestLog) error getSettingsFn func() (*model.ProxySettings, error) savedRequests int updatedRequests int } func (s *passthroughStorageStub) SaveRequest(request *model.RequestLog) (string, error) { s.savedRequests++ if s.saveRequestFn != nil { return s.saveRequestFn(request) } return "req-1", nil } func (s *passthroughStorageStub) GetRequests(int, int, string) ([]model.RequestLog, int, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetAllRequests(string) ([]*model.RequestLog, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetRequestByShortID(string) (*model.RequestLog, string, error) { panic("unexpected call") } func (s *passthroughStorageStub) ClearRequests() (int, error) { panic("unexpected call") } func (s *passthroughStorageStub) UpdateRequestWithGrading(string, *model.PromptGrade) error { panic("unexpected call") } func (s *passthroughStorageStub) UpdateRequestWithResponse(request *model.RequestLog) error { s.updatedRequests++ if s.updateRequestFn != nil { return s.updateRequestFn(request) } return nil } func (s *passthroughStorageStub) DeleteRequestsOlderThan(time.Duration) (int, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetDatabaseStats() (map[string]interface{}, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetUsageStats(string, string, string, string) (*model.UsageStats, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetRequestsSummary(string) ([]*model.RequestSummary, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetRequestsSummaryPaginated(string, string, string, int, int) ([]*model.RequestSummary, int, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetStats(string, string, string) (*model.DashboardStats, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetHourlyStats(string, string, int, string) (*model.HourlyStatsResponse, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetModelStats(string, string, string) (*model.ModelStatsResponse, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetLatestRequestDate() (*time.Time, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetDistinctOrganizations() ([]string, error) { panic("unexpected call") } func (s *passthroughStorageStub) GetSettings() (*model.ProxySettings, error) { if s.getSettingsFn != nil { return s.getSettingsFn() } return &model.ProxySettings{}, nil } func (s *passthroughStorageStub) SaveSettings(*model.ProxySettings) error { panic("unexpected call") } func (s *passthroughStorageStub) GetConfig() *config.StorageConfig { return &config.StorageConfig{} } func (s *passthroughStorageStub) EnsureDirectoryExists() error { return nil } func (s *passthroughStorageStub) Close() error { return nil } type passthroughProviderStub struct { forwardRequestFn func(context.Context, *http.Request) (*http.Response, error) } func (p *passthroughProviderStub) Name() string { return "stub" } func (p *passthroughProviderStub) ForwardRequest(ctx context.Context, req *http.Request) (*http.Response, error) { if p.forwardRequestFn != nil { return p.forwardRequestFn(ctx, req) } panic("unexpected call") } var _ provider.Provider = (*passthroughProviderStub)(nil) func TestChatCompletionsReturnsHelpfulError(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) rr := httptest.NewRecorder() (&Handler{}).ChatCompletions(rr, req) if rr.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d", rr.Code) } var response model.ErrorResponse decodeJSONBody(t, rr, &response) if response.Error == "" { t.Fatal("expected non-empty error message") } } func TestModelsReturnsEmptyList(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) rr := httptest.NewRecorder() (&Handler{}).Models(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } var response model.ModelsResponse decodeJSONBody(t, rr, &response) if response.Object != "list" || len(response.Data) != 0 { t.Fatalf("unexpected models response: %#v", response) } } func TestHealthReturnsHealthyStatus(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/health", nil) rr := httptest.NewRecorder() (&Handler{}).Health(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } var response model.HealthResponse decodeJSONBody(t, rr, &response) if response.Status != "healthy" { t.Fatalf("unexpected health response: %#v", response) } if response.Timestamp.IsZero() { t.Fatalf("expected non-zero timestamp: %#v", response) } } func TestOpenAPIEndpointsExposeDiscoveryFormats(t *testing.T) { t.Run("json", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/openapi.json", nil) rr := httptest.NewRecorder() (&Handler{}).OpenAPIJSON(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } if ct := rr.Header().Get("Content-Type"); ct != "application/json" { t.Fatalf("expected json content type, got %q", ct) } if rr.Header().Get("Access-Control-Allow-Origin") != "*" { t.Fatalf("expected CORS header, got %#v", rr.Header()) } var payload map[string]interface{} decodeJSONBody(t, rr, &payload) if payload["openapi"] == nil || payload["paths"] == nil { t.Fatalf("unexpected openapi json payload: %#v", payload) } }) t.Run("yaml", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/openapi.yaml", nil) rr := httptest.NewRecorder() (&Handler{}).OpenAPIYAML(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } if ct := rr.Header().Get("Content-Type"); ct != "application/x-yaml" { t.Fatalf("expected yaml content type, got %q", ct) } if body := rr.Body.String(); !bytes.Contains([]byte(body), []byte("openapi:")) { t.Fatalf("expected yaml body to contain openapi header, got %q", body) } }) } func TestProxyPassthroughForwardsResponseAndPersistsMetadata(t *testing.T) { storage := &passthroughStorageStub{} upstreamResponse := `{"ok":true}` provider := &passthroughProviderStub{ forwardRequestFn: func(ctx context.Context, req *http.Request) (*http.Response, error) { if req.Header.Get("X-Added") != "from-test" { t.Fatalf("expected request header rule to be applied, got headers %#v", req.Header) } bodyBytes, err := io.ReadAll(req.Body) if err != nil { t.Fatalf("failed reading forwarded request body: %v", err) } if string(bodyBytes) != `{"input":"hello"}` { t.Fatalf("unexpected forwarded body %q", string(bodyBytes)) } resp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ "Content-Type": []string{"application/json"}, "X-Upstream": []string{"yes"}, "Connection": []string{"keep-alive"}, "Anthropic-Organization-Id": []string{"org-123"}, }, Body: io.NopCloser(bytes.NewBufferString(upstreamResponse)), } return resp, nil }, } storage.saveRequestFn = func(request *model.RequestLog) (string, error) { if request.Method != http.MethodPost || request.Endpoint != "/v1/quota" { t.Fatalf("unexpected saved request metadata: %#v", request) } if request.ContentType != "application/json" { t.Fatalf("unexpected content type %q", request.ContentType) } bodyMap, ok := request.Body.(map[string]interface{}) if !ok || bodyMap["input"] != "hello" { t.Fatalf("expected parsed request body, got %#v", request.Body) } if authValues := request.Headers["Authorization"]; len(authValues) != 1 || authValues[0] == "Bearer secret" || authValues[0] == "" { t.Fatalf("expected sanitized authorization header, got %#v", request.Headers) } return "req-1", nil } storage.updateRequestFn = func(request *model.RequestLog) error { if request.OrganizationID != "org-123" { t.Fatalf("expected organization id extracted, got %#v", request) } if request.Response == nil || request.Response.StatusCode != http.StatusAccepted { t.Fatalf("expected response metadata recorded, got %#v", request.Response) } if string(request.Response.Body) != upstreamResponse { t.Fatalf("expected json response body stored, got %#v", request.Response) } if connectionValues := request.Response.Headers["Connection"]; len(connectionValues) != 0 { t.Fatalf("expected sanitized stored headers to exclude hop-by-hop data, got %#v", request.Response.Headers) } return nil } h := &Handler{ storageService: storage, anthropicProvider: provider, logger: log.New(io.Discard, "", 0), cachedSettings: &model.ProxySettings{ RequestHeaderRules: []model.HeaderRule{{Header: "X-Added", Action: "set", Value: "from-test", Enabled: true}}, ResponseHeaderRules: []model.HeaderRule{{Header: "X-Proxy", Action: "set", Value: "applied", Enabled: true}}, }, } req := httptest.NewRequest(http.MethodPost, "/v1/quota", bytes.NewBufferString(`{"input":"hello"}`)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer secret") rr := httptest.NewRecorder() h.ProxyPassthrough(rr, req) if rr.Code != http.StatusAccepted { t.Fatalf("expected 202, got %d", rr.Code) } if rr.Body.String() != upstreamResponse { t.Fatalf("unexpected client body %q", rr.Body.String()) } if rr.Header().Get("X-Upstream") != "yes" { t.Fatalf("expected upstream header to be forwarded, got %#v", rr.Header()) } if rr.Header().Get("X-Proxy") != "applied" { t.Fatalf("expected response header rule to be applied, got %#v", rr.Header()) } if rr.Header().Get("Connection") != "" { t.Fatalf("expected hop-by-hop header to be stripped, got %#v", rr.Header()) } if storage.savedRequests != 1 || storage.updatedRequests != 1 { t.Fatalf("expected request save/update pair, got saves=%d updates=%d", storage.savedRequests, storage.updatedRequests) } }