package middleware import ( "encoding/json" "net" "net/http" "strings" "github.com/seifghazi/claude-code-monitor/internal/config" ) func writeJSON(w http.ResponseWriter, status int, v interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(v) } func Auth(cfg config.AuthConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions || isPublicBypassPath(r.URL.Path) { next.ServeHTTP(w, r) return } if !cfg.Enabled { next.ServeHTTP(w, r) return } if cfg.AllowLocalhostBypass && isLocalhostRequest(r.RemoteAddr) { next.ServeHTTP(w, r) return } if token, ok := extractAuthToken(r, cfg); ok && token == cfg.Token { next.ServeHTTP(w, r) return } w.Header().Set("WWW-Authenticate", `Bearer realm="claude-code-proxy"`) writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "unauthorized", }) }) } } func isPublicBypassPath(path string) bool { switch path { case "/health", "/livez", "/openapi.json", "/openapi.yaml": return true default: return false } } func extractAuthToken(r *http.Request, cfg config.AuthConfig) (string, bool) { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) if authHeader != "" { const bearerPrefix = "Bearer " if len(authHeader) > len(bearerPrefix) && strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) { return strings.TrimSpace(authHeader[len(bearerPrefix):]), true } } if cfg.APIKeyHeader != "" { if headerValue := strings.TrimSpace(r.Header.Get(cfg.APIKeyHeader)); headerValue != "" { return headerValue, true } } // Accept the common X-API-Key header even if callers customize the config. if cfg.APIKeyHeader != "X-API-Key" && cfg.APIKeyHeader != "x-api-key" { if headerValue := strings.TrimSpace(r.Header.Get("X-API-Key")); headerValue != "" { return headerValue, true } } return "", false } func isLocalhostRequest(remoteAddr string) bool { host, _, err := net.SplitHostPort(remoteAddr) if err != nil { host = remoteAddr } host = strings.TrimSpace(strings.Trim(host, "[]")) if host == "localhost" { return true } ip := net.ParseIP(host) return ip != nil && ip.IsLoopback() }