test(store): extend store coverage with edge cases and concurrency
Add tests for: - Close: proper resource cleanup - ConcurrentLogMetrics: thread-safe metric logging - GetRunMetricsEmpty: empty result handling - GetRunParamsEmpty: empty result handling - MarkRunSyncedNonexistent: graceful handling of missing runs Coverage: 75.3%
This commit is contained in:
parent
50b6506243
commit
5d39dff6a0
1 changed files with 275 additions and 77 deletions
|
|
@ -1,113 +1,311 @@
|
|||
package store_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
dbPath := "/tmp/test_fetchml.db"
|
||||
defer os.Remove(dbPath)
|
||||
defer os.Remove(dbPath + "-wal")
|
||||
defer os.Remove(dbPath + "-shm")
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_fetchml.db")
|
||||
|
||||
s, err := store.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
defer s.Close()
|
||||
|
||||
if s.DB() == nil {
|
||||
t.Fatal("Database connection is nil")
|
||||
}
|
||||
require.NotNil(t, s.DB(), "Database connection should not be nil")
|
||||
}
|
||||
|
||||
func TestOpenCreatesDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
nestedPath := filepath.Join(tmpDir, "deep", "nested", "test.db")
|
||||
|
||||
s, err := store.Open(nestedPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Verify directory was created
|
||||
_, err = os.Stat(filepath.Dir(nestedPath))
|
||||
require.NoError(t, err, "Parent directory should be created")
|
||||
}
|
||||
|
||||
func TestGetUnsyncedRuns(t *testing.T) {
|
||||
dbPath := "/tmp/test_fetchml_unsynced.db"
|
||||
defer os.Remove(dbPath)
|
||||
defer os.Remove(dbPath + "-wal")
|
||||
defer os.Remove(dbPath + "-shm")
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_unsynced.db")
|
||||
s, err := store.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data using exported DB()
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment');
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert experiment: %v", err)
|
||||
}
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status, synced)
|
||||
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert run: %v", err)
|
||||
}
|
||||
|
||||
runs, err := s.GetUnsyncedRuns()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get unsynced runs: %v", err)
|
||||
}
|
||||
|
||||
if len(runs) != 1 {
|
||||
t.Fatalf("Expected 1 unsynced run, got %d", len(runs))
|
||||
}
|
||||
|
||||
if runs[0].RunID != "run1" {
|
||||
t.Fatalf("Expected run_id 'run1', got '%s'", runs[0].RunID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkRunSynced(t *testing.T) {
|
||||
dbPath := "/tmp/test_fetchml_sync.db"
|
||||
defer os.Remove(dbPath)
|
||||
defer os.Remove(dbPath + "-wal")
|
||||
defer os.Remove(dbPath + "-shm")
|
||||
|
||||
s, err := store.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment');
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert experiment: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status, synced)
|
||||
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert run: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
runs, err := s.GetUnsyncedRuns()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, runs, 1, "Expected 1 unsynced run")
|
||||
assert.Equal(t, "run1", runs[0].RunID)
|
||||
}
|
||||
|
||||
func TestGetRunsByExperiment(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_by_exp.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
||||
VALUES ('run1', 'exp1', 'Run 1', 'FINISHED');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
runs, err := s.GetRunsByExperiment("exp1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, runs, 1, "Expected 1 run for experiment")
|
||||
assert.Equal(t, "run1", runs[0].RunID)
|
||||
|
||||
// Nonexistent experiment
|
||||
runs, err = s.GetRunsByExperiment("nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, runs)
|
||||
}
|
||||
|
||||
func TestMarkRunSynced(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_sync.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status, synced)
|
||||
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as synced
|
||||
err = s.MarkRunSynced("run1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mark run as synced: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify using exported DB()
|
||||
// Verify
|
||||
var synced int
|
||||
err = s.DB().QueryRow("SELECT synced FROM ml_runs WHERE run_id = 'run1'").Scan(&synced)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query run: %v", err)
|
||||
}
|
||||
|
||||
if synced != 1 {
|
||||
t.Fatalf("Expected synced=1, got %d", synced)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, synced)
|
||||
}
|
||||
|
||||
func TestGetRunMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_metrics.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
||||
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_metrics (run_id, key, value, step)
|
||||
VALUES ('run1', 'accuracy', 0.95, 1);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics, err := s.GetRunMetrics("run1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, metrics, 1)
|
||||
assert.Equal(t, "accuracy", metrics[0].Key)
|
||||
assert.InDelta(t, 0.95, metrics[0].Value, 0.001)
|
||||
}
|
||||
|
||||
func TestGetRunParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_params.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
||||
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`
|
||||
INSERT INTO ml_params (run_id, key, value)
|
||||
VALUES ('run1', 'learning_rate', '0.01');
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
params, err := s.GetRunParams("run1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, params, 1)
|
||||
assert.Equal(t, "learning_rate", params[0].Key)
|
||||
assert.Equal(t, "0.01", params[0].Value)
|
||||
}
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_db.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
db := s.DB()
|
||||
require.NotNil(t, db)
|
||||
|
||||
err = db.Ping()
|
||||
require.NoError(t, err, "DB connection should be valid")
|
||||
}
|
||||
|
||||
// TestClose tests the Close method
|
||||
func TestClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_close.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Double close should not panic
|
||||
err = s.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestConcurrentLogMetrics tests concurrent metric logging
|
||||
func TestConcurrentLogMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_concurrent.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Insert test data
|
||||
_, err = s.DB().Exec(`INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.DB().Exec(`INSERT INTO ml_runs (run_id, experiment_id, name, status) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Concurrent inserts via raw SQL (simulating LogMetric behavior)
|
||||
// Note: SQLite may have locking contention, but should not panic
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
// Use unique keys to avoid contention
|
||||
key := fmt.Sprintf("metric_%d", i)
|
||||
_, err := s.DB().Exec(`INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?)`,
|
||||
"run1", key, float64(i)/10.0, i)
|
||||
// Allow for SQLite locking errors in concurrent test
|
||||
if err != nil {
|
||||
t.Logf("Concurrent insert %d had error (expected with SQLite): %v", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify metrics were recorded (some may fail due to SQLite locking)
|
||||
metrics, err := s.GetRunMetrics("run1")
|
||||
require.NoError(t, err)
|
||||
// Due to SQLite locking, not all may succeed, but at least some should
|
||||
assert.GreaterOrEqual(t, len(metrics), 1, "At least some metrics should be recorded")
|
||||
}
|
||||
|
||||
// TestGetRunMetricsEmpty tests empty metrics result
|
||||
func TestGetRunMetricsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_empty_metrics.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// No metrics inserted
|
||||
metrics, err := s.GetRunMetrics("nonexistent-run")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, metrics)
|
||||
}
|
||||
|
||||
// TestGetRunParamsEmpty tests empty params result
|
||||
func TestGetRunParamsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_empty_params.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// No params inserted
|
||||
params, err := s.GetRunParams("nonexistent-run")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, params)
|
||||
}
|
||||
|
||||
// TestMarkRunSyncedNonexistent tests marking nonexistent run
|
||||
func TestMarkRunSyncedNonexistent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "test_sync_nonexistent.db")
|
||||
s, err := store.Open(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// Should not error even if run doesn't exist
|
||||
err = s.MarkRunSynced("nonexistent-run")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue