claude-code-proxy/proxy/internal/handler/handlers_proxy_test.go

303 lines
10 KiB
Go
Raw Normal View History

Local fork: hardening + ops improvements (timeout knob, demotion, /livez, drain) This commit captures both the prior accumulated work-in-progress (framework migration web/→svelte/, postgres storage, conversation viewer, dashboard auth, OpenAPI spec, integration tests) AND today's operational improvements layered on top. History wasn't checkpointed incrementally; happy to split it via interactive rebase if a reviewer wants smaller commits. Today's changes (in addition to the older WIP): 1. Configurable upstream response-header timeout - ANTHROPIC_RESPONSE_HEADER_TIMEOUT env (default 300s) - Replaces hardcoded 300s in provider/anthropic.go that was firing on opus + 1M-context + extended thinking non-streaming requests - Files: internal/config/config.go, internal/provider/anthropic.go 2. Structured forward-error diagnostic logging - When a forward to Anthropic fails, log a single key=value line with request_id, model, stream, body_bytes, has_thinking, anthropic_beta, query, elapsed, ctx_err — alongside the existing human-readable error line for back-compat - Files: internal/handler/handlers.go (logForwardFailure) 3. Full SSE protocol passthrough + Flusher fix - handler/handlers.go: forward all SSE lines verbatim (event:, id:, retry:, : comments, blank-line terminators), not only data:. Previous code produced malformed SSE for strict parsers. - middleware/logging.go: explicit Flush() method on responseWriter. Embedding http.ResponseWriter (interface) does not auto-promote Flush(), so every w.(http.Flusher) check in the streaming handler was returning ok=false and SSE writes buffered in net/http until the body closed. 4. Non-streaming → streaming demotion (feature-flagged) - ANTHROPIC_DEMOTE_NONSTREAMING env (default false) - When enabled and the routed provider is anthropic, force stream=true upstream for clients that asked for stream=false. Receive SSE, accumulate via accumulateSSEToMessage (handles text, tool_use with partial_json reassembly, thinking, signature, citations_delta, usage merge), and synthesize a single non-streaming JSON response. - Eliminates the ResponseHeaderTimeout class of failure entirely. - Body rewrite uses json.Decoder + UseNumber() to preserve integer precision in unknown nested fields (tool inputs from prior turns). - Files: internal/config/config.go, internal/handler/handlers.go, cmd/proxy/main.go, cmd/proxy/main_test.go 5. Live operational state: /livez gauge + graceful drain - New internal/runtime package: atomic in-flight counter + draining flag - New middleware/inflight.go: increments runtime gauge, applied to /v1/* subrouter so Messages, ChatCompletions, and ProxyPassthrough are all counted - /v1/* moved to a gorilla/mux subrouter so the InFlight middleware applies surgically; /health, /livez, /openapi.* remain on parent router (unauthenticated, uncounted) - Health handler returns 503 draining when runtime.IsDraining() is true, so Traefik stops routing to a slot before drain begins - New /livez handler returns {status, in_flight, draining, timestamp} - SIGTERM handler in main.go: SetDraining(true), poll for in_flight==0 with 32-min ceiling and 1s tick (logs every 10s), then srv.Shutdown - Auth bypass list extended with /livez - Files: internal/runtime/runtime.go (new), internal/middleware/inflight.go (new), internal/middleware/auth.go, internal/handler/handlers.go (Health, Livez, runtime import), cmd/proxy/main.go (subrouter, drain loop) 6. OpenAPI spec updates - Document Health 503 response and new DrainingResponse schema - Add /livez path with LivezResponse schema - Files: internal/handler/openapi.go Verified: go build ./... clean, go test ./... all pass, go vet clean. Three rounds of codex peer review across changes 1-5; all feedback addressed (citations_delta, json.Number precision, drain-loop logging via lastLog timestamp, PathPrefix tightened to "/v1/").
2026-05-02 15:15:58 -06:00
package handler
import (
"bytes"
"context"
"io"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/seifghazi/claude-code-monitor/internal/config"
"github.com/seifghazi/claude-code-monitor/internal/model"
"github.com/seifghazi/claude-code-monitor/internal/provider"
)
type passthroughStorageStub struct {
saveRequestFn func(*model.RequestLog) (string, error)
updateRequestFn func(*model.RequestLog) error
getSettingsFn func() (*model.ProxySettings, error)
savedRequests int
updatedRequests int
}
func (s *passthroughStorageStub) SaveRequest(request *model.RequestLog) (string, error) {
s.savedRequests++
if s.saveRequestFn != nil {
return s.saveRequestFn(request)
}
return "req-1", nil
}
func (s *passthroughStorageStub) GetRequests(int, int, string) ([]model.RequestLog, int, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetAllRequests(string) ([]*model.RequestLog, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetRequestByShortID(string) (*model.RequestLog, string, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) ClearRequests() (int, error) { panic("unexpected call") }
func (s *passthroughStorageStub) UpdateRequestWithGrading(string, *model.PromptGrade) error {
panic("unexpected call")
}
func (s *passthroughStorageStub) UpdateRequestWithResponse(request *model.RequestLog) error {
s.updatedRequests++
if s.updateRequestFn != nil {
return s.updateRequestFn(request)
}
return nil
}
func (s *passthroughStorageStub) DeleteRequestsOlderThan(time.Duration) (int, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetDatabaseStats() (map[string]interface{}, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetUsageStats(string, string, string, string) (*model.UsageStats, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetRequestsSummary(string) ([]*model.RequestSummary, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetRequestsSummaryPaginated(string, string, string, int, int) ([]*model.RequestSummary, int, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetStats(string, string, string) (*model.DashboardStats, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetHourlyStats(string, string, int, string) (*model.HourlyStatsResponse, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetModelStats(string, string, string) (*model.ModelStatsResponse, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetLatestRequestDate() (*time.Time, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetDistinctOrganizations() ([]string, error) {
panic("unexpected call")
}
func (s *passthroughStorageStub) GetSettings() (*model.ProxySettings, error) {
if s.getSettingsFn != nil {
return s.getSettingsFn()
}
return &model.ProxySettings{}, nil
}
func (s *passthroughStorageStub) SaveSettings(*model.ProxySettings) error { panic("unexpected call") }
func (s *passthroughStorageStub) GetConfig() *config.StorageConfig { return &config.StorageConfig{} }
func (s *passthroughStorageStub) EnsureDirectoryExists() error { return nil }
func (s *passthroughStorageStub) Close() error { return nil }
type passthroughProviderStub struct {
forwardRequestFn func(context.Context, *http.Request) (*http.Response, error)
}
func (p *passthroughProviderStub) Name() string { return "stub" }
func (p *passthroughProviderStub) ForwardRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
if p.forwardRequestFn != nil {
return p.forwardRequestFn(ctx, req)
}
panic("unexpected call")
}
var _ provider.Provider = (*passthroughProviderStub)(nil)
func TestChatCompletionsReturnsHelpfulError(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
rr := httptest.NewRecorder()
(&Handler{}).ChatCompletions(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rr.Code)
}
var response model.ErrorResponse
decodeJSONBody(t, rr, &response)
if response.Error == "" {
t.Fatal("expected non-empty error message")
}
}
func TestModelsReturnsEmptyList(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
rr := httptest.NewRecorder()
(&Handler{}).Models(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
var response model.ModelsResponse
decodeJSONBody(t, rr, &response)
if response.Object != "list" || len(response.Data) != 0 {
t.Fatalf("unexpected models response: %#v", response)
}
}
func TestHealthReturnsHealthyStatus(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
(&Handler{}).Health(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
var response model.HealthResponse
decodeJSONBody(t, rr, &response)
if response.Status != "healthy" {
t.Fatalf("unexpected health response: %#v", response)
}
if response.Timestamp.IsZero() {
t.Fatalf("expected non-zero timestamp: %#v", response)
}
}
func TestOpenAPIEndpointsExposeDiscoveryFormats(t *testing.T) {
t.Run("json", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/openapi.json", nil)
rr := httptest.NewRecorder()
(&Handler{}).OpenAPIJSON(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if ct := rr.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected json content type, got %q", ct)
}
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Fatalf("expected CORS header, got %#v", rr.Header())
}
var payload map[string]interface{}
decodeJSONBody(t, rr, &payload)
if payload["openapi"] == nil || payload["paths"] == nil {
t.Fatalf("unexpected openapi json payload: %#v", payload)
}
})
t.Run("yaml", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/openapi.yaml", nil)
rr := httptest.NewRecorder()
(&Handler{}).OpenAPIYAML(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if ct := rr.Header().Get("Content-Type"); ct != "application/x-yaml" {
t.Fatalf("expected yaml content type, got %q", ct)
}
if body := rr.Body.String(); !bytes.Contains([]byte(body), []byte("openapi:")) {
t.Fatalf("expected yaml body to contain openapi header, got %q", body)
}
})
}
func TestProxyPassthroughForwardsResponseAndPersistsMetadata(t *testing.T) {
storage := &passthroughStorageStub{}
upstreamResponse := `{"ok":true}`
provider := &passthroughProviderStub{
forwardRequestFn: func(ctx context.Context, req *http.Request) (*http.Response, error) {
if req.Header.Get("X-Added") != "from-test" {
t.Fatalf("expected request header rule to be applied, got headers %#v", req.Header)
}
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("failed reading forwarded request body: %v", err)
}
if string(bodyBytes) != `{"input":"hello"}` {
t.Fatalf("unexpected forwarded body %q", string(bodyBytes))
}
resp := &http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Upstream": []string{"yes"},
"Connection": []string{"keep-alive"},
"Anthropic-Organization-Id": []string{"org-123"},
},
Body: io.NopCloser(bytes.NewBufferString(upstreamResponse)),
}
return resp, nil
},
}
storage.saveRequestFn = func(request *model.RequestLog) (string, error) {
if request.Method != http.MethodPost || request.Endpoint != "/v1/quota" {
t.Fatalf("unexpected saved request metadata: %#v", request)
}
if request.ContentType != "application/json" {
t.Fatalf("unexpected content type %q", request.ContentType)
}
bodyMap, ok := request.Body.(map[string]interface{})
if !ok || bodyMap["input"] != "hello" {
t.Fatalf("expected parsed request body, got %#v", request.Body)
}
if authValues := request.Headers["Authorization"]; len(authValues) != 1 || authValues[0] == "Bearer secret" || authValues[0] == "" {
t.Fatalf("expected sanitized authorization header, got %#v", request.Headers)
}
return "req-1", nil
}
storage.updateRequestFn = func(request *model.RequestLog) error {
if request.OrganizationID != "org-123" {
t.Fatalf("expected organization id extracted, got %#v", request)
}
if request.Response == nil || request.Response.StatusCode != http.StatusAccepted {
t.Fatalf("expected response metadata recorded, got %#v", request.Response)
}
if string(request.Response.Body) != upstreamResponse {
t.Fatalf("expected json response body stored, got %#v", request.Response)
}
if connectionValues := request.Response.Headers["Connection"]; len(connectionValues) != 0 {
t.Fatalf("expected sanitized stored headers to exclude hop-by-hop data, got %#v", request.Response.Headers)
}
return nil
}
h := &Handler{
storageService: storage,
anthropicProvider: provider,
logger: log.New(io.Discard, "", 0),
cachedSettings: &model.ProxySettings{
RequestHeaderRules: []model.HeaderRule{{Header: "X-Added", Action: "set", Value: "from-test", Enabled: true}},
ResponseHeaderRules: []model.HeaderRule{{Header: "X-Proxy", Action: "set", Value: "applied", Enabled: true}},
},
}
req := httptest.NewRequest(http.MethodPost, "/v1/quota", bytes.NewBufferString(`{"input":"hello"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rr := httptest.NewRecorder()
h.ProxyPassthrough(rr, req)
if rr.Code != http.StatusAccepted {
t.Fatalf("expected 202, got %d", rr.Code)
}
if rr.Body.String() != upstreamResponse {
t.Fatalf("unexpected client body %q", rr.Body.String())
}
if rr.Header().Get("X-Upstream") != "yes" {
t.Fatalf("expected upstream header to be forwarded, got %#v", rr.Header())
}
if rr.Header().Get("X-Proxy") != "applied" {
t.Fatalf("expected response header rule to be applied, got %#v", rr.Header())
}
if rr.Header().Get("Connection") != "" {
t.Fatalf("expected hop-by-hop header to be stripped, got %#v", rr.Header())
}
if storage.savedRequests != 1 || storage.updatedRequests != 1 {
t.Fatalf("expected request save/update pair, got saves=%d updates=%d", storage.savedRequests, storage.updatedRequests)
}
}