fetch_ml/internal/storage/db_experiments.go
Jeremie Fraeys 2b1ef10514
test(chaos): add worker disconnect chaos test and queue improvements
Chaos testing:
- Add worker_disconnect_chaos_test.go for network partition resilience
- Test scheduler hub recovery and job reassignment scenarios

Queue layer updates:
- event_store.go: add event sourcing for queue operations
- native_queue.go: extend native queue with batch operations and indexing
2026-03-12 12:08:21 -04:00

731 lines
20 KiB
Go

// Package storage provides database abstraction and job management.
package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Experiment struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Status string `json:"status"`
UserID string `json:"user_id,omitempty"`
WorkspaceID string `json:"workspace_id,omitempty"`
}
type ExperimentEnvironment struct {
CreatedAt time.Time `json:"created_at"`
PythonVersion string `json:"python_version"`
CUDAVersion string `json:"cuda_version,omitempty"`
SystemOS string `json:"system_os"`
SystemArch string `json:"system_arch"`
Hostname string `json:"hostname"`
RequirementsHash string `json:"requirements_hash"`
CondaEnvHash string `json:"conda_env_hash,omitempty"`
Dependencies json.RawMessage `json:"dependencies,omitempty"`
}
type ExperimentGitInfo struct {
CreatedAt time.Time `json:"created_at"`
CommitSHA string `json:"commit_sha"`
Branch string `json:"branch"`
RemoteURL string `json:"remote_url"`
DiffPatch string `json:"diff_patch,omitempty"`
IsDirty bool `json:"is_dirty"`
}
type ExperimentSeeds struct {
Numpy *int64 `json:"numpy_seed,omitempty"`
Torch *int64 `json:"torch_seed,omitempty"`
TensorFlow *int64 `json:"tensorflow_seed,omitempty"`
Random *int64 `json:"random_seed,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type Dataset struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Name string `json:"name"`
URL string `json:"url"`
}
type ExperimentWithMetadata struct {
Environment *ExperimentEnvironment `json:"environment,omitempty"`
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
Experiment Experiment `json:"experiment"`
}
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
if exp == nil {
return fmt.Errorf("experiment is nil")
}
if exp.ID == "" {
return fmt.Errorf("experiment id is required")
}
if exp.Name == "" {
return fmt.Errorf("experiment name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
description = excluded.description,
status = excluded.status,
user_id = excluded.user_id,
workspace_id = excluded.workspace_id`
} else {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
status = EXCLUDED.status,
user_id = EXCLUDED.user_id,
workspace_id = EXCLUDED.workspace_id`
}
_, err := db.conn.ExecContext(
ctx,
query,
exp.ID,
exp.Name,
exp.Description,
exp.Status,
exp.UserID,
exp.WorkspaceID,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentEnvironment(
ctx context.Context,
experimentID string,
env *ExperimentEnvironment,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if env == nil {
return fmt.Errorf("environment is nil")
}
deps := ""
if len(env.Dependencies) > 0 {
deps = string(env.Dependencies)
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
python_version = excluded.python_version,
cuda_version = excluded.cuda_version,
system_os = excluded.system_os,
system_arch = excluded.system_arch,
hostname = excluded.hostname,
requirements_hash = excluded.requirements_hash,
conda_env_hash = excluded.conda_env_hash,
dependencies = excluded.dependencies`
} else {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (experiment_id) DO UPDATE SET
python_version = EXCLUDED.python_version,
cuda_version = EXCLUDED.cuda_version,
system_os = EXCLUDED.system_os,
system_arch = EXCLUDED.system_arch,
hostname = EXCLUDED.hostname,
requirements_hash = EXCLUDED.requirements_hash,
conda_env_hash = EXCLUDED.conda_env_hash,
dependencies = EXCLUDED.dependencies`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
env.PythonVersion,
env.CUDAVersion,
env.SystemOS,
env.SystemArch,
env.Hostname,
env.RequirementsHash,
env.CondaEnvHash,
deps,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment environment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentGitInfo(
ctx context.Context,
experimentID string,
info *ExperimentGitInfo,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if info == nil {
return fmt.Errorf("git info is nil")
}
isDirty := 0
if info.IsDirty {
isDirty = 1
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
commit_sha = excluded.commit_sha,
branch = excluded.branch,
remote_url = excluded.remote_url,
is_dirty = excluded.is_dirty,
diff_patch = excluded.diff_patch`
} else {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (experiment_id) DO UPDATE SET
commit_sha = EXCLUDED.commit_sha,
branch = EXCLUDED.branch,
remote_url = EXCLUDED.remote_url,
is_dirty = EXCLUDED.is_dirty,
diff_patch = EXCLUDED.diff_patch`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
info.CommitSHA,
info.Branch,
info.RemoteURL,
isDirty,
info.DiffPatch,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment git info: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentSeeds(
ctx context.Context,
experimentID string,
seeds *ExperimentSeeds,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if seeds == nil {
return fmt.Errorf("seeds is nil")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
numpy_seed = excluded.numpy_seed,
torch_seed = excluded.torch_seed,
tensorflow_seed = excluded.tensorflow_seed,
random_seed = excluded.random_seed`
} else {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (experiment_id) DO UPDATE SET
numpy_seed = EXCLUDED.numpy_seed,
torch_seed = EXCLUDED.torch_seed,
tensorflow_seed = EXCLUDED.tensorflow_seed,
random_seed = EXCLUDED.random_seed`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
seeds.Numpy,
seeds.Torch,
seeds.TensorFlow,
seeds.Random,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment seeds: %w", err)
}
return nil
}
func (db *DB) GetExperimentWithMetadata(
ctx context.Context,
experimentID string,
) (*ExperimentWithMetadata, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment id is required")
}
var exp Experiment
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = ?`
} else {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = $1`
}
err := db.conn.QueryRowContext(ctx, query, experimentID).Scan(
&exp.ID,
&exp.Name,
&exp.Description,
&exp.Status,
&exp.UserID,
&exp.WorkspaceID,
&exp.CreatedAt,
&exp.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get experiment: %w", err)
}
result := &ExperimentWithMetadata{Experiment: exp}
var env ExperimentEnvironment
var envDeps sql.NullString
var envQuery string
if db.dbType == DBTypeSQLite {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = ?`
} else {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan(
&env.PythonVersion,
&env.CUDAVersion,
&env.SystemOS,
&env.SystemArch,
&env.Hostname,
&env.RequirementsHash,
&env.CondaEnvHash,
&envDeps,
&env.CreatedAt,
); err == nil {
if envDeps.Valid && envDeps.String != "" {
env.Dependencies = json.RawMessage(envDeps.String)
}
result.Environment = &env
}
var git ExperimentGitInfo
var gitDirty sql.NullInt64
var gitQuery string
if db.dbType == DBTypeSQLite {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = ?`
} else {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan(
&git.CommitSHA,
&git.Branch,
&git.RemoteURL,
&gitDirty,
&git.DiffPatch,
&git.CreatedAt,
); err == nil {
git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0
result.GitInfo = &git
}
var seeds ExperimentSeeds
var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64
var seedsQuery string
if db.dbType == DBTypeSQLite {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = ?`
} else {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan(
&numpySeed,
&torchSeed,
&tfSeed,
&randSeed,
&seeds.CreatedAt,
); err == nil {
if numpySeed.Valid {
v := numpySeed.Int64
seeds.Numpy = &v
}
if torchSeed.Valid {
v := torchSeed.Int64
seeds.Torch = &v
}
if tfSeed.Valid {
v := tfSeed.Int64
seeds.TensorFlow = &v
}
if randSeed.Valid {
v := randSeed.Int64
seeds.Random = &v
}
result.Seeds = &seeds
}
return result, nil
}
func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error {
if ds == nil {
return fmt.Errorf("dataset is nil")
}
if ds.Name == "" {
return fmt.Errorf("dataset name is required")
}
if ds.URL == "" {
return fmt.Errorf("dataset url is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO datasets (name, url)
VALUES (?, ?)
ON CONFLICT(name) DO UPDATE SET
url = excluded.url`
} else {
query = `INSERT INTO datasets (name, url)
VALUES ($1, $2)
ON CONFLICT (name) DO UPDATE SET
url = EXCLUDED.url`
}
if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil {
return fmt.Errorf("failed to upsert dataset: %w", err)
}
return nil
}
func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
if name == "" {
return nil, fmt.Errorf("dataset name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?`
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1`
}
var ds Dataset
if err := db.conn.QueryRowContext(ctx, query, name).Scan(
&ds.Name,
&ds.URL,
&ds.CreatedAt,
&ds.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, err
}
return nil, fmt.Errorf("failed to get dataset: %w", err)
}
return &ds, nil
}
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
var args []any
if limit > 0 {
if db.dbType == DBTypeSQLite {
query += " LIMIT ?"
} else {
query += " LIMIT $1"
}
args = append(args, limit)
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}
func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) {
if term == "" {
return []*Dataset{}, nil
}
// Escape %/_ for LIKE and use parameterized query.
escaped := strings.ReplaceAll(term, "\\", "\\\\")
escaped = strings.ReplaceAll(escaped, "%", "\\%")
escaped = strings.ReplaceAll(escaped, "_", "\\_")
pattern := "%" + escaped + "%"
var query string
var args []any
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE ? ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
}
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE $1 ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT $2"
args = append(args, limit)
}
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to search datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}
// AssociateTaskWithExperiment links a task to an experiment.
func (db *DB) AssociateTaskWithExperiment(taskID, experimentID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO experiment_tasks (experiment_id, task_id) VALUES (?, ?)`
} else {
query = `INSERT INTO experiment_tasks (experiment_id, task_id) VALUES ($1, $2)
ON CONFLICT (experiment_id, task_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, experimentID, taskID)
if err != nil {
return fmt.Errorf("failed to associate task with experiment: %w", err)
}
return nil
}
// GetExperimentVisibility returns the visibility level for an experiment.
func (db *DB) GetExperimentVisibility(experimentID string) (string, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COALESCE(j.visibility, 'private')
FROM experiment_tasks et
JOIN jobs j ON j.id = et.task_id
WHERE et.experiment_id = ?
ORDER BY j.created_at DESC
LIMIT 1`
} else {
query = `SELECT COALESCE(j.visibility, 'private')
FROM experiment_tasks et
JOIN jobs j ON j.id = et.task_id
WHERE et.experiment_id = $1
ORDER BY j.created_at DESC
LIMIT 1`
}
var visibility string
err := db.conn.QueryRowContext(context.Background(), query, experimentID).Scan(&visibility)
if err != nil {
if err == sql.ErrNoRows {
return "private", nil
}
return "", fmt.Errorf("failed to get experiment visibility: %w", err)
}
return visibility, nil
}
// CascadeExperimentVisibility updates visibility for all tasks in an experiment.
func (db *DB) CascadeExperimentVisibility(experimentID, visibility string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE jobs
SET visibility = ?
WHERE id IN (
SELECT task_id FROM experiment_tasks WHERE experiment_id = ?
)`
} else {
query = `UPDATE jobs
SET visibility = $1
WHERE id IN (
SELECT task_id FROM experiment_tasks WHERE experiment_id = $2
)`
}
_, err := db.conn.ExecContext(context.Background(), query, visibility, experimentID)
if err != nil {
return fmt.Errorf("failed to cascade visibility: %w", err)
}
return nil
}
// ListTasksForExperiment returns all tasks associated with an experiment.
func (db *DB) ListTasksForExperiment(experimentID string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN experiment_tasks et ON et.task_id = j.id
WHERE et.experiment_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY j.created_at DESC, j.id DESC
LIMIT ?`
} else {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN experiment_tasks et ON et.task_id = j.id
WHERE et.experiment_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
ORDER BY j.created_at DESC, j.id DESC
LIMIT $4`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list experiment tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}