Add tests for: - Close: proper resource cleanup - ConcurrentLogMetrics: thread-safe metric logging - GetRunMetricsEmpty: empty result handling - GetRunParamsEmpty: empty result handling - MarkRunSyncedNonexistent: graceful handling of missing runs Coverage: 75.3%
311 lines
7.8 KiB
Go
311 lines
7.8 KiB
Go
package store_test
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/store"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestOpen(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_fetchml.db")
|
|
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err, "Failed to open database")
|
|
defer s.Close()
|
|
|
|
require.NotNil(t, s.DB(), "Database connection should not be nil")
|
|
}
|
|
|
|
func TestOpenCreatesDirectory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
nestedPath := filepath.Join(tmpDir, "deep", "nested", "test.db")
|
|
|
|
s, err := store.Open(nestedPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Verify directory was created
|
|
_, err = os.Stat(filepath.Dir(nestedPath))
|
|
require.NoError(t, err, "Parent directory should be created")
|
|
}
|
|
|
|
func TestGetUnsyncedRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_unsynced.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_runs (run_id, experiment_id, name, status, synced)
|
|
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0);
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
runs, err := s.GetUnsyncedRuns()
|
|
require.NoError(t, err)
|
|
require.Len(t, runs, 1, "Expected 1 unsynced run")
|
|
assert.Equal(t, "run1", runs[0].RunID)
|
|
}
|
|
|
|
func TestGetRunsByExperiment(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_by_exp.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
|
VALUES ('run1', 'exp1', 'Run 1', 'FINISHED');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
runs, err := s.GetRunsByExperiment("exp1")
|
|
require.NoError(t, err)
|
|
require.Len(t, runs, 1, "Expected 1 run for experiment")
|
|
assert.Equal(t, "run1", runs[0].RunID)
|
|
|
|
// Nonexistent experiment
|
|
runs, err = s.GetRunsByExperiment("nonexistent")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, runs)
|
|
}
|
|
|
|
func TestMarkRunSynced(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_sync.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_runs (run_id, experiment_id, name, status, synced)
|
|
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0);
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
// Mark as synced
|
|
err = s.MarkRunSynced("run1")
|
|
require.NoError(t, err)
|
|
|
|
// Verify
|
|
var synced int
|
|
err = s.DB().QueryRow("SELECT synced FROM ml_runs WHERE run_id = 'run1'").Scan(&synced)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, synced)
|
|
}
|
|
|
|
func TestGetRunMetrics(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_metrics.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
|
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_metrics (run_id, key, value, step)
|
|
VALUES ('run1', 'accuracy', 0.95, 1);
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
metrics, err := s.GetRunMetrics("run1")
|
|
require.NoError(t, err)
|
|
require.Len(t, metrics, 1)
|
|
assert.Equal(t, "accuracy", metrics[0].Key)
|
|
assert.InDelta(t, 0.95, metrics[0].Value, 0.001)
|
|
}
|
|
|
|
func TestGetRunParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_params.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_runs (run_id, experiment_id, name, status)
|
|
VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`
|
|
INSERT INTO ml_params (run_id, key, value)
|
|
VALUES ('run1', 'learning_rate', '0.01');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
params, err := s.GetRunParams("run1")
|
|
require.NoError(t, err)
|
|
require.Len(t, params, 1)
|
|
assert.Equal(t, "learning_rate", params[0].Key)
|
|
assert.Equal(t, "0.01", params[0].Value)
|
|
}
|
|
|
|
func TestDB(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_db.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
db := s.DB()
|
|
require.NotNil(t, db)
|
|
|
|
err = db.Ping()
|
|
require.NoError(t, err, "DB connection should be valid")
|
|
}
|
|
|
|
// TestClose tests the Close method
|
|
func TestClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_close.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
|
|
err = s.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Double close should not panic
|
|
err = s.Close()
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// TestConcurrentLogMetrics tests concurrent metric logging
|
|
func TestConcurrentLogMetrics(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_concurrent.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Insert test data
|
|
_, err = s.DB().Exec(`INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test');`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = s.DB().Exec(`INSERT INTO ml_runs (run_id, experiment_id, name, status) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED');`)
|
|
require.NoError(t, err)
|
|
|
|
// Concurrent inserts via raw SQL (simulating LogMetric behavior)
|
|
// Note: SQLite may have locking contention, but should not panic
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
// Use unique keys to avoid contention
|
|
key := fmt.Sprintf("metric_%d", i)
|
|
_, err := s.DB().Exec(`INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?)`,
|
|
"run1", key, float64(i)/10.0, i)
|
|
// Allow for SQLite locking errors in concurrent test
|
|
if err != nil {
|
|
t.Logf("Concurrent insert %d had error (expected with SQLite): %v", i, err)
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Verify metrics were recorded (some may fail due to SQLite locking)
|
|
metrics, err := s.GetRunMetrics("run1")
|
|
require.NoError(t, err)
|
|
// Due to SQLite locking, not all may succeed, but at least some should
|
|
assert.GreaterOrEqual(t, len(metrics), 1, "At least some metrics should be recorded")
|
|
}
|
|
|
|
// TestGetRunMetricsEmpty tests empty metrics result
|
|
func TestGetRunMetricsEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_empty_metrics.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// No metrics inserted
|
|
metrics, err := s.GetRunMetrics("nonexistent-run")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, metrics)
|
|
}
|
|
|
|
// TestGetRunParamsEmpty tests empty params result
|
|
func TestGetRunParamsEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_empty_params.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// No params inserted
|
|
params, err := s.GetRunParams("nonexistent-run")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, params)
|
|
}
|
|
|
|
// TestMarkRunSyncedNonexistent tests marking nonexistent run
|
|
func TestMarkRunSyncedNonexistent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dbPath := filepath.Join(t.TempDir(), "test_sync_nonexistent.db")
|
|
s, err := store.Open(dbPath)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
// Should not error even if run doesn't exist
|
|
err = s.MarkRunSynced("nonexistent-run")
|
|
require.NoError(t, err)
|
|
}
|