package audit_test import ( "os" "path/filepath" "strings" "testing" "time" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/logging" ) // replaceAllSubstr replaces all occurrences of old with new in s func replaceAllSubstr(s, old, new string) string { return strings.ReplaceAll(s, old, new) } func TestChainVerifier_VerifyLogFile_EmptyLog(t *testing.T) { logger := logging.NewLogger(0, false) verifier := audit.NewChainVerifier(logger) // Test with non-existent file (empty log case) result, err := verifier.VerifyLogFile("/nonexistent/path/audit.log") if err != nil { t.Fatalf("expected no error for non-existent file, got: %v", err) } if !result.Valid { t.Error("expected valid result for empty log") } if result.TotalEvents != 0 { t.Errorf("expected 0 events, got %d", result.TotalEvents) } } func TestChainVerifier_VerifyLogFile_ValidChain(t *testing.T) { logger := logging.NewLogger(0, false) verifier := audit.NewChainVerifier(logger) // Create a temporary log file with valid chain tempDir := t.TempDir() logPath := filepath.Join(tempDir, "audit.log") // Create logger that writes to file al, err := audit.NewLogger(true, logPath, logger) if err != nil { t.Fatalf("failed to create audit logger: %v", err) } defer al.Close() // Log some events testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) al.Log(audit.Event{ Timestamp: testTime, EventType: audit.EventJobStarted, UserID: "user1", Success: true, }) al.Log(audit.Event{ Timestamp: testTime.Add(time.Second), EventType: audit.EventJobCompleted, UserID: "user1", Success: true, }) al.Log(audit.Event{ Timestamp: testTime.Add(2 * time.Second), EventType: audit.EventFileRead, UserID: "user1", Success: true, }) // Verify the chain result, err := verifier.VerifyLogFile(logPath) if err != nil { t.Fatalf("verification failed: %v", err) } if !result.Valid { t.Errorf("expected valid chain, got error: %s", result.Error) } if result.TotalEvents != 3 { t.Errorf("expected 3 events, got %d", result.TotalEvents) } if result.ChainRootHash == "" { t.Error("expected non-empty chain root hash") } } func TestChainVerifier_VerifyLogFile_TamperedChain(t *testing.T) { logger := logging.NewLogger(0, false) verifier := audit.NewChainVerifier(logger) // Create a temporary log file tempDir := t.TempDir() logPath := filepath.Join(tempDir, "audit.log") // Create logger and log events al, err := audit.NewLogger(true, logPath, logger) if err != nil { t.Fatalf("failed to create audit logger: %v", err) } testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) al.Log(audit.Event{ Timestamp: testTime, EventType: audit.EventJobStarted, UserID: "user1", Success: true, }) al.Log(audit.Event{ Timestamp: testTime.Add(time.Second), EventType: audit.EventJobCompleted, UserID: "user1", Success: true, }) al.Close() // Tamper with the file by modifying the second event data, err := os.ReadFile(logPath) if err != nil { t.Fatalf("failed to read log file: %v", err) } // Replace "user1" with "attacker" in the content content := string(data) // Simple string replacement tamperedContent := replaceAllSubstr(content, "user1", "attacker") // Write back tampered data if err := os.WriteFile(logPath, []byte(tamperedContent), 0600); err != nil { t.Fatalf("failed to write tampered log: %v", err) } // Verify should detect tampering result, err := verifier.VerifyLogFile(logPath) if err == nil && result.Valid { t.Error("expected tampering to be detected") } if result.FirstTampered == -1 { t.Error("expected FirstTampered to be set") } } func TestChainVerifier_GetChainRootHash(t *testing.T) { logger := logging.NewLogger(0, false) verifier := audit.NewChainVerifier(logger) tempDir := t.TempDir() logPath := filepath.Join(tempDir, "audit.log") // Create logger and log events al, err := audit.NewLogger(true, logPath, logger) if err != nil { t.Fatalf("failed to create audit logger: %v", err) } testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC) al.Log(audit.Event{ Timestamp: testTime, EventType: audit.EventJobStarted, UserID: "user1", Success: true, }) al.Close() // Get chain root hash rootHash, err := verifier.GetChainRootHash(logPath) if err != nil { t.Fatalf("failed to get chain root hash: %v", err) } if rootHash == "" { t.Error("expected non-empty root hash") } // Verify it matches the result from VerifyLogFile result, err := verifier.VerifyLogFile(logPath) if err != nil { t.Fatalf("verification failed: %v", err) } if rootHash != result.ChainRootHash { t.Errorf("root hash mismatch: GetChainRootHash=%s, VerifyLogFile=%s", rootHash, result.ChainRootHash) } } func TestChainVerifier_VerifyAndAlert(t *testing.T) { logger := logging.NewLogger(0, false) verifier := audit.NewChainVerifier(logger) tempDir := t.TempDir() logPath := filepath.Join(tempDir, "audit.log") // Create valid log al, err := audit.NewLogger(true, logPath, logger) if err != nil { t.Fatalf("failed to create audit logger: %v", err) } al.Log(audit.Event{ Timestamp: time.Now(), EventType: audit.EventJobStarted, UserID: "user1", Success: true, }) al.Close() // VerifyAndAlert should return false (no tampering) tampered, err := verifier.VerifyAndAlert(logPath) if err != nil { t.Fatalf("verification failed: %v", err) } if tampered { t.Error("expected no tampering detected for valid log") } // Test with non-existent file tampered, err = verifier.VerifyAndAlert("/nonexistent/path/audit.log") if err != nil { t.Fatalf("expected no error for non-existent file: %v", err) } if tampered { t.Error("expected no tampering for empty log") } } // splitLines splits byte slice by newlines func splitLines(data []byte) [][]byte { var lines [][]byte start := 0 for i := 0; i < len(data); i++ { if data[i] == '\n' { lines = append(lines, data[start:i]) start = i + 1 } } if start < len(data) { lines = append(lines, data[start:]) } return lines }