package middleware import ( "bytes" "context" "encoding/json" "fmt" "io" "log" "net/http" "strings" "time" "github.com/seifghazi/claude-code-monitor/internal/model" ) func Logging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // For POST requests with body, read and store the bytes var bodyBytes []byte if r.Body != nil && (r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH") { var err error bodyBytes, err = io.ReadAll(r.Body) if err != nil { log.Printf("❌ Error reading request body: %v", err) http.Error(w, "Error reading request body", http.StatusBadRequest) return } r.Body.Close() r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Store raw bytes in context for handler to use ctx := context.WithValue(r.Context(), model.BodyBytesKey, bodyBytes) r = r.WithContext(ctx) } wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} next.ServeHTTP(wrapped, r) duration := time.Since(start) statusColor := getStatusColor(wrapped.statusCode) // Build a richer log line for proxy requests if isProxyRequest(r.URL.Path) { details := buildProxyLogDetails(r, bodyBytes, wrapped, duration) log.Printf("%s%s%s %s", statusColor, details, colorReset, colorDim+formatDuration(duration)+colorReset) } else { log.Printf("%s %s %s%d%s (%s)", r.Method, r.URL.Path, statusColor, wrapped.statusCode, colorReset, formatDuration(duration)) } }) } // isProxyRequest returns true for /v1/* API proxy paths func isProxyRequest(path string) bool { return strings.HasPrefix(path, "/v1/") } // buildProxyLogDetails creates a rich single-line log for proxy requests func buildProxyLogDetails(r *http.Request, bodyBytes []byte, w *responseWriter, duration time.Duration) string { var parts []string // Status parts = append(parts, fmt.Sprintf("%d", w.statusCode)) // Method + path parts = append(parts, fmt.Sprintf("%s %s", r.Method, r.URL.Path)) // Extract model and stream flag from request body if len(bodyBytes) > 0 { var body struct { Model string `json:"model"` Stream bool `json:"stream"` } if err := json.Unmarshal(bodyBytes, &body); err == nil { if body.Model != "" { // Shorten model name for readability modelShort := shortenModel(body.Model) parts = append(parts, colorCyan+modelShort+colorReset) } if body.Stream { parts = append(parts, "stream") } } } // Client info — use X-Forwarded-For if behind proxy, else RemoteAddr clientIP := r.Header.Get("X-Forwarded-For") if clientIP == "" { clientIP = r.Header.Get("X-Real-Ip") } if clientIP == "" { clientIP = r.RemoteAddr } // Strip port from IP if host, _, err := splitHostPort(clientIP); err == nil { clientIP = host } // User-Agent — extract just the tool/client name ua := r.Header.Get("User-Agent") if clientName := extractClientName(ua); clientName != "" { parts = append(parts, colorDim+clientName+colorReset) } if clientIP != "" { parts = append(parts, colorDim+clientIP+colorReset) } return strings.Join(parts, " ") } // shortenModel turns "claude-sonnet-4-20250514" into "sonnet-4" func shortenModel(model string) string { lower := strings.ToLower(model) for _, family := range []string{"opus", "sonnet", "haiku"} { if strings.Contains(lower, family) { // Find version number after family name idx := strings.Index(lower, family) rest := lower[idx+len(family):] // Extract version like "-4" or "-4-20250514" rest = strings.TrimLeft(rest, "-") if dashIdx := strings.Index(rest, "-"); dashIdx > 0 { // Keep just the version number (e.g. "4" from "4-20250514") return family + "-" + rest[:dashIdx] } if rest != "" { return family + "-" + rest } return family } } // For non-Claude models, take first two segments segs := strings.SplitN(model, "-", 3) if len(segs) >= 2 { return segs[0] + "-" + segs[1] } return model } // extractClientName pulls a recognizable client name from User-Agent func extractClientName(ua string) string { if ua == "" { return "" } lower := strings.ToLower(ua) switch { case strings.Contains(lower, "claude-code"): return "claude-code" case strings.Contains(lower, "cursor"): return "cursor" case strings.Contains(lower, "continue"): return "continue" case strings.Contains(lower, "anthropic-sdk"): return "sdk" case strings.Contains(lower, "python"): return "python" case strings.Contains(lower, "node"): return "node" case strings.Contains(lower, "curl"): return "curl" default: // Take first token if spaceIdx := strings.IndexByte(ua, ' '); spaceIdx > 0 { first := ua[:spaceIdx] if len(first) > 20 { return first[:20] } return first } if len(ua) > 20 { return ua[:20] } return ua } } // splitHostPort is a simple wrapper that handles IPs without ports func splitHostPort(addr string) (string, string, error) { // If there's no colon or it's an IPv6 without port, just return the addr if !strings.Contains(addr, ":") { return addr, "", nil } // Handle comma-separated X-Forwarded-For if commaIdx := strings.IndexByte(addr, ','); commaIdx > 0 { addr = strings.TrimSpace(addr[:commaIdx]) } host, port, err := splitAddr(addr) return host, port, err } func splitAddr(addr string) (string, string, error) { if strings.Count(addr, ":") > 1 { // IPv6 — may or may not have brackets return addr, "", nil } idx := strings.LastIndexByte(addr, ':') if idx < 0 { return addr, "", nil } return addr[:idx], addr[idx+1:], nil } type responseWriter struct { http.ResponseWriter statusCode int } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } // Flush propagates to the underlying http.Flusher. Without this, embedding // http.ResponseWriter (an interface) silently drops Flush(), so SSE writes // buffer in net/http until the body closes — breaking token-by-token // streaming UX. func (rw *responseWriter) Flush() { if f, ok := rw.ResponseWriter.(http.Flusher); ok { f.Flush() } } // ANSI color codes const ( colorReset = "\033[0m" colorGreen = "\033[32m" colorYellow = "\033[33m" colorRed = "\033[31m" colorBlue = "\033[34m" colorCyan = "\033[36m" colorDim = "\033[2m" ) func getStatusColor(status int) string { switch { case status >= 200 && status < 300: return colorGreen case status >= 300 && status < 400: return colorBlue case status >= 400 && status < 500: return colorYellow case status >= 500: return colorRed default: return colorReset } } func formatDuration(d time.Duration) string { if d < time.Millisecond { return fmt.Sprintf("%dµs", d.Microseconds()) } else if d < time.Second { return fmt.Sprintf("%dms", d.Milliseconds()) } return fmt.Sprintf("%.2fs", d.Seconds()) }