refactor(storage,queue): split storage layer and add sqlite queue backend

This commit is contained in:
Jeremie Fraeys 2026-01-05 12:31:02 -05:00
parent e901ddd810
commit 6ff5324e74
11 changed files with 1916 additions and 181 deletions

76
internal/queue/backend.go Normal file
View file

@ -0,0 +1,76 @@
package queue
import (
"errors"
"time"
)
var ErrInvalidQueueBackend = errors.New("invalid queue backend")
type Backend interface {
AddTask(task *Task) error
GetNextTask() (*Task, error)
PeekNextTask() (*Task, error)
GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error)
GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error)
RenewLease(taskID string, workerID string, leaseDuration time.Duration) error
ReleaseLease(taskID string, workerID string) error
RetryTask(task *Task) error
MoveToDeadLetterQueue(task *Task, reason string) error
GetTask(taskID string) (*Task, error)
GetAllTasks() ([]*Task, error)
GetTaskByName(jobName string) (*Task, error)
CancelTask(taskID string) error
UpdateTask(task *Task) error
UpdateTaskWithMetrics(task *Task, action string) error
RecordMetric(jobName, metric string, value float64) error
Heartbeat(workerID string) error
QueueDepth() (int64, error)
SetWorkerPrewarmState(state PrewarmState) error
ClearWorkerPrewarmState(workerID string) error
GetWorkerPrewarmState(workerID string) (*PrewarmState, error)
GetAllWorkerPrewarmStates() ([]PrewarmState, error)
SignalPrewarmGC() error
PrewarmGCRequestValue() (string, error)
Close() error
}
type QueueBackend string
const (
QueueBackendRedis QueueBackend = "redis"
QueueBackendSQLite QueueBackend = "sqlite"
)
type BackendConfig struct {
Backend QueueBackend
RedisAddr string
RedisPassword string
RedisDB int
SQLitePath string
MetricsFlushInterval time.Duration
}
func NewBackend(cfg BackendConfig) (Backend, error) {
switch cfg.Backend {
case "", QueueBackendRedis:
return NewTaskQueue(Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
MetricsFlushInterval: cfg.MetricsFlushInterval,
})
case QueueBackendSQLite:
return NewSQLiteQueue(cfg.SQLitePath)
default:
return nil, ErrInvalidQueueBackend
}
}

View file

