fetch_ml/internal/storage/db_audit.go
Jeremie Fraeys 4da027868d
fix(storage): handle NULL values and state tracking in database operations
Fixes to support proper test coverage:

- db_jobs.go: UpdateJobStatus now checks RowsAffected and returns error
  for nonexistent jobs instead of silently succeeding
- db_audit.go: GetOldestAuditLogDate uses sql.NullString to parse SQLite
  datetime strings in YYYY-MM-DD HH:MM:SS format with RFC3339 fallback
- db_experiments.go: ListTasksForExperiment uses sql.NullString for
  nullable worker_id and error fields to prevent scan errors
- db_connect.go: DB struct adds isClosed state tracking with mutex;
  Close() now returns error on double close to match test expectations
2026-03-13 23:27:35 -04:00

235 lines
6.5 KiB
Go

package storage
import (
"context"
"database/sql"
"fmt"
"log"
"time"
)
// AuditLogEntry represents a single audit log entry
type AuditLogEntry struct {
ID int64 `json:"id"`
TaskID string `json:"task_id"`
UserID *string `json:"user_id,omitempty"`
Token *string `json:"token,omitempty"`
Action string `json:"action"`
AccessedAt time.Time `json:"accessed_at"`
IPAddress *string `json:"ip_address,omitempty"`
}
// LogTaskAccess records an access event in the audit log.
func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES (?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES ($1, $2, $3, $4, $5)`
}
_, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress)
if err != nil {
return fmt.Errorf("failed to log task access: %w", err)
}
return nil
}
// GetAuditLogForTask retrieves audit log entries for a specific task.
func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForUser retrieves audit log entries for a specific user.
func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForToken retrieves audit log entries for a specific token.
func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, token, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// scanAuditLogEntries scans rows into AuditLogEntry structs.
func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) {
var entries []*AuditLogEntry
for rows.Next() {
var e AuditLogEntry
var userID, token, ipAddress sql.NullString
err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress)
if err != nil {
return nil, fmt.Errorf("failed to scan audit log entry: %w", err)
}
if userID.Valid {
e.UserID = &userID.String
}
if token.Valid {
e.Token = &token.String
}
if ipAddress.Valid {
e.IPAddress = &ipAddress.String
}
entries = append(entries, &e)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit log entries: %w", err)
}
return entries, nil
}
// DeleteOldAuditLogs removes audit log entries older than the retention period.
// This should be called by a nightly job.
func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM task_access_log
WHERE accessed_at < datetime('now', '-' || ? || ' days')`
} else {
query = `DELETE FROM task_access_log
WHERE accessed_at < NOW() - INTERVAL '1 day' * $1`
}
result, err := db.conn.ExecContext(context.Background(), query, retentionDays)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit logs: %w", err)
}
affected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return affected, nil
}
// CountAuditLogs returns the total number of audit log entries.
func (db *DB) CountAuditLogs() (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM task_access_log`
} else {
query = `SELECT COUNT(*) FROM task_access_log`
}
var count int64
err := db.conn.QueryRowContext(context.Background(), query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit logs: %w", err)
}
return count, nil
}
// GetOldestAuditLogDate returns the date of the oldest audit log entry.
func (db *DB) GetOldestAuditLogDate() (*time.Time, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT MIN(accessed_at) FROM task_access_log`
} else {
query = `SELECT MIN(accessed_at) FROM task_access_log`
}
var dateStr sql.NullString
err := db.conn.QueryRowContext(context.Background(), query).Scan(&dateStr)
if err != nil {
return nil, fmt.Errorf("failed to get oldest audit log date: %w", err)
}
if !dateStr.Valid || dateStr.String == "" {
return nil, nil
}
// Parse SQLite datetime string
t, err := time.Parse("2006-01-02 15:04:05", dateStr.String)
if err != nil {
// Try RFC3339 format as fallback
t, err = time.Parse(time.RFC3339, dateStr.String)
if err != nil {
return nil, fmt.Errorf("failed to parse date: %w", err)
}
}
return &t, nil
}