564 lines
15 KiB
Go
564 lines
15 KiB
Go
// Package storage provides database abstraction and job management.
|
|
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Experiment struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Status string `json:"status"`
|
|
UserID string `json:"user_id,omitempty"`
|
|
WorkspaceID string `json:"workspace_id,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type ExperimentEnvironment struct {
|
|
PythonVersion string `json:"python_version"`
|
|
CUDAVersion string `json:"cuda_version,omitempty"`
|
|
SystemOS string `json:"system_os"`
|
|
SystemArch string `json:"system_arch"`
|
|
Hostname string `json:"hostname"`
|
|
RequirementsHash string `json:"requirements_hash"`
|
|
CondaEnvHash string `json:"conda_env_hash,omitempty"`
|
|
Dependencies json.RawMessage `json:"dependencies,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type ExperimentGitInfo struct {
|
|
CommitSHA string `json:"commit_sha"`
|
|
Branch string `json:"branch"`
|
|
RemoteURL string `json:"remote_url"`
|
|
IsDirty bool `json:"is_dirty"`
|
|
DiffPatch string `json:"diff_patch,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type ExperimentSeeds struct {
|
|
Numpy *int64 `json:"numpy_seed,omitempty"`
|
|
Torch *int64 `json:"torch_seed,omitempty"`
|
|
TensorFlow *int64 `json:"tensorflow_seed,omitempty"`
|
|
Random *int64 `json:"random_seed,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type Dataset struct {
|
|
Name string `json:"name"`
|
|
URL string `json:"url"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type ExperimentWithMetadata struct {
|
|
Experiment Experiment `json:"experiment"`
|
|
Environment *ExperimentEnvironment `json:"environment,omitempty"`
|
|
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
|
|
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
|
|
}
|
|
|
|
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
|
|
if exp == nil {
|
|
return fmt.Errorf("experiment is nil")
|
|
}
|
|
if exp.ID == "" {
|
|
return fmt.Errorf("experiment id is required")
|
|
}
|
|
if exp.Name == "" {
|
|
return fmt.Errorf("experiment name is required")
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
name = excluded.name,
|
|
description = excluded.description,
|
|
status = excluded.status,
|
|
user_id = excluded.user_id,
|
|
workspace_id = excluded.workspace_id`
|
|
} else {
|
|
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
name = EXCLUDED.name,
|
|
description = EXCLUDED.description,
|
|
status = EXCLUDED.status,
|
|
user_id = EXCLUDED.user_id,
|
|
workspace_id = EXCLUDED.workspace_id`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(
|
|
ctx,
|
|
query,
|
|
exp.ID,
|
|
exp.Name,
|
|
exp.Description,
|
|
exp.Status,
|
|
exp.UserID,
|
|
exp.WorkspaceID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert experiment: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) UpsertExperimentEnvironment(
|
|
ctx context.Context,
|
|
experimentID string,
|
|
env *ExperimentEnvironment,
|
|
) error {
|
|
if experimentID == "" {
|
|
return fmt.Errorf("experiment id is required")
|
|
}
|
|
if env == nil {
|
|
return fmt.Errorf("environment is nil")
|
|
}
|
|
|
|
deps := ""
|
|
if len(env.Dependencies) > 0 {
|
|
deps = string(env.Dependencies)
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO experiment_environments
|
|
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
|
|
requirements_hash, conda_env_hash, dependencies)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(experiment_id) DO UPDATE SET
|
|
python_version = excluded.python_version,
|
|
cuda_version = excluded.cuda_version,
|
|
system_os = excluded.system_os,
|
|
system_arch = excluded.system_arch,
|
|
hostname = excluded.hostname,
|
|
requirements_hash = excluded.requirements_hash,
|
|
conda_env_hash = excluded.conda_env_hash,
|
|
dependencies = excluded.dependencies`
|
|
} else {
|
|
query = `INSERT INTO experiment_environments
|
|
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
|
|
requirements_hash, conda_env_hash, dependencies)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
ON CONFLICT (experiment_id) DO UPDATE SET
|
|
python_version = EXCLUDED.python_version,
|
|
cuda_version = EXCLUDED.cuda_version,
|
|
system_os = EXCLUDED.system_os,
|
|
system_arch = EXCLUDED.system_arch,
|
|
hostname = EXCLUDED.hostname,
|
|
requirements_hash = EXCLUDED.requirements_hash,
|
|
conda_env_hash = EXCLUDED.conda_env_hash,
|
|
dependencies = EXCLUDED.dependencies`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(
|
|
ctx,
|
|
query,
|
|
experimentID,
|
|
env.PythonVersion,
|
|
env.CUDAVersion,
|
|
env.SystemOS,
|
|
env.SystemArch,
|
|
env.Hostname,
|
|
env.RequirementsHash,
|
|
env.CondaEnvHash,
|
|
deps,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert experiment environment: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) UpsertExperimentGitInfo(
|
|
ctx context.Context,
|
|
experimentID string,
|
|
info *ExperimentGitInfo,
|
|
) error {
|
|
if experimentID == "" {
|
|
return fmt.Errorf("experiment id is required")
|
|
}
|
|
if info == nil {
|
|
return fmt.Errorf("git info is nil")
|
|
}
|
|
|
|
isDirty := 0
|
|
if info.IsDirty {
|
|
isDirty = 1
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO experiment_git_info
|
|
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(experiment_id) DO UPDATE SET
|
|
commit_sha = excluded.commit_sha,
|
|
branch = excluded.branch,
|
|
remote_url = excluded.remote_url,
|
|
is_dirty = excluded.is_dirty,
|
|
diff_patch = excluded.diff_patch`
|
|
} else {
|
|
query = `INSERT INTO experiment_git_info
|
|
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
ON CONFLICT (experiment_id) DO UPDATE SET
|
|
commit_sha = EXCLUDED.commit_sha,
|
|
branch = EXCLUDED.branch,
|
|
remote_url = EXCLUDED.remote_url,
|
|
is_dirty = EXCLUDED.is_dirty,
|
|
diff_patch = EXCLUDED.diff_patch`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(
|
|
ctx,
|
|
query,
|
|
experimentID,
|
|
info.CommitSHA,
|
|
info.Branch,
|
|
info.RemoteURL,
|
|
isDirty,
|
|
info.DiffPatch,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert experiment git info: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) UpsertExperimentSeeds(
|
|
ctx context.Context,
|
|
experimentID string,
|
|
seeds *ExperimentSeeds,
|
|
) error {
|
|
if experimentID == "" {
|
|
return fmt.Errorf("experiment id is required")
|
|
}
|
|
if seeds == nil {
|
|
return fmt.Errorf("seeds is nil")
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO experiment_seeds
|
|
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(experiment_id) DO UPDATE SET
|
|
numpy_seed = excluded.numpy_seed,
|
|
torch_seed = excluded.torch_seed,
|
|
tensorflow_seed = excluded.tensorflow_seed,
|
|
random_seed = excluded.random_seed`
|
|
} else {
|
|
query = `INSERT INTO experiment_seeds
|
|
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (experiment_id) DO UPDATE SET
|
|
numpy_seed = EXCLUDED.numpy_seed,
|
|
torch_seed = EXCLUDED.torch_seed,
|
|
tensorflow_seed = EXCLUDED.tensorflow_seed,
|
|
random_seed = EXCLUDED.random_seed`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(
|
|
ctx,
|
|
query,
|
|
experimentID,
|
|
seeds.Numpy,
|
|
seeds.Torch,
|
|
seeds.TensorFlow,
|
|
seeds.Random,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert experiment seeds: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) GetExperimentWithMetadata(
|
|
ctx context.Context,
|
|
experimentID string,
|
|
) (*ExperimentWithMetadata, error) {
|
|
if experimentID == "" {
|
|
return nil, fmt.Errorf("experiment id is required")
|
|
}
|
|
|
|
var exp Experiment
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT id, name,
|
|
COALESCE(description, ''), COALESCE(status, ''),
|
|
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
|
|
created_at, updated_at
|
|
FROM experiments WHERE id = ?`
|
|
} else {
|
|
query = `SELECT id, name,
|
|
COALESCE(description, ''), COALESCE(status, ''),
|
|
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
|
|
created_at, updated_at
|
|
FROM experiments WHERE id = $1`
|
|
}
|
|
|
|
err := db.conn.QueryRowContext(ctx, query, experimentID).Scan(
|
|
&exp.ID,
|
|
&exp.Name,
|
|
&exp.Description,
|
|
&exp.Status,
|
|
&exp.UserID,
|
|
&exp.WorkspaceID,
|
|
&exp.CreatedAt,
|
|
&exp.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get experiment: %w", err)
|
|
}
|
|
|
|
result := &ExperimentWithMetadata{Experiment: exp}
|
|
|
|
var env ExperimentEnvironment
|
|
var envDeps sql.NullString
|
|
var envQuery string
|
|
if db.dbType == DBTypeSQLite {
|
|
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
|
|
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
|
|
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
|
|
dependencies, created_at
|
|
FROM experiment_environments WHERE experiment_id = ?`
|
|
} else {
|
|
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
|
|
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
|
|
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
|
|
dependencies, created_at
|
|
FROM experiment_environments WHERE experiment_id = $1`
|
|
}
|
|
if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan(
|
|
&env.PythonVersion,
|
|
&env.CUDAVersion,
|
|
&env.SystemOS,
|
|
&env.SystemArch,
|
|
&env.Hostname,
|
|
&env.RequirementsHash,
|
|
&env.CondaEnvHash,
|
|
&envDeps,
|
|
&env.CreatedAt,
|
|
); err == nil {
|
|
if envDeps.Valid && envDeps.String != "" {
|
|
env.Dependencies = json.RawMessage(envDeps.String)
|
|
}
|
|
result.Environment = &env
|
|
}
|
|
|
|
var git ExperimentGitInfo
|
|
var gitDirty sql.NullInt64
|
|
var gitQuery string
|
|
if db.dbType == DBTypeSQLite {
|
|
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
|
|
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
|
|
COALESCE(diff_patch, ''), created_at
|
|
FROM experiment_git_info WHERE experiment_id = ?`
|
|
} else {
|
|
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
|
|
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
|
|
COALESCE(diff_patch, ''), created_at
|
|
FROM experiment_git_info WHERE experiment_id = $1`
|
|
}
|
|
if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan(
|
|
&git.CommitSHA,
|
|
&git.Branch,
|
|
&git.RemoteURL,
|
|
&gitDirty,
|
|
&git.DiffPatch,
|
|
&git.CreatedAt,
|
|
); err == nil {
|
|
git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0
|
|
result.GitInfo = &git
|
|
}
|
|
|
|
var seeds ExperimentSeeds
|
|
var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64
|
|
var seedsQuery string
|
|
if db.dbType == DBTypeSQLite {
|
|
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
|
|
FROM experiment_seeds WHERE experiment_id = ?`
|
|
} else {
|
|
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
|
|
FROM experiment_seeds WHERE experiment_id = $1`
|
|
}
|
|
if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan(
|
|
&numpySeed,
|
|
&torchSeed,
|
|
&tfSeed,
|
|
&randSeed,
|
|
&seeds.CreatedAt,
|
|
); err == nil {
|
|
if numpySeed.Valid {
|
|
v := numpySeed.Int64
|
|
seeds.Numpy = &v
|
|
}
|
|
if torchSeed.Valid {
|
|
v := torchSeed.Int64
|
|
seeds.Torch = &v
|
|
}
|
|
if tfSeed.Valid {
|
|
v := tfSeed.Int64
|
|
seeds.TensorFlow = &v
|
|
}
|
|
if randSeed.Valid {
|
|
v := randSeed.Int64
|
|
seeds.Random = &v
|
|
}
|
|
result.Seeds = &seeds
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error {
|
|
if ds == nil {
|
|
return fmt.Errorf("dataset is nil")
|
|
}
|
|
if ds.Name == "" {
|
|
return fmt.Errorf("dataset name is required")
|
|
}
|
|
if ds.URL == "" {
|
|
return fmt.Errorf("dataset url is required")
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO datasets (name, url)
|
|
VALUES (?, ?)
|
|
ON CONFLICT(name) DO UPDATE SET
|
|
url = excluded.url`
|
|
} else {
|
|
query = `INSERT INTO datasets (name, url)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (name) DO UPDATE SET
|
|
url = EXCLUDED.url`
|
|
}
|
|
|
|
if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil {
|
|
return fmt.Errorf("failed to upsert dataset: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
|
|
if name == "" {
|
|
return nil, fmt.Errorf("dataset name is required")
|
|
}
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?`
|
|
} else {
|
|
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1`
|
|
}
|
|
|
|
var ds Dataset
|
|
if err := db.conn.QueryRowContext(ctx, query, name).Scan(
|
|
&ds.Name,
|
|
&ds.URL,
|
|
&ds.CreatedAt,
|
|
&ds.UpdatedAt,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("failed to get dataset: %w", err)
|
|
}
|
|
return &ds, nil
|
|
}
|
|
|
|
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
|
|
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
|
|
var args []interface{}
|
|
if limit > 0 {
|
|
if db.dbType == DBTypeSQLite {
|
|
query += " LIMIT ?"
|
|
} else {
|
|
query += " LIMIT $1"
|
|
}
|
|
args = append(args, limit)
|
|
}
|
|
|
|
rows, err := db.conn.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list datasets: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var out []*Dataset
|
|
for rows.Next() {
|
|
var ds Dataset
|
|
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("failed to scan dataset: %w", err)
|
|
}
|
|
out = append(out, &ds)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating datasets: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) {
|
|
if term == "" {
|
|
return []*Dataset{}, nil
|
|
}
|
|
|
|
// Escape %/_ for LIKE and use parameterized query.
|
|
escaped := strings.ReplaceAll(term, "\\", "\\\\")
|
|
escaped = strings.ReplaceAll(escaped, "%", "\\%")
|
|
escaped = strings.ReplaceAll(escaped, "_", "\\_")
|
|
pattern := "%" + escaped + "%"
|
|
|
|
var query string
|
|
var args []interface{}
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT name, url, created_at, updated_at FROM datasets
|
|
WHERE name LIKE ? ESCAPE '\'
|
|
ORDER BY name ASC`
|
|
args = append(args, pattern)
|
|
if limit > 0 {
|
|
query += " LIMIT ?"
|
|
args = append(args, limit)
|
|
}
|
|
} else {
|
|
query = `SELECT name, url, created_at, updated_at FROM datasets
|
|
WHERE name LIKE $1 ESCAPE '\'
|
|
ORDER BY name ASC`
|
|
args = append(args, pattern)
|
|
if limit > 0 {
|
|
query += " LIMIT $2"
|
|
args = append(args, limit)
|
|
}
|
|
}
|
|
|
|
rows, err := db.conn.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search datasets: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var out []*Dataset
|
|
for rows.Next() {
|
|
var ds Dataset
|
|
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("failed to scan dataset: %w", err)
|
|
}
|
|
out = append(out, &ds)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating datasets: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|