@ -174,8 +174,10 @@ func IsRetryable(category ErrorCategory) bool {
// GetUserMessage returns a user-friendly error message with suggestions
func GetUserMessage(category ErrorCategory, err error) string {
messages := map[ErrorCategory]string{
ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.",
ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.",
ErrorNetwork: "Network connectivity issue. Please check your network " +
"connection and try again.",
ErrorResource: "System resource exhausted. The system may be under heavy load. " +
"Try again later or contact support.",
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
ErrorValidation: "Invalid input. Please review your request and correct any errors.",

View file

@ -15,6 +15,7 @@ const (
defaultLeaseDuration = 30 * time.Minute
defaultMaxRetries = 3
defaultBlockTimeout = 1 * time.Second
PrewarmGCRequestKey = "ml:prewarm:gc_request"
)
// TaskQueue manages ML experiment tasks via Redis
@ -33,6 +34,21 @@ type metricEvent struct {
Value float64
}
type PrewarmState struct {
WorkerID string `json:"worker_id"`
TaskID string `json:"task_id"`
SnapshotID string `json:"snapshot_id,omitempty"`
StartedAt string `json:"started_at"`
UpdatedAt string `json:"updated_at"`
Phase string `json:"phase"`
EnvImage string `json:"env_image,omitempty"`
DatasetCnt int `json:"dataset_count"`
EnvHit int64 `json:"env_hit,omitempty"`
EnvMiss int64 `json:"env_miss,omitempty"`
EnvBuilt int64 `json:"env_built,omitempty"`
EnvTimeNs int64 `json:"env_time_ns,omitempty"`
}
// Config holds configuration for TaskQueue
type Config struct {
RedisAddr string
@ -155,7 +171,10 @@ func isBlockingUnsupported(err error) bool {
return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax")
}
func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
func (tq *TaskQueue) pollUntilDeadline(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
deadline := time.Now().Add(blockTimeout)
sleep := 25 * time.Millisecond
@ -194,8 +213,98 @@ func (tq *TaskQueue) GetNextTask() (*Task, error) {
return tq.GetTask(taskID)
}
// PeekNextTask returns the highest priority task without removing it from the queue.
// This is intended for best-effort prewarm logic; it must never be required for correctness.
func (tq *TaskQueue) PeekNextTask() (*Task, error) {
// ZRANGE with REV gives highest scores first.
ids, err := tq.client.ZRevRange(tq.ctx, TaskQueueKey, 0, 0).Result()
if err != nil {
return nil, err
}
if len(ids) == 0 {
return nil, nil
}
return tq.GetTask(ids[0])
}
func (tq *TaskQueue) SetWorkerPrewarmState(state PrewarmState) error {
if state.WorkerID == "" {
return fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + state.WorkerID
data, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal prewarm state: %w", err)
}
// Keep short TTL to avoid stale prewarm state if worker dies.
return tq.client.Set(tq.ctx, key, data, 30*time.Second).Err()
}
func (tq *TaskQueue) ClearWorkerPrewarmState(workerID string) error {
if workerID == "" {
return fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + workerID
return tq.client.Del(tq.ctx, key).Err()
}
func (tq *TaskQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
if workerID == "" {
return nil, fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + workerID
v, err := tq.client.Get(tq.ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state PrewarmState
if err := json.Unmarshal([]byte(v), &state); err != nil {
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
}
return &state, nil
}
func (tq *TaskQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
var cursor uint64
pattern := WorkerPrewarmKey + "*"
out := make([]PrewarmState, 0, 8)
for {
keys, next, err := tq.client.Scan(tq.ctx, cursor, pattern, 50).Result()
if err != nil {
return nil, err
}
cursor = next
for _, key := range keys {
v, err := tq.client.Get(tq.ctx, key).Result()
if err == redis.Nil {
continue
}
if err != nil {
return nil, err
}
var state PrewarmState
if err := json.Unmarshal([]byte(v), &state); err != nil {
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
}
out = append(out, state)
}
if cursor == 0 {
break
}
}
return out, nil
}
// GetNextTaskWithLease gets the next task and acquires a lease
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
func (tq *TaskQueue) GetNextTaskWithLease(
workerID string,
leaseDuration time.Duration,
) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
@ -239,7 +348,8 @@ func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Du
return task, nil
}
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease.
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task.
// Once a task is received, it then acquires a lease.
func (tq *TaskQueue) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration, blockTimeout time.Duration,
@ -580,6 +690,23 @@ func (tq *TaskQueue) Close() error {
return tq.client.Close()
}
func (tq *TaskQueue) SignalPrewarmGC() error {
return tq.client.Set(
tq.ctx,
PrewarmGCRequestKey,
time.Now().UnixNano(),
10*time.Minute,
).Err()
}
func (tq *TaskQueue) PrewarmGCRequestValue() (string, error) {
v, err := tq.client.Get(tq.ctx, PrewarmGCRequestKey).Result()
if err == redis.Nil {
return "", nil
}
return v, err
}
// GetRedisClient returns the underlying Redis client for direct access
func (tq *TaskQueue) GetRedisClient() *redis.Client {
return tq.client

View file

@ -0,0 +1,762 @@
package queue
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
_ "github.com/mattn/go-sqlite3"
)
type SQLiteQueue struct {
db *sql.DB
ctx context.Context
cancel context.CancelFunc
}
func NewSQLiteQueue(path string) (*SQLiteQueue, error) {
if path == "" {
return nil, fmt.Errorf("sqlite queue path is required")
}
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=5000&_foreign_keys=on", path))
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
ctx, cancel := context.WithCancel(context.Background())
q := &SQLiteQueue{db: db, ctx: ctx, cancel: cancel}
if err := q.initSchema(); err != nil {
_ = db.Close()
cancel()
return nil, err
}
go q.leaseReclaimer()
go q.kvJanitor()
return q, nil
}
func (q *SQLiteQueue) initSchema() error {
stmts := []string{
"PRAGMA journal_mode=WAL;",
"PRAGMA synchronous=NORMAL;",
`CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
job_name TEXT,
status TEXT,
priority INTEGER,
created_at INTEGER,
updated_at INTEGER,
payload BLOB
);`,
"CREATE INDEX IF NOT EXISTS idx_tasks_job_name ON tasks(job_name);",
"CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);",
`CREATE TABLE IF NOT EXISTS queue (
task_id TEXT PRIMARY KEY,
priority INTEGER,
available_at INTEGER,
created_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_queue_available ON queue(available_at, priority DESC, created_at);",
`CREATE TABLE IF NOT EXISTS worker_prewarm (
worker_id TEXT PRIMARY KEY,
payload BLOB,
updated_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_prewarm_updated ON worker_prewarm(updated_at);",
`CREATE TABLE IF NOT EXISTS kv (
key TEXT PRIMARY KEY,
value TEXT,
expires_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_kv_expires ON kv(expires_at);",
`CREATE TABLE IF NOT EXISTS worker_heartbeat (
worker_id TEXT PRIMARY KEY,
last_seen INTEGER
);`,
}
for _, stmt := range stmts {
if _, err := q.db.ExecContext(q.ctx, stmt); err != nil {
return fmt.Errorf("sqlite schema init failed: %w", err)
}
}
return nil
}
func (q *SQLiteQueue) Close() error {
q.cancel()
return q.db.Close()
}
func (q *SQLiteQueue) AddTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.ID == "" {
return fmt.Errorf("task id is required")
}
if task.JobName == "" {
return fmt.Errorf("job name is required")
}
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
now := time.Now().UTC()
if task.CreatedAt.IsZero() {
task.CreatedAt = now
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
createdAt := task.CreatedAt.UnixNano()
updatedAt := now.UnixNano()
availableAt := now.UnixNano()
if task.NextRetry != nil {
availableAt = task.NextRetry.UTC().UnixNano()
}
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO tasks(id, job_name, status, priority, created_at, updated_at, payload)
VALUES(?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
job_name=excluded.job_name,
status=excluded.status,
priority=excluded.priority,
created_at=excluded.created_at,
updated_at=excluded.updated_at,
payload=excluded.payload`,
task.ID,
task.JobName,
task.Status,
task.Priority,
createdAt,
updatedAt,
payload,
)
if err != nil {
return err
}
if task.Status == "queued" {
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO queue(task_id, priority, available_at, created_at)
VALUES(?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
priority=excluded.priority,
available_at=excluded.available_at,
created_at=excluded.created_at`,
task.ID,
task.Priority,
availableAt,
createdAt,
)
if err != nil {
return err
}
} else {
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
TasksQueued.Inc()
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
return nil
}
func (q *SQLiteQueue) GetTask(taskID string) (*Task, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT payload FROM tasks WHERE id = ?", taskID)
var payload []byte
if err := row.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
return &t, nil
}
func (q *SQLiteQueue) GetAllTasks() ([]*Task, error) {
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]*Task, 0, 32)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
out = append(out, &t)
}
return out, rows.Err()
}
func (q *SQLiteQueue) GetTaskByName(jobName string) (*Task, error) {
row := q.db.QueryRowContext(
q.ctx,
"SELECT payload FROM tasks WHERE job_name = ? ORDER BY created_at DESC LIMIT 1",
jobName,
)
var payload []byte
if err := row.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
return &t, nil
}
func (q *SQLiteQueue) UpdateTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
now := time.Now().UTC().UnixNano()
availableAt := time.Now().UTC().UnixNano()
if task.NextRetry != nil {
availableAt = task.NextRetry.UTC().UnixNano()
}
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
_, err = tx.ExecContext(
q.ctx,
`UPDATE tasks SET job_name=?, status=?, priority=?, updated_at=?, payload=? WHERE id=?`,
task.JobName,
task.Status,
task.Priority,
now,
payload,
task.ID,
)
if err != nil {
return err
}
if task.Status == "queued" {
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO queue(task_id, priority, available_at, created_at)
VALUES(?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
priority=excluded.priority,
available_at=excluded.available_at`,
task.ID,
task.Priority,
availableAt,
task.CreatedAt.UTC().UnixNano(),
)
if err != nil {
return err
}
} else {
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
if err != nil {
return err
}
}
return tx.Commit()
}
func (q *SQLiteQueue) UpdateTaskWithMetrics(task *Task, _ string) error {
return q.UpdateTask(task)
}
func (q *SQLiteQueue) GetNextTask() (*Task, error) {
t, _, err := q.peekOrPop(false)
return t, err
}
func (q *SQLiteQueue) PeekNextTask() (*Task, error) {
t, _, err := q.peekOrPop(true)
return t, err
}
func (q *SQLiteQueue) peekOrPop(peek bool) (*Task, int64, error) {
now := time.Now().UTC().UnixNano()
query := `SELECT q.task_id, t.payload FROM queue q
JOIN tasks t ON t.id = q.task_id
WHERE q.available_at <= ?
ORDER BY q.priority DESC, q.created_at ASC
LIMIT 1`
row := q.db.QueryRowContext(q.ctx, query, now)
var taskID string
var payload []byte
if err := row.Scan(&taskID, &payload); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, 0, nil
}
return nil, 0, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, 0, err
}
if !peek {
if _, err := q.db.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
return nil, 0, err
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
}
return &t, 0, nil
}
func (q *SQLiteQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
return q.claimTask(workerID, leaseDuration)
}
func (q *SQLiteQueue) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
if blockTimeout <= 0 {
blockTimeout = defaultBlockTimeout
}
deadline := time.Now().Add(blockTimeout)
for {
t, err := q.claimTask(workerID, leaseDuration)
if err != nil {
return nil, err
}
if t != nil {
return t, nil
}
if time.Now().After(deadline) {
return nil, nil
}
time.Sleep(50 * time.Millisecond)
}
}
func (q *SQLiteQueue) claimTask(workerID string, leaseDuration time.Duration) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
if workerID == "" {
return nil, fmt.Errorf("worker_id is required")
}
now := time.Now().UTC()
nowNs := now.UnixNano()
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return nil, err
}
defer func() { _ = tx.Rollback() }()
row := tx.QueryRowContext(
q.ctx,
`SELECT q.task_id, t.payload FROM queue q
JOIN tasks t ON t.id = q.task_id
WHERE q.available_at <= ?
ORDER BY q.priority DESC, q.created_at ASC
LIMIT 1`,
nowNs,
)
var taskID string
var payload []byte
if err := row.Scan(&taskID, &payload); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
// Mirror Redis semantics: acquire lease fields but do not modify status here.
exp := now.Add(leaseDuration)
t.LeaseExpiry = &exp
t.LeasedBy = workerID
newPayload, err := json.Marshal(&t)
if err != nil {
return nil, err
}
if _, err := tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
return nil, err
}
if _, err := tx.ExecContext(
q.ctx,
"UPDATE tasks SET updated_at=?, payload=? WHERE id=?",
nowNs,
newPayload,
taskID,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
return &t, nil
}
func (q *SQLiteQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
exp := time.Now().UTC().Add(leaseDuration)
t.LeaseExpiry = &exp
RecordLeaseRenewal(workerID)
return q.UpdateTask(t)
}
func (q *SQLiteQueue) ReleaseLease(taskID string, workerID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
t.LeaseExpiry = nil
t.LeasedBy = ""
return q.UpdateTask(t)
}
func (q *SQLiteQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
RecordDLQAddition("max_retries")
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
}
errorCategory := ErrorUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
}
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error
task.Error = ""
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
task.LeaseExpiry = nil
task.LeasedBy = ""
RecordTaskRetry(task.JobName, errorCategory)
return q.AddTask(task)
}
func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
return q.UpdateTask(task)
}
func (q *SQLiteQueue) CancelTask(taskID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
t.Status = "cancelled"
now := time.Now().UTC()
t.EndedAt = &now
return q.UpdateTask(t)
}
func (q *SQLiteQueue) RecordMetric(_, _ string, _ float64) error {
return nil
}
func (q *SQLiteQueue) Heartbeat(workerID string) error {
if workerID == "" {
return fmt.Errorf("worker_id is required")
}
_, err := q.db.ExecContext(
q.ctx,
`INSERT INTO worker_heartbeat(worker_id, last_seen) VALUES(?, ?)
ON CONFLICT(worker_id) DO UPDATE SET last_seen=excluded.last_seen`,
workerID,
time.Now().UTC().UnixNano(),
)
return err
}
func (q *SQLiteQueue) QueueDepth() (int64, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT COUNT(1) FROM queue")
var n int64
if err := row.Scan(&n); err != nil {
return 0, err
}
return n, nil
}
func (q *SQLiteQueue) SetWorkerPrewarmState(state PrewarmState) error {
if state.WorkerID == "" {
return fmt.Errorf("missing worker_id")
}
if state.StartedAt == "" {
state.StartedAt = time.Now().UTC().Format(time.RFC3339)
}
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
payload, err := json.Marshal(state)
if err != nil {
return err
}
_, err = q.db.ExecContext(
q.ctx,
`INSERT INTO worker_prewarm(worker_id, payload, updated_at) VALUES(?, ?, ?)
ON CONFLICT(worker_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at`,
state.WorkerID,
payload,
time.Now().UTC().UnixNano(),
)
return err
}
func (q *SQLiteQueue) ClearWorkerPrewarmState(workerID string) error {
if workerID == "" {
return fmt.Errorf("missing worker_id")
}
_, err := q.db.ExecContext(q.ctx, "DELETE FROM worker_prewarm WHERE worker_id = ?", workerID)
return err
}
func (q *SQLiteQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
if workerID == "" {
return nil, fmt.Errorf("missing worker_id")
}
row := q.db.QueryRowContext(q.ctx, "SELECT payload, updated_at FROM worker_prewarm WHERE worker_id = ?", workerID)
var payload []byte
var updatedAt int64
if err := row.Scan(&payload, &updatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
// Mirror Redis TTL semantics (30s).
if time.Since(time.Unix(0, updatedAt)) > 30*time.Second {
return nil, nil
}
var st PrewarmState
if err := json.Unmarshal(payload, &st); err != nil {
return nil, err
}
return &st, nil
}
func (q *SQLiteQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
cutoff := time.Now().UTC().Add(-30 * time.Second).UnixNano()
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM worker_prewarm WHERE updated_at >= ?", cutoff)
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]PrewarmState, 0, 8)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
return nil, err
}
var st PrewarmState
if err := json.Unmarshal(payload, &st); err != nil {
return nil, err
}
out = append(out, st)
}
return out, rows.Err()
}
func (q *SQLiteQueue) SignalPrewarmGC() error {
return q.kvSet(PrewarmGCRequestKey, fmt.Sprintf("%d", time.Now().UTC().UnixNano()), 10*time.Minute)
}
func (q *SQLiteQueue) PrewarmGCRequestValue() (string, error) {
v, _, err := q.kvGet(PrewarmGCRequestKey)
return v, err
}
func (q *SQLiteQueue) kvSet(key, value string, ttl time.Duration) error {
exp := time.Now().UTC().Add(ttl).UnixNano()
_, err := q.db.ExecContext(
q.ctx,
`INSERT INTO kv(key, value, expires_at) VALUES(?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value, expires_at=excluded.expires_at`,
key,
value,
exp,
)
return err
}
func (q *SQLiteQueue) kvGet(key string) (string, bool, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT value, expires_at FROM kv WHERE key = ?", key)
var value string
var exp int64
if err := row.Scan(&value, &exp); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
}
return "", false, err
}
if exp > 0 && time.Now().UTC().UnixNano() > exp {
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE key = ?", key)
return "", false, nil
}
return value, true, nil
}
func (q *SQLiteQueue) kvJanitor() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-q.ctx.Done():
return
case <-ticker.C:
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE expires_at > 0 AND expires_at < ?", time.Now().UTC().UnixNano())
}
}
}
func (q *SQLiteQueue) leaseReclaimer() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-q.ctx.Done():
return
case <-ticker.C:
_ = q.reclaimExpiredLeases()
}
}
}
func (q *SQLiteQueue) ReclaimExpiredLeases() error {
return q.reclaimExpiredLeases()
}
func (q *SQLiteQueue) reclaimExpiredLeases() error {
now := time.Now().UTC()
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
if err != nil {
return err
}
payloads := make([][]byte, 0, 32)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
_ = rows.Close()
return err
}
payloads = append(payloads, payload)
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return err
}
_ = rows.Close()
for _, payload := range payloads {
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
continue
}
if t.LeaseExpiry == nil {
continue
}
if t.Status != "running" {
continue
}
if t.LeaseExpiry.Before(now) {
t.Error = fmt.Sprintf("worker %s lease expired", t.LeasedBy)
RecordLeaseExpiration()
if t.RetryCount < t.MaxRetries {
_ = q.RetryTask(&t)
} else {
_ = q.MoveToDeadLetterQueue(&t, "lease expiry after max retries")
}
}
}
return nil
}

View file

@ -6,21 +6,41 @@ import (
"github.com/jfraeys/fetch_ml/internal/config"
)
// DatasetSpec describes a dataset input with optional provenance fields.
type DatasetSpec struct {
Name string `json:"name"`
Version string `json:"version,omitempty"`
Checksum string `json:"checksum,omitempty"`
URI string `json:"uri,omitempty"`
}
// Task represents an ML experiment task
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
// TODO(phase1): SnapshotID is an opaque identifier only.
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
SnapshotID string `json:"snapshot_id,omitempty"`
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
// Datasets is kept for backward compatibility (legacy callers).
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// Resource requests (optional, 0 means unspecified)
CPU int `json:"cpu,omitempty"`
MemoryGB int `json:"memory_gb,omitempty"`
GPU int `json:"gpu,omitempty"`
GPUMemory string `json:"gpu_memory,omitempty"`
// User ownership and permissions
UserID string `json:"user_id"` // User who owns this task
@ -36,6 +56,38 @@ type Task struct {
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
LastError string `json:"last_error,omitempty"` // Last error encountered
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
// Optional tracking configuration for this task
Tracking *TrackingConfig `json:"tracking,omitempty"`
}
// TrackingConfig specifies experiment tracking tools to enable for a task.
type TrackingConfig struct {
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
}
// MLflowTrackingConfig controls MLflow integration.
type MLflowTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
}
// TensorBoardTrackingConfig controls TensorBoard integration.
type TensorBoardTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
}
// WandbTrackingConfig controls Weights & Biases integration.
type WandbTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
}
// Redis key constants
@ -44,5 +96,6 @@ var (
TaskPrefix = config.RedisTaskPrefix
TaskStatusPrefix = config.RedisTaskStatusPrefix
WorkerHeartbeat = config.RedisWorkerHeartbeat
WorkerPrewarmKey = config.RedisWorkerPrewarmKey
JobMetricsPrefix = config.RedisJobMetricsPrefix
)

