From 5d39dff6a02f22e70e2a620879f1178b56afbd20 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Fri, 13 Mar 2026 23:26:41 -0400 Subject: [PATCH] test(store): extend store coverage with edge cases and concurrency Add tests for: - Close: proper resource cleanup - ConcurrentLogMetrics: thread-safe metric logging - GetRunMetricsEmpty: empty result handling - GetRunParamsEmpty: empty result handling - MarkRunSyncedNonexistent: graceful handling of missing runs Coverage: 75.3% --- internal/store/store_test.go | 352 +++++++++++++++++++++++++++-------- 1 file changed, 275 insertions(+), 77 deletions(-) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 609509d..fa53b14 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -1,113 +1,311 @@ package store_test import ( + "fmt" "os" + "path/filepath" + "sync" "testing" "github.com/jfraeys/fetch_ml/internal/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOpen(t *testing.T) { - dbPath := "/tmp/test_fetchml.db" - defer os.Remove(dbPath) - defer os.Remove(dbPath + "-wal") - defer os.Remove(dbPath + "-shm") + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_fetchml.db") s, err := store.Open(dbPath) - if err != nil { - t.Fatalf("Failed to open database: %v", err) - } + require.NoError(t, err, "Failed to open database") defer s.Close() - if s.DB() == nil { - t.Fatal("Database connection is nil") - } + require.NotNil(t, s.DB(), "Database connection should not be nil") +} + +func TestOpenCreatesDirectory(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + nestedPath := filepath.Join(tmpDir, "deep", "nested", "test.db") + + s, err := store.Open(nestedPath) + require.NoError(t, err) + defer s.Close() + + // Verify directory was created + _, err = os.Stat(filepath.Dir(nestedPath)) + require.NoError(t, err, "Parent directory should be created") } func TestGetUnsyncedRuns(t *testing.T) { - dbPath := "/tmp/test_fetchml_unsynced.db" - defer os.Remove(dbPath) - defer os.Remove(dbPath + "-wal") - defer os.Remove(dbPath + "-shm") + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "test_unsynced.db") s, err := store.Open(dbPath) - if err != nil { - t.Fatalf("Failed to open database: %v", err) - } - defer s.Close() - - // Insert test data using exported DB() - _, err = s.DB().Exec(` - INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment'); - `) - if err != nil { - t.Fatalf("Failed to insert experiment: %v", err) - } - - _, err = s.DB().Exec(` - INSERT INTO ml_runs (run_id, experiment_id, name, status, synced) - VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0); - `) - if err != nil { - t.Fatalf("Failed to insert run: %v", err) - } - - runs, err := s.GetUnsyncedRuns() - if err != nil { - t.Fatalf("Failed to get unsynced runs: %v", err) - } - - if len(runs) != 1 { - t.Fatalf("Expected 1 unsynced run, got %d", len(runs)) - } - - if runs[0].RunID != "run1" { - t.Fatalf("Expected run_id 'run1', got '%s'", runs[0].RunID) - } -} - -func TestMarkRunSynced(t *testing.T) { - dbPath := "/tmp/test_fetchml_sync.db" - defer os.Remove(dbPath) - defer os.Remove(dbPath + "-wal") - defer os.Remove(dbPath + "-shm") - - s, err := store.Open(dbPath) - if err != nil { - t.Fatalf("Failed to open database: %v", err) - } + require.NoError(t, err) defer s.Close() // Insert test data _, err = s.DB().Exec(` - INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment'); + INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test'); `) - if err != nil { - t.Fatalf("Failed to insert experiment: %v", err) - } + require.NoError(t, err) _, err = s.DB().Exec(` INSERT INTO ml_runs (run_id, experiment_id, name, status, synced) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0); `) - if err != nil { - t.Fatalf("Failed to insert run: %v", err) - } + require.NoError(t, err) + + runs, err := s.GetUnsyncedRuns() + require.NoError(t, err) + require.Len(t, runs, 1, "Expected 1 unsynced run") + assert.Equal(t, "run1", runs[0].RunID) +} + +func TestGetRunsByExperiment(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_by_exp.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Insert test data + _, err = s.DB().Exec(` + INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_runs (run_id, experiment_id, name, status) + VALUES ('run1', 'exp1', 'Run 1', 'FINISHED'); + `) + require.NoError(t, err) + + runs, err := s.GetRunsByExperiment("exp1") + require.NoError(t, err) + require.Len(t, runs, 1, "Expected 1 run for experiment") + assert.Equal(t, "run1", runs[0].RunID) + + // Nonexistent experiment + runs, err = s.GetRunsByExperiment("nonexistent") + require.NoError(t, err) + assert.Empty(t, runs) +} + +func TestMarkRunSynced(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_sync.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Insert test data + _, err = s.DB().Exec(` + INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_runs (run_id, experiment_id, name, status, synced) + VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0); + `) + require.NoError(t, err) // Mark as synced err = s.MarkRunSynced("run1") - if err != nil { - t.Fatalf("Failed to mark run as synced: %v", err) - } + require.NoError(t, err) - // Verify using exported DB() + // Verify var synced int err = s.DB().QueryRow("SELECT synced FROM ml_runs WHERE run_id = 'run1'").Scan(&synced) - if err != nil { - t.Fatalf("Failed to query run: %v", err) - } - - if synced != 1 { - t.Fatalf("Expected synced=1, got %d", synced) - } + require.NoError(t, err) + assert.Equal(t, 1, synced) +} + +func TestGetRunMetrics(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_metrics.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Insert test data + _, err = s.DB().Exec(` + INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_runs (run_id, experiment_id, name, status) + VALUES ('run1', 'exp1', 'Test Run', 'FINISHED'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_metrics (run_id, key, value, step) + VALUES ('run1', 'accuracy', 0.95, 1); + `) + require.NoError(t, err) + + metrics, err := s.GetRunMetrics("run1") + require.NoError(t, err) + require.Len(t, metrics, 1) + assert.Equal(t, "accuracy", metrics[0].Key) + assert.InDelta(t, 0.95, metrics[0].Value, 0.001) +} + +func TestGetRunParams(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_params.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Insert test data + _, err = s.DB().Exec(` + INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_runs (run_id, experiment_id, name, status) + VALUES ('run1', 'exp1', 'Test Run', 'FINISHED'); + `) + require.NoError(t, err) + + _, err = s.DB().Exec(` + INSERT INTO ml_params (run_id, key, value) + VALUES ('run1', 'learning_rate', '0.01'); + `) + require.NoError(t, err) + + params, err := s.GetRunParams("run1") + require.NoError(t, err) + require.Len(t, params, 1) + assert.Equal(t, "learning_rate", params[0].Key) + assert.Equal(t, "0.01", params[0].Value) +} + +func TestDB(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_db.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + db := s.DB() + require.NotNil(t, db) + + err = db.Ping() + require.NoError(t, err, "DB connection should be valid") +} + +// TestClose tests the Close method +func TestClose(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_close.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + + err = s.Close() + require.NoError(t, err) + + // Double close should not panic + err = s.Close() + assert.NoError(t, err) +} + +// TestConcurrentLogMetrics tests concurrent metric logging +func TestConcurrentLogMetrics(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_concurrent.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Insert test data + _, err = s.DB().Exec(`INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');`) + require.NoError(t, err) + + _, err = s.DB().Exec(`INSERT INTO ml_runs (run_id, experiment_id, name, status) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');`) + require.NoError(t, err) + + // Concurrent inserts via raw SQL (simulating LogMetric behavior) + // Note: SQLite may have locking contention, but should not panic + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // Use unique keys to avoid contention + key := fmt.Sprintf("metric_%d", i) + _, err := s.DB().Exec(`INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?)`, + "run1", key, float64(i)/10.0, i) + // Allow for SQLite locking errors in concurrent test + if err != nil { + t.Logf("Concurrent insert %d had error (expected with SQLite): %v", i, err) + } + }(i) + } + wg.Wait() + + // Verify metrics were recorded (some may fail due to SQLite locking) + metrics, err := s.GetRunMetrics("run1") + require.NoError(t, err) + // Due to SQLite locking, not all may succeed, but at least some should + assert.GreaterOrEqual(t, len(metrics), 1, "At least some metrics should be recorded") +} + +// TestGetRunMetricsEmpty tests empty metrics result +func TestGetRunMetricsEmpty(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_empty_metrics.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // No metrics inserted + metrics, err := s.GetRunMetrics("nonexistent-run") + require.NoError(t, err) + assert.Empty(t, metrics) +} + +// TestGetRunParamsEmpty tests empty params result +func TestGetRunParamsEmpty(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_empty_params.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // No params inserted + params, err := s.GetRunParams("nonexistent-run") + require.NoError(t, err) + assert.Empty(t, params) +} + +// TestMarkRunSyncedNonexistent tests marking nonexistent run +func TestMarkRunSyncedNonexistent(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "test_sync_nonexistent.db") + s, err := store.Open(dbPath) + require.NoError(t, err) + defer s.Close() + + // Should not error even if run doesn't exist + err = s.MarkRunSynced("nonexistent-run") + require.NoError(t, err) }