// Package middleware provides HTTP middleware for security and request handling. package middleware import ( "context" "fmt" "log" "net" "net/http" "net/netip" "strings" "time" "github.com/jfraeys/fetch_ml/internal/auth" "golang.org/x/time/rate" ) // SecurityMiddleware provides comprehensive security features type SecurityMiddleware struct { rateLimiter *rate.Limiter authConfig *auth.Config jwtSecret []byte } // RateLimitOptions configures request throttling. type RateLimitOptions struct { RequestsPerMinute int BurstSize int } // NewSecurityMiddleware creates a new security middleware instance. func NewSecurityMiddleware( authConfig *auth.Config, jwtSecret string, rlOpts *RateLimitOptions, ) *SecurityMiddleware { sm := &SecurityMiddleware{ authConfig: authConfig, jwtSecret: []byte(jwtSecret), } // Configure rate limiter if enabled if rlOpts != nil && rlOpts.RequestsPerMinute > 0 { limit := rate.Limit(float64(rlOpts.RequestsPerMinute) / 60.0) burst := rlOpts.BurstSize if burst <= 0 { burst = rlOpts.RequestsPerMinute } sm.rateLimiter = rate.NewLimiter(limit, burst) } return sm } // RateLimit provides rate limiting middleware. func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler { if sm.rateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !sm.rateLimiter.Allow() { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // APIKeyAuth provides API key authentication middleware. func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If authentication is not configured or disabled, allow all requests. // This keeps local/dev environments functional without requiring API keys. if sm.authConfig == nil || !sm.authConfig.Enabled { next.ServeHTTP(w, r) return } apiKey := auth.ExtractAPIKeyFromRequest(r) // Validate API key using auth config user, err := sm.authConfig.ValidateAPIKey(apiKey) if err != nil { http.Error(w, "Invalid API key", http.StatusUnauthorized) return } ctx := auth.WithUserContext(r.Context(), user) next.ServeHTTP(w, r.WithContext(ctx)) }) } // SecurityHeaders provides security headers middleware. func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Content Security Policy w.Header().Set("Content-Security-Policy", "default-src 'self'") // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // HSTS (HTTPS only) if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") } next.ServeHTTP(w, r) }) } // IPWhitelist provides IP whitelist middleware. func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler { parsedAddrs := make([]netip.Addr, 0, len(allowedIPs)) parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs)) for _, raw := range allowedIPs { val := strings.TrimSpace(raw) if val == "" { continue } if strings.Contains(val, "/") { p, err := netip.ParsePrefix(val) if err != nil { log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err) continue } parsedPrefixes = append(parsedPrefixes, p) continue } a, err := netip.ParseAddr(val) if err != nil { log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err) continue } parsedAddrs = append(parsedAddrs, a) } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 { http.Error(w, "IP not whitelisted", http.StatusForbidden) return } clientIPStr := getClientIP(r) addr, err := parseClientIP(clientIPStr) if err != nil { http.Error(w, "IP not whitelisted", http.StatusForbidden) return } allowed := false for _, a := range parsedAddrs { if a == addr { allowed = true break } } if !allowed { for _, p := range parsedPrefixes { if p.Contains(addr) { allowed = true break } } } if !allowed { http.Error(w, "IP not whitelisted", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } // CORS middleware with configured allowed origins func CORS(allowedOrigins []string) func(http.Handler) http.Handler { allowed := make([]string, 0, len(allowedOrigins)) for _, o := range allowedOrigins { v := strings.TrimSpace(o) if v != "" { allowed = append(allowed, v) } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" { isAllowed := false for _, a := range allowed { if a == "*" || origin == a { isAllowed = true break } } if isAllowed { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } } // RequestTimeout provides request timeout middleware. func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } } // RequestSizeLimit provides request size limiting middleware. func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength > maxSize { http.Error(w, "Request too large", http.StatusRequestEntityTooLarge) return } next.ServeHTTP(w, r) }) } } // AuditLogger provides security audit logging middleware. func AuditLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() path := r.URL.Path raw := r.URL.RawQuery // Wrap response writer to capture status code wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} // Process request next.ServeHTTP(wrapped, r) // Log after processing latency := time.Since(start) clientIP := getClientIP(r) method := r.Method statusCode := wrapped.statusCode if raw != "" { path = path + "?" + raw } // Log security-relevant events if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") { // Log to security audit system logSecurityEvent(map[string]interface{}{ "timestamp": start.Unix(), "client_ip": clientIP, "method": method, "path": path, "status": statusCode, "latency": latency, "user_agent": r.UserAgent(), "referer": r.Referer(), }) } }) } // Helper to get client IP func getClientIP(r *http.Request) string { // Check X-Forwarded-For header if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the list if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr host, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { return host } return r.RemoteAddr } func parseClientIP(raw string) (netip.Addr, error) { s := strings.TrimSpace(raw) if s == "" { return netip.Addr{}, fmt.Errorf("empty ip") } // Try host:port parsing first (covers IPv4 and bracketed IPv6) if host, _, err := net.SplitHostPort(s); err == nil { s = host } // Trim brackets for IPv6 literals s = strings.TrimPrefix(s, "[") s = strings.TrimSuffix(s, "]") return netip.ParseAddr(s) } // Response writer wrapper to capture status code type responseWriter struct { http.ResponseWriter statusCode int } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } func logSecurityEvent(event map[string]interface{}) { // Implementation would send to security monitoring system // For now, just log (in production, use proper logging) log.Printf( "SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"], ) }