122 lines
3.2 KiB
Go
122 lines
3.2 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
)
|
|
|
|
func TestExperimentMetadataRoundTripSQLite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
schema, err := storage.SchemaForDBType(storage.DBTypeSQLite)
|
|
if err != nil {
|
|
t.Fatalf("SchemaForDBType(sqlite) failed: %v", err)
|
|
}
|
|
|
|
dbPath := t.TempDir() + "/test.sqlite"
|
|
db, err := storage.NewDBFromPath(dbPath)
|
|
if err != nil {
|
|
t.Fatalf("NewDBFromPath failed: %v", err)
|
|
}
|
|
defer func() { _ = db.Close() }()
|
|
|
|
if err := db.Initialize(schema); err != nil {
|
|
t.Fatalf("Initialize failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
exp := &storage.Experiment{
|
|
ID: "exp-1",
|
|
Name: "train-resnet",
|
|
Description: "test run",
|
|
Status: "pending",
|
|
UserID: "alice",
|
|
WorkspaceID: "ws-1",
|
|
}
|
|
if err := db.UpsertExperiment(ctx, exp); err != nil {
|
|
t.Fatalf("UpsertExperiment failed: %v", err)
|
|
}
|
|
|
|
depsJSON, err := json.Marshal([]map[string]string{{"name": "numpy", "version": "1.26.0", "source": "pip"}})
|
|
if err != nil {
|
|
t.Fatalf("Marshal deps failed: %v", err)
|
|
}
|
|
env := &storage.ExperimentEnvironment{
|
|
PythonVersion: "Python 3.12.0",
|
|
CUDAVersion: "",
|
|
SystemOS: "darwin",
|
|
SystemArch: "arm64",
|
|
Hostname: "host",
|
|
RequirementsHash: "abc123",
|
|
Dependencies: depsJSON,
|
|
}
|
|
if err := db.UpsertExperimentEnvironment(ctx, exp.ID, env); err != nil {
|
|
t.Fatalf("UpsertExperimentEnvironment failed: %v", err)
|
|
}
|
|
|
|
git := &storage.ExperimentGitInfo{
|
|
CommitSHA: "deadbeef",
|
|
Branch: "main",
|
|
RemoteURL: "git@example.com:repo.git",
|
|
IsDirty: true,
|
|
DiffPatch: "diff --git ...",
|
|
}
|
|
if err := db.UpsertExperimentGitInfo(ctx, exp.ID, git); err != nil {
|
|
t.Fatalf("UpsertExperimentGitInfo failed: %v", err)
|
|
}
|
|
|
|
numpySeed := int64(123)
|
|
randSeed := int64(999)
|
|
seeds := &storage.ExperimentSeeds{
|
|
Numpy: &numpySeed,
|
|
Random: &randSeed,
|
|
}
|
|
if err := db.UpsertExperimentSeeds(ctx, exp.ID, seeds); err != nil {
|
|
t.Fatalf("UpsertExperimentSeeds failed: %v", err)
|
|
}
|
|
|
|
got, err := db.GetExperimentWithMetadata(ctx, exp.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetExperimentWithMetadata failed: %v", err)
|
|
}
|
|
|
|
if got.Experiment.ID != exp.ID {
|
|
t.Fatalf("expected id %q, got %q", exp.ID, got.Experiment.ID)
|
|
}
|
|
if got.Experiment.Name != exp.Name {
|
|
t.Fatalf("expected name %q, got %q", exp.Name, got.Experiment.Name)
|
|
}
|
|
if got.Experiment.UserID != exp.UserID {
|
|
t.Fatalf("expected user_id %q, got %q", exp.UserID, got.Experiment.UserID)
|
|
}
|
|
|
|
if got.Environment == nil {
|
|
t.Fatalf("expected environment, got nil")
|
|
}
|
|
if got.Environment.PythonVersion != env.PythonVersion {
|
|
t.Fatalf("expected python_version %q, got %q", env.PythonVersion, got.Environment.PythonVersion)
|
|
}
|
|
|
|
if got.GitInfo == nil {
|
|
t.Fatalf("expected git_info, got nil")
|
|
}
|
|
if got.GitInfo.IsDirty != true {
|
|
t.Fatalf("expected is_dirty true, got false")
|
|
}
|
|
|
|
if got.Seeds == nil {
|
|
t.Fatalf("expected seeds, got nil")
|
|
}
|
|
if got.Seeds.Numpy == nil || *got.Seeds.Numpy != numpySeed {
|
|
t.Fatalf("expected numpy_seed %d, got %+v", numpySeed, got.Seeds.Numpy)
|
|
}
|
|
if got.Seeds.Random == nil || *got.Seeds.Random != randSeed {
|
|
t.Fatalf("expected random_seed %d, got %+v", randSeed, got.Seeds.Random)
|
|
}
|
|
}
|