// Package middleware provides HTTP middleware including audit logging package middleware import ( "net/http" "strings" "time" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/storage" ) // Middleware provides audit logging for task access type Middleware struct { db *storage.DB logger *logging.Logger } // NewMiddleware creates a new audit logging middleware func NewMiddleware(db *storage.DB, logger *logging.Logger) *Middleware { return &Middleware{ db: db, logger: logger, } } // Logger returns an HTTP middleware that logs task access to the audit log. // It should be applied to routes that access tasks (view, clone, execute, modify). func (m *Middleware) Logger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract task ID from URL path if present taskID := extractTaskID(r.URL.Path) if taskID == "" { // No task ID in path, skip audit logging next.ServeHTTP(w, r) return } // Determine action based on HTTP method and path action := determineAction(r.Method, r.URL.Path) // Get user or token var userID, token *string user := auth.GetUserFromContext(r.Context()) if user != nil { u := user.Name userID = &u } else { // Check for token in query params t := r.URL.Query().Get("token") if t != "" { token = &t } } // Get IP address ipStr := getClientIP(r) ip := &ipStr // Log the access if err := m.db.LogTaskAccess(taskID, userID, token, &action, ip); err != nil { m.logger.Error("failed to log task access", "error", err, "task_id", taskID) // Don't fail the request, just log the error } next.ServeHTTP(w, r) }) } // extractTaskID extracts task ID from URL path patterns like: // /api/tasks/{id} // /api/tasks/{id}/clone // /api/tasks/{id}/execute func extractTaskID(path string) string { // Remove query string if present if idx := strings.Index(path, "?"); idx != -1 { path = path[:idx] } // Check for task patterns if !strings.Contains(path, "/tasks/") { return "" } parts := strings.Split(path, "/") for i, part := range parts { if part == "tasks" && i+1 < len(parts) { taskID := parts[i+1] // Validate it's not a sub-path like "tasks" or "all" if taskID != "" && taskID != "all" && taskID != "list" { return taskID } } } return "" } // determineAction maps HTTP method and path to audit action. func determineAction(method, path string) string { lowerPath := strings.ToLower(path) switch method { case http.MethodGet: if strings.Contains(lowerPath, "/clone") { return "clone" } return "view" case http.MethodPost, http.MethodPut, http.MethodPatch: if strings.Contains(lowerPath, "/execute") || strings.Contains(lowerPath, "/run") { return "execute" } return "modify" case http.MethodDelete: return "delete" default: return "view" } } // RetentionJob runs the nightly audit log retention cleanup. // It deletes audit log entries older than the configured retention period. type RetentionJob struct { db *storage.DB logger *logging.Logger retentionDays int } // NewRetentionJob creates a new retention job. // Default retention is 2 years (730 days) if not specified. func NewRetentionJob(db *storage.DB, logger *logging.Logger, retentionDays int) *RetentionJob { if retentionDays <= 0 { retentionDays = 730 // 2 years default } return &RetentionJob{ db: db, logger: logger, retentionDays: retentionDays, } } // Run executes the retention cleanup once. func (j *RetentionJob) Run() error { j.logger.Info("starting audit log retention cleanup", "retention_days", j.retentionDays) deleted, err := j.db.DeleteOldAuditLogs(j.retentionDays) if err != nil { j.logger.Error("audit log retention cleanup failed", "error", err) return err } j.logger.Info("audit log retention cleanup completed", "deleted_entries", deleted) return nil } // RunPeriodic runs the retention job at the specified interval. // This should be called in a goroutine at application startup. func (j *RetentionJob) RunPeriodic(interval time.Duration, stopCh <-chan struct{}) { ticker := time.NewTicker(interval) defer ticker.Stop() // Run immediately on startup if err := j.Run(); err != nil { j.logger.Error("initial audit log retention cleanup failed", "error", err) } for { select { case <-ticker.C: if err := j.Run(); err != nil { j.logger.Error("periodic audit log retention cleanup failed", "error", err) } case <-stopCh: j.logger.Info("stopping audit log retention job") return } } } // StartNightlyRetentionJob starts a retention job that runs once per day at midnight UTC. func StartNightlyRetentionJob(db *storage.DB, logger *logging.Logger, retentionDays int, stopCh <-chan struct{}) { job := NewRetentionJob(db, logger, retentionDays) // Calculate time until next midnight UTC now := time.Now().UTC() nextMidnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.UTC) durationUntilMidnight := nextMidnight.Sub(now) logger.Info("scheduling nightly audit log retention job", "next_run", nextMidnight.Format(time.RFC3339), "retention_days", retentionDays, ) // Wait until midnight, then start the periodic ticker go func() { time.Sleep(durationUntilMidnight) job.RunPeriodic(24*time.Hour, stopCh) }() }