fetch_ml/internal/middleware/security.go

346 lines
9 KiB
Go

// Package middleware provides HTTP middleware for security and request handling.
package middleware
import (
"context"
"fmt"
"log"
"net"
"net/http"
"net/netip"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"golang.org/x/time/rate"
)
// SecurityMiddleware provides comprehensive security features
type SecurityMiddleware struct {
rateLimiter *rate.Limiter
authConfig *auth.Config
jwtSecret []byte
}
// RateLimitOptions configures request throttling.
type RateLimitOptions struct {
RequestsPerMinute int
BurstSize int
}
// NewSecurityMiddleware creates a new security middleware instance.
func NewSecurityMiddleware(
authConfig *auth.Config,
jwtSecret string,
rlOpts *RateLimitOptions,
) *SecurityMiddleware {
sm := &SecurityMiddleware{
authConfig: authConfig,
jwtSecret: []byte(jwtSecret),
}
// Configure rate limiter if enabled
if rlOpts != nil && rlOpts.RequestsPerMinute > 0 {
limit := rate.Limit(float64(rlOpts.RequestsPerMinute) / 60.0)
burst := rlOpts.BurstSize
if burst <= 0 {
burst = rlOpts.RequestsPerMinute
}
sm.rateLimiter = rate.NewLimiter(limit, burst)
}
return sm
}
// RateLimit provides rate limiting middleware.
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
if sm.rateLimiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !sm.rateLimiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// APIKeyAuth provides API key authentication middleware.
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If authentication is not configured or disabled, allow all requests.
// This keeps local/dev environments functional without requiring API keys.
if sm.authConfig == nil || !sm.authConfig.Enabled {
next.ServeHTTP(w, r)
return
}
apiKey := auth.ExtractAPIKeyFromRequest(r)
// Validate API key using auth config
user, err := sm.authConfig.ValidateAPIKey(apiKey)
if err != nil {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
ctx := auth.WithUserContext(r.Context(), user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// SecurityHeaders provides security headers middleware.
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS protection
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Content Security Policy
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// HSTS (HTTPS only)
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
next.ServeHTTP(w, r)
})
}
// IPWhitelist provides IP whitelist middleware.
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
parsedAddrs := make([]netip.Addr, 0, len(allowedIPs))
parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs))
for _, raw := range allowedIPs {
val := strings.TrimSpace(raw)
if val == "" {
continue
}
if strings.Contains(val, "/") {
p, err := netip.ParsePrefix(val)
if err != nil {
log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err)
continue
}
parsedPrefixes = append(parsedPrefixes, p)
continue
}
a, err := netip.ParseAddr(val)
if err != nil {
log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err)
continue
}
parsedAddrs = append(parsedAddrs, a)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
clientIPStr := getClientIP(r)
addr, err := parseClientIP(clientIPStr)
if err != nil {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
allowed := false
for _, a := range parsedAddrs {
if a == addr {
allowed = true
break
}
}
if !allowed {
for _, p := range parsedPrefixes {
if p.Contains(addr) {
allowed = true
break
}
}
}
if !allowed {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// CORS middleware with configured allowed origins
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
allowed := make([]string, 0, len(allowedOrigins))
for _, o := range allowedOrigins {
v := strings.TrimSpace(o)
if v != "" {
allowed = append(allowed, v)
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" {
isAllowed := false
for _, a := range allowed {
if a == "*" || origin == a {
isAllowed = true
break
}
}
if isAllowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}
// RequestTimeout provides request timeout middleware.
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
// RequestSizeLimit provides request size limiting middleware.
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > maxSize {
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
return
}
next.ServeHTTP(w, r)
})
}
}
// AuditLogger provides security audit logging middleware.
func AuditLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := r.URL.Path
raw := r.URL.RawQuery
// Wrap response writer to capture status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Process request
next.ServeHTTP(wrapped, r)
// Log after processing
latency := time.Since(start)
clientIP := getClientIP(r)
method := r.Method
statusCode := wrapped.statusCode
if raw != "" {
path = path + "?" + raw
}
// Log security-relevant events
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
// Log to security audit system
logSecurityEvent(map[string]interface{}{
"timestamp": start.Unix(),
"client_ip": clientIP,
"method": method,
"path": path,
"status": statusCode,
"latency": latency,
"user_agent": r.UserAgent(),
"referer": r.Referer(),
})
}
})
}
// Helper to get client IP
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
return host
}
return r.RemoteAddr
}
func parseClientIP(raw string) (netip.Addr, error) {
s := strings.TrimSpace(raw)
if s == "" {
return netip.Addr{}, fmt.Errorf("empty ip")
}
// Try host:port parsing first (covers IPv4 and bracketed IPv6)
if host, _, err := net.SplitHostPort(s); err == nil {
s = host
}
// Trim brackets for IPv6 literals
s = strings.TrimPrefix(s, "[")
s = strings.TrimSuffix(s, "]")
return netip.ParseAddr(s)
}
// Response writer wrapper to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func logSecurityEvent(event map[string]interface{}) {
// Implementation would send to security monitoring system
// For now, just log (in production, use proper logging)
log.Printf(
"SECURITY AUDIT: %s %s %s %v",
event["client_ip"],
event["method"],
event["path"],
event["status"],
)
}