- Add API server with WebSocket support and REST endpoints - Implement authentication system with API keys and permissions - Add task queue system with Redis backend and error handling - Include storage layer with database migrations and schemas - Add comprehensive logging, metrics, and telemetry - Implement security middleware and network utilities - Add experiment management and container orchestration - Include configuration management with smart defaults
259 lines
6.7 KiB
Go
259 lines
6.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// SecurityMiddleware provides comprehensive security features
|
|
type SecurityMiddleware struct {
|
|
rateLimiter *rate.Limiter
|
|
apiKeys map[string]bool
|
|
jwtSecret []byte
|
|
}
|
|
|
|
func NewSecurityMiddleware(apiKeys []string, jwtSecret string) *SecurityMiddleware {
|
|
keyMap := make(map[string]bool)
|
|
for _, key := range apiKeys {
|
|
keyMap[key] = true
|
|
}
|
|
|
|
return &SecurityMiddleware{
|
|
rateLimiter: rate.NewLimiter(rate.Limit(60), 10), // 60 requests per minute, burst of 10
|
|
apiKeys: keyMap,
|
|
jwtSecret: []byte(jwtSecret),
|
|
}
|
|
}
|
|
|
|
// Rate limiting middleware
|
|
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !sm.rateLimiter.Allow() {
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// API key authentication
|
|
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
apiKey := r.Header.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
// Also check Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
|
|
}
|
|
}
|
|
|
|
if !sm.apiKeys[apiKey] {
|
|
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Security headers middleware
|
|
func SecurityHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Prevent clickjacking
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
// Prevent MIME type sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
// Enable XSS protection
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
// Content Security Policy
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
// Referrer policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
// HSTS (HTTPS only)
|
|
if r.TLS != nil {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// IP whitelist middleware
|
|
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
clientIP := getClientIP(r)
|
|
|
|
// Check if client IP is in whitelist
|
|
allowed := false
|
|
for _, ip := range allowedIPs {
|
|
if strings.Contains(ip, "/") {
|
|
// CIDR notation - would need proper IP net parsing
|
|
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
|
|
allowed = true
|
|
break
|
|
}
|
|
} else {
|
|
if clientIP == ip {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// CORS middleware with restrictive defaults
|
|
func CORS(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// Only allow specific origins in production
|
|
allowedOrigins := []string{
|
|
"https://ml-experiments.example.com",
|
|
"https://app.example.com",
|
|
}
|
|
|
|
isAllowed := false
|
|
for _, allowed := range allowedOrigins {
|
|
if origin == allowed {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if isAllowed {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Request timeout middleware
|
|
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Request size limiter
|
|
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.ContentLength > maxSize {
|
|
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Security audit logging
|
|
func AuditLogger(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
path := r.URL.Path
|
|
raw := r.URL.RawQuery
|
|
|
|
// Wrap response writer to capture status code
|
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
|
|
// Process request
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
// Log after processing
|
|
latency := time.Since(start)
|
|
clientIP := getClientIP(r)
|
|
method := r.Method
|
|
statusCode := wrapped.statusCode
|
|
|
|
if raw != "" {
|
|
path = path + "?" + raw
|
|
}
|
|
|
|
// Log security-relevant events
|
|
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
|
|
// Log to security audit system
|
|
logSecurityEvent(map[string]interface{}{
|
|
"timestamp": start.Unix(),
|
|
"client_ip": clientIP,
|
|
"method": method,
|
|
"path": path,
|
|
"status": statusCode,
|
|
"latency": latency,
|
|
"user_agent": r.UserAgent(),
|
|
"referer": r.Referer(),
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// Helper to get client IP
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the list
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
|
return r.RemoteAddr[:idx]
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// Response writer wrapper to capture status code
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func logSecurityEvent(event map[string]interface{}) {
|
|
// Implementation would send to security monitoring system
|
|
// For now, just log (in production, use proper logging)
|
|
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
|
|
}
|