fetch_ml/internal/storage/db_audit_test.go
Jeremie Fraeys 50b6506243
test(storage): add comprehensive storage layer tests
Add tests for:
- dataset: Redis dataset operations, transfer tracking
- db_audit: audit logging with hash chain, access tracking
- db_experiments: experiment metadata, dataset associations
- db_tasks: task listing with pagination for users and groups
- db_jobs: job CRUD, state transitions, worker assignment

Coverage: storage package ~40%+
2026-03-13 23:26:33 -04:00

393 lines
10 KiB
Go

package storage_test
import (
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// auditTestSchema extends the base schema with task_access_log table
const auditTestSchema = `
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
job_name TEXT NOT NULL,
args TEXT,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
ended_at DATETIME,
worker_id TEXT,
error TEXT,
datasets TEXT,
metadata TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS task_access_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
user_id TEXT,
token TEXT,
action TEXT NOT NULL,
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
ip_address TEXT,
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
);
`
// setupAuditTestDB creates a fresh database with audit schema for each test
func setupAuditTestDB(t *testing.T) *storage.DB {
t.Helper()
dbPath := t.TempDir() + "/audit_test.db"
db, err := storage.NewDBFromPath(dbPath)
require.NoError(t, err, "Failed to create database")
err = db.Initialize(auditTestSchema)
require.NoError(t, err, "Failed to initialize database schema")
t.Cleanup(func() {
_ = db.Close()
})
return db
}
// createTestJob creates a minimal job for audit testing
func createTestJobForAudit(t *testing.T, db *storage.DB, id string) *storage.Job {
t.Helper()
job := &storage.Job{
ID: id,
Status: "pending",
}
err := db.CreateJob(job)
require.NoError(t, err, "Failed to create test job")
return job
}
// TestLogTaskAccess tests logging task access events
func TestLogTaskAccess(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-1")
cases := []struct {
name string
taskID string
userID *string
token *string
action string
ipAddress *string
wantErr bool
}{
{
name: "user view access",
taskID: job.ID,
userID: strPtr("user-123"),
token: nil,
action: "view",
ipAddress: strPtr("192.168.1.1"),
wantErr: false,
},
{
name: "token-based clone access",
taskID: job.ID,
userID: nil,
token: strPtr("share-token-abc"),
action: "clone",
ipAddress: strPtr("10.0.0.1"),
wantErr: false,
},
{
name: "execute action",
taskID: job.ID,
userID: strPtr("user-456"),
token: nil,
action: "execute",
ipAddress: nil,
wantErr: false,
},
{
name: "nonexistent task",
taskID: "nonexistent-task",
userID: strPtr("user-789"),
token: nil,
action: "view",
ipAddress: strPtr("127.0.0.1"),
wantErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
action := tc.action // capture for pointer
err := db.LogTaskAccess(tc.taskID, tc.userID, tc.token, &action, tc.ipAddress)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
// TestGetAuditLogForTask tests retrieving audit log entries for a task
func TestGetAuditLogForTask(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-2")
// Log multiple access events
userID := "user-audit"
action1 := "view"
action2 := "clone"
ip := "192.168.1.100"
err := db.LogTaskAccess(job.ID, &userID, nil, &action1, &ip)
require.NoError(t, err)
err = db.LogTaskAccess(job.ID, &userID, nil, &action2, &ip)
require.NoError(t, err)
// Retrieve audit log
entries, err := db.GetAuditLogForTask(job.ID, 10)
require.NoError(t, err)
require.Len(t, entries, 2, "Should have 2 audit log entries")
// Verify entries are ordered by accessed_at DESC
assert.Equal(t, job.ID, entries[0].TaskID)
assert.Equal(t, action2, entries[0].Action, "Most recent entry should be first")
assert.Equal(t, action1, entries[1].Action)
// Verify user ID is set correctly
require.NotNil(t, entries[0].UserID)
assert.Equal(t, userID, *entries[0].UserID)
}
// TestGetAuditLogForTaskWithLimit tests the limit parameter
func TestGetAuditLogForTaskWithLimit(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-3")
// Log 5 access events
userID := "user-limit"
action := "view"
ip := "192.168.1.1"
for i := 0; i < 5; i++ {
err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
}
// Test with limit of 2
entries, err := db.GetAuditLogForTask(job.ID, 2)
require.NoError(t, err)
assert.Len(t, entries, 2, "Should respect limit parameter")
// Test with limit of 0 (should still return results, just limited by SQL)
entries, err = db.GetAuditLogForTask(job.ID, 0)
require.NoError(t, err)
// Limit 0 in SQL may return all or none depending on implementation
// Just verify it doesn't error
}
// TestGetAuditLogForUser tests retrieving audit log entries for a specific user
func TestGetAuditLogForUser(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job1 := createTestJobForAudit(t, db, "audit-task-4a")
job2 := createTestJobForAudit(t, db, "audit-task-4b")
user1 := "user-specific"
user2 := "user-other"
action := "view"
ip := "192.168.1.1"
// Log access for user1 on both tasks
err := db.LogTaskAccess(job1.ID, &user1, nil, &action, &ip)
require.NoError(t, err)
err = db.LogTaskAccess(job2.ID, &user1, nil, &action, &ip)
require.NoError(t, err)
// Log access for user2
err = db.LogTaskAccess(job1.ID, &user2, nil, &action, &ip)
require.NoError(t, err)
// Get audit log for user1
entries, err := db.GetAuditLogForUser(user1, 10)
require.NoError(t, err)
require.Len(t, entries, 2, "Should have 2 entries for user1")
// Verify both tasks are in the results
taskIDs := make(map[string]bool)
for _, e := range entries {
taskIDs[e.TaskID] = true
assert.Equal(t, user1, *e.UserID)
}
assert.True(t, taskIDs[job1.ID], "Should include job1")
assert.True(t, taskIDs[job2.ID], "Should include job2")
}
// TestGetAuditLogForToken tests retrieving audit log by token
func TestGetAuditLogForToken(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-5")
token1 := "token-abc-123"
token2 := "token-def-456"
action := "clone"
ip := "10.0.0.1"
// Log 2 accesses with token1
err := db.LogTaskAccess(job.ID, nil, &token1, &action, &ip)
require.NoError(t, err)
err = db.LogTaskAccess(job.ID, nil, &token1, &action, &ip)
require.NoError(t, err)
// Log 1 access with token2
err = db.LogTaskAccess(job.ID, nil, &token2, &action, &ip)
require.NoError(t, err)
// Get audit log for token1
entries, err := db.GetAuditLogForToken(token1, 10)
require.NoError(t, err)
require.Len(t, entries, 2, "Should have 2 entries for token1")
for _, e := range entries {
require.NotNil(t, e.Token)
assert.Equal(t, token1, *e.Token)
}
}
// TestDeleteOldAuditLogs tests cleanup of old audit log entries
func TestDeleteOldAuditLogs(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-6")
// Log some access events
userID := "user-cleanup"
action := "view"
ip := "192.168.1.1"
for i := 0; i < 3; i++ {
err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
require.NoError(t, err)
}
// Count before deletion
countBefore, err := db.CountAuditLogs()
require.NoError(t, err)
assert.Equal(t, int64(3), countBefore)
// Delete logs older than 0 days (should delete all)
deleted, err := db.DeleteOldAuditLogs(0)
require.NoError(t, err)
assert.Equal(t, int64(3), deleted, "Should delete all 3 entries")
// Verify count is 0
countAfter, err := db.CountAuditLogs()
require.NoError(t, err)
assert.Equal(t, int64(0), countAfter)
}
// TestCountAuditLogs tests counting total audit log entries
func TestCountAuditLogs(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
// Count when empty
count, err := db.CountAuditLogs()
require.NoError(t, err)
assert.Equal(t, int64(0), count)
// Add some entries
job1 := createTestJobForAudit(t, db, "audit-task-7a")
job2 := createTestJobForAudit(t, db, "audit-task-7b")
userID := "user-count"
action := "view"
ip := "192.168.1.1"
err = db.LogTaskAccess(job1.ID, &userID, nil, &action, &ip)
require.NoError(t, err)
err = db.LogTaskAccess(job2.ID, &userID, nil, &action, &ip)
require.NoError(t, err)
// Count again
count, err = db.CountAuditLogs()
require.NoError(t, err)
assert.Equal(t, int64(2), count)
}
// TestGetOldestAuditLogDate tests retrieving the oldest audit log date
func TestGetOldestAuditLogDate(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
// When empty, should return nil
date, err := db.GetOldestAuditLogDate()
require.NoError(t, err)
assert.Nil(t, date)
// Add an entry
job := createTestJobForAudit(t, db, "audit-task-8")
userID := "user-date"
action := "view"
ip := "192.168.1.1"
beforeLog := time.Now().UTC()
err = db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
require.NoError(t, err)
afterLog := time.Now().UTC()
// Get oldest date
date, err = db.GetOldestAuditLogDate()
require.NoError(t, err)
require.NotNil(t, date)
// Verify the date is within expected range (allow 1 second tolerance for test execution time)
assert.WithinDuration(t, beforeLog, *date, time.Second, "Date should be close to current time")
assert.WithinDuration(t, afterLog, *date, time.Second, "Date should be close to current time")
}
// TestAuditLogWithNilValues tests handling of NULL values in audit log
func TestAuditLogWithNilValues(t *testing.T) {
t.Parallel()
db := setupAuditTestDB(t)
job := createTestJobForAudit(t, db, "audit-task-9")
action := "view"
err := db.LogTaskAccess(job.ID, nil, nil, &action, nil)
require.NoError(t, err)
// Retrieve and verify
entries, err := db.GetAuditLogForTask(job.ID, 10)
require.NoError(t, err)
require.Len(t, entries, 1)
entry := entries[0]
assert.Nil(t, entry.UserID)
assert.Nil(t, entry.Token)
assert.Nil(t, entry.IPAddress)
assert.Equal(t, "view", entry.Action)
}
// Helper function to get string pointer
func strPtr(s string) *string {
return &s
}