package storage_test import ( "testing" "time" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // auditTestSchema extends the base schema with task_access_log table const auditTestSchema = ` CREATE TABLE IF NOT EXISTS jobs ( id TEXT PRIMARY KEY, job_name TEXT NOT NULL, args TEXT, status TEXT NOT NULL DEFAULT 'pending', priority INTEGER DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, started_at DATETIME, ended_at DATETIME, worker_id TEXT, error TEXT, datasets TEXT, metadata TEXT, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS task_access_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id TEXT NOT NULL, user_id TEXT, token TEXT, action TEXT NOT NULL, accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP, ip_address TEXT, FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE ); ` // setupAuditTestDB creates a fresh database with audit schema for each test func setupAuditTestDB(t *testing.T) *storage.DB { t.Helper() dbPath := t.TempDir() + "/audit_test.db" db, err := storage.NewDBFromPath(dbPath) require.NoError(t, err, "Failed to create database") err = db.Initialize(auditTestSchema) require.NoError(t, err, "Failed to initialize database schema") t.Cleanup(func() { _ = db.Close() }) return db } // createTestJob creates a minimal job for audit testing func createTestJobForAudit(t *testing.T, db *storage.DB, id string) *storage.Job { t.Helper() job := &storage.Job{ ID: id, Status: "pending", } err := db.CreateJob(job) require.NoError(t, err, "Failed to create test job") return job } // TestLogTaskAccess tests logging task access events func TestLogTaskAccess(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-1") cases := []struct { name string taskID string userID *string token *string action string ipAddress *string wantErr bool }{ { name: "user view access", taskID: job.ID, userID: strPtr("user-123"), token: nil, action: "view", ipAddress: strPtr("192.168.1.1"), wantErr: false, }, { name: "token-based clone access", taskID: job.ID, userID: nil, token: strPtr("share-token-abc"), action: "clone", ipAddress: strPtr("10.0.0.1"), wantErr: false, }, { name: "execute action", taskID: job.ID, userID: strPtr("user-456"), token: nil, action: "execute", ipAddress: nil, wantErr: false, }, { name: "nonexistent task", taskID: "nonexistent-task", userID: strPtr("user-789"), token: nil, action: "view", ipAddress: strPtr("127.0.0.1"), wantErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { action := tc.action // capture for pointer err := db.LogTaskAccess(tc.taskID, tc.userID, tc.token, &action, tc.ipAddress) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) }) } } // TestGetAuditLogForTask tests retrieving audit log entries for a task func TestGetAuditLogForTask(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-2") // Log multiple access events userID := "user-audit" action1 := "view" action2 := "clone" ip := "192.168.1.100" err := db.LogTaskAccess(job.ID, &userID, nil, &action1, &ip) require.NoError(t, err) err = db.LogTaskAccess(job.ID, &userID, nil, &action2, &ip) require.NoError(t, err) // Retrieve audit log entries, err := db.GetAuditLogForTask(job.ID, 10) require.NoError(t, err) require.Len(t, entries, 2, "Should have 2 audit log entries") // Verify entries are ordered by accessed_at DESC assert.Equal(t, job.ID, entries[0].TaskID) assert.Equal(t, action2, entries[0].Action, "Most recent entry should be first") assert.Equal(t, action1, entries[1].Action) // Verify user ID is set correctly require.NotNil(t, entries[0].UserID) assert.Equal(t, userID, *entries[0].UserID) } // TestGetAuditLogForTaskWithLimit tests the limit parameter func TestGetAuditLogForTaskWithLimit(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-3") // Log 5 access events userID := "user-limit" action := "view" ip := "192.168.1.1" for i := 0; i < 5; i++ { err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip) require.NoError(t, err) time.Sleep(10 * time.Millisecond) // Ensure different timestamps } // Test with limit of 2 entries, err := db.GetAuditLogForTask(job.ID, 2) require.NoError(t, err) assert.Len(t, entries, 2, "Should respect limit parameter") // Test with limit of 0 (should still return results, just limited by SQL) entries, err = db.GetAuditLogForTask(job.ID, 0) require.NoError(t, err) // Limit 0 in SQL may return all or none depending on implementation // Just verify it doesn't error } // TestGetAuditLogForUser tests retrieving audit log entries for a specific user func TestGetAuditLogForUser(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job1 := createTestJobForAudit(t, db, "audit-task-4a") job2 := createTestJobForAudit(t, db, "audit-task-4b") user1 := "user-specific" user2 := "user-other" action := "view" ip := "192.168.1.1" // Log access for user1 on both tasks err := db.LogTaskAccess(job1.ID, &user1, nil, &action, &ip) require.NoError(t, err) err = db.LogTaskAccess(job2.ID, &user1, nil, &action, &ip) require.NoError(t, err) // Log access for user2 err = db.LogTaskAccess(job1.ID, &user2, nil, &action, &ip) require.NoError(t, err) // Get audit log for user1 entries, err := db.GetAuditLogForUser(user1, 10) require.NoError(t, err) require.Len(t, entries, 2, "Should have 2 entries for user1") // Verify both tasks are in the results taskIDs := make(map[string]bool) for _, e := range entries { taskIDs[e.TaskID] = true assert.Equal(t, user1, *e.UserID) } assert.True(t, taskIDs[job1.ID], "Should include job1") assert.True(t, taskIDs[job2.ID], "Should include job2") } // TestGetAuditLogForToken tests retrieving audit log by token func TestGetAuditLogForToken(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-5") token1 := "token-abc-123" token2 := "token-def-456" action := "clone" ip := "10.0.0.1" // Log 2 accesses with token1 err := db.LogTaskAccess(job.ID, nil, &token1, &action, &ip) require.NoError(t, err) err = db.LogTaskAccess(job.ID, nil, &token1, &action, &ip) require.NoError(t, err) // Log 1 access with token2 err = db.LogTaskAccess(job.ID, nil, &token2, &action, &ip) require.NoError(t, err) // Get audit log for token1 entries, err := db.GetAuditLogForToken(token1, 10) require.NoError(t, err) require.Len(t, entries, 2, "Should have 2 entries for token1") for _, e := range entries { require.NotNil(t, e.Token) assert.Equal(t, token1, *e.Token) } } // TestDeleteOldAuditLogs tests cleanup of old audit log entries func TestDeleteOldAuditLogs(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-6") // Log some access events userID := "user-cleanup" action := "view" ip := "192.168.1.1" for i := 0; i < 3; i++ { err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip) require.NoError(t, err) } // Count before deletion countBefore, err := db.CountAuditLogs() require.NoError(t, err) assert.Equal(t, int64(3), countBefore) // Delete logs older than 0 days (should delete all) deleted, err := db.DeleteOldAuditLogs(0) require.NoError(t, err) assert.Equal(t, int64(3), deleted, "Should delete all 3 entries") // Verify count is 0 countAfter, err := db.CountAuditLogs() require.NoError(t, err) assert.Equal(t, int64(0), countAfter) } // TestCountAuditLogs tests counting total audit log entries func TestCountAuditLogs(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) // Count when empty count, err := db.CountAuditLogs() require.NoError(t, err) assert.Equal(t, int64(0), count) // Add some entries job1 := createTestJobForAudit(t, db, "audit-task-7a") job2 := createTestJobForAudit(t, db, "audit-task-7b") userID := "user-count" action := "view" ip := "192.168.1.1" err = db.LogTaskAccess(job1.ID, &userID, nil, &action, &ip) require.NoError(t, err) err = db.LogTaskAccess(job2.ID, &userID, nil, &action, &ip) require.NoError(t, err) // Count again count, err = db.CountAuditLogs() require.NoError(t, err) assert.Equal(t, int64(2), count) } // TestGetOldestAuditLogDate tests retrieving the oldest audit log date func TestGetOldestAuditLogDate(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) // When empty, should return nil date, err := db.GetOldestAuditLogDate() require.NoError(t, err) assert.Nil(t, date) // Add an entry job := createTestJobForAudit(t, db, "audit-task-8") userID := "user-date" action := "view" ip := "192.168.1.1" beforeLog := time.Now().UTC() err = db.LogTaskAccess(job.ID, &userID, nil, &action, &ip) require.NoError(t, err) afterLog := time.Now().UTC() // Get oldest date date, err = db.GetOldestAuditLogDate() require.NoError(t, err) require.NotNil(t, date) // Verify the date is within expected range (allow 1 second tolerance for test execution time) assert.WithinDuration(t, beforeLog, *date, time.Second, "Date should be close to current time") assert.WithinDuration(t, afterLog, *date, time.Second, "Date should be close to current time") } // TestAuditLogWithNilValues tests handling of NULL values in audit log func TestAuditLogWithNilValues(t *testing.T) { t.Parallel() db := setupAuditTestDB(t) job := createTestJobForAudit(t, db, "audit-task-9") action := "view" err := db.LogTaskAccess(job.ID, nil, nil, &action, nil) require.NoError(t, err) // Retrieve and verify entries, err := db.GetAuditLogForTask(job.ID, 10) require.NoError(t, err) require.Len(t, entries, 1) entry := entries[0] assert.Nil(t, entry.UserID) assert.Nil(t, entry.Token) assert.Nil(t, entry.IPAddress) assert.Equal(t, "view", entry.Action) } // Helper function to get string pointer func strPtr(s string) *string { return &s }