feat(audit): Tamper-evident audit chain verification system
Add ChainVerifier for cryptographic audit log verification: - VerifyLogFile(): Validates entire audit chain integrity - Detects tampering at specific event index (FirstTampered) - Returns chain root hash for external verification - GetChainRootHash(): Standalone hash computation - VerifyAndAlert(): Boolean tampering detection with logging Add audit-verifier CLI tool: - Standalone binary for audit chain verification - Takes log path argument and reports tampering Update audit logger for chain integrity: - Each event includes sequence number and hash chain - SHA-256 linking: hash_n = SHA-256(prev_hash || event_n) - Tamper detection through hash chain validation Add comprehensive test coverage: - Empty log handling - Valid chain verification - Tampering detection with modification - Root hash consistency - Alert mechanism tests Part of: V.7 audit verification from security plan
This commit is contained in:
parent
4a4d3de8e1
commit
58c1a5fa58
5 changed files with 603 additions and 12 deletions
95
cmd/audit-verifier/main.go
Normal file
95
cmd/audit-verifier/main.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
// Package main implements the audit-verifier standalone verification tool
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
logPath string
|
||||
interval time.Duration
|
||||
continuous bool
|
||||
verbose bool
|
||||
)
|
||||
|
||||
flag.StringVar(&logPath, "log-path", "", "Path to audit log file to verify (required)")
|
||||
flag.DurationVar(&interval, "interval", 15*time.Minute, "Verification interval for continuous mode")
|
||||
flag.BoolVar(&continuous, "continuous", false, "Run continuous verification in a loop")
|
||||
flag.BoolVar(&verbose, "verbose", false, "Enable verbose output")
|
||||
flag.Parse()
|
||||
|
||||
if logPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: -log-path is required")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Setup logging
|
||||
logLevel := slog.LevelInfo
|
||||
if verbose {
|
||||
logLevel = slog.LevelDebug
|
||||
}
|
||||
logger := logging.NewLogger(logLevel, false)
|
||||
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
if continuous {
|
||||
fmt.Printf("Starting continuous audit verification every %v...\n", interval)
|
||||
fmt.Printf("Press Ctrl+C to stop\n\n")
|
||||
|
||||
// Run with alert function that prints to stdout
|
||||
verifier.ContinuousVerification(logPath, interval, func(result *audit.VerificationResult) {
|
||||
printResult(result)
|
||||
if !result.Valid {
|
||||
// In continuous mode, we don't exit on tampering - we keep monitoring
|
||||
// The alert function should notify appropriate channels (email, slack, etc.)
|
||||
fmt.Println("\n*** TAMPERING DETECTED - INVESTIGATE IMMEDIATELY ***")
|
||||
}
|
||||
})
|
||||
} else {
|
||||
// Single verification run
|
||||
fmt.Printf("Verifying audit log: %s\n", logPath)
|
||||
|
||||
result, err := verifier.VerifyLogFile(logPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
printResult(result)
|
||||
|
||||
if !result.Valid {
|
||||
fmt.Println("\n*** VERIFICATION FAILED - AUDIT CHAIN TAMPERING DETECTED ***")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Audit chain integrity verified")
|
||||
}
|
||||
}
|
||||
|
||||
func printResult(result *audit.VerificationResult) {
|
||||
fmt.Printf("\nVerification Time: %s\n", result.Timestamp.Format(time.RFC3339))
|
||||
fmt.Printf("Total Events: %d\n", result.TotalEvents)
|
||||
fmt.Printf("Valid: %v\n", result.Valid)
|
||||
|
||||
if result.ChainRootHash != "" {
|
||||
fmt.Printf("Chain Root Hash: %s...\n", result.ChainRootHash[:16])
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
if result.FirstTampered != -1 {
|
||||
fmt.Printf("First Tampered Event: %d\n", result.FirstTampered)
|
||||
}
|
||||
if result.Error != "" {
|
||||
fmt.Printf("Error: %s\n", result.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -100,7 +100,7 @@ func (al *Logger) Log(event Event) {
|
|||
event.PrevHash = al.lastHash
|
||||
|
||||
// Calculate hash of this event for tamper evidence
|
||||
event.EventHash = al.calculateEventHash(event)
|
||||
event.EventHash = al.CalculateEventHash(event)
|
||||
al.lastHash = event.EventHash
|
||||
|
||||
// Marshal to JSON
|
||||
|
|
@ -133,8 +133,9 @@ func (al *Logger) Log(event Event) {
|
|||
}
|
||||
}
|
||||
|
||||
// calculateEventHash computes SHA-256 hash of event data for integrity chain
|
||||
func (al *Logger) calculateEventHash(event Event) string {
|
||||
// CalculateEventHash computes SHA-256 hash of event data for integrity chain
|
||||
// Exported for testing purposes
|
||||
func (al *Logger) CalculateEventHash(event Event) string {
|
||||
// Create a copy without the hash field for hashing
|
||||
eventCopy := event
|
||||
eventCopy.EventHash = ""
|
||||
|
|
@ -195,7 +196,7 @@ func (al *Logger) VerifyChain(events []Event) (tamperedSeq int, err error) {
|
|||
}
|
||||
|
||||
// Verify event hash
|
||||
expectedHash := al.calculateEventHash(event)
|
||||
expectedHash := al.CalculateEventHash(event)
|
||||
if event.EventHash != expectedHash {
|
||||
return int(event.SequenceNum), fmt.Errorf(
|
||||
"hash mismatch at sequence %d: expected %s, got %s",
|
||||
|
|
|
|||
219
internal/audit/verifier.go
Normal file
219
internal/audit/verifier.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// ChainVerifier provides continuous verification of audit log integrity
|
||||
// by checking the chained hash structure and detecting any tampering.
|
||||
type ChainVerifier struct {
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewChainVerifier creates a new audit chain verifier
|
||||
func NewChainVerifier(logger *logging.Logger) *ChainVerifier {
|
||||
return &ChainVerifier{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// VerificationResult contains the outcome of a chain verification
|
||||
//type VerificationResult struct {
|
||||
// Timestamp time.Time
|
||||
// TotalEvents int
|
||||
// Valid bool
|
||||
// FirstTampered int64 // Sequence number of first tampered event, -1 if none
|
||||
// Error string // Error message if verification failed
|
||||
// ChainRootHash string // Hash of the last valid event (for external verification)
|
||||
//}
|
||||
|
||||
// VerificationResult contains the outcome of a chain verification
|
||||
type VerificationResult struct {
|
||||
Timestamp time.Time
|
||||
TotalEvents int
|
||||
Valid bool
|
||||
FirstTampered int64 // Sequence number of first tampered event, -1 if none
|
||||
Error string // Error message if verification failed
|
||||
ChainRootHash string // Hash of the last valid event (for external verification)
|
||||
}
|
||||
|
||||
// VerifyLogFile performs a complete verification of an audit log file.
|
||||
// It checks the integrity chain by verifying each event's hash and
|
||||
// ensuring the previous hash links are unbroken.
|
||||
func (cv *ChainVerifier) VerifyLogFile(logPath string) (*VerificationResult, error) {
|
||||
result := &VerificationResult{
|
||||
Timestamp: time.Now().UTC(),
|
||||
TotalEvents: 0,
|
||||
Valid: true,
|
||||
FirstTampered: -1,
|
||||
}
|
||||
|
||||
// Open the log file
|
||||
file, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// No log file yet - this is valid (no entries to verify)
|
||||
return result, nil
|
||||
}
|
||||
result.Valid = false
|
||||
result.Error = fmt.Sprintf("failed to open log file: %v", err)
|
||||
return result, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create a temporary logger to calculate hashes
|
||||
tempLogger, _ := NewLogger(false, "", cv.logger)
|
||||
|
||||
var events []Event
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(line), &event); err != nil {
|
||||
result.Valid = false
|
||||
result.Error = fmt.Sprintf("failed to parse event at line %d: %v", lineNum, err)
|
||||
return result, fmt.Errorf("parse error at line %d: %w", lineNum, err)
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
result.TotalEvents++
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
result.Valid = false
|
||||
result.Error = fmt.Sprintf("error reading log file: %v", err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Verify the chain
|
||||
tamperedSeq, err := tempLogger.VerifyChain(events)
|
||||
if err != nil {
|
||||
result.Valid = false
|
||||
result.FirstTampered = int64(tamperedSeq)
|
||||
result.Error = err.Error()
|
||||
return result, err
|
||||
}
|
||||
|
||||
if tamperedSeq != -1 {
|
||||
result.Valid = false
|
||||
result.FirstTampered = int64(tamperedSeq)
|
||||
result.Error = fmt.Sprintf("tampering detected at sequence %d", tamperedSeq)
|
||||
}
|
||||
|
||||
// Set the chain root hash (hash of the last event)
|
||||
if len(events) > 0 {
|
||||
lastEvent := events[len(events)-1]
|
||||
result.ChainRootHash = lastEvent.EventHash
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ContinuousVerification runs verification at regular intervals and reports any issues.
|
||||
// This should be run as a background goroutine in long-running services.
|
||||
func (cv *ChainVerifier) ContinuousVerification(logPath string, interval time.Duration, alertFunc func(*VerificationResult)) {
|
||||
if interval <= 0 {
|
||||
interval = 15 * time.Minute // Default: 15 minutes for HIPAA, use 1 hour otherwise
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run initial verification
|
||||
cv.runAndReport(logPath, alertFunc)
|
||||
|
||||
for range ticker.C {
|
||||
cv.runAndReport(logPath, alertFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// runAndReport performs verification and calls the alert function if issues are found
|
||||
func (cv *ChainVerifier) runAndReport(logPath string, alertFunc func(*VerificationResult)) {
|
||||
result, err := cv.VerifyLogFile(logPath)
|
||||
if err != nil {
|
||||
if cv.logger != nil {
|
||||
cv.logger.Error("audit chain verification error", "error", err, "log_path", logPath)
|
||||
}
|
||||
// Still report the error
|
||||
if alertFunc != nil {
|
||||
alertFunc(result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Report if not valid or if we just want to log successful verification periodically
|
||||
if !result.Valid {
|
||||
if cv.logger != nil {
|
||||
cv.logger.Error("audit chain tampering detected",
|
||||
"first_tampered", result.FirstTampered,
|
||||
"total_events", result.TotalEvents,
|
||||
"chain_root", result.ChainRootHash[:16])
|
||||
}
|
||||
if alertFunc != nil {
|
||||
alertFunc(result)
|
||||
}
|
||||
} else {
|
||||
if cv.logger != nil {
|
||||
cv.logger.Debug("audit chain verification passed",
|
||||
"total_events", result.TotalEvents,
|
||||
"chain_root", result.ChainRootHash[:16])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyAndAlert performs a single verification and returns true if tampering detected
|
||||
func (cv *ChainVerifier) VerifyAndAlert(logPath string) (bool, error) {
|
||||
result, err := cv.VerifyLogFile(logPath)
|
||||
if err != nil {
|
||||
return true, err // Treat errors as potential tampering
|
||||
}
|
||||
|
||||
return !result.Valid, nil
|
||||
}
|
||||
|
||||
// GetChainRootHash returns the hash of the last event in the chain
|
||||
// This can be published to an external append-only store for independent verification
|
||||
func (cv *ChainVerifier) GetChainRootHash(logPath string) (string, error) {
|
||||
file, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var lastLine string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
lastLine = line
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if lastLine == "" {
|
||||
return "", fmt.Errorf("no events in log file")
|
||||
}
|
||||
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(lastLine), &event); err != nil {
|
||||
return "", fmt.Errorf("failed to parse last event: %w", err)
|
||||
}
|
||||
|
||||
return event.EventHash, nil
|
||||
}
|
||||
243
tests/unit/audit/verifier_test.go
Normal file
243
tests/unit/audit/verifier_test.go
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// replaceAllSubstr replaces all occurrences of old with new in s
|
||||
func replaceAllSubstr(s, old, new string) string {
|
||||
return strings.ReplaceAll(s, old, new)
|
||||
}
|
||||
|
||||
func TestChainVerifier_VerifyLogFile_EmptyLog(t *testing.T) {
|
||||
logger := logging.NewLogger(0, false)
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
// Test with non-existent file (empty log case)
|
||||
result, err := verifier.VerifyLogFile("/nonexistent/path/audit.log")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for non-existent file, got: %v", err)
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
t.Error("expected valid result for empty log")
|
||||
}
|
||||
if result.TotalEvents != 0 {
|
||||
t.Errorf("expected 0 events, got %d", result.TotalEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainVerifier_VerifyLogFile_ValidChain(t *testing.T) {
|
||||
logger := logging.NewLogger(0, false)
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
// Create a temporary log file with valid chain
|
||||
tempDir := t.TempDir()
|
||||
logPath := filepath.Join(tempDir, "audit.log")
|
||||
|
||||
// Create logger that writes to file
|
||||
al, err := audit.NewLogger(true, logPath, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create audit logger: %v", err)
|
||||
}
|
||||
defer al.Close()
|
||||
|
||||
// Log some events
|
||||
testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC)
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime,
|
||||
EventType: audit.EventJobStarted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime.Add(time.Second),
|
||||
EventType: audit.EventJobCompleted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime.Add(2 * time.Second),
|
||||
EventType: audit.EventFileRead,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
|
||||
// Verify the chain
|
||||
result, err := verifier.VerifyLogFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("verification failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("expected valid chain, got error: %s", result.Error)
|
||||
}
|
||||
if result.TotalEvents != 3 {
|
||||
t.Errorf("expected 3 events, got %d", result.TotalEvents)
|
||||
}
|
||||
if result.ChainRootHash == "" {
|
||||
t.Error("expected non-empty chain root hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainVerifier_VerifyLogFile_TamperedChain(t *testing.T) {
|
||||
logger := logging.NewLogger(0, false)
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
// Create a temporary log file
|
||||
tempDir := t.TempDir()
|
||||
logPath := filepath.Join(tempDir, "audit.log")
|
||||
|
||||
// Create logger and log events
|
||||
al, err := audit.NewLogger(true, logPath, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create audit logger: %v", err)
|
||||
}
|
||||
|
||||
testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC)
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime,
|
||||
EventType: audit.EventJobStarted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime.Add(time.Second),
|
||||
EventType: audit.EventJobCompleted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Close()
|
||||
|
||||
// Tamper with the file by modifying the second event
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
// Replace "user1" with "attacker" in the content
|
||||
content := string(data)
|
||||
// Simple string replacement
|
||||
tamperedContent := replaceAllSubstr(content, "user1", "attacker")
|
||||
|
||||
// Write back tampered data
|
||||
if err := os.WriteFile(logPath, []byte(tamperedContent), 0600); err != nil {
|
||||
t.Fatalf("failed to write tampered log: %v", err)
|
||||
}
|
||||
|
||||
// Verify should detect tampering
|
||||
result, err := verifier.VerifyLogFile(logPath)
|
||||
if err == nil && result.Valid {
|
||||
t.Error("expected tampering to be detected")
|
||||
}
|
||||
|
||||
if result.FirstTampered == -1 {
|
||||
t.Error("expected FirstTampered to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainVerifier_GetChainRootHash(t *testing.T) {
|
||||
logger := logging.NewLogger(0, false)
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
logPath := filepath.Join(tempDir, "audit.log")
|
||||
|
||||
// Create logger and log events
|
||||
al, err := audit.NewLogger(true, logPath, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create audit logger: %v", err)
|
||||
}
|
||||
|
||||
testTime := time.Date(2026, 2, 23, 12, 0, 0, 0, time.UTC)
|
||||
al.Log(audit.Event{
|
||||
Timestamp: testTime,
|
||||
EventType: audit.EventJobStarted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Close()
|
||||
|
||||
// Get chain root hash
|
||||
rootHash, err := verifier.GetChainRootHash(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chain root hash: %v", err)
|
||||
}
|
||||
|
||||
if rootHash == "" {
|
||||
t.Error("expected non-empty root hash")
|
||||
}
|
||||
|
||||
// Verify it matches the result from VerifyLogFile
|
||||
result, err := verifier.VerifyLogFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("verification failed: %v", err)
|
||||
}
|
||||
|
||||
if rootHash != result.ChainRootHash {
|
||||
t.Errorf("root hash mismatch: GetChainRootHash=%s, VerifyLogFile=%s", rootHash, result.ChainRootHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainVerifier_VerifyAndAlert(t *testing.T) {
|
||||
logger := logging.NewLogger(0, false)
|
||||
verifier := audit.NewChainVerifier(logger)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
logPath := filepath.Join(tempDir, "audit.log")
|
||||
|
||||
// Create valid log
|
||||
al, err := audit.NewLogger(true, logPath, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create audit logger: %v", err)
|
||||
}
|
||||
|
||||
al.Log(audit.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: audit.EventJobStarted,
|
||||
UserID: "user1",
|
||||
Success: true,
|
||||
})
|
||||
al.Close()
|
||||
|
||||
// VerifyAndAlert should return false (no tampering)
|
||||
tampered, err := verifier.VerifyAndAlert(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("verification failed: %v", err)
|
||||
}
|
||||
if tampered {
|
||||
t.Error("expected no tampering detected for valid log")
|
||||
}
|
||||
|
||||
// Test with non-existent file
|
||||
tampered, err = verifier.VerifyAndAlert("/nonexistent/path/audit.log")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for non-existent file: %v", err)
|
||||
}
|
||||
if tampered {
|
||||
t.Error("expected no tampering for empty log")
|
||||
}
|
||||
}
|
||||
|
||||
// splitLines splits byte slice by newlines
|
||||
func splitLines(data []byte) [][]byte {
|
||||
var lines [][]byte
|
||||
start := 0
|
||||
for i := 0; i < len(data); i++ {
|
||||
if data[i] == '\n' {
|
||||
lines = append(lines, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
lines = append(lines, data[start:])
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package security
|
|||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
|
|
@ -43,11 +44,12 @@ func TestAuditLogger_VerifyChain(t *testing.T) {
|
|||
}
|
||||
defer al.Close()
|
||||
|
||||
// Create a valid chain of events
|
||||
// Create a valid chain of events with real hashes
|
||||
events := []audit.Event{
|
||||
{
|
||||
EventType: audit.EventAuthSuccess,
|
||||
UserID: "user1",
|
||||
Timestamp: time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
SequenceNum: 1,
|
||||
PrevHash: "",
|
||||
},
|
||||
|
|
@ -55,29 +57,60 @@ func TestAuditLogger_VerifyChain(t *testing.T) {
|
|||
EventType: audit.EventFileRead,
|
||||
UserID: "user1",
|
||||
Resource: "/data/file.txt",
|
||||
Timestamp: time.Date(2026, 1, 1, 12, 1, 0, 0, time.UTC),
|
||||
SequenceNum: 2,
|
||||
},
|
||||
{
|
||||
EventType: audit.EventFileWrite,
|
||||
UserID: "user1",
|
||||
Resource: "/data/output.txt",
|
||||
Timestamp: time.Date(2026, 1, 1, 12, 2, 0, 0, time.UTC),
|
||||
SequenceNum: 3,
|
||||
},
|
||||
}
|
||||
|
||||
// Calculate hashes for each event
|
||||
// Calculate real hashes for each event
|
||||
for i := range events {
|
||||
if i > 0 {
|
||||
events[i].PrevHash = events[i-1].EventHash
|
||||
}
|
||||
// We can't easily call calculateEventHash since it's private
|
||||
// In real test, we'd use the logged events
|
||||
events[i].EventHash = "dummy_hash_for_testing"
|
||||
events[i].EventHash = al.CalculateEventHash(events[i])
|
||||
}
|
||||
|
||||
// Test verification with valid chain
|
||||
// In real scenario, we'd verify the actual hashes
|
||||
_, _ = al.VerifyChain(events)
|
||||
// Test verification with valid chain - should pass
|
||||
tamperedSeq, err := al.VerifyChain(events)
|
||||
if err != nil {
|
||||
t.Errorf("VerifyChain failed for valid chain: %v", err)
|
||||
}
|
||||
if tamperedSeq != -1 {
|
||||
t.Errorf("Expected valid chain (tamperedSeq=-1), got %d", tamperedSeq)
|
||||
}
|
||||
|
||||
// Test tamper detection - modify an event hash
|
||||
tamperedEvents := make([]audit.Event, len(events))
|
||||
copy(tamperedEvents, events)
|
||||
tamperedEvents[1].EventHash = "tampered_hash_1234567890abcdef"
|
||||
|
||||
tamperedSeq, err = al.VerifyChain(tamperedEvents)
|
||||
if err == nil {
|
||||
t.Error("Expected error for tampered chain, got nil")
|
||||
}
|
||||
if tamperedSeq != 2 {
|
||||
t.Errorf("Expected tamperedSeq=2, got %d", tamperedSeq)
|
||||
}
|
||||
|
||||
// Test chain break detection - modify prev_hash
|
||||
brokenEvents := make([]audit.Event, len(events))
|
||||
copy(brokenEvents, events)
|
||||
brokenEvents[1].PrevHash = "wrong_prev_hash_1234567890abcdef"
|
||||
|
||||
tamperedSeq, err = al.VerifyChain(brokenEvents)
|
||||
if err == nil {
|
||||
t.Error("Expected error for broken chain, got nil")
|
||||
}
|
||||
if tamperedSeq != 2 {
|
||||
t.Errorf("Expected tamperedSeq=2 for broken chain, got %d", tamperedSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditLogger_LogFileAccess(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue