diff --git a/cmd/audit-verifier/main.go b/cmd/audit-verifier/main.go new file mode 100644 index 0000000..0255295 --- /dev/null +++ b/cmd/audit-verifier/main.go @@ -0,0 +1,95 @@ +// Package main implements the audit-verifier standalone verification tool +package main + +import ( + "flag" + "fmt" + "log/slog" + "os" + "time" + + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/logging" +) + +func main() { + var ( + logPath string + interval time.Duration + continuous bool + verbose bool + ) + + flag.StringVar(&logPath, "log-path", "", "Path to audit log file to verify (required)") + flag.DurationVar(&interval, "interval", 15*time.Minute, "Verification interval for continuous mode") + flag.BoolVar(&continuous, "continuous", false, "Run continuous verification in a loop") + flag.BoolVar(&verbose, "verbose", false, "Enable verbose output") + flag.Parse() + + if logPath == "" { + fmt.Fprintln(os.Stderr, "Error: -log-path is required") + flag.Usage() + os.Exit(1) + } + + // Setup logging + logLevel := slog.LevelInfo + if verbose { + logLevel = slog.LevelDebug + } + logger := logging.NewLogger(logLevel, false) + + verifier := audit.NewChainVerifier(logger) + + if continuous { + fmt.Printf("Starting continuous audit verification every %v...\n", interval) + fmt.Printf("Press Ctrl+C to stop\n\n") + + // Run with alert function that prints to stdout + verifier.ContinuousVerification(logPath, interval, func(result *audit.VerificationResult) { + printResult(result) + if !result.Valid { + // In continuous mode, we don't exit on tampering - we keep monitoring + // The alert function should notify appropriate channels (email, slack, etc.) + fmt.Println("\n*** TAMPERING DETECTED - INVESTIGATE IMMEDIATELY ***") + } + }) + } else { + // Single verification run + fmt.Printf("Verifying audit log: %s\n", logPath) + + result, err := verifier.VerifyLogFile(logPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err) + os.Exit(1) + } + + printResult(result) + + if !result.Valid { + fmt.Println("\n*** VERIFICATION FAILED - AUDIT CHAIN TAMPERING DETECTED ***") + os.Exit(2) + } + + fmt.Println("\n✓ Audit chain integrity verified") + } +} + +func printResult(result *audit.VerificationResult) { + fmt.Printf("\nVerification Time: %s\n", result.Timestamp.Format(time.RFC3339)) + fmt.Printf("Total Events: %d\n", result.TotalEvents) + fmt.Printf("Valid: %v\n", result.Valid) + + if result.ChainRootHash != "" { + fmt.Printf("Chain Root Hash: %s...\n", result.ChainRootHash[:16]) + } + + if !result.Valid { + if result.FirstTampered != -1 { + fmt.Printf("First Tampered Event: %d\n", result.FirstTampered) + } + if result.Error != "" { + fmt.Printf("Error: %s\n", result.Error) + } + } +} diff --git a/internal/audit/audit.go b/internal/audit/audit.go index a87ed28..671a9a5 100644 --- a/internal/audit/audit.go +++ b/internal/audit/audit.go @@ -100,7 +100,7 @@ func (al *Logger) Log(event Event) { event.PrevHash = al.lastHash // Calculate hash of this event for tamper evidence - event.EventHash = al.calculateEventHash(event) + event.EventHash = al.CalculateEventHash(event) al.lastHash = event.EventHash // Marshal to JSON @@ -133,8 +133,9 @@ func (al *Logger) Log(event Event) { } } -// calculateEventHash computes SHA-256 hash of event data for integrity chain -func (al *Logger) calculateEventHash(event Event) string { +// CalculateEventHash computes SHA-256 hash of event data for integrity chain +// Exported for testing purposes +func (al *Logger) CalculateEventHash(event Event) string { // Create a copy without the hash field for hashing eventCopy := event eventCopy.EventHash = "" @@ -195,7 +196,7 @@ func (al *Logger) VerifyChain(events []Event) (tamperedSeq int, err error) { } // Verify event hash - expectedHash := al.calculateEventHash(event) + expectedHash := al.CalculateEventHash(event) if event.EventHash != expectedHash { return int(event.SequenceNum), fmt.Errorf( "hash mismatch at sequence %d: expected %s, got %s", diff --git a/internal/audit/verifier.go b/internal/audit/verifier.go new file mode 100644 index 0000000..af1f78b --- /dev/null +++ b/internal/audit/verifier.go @@ -0,0 +1,219 @@ +package audit + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// ChainVerifier provides continuous verification of audit log integrity +// by checking the chained hash structure and detecting any tampering. +type ChainVerifier struct { + logger *logging.Logger +} + +// NewChainVerifier creates a new audit chain verifier +func NewChainVerifier(logger *logging.Logger) *ChainVerifier { + return &ChainVerifier{ + logger: logger, + } +} + +// VerificationResult contains the outcome of a chain verification +//type VerificationResult struct { +// Timestamp time.Time +// TotalEvents int +// Valid bool +// FirstTampered int64 // Sequence number of first tampered event, -1 if none +// Error string // Error message if verification failed +// ChainRootHash string // Hash of the last valid event (for external verification) +//} + +// VerificationResult contains the outcome of a chain verification +type VerificationResult struct { + Timestamp time.Time + TotalEvents int + Valid bool + FirstTampered int64 // Sequence number of first tampered event, -1 if none + Error string // Error message if verification failed + ChainRootHash string // Hash of the last valid event (for external verification) +} + +// VerifyLogFile performs a complete verification of an audit log file. +// It checks the integrity chain by verifying each event's hash and +// ensuring the previous hash links are unbroken. +func (cv *ChainVerifier) VerifyLogFile(logPath string) (*VerificationResult, error) { + result := &VerificationResult{ + Timestamp: time.Now().UTC(), + TotalEvents: 0, + Valid: true, + FirstTampered: -1, + } + + // Open the log file + file, err := os.Open(logPath) + if err != nil { + if os.IsNotExist(err) { + // No log file yet - this is valid (no entries to verify) + return result, nil + } + result.Valid = false + result.Error = fmt.Sprintf("failed to open log file: %v", err) + return result, err + } + defer file.Close() + + // Create a temporary logger to calculate hashes + tempLogger, _ := NewLogger(false, "", cv.logger) + + var events []Event + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := scanner.Text() + if line == "" { + continue + } + + var event Event + if err := json.Unmarshal([]byte(line), &event); err != nil { + result.Valid = false + result.Error = fmt.Sprintf("failed to parse event at line %d: %v", lineNum, err) + return result, fmt.Errorf("parse error at line %d: %w", lineNum, err) + } + + events = append(events, event) + result.TotalEvents++ + } + + if err := scanner.Err(); err != nil { + result.Valid = false + result.Error = fmt.Sprintf("error reading log file: %v", err) + return result, err + } + + // Verify the chain + tamperedSeq, err := tempLogger.VerifyChain(events) + if err != nil { + result.Valid = false + result.FirstTampered = int64(tamperedSeq) + result.Error = err.Error() + return result, err + } + + if tamperedSeq != -1 { + result.Valid = false + result.FirstTampered = int64(tamperedSeq) + result.Error = fmt.Sprintf("tampering detected at sequence %d", tamperedSeq) + } + + // Set the chain root hash (hash of the last event) + if len(events) > 0 { + lastEvent := events[len(events)-1] + result.ChainRootHash = lastEvent.EventHash + } + + return result, nil +} + +// ContinuousVerification runs verification at regular intervals and reports any issues. +// This should be run as a background goroutine in long-running services. +func (cv *ChainVerifier) ContinuousVerification(logPath string, interval time.Duration, alertFunc func(*VerificationResult)) { + if interval <= 0 { + interval = 15 * time.Minute // Default: 15 minutes for HIPAA, use 1 hour otherwise + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Run initial verification + cv.runAndReport(logPath, alertFunc) + + for range ticker.C { + cv.runAndReport(logPath, alertFunc) + } +} + +// runAndReport performs verification and calls the alert function if issues are found +func (cv *ChainVerifier) runAndReport(logPath string, alertFunc func(*VerificationResult)) { + result, err := cv.VerifyLogFile(logPath) + if err != nil { + if cv.logger != nil { + cv.logger.Error("audit chain verification error", "error", err, "log_path", logPath) + } + // Still report the error + if alertFunc != nil { + alertFunc(result) + } + return + } + + // Report if not valid or if we just want to log successful verification periodically + if !result.Valid { + if cv.logger != nil { + cv.logger.Error("audit chain tampering detected", + "first_tampered", result.FirstTampered, + "total_events", result.TotalEvents, + "chain_root", result.ChainRootHash[:16]) + } + if alertFunc != nil { + alertFunc(result) + } + } else { + if cv.logger != nil { + cv.logger.Debug("audit chain verification passed", + "total_events", result.TotalEvents, + "chain_root", result.ChainRootHash[:16]) + } + } +} + +// VerifyAndAlert performs a single verification and returns true if tampering detected +func (cv *ChainVerifier) VerifyAndAlert(logPath string) (bool, error) { + result, err := cv.VerifyLogFile(logPath) + if err != nil { + return true, err // Treat errors as potential tampering + } + + return !result.Valid, nil +} + +// GetChainRootHash returns the hash of the last event in the chain +// This can be published to an external append-only store for independent verification +func (cv *ChainVerifier) GetChainRootHash(logPath string) (string, error) { + file, err := os.Open(logPath) + if err != nil { + return "", err + } + defer file.Close() + + var lastLine string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + lastLine = line + } + } + + if err := scanner.Err(); err != nil { + return "", err + } + + if lastLine == "" { + return "", fmt.Errorf("no events in log file") + } + + var event Event + if err := json.Unmarshal([]byte(lastLine), &event); err != nil { + return "", fmt.Errorf("failed to parse last event: %w", err) + } + + return event.EventHash, nil +} diff --git a/tests/unit/audit/verifier_test.go b/tests/unit/audit/verifier_test.go new file mode 100644 index 0000000..69b1fc2 --- /dev/null +++ b/tests/unit/audit/verifier_test.go @@ -0,0 +1,243 @@ +package audit_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// replaceAllSubstr replaces all occurrences of old with new in s +func replaceAllSubstr(s, old, new string) string { + return strings.ReplaceAll(s, old, new) +} + +func TestChainVerifier_VerifyLogFile_EmptyLog(t *testing.T) { + logger := logging.NewLogger(0, false) + verifier := audit.NewChainVerifier(logger) + + // Test with non-existent file (empty log case) + result, err := verifier.VerifyLogFile("/nonexistent/path/audit.log") + if err != nil { + t.Fatalf("expected no error for non-existent file, got: %v", err) + } + + if !result.Valid { + t.Error("expected valid result for empty log") + } + if result.TotalEvents != 0 { + t.Errorf("expected 0 events, got %d", result.TotalEvents) + } +} + +func TestChainVerifier_VerifyLogFile_ValidChain(t *testing.T) { + logger := logging.NewLogger(0, false) + verifier := audit.NewChainVerifier(logger) + + // Create a temporary log file with valid chain + tempDir := t.TempDir() + logPath := filepath.Join(tempDir, "audit.log") + + // Create logger that writes to file + al, err := audit.NewLogger(true, logPath, logger) + if err != nil { + t.Fatalf("failed to create audit logger: %v", err) + } + defer al.Close() + + // Log some events + testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) + al.Log(audit.Event{ + Timestamp: testTime, + EventType: audit.EventJobStarted, + UserID: "user1", + Success: true, + }) + al.Log(audit.Event{ + Timestamp: testTime.Add(time.Second), + EventType: audit.EventJobCompleted, + UserID: "user1", + Success: true, + }) + al.Log(audit.Event{ + Timestamp: testTime.Add(2 * time.Second), + EventType: audit.EventFileRead, + UserID: "user1", + Success: true, + }) + + // Verify the chain + result, err := verifier.VerifyLogFile(logPath) + if err != nil { + t.Fatalf("verification failed: %v", err) + } + + if !result.Valid { + t.Errorf("expected valid chain, got error: %s", result.Error) + } + if result.TotalEvents != 3 { + t.Errorf("expected 3 events, got %d", result.TotalEvents) + } + if result.ChainRootHash == "" { + t.Error("expected non-empty chain root hash") + } +} + +func TestChainVerifier_VerifyLogFile_TamperedChain(t *testing.T) { + logger := logging.NewLogger(0, false) + verifier := audit.NewChainVerifier(logger) + + // Create a temporary log file + tempDir := t.TempDir() + logPath := filepath.Join(tempDir, "audit.log") + + // Create logger and log events + al, err := audit.NewLogger(true, logPath, logger) + if err != nil { + t.Fatalf("failed to create audit logger: %v", err) + } + + testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) + al.Log(audit.Event{ + Timestamp: testTime, + EventType: audit.EventJobStarted, + UserID: "user1", + Success: true, + }) + al.Log(audit.Event{ + Timestamp: testTime.Add(time.Second), + EventType: audit.EventJobCompleted, + UserID: "user1", + Success: true, + }) + al.Close() + + // Tamper with the file by modifying the second event + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + + // Replace "user1" with "attacker" in the content + content := string(data) + // Simple string replacement + tamperedContent := replaceAllSubstr(content, "user1", "attacker") + + // Write back tampered data + if err := os.WriteFile(logPath, []byte(tamperedContent), 0600); err != nil { + t.Fatalf("failed to write tampered log: %v", err) + } + + // Verify should detect tampering + result, err := verifier.VerifyLogFile(logPath) + if err == nil && result.Valid { + t.Error("expected tampering to be detected") + } + + if result.FirstTampered == -1 { + t.Error("expected FirstTampered to be set") + } +} + +func TestChainVerifier_GetChainRootHash(t *testing.T) { + logger := logging.NewLogger(0, false) + verifier := audit.NewChainVerifier(logger) + + tempDir := t.TempDir() + logPath := filepath.Join(tempDir, "audit.log") + + // Create logger and log events + al, err := audit.NewLogger(true, logPath, logger) + if err != nil { + t.Fatalf("failed to create audit logger: %v", err) + } + + testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) + al.Log(audit.Event{ + Timestamp: testTime, + EventType: audit.EventJobStarted, + UserID: "user1", + Success: true, + }) + al.Close() + + // Get chain root hash + rootHash, err := verifier.GetChainRootHash(logPath) + if err != nil { + t.Fatalf("failed to get chain root hash: %v", err) + } + + if rootHash == "" { + t.Error("expected non-empty root hash") + } + + // Verify it matches the result from VerifyLogFile + result, err := verifier.VerifyLogFile(logPath) + if err != nil { + t.Fatalf("verification failed: %v", err) + } + + if rootHash != result.ChainRootHash { + t.Errorf("root hash mismatch: GetChainRootHash=%s, VerifyLogFile=%s", rootHash, result.ChainRootHash) + } +} + +func TestChainVerifier_VerifyAndAlert(t *testing.T) { + logger := logging.NewLogger(0, false) + verifier := audit.NewChainVerifier(logger) + + tempDir := t.TempDir() + logPath := filepath.Join(tempDir, "audit.log") + + // Create valid log + al, err := audit.NewLogger(true, logPath, logger) + if err != nil { + t.Fatalf("failed to create audit logger: %v", err) + } + + al.Log(audit.Event{ + Timestamp: time.Now(), + EventType: audit.EventJobStarted, + UserID: "user1", + Success: true, + }) + al.Close() + + // VerifyAndAlert should return false (no tampering) + tampered, err := verifier.VerifyAndAlert(logPath) + if err != nil { + t.Fatalf("verification failed: %v", err) + } + if tampered { + t.Error("expected no tampering detected for valid log") + } + + // Test with non-existent file + tampered, err = verifier.VerifyAndAlert("/nonexistent/path/audit.log") + if err != nil { + t.Fatalf("expected no error for non-existent file: %v", err) + } + if tampered { + t.Error("expected no tampering for empty log") + } +} + +// splitLines splits byte slice by newlines +func splitLines(data []byte) [][]byte { + var lines [][]byte + start := 0 + for i := 0; i < len(data); i++ { + if data[i] == '\n' { + lines = append(lines, data[start:i]) + start = i + 1 + } + } + if start < len(data) { + lines = append(lines, data[start:]) + } + return lines +} diff --git a/tests/unit/security/audit_test.go b/tests/unit/security/audit_test.go index 07000f0..01888ea 100644 --- a/tests/unit/security/audit_test.go +++ b/tests/unit/security/audit_test.go @@ -3,6 +3,7 @@ package security import ( "log/slog" "testing" + "time" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/logging" @@ -43,11 +44,12 @@ func TestAuditLogger_VerifyChain(t *testing.T) { } defer al.Close() - // Create a valid chain of events + // Create a valid chain of events with real hashes events := []audit.Event{ { EventType: audit.EventAuthSuccess, UserID: "user1", + Timestamp: time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), SequenceNum: 1, PrevHash: "", }, @@ -55,29 +57,60 @@ func TestAuditLogger_VerifyChain(t *testing.T) { EventType: audit.EventFileRead, UserID: "user1", Resource: "/data/file.txt", + Timestamp: time.Date(2026, 1, 1, 12, 1, 0, 0, time.UTC), SequenceNum: 2, }, { EventType: audit.EventFileWrite, UserID: "user1", Resource: "/data/output.txt", + Timestamp: time.Date(2026, 1, 1, 12, 2, 0, 0, time.UTC), SequenceNum: 3, }, } - // Calculate hashes for each event + // Calculate real hashes for each event for i := range events { if i > 0 { events[i].PrevHash = events[i-1].EventHash } - // We can't easily call calculateEventHash since it's private - // In real test, we'd use the logged events - events[i].EventHash = "dummy_hash_for_testing" + events[i].EventHash = al.CalculateEventHash(events[i]) } - // Test verification with valid chain - // In real scenario, we'd verify the actual hashes - _, _ = al.VerifyChain(events) + // Test verification with valid chain - should pass + tamperedSeq, err := al.VerifyChain(events) + if err != nil { + t.Errorf("VerifyChain failed for valid chain: %v", err) + } + if tamperedSeq != -1 { + t.Errorf("Expected valid chain (tamperedSeq=-1), got %d", tamperedSeq) + } + + // Test tamper detection - modify an event hash + tamperedEvents := make([]audit.Event, len(events)) + copy(tamperedEvents, events) + tamperedEvents[1].EventHash = "tampered_hash_1234567890abcdef" + + tamperedSeq, err = al.VerifyChain(tamperedEvents) + if err == nil { + t.Error("Expected error for tampered chain, got nil") + } + if tamperedSeq != 2 { + t.Errorf("Expected tamperedSeq=2, got %d", tamperedSeq) + } + + // Test chain break detection - modify prev_hash + brokenEvents := make([]audit.Event, len(events)) + copy(brokenEvents, events) + brokenEvents[1].PrevHash = "wrong_prev_hash_1234567890abcdef" + + tamperedSeq, err = al.VerifyChain(brokenEvents) + if err == nil { + t.Error("Expected error for broken chain, got nil") + } + if tamperedSeq != 2 { + t.Errorf("Expected tamperedSeq=2 for broken chain, got %d", tamperedSeq) + } } func TestAuditLogger_LogFileAccess(t *testing.T) {