package storage_test import ( "context" "encoding/json" "testing" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // experimentsTestSchema includes all tables needed for experiment tests const experimentsTestSchema = ` CREATE TABLE IF NOT EXISTS jobs ( id TEXT PRIMARY KEY, job_name TEXT NOT NULL, args TEXT, status TEXT NOT NULL DEFAULT 'pending', priority INTEGER DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, started_at DATETIME, ended_at DATETIME, worker_id TEXT, user_id TEXT, error TEXT, datasets TEXT, metadata TEXT, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, visibility TEXT NOT NULL DEFAULT 'lab' ); CREATE TABLE IF NOT EXISTS experiments ( id TEXT PRIMARY KEY, name TEXT NOT NULL, description TEXT, status TEXT DEFAULT 'pending', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, user_id TEXT, workspace_id TEXT ); CREATE TABLE IF NOT EXISTS experiment_environments ( experiment_id TEXT PRIMARY KEY, python_version TEXT, cuda_version TEXT, system_os TEXT, system_arch TEXT, hostname TEXT, requirements_hash TEXT, conda_env_hash TEXT, dependencies TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS experiment_git_info ( experiment_id TEXT PRIMARY KEY, commit_sha TEXT, branch TEXT, remote_url TEXT, is_dirty INTEGER DEFAULT 0, diff_patch TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS experiment_seeds ( experiment_id TEXT PRIMARY KEY, numpy_seed INTEGER, torch_seed INTEGER, tensorflow_seed INTEGER, random_seed INTEGER, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS experiment_tasks ( experiment_id TEXT NOT NULL, task_id TEXT NOT NULL, added_at DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (experiment_id, task_id), FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE, FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS datasets ( name TEXT PRIMARY KEY, url TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); ` // setupExperimentsTestDB creates a fresh database with experiment schema func setupExperimentsTestDB(t *testing.T) *storage.DB { t.Helper() dbPath := t.TempDir() + "/experiments_test.db" db, err := storage.NewDBFromPath(dbPath) require.NoError(t, err, "Failed to create database") err = db.Initialize(experimentsTestSchema) require.NoError(t, err, "Failed to initialize database schema") t.Cleanup(func() { _ = db.Close() }) return db } // TestUpsertExperiment tests creating and updating experiments func TestUpsertExperiment(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() cases := []struct { name string exp *storage.Experiment wantErr bool }{ { name: "valid experiment", exp: &storage.Experiment{ ID: "exp-1", Name: "Test Experiment", Description: "Test description", Status: "active", UserID: "user-123", WorkspaceID: "ws-456", }, wantErr: false, }, { name: "minimal experiment", exp: &storage.Experiment{ ID: "exp-2", Name: "Minimal Experiment", }, wantErr: false, }, { name: "nil experiment", exp: nil, wantErr: true, }, { name: "empty id", exp: &storage.Experiment{ ID: "", Name: "No ID", }, wantErr: true, }, { name: "empty name", exp: &storage.Experiment{ ID: "exp-3", Name: "", }, wantErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := db.UpsertExperiment(ctx, tc.exp) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) }) } } // TestUpsertExperimentUpdate tests updating an existing experiment func TestUpsertExperimentUpdate(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create initial experiment exp := &storage.Experiment{ ID: "exp-update", Name: "Original Name", Description: "Original description", Status: "pending", UserID: "user-1", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Update the experiment exp.Name = "Updated Name" exp.Description = "Updated description" exp.Status = "active" err = db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Verify update by retrieving with metadata result, err := db.GetExperimentWithMetadata(ctx, exp.ID) require.NoError(t, err) assert.Equal(t, "Updated Name", result.Experiment.Name) assert.Equal(t, "Updated description", result.Experiment.Description) assert.Equal(t, "active", result.Experiment.Status) } // TestUpsertExperimentEnvironment tests storing experiment environment info func TestUpsertExperimentEnvironment(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment first exp := &storage.Experiment{ ID: "exp-env", Name: "Environment Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) cases := []struct { name string expID string env *storage.ExperimentEnvironment wantErr bool }{ { name: "valid environment", expID: exp.ID, env: &storage.ExperimentEnvironment{ PythonVersion: "3.9.7", CUDAVersion: "11.8", SystemOS: "linux", SystemArch: "x86_64", Hostname: "test-host", RequirementsHash: "abc123", CondaEnvHash: "def456", Dependencies: json.RawMessage(`{"numpy": "1.21.0"}`), }, wantErr: false, }, { name: "empty experiment id", expID: "", env: &storage.ExperimentEnvironment{PythonVersion: "3.8"}, wantErr: true, }, { name: "nil environment", expID: exp.ID, env: nil, wantErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := db.UpsertExperimentEnvironment(ctx, tc.expID, tc.env) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) }) } } // TestUpsertExperimentGitInfo tests storing git information func TestUpsertExperimentGitInfo(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-git", Name: "Git Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Store git info gitInfo := &storage.ExperimentGitInfo{ CommitSHA: "a1b2c3d4e5f6", Branch: "main", RemoteURL: "https://github.com/user/repo.git", IsDirty: true, DiffPatch: "diff --git a/file.txt b/file.txt", } err = db.UpsertExperimentGitInfo(ctx, exp.ID, gitInfo) require.NoError(t, err) // Retrieve and verify result, err := db.GetExperimentWithMetadata(ctx, exp.ID) require.NoError(t, err) require.NotNil(t, result.GitInfo) assert.Equal(t, gitInfo.CommitSHA, result.GitInfo.CommitSHA) assert.Equal(t, gitInfo.Branch, result.GitInfo.Branch) assert.Equal(t, gitInfo.RemoteURL, result.GitInfo.RemoteURL) assert.Equal(t, gitInfo.IsDirty, result.GitInfo.IsDirty) } // TestUpsertExperimentSeeds tests storing random seeds func TestUpsertExperimentSeeds(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-seeds", Name: "Seeds Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Store seeds numpySeed := int64(42) torchSeed := int64(123) seeds := &storage.ExperimentSeeds{ Numpy: &numpySeed, Torch: &torchSeed, } err = db.UpsertExperimentSeeds(ctx, exp.ID, seeds) require.NoError(t, err) // Retrieve and verify result, err := db.GetExperimentWithMetadata(ctx, exp.ID) require.NoError(t, err) require.NotNil(t, result.Seeds) require.NotNil(t, result.Seeds.Numpy) assert.Equal(t, numpySeed, *result.Seeds.Numpy) require.NotNil(t, result.Seeds.Torch) assert.Equal(t, torchSeed, *result.Seeds.Torch) assert.Nil(t, result.Seeds.TensorFlow) assert.Nil(t, result.Seeds.Random) } // TestGetExperimentWithMetadata tests retrieving complete experiment metadata func TestGetExperimentWithMetadata(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment with all metadata exp := &storage.Experiment{ ID: "exp-full", Name: "Full Metadata Test", Description: "Test experiment", Status: "active", UserID: "user-test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) env := &storage.ExperimentEnvironment{ PythonVersion: "3.10.0", SystemOS: "darwin", SystemArch: "arm64", } err = db.UpsertExperimentEnvironment(ctx, exp.ID, env) require.NoError(t, err) // Retrieve without metadata result, err := db.GetExperimentWithMetadata(ctx, exp.ID) require.NoError(t, err) assert.Equal(t, exp.ID, result.Experiment.ID) assert.Equal(t, exp.Name, result.Experiment.Name) assert.NotNil(t, result.Environment) // Test nonexistent experiment _, err = db.GetExperimentWithMetadata(ctx, "nonexistent") require.Error(t, err) } // TestGetExperimentWithMetadataEmptyID tests validation func TestGetExperimentWithMetadataEmptyID(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() _, err := db.GetExperimentWithMetadata(ctx, "") require.Error(t, err) } // TestUpsertDataset tests dataset creation and updates func TestUpsertDataset(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() cases := []struct { name string ds *storage.Dataset wantErr bool }{ { name: "valid dataset", ds: &storage.Dataset{ Name: "dataset-1", URL: "s3://bucket/dataset-1", }, wantErr: false, }, { name: "nil dataset", ds: nil, wantErr: true, }, { name: "empty name", ds: &storage.Dataset{ Name: "", URL: "s3://bucket/data", }, wantErr: true, }, { name: "empty url", ds: &storage.Dataset{ Name: "dataset-2", URL: "", }, wantErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := db.UpsertDataset(ctx, tc.ds) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) }) } } // TestGetDataset tests dataset retrieval func TestGetDataset(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create dataset ds := &storage.Dataset{ Name: "test-dataset", URL: "s3://bucket/test-data", } err := db.UpsertDataset(ctx, ds) require.NoError(t, err) // Retrieve retrieved, err := db.GetDataset(ctx, ds.Name) require.NoError(t, err) assert.Equal(t, ds.Name, retrieved.Name) assert.Equal(t, ds.URL, retrieved.URL) // Nonexistent dataset _, err = db.GetDataset(ctx, "nonexistent") require.Error(t, err) // Empty name _, err = db.GetDataset(ctx, "") require.Error(t, err) } // TestListDatasets tests listing all datasets func TestListDatasets(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create multiple datasets datasets := []*storage.Dataset{ {Name: "dataset-a", URL: "s3://bucket/a"}, {Name: "dataset-b", URL: "s3://bucket/b"}, {Name: "dataset-c", URL: "s3://bucket/c"}, } for _, ds := range datasets { err := db.UpsertDataset(ctx, ds) require.NoError(t, err) } // List all list, err := db.ListDatasets(ctx, 0) require.NoError(t, err) assert.Len(t, list, 3) // List with limit listLimited, err := db.ListDatasets(ctx, 2) require.NoError(t, err) assert.Len(t, listLimited, 2) } // TestSearchDatasets tests dataset search functionality func TestSearchDatasets(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create datasets datasets := []*storage.Dataset{ {Name: "imagenet-train", URL: "s3://bucket/imagenet/train"}, {Name: "imagenet-val", URL: "s3://bucket/imagenet/val"}, {Name: "coco-dataset", URL: "s3://bucket/coco"}, } for _, ds := range datasets { err := db.UpsertDataset(ctx, ds) require.NoError(t, err) } cases := []struct { name string term string wantLen int wantErr bool }{ {"search imagenet", "imagenet", 2, false}, {"search coco", "coco", 1, false}, {"search val", "val", 1, false}, {"no match", "nonexistent", 0, false}, {"empty term", "", 0, false}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { results, err := db.SearchDatasets(ctx, tc.term, 10) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) assert.Len(t, results, tc.wantLen) }) } } // TestAssociateTaskWithExperiment tests linking tasks to experiments func TestAssociateTaskWithExperiment(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-tasks", Name: "Task Association Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Create jobs job1 := &storage.Job{ID: "job-1", Status: "pending"} job2 := &storage.Job{ID: "job-2", Status: "pending"} err = db.CreateJob(job1) require.NoError(t, err) err = db.CreateJob(job2) require.NoError(t, err) // Associate tasks err = db.AssociateTaskWithExperiment(job1.ID, exp.ID) require.NoError(t, err) err = db.AssociateTaskWithExperiment(job2.ID, exp.ID) require.NoError(t, err) // List tasks for experiment tasks, _, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 10}) require.NoError(t, err) assert.Len(t, tasks, 2) // Idempotent - associate again should not error err = db.AssociateTaskWithExperiment(job1.ID, exp.ID) require.NoError(t, err) } // TestGetExperimentVisibility tests retrieving visibility func TestGetExperimentVisibility(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-visibility", Name: "Visibility Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Create job with visibility and associate job := &storage.Job{ ID: "job-vis", Status: "pending", } err = db.CreateJob(job) require.NoError(t, err) err = db.AssociateTaskWithExperiment(job.ID, exp.ID) require.NoError(t, err) // Get visibility - jobs created without visibility default to 'private' via COALESCE visibility, err := db.GetExperimentVisibility(exp.ID) require.NoError(t, err) assert.Equal(t, "private", visibility) // Nonexistent experiment returns "private" vis, err := db.GetExperimentVisibility("nonexistent") require.NoError(t, err) assert.Equal(t, "private", vis) } // TestCascadeExperimentVisibility tests updating visibility for all tasks func TestCascadeExperimentVisibility(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-cascade", Name: "Cascade Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Create jobs job1 := &storage.Job{ID: "job-c1", Status: "pending"} job2 := &storage.Job{ID: "job-c2", Status: "pending"} err = db.CreateJob(job1) require.NoError(t, err) err = db.CreateJob(job2) require.NoError(t, err) // Associate tasks err = db.AssociateTaskWithExperiment(job1.ID, exp.ID) require.NoError(t, err) err = db.AssociateTaskWithExperiment(job2.ID, exp.ID) require.NoError(t, err) // Cascade visibility update err = db.CascadeExperimentVisibility(exp.ID, "institution") require.NoError(t, err) // Verify through listing tasks, _, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 10}) require.NoError(t, err) // Verify tasks are associated (visibility check removed - field doesn't exist) assert.Len(t, tasks, 2) } // TestListTasksForExperiment tests pagination of experiment tasks func TestListTasksForExperiment(t *testing.T) { t.Parallel() db := setupExperimentsTestDB(t) ctx := context.Background() // Create experiment exp := &storage.Experiment{ ID: "exp-list-tasks", Name: "List Tasks Test", } err := db.UpsertExperiment(ctx, exp) require.NoError(t, err) // Create multiple jobs for i := 0; i < 5; i++ { job := &storage.Job{ ID: "job-list-" + string(rune('a'+i)), Status: "pending", } err = db.CreateJob(job) require.NoError(t, err) err = db.AssociateTaskWithExperiment(job.ID, exp.ID) require.NoError(t, err) } // List with pagination tasks, cursor, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{Limit: 2}) require.NoError(t, err) assert.Len(t, tasks, 2) assert.NotEmpty(t, cursor, "Should have next cursor") // List with next page tasks2, cursor2, err := db.ListTasksForExperiment(exp.ID, storage.ListTasksOptions{ Limit: 2, Cursor: cursor, }) require.NoError(t, err) assert.Len(t, tasks2, 2) assert.NotEmpty(t, cursor2) // Verify different tasks (may be same if cursor not working correctly due to same timestamp) // Note: With same created_at, cursor pagination may return same results if len(tasks) > 0 && len(tasks2) > 0 { t.Logf("First page: %v, Second page: %v", tasks[0].ID, tasks2[0].ID) } }