// Package storage provides database abstraction and job management. package storage import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" ) type Experiment struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` Status string `json:"status"` UserID string `json:"user_id,omitempty"` WorkspaceID string `json:"workspace_id,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type ExperimentEnvironment struct { PythonVersion string `json:"python_version"` CUDAVersion string `json:"cuda_version,omitempty"` SystemOS string `json:"system_os"` SystemArch string `json:"system_arch"` Hostname string `json:"hostname"` RequirementsHash string `json:"requirements_hash"` CondaEnvHash string `json:"conda_env_hash,omitempty"` Dependencies json.RawMessage `json:"dependencies,omitempty"` CreatedAt time.Time `json:"created_at"` } type ExperimentGitInfo struct { CommitSHA string `json:"commit_sha"` Branch string `json:"branch"` RemoteURL string `json:"remote_url"` IsDirty bool `json:"is_dirty"` DiffPatch string `json:"diff_patch,omitempty"` CreatedAt time.Time `json:"created_at"` } type ExperimentSeeds struct { Numpy *int64 `json:"numpy_seed,omitempty"` Torch *int64 `json:"torch_seed,omitempty"` TensorFlow *int64 `json:"tensorflow_seed,omitempty"` Random *int64 `json:"random_seed,omitempty"` CreatedAt time.Time `json:"created_at"` } type Dataset struct { Name string `json:"name"` URL string `json:"url"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type ExperimentWithMetadata struct { Experiment Experiment `json:"experiment"` Environment *ExperimentEnvironment `json:"environment,omitempty"` GitInfo *ExperimentGitInfo `json:"git_info,omitempty"` Seeds *ExperimentSeeds `json:"seeds,omitempty"` } func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error { if exp == nil { return fmt.Errorf("experiment is nil") } if exp.ID == "" { return fmt.Errorf("experiment id is required") } if exp.Name == "" { return fmt.Errorf("experiment name is required") } var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name = excluded.name, description = excluded.description, status = excluded.status, user_id = excluded.user_id, workspace_id = excluded.workspace_id` } else { query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description, status = EXCLUDED.status, user_id = EXCLUDED.user_id, workspace_id = EXCLUDED.workspace_id` } _, err := db.conn.ExecContext( ctx, query, exp.ID, exp.Name, exp.Description, exp.Status, exp.UserID, exp.WorkspaceID, ) if err != nil { return fmt.Errorf("failed to upsert experiment: %w", err) } return nil } func (db *DB) UpsertExperimentEnvironment( ctx context.Context, experimentID string, env *ExperimentEnvironment, ) error { if experimentID == "" { return fmt.Errorf("experiment id is required") } if env == nil { return fmt.Errorf("environment is nil") } deps := "" if len(env.Dependencies) > 0 { deps = string(env.Dependencies) } var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO experiment_environments (experiment_id, python_version, cuda_version, system_os, system_arch, hostname, requirements_hash, conda_env_hash, dependencies) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(experiment_id) DO UPDATE SET python_version = excluded.python_version, cuda_version = excluded.cuda_version, system_os = excluded.system_os, system_arch = excluded.system_arch, hostname = excluded.hostname, requirements_hash = excluded.requirements_hash, conda_env_hash = excluded.conda_env_hash, dependencies = excluded.dependencies` } else { query = `INSERT INTO experiment_environments (experiment_id, python_version, cuda_version, system_os, system_arch, hostname, requirements_hash, conda_env_hash, dependencies) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (experiment_id) DO UPDATE SET python_version = EXCLUDED.python_version, cuda_version = EXCLUDED.cuda_version, system_os = EXCLUDED.system_os, system_arch = EXCLUDED.system_arch, hostname = EXCLUDED.hostname, requirements_hash = EXCLUDED.requirements_hash, conda_env_hash = EXCLUDED.conda_env_hash, dependencies = EXCLUDED.dependencies` } _, err := db.conn.ExecContext( ctx, query, experimentID, env.PythonVersion, env.CUDAVersion, env.SystemOS, env.SystemArch, env.Hostname, env.RequirementsHash, env.CondaEnvHash, deps, ) if err != nil { return fmt.Errorf("failed to upsert experiment environment: %w", err) } return nil } func (db *DB) UpsertExperimentGitInfo( ctx context.Context, experimentID string, info *ExperimentGitInfo, ) error { if experimentID == "" { return fmt.Errorf("experiment id is required") } if info == nil { return fmt.Errorf("git info is nil") } isDirty := 0 if info.IsDirty { isDirty = 1 } var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO experiment_git_info (experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(experiment_id) DO UPDATE SET commit_sha = excluded.commit_sha, branch = excluded.branch, remote_url = excluded.remote_url, is_dirty = excluded.is_dirty, diff_patch = excluded.diff_patch` } else { query = `INSERT INTO experiment_git_info (experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (experiment_id) DO UPDATE SET commit_sha = EXCLUDED.commit_sha, branch = EXCLUDED.branch, remote_url = EXCLUDED.remote_url, is_dirty = EXCLUDED.is_dirty, diff_patch = EXCLUDED.diff_patch` } _, err := db.conn.ExecContext( ctx, query, experimentID, info.CommitSHA, info.Branch, info.RemoteURL, isDirty, info.DiffPatch, ) if err != nil { return fmt.Errorf("failed to upsert experiment git info: %w", err) } return nil } func (db *DB) UpsertExperimentSeeds( ctx context.Context, experimentID string, seeds *ExperimentSeeds, ) error { if experimentID == "" { return fmt.Errorf("experiment id is required") } if seeds == nil { return fmt.Errorf("seeds is nil") } var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO experiment_seeds (experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed) VALUES (?, ?, ?, ?, ?) ON CONFLICT(experiment_id) DO UPDATE SET numpy_seed = excluded.numpy_seed, torch_seed = excluded.torch_seed, tensorflow_seed = excluded.tensorflow_seed, random_seed = excluded.random_seed` } else { query = `INSERT INTO experiment_seeds (experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (experiment_id) DO UPDATE SET numpy_seed = EXCLUDED.numpy_seed, torch_seed = EXCLUDED.torch_seed, tensorflow_seed = EXCLUDED.tensorflow_seed, random_seed = EXCLUDED.random_seed` } _, err := db.conn.ExecContext( ctx, query, experimentID, seeds.Numpy, seeds.Torch, seeds.TensorFlow, seeds.Random, ) if err != nil { return fmt.Errorf("failed to upsert experiment seeds: %w", err) } return nil } func (db *DB) GetExperimentWithMetadata( ctx context.Context, experimentID string, ) (*ExperimentWithMetadata, error) { if experimentID == "" { return nil, fmt.Errorf("experiment id is required") } var exp Experiment var query string if db.dbType == DBTypeSQLite { query = `SELECT id, name, COALESCE(description, ''), COALESCE(status, ''), COALESCE(user_id, ''), COALESCE(workspace_id, ''), created_at, updated_at FROM experiments WHERE id = ?` } else { query = `SELECT id, name, COALESCE(description, ''), COALESCE(status, ''), COALESCE(user_id, ''), COALESCE(workspace_id, ''), created_at, updated_at FROM experiments WHERE id = $1` } err := db.conn.QueryRowContext(ctx, query, experimentID).Scan( &exp.ID, &exp.Name, &exp.Description, &exp.Status, &exp.UserID, &exp.WorkspaceID, &exp.CreatedAt, &exp.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to get experiment: %w", err) } result := &ExperimentWithMetadata{Experiment: exp} var env ExperimentEnvironment var envDeps sql.NullString var envQuery string if db.dbType == DBTypeSQLite { envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''), COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''), COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''), dependencies, created_at FROM experiment_environments WHERE experiment_id = ?` } else { envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''), COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''), COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''), dependencies, created_at FROM experiment_environments WHERE experiment_id = $1` } if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan( &env.PythonVersion, &env.CUDAVersion, &env.SystemOS, &env.SystemArch, &env.Hostname, &env.RequirementsHash, &env.CondaEnvHash, &envDeps, &env.CreatedAt, ); err == nil { if envDeps.Valid && envDeps.String != "" { env.Dependencies = json.RawMessage(envDeps.String) } result.Environment = &env } var git ExperimentGitInfo var gitDirty sql.NullInt64 var gitQuery string if db.dbType == DBTypeSQLite { gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''), COALESCE(remote_url, ''), COALESCE(is_dirty, 0), COALESCE(diff_patch, ''), created_at FROM experiment_git_info WHERE experiment_id = ?` } else { gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''), COALESCE(remote_url, ''), COALESCE(is_dirty, 0), COALESCE(diff_patch, ''), created_at FROM experiment_git_info WHERE experiment_id = $1` } if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan( &git.CommitSHA, &git.Branch, &git.RemoteURL, &gitDirty, &git.DiffPatch, &git.CreatedAt, ); err == nil { git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0 result.GitInfo = &git } var seeds ExperimentSeeds var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64 var seedsQuery string if db.dbType == DBTypeSQLite { seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at FROM experiment_seeds WHERE experiment_id = ?` } else { seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at FROM experiment_seeds WHERE experiment_id = $1` } if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan( &numpySeed, &torchSeed, &tfSeed, &randSeed, &seeds.CreatedAt, ); err == nil { if numpySeed.Valid { v := numpySeed.Int64 seeds.Numpy = &v } if torchSeed.Valid { v := torchSeed.Int64 seeds.Torch = &v } if tfSeed.Valid { v := tfSeed.Int64 seeds.TensorFlow = &v } if randSeed.Valid { v := randSeed.Int64 seeds.Random = &v } result.Seeds = &seeds } return result, nil } func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error { if ds == nil { return fmt.Errorf("dataset is nil") } if ds.Name == "" { return fmt.Errorf("dataset name is required") } if ds.URL == "" { return fmt.Errorf("dataset url is required") } var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO datasets (name, url) VALUES (?, ?) ON CONFLICT(name) DO UPDATE SET url = excluded.url` } else { query = `INSERT INTO datasets (name, url) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET url = EXCLUDED.url` } if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil { return fmt.Errorf("failed to upsert dataset: %w", err) } return nil } func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) { if name == "" { return nil, fmt.Errorf("dataset name is required") } var query string if db.dbType == DBTypeSQLite { query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?` } else { query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1` } var ds Dataset if err := db.conn.QueryRowContext(ctx, query, name).Scan( &ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt, ); err != nil { if err == sql.ErrNoRows { return nil, err } return nil, fmt.Errorf("failed to get dataset: %w", err) } return &ds, nil } func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) { query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC` var args []interface{} if limit > 0 { if db.dbType == DBTypeSQLite { query += " LIMIT ?" } else { query += " LIMIT $1" } args = append(args, limit) } rows, err := db.conn.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to list datasets: %w", err) } defer func() { _ = rows.Close() }() var out []*Dataset for rows.Next() { var ds Dataset if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil { return nil, fmt.Errorf("failed to scan dataset: %w", err) } out = append(out, &ds) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating datasets: %w", err) } return out, nil } func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) { if term == "" { return []*Dataset{}, nil } // Escape %/_ for LIKE and use parameterized query. escaped := strings.ReplaceAll(term, "\\", "\\\\") escaped = strings.ReplaceAll(escaped, "%", "\\%") escaped = strings.ReplaceAll(escaped, "_", "\\_") pattern := "%" + escaped + "%" var query string var args []interface{} if db.dbType == DBTypeSQLite { query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name LIKE ? ESCAPE '\' ORDER BY name ASC` args = append(args, pattern) if limit > 0 { query += " LIMIT ?" args = append(args, limit) } } else { query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name LIKE $1 ESCAPE '\' ORDER BY name ASC` args = append(args, pattern) if limit > 0 { query += " LIMIT $2" args = append(args, limit) } } rows, err := db.conn.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to search datasets: %w", err) } defer func() { _ = rows.Close() }() var out []*Dataset for rows.Next() { var ds Dataset if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil { return nil, fmt.Errorf("failed to scan dataset: %w", err) } out = append(out, &ds) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating datasets: %w", err) } return out, nil }