fetch_ml/internal/storage/db_experiments_test.go
Jeremie Fraeys 50b6506243
test(storage): add comprehensive storage layer tests
Add tests for:
- dataset: Redis dataset operations, transfer tracking
- db_audit: audit logging with hash chain, access tracking
- db_experiments: experiment metadata, dataset associations
- db_tasks: task listing with pagination for users and groups
- db_jobs: job CRUD, state transitions, worker assignment

Coverage: storage package ~40%+
2026-03-13 23:26:33 -04:00

715 lines
17 KiB
Go

package storage_test
import (
"context"
"encoding/json"
"testing"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// experimentsTestSchema includes all tables needed for experiment tests
const experimentsTestSchema = `
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
job_name TEXT NOT NULL,
args TEXT,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
ended_at DATETIME,
worker_id TEXT,
user_id TEXT,
error TEXT,
datasets TEXT,
metadata TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
visibility TEXT NOT NULL DEFAULT 'lab'
);
CREATE TABLE IF NOT EXISTS experiments (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
status TEXT DEFAULT 'pending',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
user_id TEXT,
workspace_id TEXT
);
CREATE TABLE IF NOT EXISTS experiment_environments (
experiment_id TEXT PRIMARY KEY,
python_version TEXT,
cuda_version TEXT,
system_os TEXT,
system_arch TEXT,
hostname TEXT,
requirements_hash TEXT,
conda_env_hash TEXT,
dependencies TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS experiment_git_info (
experiment_id TEXT PRIMARY KEY,
commit_sha TEXT,
branch TEXT,
remote_url TEXT,
is_dirty INTEGER DEFAULT 0,
diff_patch TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS experiment_seeds (
experiment_id TEXT PRIMARY KEY,
numpy_seed INTEGER,
torch_seed INTEGER,
tensorflow_seed INTEGER,
random_seed INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS experiment_tasks (
experiment_id TEXT NOT NULL,
task_id TEXT NOT NULL,
added_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (experiment_id, task_id),
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS datasets (
name TEXT PRIMARY KEY,
url TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
`
// setupExperimentsTestDB creates a fresh database with experiment schema
func setupExperimentsTestDB(t *testing.T) *storage.DB {
t.Helper()
dbPath := t.TempDir() + "/experiments_test.db"
db, err := storage.NewDBFromPath(dbPath)
require.NoError(t, err, "Failed to create database")
err = db.Initialize(experimentsTestSchema)
require.NoError(t, err, "Failed to initialize database schema")
t.Cleanup(func() {
_ = db.Close()
})
return db
}
// TestUpsertExperiment tests creating and updating experiments
func TestUpsertExperiment(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
cases := []struct {
name string
exp *storage.Experiment
wantErr bool
}{
{
name: "valid experiment",
exp: &storage.Experiment{
ID: "exp-1",
Name: "Test Experiment",
Description: "Test description",
Status: "active",
UserID: "user-123",
WorkspaceID: "ws-456",
},
wantErr: false,
},
{
name: "minimal experiment",
exp: &storage.Experiment{
ID: "exp-2",
Name: "Minimal Experiment",
},
wantErr: false,
},
{
name: "nil experiment",
exp: nil,
wantErr: true,
},
{
name: "empty id",
exp: &storage.Experiment{
ID: "",
Name: "No ID",
},
wantErr: true,
},
{
name: "empty name",
exp: &storage.Experiment{
ID: "exp-3",
Name: "",
},
wantErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := db.UpsertExperiment(ctx, tc.exp)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
// TestUpsertExperimentUpdate tests updating an existing experiment
func TestUpsertExperimentUpdate(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create initial experiment
exp := &storage.Experiment{
ID: "exp-update",
Name: "Original Name",
Description: "Original description",
Status: "pending",
UserID: "user-1",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Update the experiment
exp.Name = "Updated Name"
exp.Description = "Updated description"
exp.Status = "active"
err = db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Verify update by retrieving with metadata
result, err := db.GetExperimentWithMetadata(ctx, exp.ID)
require.NoError(t, err)
assert.Equal(t, "Updated Name", result.Experiment.Name)
assert.Equal(t, "Updated description", result.Experiment.Description)
assert.Equal(t, "active", result.Experiment.Status)
}
// TestUpsertExperimentEnvironment tests storing experiment environment info
func TestUpsertExperimentEnvironment(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment first
exp := &storage.Experiment{
ID: "exp-env",
Name: "Environment Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
cases := []struct {
name string
expID string
env *storage.ExperimentEnvironment
wantErr bool
}{
{
name: "valid environment",
expID: exp.ID,
env: &storage.ExperimentEnvironment{
PythonVersion: "3.9.7",
CUDAVersion: "11.8",
SystemOS: "linux",
SystemArch: "x86_64",
Hostname: "test-host",
RequirementsHash: "abc123",
CondaEnvHash: "def456",
Dependencies: json.RawMessage(`{"numpy": "1.21.0"}`),
},
wantErr: false,
},
{
name: "empty experiment id",
expID: "",
env: &storage.ExperimentEnvironment{PythonVersion: "3.8"},
wantErr: true,
},
{
name: "nil environment",
expID: exp.ID,
env: nil,
wantErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := db.UpsertExperimentEnvironment(ctx, tc.expID, tc.env)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
// TestUpsertExperimentGitInfo tests storing git information
func TestUpsertExperimentGitInfo(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-git",
Name: "Git Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Store git info
gitInfo := &storage.ExperimentGitInfo{
CommitSHA: "a1b2c3d4e5f6",
Branch: "main",
RemoteURL: "https://github.com/user/repo.git",
IsDirty: true,
DiffPatch: "diff --git a/file.txt b/file.txt",
}
err = db.UpsertExperimentGitInfo(ctx, exp.ID, gitInfo)
require.NoError(t, err)
// Retrieve and verify
result, err := db.GetExperimentWithMetadata(ctx, exp.ID)
require.NoError(t, err)
require.NotNil(t, result.GitInfo)
assert.Equal(t, gitInfo.CommitSHA, result.GitInfo.CommitSHA)
assert.Equal(t, gitInfo.Branch, result.GitInfo.Branch)
assert.Equal(t, gitInfo.RemoteURL, result.GitInfo.RemoteURL)
assert.Equal(t, gitInfo.IsDirty, result.GitInfo.IsDirty)
}
// TestUpsertExperimentSeeds tests storing random seeds
func TestUpsertExperimentSeeds(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-seeds",
Name: "Seeds Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Store seeds
numpySeed := int64(42)
torchSeed := int64(123)
seeds := &storage.ExperimentSeeds{
Numpy: &numpySeed,
Torch: &torchSeed,
}
err = db.UpsertExperimentSeeds(ctx, exp.ID, seeds)
require.NoError(t, err)
// Retrieve and verify
result, err := db.GetExperimentWithMetadata(ctx, exp.ID)
require.NoError(t, err)
require.NotNil(t, result.Seeds)
require.NotNil(t, result.Seeds.Numpy)
assert.Equal(t, numpySeed, *result.Seeds.Numpy)
require.NotNil(t, result.Seeds.Torch)
assert.Equal(t, torchSeed, *result.Seeds.Torch)
assert.Nil(t, result.Seeds.TensorFlow)
assert.Nil(t, result.Seeds.Random)
}
// TestGetExperimentWithMetadata tests retrieving complete experiment metadata
func TestGetExperimentWithMetadata(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment with all metadata
exp := &storage.Experiment{
ID: "exp-full",
Name: "Full Metadata Test",
Description: "Test experiment",
Status: "active",
UserID: "user-test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
env := &storage.ExperimentEnvironment{
PythonVersion: "3.10.0",
SystemOS: "darwin",
SystemArch: "arm64",
}
err = db.UpsertExperimentEnvironment(ctx, exp.ID, env)
require.NoError(t, err)
// Retrieve without metadata
result, err := db.GetExperimentWithMetadata(ctx, exp.ID)
require.NoError(t, err)
assert.Equal(t, exp.ID, result.Experiment.ID)
assert.Equal(t, exp.Name, result.Experiment.Name)
assert.NotNil(t, result.Environment)
// Test nonexistent experiment
_, err = db.GetExperimentWithMetadata(ctx, "nonexistent")
require.Error(t, err)
}
// TestGetExperimentWithMetadataEmptyID tests validation
func TestGetExperimentWithMetadataEmptyID(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
_, err := db.GetExperimentWithMetadata(ctx, "")
require.Error(t, err)
}
// TestUpsertDataset tests dataset creation and updates
func TestUpsertDataset(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
cases := []struct {
name string
ds *storage.Dataset
wantErr bool
}{
{
name: "valid dataset",
ds: &storage.Dataset{
Name: "dataset-1",
URL: "s3://bucket/dataset-1",
},
wantErr: false,
},
{
name: "nil dataset",
ds: nil,
wantErr: true,
},
{
name: "empty name",
ds: &storage.Dataset{
Name: "",
URL: "s3://bucket/data",
},
wantErr: true,
},
{
name: "empty url",
ds: &storage.Dataset{
Name: "dataset-2",
URL: "",
},
wantErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := db.UpsertDataset(ctx, tc.ds)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
// TestGetDataset tests dataset retrieval
func TestGetDataset(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create dataset
ds := &storage.Dataset{
Name: "test-dataset",
URL: "s3://bucket/test-data",
}
err := db.UpsertDataset(ctx, ds)
require.NoError(t, err)
// Retrieve
retrieved, err := db.GetDataset(ctx, ds.Name)
require.NoError(t, err)
assert.Equal(t, ds.Name, retrieved.Name)
assert.Equal(t, ds.URL, retrieved.URL)
// Nonexistent dataset
_, err = db.GetDataset(ctx, "nonexistent")
require.Error(t, err)
// Empty name
_, err = db.GetDataset(ctx, "")
require.Error(t, err)
}
// TestListDatasets tests listing all datasets
func TestListDatasets(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create multiple datasets
datasets := []*storage.Dataset{
{Name: "dataset-a", URL: "s3://bucket/a"},
{Name: "dataset-b", URL: "s3://bucket/b"},
{Name: "dataset-c", URL: "s3://bucket/c"},
}
for _, ds := range datasets {
err := db.UpsertDataset(ctx, ds)
require.NoError(t, err)
}
// List all
list, err := db.ListDatasets(ctx, 0)
require.NoError(t, err)
assert.Len(t, list, 3)
// List with limit
listLimited, err := db.ListDatasets(ctx, 2)
require.NoError(t, err)
assert.Len(t, listLimited, 2)
}
// TestSearchDatasets tests dataset search functionality
func TestSearchDatasets(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create datasets
datasets := []*storage.Dataset{
{Name: "imagenet-train", URL: "s3://bucket/imagenet/train"},
{Name: "imagenet-val", URL: "s3://bucket/imagenet/val"},
{Name: "coco-dataset", URL: "s3://bucket/coco"},
}
for _, ds := range datasets {
err := db.UpsertDataset(ctx, ds)
require.NoError(t, err)
}
cases := []struct {
name string
term string
wantLen int
wantErr bool
}{
{"search imagenet", "imagenet", 2, false},
{"search coco", "coco", 1, false},
{"search val", "val", 1, false},
{"no match", "nonexistent", 0, false},
{"empty term", "", 0, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
results, err := db.SearchDatasets(ctx, tc.term, 10)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Len(t, results, tc.wantLen)
})
}
}
// TestAssociateTaskWithExperiment tests linking tasks to experiments
func TestAssociateTaskWithExperiment(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-tasks",
Name: "Task Association Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Create jobs
job1 := &storage.Job{ID: "job-1", Status: "pending"}
job2 := &storage.Job{ID: "job-2", Status: "pending"}
err = db.CreateJob(job1)
require.NoError(t, err)
err = db.CreateJob(job2)
require.NoError(t, err)
// Associate tasks
err = db.AssociateTaskWithExperiment(job1.ID, exp.ID)
require.NoError(t, err)
err = db.AssociateTaskWithExperiment(job2.ID, exp.ID)
require.NoError(t, err)
// List tasks for experiment
tasks, _, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 10})
require.NoError(t, err)
assert.Len(t, tasks, 2)
// Idempotent - associate again should not error
err = db.AssociateTaskWithExperiment(job1.ID, exp.ID)
require.NoError(t, err)
}
// TestGetExperimentVisibility tests retrieving visibility
func TestGetExperimentVisibility(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-visibility",
Name: "Visibility Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Create job with visibility and associate
job := &storage.Job{
ID: "job-vis",
Status: "pending",
}
err = db.CreateJob(job)
require.NoError(t, err)
err = db.AssociateTaskWithExperiment(job.ID, exp.ID)
require.NoError(t, err)
// Get visibility - jobs created without visibility default to 'private' via COALESCE
visibility, err := db.GetExperimentVisibility(exp.ID)
require.NoError(t, err)
assert.Equal(t, "private", visibility)
// Nonexistent experiment returns "private"
vis, err := db.GetExperimentVisibility("nonexistent")
require.NoError(t, err)
assert.Equal(t, "private", vis)
}
// TestCascadeExperimentVisibility tests updating visibility for all tasks
func TestCascadeExperimentVisibility(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-cascade",
Name: "Cascade Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Create jobs
job1 := &storage.Job{ID: "job-c1", Status: "pending"}
job2 := &storage.Job{ID: "job-c2", Status: "pending"}
err = db.CreateJob(job1)
require.NoError(t, err)
err = db.CreateJob(job2)
require.NoError(t, err)
// Associate tasks
err = db.AssociateTaskWithExperiment(job1.ID, exp.ID)
require.NoError(t, err)
err = db.AssociateTaskWithExperiment(job2.ID, exp.ID)
require.NoError(t, err)
// Cascade visibility update
err = db.CascadeExperimentVisibility(exp.ID, "institution")
require.NoError(t, err)
// Verify through listing
tasks, _, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 10})
require.NoError(t, err)
// Verify tasks are associated (visibility check removed - field doesn't exist)
assert.Len(t, tasks, 2)
}
// TestListTasksForExperiment tests pagination of experiment tasks
func TestListTasksForExperiment(t *testing.T) {
t.Parallel()
db := setupExperimentsTestDB(t)
ctx := context.Background()
// Create experiment
exp := &storage.Experiment{
ID: "exp-list-tasks",
Name: "List Tasks Test",
}
err := db.UpsertExperiment(ctx, exp)
require.NoError(t, err)
// Create multiple jobs
for i := 0; i < 5; i++ {
job := &storage.Job{
ID: "job-list-" + string(rune('a'+i)),
Status: "pending",
}
err = db.CreateJob(job)
require.NoError(t, err)
err = db.AssociateTaskWithExperiment(job.ID, exp.ID)
require.NoError(t, err)
}
// List with pagination
tasks, cursor, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 2})
require.NoError(t, err)
assert.Len(t, tasks, 2)
assert.NotEmpty(t, cursor, "Should have next cursor")
// List with next page
tasks2, cursor2, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{
Limit: 2,
Cursor: cursor,
})
require.NoError(t, err)
assert.Len(t, tasks2, 2)
assert.NotEmpty(t, cursor2)
// Verify different tasks (may be same if cursor not working correctly due to same timestamp)
// Note: With same created_at, cursor pagination may return same results
if len(tasks) > 0 && len(tasks2) > 0 {
t.Logf("First page: %v, Second page: %v", tasks[0].ID, tasks2[0].ID)
}
}