package storage import ( "context" "encoding/json" "testing" "time" "github.com/jfraeys/fetch_ml/internal/storage" ) func TestExperimentMetadataRoundTripSQLite(t *testing.T) { t.Parallel() schema, err := storage.SchemaForDBType(storage.DBTypeSQLite) if err != nil { t.Fatalf("SchemaForDBType(sqlite) failed: %v", err) } dbPath := t.TempDir() + "/test.sqlite" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("NewDBFromPath failed: %v", err) } defer func() { _ = db.Close() }() if err := db.Initialize(schema); err != nil { t.Fatalf("Initialize failed: %v", err) } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() exp := &storage.Experiment{ ID: "exp-1", Name: "train-resnet", Description: "test run", Status: "pending", UserID: "alice", WorkspaceID: "ws-1", } if err := db.UpsertExperiment(ctx, exp); err != nil { t.Fatalf("UpsertExperiment failed: %v", err) } depsJSON, err := json.Marshal([]map[string]string{{"name": "numpy", "version": "1.26.0", "source": "pip"}}) if err != nil { t.Fatalf("Marshal deps failed: %v", err) } env := &storage.ExperimentEnvironment{ PythonVersion: "Python 3.12.0", CUDAVersion: "", SystemOS: "darwin", SystemArch: "arm64", Hostname: "host", RequirementsHash: "abc123", Dependencies: depsJSON, } if err := db.UpsertExperimentEnvironment(ctx, exp.ID, env); err != nil { t.Fatalf("UpsertExperimentEnvironment failed: %v", err) } git := &storage.ExperimentGitInfo{ CommitSHA: "deadbeef", Branch: "main", RemoteURL: "git@example.com:repo.git", IsDirty: true, DiffPatch: "diff --git ...", } if err := db.UpsertExperimentGitInfo(ctx, exp.ID, git); err != nil { t.Fatalf("UpsertExperimentGitInfo failed: %v", err) } numpySeed := int64(123) randSeed := int64(999) seeds := &storage.ExperimentSeeds{ Numpy: &numpySeed, Random: &randSeed, } if err := db.UpsertExperimentSeeds(ctx, exp.ID, seeds); err != nil { t.Fatalf("UpsertExperimentSeeds failed: %v", err) } got, err := db.GetExperimentWithMetadata(ctx, exp.ID) if err != nil { t.Fatalf("GetExperimentWithMetadata failed: %v", err) } if got.Experiment.ID != exp.ID { t.Fatalf("expected id %q, got %q", exp.ID, got.Experiment.ID) } if got.Experiment.Name != exp.Name { t.Fatalf("expected name %q, got %q", exp.Name, got.Experiment.Name) } if got.Experiment.UserID != exp.UserID { t.Fatalf("expected user_id %q, got %q", exp.UserID, got.Experiment.UserID) } if got.Environment == nil { t.Fatalf("expected environment, got nil") } if got.Environment.PythonVersion != env.PythonVersion { t.Fatalf("expected python_version %q, got %q", env.PythonVersion, got.Environment.PythonVersion) } if got.GitInfo == nil { t.Fatalf("expected git_info, got nil") } if got.GitInfo.IsDirty != true { t.Fatalf("expected is_dirty true, got false") } if got.Seeds == nil { t.Fatalf("expected seeds, got nil") } if got.Seeds.Numpy == nil || *got.Seeds.Numpy != numpySeed { t.Fatalf("expected numpy_seed %d, got %+v", numpySeed, got.Seeds.Numpy) } if got.Seeds.Random == nil || *got.Seeds.Random != randSeed { t.Fatalf("expected random_seed %d, got %+v", randSeed, got.Seeds.Random) } }