fetch_ml/internal/middleware/security.go
Jeremie Fraeys 803677be57 feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
2025-12-04 16:53:53 -05:00

259 lines
6.7 KiB
Go

package middleware
import (
"context"
"log"
"net/http"
"strings"
"time"
"golang.org/x/time/rate"
)
// SecurityMiddleware provides comprehensive security features
type SecurityMiddleware struct {
rateLimiter *rate.Limiter
apiKeys map[string]bool
jwtSecret []byte
}
func NewSecurityMiddleware(apiKeys []string, jwtSecret string) *SecurityMiddleware {
keyMap := make(map[string]bool)
for _, key := range apiKeys {
keyMap[key] = true
}
return &SecurityMiddleware{
rateLimiter: rate.NewLimiter(rate.Limit(60), 10), // 60 requests per minute, burst of 10
apiKeys: keyMap,
jwtSecret: []byte(jwtSecret),
}
}
// Rate limiting middleware
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !sm.rateLimiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// API key authentication
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
// Also check Authorization header
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if !sm.apiKeys[apiKey] {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// Security headers middleware
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS protection
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Content Security Policy
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// HSTS (HTTPS only)
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
next.ServeHTTP(w, r)
})
}
// IP whitelist middleware
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
// Check if client IP is in whitelist
allowed := false
for _, ip := range allowedIPs {
if strings.Contains(ip, "/") {
// CIDR notation - would need proper IP net parsing
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
allowed = true
break
}
} else {
if clientIP == ip {
allowed = true
break
}
}
}
if !allowed {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// CORS middleware with restrictive defaults
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Only allow specific origins in production
allowedOrigins := []string{
"https://ml-experiments.example.com",
"https://app.example.com",
}
isAllowed := false
for _, allowed := range allowedOrigins {
if origin == allowed {
isAllowed = true
break
}
}
if isAllowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// Request timeout middleware
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
// Request size limiter
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > maxSize {
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
return
}
next.ServeHTTP(w, r)
})
}
}
// Security audit logging
func AuditLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := r.URL.Path
raw := r.URL.RawQuery
// Wrap response writer to capture status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Process request
next.ServeHTTP(wrapped, r)
// Log after processing
latency := time.Since(start)
clientIP := getClientIP(r)
method := r.Method
statusCode := wrapped.statusCode
if raw != "" {
path = path + "?" + raw
}
// Log security-relevant events
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
// Log to security audit system
logSecurityEvent(map[string]interface{}{
"timestamp": start.Unix(),
"client_ip": clientIP,
"method": method,
"path": path,
"status": statusCode,
"latency": latency,
"user_agent": r.UserAgent(),
"referer": r.Referer(),
})
}
})
}
// Helper to get client IP
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
// Response writer wrapper to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func logSecurityEvent(event map[string]interface{}) {
// Implementation would send to security monitoring system
// For now, just log (in production, use proper logging)
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
}