View file

@ -0,0 +1,146 @@
// Package storage provides database abstraction and job management.
package storage
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/lib/pq" // PostgreSQL driver
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DBConfig holds database connection configuration.
type DBConfig struct {
Type string
Connection string
Host string
Port int
Username string
Password string
Database string
}
// DB wraps a database connection with type information.
type DB struct {
conn *sql.DB
dbType string
}
// DBTypeSQLite is the constant for SQLite database type
const DBTypeSQLite = "sqlite"
// NewDB creates a new database connection.
func NewDB(config DBConfig) (*DB, error) {
var conn *sql.DB
var err error
switch strings.ToLower(config.Type) {
case DBTypeSQLite:
conn, err = sql.Open("sqlite3", config.Connection)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Enable foreign keys
if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Enable WAL mode for better concurrency
if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil {
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
}
// Additional SQLite optimizations for throughput
if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil {
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil {
return nil, fmt.Errorf("failed to set cache size: %w", err)
}
if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil {
return nil, fmt.Errorf("failed to set temp store: %w", err)
}
case "postgres":
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
case "postgresql":
// Handle "postgresql" as alias for "postgres"
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
default:
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
}
// Optimize connection pool for better throughput
conn.SetMaxOpenConns(50) // Increase max open connections
conn.SetMaxIdleConns(25) // Maintain idle connections
conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime
conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
}
func buildPostgresConnectionString(config DBConfig) string {
if config.Connection != "" {
return config.Connection
}
var connStr strings.Builder
connStr.WriteString("host=")
if config.Host != "" {
connStr.WriteString(config.Host)
} else {
connStr.WriteString("localhost")
}
if config.Port > 0 {
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
} else {
connStr.WriteString(" port=5432")
}
if config.Username != "" {
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
}
if config.Password != "" {
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
}
if config.Database != "" {
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
} else {
connStr.WriteString(" dbname=fetch_ml")
}
connStr.WriteString(" sslmode=disable")
return connStr.String()
}
// NewDBFromPath creates a new database from a file path (legacy constructor).
func NewDBFromPath(dbPath string) (*DB, error) {
return NewDB(DBConfig{
Type: DBTypeSQLite,
Connection: dbPath,
})
}
// Initialize creates database schema.
func (db *DB) Initialize(schema string) error {
if _, err := db.conn.ExecContext(context.Background(), schema); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.conn.Close()
}

View file

@ -0,0 +1,564 @@
// Package storage provides database abstraction and job management.
package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Experiment struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Status string `json:"status"`
UserID string `json:"user_id,omitempty"`
WorkspaceID string `json:"workspace_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ExperimentEnvironment struct {
PythonVersion string `json:"python_version"`
CUDAVersion string `json:"cuda_version,omitempty"`
SystemOS string `json:"system_os"`
SystemArch string `json:"system_arch"`
Hostname string `json:"hostname"`
RequirementsHash string `json:"requirements_hash"`
CondaEnvHash string `json:"conda_env_hash,omitempty"`
Dependencies json.RawMessage `json:"dependencies,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ExperimentGitInfo struct {
CommitSHA string `json:"commit_sha"`
Branch string `json:"branch"`
RemoteURL string `json:"remote_url"`
IsDirty bool `json:"is_dirty"`
DiffPatch string `json:"diff_patch,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ExperimentSeeds struct {
Numpy *int64 `json:"numpy_seed,omitempty"`
Torch *int64 `json:"torch_seed,omitempty"`
TensorFlow *int64 `json:"tensorflow_seed,omitempty"`
Random *int64 `json:"random_seed,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type Dataset struct {
Name string `json:"name"`
URL string `json:"url"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ExperimentWithMetadata struct {
Experiment Experiment `json:"experiment"`
Environment *ExperimentEnvironment `json:"environment,omitempty"`
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
}
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
if exp == nil {
return fmt.Errorf("experiment is nil")
}
if exp.ID == "" {
return fmt.Errorf("experiment id is required")
}
if exp.Name == "" {
return fmt.Errorf("experiment name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
description = excluded.description,
status = excluded.status,
user_id = excluded.user_id,
workspace_id = excluded.workspace_id`
} else {
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
status = EXCLUDED.status,
user_id = EXCLUDED.user_id,
workspace_id = EXCLUDED.workspace_id`
}
_, err := db.conn.ExecContext(
ctx,
query,
exp.ID,
exp.Name,
exp.Description,
exp.Status,
exp.UserID,
exp.WorkspaceID,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentEnvironment(
ctx context.Context,
experimentID string,
env *ExperimentEnvironment,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if env == nil {
return fmt.Errorf("environment is nil")
}
deps := ""
if len(env.Dependencies) > 0 {
deps = string(env.Dependencies)
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
python_version = excluded.python_version,
cuda_version = excluded.cuda_version,
system_os = excluded.system_os,
system_arch = excluded.system_arch,
hostname = excluded.hostname,
requirements_hash = excluded.requirements_hash,
conda_env_hash = excluded.conda_env_hash,
dependencies = excluded.dependencies`
} else {
query = `INSERT INTO experiment_environments
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
requirements_hash, conda_env_hash, dependencies)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (experiment_id) DO UPDATE SET
python_version = EXCLUDED.python_version,
cuda_version = EXCLUDED.cuda_version,
system_os = EXCLUDED.system_os,
system_arch = EXCLUDED.system_arch,
hostname = EXCLUDED.hostname,
requirements_hash = EXCLUDED.requirements_hash,
conda_env_hash = EXCLUDED.conda_env_hash,
dependencies = EXCLUDED.dependencies`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
env.PythonVersion,
env.CUDAVersion,
env.SystemOS,
env.SystemArch,
env.Hostname,
env.RequirementsHash,
env.CondaEnvHash,
deps,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment environment: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentGitInfo(
ctx context.Context,
experimentID string,
info *ExperimentGitInfo,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if info == nil {
return fmt.Errorf("git info is nil")
}
isDirty := 0
if info.IsDirty {
isDirty = 1
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
commit_sha = excluded.commit_sha,
branch = excluded.branch,
remote_url = excluded.remote_url,
is_dirty = excluded.is_dirty,
diff_patch = excluded.diff_patch`
} else {
query = `INSERT INTO experiment_git_info
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (experiment_id) DO UPDATE SET
commit_sha = EXCLUDED.commit_sha,
branch = EXCLUDED.branch,
remote_url = EXCLUDED.remote_url,
is_dirty = EXCLUDED.is_dirty,
diff_patch = EXCLUDED.diff_patch`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
info.CommitSHA,
info.Branch,
info.RemoteURL,
isDirty,
info.DiffPatch,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment git info: %w", err)
}
return nil
}
func (db *DB) UpsertExperimentSeeds(
ctx context.Context,
experimentID string,
seeds *ExperimentSeeds,
) error {
if experimentID == "" {
return fmt.Errorf("experiment id is required")
}
if seeds == nil {
return fmt.Errorf("seeds is nil")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(experiment_id) DO UPDATE SET
numpy_seed = excluded.numpy_seed,
torch_seed = excluded.torch_seed,
tensorflow_seed = excluded.tensorflow_seed,
random_seed = excluded.random_seed`
} else {
query = `INSERT INTO experiment_seeds
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (experiment_id) DO UPDATE SET
numpy_seed = EXCLUDED.numpy_seed,
torch_seed = EXCLUDED.torch_seed,
tensorflow_seed = EXCLUDED.tensorflow_seed,
random_seed = EXCLUDED.random_seed`
}
_, err := db.conn.ExecContext(
ctx,
query,
experimentID,
seeds.Numpy,
seeds.Torch,
seeds.TensorFlow,
seeds.Random,
)
if err != nil {
return fmt.Errorf("failed to upsert experiment seeds: %w", err)
}
return nil
}
func (db *DB) GetExperimentWithMetadata(
ctx context.Context,
experimentID string,
) (*ExperimentWithMetadata, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment id is required")
}
var exp Experiment
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = ?`
} else {
query = `SELECT id, name,
COALESCE(description, ''), COALESCE(status, ''),
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
created_at, updated_at
FROM experiments WHERE id = $1`
}
err := db.conn.QueryRowContext(ctx, query, experimentID).Scan(
&exp.ID,
&exp.Name,
&exp.Description,
&exp.Status,
&exp.UserID,
&exp.WorkspaceID,
&exp.CreatedAt,
&exp.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get experiment: %w", err)
}
result := &ExperimentWithMetadata{Experiment: exp}
var env ExperimentEnvironment
var envDeps sql.NullString
var envQuery string
if db.dbType == DBTypeSQLite {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = ?`
} else {
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
dependencies, created_at
FROM experiment_environments WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan(
&env.PythonVersion,
&env.CUDAVersion,
&env.SystemOS,
&env.SystemArch,
&env.Hostname,
&env.RequirementsHash,
&env.CondaEnvHash,
&envDeps,
&env.CreatedAt,
); err == nil {
if envDeps.Valid && envDeps.String != "" {
env.Dependencies = json.RawMessage(envDeps.String)
}
result.Environment = &env
}
var git ExperimentGitInfo
var gitDirty sql.NullInt64
var gitQuery string
if db.dbType == DBTypeSQLite {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = ?`
} else {
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
COALESCE(diff_patch, ''), created_at
FROM experiment_git_info WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan(
&git.CommitSHA,
&git.Branch,
&git.RemoteURL,
&gitDirty,
&git.DiffPatch,
&git.CreatedAt,
); err == nil {
git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0
result.GitInfo = &git
}
var seeds ExperimentSeeds
var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64
var seedsQuery string
if db.dbType == DBTypeSQLite {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = ?`
} else {
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
FROM experiment_seeds WHERE experiment_id = $1`
}
if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan(
&numpySeed,
&torchSeed,
&tfSeed,
&randSeed,
&seeds.CreatedAt,
); err == nil {
if numpySeed.Valid {
v := numpySeed.Int64
seeds.Numpy = &v
}
if torchSeed.Valid {
v := torchSeed.Int64
seeds.Torch = &v
}
if tfSeed.Valid {
v := tfSeed.Int64
seeds.TensorFlow = &v
}
if randSeed.Valid {
v := randSeed.Int64
seeds.Random = &v
}
result.Seeds = &seeds
}
return result, nil
}
func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error {
if ds == nil {
return fmt.Errorf("dataset is nil")
}
if ds.Name == "" {
return fmt.Errorf("dataset name is required")
}
if ds.URL == "" {
return fmt.Errorf("dataset url is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO datasets (name, url)
VALUES (?, ?)
ON CONFLICT(name) DO UPDATE SET
url = excluded.url`
} else {
query = `INSERT INTO datasets (name, url)
VALUES ($1, $2)
ON CONFLICT (name) DO UPDATE SET
url = EXCLUDED.url`
}
if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil {
return fmt.Errorf("failed to upsert dataset: %w", err)
}
return nil
}
func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
if name == "" {
return nil, fmt.Errorf("dataset name is required")
}
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?`
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1`
}
var ds Dataset
if err := db.conn.QueryRowContext(ctx, query, name).Scan(
&ds.Name,
&ds.URL,
&ds.CreatedAt,
&ds.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, err
}
return nil, fmt.Errorf("failed to get dataset: %w", err)
}
return &ds, nil
}
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
var args []interface{}
if limit > 0 {
if db.dbType == DBTypeSQLite {
query += " LIMIT ?"
} else {
query += " LIMIT $1"
}
args = append(args, limit)
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}
func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) {
if term == "" {
return []*Dataset{}, nil
}
// Escape %/_ for LIKE and use parameterized query.
escaped := strings.ReplaceAll(term, "\\", "\\\\")
escaped = strings.ReplaceAll(escaped, "%", "\\%")
escaped = strings.ReplaceAll(escaped, "_", "\\_")
pattern := "%" + escaped + "%"
var query string
var args []interface{}
if db.dbType == DBTypeSQLite {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE ? ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
}
} else {
query = `SELECT name, url, created_at, updated_at FROM datasets
WHERE name LIKE $1 ESCAPE '\'
ORDER BY name ASC`
args = append(args, pattern)
if limit > 0 {
query += " LIMIT $2"
args = append(args, limit)
}
}
rows, err := db.conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to search datasets: %w", err)
}
defer func() { _ = rows.Close() }()
var out []*Dataset
for rows.Next() {
var ds Dataset
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan dataset: %w", err)
}
out = append(out, &ds)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating datasets: %w", err)
}
return out, nil
}

