fetch_ml/tests/unit/storage/experiment_metadata_test.go

122 lines
3.2 KiB
Go

package storage
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/storage"
)
func TestExperimentMetadataRoundTripSQLite(t *testing.T) {
t.Parallel()
schema, err := storage.SchemaForDBType(storage.DBTypeSQLite)
if err != nil {
t.Fatalf("SchemaForDBType(sqlite) failed: %v", err)
}
dbPath := t.TempDir() + "/test.sqlite"
db, err := storage.NewDBFromPath(dbPath)
if err != nil {
t.Fatalf("NewDBFromPath failed: %v", err)
}
defer func() { _ = db.Close() }()
if err := db.Initialize(schema); err != nil {
t.Fatalf("Initialize failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
exp := &storage.Experiment{
ID: "exp-1",
Name: "train-resnet",
Description: "test run",
Status: "pending",
UserID: "alice",
WorkspaceID: "ws-1",
}
if err := db.UpsertExperiment(ctx, exp); err != nil {
t.Fatalf("UpsertExperiment failed: %v", err)
}
depsJSON, err := json.Marshal([]map[string]string{{"name": "numpy", "version": "1.26.0", "source": "pip"}})
if err != nil {
t.Fatalf("Marshal deps failed: %v", err)
}
env := &storage.ExperimentEnvironment{
PythonVersion: "Python 3.12.0",
CUDAVersion: "",
SystemOS: "darwin",
SystemArch: "arm64",
Hostname: "host",
RequirementsHash: "abc123",
Dependencies: depsJSON,
}
if err := db.UpsertExperimentEnvironment(ctx, exp.ID, env); err != nil {
t.Fatalf("UpsertExperimentEnvironment failed: %v", err)
}
git := &storage.ExperimentGitInfo{
CommitSHA: "deadbeef",
Branch: "main",
RemoteURL: "git@example.com:repo.git",
IsDirty: true,
DiffPatch: "diff --git ...",
}
if err := db.UpsertExperimentGitInfo(ctx, exp.ID, git); err != nil {
t.Fatalf("UpsertExperimentGitInfo failed: %v", err)
}
numpySeed := int64(123)
randSeed := int64(999)
seeds := &storage.ExperimentSeeds{
Numpy: &numpySeed,
Random: &randSeed,
}
if err := db.UpsertExperimentSeeds(ctx, exp.ID, seeds); err != nil {
t.Fatalf("UpsertExperimentSeeds failed: %v", err)
}
got, err := db.GetExperimentWithMetadata(ctx, exp.ID)
if err != nil {
t.Fatalf("GetExperimentWithMetadata failed: %v", err)
}
if got.Experiment.ID != exp.ID {
t.Fatalf("expected id %q, got %q", exp.ID, got.Experiment.ID)
}
if got.Experiment.Name != exp.Name {
t.Fatalf("expected name %q, got %q", exp.Name, got.Experiment.Name)
}
if got.Experiment.UserID != exp.UserID {
t.Fatalf("expected user_id %q, got %q", exp.UserID, got.Experiment.UserID)
}
if got.Environment == nil {
t.Fatalf("expected environment, got nil")
}
if got.Environment.PythonVersion != env.PythonVersion {
t.Fatalf("expected python_version %q, got %q", env.PythonVersion, got.Environment.PythonVersion)
}
if got.GitInfo == nil {
t.Fatalf("expected git_info, got nil")
}
if got.GitInfo.IsDirty != true {
t.Fatalf("expected is_dirty true, got false")
}
if got.Seeds == nil {
t.Fatalf("expected seeds, got nil")
}
if got.Seeds.Numpy == nil || *got.Seeds.Numpy != numpySeed {
t.Fatalf("expected numpy_seed %d, got %+v", numpySeed, got.Seeds.Numpy)
}
if got.Seeds.Random == nil || *got.Seeds.Random != randSeed {
t.Fatalf("expected random_seed %d, got %+v", randSeed, got.Seeds.Random)
}
}