303 lines
10 KiB
Go
303 lines
10 KiB
Go
|
|
package handler
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"io"
|
||
|
|
"log"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
||
|
|
"github.com/seifghazi/claude-code-monitor/internal/model"
|
||
|
|
"github.com/seifghazi/claude-code-monitor/internal/provider"
|
||
|
|
)
|
||
|
|
|
||
|
|
type passthroughStorageStub struct {
|
||
|
|
saveRequestFn func(*model.RequestLog) (string, error)
|
||
|
|
updateRequestFn func(*model.RequestLog) error
|
||
|
|
getSettingsFn func() (*model.ProxySettings, error)
|
||
|
|
savedRequests int
|
||
|
|
updatedRequests int
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *passthroughStorageStub) SaveRequest(request *model.RequestLog) (string, error) {
|
||
|
|
s.savedRequests++
|
||
|
|
if s.saveRequestFn != nil {
|
||
|
|
return s.saveRequestFn(request)
|
||
|
|
}
|
||
|
|
return "req-1", nil
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetRequests(int, int, string) ([]model.RequestLog, int, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetAllRequests(string) ([]*model.RequestLog, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetRequestByShortID(string) (*model.RequestLog, string, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) ClearRequests() (int, error) { panic("unexpected call") }
|
||
|
|
func (s *passthroughStorageStub) UpdateRequestWithGrading(string, *model.PromptGrade) error {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) UpdateRequestWithResponse(request *model.RequestLog) error {
|
||
|
|
s.updatedRequests++
|
||
|
|
if s.updateRequestFn != nil {
|
||
|
|
return s.updateRequestFn(request)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) DeleteRequestsOlderThan(time.Duration) (int, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetDatabaseStats() (map[string]interface{}, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetUsageStats(string, string, string, string) (*model.UsageStats, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetRequestsSummary(string) ([]*model.RequestSummary, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetRequestsSummaryPaginated(string, string, string, int, int) ([]*model.RequestSummary, int, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetStats(string, string, string) (*model.DashboardStats, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetHourlyStats(string, string, int, string) (*model.HourlyStatsResponse, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetModelStats(string, string, string) (*model.ModelStatsResponse, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetLatestRequestDate() (*time.Time, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetDistinctOrganizations() ([]string, error) {
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) GetSettings() (*model.ProxySettings, error) {
|
||
|
|
if s.getSettingsFn != nil {
|
||
|
|
return s.getSettingsFn()
|
||
|
|
}
|
||
|
|
return &model.ProxySettings{}, nil
|
||
|
|
}
|
||
|
|
func (s *passthroughStorageStub) SaveSettings(*model.ProxySettings) error { panic("unexpected call") }
|
||
|
|
func (s *passthroughStorageStub) GetConfig() *config.StorageConfig { return &config.StorageConfig{} }
|
||
|
|
func (s *passthroughStorageStub) EnsureDirectoryExists() error { return nil }
|
||
|
|
func (s *passthroughStorageStub) Close() error { return nil }
|
||
|
|
|
||
|
|
type passthroughProviderStub struct {
|
||
|
|
forwardRequestFn func(context.Context, *http.Request) (*http.Response, error)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (p *passthroughProviderStub) Name() string { return "stub" }
|
||
|
|
func (p *passthroughProviderStub) ForwardRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||
|
|
if p.forwardRequestFn != nil {
|
||
|
|
return p.forwardRequestFn(ctx, req)
|
||
|
|
}
|
||
|
|
panic("unexpected call")
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ provider.Provider = (*passthroughProviderStub)(nil)
|
||
|
|
|
||
|
|
func TestChatCompletionsReturnsHelpfulError(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
(&Handler{}).ChatCompletions(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
var response model.ErrorResponse
|
||
|
|
decodeJSONBody(t, rr, &response)
|
||
|
|
if response.Error == "" {
|
||
|
|
t.Fatal("expected non-empty error message")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestModelsReturnsEmptyList(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
(&Handler{}).Models(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
var response model.ModelsResponse
|
||
|
|
decodeJSONBody(t, rr, &response)
|
||
|
|
if response.Object != "list" || len(response.Data) != 0 {
|
||
|
|
t.Fatalf("unexpected models response: %#v", response)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealthReturnsHealthyStatus(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
(&Handler{}).Health(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
var response model.HealthResponse
|
||
|
|
decodeJSONBody(t, rr, &response)
|
||
|
|
if response.Status != "healthy" {
|
||
|
|
t.Fatalf("unexpected health response: %#v", response)
|
||
|
|
}
|
||
|
|
if response.Timestamp.IsZero() {
|
||
|
|
t.Fatalf("expected non-zero timestamp: %#v", response)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenAPIEndpointsExposeDiscoveryFormats(t *testing.T) {
|
||
|
|
t.Run("json", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/openapi.json", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
(&Handler{}).OpenAPIJSON(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
if ct := rr.Header().Get("Content-Type"); ct != "application/json" {
|
||
|
|
t.Fatalf("expected json content type, got %q", ct)
|
||
|
|
}
|
||
|
|
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||
|
|
t.Fatalf("expected CORS header, got %#v", rr.Header())
|
||
|
|
}
|
||
|
|
|
||
|
|
var payload map[string]interface{}
|
||
|
|
decodeJSONBody(t, rr, &payload)
|
||
|
|
if payload["openapi"] == nil || payload["paths"] == nil {
|
||
|
|
t.Fatalf("unexpected openapi json payload: %#v", payload)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("yaml", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/openapi.yaml", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
(&Handler{}).OpenAPIYAML(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
if ct := rr.Header().Get("Content-Type"); ct != "application/x-yaml" {
|
||
|
|
t.Fatalf("expected yaml content type, got %q", ct)
|
||
|
|
}
|
||
|
|
if body := rr.Body.String(); !bytes.Contains([]byte(body), []byte("openapi:")) {
|
||
|
|
t.Fatalf("expected yaml body to contain openapi header, got %q", body)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestProxyPassthroughForwardsResponseAndPersistsMetadata(t *testing.T) {
|
||
|
|
storage := &passthroughStorageStub{}
|
||
|
|
upstreamResponse := `{"ok":true}`
|
||
|
|
provider := &passthroughProviderStub{
|
||
|
|
forwardRequestFn: func(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||
|
|
if req.Header.Get("X-Added") != "from-test" {
|
||
|
|
t.Fatalf("expected request header rule to be applied, got headers %#v", req.Header)
|
||
|
|
}
|
||
|
|
|
||
|
|
bodyBytes, err := io.ReadAll(req.Body)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed reading forwarded request body: %v", err)
|
||
|
|
}
|
||
|
|
if string(bodyBytes) != `{"input":"hello"}` {
|
||
|
|
t.Fatalf("unexpected forwarded body %q", string(bodyBytes))
|
||
|
|
}
|
||
|
|
|
||
|
|
resp := &http.Response{
|
||
|
|
StatusCode: http.StatusAccepted,
|
||
|
|
Header: http.Header{
|
||
|
|
"Content-Type": []string{"application/json"},
|
||
|
|
"X-Upstream": []string{"yes"},
|
||
|
|
"Connection": []string{"keep-alive"},
|
||
|
|
"Anthropic-Organization-Id": []string{"org-123"},
|
||
|
|
},
|
||
|
|
Body: io.NopCloser(bytes.NewBufferString(upstreamResponse)),
|
||
|
|
}
|
||
|
|
return resp, nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
storage.saveRequestFn = func(request *model.RequestLog) (string, error) {
|
||
|
|
if request.Method != http.MethodPost || request.Endpoint != "/v1/quota" {
|
||
|
|
t.Fatalf("unexpected saved request metadata: %#v", request)
|
||
|
|
}
|
||
|
|
if request.ContentType != "application/json" {
|
||
|
|
t.Fatalf("unexpected content type %q", request.ContentType)
|
||
|
|
}
|
||
|
|
bodyMap, ok := request.Body.(map[string]interface{})
|
||
|
|
if !ok || bodyMap["input"] != "hello" {
|
||
|
|
t.Fatalf("expected parsed request body, got %#v", request.Body)
|
||
|
|
}
|
||
|
|
if authValues := request.Headers["Authorization"]; len(authValues) != 1 || authValues[0] == "Bearer secret" || authValues[0] == "" {
|
||
|
|
t.Fatalf("expected sanitized authorization header, got %#v", request.Headers)
|
||
|
|
}
|
||
|
|
return "req-1", nil
|
||
|
|
}
|
||
|
|
storage.updateRequestFn = func(request *model.RequestLog) error {
|
||
|
|
if request.OrganizationID != "org-123" {
|
||
|
|
t.Fatalf("expected organization id extracted, got %#v", request)
|
||
|
|
}
|
||
|
|
if request.Response == nil || request.Response.StatusCode != http.StatusAccepted {
|
||
|
|
t.Fatalf("expected response metadata recorded, got %#v", request.Response)
|
||
|
|
}
|
||
|
|
if string(request.Response.Body) != upstreamResponse {
|
||
|
|
t.Fatalf("expected json response body stored, got %#v", request.Response)
|
||
|
|
}
|
||
|
|
if connectionValues := request.Response.Headers["Connection"]; len(connectionValues) != 0 {
|
||
|
|
t.Fatalf("expected sanitized stored headers to exclude hop-by-hop data, got %#v", request.Response.Headers)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
h := &Handler{
|
||
|
|
storageService: storage,
|
||
|
|
anthropicProvider: provider,
|
||
|
|
logger: log.New(io.Discard, "", 0),
|
||
|
|
cachedSettings: &model.ProxySettings{
|
||
|
|
RequestHeaderRules: []model.HeaderRule{{Header: "X-Added", Action: "set", Value: "from-test", Enabled: true}},
|
||
|
|
ResponseHeaderRules: []model.HeaderRule{{Header: "X-Proxy", Action: "set", Value: "applied", Enabled: true}},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/quota", bytes.NewBufferString(`{"input":"hello"}`))
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
req.Header.Set("Authorization", "Bearer secret")
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
|
||
|
|
h.ProxyPassthrough(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusAccepted {
|
||
|
|
t.Fatalf("expected 202, got %d", rr.Code)
|
||
|
|
}
|
||
|
|
if rr.Body.String() != upstreamResponse {
|
||
|
|
t.Fatalf("unexpected client body %q", rr.Body.String())
|
||
|
|
}
|
||
|
|
if rr.Header().Get("X-Upstream") != "yes" {
|
||
|
|
t.Fatalf("expected upstream header to be forwarded, got %#v", rr.Header())
|
||
|
|
}
|
||
|
|
if rr.Header().Get("X-Proxy") != "applied" {
|
||
|
|
t.Fatalf("expected response header rule to be applied, got %#v", rr.Header())
|
||
|
|
}
|
||
|
|
if rr.Header().Get("Connection") != "" {
|
||
|
|
t.Fatalf("expected hop-by-hop header to be stripped, got %#v", rr.Header())
|
||
|
|
}
|
||
|
|
if storage.savedRequests != 1 || storage.updatedRequests != 1 {
|
||
|
|
t.Fatalf("expected request save/update pair, got saves=%d updates=%d", storage.savedRequests, storage.updatedRequests)
|
||
|
|
}
|
||
|
|
}
|