63 lines
2.1 KiB
Go
63 lines
2.1 KiB
Go
package provider
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/seifghazi/claude-code-monitor/internal/config"
|
|
)
|
|
|
|
func TestOpenAIProviderForwardRequestClearsInboundAuthorization(t *testing.T) {
|
|
var gotAuthorization string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotAuthorization = r.Header.Get("Authorization")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"id":"resp_1","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := NewOpenAIProvider(&config.OpenAIProviderConfig{
|
|
BaseURL: server.URL,
|
|
}).(*OpenAIProvider)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://proxy.local/v1/messages", strings.NewReader(`{"model":"gpt-4o","messages":[],"max_tokens":16}`))
|
|
req.Header.Set("Authorization", "Bearer should-not-leak")
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := provider.ForwardRequest(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("ForwardRequest() error = %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
_, _ = io.ReadAll(resp.Body)
|
|
|
|
if gotAuthorization != "" {
|
|
t.Fatalf("expected forwarded Authorization header to be empty, got %q", gotAuthorization)
|
|
}
|
|
}
|
|
|
|
func TestTransformOpenAIStreamToAnthropicHandlesLargeEvents(t *testing.T) {
|
|
largeContent := strings.Repeat("x", 128*1024)
|
|
openAIStream := strings.NewReader("data: {\"id\":\"chatcmpl_1\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"" + largeContent + "\"}}]}\n\n" +
|
|
"data: [DONE]\n\n")
|
|
|
|
var output strings.Builder
|
|
if err := transformOpenAIStreamToAnthropic(openAIStream, &output); err != nil {
|
|
t.Fatalf("transformOpenAIStreamToAnthropic() error = %v", err)
|
|
}
|
|
|
|
got := output.String()
|
|
if !strings.Contains(got, "\"message_start\"") {
|
|
t.Fatal("expected message_start event in output")
|
|
}
|
|
if !strings.Contains(got, largeContent) {
|
|
t.Fatal("expected large content to be preserved in output")
|
|
}
|
|
if !strings.Contains(got, "\"message_stop\"") {
|
|
t.Fatal("expected message_stop event in output")
|
|
}
|
|
}
|