Add tests for: - dataset: Redis dataset operations, transfer tracking - db_audit: audit logging with hash chain, access tracking - db_experiments: experiment metadata, dataset associations - db_tasks: task listing with pagination for users and groups - db_jobs: job CRUD, state transitions, worker assignment Coverage: storage package ~40%+
393 lines
10 KiB
Go
393 lines
10 KiB
Go
package storage_test
|
|
|
|
import (
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// auditTestSchema extends the base schema with task_access_log table
|
|
const auditTestSchema = `
|
|
CREATE TABLE IF NOT EXISTS jobs (
|
|
id TEXT PRIMARY KEY,
|
|
job_name TEXT NOT NULL,
|
|
args TEXT,
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
|
priority INTEGER DEFAULT 0,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
started_at DATETIME,
|
|
ended_at DATETIME,
|
|
worker_id TEXT,
|
|
error TEXT,
|
|
datasets TEXT,
|
|
metadata TEXT,
|
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
CREATE TABLE IF NOT EXISTS task_access_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
task_id TEXT NOT NULL,
|
|
user_id TEXT,
|
|
token TEXT,
|
|
action TEXT NOT NULL,
|
|
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
ip_address TEXT,
|
|
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
|
|
);
|
|
`
|
|
|
|
// setupAuditTestDB creates a fresh database with audit schema for each test
|
|
func setupAuditTestDB(t *testing.T) *storage.DB {
|
|
t.Helper()
|
|
dbPath := t.TempDir() + "/audit_test.db"
|
|
db, err := storage.NewDBFromPath(dbPath)
|
|
require.NoError(t, err, "Failed to create database")
|
|
|
|
err = db.Initialize(auditTestSchema)
|
|
require.NoError(t, err, "Failed to initialize database schema")
|
|
|
|
t.Cleanup(func() {
|
|
_ = db.Close()
|
|
})
|
|
|
|
return db
|
|
}
|
|
|
|
// createTestJob creates a minimal job for audit testing
|
|
func createTestJobForAudit(t *testing.T, db *storage.DB, id string) *storage.Job {
|
|
t.Helper()
|
|
job := &storage.Job{
|
|
ID: id,
|
|
Status: "pending",
|
|
}
|
|
err := db.CreateJob(job)
|
|
require.NoError(t, err, "Failed to create test job")
|
|
return job
|
|
}
|
|
|
|
// TestLogTaskAccess tests logging task access events
|
|
func TestLogTaskAccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-1")
|
|
|
|
cases := []struct {
|
|
name string
|
|
taskID string
|
|
userID *string
|
|
token *string
|
|
action string
|
|
ipAddress *string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "user view access",
|
|
taskID: job.ID,
|
|
userID: strPtr("user-123"),
|
|
token: nil,
|
|
action: "view",
|
|
ipAddress: strPtr("192.168.1.1"),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "token-based clone access",
|
|
taskID: job.ID,
|
|
userID: nil,
|
|
token: strPtr("share-token-abc"),
|
|
action: "clone",
|
|
ipAddress: strPtr("10.0.0.1"),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "execute action",
|
|
taskID: job.ID,
|
|
userID: strPtr("user-456"),
|
|
token: nil,
|
|
action: "execute",
|
|
ipAddress: nil,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "nonexistent task",
|
|
taskID: "nonexistent-task",
|
|
userID: strPtr("user-789"),
|
|
token: nil,
|
|
action: "view",
|
|
ipAddress: strPtr("127.0.0.1"),
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
action := tc.action // capture for pointer
|
|
err := db.LogTaskAccess(tc.taskID, tc.userID, tc.token, &action, tc.ipAddress)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetAuditLogForTask tests retrieving audit log entries for a task
|
|
func TestGetAuditLogForTask(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-2")
|
|
|
|
// Log multiple access events
|
|
userID := "user-audit"
|
|
action1 := "view"
|
|
action2 := "clone"
|
|
ip := "192.168.1.100"
|
|
|
|
err := db.LogTaskAccess(job.ID, &userID, nil, &action1, &ip)
|
|
require.NoError(t, err)
|
|
|
|
err = db.LogTaskAccess(job.ID, &userID, nil, &action2, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve audit log
|
|
entries, err := db.GetAuditLogForTask(job.ID, 10)
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 2, "Should have 2 audit log entries")
|
|
|
|
// Verify entries are ordered by accessed_at DESC
|
|
assert.Equal(t, job.ID, entries[0].TaskID)
|
|
assert.Equal(t, action2, entries[0].Action, "Most recent entry should be first")
|
|
assert.Equal(t, action1, entries[1].Action)
|
|
|
|
// Verify user ID is set correctly
|
|
require.NotNil(t, entries[0].UserID)
|
|
assert.Equal(t, userID, *entries[0].UserID)
|
|
}
|
|
|
|
// TestGetAuditLogForTaskWithLimit tests the limit parameter
|
|
func TestGetAuditLogForTaskWithLimit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-3")
|
|
|
|
// Log 5 access events
|
|
userID := "user-limit"
|
|
action := "view"
|
|
ip := "192.168.1.1"
|
|
|
|
for i := 0; i < 5; i++ {
|
|
err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
|
|
}
|
|
|
|
// Test with limit of 2
|
|
entries, err := db.GetAuditLogForTask(job.ID, 2)
|
|
require.NoError(t, err)
|
|
assert.Len(t, entries, 2, "Should respect limit parameter")
|
|
|
|
// Test with limit of 0 (should still return results, just limited by SQL)
|
|
entries, err = db.GetAuditLogForTask(job.ID, 0)
|
|
require.NoError(t, err)
|
|
// Limit 0 in SQL may return all or none depending on implementation
|
|
// Just verify it doesn't error
|
|
}
|
|
|
|
// TestGetAuditLogForUser tests retrieving audit log entries for a specific user
|
|
func TestGetAuditLogForUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job1 := createTestJobForAudit(t, db, "audit-task-4a")
|
|
job2 := createTestJobForAudit(t, db, "audit-task-4b")
|
|
|
|
user1 := "user-specific"
|
|
user2 := "user-other"
|
|
action := "view"
|
|
ip := "192.168.1.1"
|
|
|
|
// Log access for user1 on both tasks
|
|
err := db.LogTaskAccess(job1.ID, &user1, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
err = db.LogTaskAccess(job2.ID, &user1, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Log access for user2
|
|
err = db.LogTaskAccess(job1.ID, &user2, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Get audit log for user1
|
|
entries, err := db.GetAuditLogForUser(user1, 10)
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 2, "Should have 2 entries for user1")
|
|
|
|
// Verify both tasks are in the results
|
|
taskIDs := make(map[string]bool)
|
|
for _, e := range entries {
|
|
taskIDs[e.TaskID] = true
|
|
assert.Equal(t, user1, *e.UserID)
|
|
}
|
|
assert.True(t, taskIDs[job1.ID], "Should include job1")
|
|
assert.True(t, taskIDs[job2.ID], "Should include job2")
|
|
}
|
|
|
|
// TestGetAuditLogForToken tests retrieving audit log by token
|
|
func TestGetAuditLogForToken(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-5")
|
|
|
|
token1 := "token-abc-123"
|
|
token2 := "token-def-456"
|
|
action := "clone"
|
|
ip := "10.0.0.1"
|
|
|
|
// Log 2 accesses with token1
|
|
err := db.LogTaskAccess(job.ID, nil, &token1, &action, &ip)
|
|
require.NoError(t, err)
|
|
err = db.LogTaskAccess(job.ID, nil, &token1, &action, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Log 1 access with token2
|
|
err = db.LogTaskAccess(job.ID, nil, &token2, &action, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Get audit log for token1
|
|
entries, err := db.GetAuditLogForToken(token1, 10)
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 2, "Should have 2 entries for token1")
|
|
|
|
for _, e := range entries {
|
|
require.NotNil(t, e.Token)
|
|
assert.Equal(t, token1, *e.Token)
|
|
}
|
|
}
|
|
|
|
// TestDeleteOldAuditLogs tests cleanup of old audit log entries
|
|
func TestDeleteOldAuditLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-6")
|
|
|
|
// Log some access events
|
|
userID := "user-cleanup"
|
|
action := "view"
|
|
ip := "192.168.1.1"
|
|
|
|
for i := 0; i < 3; i++ {
|
|
err := db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Count before deletion
|
|
countBefore, err := db.CountAuditLogs()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(3), countBefore)
|
|
|
|
// Delete logs older than 0 days (should delete all)
|
|
deleted, err := db.DeleteOldAuditLogs(0)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(3), deleted, "Should delete all 3 entries")
|
|
|
|
// Verify count is 0
|
|
countAfter, err := db.CountAuditLogs()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(0), countAfter)
|
|
}
|
|
|
|
// TestCountAuditLogs tests counting total audit log entries
|
|
func TestCountAuditLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
|
|
// Count when empty
|
|
count, err := db.CountAuditLogs()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(0), count)
|
|
|
|
// Add some entries
|
|
job1 := createTestJobForAudit(t, db, "audit-task-7a")
|
|
job2 := createTestJobForAudit(t, db, "audit-task-7b")
|
|
|
|
userID := "user-count"
|
|
action := "view"
|
|
ip := "192.168.1.1"
|
|
|
|
err = db.LogTaskAccess(job1.ID, &userID, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
err = db.LogTaskAccess(job2.ID, &userID, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
|
|
// Count again
|
|
count, err = db.CountAuditLogs()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), count)
|
|
}
|
|
|
|
// TestGetOldestAuditLogDate tests retrieving the oldest audit log date
|
|
func TestGetOldestAuditLogDate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
|
|
// When empty, should return nil
|
|
date, err := db.GetOldestAuditLogDate()
|
|
require.NoError(t, err)
|
|
assert.Nil(t, date)
|
|
|
|
// Add an entry
|
|
job := createTestJobForAudit(t, db, "audit-task-8")
|
|
userID := "user-date"
|
|
action := "view"
|
|
ip := "192.168.1.1"
|
|
|
|
beforeLog := time.Now().UTC()
|
|
err = db.LogTaskAccess(job.ID, &userID, nil, &action, &ip)
|
|
require.NoError(t, err)
|
|
afterLog := time.Now().UTC()
|
|
|
|
// Get oldest date
|
|
date, err = db.GetOldestAuditLogDate()
|
|
require.NoError(t, err)
|
|
require.NotNil(t, date)
|
|
|
|
// Verify the date is within expected range (allow 1 second tolerance for test execution time)
|
|
assert.WithinDuration(t, beforeLog, *date, time.Second, "Date should be close to current time")
|
|
assert.WithinDuration(t, afterLog, *date, time.Second, "Date should be close to current time")
|
|
}
|
|
|
|
// TestAuditLogWithNilValues tests handling of NULL values in audit log
|
|
func TestAuditLogWithNilValues(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db := setupAuditTestDB(t)
|
|
job := createTestJobForAudit(t, db, "audit-task-9")
|
|
|
|
action := "view"
|
|
err := db.LogTaskAccess(job.ID, nil, nil, &action, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve and verify
|
|
entries, err := db.GetAuditLogForTask(job.ID, 10)
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 1)
|
|
|
|
entry := entries[0]
|
|
assert.Nil(t, entry.UserID)
|
|
assert.Nil(t, entry.Token)
|
|
assert.Nil(t, entry.IPAddress)
|
|
assert.Equal(t, "view", entry.Action)
|
|
}
|
|
|
|
// Helper function to get string pointer
|
|
func strPtr(s string) *string {
|
|
return &s
|
|
}
|