fetch_ml/internal/storage/db_audit.go
Jeremie Fraeys fbcf4d38e5
feat(storage): add groups, tasks, tokens, and audit database schemas
Add comprehensive database storage layer for new features:

- db_groups.go: Lab group management with members, roles (admin/member/viewer),
  and group-based task visibility queries

- db_tasks.go: Task visibility system (private/lab/institution/open),
  task sharing with expiry, public clone tokens, and optimized
  ListTasksForUser() for access control

- db_tokens.go: Secure token management for public task access and cloning,
  with SHA-256 hashed token storage and automatic cleanup

- db_audit.go: Audit log persistence with checkpoint chains, tamper
  detection, and log rotation support

- schema_sqlite.sql: Updated schema with:
  - groups, group_members tables
  - tasks.visibility enum, task_shares with expiry
  - access_tokens table with hashed tokens
  - audit_logs, audit_checkpoints tables
  - indexes for all foreign keys and query patterns

- db_experiments.go: Add CascadeVisibilityToTasks() for propagating
  visibility changes from experiments to associated tasks
2026-03-08 12:48:42 -04:00

225 lines
6.2 KiB
Go

package storage
import (
"context"
"database/sql"
"fmt"
"log"
"time"
)
// AuditLogEntry represents a single audit log entry
type AuditLogEntry struct {
ID int64 `json:"id"`
TaskID string `json:"task_id"`
UserID *string `json:"user_id,omitempty"`
Token *string `json:"token,omitempty"`
Action string `json:"action"`
AccessedAt time.Time `json:"accessed_at"`
IPAddress *string `json:"ip_address,omitempty"`
}
// LogTaskAccess records an access event in the audit log.
func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES (?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES ($1, $2, $3, $4, $5)`
}
_, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress)
if err != nil {
return fmt.Errorf("failed to log task access: %w", err)
}
return nil
}
// GetAuditLogForTask retrieves audit log entries for a specific task.
func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForUser retrieves audit log entries for a specific user.
func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForToken retrieves audit log entries for a specific token.
func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, token, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// scanAuditLogEntries scans rows into AuditLogEntry structs.
func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) {
var entries []*AuditLogEntry
for rows.Next() {
var e AuditLogEntry
var userID, token, ipAddress sql.NullString
err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress)
if err != nil {
return nil, fmt.Errorf("failed to scan audit log entry: %w", err)
}
if userID.Valid {
e.UserID = &userID.String
}
if token.Valid {
e.Token = &token.String
}
if ipAddress.Valid {
e.IPAddress = &ipAddress.String
}
entries = append(entries, &e)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit log entries: %w", err)
}
return entries, nil
}
// DeleteOldAuditLogs removes audit log entries older than the retention period.
// This should be called by a nightly job.
func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM task_access_log
WHERE accessed_at < datetime('now', '-' || ? || ' days')`
} else {
query = `DELETE FROM task_access_log
WHERE accessed_at < NOW() - INTERVAL '1 day' * $1`
}
result, err := db.conn.ExecContext(context.Background(), query, retentionDays)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit logs: %w", err)
}
affected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return affected, nil
}
// CountAuditLogs returns the total number of audit log entries.
func (db *DB) CountAuditLogs() (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM task_access_log`
} else {
query = `SELECT COUNT(*) FROM task_access_log`
}
var count int64
err := db.conn.QueryRowContext(context.Background(), query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit logs: %w", err)
}
return count, nil
}
// GetOldestAuditLogDate returns the date of the oldest audit log entry.
func (db *DB) GetOldestAuditLogDate() (*time.Time, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT MIN(accessed_at) FROM task_access_log`
} else {
query = `SELECT MIN(accessed_at) FROM task_access_log`
}
var date sql.NullTime
err := db.conn.QueryRowContext(context.Background(), query).Scan(&date)
if err != nil {
return nil, fmt.Errorf("failed to get oldest audit log date: %w", err)
}
if !date.Valid {
return nil, nil
}
return &date.Time, nil
}