// Package middleware provides HTTP middleware for security and request handling. package middleware import ( "context" "log" "net/http" "strings" "time" "golang.org/x/time/rate" ) // SecurityMiddleware provides comprehensive security features type SecurityMiddleware struct { rateLimiter *rate.Limiter apiKeys map[string]bool jwtSecret []byte } // RateLimitOptions configures request throttling. type RateLimitOptions struct { RequestsPerMinute int BurstSize int } // NewSecurityMiddleware creates a new security middleware instance. func NewSecurityMiddleware(apiKeys []string, jwtSecret string, rlOpts *RateLimitOptions) *SecurityMiddleware { keyMap := make(map[string]bool) for _, key := range apiKeys { keyMap[key] = true } sm := &SecurityMiddleware{ apiKeys: keyMap, jwtSecret: []byte(jwtSecret), } // Configure rate limiter if enabled if rlOpts != nil && rlOpts.RequestsPerMinute > 0 { limit := rate.Limit(float64(rlOpts.RequestsPerMinute) / 60.0) burst := rlOpts.BurstSize if burst <= 0 { burst = rlOpts.RequestsPerMinute } sm.rateLimiter = rate.NewLimiter(limit, burst) } return sm } // RateLimit provides rate limiting middleware. func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler { if sm.rateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !sm.rateLimiter.Allow() { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // APIKeyAuth provides API key authentication middleware. func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := r.Header.Get("X-API-Key") if apiKey == "" { // Also check Authorization header authHeader := r.Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { apiKey = strings.TrimPrefix(authHeader, "Bearer ") } } if !sm.apiKeys[apiKey] { http.Error(w, "Invalid API key", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } // SecurityHeaders provides security headers middleware. func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Content Security Policy w.Header().Set("Content-Security-Policy", "default-src 'self'") // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // HSTS (HTTPS only) if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") } next.ServeHTTP(w, r) }) } // IPWhitelist provides IP whitelist middleware. func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getClientIP(r) // Check if client IP is in whitelist allowed := false for _, ip := range allowedIPs { if strings.Contains(ip, "/") { // CIDR notation - would need proper IP net parsing if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) { allowed = true break } } else { if clientIP == ip { allowed = true break } } } if !allowed { http.Error(w, "IP not whitelisted", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } // CORS middleware with restrictive defaults func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Only allow specific origins in production allowedOrigins := []string{ "https://ml-experiments.example.com", "https://app.example.com", } isAllowed := false for _, allowed := range allowedOrigins { if origin == allowed { isAllowed = true break } } if isAllowed { w.Header().Set("Access-Control-Allow-Origin", origin) } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // RequestTimeout provides request timeout middleware. func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } } // RequestSizeLimit provides request size limiting middleware. func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength > maxSize { http.Error(w, "Request too large", http.StatusRequestEntityTooLarge) return } next.ServeHTTP(w, r) }) } } // AuditLogger provides security audit logging middleware. func AuditLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() path := r.URL.Path raw := r.URL.RawQuery // Wrap response writer to capture status code wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} // Process request next.ServeHTTP(wrapped, r) // Log after processing latency := time.Since(start) clientIP := getClientIP(r) method := r.Method statusCode := wrapped.statusCode if raw != "" { path = path + "?" + raw } // Log security-relevant events if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") { // Log to security audit system logSecurityEvent(map[string]interface{}{ "timestamp": start.Unix(), "client_ip": clientIP, "method": method, "path": path, "status": statusCode, "latency": latency, "user_agent": r.UserAgent(), "referer": r.Referer(), }) } }) } // Helper to get client IP func getClientIP(r *http.Request) string { // Check X-Forwarded-For header if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the list if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { return r.RemoteAddr[:idx] } return r.RemoteAddr } // Response writer wrapper to capture status code type responseWriter struct { http.ResponseWriter statusCode int } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } func logSecurityEvent(event map[string]interface{}) { // Implementation would send to security monitoring system // For now, just log (in production, use proper logging) log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"]) }