346 lines
9 KiB
Go
346 lines
9 KiB
Go
// Package middleware provides HTTP middleware for security and request handling.
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// SecurityMiddleware provides comprehensive security features
|
|
type SecurityMiddleware struct {
|
|
rateLimiter *rate.Limiter
|
|
authConfig *auth.Config
|
|
jwtSecret []byte
|
|
}
|
|
|
|
// RateLimitOptions configures request throttling.
|
|
type RateLimitOptions struct {
|
|
RequestsPerMinute int
|
|
BurstSize int
|
|
}
|
|
|
|
// NewSecurityMiddleware creates a new security middleware instance.
|
|
func NewSecurityMiddleware(
|
|
authConfig *auth.Config,
|
|
jwtSecret string,
|
|
rlOpts *RateLimitOptions,
|
|
) *SecurityMiddleware {
|
|
sm := &SecurityMiddleware{
|
|
authConfig: authConfig,
|
|
jwtSecret: []byte(jwtSecret),
|
|
}
|
|
|
|
// Configure rate limiter if enabled
|
|
if rlOpts != nil && rlOpts.RequestsPerMinute > 0 {
|
|
limit := rate.Limit(float64(rlOpts.RequestsPerMinute) / 60.0)
|
|
burst := rlOpts.BurstSize
|
|
if burst <= 0 {
|
|
burst = rlOpts.RequestsPerMinute
|
|
}
|
|
sm.rateLimiter = rate.NewLimiter(limit, burst)
|
|
}
|
|
|
|
return sm
|
|
}
|
|
|
|
// RateLimit provides rate limiting middleware.
|
|
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
|
|
if sm.rateLimiter == nil {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !sm.rateLimiter.Allow() {
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// APIKeyAuth provides API key authentication middleware.
|
|
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// If authentication is not configured or disabled, allow all requests.
|
|
// This keeps local/dev environments functional without requiring API keys.
|
|
if sm.authConfig == nil || !sm.authConfig.Enabled {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
|
|
|
// Validate API key using auth config
|
|
user, err := sm.authConfig.ValidateAPIKey(apiKey)
|
|
if err != nil {
|
|
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
ctx := auth.WithUserContext(r.Context(), user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// SecurityHeaders provides security headers middleware.
|
|
func SecurityHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Prevent clickjacking
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
// Prevent MIME type sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
// Enable XSS protection
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
// Content Security Policy
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
// Referrer policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
// HSTS (HTTPS only)
|
|
if r.TLS != nil {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// IPWhitelist provides IP whitelist middleware.
|
|
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
|
|
parsedAddrs := make([]netip.Addr, 0, len(allowedIPs))
|
|
parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs))
|
|
for _, raw := range allowedIPs {
|
|
val := strings.TrimSpace(raw)
|
|
if val == "" {
|
|
continue
|
|
}
|
|
if strings.Contains(val, "/") {
|
|
p, err := netip.ParsePrefix(val)
|
|
if err != nil {
|
|
log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err)
|
|
continue
|
|
}
|
|
parsedPrefixes = append(parsedPrefixes, p)
|
|
continue
|
|
}
|
|
a, err := netip.ParseAddr(val)
|
|
if err != nil {
|
|
log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err)
|
|
continue
|
|
}
|
|
parsedAddrs = append(parsedAddrs, a)
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 {
|
|
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
clientIPStr := getClientIP(r)
|
|
addr, err := parseClientIP(clientIPStr)
|
|
if err != nil {
|
|
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
allowed := false
|
|
for _, a := range parsedAddrs {
|
|
if a == addr {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
for _, p := range parsedPrefixes {
|
|
if p.Contains(addr) {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// CORS middleware with configured allowed origins
|
|
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
|
|
allowed := make([]string, 0, len(allowedOrigins))
|
|
for _, o := range allowedOrigins {
|
|
v := strings.TrimSpace(o)
|
|
if v != "" {
|
|
allowed = append(allowed, v)
|
|
}
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin != "" {
|
|
isAllowed := false
|
|
for _, a := range allowed {
|
|
if a == "*" || origin == a {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
}
|
|
if isAllowed {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Vary", "Origin")
|
|
}
|
|
}
|
|
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequestTimeout provides request timeout middleware.
|
|
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequestSizeLimit provides request size limiting middleware.
|
|
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.ContentLength > maxSize {
|
|
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// AuditLogger provides security audit logging middleware.
|
|
func AuditLogger(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
path := r.URL.Path
|
|
raw := r.URL.RawQuery
|
|
|
|
// Wrap response writer to capture status code
|
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
|
|
// Process request
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
// Log after processing
|
|
latency := time.Since(start)
|
|
clientIP := getClientIP(r)
|
|
method := r.Method
|
|
statusCode := wrapped.statusCode
|
|
|
|
if raw != "" {
|
|
path = path + "?" + raw
|
|
}
|
|
|
|
// Log security-relevant events
|
|
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
|
|
// Log to security audit system
|
|
logSecurityEvent(map[string]interface{}{
|
|
"timestamp": start.Unix(),
|
|
"client_ip": clientIP,
|
|
"method": method,
|
|
"path": path,
|
|
"status": statusCode,
|
|
"latency": latency,
|
|
"user_agent": r.UserAgent(),
|
|
"referer": r.Referer(),
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// Helper to get client IP
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the list
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err == nil {
|
|
return host
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
func parseClientIP(raw string) (netip.Addr, error) {
|
|
s := strings.TrimSpace(raw)
|
|
if s == "" {
|
|
return netip.Addr{}, fmt.Errorf("empty ip")
|
|
}
|
|
// Try host:port parsing first (covers IPv4 and bracketed IPv6)
|
|
if host, _, err := net.SplitHostPort(s); err == nil {
|
|
s = host
|
|
}
|
|
// Trim brackets for IPv6 literals
|
|
s = strings.TrimPrefix(s, "[")
|
|
s = strings.TrimSuffix(s, "]")
|
|
return netip.ParseAddr(s)
|
|
}
|
|
|
|
// Response writer wrapper to capture status code
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func logSecurityEvent(event map[string]interface{}) {
|
|
// Implementation would send to security monitoring system
|
|
// For now, just log (in production, use proper logging)
|
|
log.Printf(
|
|
"SECURITY AUDIT: %s %s %s %v",
|
|
event["client_ip"],
|
|
event["method"],
|
|
event["path"],
|
|
event["status"],
|
|
)
|
|
}
|