fetch_ml/internal/audit/audit.go
Jeremie Fraeys 66f262d788
security: improve audit, crypto, and config handling
- Enhance audit checkpoint system
- Update KMS provider and tenant key management
- Refine configuration constants
- Improve TUI config handling
2026-03-04 13:23:42 -05:00

542 lines
14 KiB
Go

package audit
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// EventType represents the type of audit event
type EventType string
const (
EventAuthAttempt EventType = "authentication_attempt"
EventAuthSuccess EventType = "authentication_success"
EventAuthFailure EventType = "authentication_failure"
EventJobQueued EventType = "job_queued"
EventJobStarted EventType = "job_started"
EventJobCompleted EventType = "job_completed"
EventJobFailed EventType = "job_failed"
EventJupyterStart EventType = "jupyter_start"
EventJupyterStop EventType = "jupyter_stop"
EventExperimentCreated EventType = "experiment_created"
EventExperimentDeleted EventType = "experiment_deleted"
// HIPAA-specific file access events
EventFileRead EventType = "file_read"
EventFileWrite EventType = "file_write"
EventFileDelete EventType = "file_delete"
EventDatasetAccess EventType = "dataset_access"
// KMS encryption events per ADR-012 through ADR-015
EventKMSEncrypt EventType = "kms_encrypt"
EventKMSDecrypt EventType = "kms_decrypt"
EventKMSKeyCreate EventType = "kms_key_create"
EventKMSKeyRotate EventType = "kms_key_rotate"
EventKMSKeyDisable EventType = "kms_key_disable"
EventKMSKeyDelete EventType = "kms_key_delete"
)
// Event represents an audit log event with integrity chain.
// SECURITY NOTE: Metadata uses map[string]any which relies on Go 1.20+'s
// guaranteed stable JSON key ordering for hash determinism. If you need to
// hash events externally, ensure the same ordering is used, or exclude
// Metadata from the hashed portion.
type Event struct {
Timestamp time.Time `json:"timestamp"`
Metadata map[string]any `json:"metadata,omitempty"`
EventType EventType `json:"event_type"`
UserID string `json:"user_id,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Resource string `json:"resource,omitempty"`
Action string `json:"action,omitempty"`
ErrorMsg string `json:"error,omitempty"`
PrevHash string `json:"prev_hash,omitempty"`
EventHash string `json:"event_hash,omitempty"`
SequenceNum int64 `json:"sequence_num,omitempty"`
Success bool `json:"success"`
}
// Logger handles audit logging with integrity chain
type Logger struct {
file *os.File
logger *logging.Logger
filePath string
lastHash string
sequenceNum int64
mu sync.Mutex
enabled bool
}
// NewLogger creates a new audit logger with secure path validation.
// It validates the filePath for path traversal, symlink attacks, and ensures
// it stays within the base directory (/var/lib/fetch_ml/audit).
func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger, error) {
return NewLoggerWithBase(enabled, filePath, logger, "/var/lib/fetch_ml/audit")
}
// NewLoggerWithBase creates a new audit logger with a configurable base directory.
// This is useful for testing. For production, use NewLogger which uses the default base.
func NewLoggerWithBase(enabled bool, filePath string, logger *logging.Logger, baseDir string) (*Logger, error) {
al := &Logger{
enabled: enabled,
logger: logger,
}
if !enabled || filePath == "" {
return al, nil
}
// Use secure path validation
fullPath, err := validateAndSecurePath(filePath, baseDir)
if err != nil {
return nil, fmt.Errorf("invalid audit log path: %w", err)
}
// Check if file is a symlink (security check)
if err := checkFileNotSymlink(fullPath); err != nil {
return nil, fmt.Errorf("audit log security check failed: %w", err)
}
if err := os.MkdirAll(filepath.Dir(fullPath), 0o700); err != nil {
return nil, fmt.Errorf("failed to create audit directory: %w", err)
}
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return nil, fmt.Errorf("failed to open audit log file: %w", err)
}
al.file = file
al.filePath = fullPath
// Restore chain state from existing log to prevent integrity break on restart
if err := al.resumeFromFile(); err != nil {
file.Close()
return nil, fmt.Errorf("failed to resume audit chain: %w", err)
}
return al, nil
}
// Log logs an audit event with integrity chain
func (al *Logger) Log(event Event) {
if !al.enabled {
return
}
event.Timestamp = time.Now().UTC()
al.mu.Lock()
defer al.mu.Unlock()
// Set sequence number and previous hash for integrity chain
al.sequenceNum++
event.SequenceNum = al.sequenceNum
event.PrevHash = al.lastHash
// Calculate hash of this event for tamper evidence
event.EventHash = al.CalculateEventHash(event)
al.lastHash = event.EventHash
// Marshal to JSON
data, err := json.Marshal(event)
if err != nil {
if al.logger != nil {
al.logger.Error("failed to marshal audit event", "error", err)
}
return
}
// Write to file if configured
if al.file != nil {
_, err = al.file.Write(append(data, '\n'))
if err != nil && al.logger != nil {
al.logger.Error("failed to write audit event", "error", err)
}
// fsync ensures data is flushed to disk before updating hash in memory.
// Critical for crash safety: prevents chain inconsistency if system
// crashes after hash advance but before write completion.
if err == nil {
if syncErr := al.file.Sync(); syncErr != nil && al.logger != nil {
al.logger.Error("failed to sync audit log", "error", syncErr)
}
}
}
hashPreview := event.EventHash
if len(hashPreview) > 16 {
hashPreview = hashPreview[:16]
}
// Also log via structured logger
if al.logger != nil {
al.logger.Info("audit_event",
"event_type", event.EventType,
"user_id", event.UserID,
"resource", event.Resource,
"success", event.Success,
"seq", event.SequenceNum,
"hash", hashPreview,
)
}
}
// CalculateEventHash computes SHA-256 hash of event data for integrity chain
// Exported for testing purposes
func (al *Logger) CalculateEventHash(event Event) string {
eventCopy := event
eventCopy.EventHash = "" // keep PrevHash for chaining
data, err := json.Marshal(eventCopy)
if err != nil {
fallback := fmt.Sprintf(
"%s:%s:%d:%s",
event.Timestamp.UTC().Format(time.RFC3339Nano),
event.EventType,
event.SequenceNum,
event.PrevHash,
)
data = []byte(fallback)
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:])
}
// LogFileAccess logs a file access operation (HIPAA requirement)
func (al *Logger) LogFileAccess(
eventType EventType,
userID, filePath, ipAddr string,
success bool,
errMsg string,
) {
var action string
switch eventType {
case EventFileRead:
action = "read"
case EventFileWrite:
action = "write"
case EventFileDelete:
action = "delete"
case EventDatasetAccess:
action = "dataset_access"
default:
// Defensive: prevent silent misclassification
if al.logger != nil {
al.logger.Error(
"invalid file access event type",
"event_type", eventType,
)
}
return
}
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: filePath,
Action: action,
Success: success,
ErrorMsg: errMsg,
})
}
// VerifyChain checks the integrity of the audit log chain.
// The events slice must be provided in ascending sequence order.
// Returns the first sequence number where tampering is detected, or -1 if valid.
func (al *Logger) VerifyChain(events []Event) (tamperedSeq int, err error) {
if len(events) == 0 {
return -1, nil
}
var expectedPrevHash string
for i, event := range events {
// Enforce strict sequence ordering (events must be sorted by SequenceNum)
if event.SequenceNum != int64(i+1) {
return int(event.SequenceNum), fmt.Errorf(
"sequence mismatch: expected %d, got %d",
i+1, event.SequenceNum,
)
}
if i == 0 {
if event.PrevHash != "" {
return int(event.SequenceNum), fmt.Errorf(
"first event must have empty prev_hash",
)
}
// Explicit check: first event must have SequenceNum == 1
if event.SequenceNum != 1 {
return int(event.SequenceNum), fmt.Errorf(
"first event must have sequence_num=1, got %d",
event.SequenceNum,
)
}
} else {
if event.PrevHash != expectedPrevHash {
return int(event.SequenceNum), fmt.Errorf(
"chain break at sequence %d",
event.SequenceNum,
)
}
}
expectedHash := al.CalculateEventHash(event)
if event.EventHash != expectedHash {
return int(event.SequenceNum), fmt.Errorf(
"hash mismatch at sequence %d",
event.SequenceNum,
)
}
expectedPrevHash = event.EventHash
}
return -1, nil
}
// LogAuthAttempt logs an authentication attempt
func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) {
eventType := EventAuthSuccess
if !success {
eventType = EventAuthFailure
}
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Success: success,
ErrorMsg: errMsg,
})
}
// LogJobOperation logs a job-related operation
func (al *Logger) LogJobOperation(
eventType EventType,
userID, jobID, ipAddr string,
success bool,
errMsg string,
) {
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: jobID,
Action: "job_operation",
Success: success,
ErrorMsg: errMsg,
})
}
// LogJupyterOperation logs a Jupyter service operation
func (al *Logger) LogJupyterOperation(
eventType EventType,
userID, serviceID, ipAddr string,
success bool,
errMsg string,
) {
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: serviceID,
Action: "jupyter_operation",
Success: success,
ErrorMsg: errMsg,
})
}
// LogKMSOperation logs a KMS encryption/decryption or key management operation.
// Per ADR-012 through ADR-015: All key operations must be logged with tenant ID.
func (al *Logger) LogKMSOperation(
eventType EventType,
tenantID, artifactID, kmsKeyID string,
success bool,
errMsg string,
) {
metadata := map[string]any{
"tenant_id": tenantID,
"kms_key_id": kmsKeyID,
}
if artifactID != "" {
metadata["artifact_id"] = artifactID
}
al.Log(Event{
EventType: eventType,
UserID: tenantID, // Tenant is the entity performing the operation
Resource: kmsKeyID,
Action: string(eventType),
Success: success,
ErrorMsg: errMsg,
Metadata: metadata,
})
}
// Close closes the audit logger
func (al *Logger) Close() error {
al.mu.Lock()
defer al.mu.Unlock()
if al.file != nil {
return al.file.Close()
}
return nil
}
// resumeFromFile reads the last entry from the audit log file and restores
// the chain state (sequenceNum and lastHash) to prevent chain reset on restart.
// This is critical for tamper-evident logging integrity.
func (al *Logger) resumeFromFile() error {
if al.file == nil {
return nil
}
// Open file for reading to get the last entry
file, err := os.Open(al.filePath)
if err != nil {
return fmt.Errorf("failed to open audit log for resume: %w", err)
}
defer file.Close()
var lastEvent Event
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Text()
if line == "" {
continue
}
var event Event
if err := json.Unmarshal([]byte(line), &event); err != nil {
// Corrupted line - log but continue
if al.logger != nil {
al.logger.Warn("corrupted audit log entry during resume",
"line", lineNum,
"error", err)
}
continue
}
lastEvent = event
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading audit log during resume: %w", err)
}
// Restore chain state from last valid event
if lastEvent.SequenceNum > 0 {
al.sequenceNum = lastEvent.SequenceNum
al.lastHash = lastEvent.EventHash
if al.logger != nil {
al.logger.Info("audit chain resumed",
"sequence", al.sequenceNum,
"hash_preview", truncateHash(al.lastHash, 16))
}
}
return nil
}
// truncateHash returns a truncated hash string for logging (safe preview)
func truncateHash(hash string, maxLen int) string {
if len(hash) <= maxLen {
return hash
}
return hash[:maxLen]
}
// validateAndSecurePath validates a file path for security issues.
// It checks for path traversal, symlinks, and ensures the path stays within baseDir.
func validateAndSecurePath(filePath, baseDir string) (string, error) {
// Reject absolute paths
if filepath.IsAbs(filePath) {
return "", fmt.Errorf("absolute paths not allowed: %s", filePath)
}
// Clean the path to resolve any . or .. components
cleanPath := filepath.Clean(filePath)
// Check for path traversal attempts after cleaning
// If the path starts with .., it's trying to escape
if strings.HasPrefix(cleanPath, "..") {
return "", fmt.Errorf("path traversal attempt detected: %s", filePath)
}
// Resolve base directory symlinks (critical for security)
resolvedBase, err := filepath.EvalSymlinks(baseDir)
if err != nil {
// Base may not exist yet, use as-is but this is less secure
resolvedBase = baseDir
}
// Construct full path
fullPath := filepath.Join(resolvedBase, cleanPath)
// Resolve any symlinks in the full path
resolvedPath, err := filepath.EvalSymlinks(fullPath)
if err != nil {
// File doesn't exist yet - check parent directory
parent := filepath.Dir(fullPath)
resolvedParent, err := filepath.EvalSymlinks(parent)
if err != nil {
// Parent doesn't exist - validate the path itself
// Check that the path stays within base directory
if !strings.HasPrefix(fullPath, resolvedBase+string(os.PathSeparator)) &&
fullPath != resolvedBase {
return "", fmt.Errorf("path escapes base directory: %s", filePath)
}
resolvedPath = fullPath
} else {
// Parent resolved - verify it's still within base
if !strings.HasPrefix(resolvedParent, resolvedBase) {
return "", fmt.Errorf("parent directory escapes base: %s", filePath)
}
// Reconstruct path with resolved parent
base := filepath.Base(fullPath)
resolvedPath = filepath.Join(resolvedParent, base)
}
}
// Final verification: resolved path must be within base directory
if !strings.HasPrefix(resolvedPath, resolvedBase+string(os.PathSeparator)) &&
resolvedPath != resolvedBase {
return "", fmt.Errorf("path escapes base directory after symlink resolution: %s", filePath)
}
return resolvedPath, nil
}
// checkFileNotSymlink verifies that the given path is not a symlink
func checkFileNotSymlink(path string) error {
info, err := os.Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, can't be a symlink
}
return fmt.Errorf("failed to stat file: %w", err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("file is a symlink: %s", path)
}
return nil
}