- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
281 lines
7.5 KiB
Go
281 lines
7.5 KiB
Go
// Package middleware provides HTTP middleware for security and request handling.
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// SecurityMiddleware provides comprehensive security features
|
|
type SecurityMiddleware struct {
|
|
rateLimiter *rate.Limiter
|
|
apiKeys map[string]bool
|
|
jwtSecret []byte
|
|
}
|
|
|
|
// RateLimitOptions configures request throttling.
|
|
type RateLimitOptions struct {
|
|
RequestsPerMinute int
|
|
BurstSize int
|
|
}
|
|
|
|
// NewSecurityMiddleware creates a new security middleware instance.
|
|
func NewSecurityMiddleware(apiKeys []string, jwtSecret string, rlOpts *RateLimitOptions) *SecurityMiddleware {
|
|
keyMap := make(map[string]bool)
|
|
for _, key := range apiKeys {
|
|
keyMap[key] = true
|
|
}
|
|
|
|
sm := &SecurityMiddleware{
|
|
apiKeys: keyMap,
|
|
jwtSecret: []byte(jwtSecret),
|
|
}
|
|
|
|
// Configure rate limiter if enabled
|
|
if rlOpts != nil && rlOpts.RequestsPerMinute > 0 {
|
|
limit := rate.Limit(float64(rlOpts.RequestsPerMinute) / 60.0)
|
|
burst := rlOpts.BurstSize
|
|
if burst <= 0 {
|
|
burst = rlOpts.RequestsPerMinute
|
|
}
|
|
sm.rateLimiter = rate.NewLimiter(limit, burst)
|
|
}
|
|
|
|
return sm
|
|
}
|
|
|
|
// RateLimit provides rate limiting middleware.
|
|
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
|
|
if sm.rateLimiter == nil {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !sm.rateLimiter.Allow() {
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// APIKeyAuth provides API key authentication middleware.
|
|
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
apiKey := r.Header.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
// Also check Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
|
|
}
|
|
}
|
|
|
|
if !sm.apiKeys[apiKey] {
|
|
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// SecurityHeaders provides security headers middleware.
|
|
func SecurityHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Prevent clickjacking
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
// Prevent MIME type sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
// Enable XSS protection
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
// Content Security Policy
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
// Referrer policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
// HSTS (HTTPS only)
|
|
if r.TLS != nil {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// IPWhitelist provides IP whitelist middleware.
|
|
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
clientIP := getClientIP(r)
|
|
|
|
// Check if client IP is in whitelist
|
|
allowed := false
|
|
for _, ip := range allowedIPs {
|
|
if strings.Contains(ip, "/") {
|
|
// CIDR notation - would need proper IP net parsing
|
|
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
|
|
allowed = true
|
|
break
|
|
}
|
|
} else {
|
|
if clientIP == ip {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// CORS middleware with restrictive defaults
|
|
func CORS(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// Only allow specific origins in production
|
|
allowedOrigins := []string{
|
|
"https://ml-experiments.example.com",
|
|
"https://app.example.com",
|
|
}
|
|
|
|
isAllowed := false
|
|
for _, allowed := range allowedOrigins {
|
|
if origin == allowed {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if isAllowed {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// RequestTimeout provides request timeout middleware.
|
|
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequestSizeLimit provides request size limiting middleware.
|
|
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.ContentLength > maxSize {
|
|
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// AuditLogger provides security audit logging middleware.
|
|
func AuditLogger(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
path := r.URL.Path
|
|
raw := r.URL.RawQuery
|
|
|
|
// Wrap response writer to capture status code
|
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
|
|
// Process request
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
// Log after processing
|
|
latency := time.Since(start)
|
|
clientIP := getClientIP(r)
|
|
method := r.Method
|
|
statusCode := wrapped.statusCode
|
|
|
|
if raw != "" {
|
|
path = path + "?" + raw
|
|
}
|
|
|
|
// Log security-relevant events
|
|
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
|
|
// Log to security audit system
|
|
logSecurityEvent(map[string]interface{}{
|
|
"timestamp": start.Unix(),
|
|
"client_ip": clientIP,
|
|
"method": method,
|
|
"path": path,
|
|
"status": statusCode,
|
|
"latency": latency,
|
|
"user_agent": r.UserAgent(),
|
|
"referer": r.Referer(),
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// Helper to get client IP
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the list
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
|
return r.RemoteAddr[:idx]
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// Response writer wrapper to capture status code
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func logSecurityEvent(event map[string]interface{}) {
|
|
// Implementation would send to security monitoring system
|
|
// For now, just log (in production, use proper logging)
|
|
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
|
|
}
|