View file

@ -1,4 +1,3 @@
// Package storage provides database abstraction and job management.
package storage
import (
@ -6,133 +5,9 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
_ "github.com/lib/pq" // PostgreSQL driver
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DBConfig holds database connection configuration.
type DBConfig struct {
Type string
Connection string
Host string
Port int
Username string
Password string
Database string
}
// DB wraps a database connection with type information.
type DB struct {
conn *sql.DB
dbType string
}
// DBTypeSQLite is the constant for SQLite database type
const DBTypeSQLite = "sqlite"
// NewDB creates a new database connection.
func NewDB(config DBConfig) (*DB, error) {
var conn *sql.DB
var err error
switch strings.ToLower(config.Type) {
case DBTypeSQLite:
conn, err = sql.Open("sqlite3", config.Connection)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Enable foreign keys
if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Enable WAL mode for better concurrency
if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil {
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
}
// Additional SQLite optimizations for throughput
if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil {
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil {
return nil, fmt.Errorf("failed to set cache size: %w", err)
}
if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil {
return nil, fmt.Errorf("failed to set temp store: %w", err)
}
case "postgres":
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
case "postgresql":
// Handle "postgresql" as alias for "postgres"
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
default:
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
}
// Optimize connection pool for better throughput
conn.SetMaxOpenConns(50) // Increase max open connections
conn.SetMaxIdleConns(25) // Maintain idle connections
conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime
conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
}
func buildPostgresConnectionString(config DBConfig) string {
if config.Connection != "" {
return config.Connection
}
var connStr strings.Builder
connStr.WriteString("host=")
if config.Host != "" {
connStr.WriteString(config.Host)
} else {
connStr.WriteString("localhost")
}
if config.Port > 0 {
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
} else {
connStr.WriteString(" port=5432")
}
if config.Username != "" {
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
}
if config.Password != "" {
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
}
if config.Database != "" {
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
} else {
connStr.WriteString(" dbname=fetch_ml")
}
connStr.WriteString(" sslmode=disable")
return connStr.String()
}
// NewDBFromPath creates a new database from a file path (legacy constructor).
func NewDBFromPath(dbPath string) (*DB, error) {
return NewDB(DBConfig{
Type: DBTypeSQLite,
Connection: dbPath,
})
}
// Job represents a machine learning job in the system.
type Job struct {
ID string `json:"id"`
@ -161,19 +36,6 @@ type Worker struct {
Metadata map[string]string `json:"metadata,omitempty"`
}
// Initialize creates database schema.
func (db *DB) Initialize(schema string) error {
if _, err := db.conn.ExecContext(context.Background(), schema); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.conn.Close()
}
// CreateJob inserts a new job into the database.
func (db *DB) CreateJob(job *Job) error {
datasetsJSON, _ := json.Marshal(job.Datasets)
@ -185,11 +47,20 @@ func (db *DB) CreateJob(job *Job) error {
VALUES (?, ?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)`
VALUES ($1, $2, $3, $4, $5, $6, $7)`
}
_, err := db.conn.ExecContext(context.Background(), query, job.ID, job.JobName, job.Args, job.Status,
job.Priority, string(datasetsJSON), string(metadataJSON))
_, err := db.conn.ExecContext(
context.Background(),
query,
job.ID,
job.JobName,
job.Args,
job.Status,
job.Priority,
string(datasetsJSON),
string(metadataJSON),
)
if err != nil {
return fmt.Errorf("failed to create job: %w", err)
}
@ -242,21 +113,30 @@ func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?,
started_at = CASE WHEN ? = 'running' AND started_at IS NULL
THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL
THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = ?`
started_at = CASE WHEN ? = 'running' AND started_at IS NULL
THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN ? IN ('completed', 'failed')
AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = ?`
} else {
query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3,
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL
THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL
THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = $6`
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL
THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN $5 IN ('completed', 'failed')
AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = $6`
}
_, err := db.conn.ExecContext(context.Background(), query, status, workerID, errorMsg, status, status, id)
_, err := db.conn.ExecContext(
context.Background(),
query,
status,
workerID,
errorMsg,
status,
status,
id,
)
if err != nil {
return fmt.Errorf("failed to update job status: %w", err)
}
@ -339,17 +219,25 @@ func (db *DB) RegisterWorker(worker *Worker) error {
VALUES (?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
hostname = EXCLUDED.hostname,
status = EXCLUDED.status,
current_jobs = EXCLUDED.current_jobs,
max_jobs = EXCLUDED.max_jobs,
metadata = EXCLUDED.metadata`
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
hostname = EXCLUDED.hostname,
status = EXCLUDED.status,
current_jobs = EXCLUDED.current_jobs,
max_jobs = EXCLUDED.max_jobs,
metadata = EXCLUDED.metadata`
}
_, err := db.conn.ExecContext(context.Background(), query, worker.ID, worker.Hostname, worker.Status,
worker.CurrentJobs, worker.MaxJobs, string(metadataJSON))
_, err := db.conn.ExecContext(
context.Background(),
query,
worker.ID,
worker.Hostname,
worker.Status,
worker.CurrentJobs,
worker.MaxJobs,
string(metadataJSON),
)
if err != nil {
return fmt.Errorf("failed to register worker: %w", err)
}
@ -377,10 +265,14 @@ func (db *DB) GetActiveWorkers() ([]*Worker, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')`
FROM workers
WHERE status = 'active'
AND last_heartbeat > datetime('now', '-30 seconds')`
} else {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
FROM workers
WHERE status = 'active'
AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
}
rows, err := db.conn.QueryContext(context.Background(), query)

View file

@ -0,0 +1,32 @@
package storage
import (
"embed"
"fmt"
"strings"
)
//go:embed schema_sqlite.sql
var sqliteSchemaFS embed.FS
//go:embed schema_postgres.sql
var postgresSchemaFS embed.FS
func SchemaForDBType(dbType string) (string, error) {
switch strings.ToLower(dbType) {
case DBTypeSQLite:
b, err := sqliteSchemaFS.ReadFile("schema_sqlite.sql")
if err != nil {
return "", fmt.Errorf("failed to read sqlite schema: %w", err)
}
return string(b), nil
case "postgres", "postgresql":
b, err := postgresSchemaFS.ReadFile("schema_postgres.sql")
if err != nil {
return "", fmt.Errorf("failed to read postgres schema: %w", err)
}
return string(b), nil
default:
return "", fmt.Errorf("unsupported database type: %s", dbType)
}
}

View file

@ -43,6 +43,13 @@ CREATE TABLE IF NOT EXISTS system_metrics (
PRIMARY KEY (metric_name, timestamp)
);
CREATE TABLE IF NOT EXISTS datasets (
name TEXT PRIMARY KEY,
url TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
-- Indexes for performance
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
@ -52,6 +59,8 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
-- Function to update updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$

View file

@ -43,6 +43,59 @@ CREATE TABLE IF NOT EXISTS system_metrics (
PRIMARY KEY (metric_name, timestamp)
);
CREATE TABLE IF NOT EXISTS experiments (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
status TEXT DEFAULT 'pending',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
user_id TEXT,
workspace_id TEXT
);
CREATE TABLE IF NOT EXISTS experiment_environments (
experiment_id TEXT PRIMARY KEY,
python_version TEXT,
cuda_version TEXT,
system_os TEXT,
system_arch TEXT,
hostname TEXT,
requirements_hash TEXT,
conda_env_hash TEXT,
dependencies TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS experiment_git_info (
experiment_id TEXT PRIMARY KEY,
commit_sha TEXT,
branch TEXT,
remote_url TEXT,
is_dirty INTEGER DEFAULT 0,
diff_patch TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS experiment_seeds (
experiment_id TEXT PRIMARY KEY,
numpy_seed INTEGER,
torch_seed INTEGER,
tensorflow_seed INTEGER,
random_seed INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS datasets (
name TEXT PRIMARY KEY,
url TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Indexes for performance
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
@ -51,6 +104,11 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_experiments_created_at ON experiments(created_at);
CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status);
CREATE INDEX IF NOT EXISTS idx_experiments_user_id ON experiments(user_id);
CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
-- Triggers to update timestamps
CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
@ -59,3 +117,17 @@ CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
BEGIN
UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;
CREATE TRIGGER IF NOT EXISTS update_experiments_timestamp
AFTER UPDATE ON experiments
FOR EACH ROW
BEGIN
UPDATE experiments SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;
CREATE TRIGGER IF NOT EXISTS update_datasets_timestamp
AFTER UPDATE ON datasets
FOR EACH ROW
BEGIN
UPDATE datasets SET updated_at = CURRENT_TIMESTAMP WHERE name = NEW.name;
END;