diff --git a/internal/queue/backend.go b/internal/queue/backend.go new file mode 100644 index 0000000..2b2f088 --- /dev/null +++ b/internal/queue/backend.go @@ -0,0 +1,76 @@ +package queue + +import ( + "errors" + "time" +) + +var ErrInvalidQueueBackend = errors.New("invalid queue backend") + +type Backend interface { + AddTask(task *Task) error + GetNextTask() (*Task, error) + PeekNextTask() (*Task, error) + + GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) + GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) + RenewLease(taskID string, workerID string, leaseDuration time.Duration) error + ReleaseLease(taskID string, workerID string) error + + RetryTask(task *Task) error + MoveToDeadLetterQueue(task *Task, reason string) error + + GetTask(taskID string) (*Task, error) + GetAllTasks() ([]*Task, error) + GetTaskByName(jobName string) (*Task, error) + CancelTask(taskID string) error + + UpdateTask(task *Task) error + UpdateTaskWithMetrics(task *Task, action string) error + + RecordMetric(jobName, metric string, value float64) error + Heartbeat(workerID string) error + QueueDepth() (int64, error) + + SetWorkerPrewarmState(state PrewarmState) error + ClearWorkerPrewarmState(workerID string) error + GetWorkerPrewarmState(workerID string) (*PrewarmState, error) + GetAllWorkerPrewarmStates() ([]PrewarmState, error) + + SignalPrewarmGC() error + PrewarmGCRequestValue() (string, error) + + Close() error +} + +type QueueBackend string + +const ( + QueueBackendRedis QueueBackend = "redis" + QueueBackendSQLite QueueBackend = "sqlite" +) + +type BackendConfig struct { + Backend QueueBackend + RedisAddr string + RedisPassword string + RedisDB int + SQLitePath string + MetricsFlushInterval time.Duration +} + +func NewBackend(cfg BackendConfig) (Backend, error) { + switch cfg.Backend { + case "", QueueBackendRedis: + return NewTaskQueue(Config{ + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + MetricsFlushInterval: cfg.MetricsFlushInterval, + }) + case QueueBackendSQLite: + return NewSQLiteQueue(cfg.SQLitePath) + default: + return nil, ErrInvalidQueueBackend + } +} diff --git a/internal/queue/errors.go b/internal/queue/errors.go index 1c179f5..aa0b030 100644 --- a/internal/queue/errors.go +++ b/internal/queue/errors.go @@ -174,8 +174,10 @@ func IsRetryable(category ErrorCategory) bool { // GetUserMessage returns a user-friendly error message with suggestions func GetUserMessage(category ErrorCategory, err error) string { messages := map[ErrorCategory]string{ - ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.", - ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.", + ErrorNetwork: "Network connectivity issue. Please check your network " + + "connection and try again.", + ErrorResource: "System resource exhausted. The system may be under heavy load. " + + "Try again later or contact support.", ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.", ErrorAuth: "Authentication failed. Please check your API key or credentials.", ErrorValidation: "Invalid input. Please review your request and correct any errors.", diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 1c61c8d..c83459d 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -15,6 +15,7 @@ const ( defaultLeaseDuration = 30 * time.Minute defaultMaxRetries = 3 defaultBlockTimeout = 1 * time.Second + PrewarmGCRequestKey = "ml:prewarm:gc_request" ) // TaskQueue manages ML experiment tasks via Redis @@ -33,6 +34,21 @@ type metricEvent struct { Value float64 } +type PrewarmState struct { + WorkerID string `json:"worker_id"` + TaskID string `json:"task_id"` + SnapshotID string `json:"snapshot_id,omitempty"` + StartedAt string `json:"started_at"` + UpdatedAt string `json:"updated_at"` + Phase string `json:"phase"` + EnvImage string `json:"env_image,omitempty"` + DatasetCnt int `json:"dataset_count"` + EnvHit int64 `json:"env_hit,omitempty"` + EnvMiss int64 `json:"env_miss,omitempty"` + EnvBuilt int64 `json:"env_built,omitempty"` + EnvTimeNs int64 `json:"env_time_ns,omitempty"` +} + // Config holds configuration for TaskQueue type Config struct { RedisAddr string @@ -155,7 +171,10 @@ func isBlockingUnsupported(err error) bool { return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax") } -func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) { +func (tq *TaskQueue) pollUntilDeadline( + workerID string, + leaseDuration, blockTimeout time.Duration, +) (*Task, error) { deadline := time.Now().Add(blockTimeout) sleep := 25 * time.Millisecond @@ -194,8 +213,98 @@ func (tq *TaskQueue) GetNextTask() (*Task, error) { return tq.GetTask(taskID) } +// PeekNextTask returns the highest priority task without removing it from the queue. +// This is intended for best-effort prewarm logic; it must never be required for correctness. +func (tq *TaskQueue) PeekNextTask() (*Task, error) { + // ZRANGE with REV gives highest scores first. + ids, err := tq.client.ZRevRange(tq.ctx, TaskQueueKey, 0, 0).Result() + if err != nil { + return nil, err + } + if len(ids) == 0 { + return nil, nil + } + return tq.GetTask(ids[0]) +} + +func (tq *TaskQueue) SetWorkerPrewarmState(state PrewarmState) error { + if state.WorkerID == "" { + return fmt.Errorf("missing worker_id") + } + key := WorkerPrewarmKey + state.WorkerID + data, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal prewarm state: %w", err) + } + // Keep short TTL to avoid stale prewarm state if worker dies. + return tq.client.Set(tq.ctx, key, data, 30*time.Second).Err() +} + +func (tq *TaskQueue) ClearWorkerPrewarmState(workerID string) error { + if workerID == "" { + return fmt.Errorf("missing worker_id") + } + key := WorkerPrewarmKey + workerID + return tq.client.Del(tq.ctx, key).Err() +} + +func (tq *TaskQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) { + if workerID == "" { + return nil, fmt.Errorf("missing worker_id") + } + key := WorkerPrewarmKey + workerID + v, err := tq.client.Get(tq.ctx, key).Result() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, err + } + var state PrewarmState + if err := json.Unmarshal([]byte(v), &state); err != nil { + return nil, fmt.Errorf("unmarshal prewarm state: %w", err) + } + return &state, nil +} + +func (tq *TaskQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) { + var cursor uint64 + pattern := WorkerPrewarmKey + "*" + out := make([]PrewarmState, 0, 8) + + for { + keys, next, err := tq.client.Scan(tq.ctx, cursor, pattern, 50).Result() + if err != nil { + return nil, err + } + cursor = next + for _, key := range keys { + v, err := tq.client.Get(tq.ctx, key).Result() + if err == redis.Nil { + continue + } + if err != nil { + return nil, err + } + var state PrewarmState + if err := json.Unmarshal([]byte(v), &state); err != nil { + return nil, fmt.Errorf("unmarshal prewarm state: %w", err) + } + out = append(out, state) + } + if cursor == 0 { + break + } + } + + return out, nil +} + // GetNextTaskWithLease gets the next task and acquires a lease -func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { +func (tq *TaskQueue) GetNextTaskWithLease( + workerID string, + leaseDuration time.Duration, +) (*Task, error) { if leaseDuration == 0 { leaseDuration = defaultLeaseDuration } @@ -239,7 +348,8 @@ func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Du return task, nil } -// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease. +// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task. +// Once a task is received, it then acquires a lease. func (tq *TaskQueue) GetNextTaskWithLeaseBlocking( workerID string, leaseDuration, blockTimeout time.Duration, @@ -580,6 +690,23 @@ func (tq *TaskQueue) Close() error { return tq.client.Close() } +func (tq *TaskQueue) SignalPrewarmGC() error { + return tq.client.Set( + tq.ctx, + PrewarmGCRequestKey, + time.Now().UnixNano(), + 10*time.Minute, + ).Err() +} + +func (tq *TaskQueue) PrewarmGCRequestValue() (string, error) { + v, err := tq.client.Get(tq.ctx, PrewarmGCRequestKey).Result() + if err == redis.Nil { + return "", nil + } + return v, err +} + // GetRedisClient returns the underlying Redis client for direct access func (tq *TaskQueue) GetRedisClient() *redis.Client { return tq.client diff --git a/internal/queue/sqlite_queue.go b/internal/queue/sqlite_queue.go new file mode 100644 index 0000000..af960fd --- /dev/null +++ b/internal/queue/sqlite_queue.go @@ -0,0 +1,762 @@ +package queue + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type SQLiteQueue struct { + db *sql.DB + ctx context.Context + cancel context.CancelFunc +} + +func NewSQLiteQueue(path string) (*SQLiteQueue, error) { + if path == "" { + return nil, fmt.Errorf("sqlite queue path is required") + } + + db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=5000&_foreign_keys=on", path)) + if err != nil { + return nil, err + } + + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + ctx, cancel := context.WithCancel(context.Background()) + q := &SQLiteQueue{db: db, ctx: ctx, cancel: cancel} + + if err := q.initSchema(); err != nil { + _ = db.Close() + cancel() + return nil, err + } + + go q.leaseReclaimer() + go q.kvJanitor() + return q, nil +} + +func (q *SQLiteQueue) initSchema() error { + stmts := []string{ + "PRAGMA journal_mode=WAL;", + "PRAGMA synchronous=NORMAL;", + `CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + job_name TEXT, + status TEXT, + priority INTEGER, + created_at INTEGER, + updated_at INTEGER, + payload BLOB + );`, + "CREATE INDEX IF NOT EXISTS idx_tasks_job_name ON tasks(job_name);", + "CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);", + `CREATE TABLE IF NOT EXISTS queue ( + task_id TEXT PRIMARY KEY, + priority INTEGER, + available_at INTEGER, + created_at INTEGER + );`, + "CREATE INDEX IF NOT EXISTS idx_queue_available ON queue(available_at, priority DESC, created_at);", + `CREATE TABLE IF NOT EXISTS worker_prewarm ( + worker_id TEXT PRIMARY KEY, + payload BLOB, + updated_at INTEGER + );`, + "CREATE INDEX IF NOT EXISTS idx_prewarm_updated ON worker_prewarm(updated_at);", + `CREATE TABLE IF NOT EXISTS kv ( + key TEXT PRIMARY KEY, + value TEXT, + expires_at INTEGER + );`, + "CREATE INDEX IF NOT EXISTS idx_kv_expires ON kv(expires_at);", + `CREATE TABLE IF NOT EXISTS worker_heartbeat ( + worker_id TEXT PRIMARY KEY, + last_seen INTEGER + );`, + } + + for _, stmt := range stmts { + if _, err := q.db.ExecContext(q.ctx, stmt); err != nil { + return fmt.Errorf("sqlite schema init failed: %w", err) + } + } + return nil +} + +func (q *SQLiteQueue) Close() error { + q.cancel() + return q.db.Close() +} + +func (q *SQLiteQueue) AddTask(task *Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if task.ID == "" { + return fmt.Errorf("task id is required") + } + if task.JobName == "" { + return fmt.Errorf("job name is required") + } + + if task.MaxRetries == 0 { + task.MaxRetries = defaultMaxRetries + } + + now := time.Now().UTC() + if task.CreatedAt.IsZero() { + task.CreatedAt = now + } + + payload, err := json.Marshal(task) + if err != nil { + return err + } + + createdAt := task.CreatedAt.UnixNano() + updatedAt := now.UnixNano() + + availableAt := now.UnixNano() + if task.NextRetry != nil { + availableAt = task.NextRetry.UTC().UnixNano() + } + + tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + _, err = tx.ExecContext( + q.ctx, + `INSERT INTO tasks(id, job_name, status, priority, created_at, updated_at, payload) + VALUES(?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + job_name=excluded.job_name, + status=excluded.status, + priority=excluded.priority, + created_at=excluded.created_at, + updated_at=excluded.updated_at, + payload=excluded.payload`, + task.ID, + task.JobName, + task.Status, + task.Priority, + createdAt, + updatedAt, + payload, + ) + if err != nil { + return err + } + + if task.Status == "queued" { + _, err = tx.ExecContext( + q.ctx, + `INSERT INTO queue(task_id, priority, available_at, created_at) + VALUES(?, ?, ?, ?) + ON CONFLICT(task_id) DO UPDATE SET + priority=excluded.priority, + available_at=excluded.available_at, + created_at=excluded.created_at`, + task.ID, + task.Priority, + availableAt, + createdAt, + ) + if err != nil { + return err + } + } else { + _, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID) + if err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + TasksQueued.Inc() + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + return nil +} + +func (q *SQLiteQueue) GetTask(taskID string) (*Task, error) { + row := q.db.QueryRowContext(q.ctx, "SELECT payload FROM tasks WHERE id = ?", taskID) + var payload []byte + if err := row.Scan(&payload); err != nil { + return nil, err + } + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + return nil, err + } + return &t, nil +} + +func (q *SQLiteQueue) GetAllTasks() ([]*Task, error) { + rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks") + if err != nil { + return nil, err + } + defer rows.Close() + + out := make([]*Task, 0, 32) + for rows.Next() { + var payload []byte + if err := rows.Scan(&payload); err != nil { + return nil, err + } + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + return nil, err + } + out = append(out, &t) + } + return out, rows.Err() +} + +func (q *SQLiteQueue) GetTaskByName(jobName string) (*Task, error) { + row := q.db.QueryRowContext( + q.ctx, + "SELECT payload FROM tasks WHERE job_name = ? ORDER BY created_at DESC LIMIT 1", + jobName, + ) + var payload []byte + if err := row.Scan(&payload); err != nil { + return nil, err + } + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + return nil, err + } + return &t, nil +} + +func (q *SQLiteQueue) UpdateTask(task *Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + + payload, err := json.Marshal(task) + if err != nil { + return err + } + + now := time.Now().UTC().UnixNano() + availableAt := time.Now().UTC().UnixNano() + if task.NextRetry != nil { + availableAt = task.NextRetry.UTC().UnixNano() + } + + tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + _, err = tx.ExecContext( + q.ctx, + `UPDATE tasks SET job_name=?, status=?, priority=?, updated_at=?, payload=? WHERE id=?`, + task.JobName, + task.Status, + task.Priority, + now, + payload, + task.ID, + ) + if err != nil { + return err + } + + if task.Status == "queued" { + _, err = tx.ExecContext( + q.ctx, + `INSERT INTO queue(task_id, priority, available_at, created_at) + VALUES(?, ?, ?, ?) + ON CONFLICT(task_id) DO UPDATE SET + priority=excluded.priority, + available_at=excluded.available_at`, + task.ID, + task.Priority, + availableAt, + task.CreatedAt.UTC().UnixNano(), + ) + if err != nil { + return err + } + } else { + _, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (q *SQLiteQueue) UpdateTaskWithMetrics(task *Task, _ string) error { + return q.UpdateTask(task) +} + +func (q *SQLiteQueue) GetNextTask() (*Task, error) { + t, _, err := q.peekOrPop(false) + return t, err +} + +func (q *SQLiteQueue) PeekNextTask() (*Task, error) { + t, _, err := q.peekOrPop(true) + return t, err +} + +func (q *SQLiteQueue) peekOrPop(peek bool) (*Task, int64, error) { + now := time.Now().UTC().UnixNano() + + query := `SELECT q.task_id, t.payload FROM queue q + JOIN tasks t ON t.id = q.task_id + WHERE q.available_at <= ? + ORDER BY q.priority DESC, q.created_at ASC + LIMIT 1` + + row := q.db.QueryRowContext(q.ctx, query, now) + var taskID string + var payload []byte + if err := row.Scan(&taskID, &payload); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, 0, nil + } + return nil, 0, err + } + + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + return nil, 0, err + } + + if !peek { + if _, err := q.db.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil { + return nil, 0, err + } + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + } + + return &t, 0, nil +} + +func (q *SQLiteQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { + return q.claimTask(workerID, leaseDuration) +} + +func (q *SQLiteQueue) GetNextTaskWithLeaseBlocking( + workerID string, + leaseDuration, blockTimeout time.Duration, +) (*Task, error) { + if blockTimeout <= 0 { + blockTimeout = defaultBlockTimeout + } + deadline := time.Now().Add(blockTimeout) + for { + t, err := q.claimTask(workerID, leaseDuration) + if err != nil { + return nil, err + } + if t != nil { + return t, nil + } + if time.Now().After(deadline) { + return nil, nil + } + time.Sleep(50 * time.Millisecond) + } +} + +func (q *SQLiteQueue) claimTask(workerID string, leaseDuration time.Duration) (*Task, error) { + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + if workerID == "" { + return nil, fmt.Errorf("worker_id is required") + } + + now := time.Now().UTC() + nowNs := now.UnixNano() + + tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + return nil, err + } + defer func() { _ = tx.Rollback() }() + + row := tx.QueryRowContext( + q.ctx, + `SELECT q.task_id, t.payload FROM queue q + JOIN tasks t ON t.id = q.task_id + WHERE q.available_at <= ? + ORDER BY q.priority DESC, q.created_at ASC + LIMIT 1`, + nowNs, + ) + + var taskID string + var payload []byte + if err := row.Scan(&taskID, &payload); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + return nil, err + } + + // Mirror Redis semantics: acquire lease fields but do not modify status here. + exp := now.Add(leaseDuration) + t.LeaseExpiry = &exp + t.LeasedBy = workerID + + newPayload, err := json.Marshal(&t) + if err != nil { + return nil, err + } + + if _, err := tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil { + return nil, err + } + if _, err := tx.ExecContext( + q.ctx, + "UPDATE tasks SET updated_at=?, payload=? WHERE id=?", + nowNs, + newPayload, + taskID, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + return &t, nil +} + +func (q *SQLiteQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error { + t, err := q.GetTask(taskID) + if err != nil { + return err + } + if t.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", t.LeasedBy) + } + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + exp := time.Now().UTC().Add(leaseDuration) + t.LeaseExpiry = &exp + RecordLeaseRenewal(workerID) + return q.UpdateTask(t) +} + +func (q *SQLiteQueue) ReleaseLease(taskID string, workerID string) error { + t, err := q.GetTask(taskID) + if err != nil { + return err + } + if t.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", t.LeasedBy) + } + t.LeaseExpiry = nil + t.LeasedBy = "" + return q.UpdateTask(t) +} + +func (q *SQLiteQueue) RetryTask(task *Task) error { + if task.RetryCount >= task.MaxRetries { + RecordDLQAddition("max_retries") + return q.MoveToDeadLetterQueue(task, "max retries exceeded") + } + + errorCategory := ErrorUnknown + if task.Error != "" { + errorCategory = ClassifyError(fmt.Errorf("%s", task.Error)) + } + if !IsRetryable(errorCategory) { + RecordDLQAddition(string(errorCategory)) + return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory)) + } + + task.RetryCount++ + task.Status = "queued" + task.LastError = task.Error + task.Error = "" + + backoffSeconds := RetryDelay(errorCategory, task.RetryCount) + nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second) + task.NextRetry = &nextRetry + task.LeaseExpiry = nil + task.LeasedBy = "" + + RecordTaskRetry(task.JobName, errorCategory) + return q.AddTask(task) +} + +func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error { + task.Status = "failed" + task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError) + + RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError))) + return q.UpdateTask(task) +} + +func (q *SQLiteQueue) CancelTask(taskID string) error { + t, err := q.GetTask(taskID) + if err != nil { + return err + } + t.Status = "cancelled" + now := time.Now().UTC() + t.EndedAt = &now + return q.UpdateTask(t) +} + +func (q *SQLiteQueue) RecordMetric(_, _ string, _ float64) error { + return nil +} + +func (q *SQLiteQueue) Heartbeat(workerID string) error { + if workerID == "" { + return fmt.Errorf("worker_id is required") + } + _, err := q.db.ExecContext( + q.ctx, + `INSERT INTO worker_heartbeat(worker_id, last_seen) VALUES(?, ?) + ON CONFLICT(worker_id) DO UPDATE SET last_seen=excluded.last_seen`, + workerID, + time.Now().UTC().UnixNano(), + ) + return err +} + +func (q *SQLiteQueue) QueueDepth() (int64, error) { + row := q.db.QueryRowContext(q.ctx, "SELECT COUNT(1) FROM queue") + var n int64 + if err := row.Scan(&n); err != nil { + return 0, err + } + return n, nil +} + +func (q *SQLiteQueue) SetWorkerPrewarmState(state PrewarmState) error { + if state.WorkerID == "" { + return fmt.Errorf("missing worker_id") + } + if state.StartedAt == "" { + state.StartedAt = time.Now().UTC().Format(time.RFC3339) + } + state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + + payload, err := json.Marshal(state) + if err != nil { + return err + } + _, err = q.db.ExecContext( + q.ctx, + `INSERT INTO worker_prewarm(worker_id, payload, updated_at) VALUES(?, ?, ?) + ON CONFLICT(worker_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at`, + state.WorkerID, + payload, + time.Now().UTC().UnixNano(), + ) + return err +} + +func (q *SQLiteQueue) ClearWorkerPrewarmState(workerID string) error { + if workerID == "" { + return fmt.Errorf("missing worker_id") + } + _, err := q.db.ExecContext(q.ctx, "DELETE FROM worker_prewarm WHERE worker_id = ?", workerID) + return err +} + +func (q *SQLiteQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) { + if workerID == "" { + return nil, fmt.Errorf("missing worker_id") + } + row := q.db.QueryRowContext(q.ctx, "SELECT payload, updated_at FROM worker_prewarm WHERE worker_id = ?", workerID) + var payload []byte + var updatedAt int64 + if err := row.Scan(&payload, &updatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + + // Mirror Redis TTL semantics (30s). + if time.Since(time.Unix(0, updatedAt)) > 30*time.Second { + return nil, nil + } + + var st PrewarmState + if err := json.Unmarshal(payload, &st); err != nil { + return nil, err + } + return &st, nil +} + +func (q *SQLiteQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) { + cutoff := time.Now().UTC().Add(-30 * time.Second).UnixNano() + rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM worker_prewarm WHERE updated_at >= ?", cutoff) + if err != nil { + return nil, err + } + defer rows.Close() + + out := make([]PrewarmState, 0, 8) + for rows.Next() { + var payload []byte + if err := rows.Scan(&payload); err != nil { + return nil, err + } + var st PrewarmState + if err := json.Unmarshal(payload, &st); err != nil { + return nil, err + } + out = append(out, st) + } + return out, rows.Err() +} + +func (q *SQLiteQueue) SignalPrewarmGC() error { + return q.kvSet(PrewarmGCRequestKey, fmt.Sprintf("%d", time.Now().UTC().UnixNano()), 10*time.Minute) +} + +func (q *SQLiteQueue) PrewarmGCRequestValue() (string, error) { + v, _, err := q.kvGet(PrewarmGCRequestKey) + return v, err +} + +func (q *SQLiteQueue) kvSet(key, value string, ttl time.Duration) error { + exp := time.Now().UTC().Add(ttl).UnixNano() + _, err := q.db.ExecContext( + q.ctx, + `INSERT INTO kv(key, value, expires_at) VALUES(?, ?, ?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value, expires_at=excluded.expires_at`, + key, + value, + exp, + ) + return err +} + +func (q *SQLiteQueue) kvGet(key string) (string, bool, error) { + row := q.db.QueryRowContext(q.ctx, "SELECT value, expires_at FROM kv WHERE key = ?", key) + var value string + var exp int64 + if err := row.Scan(&value, &exp); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + return "", false, err + } + if exp > 0 && time.Now().UTC().UnixNano() > exp { + _, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE key = ?", key) + return "", false, nil + } + return value, true, nil +} + +func (q *SQLiteQueue) kvJanitor() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-q.ctx.Done(): + return + case <-ticker.C: + _, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE expires_at > 0 AND expires_at < ?", time.Now().UTC().UnixNano()) + } + } +} + +func (q *SQLiteQueue) leaseReclaimer() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-q.ctx.Done(): + return + case <-ticker.C: + _ = q.reclaimExpiredLeases() + } + } +} + +func (q *SQLiteQueue) ReclaimExpiredLeases() error { + return q.reclaimExpiredLeases() +} + +func (q *SQLiteQueue) reclaimExpiredLeases() error { + now := time.Now().UTC() + rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks") + if err != nil { + return err + } + + payloads := make([][]byte, 0, 32) + for rows.Next() { + var payload []byte + if err := rows.Scan(&payload); err != nil { + _ = rows.Close() + return err + } + payloads = append(payloads, payload) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return err + } + _ = rows.Close() + + for _, payload := range payloads { + var t Task + if err := json.Unmarshal(payload, &t); err != nil { + continue + } + if t.LeaseExpiry == nil { + continue + } + if t.Status != "running" { + continue + } + if t.LeaseExpiry.Before(now) { + t.Error = fmt.Sprintf("worker %s lease expired", t.LeasedBy) + RecordLeaseExpiration() + if t.RetryCount < t.MaxRetries { + _ = q.RetryTask(&t) + } else { + _ = q.MoveToDeadLetterQueue(&t, "lease expiry after max retries") + } + } + } + return nil +} diff --git a/internal/queue/task.go b/internal/queue/task.go index 8a0cded..9bcad75 100644 --- a/internal/queue/task.go +++ b/internal/queue/task.go @@ -6,21 +6,41 @@ import ( "github.com/jfraeys/fetch_ml/internal/config" ) +// DatasetSpec describes a dataset input with optional provenance fields. +type DatasetSpec struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` + Checksum string `json:"checksum,omitempty"` + URI string `json:"uri,omitempty"` +} + // Task represents an ML experiment task type Task struct { - ID string `json:"id"` - JobName string `json:"job_name"` - Args string `json:"args"` - Status string `json:"status"` // queued, running, completed, failed - Priority int64 `json:"priority"` - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - EndedAt *time.Time `json:"ended_at,omitempty"` - WorkerID string `json:"worker_id,omitempty"` - Error string `json:"error,omitempty"` - Output string `json:"output,omitempty"` - Datasets []string `json:"datasets,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` + ID string `json:"id"` + JobName string `json:"job_name"` + Args string `json:"args"` + Status string `json:"status"` // queued, running, completed, failed + Priority int64 `json:"priority"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` + WorkerID string `json:"worker_id,omitempty"` + Error string `json:"error,omitempty"` + Output string `json:"output,omitempty"` + // TODO(phase1): SnapshotID is an opaque identifier only. + // TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution. + SnapshotID string `json:"snapshot_id,omitempty"` + // DatasetSpecs is the preferred structured dataset input and should be authoritative. + DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"` + // Datasets is kept for backward compatibility (legacy callers). + Datasets []string `json:"datasets,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + + // Resource requests (optional, 0 means unspecified) + CPU int `json:"cpu,omitempty"` + MemoryGB int `json:"memory_gb,omitempty"` + GPU int `json:"gpu,omitempty"` + GPUMemory string `json:"gpu_memory,omitempty"` // User ownership and permissions UserID string `json:"user_id"` // User who owns this task @@ -36,6 +56,38 @@ type Task struct { MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3) LastError string `json:"last_error,omitempty"` // Last error encountered NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff) + + // Optional tracking configuration for this task + Tracking *TrackingConfig `json:"tracking,omitempty"` +} + +// TrackingConfig specifies experiment tracking tools to enable for a task. +type TrackingConfig struct { + MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"` + TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"` + Wandb *WandbTrackingConfig `json:"wandb,omitempty"` +} + +// MLflowTrackingConfig controls MLflow integration. +type MLflowTrackingConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled" + TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode +} + +// TensorBoardTrackingConfig controls TensorBoard integration. +type TensorBoardTrackingConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode,omitempty"` // "sidecar" | "disabled" +} + +// WandbTrackingConfig controls Weights & Biases integration. +type WandbTrackingConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode,omitempty"` // "remote" | "disabled" + APIKey string `json:"api_key,omitempty"` + Project string `json:"project,omitempty"` + Entity string `json:"entity,omitempty"` } // Redis key constants @@ -44,5 +96,6 @@ var ( TaskPrefix = config.RedisTaskPrefix TaskStatusPrefix = config.RedisTaskStatusPrefix WorkerHeartbeat = config.RedisWorkerHeartbeat + WorkerPrewarmKey = config.RedisWorkerPrewarmKey JobMetricsPrefix = config.RedisJobMetricsPrefix ) diff --git a/internal/storage/db_connect.go b/internal/storage/db_connect.go new file mode 100644 index 0000000..df5b3bd --- /dev/null +++ b/internal/storage/db_connect.go @@ -0,0 +1,146 @@ +// Package storage provides database abstraction and job management. +package storage + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + _ "github.com/lib/pq" // PostgreSQL driver + _ "github.com/mattn/go-sqlite3" // SQLite driver +) + +// DBConfig holds database connection configuration. +type DBConfig struct { + Type string + Connection string + Host string + Port int + Username string + Password string + Database string +} + +// DB wraps a database connection with type information. +type DB struct { + conn *sql.DB + dbType string +} + +// DBTypeSQLite is the constant for SQLite database type +const DBTypeSQLite = "sqlite" + +// NewDB creates a new database connection. +func NewDB(config DBConfig) (*DB, error) { + var conn *sql.DB + var err error + + switch strings.ToLower(config.Type) { + case DBTypeSQLite: + conn, err = sql.Open("sqlite3", config.Connection) + if err != nil { + return nil, fmt.Errorf("failed to open SQLite database: %w", err) + } + // Enable foreign keys + if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil { + return nil, fmt.Errorf("failed to enable foreign keys: %w", err) + } + // Enable WAL mode for better concurrency + if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil { + return nil, fmt.Errorf("failed to enable WAL mode: %w", err) + } + // Additional SQLite optimizations for throughput + if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil { + return nil, fmt.Errorf("failed to set synchronous mode: %w", err) + } + if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil { + return nil, fmt.Errorf("failed to set cache size: %w", err) + } + if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil { + return nil, fmt.Errorf("failed to set temp store: %w", err) + } + case "postgres": + connStr := buildPostgresConnectionString(config) + conn, err = sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) + } + case "postgresql": + // Handle "postgresql" as alias for "postgres" + connStr := buildPostgresConnectionString(config) + conn, err = sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) + } + default: + return nil, fmt.Errorf("unsupported database type: %s", config.Type) + } + + // Optimize connection pool for better throughput + conn.SetMaxOpenConns(50) // Increase max open connections + conn.SetMaxIdleConns(25) // Maintain idle connections + conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime + conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout + + return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil +} + +func buildPostgresConnectionString(config DBConfig) string { + if config.Connection != "" { + return config.Connection + } + + var connStr strings.Builder + connStr.WriteString("host=") + if config.Host != "" { + connStr.WriteString(config.Host) + } else { + connStr.WriteString("localhost") + } + + if config.Port > 0 { + connStr.WriteString(fmt.Sprintf(" port=%d", config.Port)) + } else { + connStr.WriteString(" port=5432") + } + + if config.Username != "" { + connStr.WriteString(fmt.Sprintf(" user=%s", config.Username)) + } + + if config.Password != "" { + connStr.WriteString(fmt.Sprintf(" password=%s", config.Password)) + } + + if config.Database != "" { + connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database)) + } else { + connStr.WriteString(" dbname=fetch_ml") + } + + connStr.WriteString(" sslmode=disable") + return connStr.String() +} + +// NewDBFromPath creates a new database from a file path (legacy constructor). +func NewDBFromPath(dbPath string) (*DB, error) { + return NewDB(DBConfig{ + Type: DBTypeSQLite, + Connection: dbPath, + }) +} + +// Initialize creates database schema. +func (db *DB) Initialize(schema string) error { + if _, err := db.conn.ExecContext(context.Background(), schema); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } + return nil +} + +// Close closes the database connection. +func (db *DB) Close() error { + return db.conn.Close() +} diff --git a/internal/storage/db_experiments.go b/internal/storage/db_experiments.go new file mode 100644 index 0000000..ef9b19c --- /dev/null +++ b/internal/storage/db_experiments.go @@ -0,0 +1,564 @@ +// Package storage provides database abstraction and job management. +package storage + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type Experiment struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Status string `json:"status"` + UserID string `json:"user_id,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type ExperimentEnvironment struct { + PythonVersion string `json:"python_version"` + CUDAVersion string `json:"cuda_version,omitempty"` + SystemOS string `json:"system_os"` + SystemArch string `json:"system_arch"` + Hostname string `json:"hostname"` + RequirementsHash string `json:"requirements_hash"` + CondaEnvHash string `json:"conda_env_hash,omitempty"` + Dependencies json.RawMessage `json:"dependencies,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ExperimentGitInfo struct { + CommitSHA string `json:"commit_sha"` + Branch string `json:"branch"` + RemoteURL string `json:"remote_url"` + IsDirty bool `json:"is_dirty"` + DiffPatch string `json:"diff_patch,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ExperimentSeeds struct { + Numpy *int64 `json:"numpy_seed,omitempty"` + Torch *int64 `json:"torch_seed,omitempty"` + TensorFlow *int64 `json:"tensorflow_seed,omitempty"` + Random *int64 `json:"random_seed,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type Dataset struct { + Name string `json:"name"` + URL string `json:"url"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type ExperimentWithMetadata struct { + Experiment Experiment `json:"experiment"` + Environment *ExperimentEnvironment `json:"environment,omitempty"` + GitInfo *ExperimentGitInfo `json:"git_info,omitempty"` + Seeds *ExperimentSeeds `json:"seeds,omitempty"` +} + +func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error { + if exp == nil { + return fmt.Errorf("experiment is nil") + } + if exp.ID == "" { + return fmt.Errorf("experiment id is required") + } + if exp.Name == "" { + return fmt.Errorf("experiment name is required") + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + description = excluded.description, + status = excluded.status, + user_id = excluded.user_id, + workspace_id = excluded.workspace_id` + } else { + query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (id) DO UPDATE SET + name = EXCLUDED.name, + description = EXCLUDED.description, + status = EXCLUDED.status, + user_id = EXCLUDED.user_id, + workspace_id = EXCLUDED.workspace_id` + } + + _, err := db.conn.ExecContext( + ctx, + query, + exp.ID, + exp.Name, + exp.Description, + exp.Status, + exp.UserID, + exp.WorkspaceID, + ) + if err != nil { + return fmt.Errorf("failed to upsert experiment: %w", err) + } + return nil +} + +func (db *DB) UpsertExperimentEnvironment( + ctx context.Context, + experimentID string, + env *ExperimentEnvironment, +) error { + if experimentID == "" { + return fmt.Errorf("experiment id is required") + } + if env == nil { + return fmt.Errorf("environment is nil") + } + + deps := "" + if len(env.Dependencies) > 0 { + deps = string(env.Dependencies) + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO experiment_environments + (experiment_id, python_version, cuda_version, system_os, system_arch, hostname, + requirements_hash, conda_env_hash, dependencies) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(experiment_id) DO UPDATE SET + python_version = excluded.python_version, + cuda_version = excluded.cuda_version, + system_os = excluded.system_os, + system_arch = excluded.system_arch, + hostname = excluded.hostname, + requirements_hash = excluded.requirements_hash, + conda_env_hash = excluded.conda_env_hash, + dependencies = excluded.dependencies` + } else { + query = `INSERT INTO experiment_environments + (experiment_id, python_version, cuda_version, system_os, system_arch, hostname, + requirements_hash, conda_env_hash, dependencies) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (experiment_id) DO UPDATE SET + python_version = EXCLUDED.python_version, + cuda_version = EXCLUDED.cuda_version, + system_os = EXCLUDED.system_os, + system_arch = EXCLUDED.system_arch, + hostname = EXCLUDED.hostname, + requirements_hash = EXCLUDED.requirements_hash, + conda_env_hash = EXCLUDED.conda_env_hash, + dependencies = EXCLUDED.dependencies` + } + + _, err := db.conn.ExecContext( + ctx, + query, + experimentID, + env.PythonVersion, + env.CUDAVersion, + env.SystemOS, + env.SystemArch, + env.Hostname, + env.RequirementsHash, + env.CondaEnvHash, + deps, + ) + if err != nil { + return fmt.Errorf("failed to upsert experiment environment: %w", err) + } + return nil +} + +func (db *DB) UpsertExperimentGitInfo( + ctx context.Context, + experimentID string, + info *ExperimentGitInfo, +) error { + if experimentID == "" { + return fmt.Errorf("experiment id is required") + } + if info == nil { + return fmt.Errorf("git info is nil") + } + + isDirty := 0 + if info.IsDirty { + isDirty = 1 + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO experiment_git_info + (experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(experiment_id) DO UPDATE SET + commit_sha = excluded.commit_sha, + branch = excluded.branch, + remote_url = excluded.remote_url, + is_dirty = excluded.is_dirty, + diff_patch = excluded.diff_patch` + } else { + query = `INSERT INTO experiment_git_info + (experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (experiment_id) DO UPDATE SET + commit_sha = EXCLUDED.commit_sha, + branch = EXCLUDED.branch, + remote_url = EXCLUDED.remote_url, + is_dirty = EXCLUDED.is_dirty, + diff_patch = EXCLUDED.diff_patch` + } + + _, err := db.conn.ExecContext( + ctx, + query, + experimentID, + info.CommitSHA, + info.Branch, + info.RemoteURL, + isDirty, + info.DiffPatch, + ) + if err != nil { + return fmt.Errorf("failed to upsert experiment git info: %w", err) + } + return nil +} + +func (db *DB) UpsertExperimentSeeds( + ctx context.Context, + experimentID string, + seeds *ExperimentSeeds, +) error { + if experimentID == "" { + return fmt.Errorf("experiment id is required") + } + if seeds == nil { + return fmt.Errorf("seeds is nil") + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO experiment_seeds + (experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(experiment_id) DO UPDATE SET + numpy_seed = excluded.numpy_seed, + torch_seed = excluded.torch_seed, + tensorflow_seed = excluded.tensorflow_seed, + random_seed = excluded.random_seed` + } else { + query = `INSERT INTO experiment_seeds + (experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (experiment_id) DO UPDATE SET + numpy_seed = EXCLUDED.numpy_seed, + torch_seed = EXCLUDED.torch_seed, + tensorflow_seed = EXCLUDED.tensorflow_seed, + random_seed = EXCLUDED.random_seed` + } + + _, err := db.conn.ExecContext( + ctx, + query, + experimentID, + seeds.Numpy, + seeds.Torch, + seeds.TensorFlow, + seeds.Random, + ) + if err != nil { + return fmt.Errorf("failed to upsert experiment seeds: %w", err) + } + return nil +} + +func (db *DB) GetExperimentWithMetadata( + ctx context.Context, + experimentID string, +) (*ExperimentWithMetadata, error) { + if experimentID == "" { + return nil, fmt.Errorf("experiment id is required") + } + + var exp Experiment + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id, name, + COALESCE(description, ''), COALESCE(status, ''), + COALESCE(user_id, ''), COALESCE(workspace_id, ''), + created_at, updated_at + FROM experiments WHERE id = ?` + } else { + query = `SELECT id, name, + COALESCE(description, ''), COALESCE(status, ''), + COALESCE(user_id, ''), COALESCE(workspace_id, ''), + created_at, updated_at + FROM experiments WHERE id = $1` + } + + err := db.conn.QueryRowContext(ctx, query, experimentID).Scan( + &exp.ID, + &exp.Name, + &exp.Description, + &exp.Status, + &exp.UserID, + &exp.WorkspaceID, + &exp.CreatedAt, + &exp.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to get experiment: %w", err) + } + + result := &ExperimentWithMetadata{Experiment: exp} + + var env ExperimentEnvironment + var envDeps sql.NullString + var envQuery string + if db.dbType == DBTypeSQLite { + envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''), + COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''), + COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''), + dependencies, created_at + FROM experiment_environments WHERE experiment_id = ?` + } else { + envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''), + COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''), + COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''), + dependencies, created_at + FROM experiment_environments WHERE experiment_id = $1` + } + if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan( + &env.PythonVersion, + &env.CUDAVersion, + &env.SystemOS, + &env.SystemArch, + &env.Hostname, + &env.RequirementsHash, + &env.CondaEnvHash, + &envDeps, + &env.CreatedAt, + ); err == nil { + if envDeps.Valid && envDeps.String != "" { + env.Dependencies = json.RawMessage(envDeps.String) + } + result.Environment = &env + } + + var git ExperimentGitInfo + var gitDirty sql.NullInt64 + var gitQuery string + if db.dbType == DBTypeSQLite { + gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''), + COALESCE(remote_url, ''), COALESCE(is_dirty, 0), + COALESCE(diff_patch, ''), created_at + FROM experiment_git_info WHERE experiment_id = ?` + } else { + gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''), + COALESCE(remote_url, ''), COALESCE(is_dirty, 0), + COALESCE(diff_patch, ''), created_at + FROM experiment_git_info WHERE experiment_id = $1` + } + if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan( + &git.CommitSHA, + &git.Branch, + &git.RemoteURL, + &gitDirty, + &git.DiffPatch, + &git.CreatedAt, + ); err == nil { + git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0 + result.GitInfo = &git + } + + var seeds ExperimentSeeds + var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64 + var seedsQuery string + if db.dbType == DBTypeSQLite { + seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at + FROM experiment_seeds WHERE experiment_id = ?` + } else { + seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at + FROM experiment_seeds WHERE experiment_id = $1` + } + if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan( + &numpySeed, + &torchSeed, + &tfSeed, + &randSeed, + &seeds.CreatedAt, + ); err == nil { + if numpySeed.Valid { + v := numpySeed.Int64 + seeds.Numpy = &v + } + if torchSeed.Valid { + v := torchSeed.Int64 + seeds.Torch = &v + } + if tfSeed.Valid { + v := tfSeed.Int64 + seeds.TensorFlow = &v + } + if randSeed.Valid { + v := randSeed.Int64 + seeds.Random = &v + } + result.Seeds = &seeds + } + + return result, nil +} + +func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error { + if ds == nil { + return fmt.Errorf("dataset is nil") + } + if ds.Name == "" { + return fmt.Errorf("dataset name is required") + } + if ds.URL == "" { + return fmt.Errorf("dataset url is required") + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO datasets (name, url) + VALUES (?, ?) + ON CONFLICT(name) DO UPDATE SET + url = excluded.url` + } else { + query = `INSERT INTO datasets (name, url) + VALUES ($1, $2) + ON CONFLICT (name) DO UPDATE SET + url = EXCLUDED.url` + } + + if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil { + return fmt.Errorf("failed to upsert dataset: %w", err) + } + return nil +} + +func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) { + if name == "" { + return nil, fmt.Errorf("dataset name is required") + } + + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?` + } else { + query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1` + } + + var ds Dataset + if err := db.conn.QueryRowContext(ctx, query, name).Scan( + &ds.Name, + &ds.URL, + &ds.CreatedAt, + &ds.UpdatedAt, + ); err != nil { + if err == sql.ErrNoRows { + return nil, err + } + return nil, fmt.Errorf("failed to get dataset: %w", err) + } + return &ds, nil +} + +func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) { + query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC` + var args []interface{} + if limit > 0 { + if db.dbType == DBTypeSQLite { + query += " LIMIT ?" + } else { + query += " LIMIT $1" + } + args = append(args, limit) + } + + rows, err := db.conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to list datasets: %w", err) + } + defer func() { _ = rows.Close() }() + + var out []*Dataset + for rows.Next() { + var ds Dataset + if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan dataset: %w", err) + } + out = append(out, &ds) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating datasets: %w", err) + } + return out, nil +} + +func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) { + if term == "" { + return []*Dataset{}, nil + } + + // Escape %/_ for LIKE and use parameterized query. + escaped := strings.ReplaceAll(term, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "%", "\\%") + escaped = strings.ReplaceAll(escaped, "_", "\\_") + pattern := "%" + escaped + "%" + + var query string + var args []interface{} + if db.dbType == DBTypeSQLite { + query = `SELECT name, url, created_at, updated_at FROM datasets + WHERE name LIKE ? ESCAPE '\' + ORDER BY name ASC` + args = append(args, pattern) + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + } + } else { + query = `SELECT name, url, created_at, updated_at FROM datasets + WHERE name LIKE $1 ESCAPE '\' + ORDER BY name ASC` + args = append(args, pattern) + if limit > 0 { + query += " LIMIT $2" + args = append(args, limit) + } + } + + rows, err := db.conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to search datasets: %w", err) + } + defer func() { _ = rows.Close() }() + + var out []*Dataset + for rows.Next() { + var ds Dataset + if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan dataset: %w", err) + } + out = append(out, &ds) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating datasets: %w", err) + } + return out, nil +} diff --git a/internal/storage/db.go b/internal/storage/db_jobs.go similarity index 61% rename from internal/storage/db.go rename to internal/storage/db_jobs.go index 8a09823..dd547f6 100644 --- a/internal/storage/db.go +++ b/internal/storage/db_jobs.go @@ -1,4 +1,3 @@ -// Package storage provides database abstraction and job management. package storage import ( @@ -6,133 +5,9 @@ import ( "database/sql" "encoding/json" "fmt" - "strings" "time" - - _ "github.com/lib/pq" // PostgreSQL driver - _ "github.com/mattn/go-sqlite3" // SQLite driver ) -// DBConfig holds database connection configuration. -type DBConfig struct { - Type string - Connection string - Host string - Port int - Username string - Password string - Database string -} - -// DB wraps a database connection with type information. -type DB struct { - conn *sql.DB - dbType string -} - -// DBTypeSQLite is the constant for SQLite database type -const DBTypeSQLite = "sqlite" - -// NewDB creates a new database connection. -func NewDB(config DBConfig) (*DB, error) { - var conn *sql.DB - var err error - - switch strings.ToLower(config.Type) { - case DBTypeSQLite: - conn, err = sql.Open("sqlite3", config.Connection) - if err != nil { - return nil, fmt.Errorf("failed to open SQLite database: %w", err) - } - // Enable foreign keys - if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil { - return nil, fmt.Errorf("failed to enable foreign keys: %w", err) - } - // Enable WAL mode for better concurrency - if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil { - return nil, fmt.Errorf("failed to enable WAL mode: %w", err) - } - // Additional SQLite optimizations for throughput - if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil { - return nil, fmt.Errorf("failed to set synchronous mode: %w", err) - } - if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil { - return nil, fmt.Errorf("failed to set cache size: %w", err) - } - if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil { - return nil, fmt.Errorf("failed to set temp store: %w", err) - } - case "postgres": - connStr := buildPostgresConnectionString(config) - conn, err = sql.Open("postgres", connStr) - if err != nil { - return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) - } - case "postgresql": - // Handle "postgresql" as alias for "postgres" - connStr := buildPostgresConnectionString(config) - conn, err = sql.Open("postgres", connStr) - if err != nil { - return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) - } - default: - return nil, fmt.Errorf("unsupported database type: %s", config.Type) - } - - // Optimize connection pool for better throughput - conn.SetMaxOpenConns(50) // Increase max open connections - conn.SetMaxIdleConns(25) // Maintain idle connections - conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime - conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout - - return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil -} - -func buildPostgresConnectionString(config DBConfig) string { - if config.Connection != "" { - return config.Connection - } - - var connStr strings.Builder - connStr.WriteString("host=") - if config.Host != "" { - connStr.WriteString(config.Host) - } else { - connStr.WriteString("localhost") - } - - if config.Port > 0 { - connStr.WriteString(fmt.Sprintf(" port=%d", config.Port)) - } else { - connStr.WriteString(" port=5432") - } - - if config.Username != "" { - connStr.WriteString(fmt.Sprintf(" user=%s", config.Username)) - } - - if config.Password != "" { - connStr.WriteString(fmt.Sprintf(" password=%s", config.Password)) - } - - if config.Database != "" { - connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database)) - } else { - connStr.WriteString(" dbname=fetch_ml") - } - - connStr.WriteString(" sslmode=disable") - return connStr.String() -} - -// NewDBFromPath creates a new database from a file path (legacy constructor). -func NewDBFromPath(dbPath string) (*DB, error) { - return NewDB(DBConfig{ - Type: DBTypeSQLite, - Connection: dbPath, - }) -} - // Job represents a machine learning job in the system. type Job struct { ID string `json:"id"` @@ -161,19 +36,6 @@ type Worker struct { Metadata map[string]string `json:"metadata,omitempty"` } -// Initialize creates database schema. -func (db *DB) Initialize(schema string) error { - if _, err := db.conn.ExecContext(context.Background(), schema); err != nil { - return fmt.Errorf("failed to initialize database: %w", err) - } - return nil -} - -// Close closes the database connection. -func (db *DB) Close() error { - return db.conn.Close() -} - // CreateJob inserts a new job into the database. func (db *DB) CreateJob(job *Job) error { datasetsJSON, _ := json.Marshal(job.Datasets) @@ -185,11 +47,20 @@ func (db *DB) CreateJob(job *Job) error { VALUES (?, ?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7)` + VALUES ($1, $2, $3, $4, $5, $6, $7)` } - _, err := db.conn.ExecContext(context.Background(), query, job.ID, job.JobName, job.Args, job.Status, - job.Priority, string(datasetsJSON), string(metadataJSON)) + _, err := db.conn.ExecContext( + context.Background(), + query, + job.ID, + job.JobName, + job.Args, + job.Status, + job.Priority, + string(datasetsJSON), + string(metadataJSON), + ) if err != nil { return fmt.Errorf("failed to create job: %w", err) } @@ -242,21 +113,30 @@ func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { var query string if db.dbType == DBTypeSQLite { query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?, - started_at = CASE WHEN ? = 'running' AND started_at IS NULL - THEN CURRENT_TIMESTAMP ELSE started_at END, - ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL - THEN CURRENT_TIMESTAMP ELSE ended_at END - WHERE id = ?` + started_at = CASE WHEN ? = 'running' AND started_at IS NULL + THEN CURRENT_TIMESTAMP ELSE started_at END, + ended_at = CASE WHEN ? IN ('completed', 'failed') + AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END + WHERE id = ?` } else { query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3, - started_at = CASE WHEN $4 = 'running' AND started_at IS NULL - THEN CURRENT_TIMESTAMP ELSE started_at END, - ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL - THEN CURRENT_TIMESTAMP ELSE ended_at END - WHERE id = $6` + started_at = CASE WHEN $4 = 'running' AND started_at IS NULL + THEN CURRENT_TIMESTAMP ELSE started_at END, + ended_at = CASE WHEN $5 IN ('completed', 'failed') + AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END + WHERE id = $6` } - _, err := db.conn.ExecContext(context.Background(), query, status, workerID, errorMsg, status, status, id) + _, err := db.conn.ExecContext( + context.Background(), + query, + status, + workerID, + errorMsg, + status, + status, + id, + ) if err != nil { return fmt.Errorf("failed to update job status: %w", err) } @@ -339,17 +219,25 @@ func (db *DB) RegisterWorker(worker *Worker) error { VALUES (?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (id) DO UPDATE SET - hostname = EXCLUDED.hostname, - status = EXCLUDED.status, - current_jobs = EXCLUDED.current_jobs, - max_jobs = EXCLUDED.max_jobs, - metadata = EXCLUDED.metadata` + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (id) DO UPDATE SET + hostname = EXCLUDED.hostname, + status = EXCLUDED.status, + current_jobs = EXCLUDED.current_jobs, + max_jobs = EXCLUDED.max_jobs, + metadata = EXCLUDED.metadata` } - _, err := db.conn.ExecContext(context.Background(), query, worker.ID, worker.Hostname, worker.Status, - worker.CurrentJobs, worker.MaxJobs, string(metadataJSON)) + _, err := db.conn.ExecContext( + context.Background(), + query, + worker.ID, + worker.Hostname, + worker.Status, + worker.CurrentJobs, + worker.MaxJobs, + string(metadataJSON), + ) if err != nil { return fmt.Errorf("failed to register worker: %w", err) } @@ -377,10 +265,14 @@ func (db *DB) GetActiveWorkers() ([]*Worker, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata - FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')` + FROM workers + WHERE status = 'active' + AND last_heartbeat > datetime('now', '-30 seconds')` } else { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata - FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'` + FROM workers + WHERE status = 'active' + AND last_heartbeat > NOW() - INTERVAL '30 seconds'` } rows, err := db.conn.QueryContext(context.Background(), query) diff --git a/internal/storage/schema_embed.go b/internal/storage/schema_embed.go new file mode 100644 index 0000000..8589355 --- /dev/null +++ b/internal/storage/schema_embed.go @@ -0,0 +1,32 @@ +package storage + +import ( + "embed" + "fmt" + "strings" +) + +//go:embed schema_sqlite.sql +var sqliteSchemaFS embed.FS + +//go:embed schema_postgres.sql +var postgresSchemaFS embed.FS + +func SchemaForDBType(dbType string) (string, error) { + switch strings.ToLower(dbType) { + case DBTypeSQLite: + b, err := sqliteSchemaFS.ReadFile("schema_sqlite.sql") + if err != nil { + return "", fmt.Errorf("failed to read sqlite schema: %w", err) + } + return string(b), nil + case "postgres", "postgresql": + b, err := postgresSchemaFS.ReadFile("schema_postgres.sql") + if err != nil { + return "", fmt.Errorf("failed to read postgres schema: %w", err) + } + return string(b), nil + default: + return "", fmt.Errorf("unsupported database type: %s", dbType) + } +} diff --git a/internal/storage/schema_postgres.sql b/internal/storage/schema_postgres.sql index 713ba40..37a65ac 100644 --- a/internal/storage/schema_postgres.sql +++ b/internal/storage/schema_postgres.sql @@ -43,6 +43,13 @@ CREATE TABLE IF NOT EXISTS system_metrics ( PRIMARY KEY (metric_name, timestamp) ); +CREATE TABLE IF NOT EXISTS datasets ( + name TEXT PRIMARY KEY, + url TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + -- Indexes for performance CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at); @@ -52,6 +59,8 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp); CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat); CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp); +CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name); + -- Function to update updated_at timestamp CREATE OR REPLACE FUNCTION update_updated_at_column() RETURNS TRIGGER AS $$ diff --git a/internal/storage/schema_sqlite.sql b/internal/storage/schema_sqlite.sql index ce415c4..e66d385 100644 --- a/internal/storage/schema_sqlite.sql +++ b/internal/storage/schema_sqlite.sql @@ -43,6 +43,59 @@ CREATE TABLE IF NOT EXISTS system_metrics ( PRIMARY KEY (metric_name, timestamp) ); +CREATE TABLE IF NOT EXISTS experiments ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + status TEXT DEFAULT 'pending', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + user_id TEXT, + workspace_id TEXT +); + +CREATE TABLE IF NOT EXISTS experiment_environments ( + experiment_id TEXT PRIMARY KEY, + python_version TEXT, + cuda_version TEXT, + system_os TEXT, + system_arch TEXT, + hostname TEXT, + requirements_hash TEXT, + conda_env_hash TEXT, + dependencies TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS experiment_git_info ( + experiment_id TEXT PRIMARY KEY, + commit_sha TEXT, + branch TEXT, + remote_url TEXT, + is_dirty INTEGER DEFAULT 0, + diff_patch TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS experiment_seeds ( + experiment_id TEXT PRIMARY KEY, + numpy_seed INTEGER, + torch_seed INTEGER, + tensorflow_seed INTEGER, + random_seed INTEGER, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS datasets ( + name TEXT PRIMARY KEY, + url TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + -- Indexes for performance CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at); @@ -51,6 +104,11 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id); CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp); CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat); CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp); +CREATE INDEX IF NOT EXISTS idx_experiments_created_at ON experiments(created_at); +CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status); +CREATE INDEX IF NOT EXISTS idx_experiments_user_id ON experiments(user_id); + +CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name); -- Triggers to update timestamps CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp @@ -59,3 +117,17 @@ CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp BEGIN UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; END; + +CREATE TRIGGER IF NOT EXISTS update_experiments_timestamp + AFTER UPDATE ON experiments + FOR EACH ROW + BEGIN + UPDATE experiments SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END; + +CREATE TRIGGER IF NOT EXISTS update_datasets_timestamp + AFTER UPDATE ON datasets + FOR EACH ROW + BEGIN + UPDATE datasets SET updated_at = CURRENT_TIMESTAMP WHERE name = NEW.name; + END;