fetch_ml/internal/storage/db_experiments.go

564 lines
15 KiB
Go

// Package storage provides database abstraction and job management.
package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Experiment struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Status string `json:"status"`
UserID string `json:"user_id,omitempty"`
WorkspaceID string `json:"workspace_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ExperimentEnvironment struct {
PythonVersion string `json:"python_version"`
CUDAVersion string `json:"cuda_version,omitempty"`
SystemOS string `json:"system_os"`
SystemArch string `json:"system_arch"`
Hostname string `json:"hostname"`
RequirementsHash string `json:"requirements_hash"`
CondaEnvHash string `json:"conda_env_hash,omitempty"`
Dependencies json.RawMessage `json:"dependencies,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ExperimentGitInfo struct {
CommitSHA string `json:"commit_sha"`
Branch string `json:"branch"`
RemoteURL string `json:"remote_url"`
IsDirty bool `json:"is_dirty"`
DiffPatch string `json:"diff_patch,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ExperimentSeeds struct {
Numpy *int64 `json:"numpy_seed,omitempty"`
Torch *int64 `json:"torch_seed,omitempty"`
TensorFlow *int64 `json:"tensorflow_seed,omitempty"`
Random *int64 `json:"random_seed,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type Dataset struct {
Name string `json:"name"`
URL string `json:"url"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ExperimentWithMetadata struct {
Experiment Experiment `json:"experiment"`
Environment *ExperimentEnvironment `json:"environment,omitempty"`
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
}
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
if exp == nil {
return fmt.Errorf("experiment is nil")
}
if exp.ID == "" {
return fmt.Errorf("experiment id is required")
}
if exp.Name == "" {
return fmt.Errorf("experiment name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
description = excluded.description,
status = excluded.status,
user_id = excluded.user_id,
workspace_id = excluded.workspace_id`
} else {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
status = EXCLUDED.status,
user_id = EXCLUDED.user_id,
workspace_id = EXCLUDED.workspace_id`
}
_, err := db.conn.ExecContext(
ctx,
query,
exp.ID,
exp.Name,
exp.Description,
exp.Status,
exp.UserID,
exp.WorkspaceID,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentEnvironment(
ctx context.Context,
experimentID string,
env *ExperimentEnvironment,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if env == nil {
return fmt.Errorf("environment is nil")
}
deps := ""
if len(env.Dependencies) > 0 {
deps = string(env.Dependencies)
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
python_version = excluded.python_version,
cuda_version = excluded.cuda_version,
system_os = excluded.system_os,
system_arch = excluded.system_arch,
hostname = excluded.hostname,
requirements_hash = excluded.requirements_hash,
conda_env_hash = excluded.conda_env_hash,
dependencies = excluded.dependencies`
} else {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (experiment_id) DO UPDATE SET
python_version = EXCLUDED.python_version,
cuda_version = EXCLUDED.cuda_version,
system_os = EXCLUDED.system_os,
system_arch = EXCLUDED.system_arch,
hostname = EXCLUDED.hostname,
requirements_hash = EXCLUDED.requirements_hash,
conda_env_hash = EXCLUDED.conda_env_hash,
dependencies = EXCLUDED.dependencies`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
env.PythonVersion,
env.CUDAVersion,
env.SystemOS,
env.SystemArch,
env.Hostname,
env.RequirementsHash,
env.CondaEnvHash,
deps,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment environment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentGitInfo(
ctx context.Context,
experimentID string,
info *ExperimentGitInfo,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if info == nil {
return fmt.Errorf("git info is nil")
}
isDirty := 0
if info.IsDirty {
isDirty = 1
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
commit_sha = excluded.commit_sha,
branch = excluded.branch,
remote_url = excluded.remote_url,
is_dirty = excluded.is_dirty,
diff_patch = excluded.diff_patch`
} else {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (experiment_id) DO UPDATE SET
commit_sha = EXCLUDED.commit_sha,
branch = EXCLUDED.branch,
remote_url = EXCLUDED.remote_url,
is_dirty = EXCLUDED.is_dirty,
diff_patch = EXCLUDED.diff_patch`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
info.CommitSHA,
info.Branch,
info.RemoteURL,
isDirty,
info.DiffPatch,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment git info: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentSeeds(
ctx context.Context,
experimentID string,
seeds *ExperimentSeeds,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if seeds == nil {
return fmt.Errorf("seeds is nil")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
numpy_seed = excluded.numpy_seed,
torch_seed = excluded.torch_seed,
tensorflow_seed = excluded.tensorflow_seed,
random_seed = excluded.random_seed`
} else {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (experiment_id) DO UPDATE SET
numpy_seed = EXCLUDED.numpy_seed,
torch_seed = EXCLUDED.torch_seed,
tensorflow_seed = EXCLUDED.tensorflow_seed,
random_seed = EXCLUDED.random_seed`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
seeds.Numpy,
seeds.Torch,
seeds.TensorFlow,
seeds.Random,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment seeds: %w", err)
}
return nil
}
func (db *DB) GetExperimentWithMetadata(
ctx context.Context,
experimentID string,
) (*ExperimentWithMetadata, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment id is required")
}
var exp Experiment
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = ?`
} else {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = $1`
}
err := db.conn.QueryRowContext(ctx, query, experimentID).Scan(
&exp.ID,
&exp.Name,
&exp.Description,
&exp.Status,
&exp.UserID,
&exp.WorkspaceID,
&exp.CreatedAt,
&exp.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get experiment: %w", err)
}
result := &ExperimentWithMetadata{Experiment: exp}
var env ExperimentEnvironment
var envDeps sql.NullString
var envQuery string
if db.dbType == DBTypeSQLite {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = ?`
} else {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan(
&env.PythonVersion,
&env.CUDAVersion,
&env.SystemOS,
&env.SystemArch,
&env.Hostname,
&env.RequirementsHash,
&env.CondaEnvHash,
&envDeps,
&env.CreatedAt,
); err == nil {
if envDeps.Valid && envDeps.String != "" {
env.Dependencies = json.RawMessage(envDeps.String)
}
result.Environment = &env
}
var git ExperimentGitInfo
var gitDirty sql.NullInt64
var gitQuery string
if db.dbType == DBTypeSQLite {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = ?`
} else {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan(
&git.CommitSHA,
&git.Branch,
&git.RemoteURL,
&gitDirty,
&git.DiffPatch,
&git.CreatedAt,
); err == nil {
git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0
result.GitInfo = &git
}
var seeds ExperimentSeeds
var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64
var seedsQuery string
if db.dbType == DBTypeSQLite {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = ?`
} else {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan(
&numpySeed,
&torchSeed,
&tfSeed,
&randSeed,
&seeds.CreatedAt,
); err == nil {
if numpySeed.Valid {
v := numpySeed.Int64
seeds.Numpy = &v
}
if torchSeed.Valid {
v := torchSeed.Int64
seeds.Torch = &v
}
if tfSeed.Valid {
v := tfSeed.Int64
seeds.TensorFlow = &v
}
if randSeed.Valid {
v := randSeed.Int64
seeds.Random = &v
}
result.Seeds = &seeds
}
return result, nil
}
func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error {
if ds == nil {
return fmt.Errorf("dataset is nil")
}
if ds.Name == "" {
return fmt.Errorf("dataset name is required")
}
if ds.URL == "" {
return fmt.Errorf("dataset url is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO datasets (name, url)
VALUES (?, ?)
ON CONFLICT(name) DO UPDATE SET
url = excluded.url`
} else {
query = `INSERT INTO datasets (name, url)
VALUES ($1, $2)
ON CONFLICT (name) DO UPDATE SET
url = EXCLUDED.url`
}
if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil {
return fmt.Errorf("failed to upsert dataset: %w", err)
}
return nil
}
func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
if name == "" {
return nil, fmt.Errorf("dataset name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?`
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1`
}
var ds Dataset
if err := db.conn.QueryRowContext(ctx, query, name).Scan(
&ds.Name,
&ds.URL,
&ds.CreatedAt,
&ds.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, err
}
return nil, fmt.Errorf("failed to get dataset: %w", err)
}
return &ds, nil
}
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
var args []interface{}
if limit > 0 {
if db.dbType == DBTypeSQLite {
query += " LIMIT ?"
} else {
query += " LIMIT $1"
}
args = append(args, limit)
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}
func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) {
if term == "" {
return []*Dataset{}, nil
}
// Escape %/_ for LIKE and use parameterized query.
escaped := strings.ReplaceAll(term, "\\", "\\\\")
escaped = strings.ReplaceAll(escaped, "%", "\\%")
escaped = strings.ReplaceAll(escaped, "_", "\\_")
pattern := "%" + escaped + "%"
var query string
var args []interface{}
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE ? ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
}
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE $1 ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT $2"
args = append(args, limit)
}
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to search datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}