package store_test import ( "os" "testing" "github.com/jfraeys/fetch_ml/internal/store" ) func TestOpen(t *testing.T) { dbPath := "/tmp/test_fetchml.db" defer os.Remove(dbPath) defer os.Remove(dbPath + "-wal") defer os.Remove(dbPath + "-shm") s, err := store.Open(dbPath) if err != nil { t.Fatalf("Failed to open database: %v", err) } defer s.Close() if s.DB() == nil { t.Fatal("Database connection is nil") } } func TestGetUnsyncedRuns(t *testing.T) { dbPath := "/tmp/test_fetchml_unsynced.db" defer os.Remove(dbPath) defer os.Remove(dbPath + "-wal") defer os.Remove(dbPath + "-shm") s, err := store.Open(dbPath) if err != nil { t.Fatalf("Failed to open database: %v", err) } defer s.Close() // Insert test data using exported DB() _, err = s.DB().Exec(` INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment'); `) if err != nil { t.Fatalf("Failed to insert experiment: %v", err) } _, err = s.DB().Exec(` INSERT INTO ml_runs (run_id, experiment_id, name, status, synced) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0); `) if err != nil { t.Fatalf("Failed to insert run: %v", err) } runs, err := s.GetUnsyncedRuns() if err != nil { t.Fatalf("Failed to get unsynced runs: %v", err) } if len(runs) != 1 { t.Fatalf("Expected 1 unsynced run, got %d", len(runs)) } if runs[0].RunID != "run1" { t.Fatalf("Expected run_id 'run1', got '%s'", runs[0].RunID) } } func TestMarkRunSynced(t *testing.T) { dbPath := "/tmp/test_fetchml_sync.db" defer os.Remove(dbPath) defer os.Remove(dbPath + "-wal") defer os.Remove(dbPath + "-shm") s, err := store.Open(dbPath) if err != nil { t.Fatalf("Failed to open database: %v", err) } defer s.Close() // Insert test data _, err = s.DB().Exec(` INSERT INTO ml_experiments (experiment_id, name) VALUES ('exp1', 'Test Experiment'); `) if err != nil { t.Fatalf("Failed to insert experiment: %v", err) } _, err = s.DB().Exec(` INSERT INTO ml_runs (run_id, experiment_id, name, status, synced) VALUES ('run1', 'exp1', 'Test Run', 'FINISHED', 0); `) if err != nil { t.Fatalf("Failed to insert run: %v", err) } // Mark as synced err = s.MarkRunSynced("run1") if err != nil { t.Fatalf("Failed to mark run as synced: %v", err) } // Verify using exported DB() var synced int err = s.DB().QueryRow("SELECT synced FROM ml_runs WHERE run_id = 'run1'").Scan(&synced) if err != nil { t.Fatalf("Failed to query run: %v", err) } if synced != 1 { t.Fatalf("Expected synced=1, got %d", synced) } }