diff --git a/internal/api/audit/handlers.go b/internal/api/audit/handlers.go index 4ccb1a1..8ceebe2 100644 --- a/internal/api/audit/handlers.go +++ b/internal/api/audit/handlers.go @@ -116,7 +116,9 @@ func (h *Handler) GetV1AuditEvents(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode audit events", "error", err) + } return } @@ -129,7 +131,9 @@ func (h *Handler) GetV1AuditEvents(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode empty audit events", "error", err) + } } // PostV1AuditVerify handles POST /v1/audit/verify @@ -150,7 +154,9 @@ func (h *Handler) PostV1AuditVerify(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(result) + if err := json.NewEncoder(w).Encode(result); err != nil { + h.logger.Warn("failed to encode verification result", "error", err) + } } // GetV1AuditChainRoot handles GET /v1/audit/chain-root @@ -169,7 +175,9 @@ func (h *Handler) GetV1AuditChainRoot(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode chain root", "error", err) + } } // checkPermission checks if the user has the required permission diff --git a/internal/audit/audit.go b/internal/audit/audit.go index 705df29..184ff23 100644 --- a/internal/audit/audit.go +++ b/internal/audit/audit.go @@ -111,6 +111,7 @@ func NewLoggerWithBase(enabled bool, filePath string, logger *logging.Logger, ba return nil, fmt.Errorf("failed to create audit directory: %w", err) } + // #nosec G304 -- fullPath is validated through validateAndSecurePath and checkFileNotSymlink file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return nil, fmt.Errorf("failed to open audit log file: %w", err) @@ -121,7 +122,9 @@ func NewLoggerWithBase(enabled bool, filePath string, logger *logging.Logger, ba // Restore chain state from existing log to prevent integrity break on restart if err := al.resumeFromFile(); err != nil { - file.Close() + if closeErr := file.Close(); closeErr != nil { + return nil, fmt.Errorf("failed to resume audit chain and close file: %w (close error: %v)", err, closeErr) + } return nil, fmt.Errorf("failed to resume audit chain: %w", err) } diff --git a/internal/audit/chain.go b/internal/audit/chain.go index ee1ff8e..83dd309 100644 --- a/internal/audit/chain.go +++ b/internal/audit/chain.go @@ -100,6 +100,7 @@ func (hc *HashChain) AddEvent(event Event) (*ChainEntry, error) { // VerifyChain verifies the integrity of a chain from a file func VerifyChain(filePath string) error { + // #nosec G304 -- filePath is internally controlled, not from user input file, err := os.Open(filePath) if err != nil { return fmt.Errorf("failed to open chain file: %w", err) @@ -188,6 +189,7 @@ func openOrCreateChainFile(filePath string) (*os.File, uint64, string, error) { // Check if file exists if _, err := os.Stat(filePath); os.IsNotExist(err) { // Create new file + // #nosec G304 -- filePath is internally controlled, not from user input file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { return nil, 0, "", fmt.Errorf("failed to create chain file: %w", err) @@ -196,6 +198,7 @@ func openOrCreateChainFile(filePath string) (*os.File, uint64, string, error) { } // File exists - verify integrity and get last state + // #nosec G304 -- filePath is internally controlled, not from user input file, err := os.OpenFile(filePath, os.O_RDONLY, 0600) if err != nil { return nil, 0, "", fmt.Errorf("failed to open existing chain file: %w", err) @@ -208,16 +211,19 @@ func openOrCreateChainFile(filePath string) (*os.File, uint64, string, error) { for decoder.More() { var entry ChainEntry if err := decoder.Decode(&entry); err != nil { - file.Close() + _ = file.Close() return nil, 0, "", fmt.Errorf("corrupted chain file: %w", err) } lastSeq = entry.SeqNum lastHash = entry.ThisHash } - file.Close() + if err := file.Close(); err != nil { + return nil, 0, "", fmt.Errorf("failed to close chain file after read: %w", err) + } // Reopen for appending + // #nosec G304 -- filePath is internally controlled, not from user input file, err = os.OpenFile(filePath, os.O_WRONLY|os.O_APPEND, 0600) if err != nil { return nil, 0, "", fmt.Errorf("failed to reopen chain file for append: %w", err) diff --git a/internal/audit/checkpoint.go b/internal/audit/checkpoint.go index c0bbb9c..cc33a1c 100644 --- a/internal/audit/checkpoint.go +++ b/internal/audit/checkpoint.go @@ -62,6 +62,9 @@ func (dcm *DBCheckpointManager) VerifyAgainstDB(filePath string) error { return err } + if localSeq < 0 { + return fmt.Errorf("sequence number cannot be negative: %d", localSeq) + } if uint64(localSeq) != dbSeq || localHash != dbHash { return fmt.Errorf( "TAMPERING DETECTED: local(seq=%d hash=%s) vs db(seq=%d hash=%s)", @@ -102,6 +105,9 @@ func (dcm *DBCheckpointManager) VerifyAllFiles() ([]VerificationResult, error) { if err != nil { result.Valid = false result.Error = fmt.Sprintf("read local file: %v", err) + } else if localSeq < 0 { + result.Valid = false + result.Error = fmt.Sprintf("sequence number cannot be negative: %d", localSeq) } else if uint64(localSeq) != dbSeq || localHash != dbHash { result.Valid = false result.FirstTampered = localSeq diff --git a/internal/audit/rotation.go b/internal/audit/rotation.go index bd8d439..5af10dc 100644 --- a/internal/audit/rotation.go +++ b/internal/audit/rotation.go @@ -60,6 +60,7 @@ func NewRotatingLogger(enabled bool, basePath, anchorDir string, logger *logging } // Open the log file for current date + // #nosec G304 -- fullPath is internally constructed from basePath and currentDate file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return nil, fmt.Errorf("open audit log file: %w", err) @@ -76,7 +77,9 @@ func NewRotatingLogger(enabled bool, basePath, anchorDir string, logger *logging // Resume from file if it exists if err := al.resumeFromFile(); err != nil { - file.Close() + if closeErr := file.Close(); closeErr != nil { + return nil, fmt.Errorf("resume audit chain: %w, close: %v", err, closeErr) + } return nil, fmt.Errorf("resume audit chain: %w", err) } @@ -131,6 +134,9 @@ func (rl *RotatingLogger) Rotate() error { } // Create anchor file with last hash + if rl.sequenceNum < 0 { + return fmt.Errorf("sequence number cannot be negative: %d", rl.sequenceNum) + } anchor := AnchorFile{ Date: oldDate, LastHash: rl.lastHash, @@ -146,6 +152,7 @@ func (rl *RotatingLogger) Rotate() error { rl.currentDate = time.Now().UTC().Format("2006-01-02") newPath := filepath.Join(rl.basePath, fmt.Sprintf("audit-%s.log", rl.currentDate)) + // #nosec G304 -- newPath is internally constructed from basePath and currentDate f, err := os.OpenFile(newPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) if err != nil { return err @@ -202,6 +209,7 @@ func writeAnchorFile(path string, anchor AnchorFile) error { // readAnchorFile reads an anchor file from disk func readAnchorFile(path string) (*AnchorFile, error) { + // #nosec G304 -- path is an internally controlled anchor file path data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read anchor file: %w", err) @@ -216,6 +224,7 @@ func readAnchorFile(path string) (*AnchorFile, error) { // sha256File computes the SHA256 hash of a file func sha256File(path string) (string, error) { + // #nosec G304 -- path is an internally controlled audit log file data, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("read file: %w", err) @@ -246,6 +255,9 @@ func VerifyRotationIntegrity(logPath, anchorPath string) error { if err != nil { return err } + if lastSeq < 0 { + return fmt.Errorf("sequence number cannot be negative: %d", lastSeq) + } if uint64(lastSeq) != anchor.LastSeq || lastHash != anchor.LastHash { return fmt.Errorf("TAMPERING DETECTED: chain mismatch: expected(seq=%d,hash=%s), got(seq=%d,hash=%s)", anchor.LastSeq, anchor.LastHash, lastSeq, lastHash) @@ -256,6 +268,7 @@ func VerifyRotationIntegrity(logPath, anchorPath string) error { // getLastEventFromFile returns the last event's sequence and hash from a file func getLastEventFromFile(path string) (int64, string, error) { + // #nosec G304 -- path is an internally controlled audit log file file, err := os.Open(path) if err != nil { return 0, "", err diff --git a/internal/audit/sealed.go b/internal/audit/sealed.go index d533a7b..0b42fd7 100644 --- a/internal/audit/sealed.go +++ b/internal/audit/sealed.go @@ -62,13 +62,17 @@ func (ssm *SealedStateManager) Checkpoint(seq uint64, hash string) error { } if _, err := f.Write(append(data, '\n')); err != nil { - f.Close() + if errClose := f.Close(); errClose != nil { + return fmt.Errorf("write chain entry: %w, close: %v", err, errClose) + } return fmt.Errorf("write chain entry: %w", err) } // CRITICAL: fsync chain before returning — crash safety if err := f.Sync(); err != nil { - f.Close() + if errClose := f.Close(); errClose != nil { + return fmt.Errorf("sync sealed chain: %w, close: %v", err, errClose) + } return fmt.Errorf("sync sealed chain: %w", err) } diff --git a/internal/audit/verifier.go b/internal/audit/verifier.go index 7479baa..26edb29 100644 --- a/internal/audit/verifier.go +++ b/internal/audit/verifier.go @@ -55,6 +55,7 @@ func (cv *ChainVerifier) VerifyLogFile(logPath string) (*VerificationResult, err } // Open the log file + // #nosec G304 -- logPath is an audit log path, not arbitrary user input file, err := os.Open(logPath) if err != nil { if os.IsNotExist(err) { @@ -187,6 +188,7 @@ func (cv *ChainVerifier) VerifyAndAlert(logPath string) (bool, error) { // GetChainRootHash returns the hash of the last event in the chain // This can be published to an external append-only store for independent verification func (cv *ChainVerifier) GetChainRootHash(logPath string) (string, error) { + // #nosec G304 -- logPath is an audit log path, not arbitrary user input file, err := os.Open(logPath) if err != nil { return "", err diff --git a/internal/middleware/audit.go b/internal/middleware/audit.go new file mode 100644 index 0000000..146b67c --- /dev/null +++ b/internal/middleware/audit.go @@ -0,0 +1,200 @@ +// Package middleware provides HTTP middleware including audit logging +package middleware + +import ( + "net/http" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// Middleware provides audit logging for task access +type Middleware struct { + db *storage.DB + logger *logging.Logger +} + +// NewMiddleware creates a new audit logging middleware +func NewMiddleware(db *storage.DB, logger *logging.Logger) *Middleware { + return &Middleware{ + db: db, + logger: logger, + } +} + +// Logger returns an HTTP middleware that logs task access to the audit log. +// It should be applied to routes that access tasks (view, clone, execute, modify). +func (m *Middleware) Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract task ID from URL path if present + taskID := extractTaskID(r.URL.Path) + if taskID == "" { + // No task ID in path, skip audit logging + next.ServeHTTP(w, r) + return + } + + // Determine action based on HTTP method and path + action := determineAction(r.Method, r.URL.Path) + + // Get user or token + var userID, token *string + user := auth.GetUserFromContext(r.Context()) + if user != nil { + u := user.Name + userID = &u + } else { + // Check for token in query params + t := r.URL.Query().Get("token") + if t != "" { + token = &t + } + } + + // Get IP address + ipStr := getClientIP(r) + ip := &ipStr + + // Log the access + if err := m.db.LogTaskAccess(taskID, userID, token, &action, ip); err != nil { + m.logger.Error("failed to log task access", "error", err, "task_id", taskID) + // Don't fail the request, just log the error + } + + next.ServeHTTP(w, r) + }) +} + +// extractTaskID extracts task ID from URL path patterns like: +// /api/tasks/{id} +// /api/tasks/{id}/clone +// /api/tasks/{id}/execute +func extractTaskID(path string) string { + // Remove query string if present + if idx := strings.Index(path, "?"); idx != -1 { + path = path[:idx] + } + + // Check for task patterns + if !strings.Contains(path, "/tasks/") { + return "" + } + + parts := strings.Split(path, "/") + for i, part := range parts { + if part == "tasks" && i+1 < len(parts) { + taskID := parts[i+1] + // Validate it's not a sub-path like "tasks" or "all" + if taskID != "" && taskID != "all" && taskID != "list" { + return taskID + } + } + } + + return "" +} + +// determineAction maps HTTP method and path to audit action. +func determineAction(method, path string) string { + lowerPath := strings.ToLower(path) + + switch method { + case http.MethodGet: + if strings.Contains(lowerPath, "/clone") { + return "clone" + } + return "view" + case http.MethodPost, http.MethodPut, http.MethodPatch: + if strings.Contains(lowerPath, "/execute") || strings.Contains(lowerPath, "/run") { + return "execute" + } + return "modify" + case http.MethodDelete: + return "delete" + default: + return "view" + } +} + +// RetentionJob runs the nightly audit log retention cleanup. +// It deletes audit log entries older than the configured retention period. +type RetentionJob struct { + db *storage.DB + logger *logging.Logger + retentionDays int +} + +// NewRetentionJob creates a new retention job. +// Default retention is 2 years (730 days) if not specified. +func NewRetentionJob(db *storage.DB, logger *logging.Logger, retentionDays int) *RetentionJob { + if retentionDays <= 0 { + retentionDays = 730 // 2 years default + } + return &RetentionJob{ + db: db, + logger: logger, + retentionDays: retentionDays, + } +} + +// Run executes the retention cleanup once. +func (j *RetentionJob) Run() error { + j.logger.Info("starting audit log retention cleanup", "retention_days", j.retentionDays) + + deleted, err := j.db.DeleteOldAuditLogs(j.retentionDays) + if err != nil { + j.logger.Error("audit log retention cleanup failed", "error", err) + return err + } + + j.logger.Info("audit log retention cleanup completed", "deleted_entries", deleted) + return nil +} + +// RunPeriodic runs the retention job at the specified interval. +// This should be called in a goroutine at application startup. +func (j *RetentionJob) RunPeriodic(interval time.Duration, stopCh <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Run immediately on startup + if err := j.Run(); err != nil { + j.logger.Error("initial audit log retention cleanup failed", "error", err) + } + + for { + select { + case <-ticker.C: + if err := j.Run(); err != nil { + j.logger.Error("periodic audit log retention cleanup failed", "error", err) + } + case <-stopCh: + j.logger.Info("stopping audit log retention job") + return + } + } +} + +// StartNightlyRetentionJob starts a retention job that runs once per day at midnight UTC. +func StartNightlyRetentionJob(db *storage.DB, logger *logging.Logger, retentionDays int, stopCh <-chan struct{}) { + job := NewRetentionJob(db, logger, retentionDays) + + // Calculate time until next midnight UTC + now := time.Now().UTC() + nextMidnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.UTC) + durationUntilMidnight := nextMidnight.Sub(now) + + logger.Info("scheduling nightly audit log retention job", + "next_run", nextMidnight.Format(time.RFC3339), + "retention_days", retentionDays, + ) + + // Wait until midnight, then start the periodic ticker + go func() { + time.Sleep(durationUntilMidnight) + job.RunPeriodic(24*time.Hour, stopCh) + }() +}