package storage import ( "context" "database/sql" "fmt" "log" "time" ) // AuditLogEntry represents a single audit log entry type AuditLogEntry struct { ID int64 `json:"id"` TaskID string `json:"task_id"` UserID *string `json:"user_id,omitempty"` Token *string `json:"token,omitempty"` Action string `json:"action"` AccessedAt time.Time `json:"accessed_at"` IPAddress *string `json:"ip_address,omitempty"` } // LogTaskAccess records an access event in the audit log. func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error { var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address) VALUES (?, ?, ?, ?, ?)` } else { query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address) VALUES ($1, $2, $3, $4, $5)` } _, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress) if err != nil { return fmt.Errorf("failed to log task access: %w", err) } return nil } // GetAuditLogForTask retrieves audit log entries for a specific task. func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE task_id = ? ORDER BY accessed_at DESC LIMIT ?` } else { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE task_id = $1 ORDER BY accessed_at DESC LIMIT $2` } rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit) if err != nil { return nil, fmt.Errorf("failed to get audit log: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("ERROR: failed to close rows: %v", err) } }() return db.scanAuditLogEntries(rows) } // GetAuditLogForUser retrieves audit log entries for a specific user. func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE user_id = ? ORDER BY accessed_at DESC LIMIT ?` } else { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE user_id = $1 ORDER BY accessed_at DESC LIMIT $2` } rows, err := db.conn.QueryContext(context.Background(), query, userID, limit) if err != nil { return nil, fmt.Errorf("failed to get audit log: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("ERROR: failed to close rows: %v", err) } }() return db.scanAuditLogEntries(rows) } // GetAuditLogForToken retrieves audit log entries for a specific token. func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE token = ? ORDER BY accessed_at DESC LIMIT ?` } else { query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address FROM task_access_log WHERE token = $1 ORDER BY accessed_at DESC LIMIT $2` } rows, err := db.conn.QueryContext(context.Background(), query, token, limit) if err != nil { return nil, fmt.Errorf("failed to get audit log: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("ERROR: failed to close rows: %v", err) } }() return db.scanAuditLogEntries(rows) } // scanAuditLogEntries scans rows into AuditLogEntry structs. func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) { var entries []*AuditLogEntry for rows.Next() { var e AuditLogEntry var userID, token, ipAddress sql.NullString err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress) if err != nil { return nil, fmt.Errorf("failed to scan audit log entry: %w", err) } if userID.Valid { e.UserID = &userID.String } if token.Valid { e.Token = &token.String } if ipAddress.Valid { e.IPAddress = &ipAddress.String } entries = append(entries, &e) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating audit log entries: %w", err) } return entries, nil } // DeleteOldAuditLogs removes audit log entries older than the retention period. // This should be called by a nightly job. func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM task_access_log WHERE accessed_at < datetime('now', '-' || ? || ' days')` } else { query = `DELETE FROM task_access_log WHERE accessed_at < NOW() - INTERVAL '1 day' * $1` } result, err := db.conn.ExecContext(context.Background(), query, retentionDays) if err != nil { return 0, fmt.Errorf("failed to delete old audit logs: %w", err) } affected, err := result.RowsAffected() if err != nil { return 0, fmt.Errorf("failed to get rows affected: %w", err) } return affected, nil } // CountAuditLogs returns the total number of audit log entries. func (db *DB) CountAuditLogs() (int64, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM task_access_log` } else { query = `SELECT COUNT(*) FROM task_access_log` } var count int64 err := db.conn.QueryRowContext(context.Background(), query).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to count audit logs: %w", err) } return count, nil } // GetOldestAuditLogDate returns the date of the oldest audit log entry. func (db *DB) GetOldestAuditLogDate() (*time.Time, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT MIN(accessed_at) FROM task_access_log` } else { query = `SELECT MIN(accessed_at) FROM task_access_log` } var date sql.NullTime err := db.conn.QueryRowContext(context.Background(), query).Scan(&date) if err != nil { return nil, fmt.Errorf("failed to get oldest audit log date: %w", err) } if !date.Valid { return nil, nil } return &date.Time, nil }