refactor(storage,queue): split storage layer and add sqlite queue backend
This commit is contained in:
parent
e901ddd810
commit
6ff5324e74
11 changed files with 1916 additions and 181 deletions
76
internal/queue/backend.go
Normal file
76
internal/queue/backend.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrInvalidQueueBackend = errors.New("invalid queue backend")
|
||||
|
||||
type Backend interface {
|
||||
AddTask(task *Task) error
|
||||
GetNextTask() (*Task, error)
|
||||
PeekNextTask() (*Task, error)
|
||||
|
||||
GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error)
|
||||
GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error)
|
||||
RenewLease(taskID string, workerID string, leaseDuration time.Duration) error
|
||||
ReleaseLease(taskID string, workerID string) error
|
||||
|
||||
RetryTask(task *Task) error
|
||||
MoveToDeadLetterQueue(task *Task, reason string) error
|
||||
|
||||
GetTask(taskID string) (*Task, error)
|
||||
GetAllTasks() ([]*Task, error)
|
||||
GetTaskByName(jobName string) (*Task, error)
|
||||
CancelTask(taskID string) error
|
||||
|
||||
UpdateTask(task *Task) error
|
||||
UpdateTaskWithMetrics(task *Task, action string) error
|
||||
|
||||
RecordMetric(jobName, metric string, value float64) error
|
||||
Heartbeat(workerID string) error
|
||||
QueueDepth() (int64, error)
|
||||
|
||||
SetWorkerPrewarmState(state PrewarmState) error
|
||||
ClearWorkerPrewarmState(workerID string) error
|
||||
GetWorkerPrewarmState(workerID string) (*PrewarmState, error)
|
||||
GetAllWorkerPrewarmStates() ([]PrewarmState, error)
|
||||
|
||||
SignalPrewarmGC() error
|
||||
PrewarmGCRequestValue() (string, error)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
type QueueBackend string
|
||||
|
||||
const (
|
||||
QueueBackendRedis QueueBackend = "redis"
|
||||
QueueBackendSQLite QueueBackend = "sqlite"
|
||||
)
|
||||
|
||||
type BackendConfig struct {
|
||||
Backend QueueBackend
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
SQLitePath string
|
||||
MetricsFlushInterval time.Duration
|
||||
}
|
||||
|
||||
func NewBackend(cfg BackendConfig) (Backend, error) {
|
||||
switch cfg.Backend {
|
||||
case "", QueueBackendRedis:
|
||||
return NewTaskQueue(Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
})
|
||||
case QueueBackendSQLite:
|
||||
return NewSQLiteQueue(cfg.SQLitePath)
|
||||
default:
|
||||
return nil, ErrInvalidQueueBackend
|
||||
}
|
||||
}
|
||||
|
|
@ -174,8 +174,10 @@ func IsRetryable(category ErrorCategory) bool {
|
|||
// GetUserMessage returns a user-friendly error message with suggestions
|
||||
func GetUserMessage(category ErrorCategory, err error) string {
|
||||
messages := map[ErrorCategory]string{
|
||||
ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.",
|
||||
ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.",
|
||||
ErrorNetwork: "Network connectivity issue. Please check your network " +
|
||||
"connection and try again.",
|
||||
ErrorResource: "System resource exhausted. The system may be under heavy load. " +
|
||||
"Try again later or contact support.",
|
||||
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
|
||||
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
|
||||
ErrorValidation: "Invalid input. Please review your request and correct any errors.",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ const (
|
|||
defaultLeaseDuration = 30 * time.Minute
|
||||
defaultMaxRetries = 3
|
||||
defaultBlockTimeout = 1 * time.Second
|
||||
PrewarmGCRequestKey = "ml:prewarm:gc_request"
|
||||
)
|
||||
|
||||
// TaskQueue manages ML experiment tasks via Redis
|
||||
|
|
@ -33,6 +34,21 @@ type metricEvent struct {
|
|||
Value float64
|
||||
}
|
||||
|
||||
type PrewarmState struct {
|
||||
WorkerID string `json:"worker_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Phase string `json:"phase"`
|
||||
EnvImage string `json:"env_image,omitempty"`
|
||||
DatasetCnt int `json:"dataset_count"`
|
||||
EnvHit int64 `json:"env_hit,omitempty"`
|
||||
EnvMiss int64 `json:"env_miss,omitempty"`
|
||||
EnvBuilt int64 `json:"env_built,omitempty"`
|
||||
EnvTimeNs int64 `json:"env_time_ns,omitempty"`
|
||||
}
|
||||
|
||||
// Config holds configuration for TaskQueue
|
||||
type Config struct {
|
||||
RedisAddr string
|
||||
|
|
@ -155,7 +171,10 @@ func isBlockingUnsupported(err error) bool {
|
|||
return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax")
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
|
||||
func (tq *TaskQueue) pollUntilDeadline(
|
||||
workerID string,
|
||||
leaseDuration, blockTimeout time.Duration,
|
||||
) (*Task, error) {
|
||||
deadline := time.Now().Add(blockTimeout)
|
||||
sleep := 25 * time.Millisecond
|
||||
|
||||
|
|
@ -194,8 +213,98 @@ func (tq *TaskQueue) GetNextTask() (*Task, error) {
|
|||
return tq.GetTask(taskID)
|
||||
}
|
||||
|
||||
// PeekNextTask returns the highest priority task without removing it from the queue.
|
||||
// This is intended for best-effort prewarm logic; it must never be required for correctness.
|
||||
func (tq *TaskQueue) PeekNextTask() (*Task, error) {
|
||||
// ZRANGE with REV gives highest scores first.
|
||||
ids, err := tq.client.ZRevRange(tq.ctx, TaskQueueKey, 0, 0).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return tq.GetTask(ids[0])
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) SetWorkerPrewarmState(state PrewarmState) error {
|
||||
if state.WorkerID == "" {
|
||||
return fmt.Errorf("missing worker_id")
|
||||
}
|
||||
key := WorkerPrewarmKey + state.WorkerID
|
||||
data, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal prewarm state: %w", err)
|
||||
}
|
||||
// Keep short TTL to avoid stale prewarm state if worker dies.
|
||||
return tq.client.Set(tq.ctx, key, data, 30*time.Second).Err()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) ClearWorkerPrewarmState(workerID string) error {
|
||||
if workerID == "" {
|
||||
return fmt.Errorf("missing worker_id")
|
||||
}
|
||||
key := WorkerPrewarmKey + workerID
|
||||
return tq.client.Del(tq.ctx, key).Err()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
|
||||
if workerID == "" {
|
||||
return nil, fmt.Errorf("missing worker_id")
|
||||
}
|
||||
key := WorkerPrewarmKey + workerID
|
||||
v, err := tq.client.Get(tq.ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var state PrewarmState
|
||||
if err := json.Unmarshal([]byte(v), &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
|
||||
var cursor uint64
|
||||
pattern := WorkerPrewarmKey + "*"
|
||||
out := make([]PrewarmState, 0, 8)
|
||||
|
||||
for {
|
||||
keys, next, err := tq.client.Scan(tq.ctx, cursor, pattern, 50).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cursor = next
|
||||
for _, key := range keys {
|
||||
v, err := tq.client.Get(tq.ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var state PrewarmState
|
||||
if err := json.Unmarshal([]byte(v), &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
|
||||
}
|
||||
out = append(out, state)
|
||||
}
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetNextTaskWithLease gets the next task and acquires a lease
|
||||
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
func (tq *TaskQueue) GetNextTaskWithLease(
|
||||
workerID string,
|
||||
leaseDuration time.Duration,
|
||||
) (*Task, error) {
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
|
|
@ -239,7 +348,8 @@ func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Du
|
|||
return task, nil
|
||||
}
|
||||
|
||||
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease.
|
||||
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task.
|
||||
// Once a task is received, it then acquires a lease.
|
||||
func (tq *TaskQueue) GetNextTaskWithLeaseBlocking(
|
||||
workerID string,
|
||||
leaseDuration, blockTimeout time.Duration,
|
||||
|
|
@ -580,6 +690,23 @@ func (tq *TaskQueue) Close() error {
|
|||
return tq.client.Close()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) SignalPrewarmGC() error {
|
||||
return tq.client.Set(
|
||||
tq.ctx,
|
||||
PrewarmGCRequestKey,
|
||||
time.Now().UnixNano(),
|
||||
10*time.Minute,
|
||||
).Err()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) PrewarmGCRequestValue() (string, error) {
|
||||
v, err := tq.client.Get(tq.ctx, PrewarmGCRequestKey).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetRedisClient returns the underlying Redis client for direct access
|
||||
func (tq *TaskQueue) GetRedisClient() *redis.Client {
|
||||
return tq.client
|
||||
|
|
|
|||
762
internal/queue/sqlite_queue.go
Normal file
762
internal/queue/sqlite_queue.go
Normal file
|
|
@ -0,0 +1,762 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type SQLiteQueue struct {
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewSQLiteQueue(path string) (*SQLiteQueue, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("sqlite queue path is required")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=5000&_foreign_keys=on", path))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
q := &SQLiteQueue{db: db, ctx: ctx, cancel: cancel}
|
||||
|
||||
if err := q.initSchema(); err != nil {
|
||||
_ = db.Close()
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go q.leaseReclaimer()
|
||||
go q.kvJanitor()
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) initSchema() error {
|
||||
stmts := []string{
|
||||
"PRAGMA journal_mode=WAL;",
|
||||
"PRAGMA synchronous=NORMAL;",
|
||||
`CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
job_name TEXT,
|
||||
status TEXT,
|
||||
priority INTEGER,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
payload BLOB
|
||||
);`,
|
||||
"CREATE INDEX IF NOT EXISTS idx_tasks_job_name ON tasks(job_name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);",
|
||||
`CREATE TABLE IF NOT EXISTS queue (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
priority INTEGER,
|
||||
available_at INTEGER,
|
||||
created_at INTEGER
|
||||
);`,
|
||||
"CREATE INDEX IF NOT EXISTS idx_queue_available ON queue(available_at, priority DESC, created_at);",
|
||||
`CREATE TABLE IF NOT EXISTS worker_prewarm (
|
||||
worker_id TEXT PRIMARY KEY,
|
||||
payload BLOB,
|
||||
updated_at INTEGER
|
||||
);`,
|
||||
"CREATE INDEX IF NOT EXISTS idx_prewarm_updated ON worker_prewarm(updated_at);",
|
||||
`CREATE TABLE IF NOT EXISTS kv (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expires_at INTEGER
|
||||
);`,
|
||||
"CREATE INDEX IF NOT EXISTS idx_kv_expires ON kv(expires_at);",
|
||||
`CREATE TABLE IF NOT EXISTS worker_heartbeat (
|
||||
worker_id TEXT PRIMARY KEY,
|
||||
last_seen INTEGER
|
||||
);`,
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
if _, err := q.db.ExecContext(q.ctx, stmt); err != nil {
|
||||
return fmt.Errorf("sqlite schema init failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) Close() error {
|
||||
q.cancel()
|
||||
return q.db.Close()
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) AddTask(task *Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if task.ID == "" {
|
||||
return fmt.Errorf("task id is required")
|
||||
}
|
||||
if task.JobName == "" {
|
||||
return fmt.Errorf("job name is required")
|
||||
}
|
||||
|
||||
if task.MaxRetries == 0 {
|
||||
task.MaxRetries = defaultMaxRetries
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = now
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createdAt := task.CreatedAt.UnixNano()
|
||||
updatedAt := now.UnixNano()
|
||||
|
||||
availableAt := now.UnixNano()
|
||||
if task.NextRetry != nil {
|
||||
availableAt = task.NextRetry.UTC().UnixNano()
|
||||
}
|
||||
|
||||
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO tasks(id, job_name, status, priority, created_at, updated_at, payload)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
job_name=excluded.job_name,
|
||||
status=excluded.status,
|
||||
priority=excluded.priority,
|
||||
created_at=excluded.created_at,
|
||||
updated_at=excluded.updated_at,
|
||||
payload=excluded.payload`,
|
||||
task.ID,
|
||||
task.JobName,
|
||||
task.Status,
|
||||
task.Priority,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
payload,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Status == "queued" {
|
||||
_, err = tx.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO queue(task_id, priority, available_at, created_at)
|
||||
VALUES(?, ?, ?, ?)
|
||||
ON CONFLICT(task_id) DO UPDATE SET
|
||||
priority=excluded.priority,
|
||||
available_at=excluded.available_at,
|
||||
created_at=excluded.created_at`,
|
||||
task.ID,
|
||||
task.Priority,
|
||||
availableAt,
|
||||
createdAt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
TasksQueued.Inc()
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetTask(taskID string) (*Task, error) {
|
||||
row := q.db.QueryRowContext(q.ctx, "SELECT payload FROM tasks WHERE id = ?", taskID)
|
||||
var payload []byte
|
||||
if err := row.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetAllTasks() ([]*Task, error) {
|
||||
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := make([]*Task, 0, 32)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, &t)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetTaskByName(jobName string) (*Task, error) {
|
||||
row := q.db.QueryRowContext(
|
||||
q.ctx,
|
||||
"SELECT payload FROM tasks WHERE job_name = ? ORDER BY created_at DESC LIMIT 1",
|
||||
jobName,
|
||||
)
|
||||
var payload []byte
|
||||
if err := row.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) UpdateTask(task *Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().UnixNano()
|
||||
availableAt := time.Now().UTC().UnixNano()
|
||||
if task.NextRetry != nil {
|
||||
availableAt = task.NextRetry.UTC().UnixNano()
|
||||
}
|
||||
|
||||
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
q.ctx,
|
||||
`UPDATE tasks SET job_name=?, status=?, priority=?, updated_at=?, payload=? WHERE id=?`,
|
||||
task.JobName,
|
||||
task.Status,
|
||||
task.Priority,
|
||||
now,
|
||||
payload,
|
||||
task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Status == "queued" {
|
||||
_, err = tx.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO queue(task_id, priority, available_at, created_at)
|
||||
VALUES(?, ?, ?, ?)
|
||||
ON CONFLICT(task_id) DO UPDATE SET
|
||||
priority=excluded.priority,
|
||||
available_at=excluded.available_at`,
|
||||
task.ID,
|
||||
task.Priority,
|
||||
availableAt,
|
||||
task.CreatedAt.UTC().UnixNano(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) UpdateTaskWithMetrics(task *Task, _ string) error {
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetNextTask() (*Task, error) {
|
||||
t, _, err := q.peekOrPop(false)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) PeekNextTask() (*Task, error) {
|
||||
t, _, err := q.peekOrPop(true)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) peekOrPop(peek bool) (*Task, int64, error) {
|
||||
now := time.Now().UTC().UnixNano()
|
||||
|
||||
query := `SELECT q.task_id, t.payload FROM queue q
|
||||
JOIN tasks t ON t.id = q.task_id
|
||||
WHERE q.available_at <= ?
|
||||
ORDER BY q.priority DESC, q.created_at ASC
|
||||
LIMIT 1`
|
||||
|
||||
row := q.db.QueryRowContext(q.ctx, query, now)
|
||||
var taskID string
|
||||
var payload []byte
|
||||
if err := row.Scan(&taskID, &payload); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if !peek {
|
||||
if _, err := q.db.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
}
|
||||
|
||||
return &t, 0, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
return q.claimTask(workerID, leaseDuration)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetNextTaskWithLeaseBlocking(
|
||||
workerID string,
|
||||
leaseDuration, blockTimeout time.Duration,
|
||||
) (*Task, error) {
|
||||
if blockTimeout <= 0 {
|
||||
blockTimeout = defaultBlockTimeout
|
||||
}
|
||||
deadline := time.Now().Add(blockTimeout)
|
||||
for {
|
||||
t, err := q.claimTask(workerID, leaseDuration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t != nil {
|
||||
return t, nil
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return nil, nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) claimTask(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
if workerID == "" {
|
||||
return nil, fmt.Errorf("worker_id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
nowNs := now.UnixNano()
|
||||
|
||||
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
row := tx.QueryRowContext(
|
||||
q.ctx,
|
||||
`SELECT q.task_id, t.payload FROM queue q
|
||||
JOIN tasks t ON t.id = q.task_id
|
||||
WHERE q.available_at <= ?
|
||||
ORDER BY q.priority DESC, q.created_at ASC
|
||||
LIMIT 1`,
|
||||
nowNs,
|
||||
)
|
||||
|
||||
var taskID string
|
||||
var payload []byte
|
||||
if err := row.Scan(&taskID, &payload); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mirror Redis semantics: acquire lease fields but do not modify status here.
|
||||
exp := now.Add(leaseDuration)
|
||||
t.LeaseExpiry = &exp
|
||||
t.LeasedBy = workerID
|
||||
|
||||
newPayload, err := json.Marshal(&t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := tx.ExecContext(
|
||||
q.ctx,
|
||||
"UPDATE tasks SET updated_at=?, payload=? WHERE id=?",
|
||||
nowNs,
|
||||
newPayload,
|
||||
taskID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
|
||||
}
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
exp := time.Now().UTC().Add(leaseDuration)
|
||||
t.LeaseExpiry = &exp
|
||||
RecordLeaseRenewal(workerID)
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) ReleaseLease(taskID string, workerID string) error {
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
|
||||
}
|
||||
t.LeaseExpiry = nil
|
||||
t.LeasedBy = ""
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) RetryTask(task *Task) error {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
RecordDLQAddition("max_retries")
|
||||
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
|
||||
}
|
||||
|
||||
errorCategory := ErrorUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
}
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
task.Status = "queued"
|
||||
task.LastError = task.Error
|
||||
task.Error = ""
|
||||
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
return q.AddTask(task)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
||||
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) CancelTask(taskID string) error {
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Status = "cancelled"
|
||||
now := time.Now().UTC()
|
||||
t.EndedAt = &now
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) RecordMetric(_, _ string, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) Heartbeat(workerID string) error {
|
||||
if workerID == "" {
|
||||
return fmt.Errorf("worker_id is required")
|
||||
}
|
||||
_, err := q.db.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO worker_heartbeat(worker_id, last_seen) VALUES(?, ?)
|
||||
ON CONFLICT(worker_id) DO UPDATE SET last_seen=excluded.last_seen`,
|
||||
workerID,
|
||||
time.Now().UTC().UnixNano(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) QueueDepth() (int64, error) {
|
||||
row := q.db.QueryRowContext(q.ctx, "SELECT COUNT(1) FROM queue")
|
||||
var n int64
|
||||
if err := row.Scan(&n); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) SetWorkerPrewarmState(state PrewarmState) error {
|
||||
if state.WorkerID == "" {
|
||||
return fmt.Errorf("missing worker_id")
|
||||
}
|
||||
if state.StartedAt == "" {
|
||||
state.StartedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
payload, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = q.db.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO worker_prewarm(worker_id, payload, updated_at) VALUES(?, ?, ?)
|
||||
ON CONFLICT(worker_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at`,
|
||||
state.WorkerID,
|
||||
payload,
|
||||
time.Now().UTC().UnixNano(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) ClearWorkerPrewarmState(workerID string) error {
|
||||
if workerID == "" {
|
||||
return fmt.Errorf("missing worker_id")
|
||||
}
|
||||
_, err := q.db.ExecContext(q.ctx, "DELETE FROM worker_prewarm WHERE worker_id = ?", workerID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
|
||||
if workerID == "" {
|
||||
return nil, fmt.Errorf("missing worker_id")
|
||||
}
|
||||
row := q.db.QueryRowContext(q.ctx, "SELECT payload, updated_at FROM worker_prewarm WHERE worker_id = ?", workerID)
|
||||
var payload []byte
|
||||
var updatedAt int64
|
||||
if err := row.Scan(&payload, &updatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mirror Redis TTL semantics (30s).
|
||||
if time.Since(time.Unix(0, updatedAt)) > 30*time.Second {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var st PrewarmState
|
||||
if err := json.Unmarshal(payload, &st); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &st, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
|
||||
cutoff := time.Now().UTC().Add(-30 * time.Second).UnixNano()
|
||||
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM worker_prewarm WHERE updated_at >= ?", cutoff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := make([]PrewarmState, 0, 8)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var st PrewarmState
|
||||
if err := json.Unmarshal(payload, &st); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, st)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) SignalPrewarmGC() error {
|
||||
return q.kvSet(PrewarmGCRequestKey, fmt.Sprintf("%d", time.Now().UTC().UnixNano()), 10*time.Minute)
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) PrewarmGCRequestValue() (string, error) {
|
||||
v, _, err := q.kvGet(PrewarmGCRequestKey)
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) kvSet(key, value string, ttl time.Duration) error {
|
||||
exp := time.Now().UTC().Add(ttl).UnixNano()
|
||||
_, err := q.db.ExecContext(
|
||||
q.ctx,
|
||||
`INSERT INTO kv(key, value, expires_at) VALUES(?, ?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value=excluded.value, expires_at=excluded.expires_at`,
|
||||
key,
|
||||
value,
|
||||
exp,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) kvGet(key string) (string, bool, error) {
|
||||
row := q.db.QueryRowContext(q.ctx, "SELECT value, expires_at FROM kv WHERE key = ?", key)
|
||||
var value string
|
||||
var exp int64
|
||||
if err := row.Scan(&value, &exp); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, err
|
||||
}
|
||||
if exp > 0 && time.Now().UTC().UnixNano() > exp {
|
||||
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE key = ?", key)
|
||||
return "", false, nil
|
||||
}
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) kvJanitor() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE expires_at > 0 AND expires_at < ?", time.Now().UTC().UnixNano())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) leaseReclaimer() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = q.reclaimExpiredLeases()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) ReclaimExpiredLeases() error {
|
||||
return q.reclaimExpiredLeases()
|
||||
}
|
||||
|
||||
func (q *SQLiteQueue) reclaimExpiredLeases() error {
|
||||
now := time.Now().UTC()
|
||||
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payloads := make([][]byte, 0, 32)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
_ = rows.Close()
|
||||
return err
|
||||
}
|
||||
payloads = append(payloads, payload)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return err
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
for _, payload := range payloads {
|
||||
var t Task
|
||||
if err := json.Unmarshal(payload, &t); err != nil {
|
||||
continue
|
||||
}
|
||||
if t.LeaseExpiry == nil {
|
||||
continue
|
||||
}
|
||||
if t.Status != "running" {
|
||||
continue
|
||||
}
|
||||
if t.LeaseExpiry.Before(now) {
|
||||
t.Error = fmt.Sprintf("worker %s lease expired", t.LeasedBy)
|
||||
RecordLeaseExpiration()
|
||||
if t.RetryCount < t.MaxRetries {
|
||||
_ = q.RetryTask(&t)
|
||||
} else {
|
||||
_ = q.MoveToDeadLetterQueue(&t, "lease expiry after max retries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -6,21 +6,41 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
)
|
||||
|
||||
// DatasetSpec describes a dataset input with optional provenance fields.
|
||||
type DatasetSpec struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Checksum string `json:"checksum,omitempty"`
|
||||
URI string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// Task represents an ML experiment task
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
// TODO(phase1): SnapshotID is an opaque identifier only.
|
||||
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
// Datasets is kept for backward compatibility (legacy callers).
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// Resource requests (optional, 0 means unspecified)
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
MemoryGB int `json:"memory_gb,omitempty"`
|
||||
GPU int `json:"gpu,omitempty"`
|
||||
GPUMemory string `json:"gpu_memory,omitempty"`
|
||||
|
||||
// User ownership and permissions
|
||||
UserID string `json:"user_id"` // User who owns this task
|
||||
|
|
@ -36,6 +56,38 @@ type Task struct {
|
|||
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
|
||||
LastError string `json:"last_error,omitempty"` // Last error encountered
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
|
||||
|
||||
// Optional tracking configuration for this task
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
}
|
||||
|
||||
// TrackingConfig specifies experiment tracking tools to enable for a task.
|
||||
type TrackingConfig struct {
|
||||
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
|
||||
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
|
||||
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
|
||||
}
|
||||
|
||||
// MLflowTrackingConfig controls MLflow integration.
|
||||
type MLflowTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
|
||||
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
|
||||
}
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration.
|
||||
type TensorBoardTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
|
||||
}
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration.
|
||||
type WandbTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Entity string `json:"entity,omitempty"`
|
||||
}
|
||||
|
||||
// Redis key constants
|
||||
|
|
@ -44,5 +96,6 @@ var (
|
|||
TaskPrefix = config.RedisTaskPrefix
|
||||
TaskStatusPrefix = config.RedisTaskStatusPrefix
|
||||
WorkerHeartbeat = config.RedisWorkerHeartbeat
|
||||
WorkerPrewarmKey = config.RedisWorkerPrewarmKey
|
||||
JobMetricsPrefix = config.RedisJobMetricsPrefix
|
||||
)
|
||||
|
|
|
|||
146
internal/storage/db_connect.go
Normal file
146
internal/storage/db_connect.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// Package storage provides database abstraction and job management.
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// DBConfig holds database connection configuration.
|
||||
type DBConfig struct {
|
||||
Type string
|
||||
Connection string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// DB wraps a database connection with type information.
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
// DBTypeSQLite is the constant for SQLite database type
|
||||
const DBTypeSQLite = "sqlite"
|
||||
|
||||
// NewDB creates a new database connection.
|
||||
func NewDB(config DBConfig) (*DB, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
|
||||
switch strings.ToLower(config.Type) {
|
||||
case DBTypeSQLite:
|
||||
conn, err = sql.Open("sqlite3", config.Connection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
// Enable foreign keys
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
// Enable WAL mode for better concurrency
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
|
||||
}
|
||||
// Additional SQLite optimizations for throughput
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set cache size: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set temp store: %w", err)
|
||||
}
|
||||
case "postgres":
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
case "postgresql":
|
||||
// Handle "postgresql" as alias for "postgres"
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
|
||||
}
|
||||
|
||||
// Optimize connection pool for better throughput
|
||||
conn.SetMaxOpenConns(50) // Increase max open connections
|
||||
conn.SetMaxIdleConns(25) // Maintain idle connections
|
||||
conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime
|
||||
conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout
|
||||
|
||||
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
|
||||
}
|
||||
|
||||
func buildPostgresConnectionString(config DBConfig) string {
|
||||
if config.Connection != "" {
|
||||
return config.Connection
|
||||
}
|
||||
|
||||
var connStr strings.Builder
|
||||
connStr.WriteString("host=")
|
||||
if config.Host != "" {
|
||||
connStr.WriteString(config.Host)
|
||||
} else {
|
||||
connStr.WriteString("localhost")
|
||||
}
|
||||
|
||||
if config.Port > 0 {
|
||||
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
|
||||
} else {
|
||||
connStr.WriteString(" port=5432")
|
||||
}
|
||||
|
||||
if config.Username != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
|
||||
}
|
||||
|
||||
if config.Password != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
|
||||
}
|
||||
|
||||
if config.Database != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
|
||||
} else {
|
||||
connStr.WriteString(" dbname=fetch_ml")
|
||||
}
|
||||
|
||||
connStr.WriteString(" sslmode=disable")
|
||||
return connStr.String()
|
||||
}
|
||||
|
||||
// NewDBFromPath creates a new database from a file path (legacy constructor).
|
||||
func NewDBFromPath(dbPath string) (*DB, error) {
|
||||
return NewDB(DBConfig{
|
||||
Type: DBTypeSQLite,
|
||||
Connection: dbPath,
|
||||
})
|
||||
}
|
||||
|
||||
// Initialize creates database schema.
|
||||
func (db *DB) Initialize(schema string) error {
|
||||
if _, err := db.conn.ExecContext(context.Background(), schema); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (db *DB) Close() error {
|
||||
return db.conn.Close()
|
||||
}
|
||||
564
internal/storage/db_experiments.go
Normal file
564
internal/storage/db_experiments.go
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
// Package storage provides database abstraction and job management.
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Experiment struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Status string `json:"status"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
WorkspaceID string `json:"workspace_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ExperimentEnvironment struct {
|
||||
PythonVersion string `json:"python_version"`
|
||||
CUDAVersion string `json:"cuda_version,omitempty"`
|
||||
SystemOS string `json:"system_os"`
|
||||
SystemArch string `json:"system_arch"`
|
||||
Hostname string `json:"hostname"`
|
||||
RequirementsHash string `json:"requirements_hash"`
|
||||
CondaEnvHash string `json:"conda_env_hash,omitempty"`
|
||||
Dependencies json.RawMessage `json:"dependencies,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ExperimentGitInfo struct {
|
||||
CommitSHA string `json:"commit_sha"`
|
||||
Branch string `json:"branch"`
|
||||
RemoteURL string `json:"remote_url"`
|
||||
IsDirty bool `json:"is_dirty"`
|
||||
DiffPatch string `json:"diff_patch,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ExperimentSeeds struct {
|
||||
Numpy *int64 `json:"numpy_seed,omitempty"`
|
||||
Torch *int64 `json:"torch_seed,omitempty"`
|
||||
TensorFlow *int64 `json:"tensorflow_seed,omitempty"`
|
||||
Random *int64 `json:"random_seed,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ExperimentWithMetadata struct {
|
||||
Experiment Experiment `json:"experiment"`
|
||||
Environment *ExperimentEnvironment `json:"environment,omitempty"`
|
||||
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
|
||||
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
|
||||
}
|
||||
|
||||
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
|
||||
if exp == nil {
|
||||
return fmt.Errorf("experiment is nil")
|
||||
}
|
||||
if exp.ID == "" {
|
||||
return fmt.Errorf("experiment id is required")
|
||||
}
|
||||
if exp.Name == "" {
|
||||
return fmt.Errorf("experiment name is required")
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
description = excluded.description,
|
||||
status = excluded.status,
|
||||
user_id = excluded.user_id,
|
||||
workspace_id = excluded.workspace_id`
|
||||
} else {
|
||||
query = `INSERT INTO experiments (id, name, description, status, user_id, workspace_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
description = EXCLUDED.description,
|
||||
status = EXCLUDED.status,
|
||||
user_id = EXCLUDED.user_id,
|
||||
workspace_id = EXCLUDED.workspace_id`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
exp.ID,
|
||||
exp.Name,
|
||||
exp.Description,
|
||||
exp.Status,
|
||||
exp.UserID,
|
||||
exp.WorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert experiment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertExperimentEnvironment(
|
||||
ctx context.Context,
|
||||
experimentID string,
|
||||
env *ExperimentEnvironment,
|
||||
) error {
|
||||
if experimentID == "" {
|
||||
return fmt.Errorf("experiment id is required")
|
||||
}
|
||||
if env == nil {
|
||||
return fmt.Errorf("environment is nil")
|
||||
}
|
||||
|
||||
deps := ""
|
||||
if len(env.Dependencies) > 0 {
|
||||
deps = string(env.Dependencies)
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO experiment_environments
|
||||
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
|
||||
requirements_hash, conda_env_hash, dependencies)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(experiment_id) DO UPDATE SET
|
||||
python_version = excluded.python_version,
|
||||
cuda_version = excluded.cuda_version,
|
||||
system_os = excluded.system_os,
|
||||
system_arch = excluded.system_arch,
|
||||
hostname = excluded.hostname,
|
||||
requirements_hash = excluded.requirements_hash,
|
||||
conda_env_hash = excluded.conda_env_hash,
|
||||
dependencies = excluded.dependencies`
|
||||
} else {
|
||||
query = `INSERT INTO experiment_environments
|
||||
(experiment_id, python_version, cuda_version, system_os, system_arch, hostname,
|
||||
requirements_hash, conda_env_hash, dependencies)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (experiment_id) DO UPDATE SET
|
||||
python_version = EXCLUDED.python_version,
|
||||
cuda_version = EXCLUDED.cuda_version,
|
||||
system_os = EXCLUDED.system_os,
|
||||
system_arch = EXCLUDED.system_arch,
|
||||
hostname = EXCLUDED.hostname,
|
||||
requirements_hash = EXCLUDED.requirements_hash,
|
||||
conda_env_hash = EXCLUDED.conda_env_hash,
|
||||
dependencies = EXCLUDED.dependencies`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
experimentID,
|
||||
env.PythonVersion,
|
||||
env.CUDAVersion,
|
||||
env.SystemOS,
|
||||
env.SystemArch,
|
||||
env.Hostname,
|
||||
env.RequirementsHash,
|
||||
env.CondaEnvHash,
|
||||
deps,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert experiment environment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertExperimentGitInfo(
|
||||
ctx context.Context,
|
||||
experimentID string,
|
||||
info *ExperimentGitInfo,
|
||||
) error {
|
||||
if experimentID == "" {
|
||||
return fmt.Errorf("experiment id is required")
|
||||
}
|
||||
if info == nil {
|
||||
return fmt.Errorf("git info is nil")
|
||||
}
|
||||
|
||||
isDirty := 0
|
||||
if info.IsDirty {
|
||||
isDirty = 1
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO experiment_git_info
|
||||
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(experiment_id) DO UPDATE SET
|
||||
commit_sha = excluded.commit_sha,
|
||||
branch = excluded.branch,
|
||||
remote_url = excluded.remote_url,
|
||||
is_dirty = excluded.is_dirty,
|
||||
diff_patch = excluded.diff_patch`
|
||||
} else {
|
||||
query = `INSERT INTO experiment_git_info
|
||||
(experiment_id, commit_sha, branch, remote_url, is_dirty, diff_patch)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (experiment_id) DO UPDATE SET
|
||||
commit_sha = EXCLUDED.commit_sha,
|
||||
branch = EXCLUDED.branch,
|
||||
remote_url = EXCLUDED.remote_url,
|
||||
is_dirty = EXCLUDED.is_dirty,
|
||||
diff_patch = EXCLUDED.diff_patch`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
experimentID,
|
||||
info.CommitSHA,
|
||||
info.Branch,
|
||||
info.RemoteURL,
|
||||
isDirty,
|
||||
info.DiffPatch,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert experiment git info: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertExperimentSeeds(
|
||||
ctx context.Context,
|
||||
experimentID string,
|
||||
seeds *ExperimentSeeds,
|
||||
) error {
|
||||
if experimentID == "" {
|
||||
return fmt.Errorf("experiment id is required")
|
||||
}
|
||||
if seeds == nil {
|
||||
return fmt.Errorf("seeds is nil")
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO experiment_seeds
|
||||
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(experiment_id) DO UPDATE SET
|
||||
numpy_seed = excluded.numpy_seed,
|
||||
torch_seed = excluded.torch_seed,
|
||||
tensorflow_seed = excluded.tensorflow_seed,
|
||||
random_seed = excluded.random_seed`
|
||||
} else {
|
||||
query = `INSERT INTO experiment_seeds
|
||||
(experiment_id, numpy_seed, torch_seed, tensorflow_seed, random_seed)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (experiment_id) DO UPDATE SET
|
||||
numpy_seed = EXCLUDED.numpy_seed,
|
||||
torch_seed = EXCLUDED.torch_seed,
|
||||
tensorflow_seed = EXCLUDED.tensorflow_seed,
|
||||
random_seed = EXCLUDED.random_seed`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
experimentID,
|
||||
seeds.Numpy,
|
||||
seeds.Torch,
|
||||
seeds.TensorFlow,
|
||||
seeds.Random,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert experiment seeds: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetExperimentWithMetadata(
|
||||
ctx context.Context,
|
||||
experimentID string,
|
||||
) (*ExperimentWithMetadata, error) {
|
||||
if experimentID == "" {
|
||||
return nil, fmt.Errorf("experiment id is required")
|
||||
}
|
||||
|
||||
var exp Experiment
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, name,
|
||||
COALESCE(description, ''), COALESCE(status, ''),
|
||||
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
|
||||
created_at, updated_at
|
||||
FROM experiments WHERE id = ?`
|
||||
} else {
|
||||
query = `SELECT id, name,
|
||||
COALESCE(description, ''), COALESCE(status, ''),
|
||||
COALESCE(user_id, ''), COALESCE(workspace_id, ''),
|
||||
created_at, updated_at
|
||||
FROM experiments WHERE id = $1`
|
||||
}
|
||||
|
||||
err := db.conn.QueryRowContext(ctx, query, experimentID).Scan(
|
||||
&exp.ID,
|
||||
&exp.Name,
|
||||
&exp.Description,
|
||||
&exp.Status,
|
||||
&exp.UserID,
|
||||
&exp.WorkspaceID,
|
||||
&exp.CreatedAt,
|
||||
&exp.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get experiment: %w", err)
|
||||
}
|
||||
|
||||
result := &ExperimentWithMetadata{Experiment: exp}
|
||||
|
||||
var env ExperimentEnvironment
|
||||
var envDeps sql.NullString
|
||||
var envQuery string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
|
||||
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
|
||||
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
|
||||
dependencies, created_at
|
||||
FROM experiment_environments WHERE experiment_id = ?`
|
||||
} else {
|
||||
envQuery = `SELECT COALESCE(python_version, ''), COALESCE(cuda_version, ''),
|
||||
COALESCE(system_os, ''), COALESCE(system_arch, ''), COALESCE(hostname, ''),
|
||||
COALESCE(requirements_hash, ''), COALESCE(conda_env_hash, ''),
|
||||
dependencies, created_at
|
||||
FROM experiment_environments WHERE experiment_id = $1`
|
||||
}
|
||||
if err := db.conn.QueryRowContext(ctx, envQuery, experimentID).Scan(
|
||||
&env.PythonVersion,
|
||||
&env.CUDAVersion,
|
||||
&env.SystemOS,
|
||||
&env.SystemArch,
|
||||
&env.Hostname,
|
||||
&env.RequirementsHash,
|
||||
&env.CondaEnvHash,
|
||||
&envDeps,
|
||||
&env.CreatedAt,
|
||||
); err == nil {
|
||||
if envDeps.Valid && envDeps.String != "" {
|
||||
env.Dependencies = json.RawMessage(envDeps.String)
|
||||
}
|
||||
result.Environment = &env
|
||||
}
|
||||
|
||||
var git ExperimentGitInfo
|
||||
var gitDirty sql.NullInt64
|
||||
var gitQuery string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
|
||||
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
|
||||
COALESCE(diff_patch, ''), created_at
|
||||
FROM experiment_git_info WHERE experiment_id = ?`
|
||||
} else {
|
||||
gitQuery = `SELECT COALESCE(commit_sha, ''), COALESCE(branch, ''),
|
||||
COALESCE(remote_url, ''), COALESCE(is_dirty, 0),
|
||||
COALESCE(diff_patch, ''), created_at
|
||||
FROM experiment_git_info WHERE experiment_id = $1`
|
||||
}
|
||||
if err := db.conn.QueryRowContext(ctx, gitQuery, experimentID).Scan(
|
||||
&git.CommitSHA,
|
||||
&git.Branch,
|
||||
&git.RemoteURL,
|
||||
&gitDirty,
|
||||
&git.DiffPatch,
|
||||
&git.CreatedAt,
|
||||
); err == nil {
|
||||
git.IsDirty = gitDirty.Valid && gitDirty.Int64 != 0
|
||||
result.GitInfo = &git
|
||||
}
|
||||
|
||||
var seeds ExperimentSeeds
|
||||
var numpySeed, torchSeed, tfSeed, randSeed sql.NullInt64
|
||||
var seedsQuery string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
|
||||
FROM experiment_seeds WHERE experiment_id = ?`
|
||||
} else {
|
||||
seedsQuery = `SELECT numpy_seed, torch_seed, tensorflow_seed, random_seed, created_at
|
||||
FROM experiment_seeds WHERE experiment_id = $1`
|
||||
}
|
||||
if err := db.conn.QueryRowContext(ctx, seedsQuery, experimentID).Scan(
|
||||
&numpySeed,
|
||||
&torchSeed,
|
||||
&tfSeed,
|
||||
&randSeed,
|
||||
&seeds.CreatedAt,
|
||||
); err == nil {
|
||||
if numpySeed.Valid {
|
||||
v := numpySeed.Int64
|
||||
seeds.Numpy = &v
|
||||
}
|
||||
if torchSeed.Valid {
|
||||
v := torchSeed.Int64
|
||||
seeds.Torch = &v
|
||||
}
|
||||
if tfSeed.Valid {
|
||||
v := tfSeed.Int64
|
||||
seeds.TensorFlow = &v
|
||||
}
|
||||
if randSeed.Valid {
|
||||
v := randSeed.Int64
|
||||
seeds.Random = &v
|
||||
}
|
||||
result.Seeds = &seeds
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertDataset(ctx context.Context, ds *Dataset) error {
|
||||
if ds == nil {
|
||||
return fmt.Errorf("dataset is nil")
|
||||
}
|
||||
if ds.Name == "" {
|
||||
return fmt.Errorf("dataset name is required")
|
||||
}
|
||||
if ds.URL == "" {
|
||||
return fmt.Errorf("dataset url is required")
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO datasets (name, url)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
url = excluded.url`
|
||||
} else {
|
||||
query = `INSERT INTO datasets (name, url)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
url = EXCLUDED.url`
|
||||
}
|
||||
|
||||
if _, err := db.conn.ExecContext(ctx, query, ds.Name, ds.URL); err != nil {
|
||||
return fmt.Errorf("failed to upsert dataset: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("dataset name is required")
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = ?`
|
||||
} else {
|
||||
query = `SELECT name, url, created_at, updated_at FROM datasets WHERE name = $1`
|
||||
}
|
||||
|
||||
var ds Dataset
|
||||
if err := db.conn.QueryRowContext(ctx, query, name).Scan(
|
||||
&ds.Name,
|
||||
&ds.URL,
|
||||
&ds.CreatedAt,
|
||||
&ds.UpdatedAt,
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get dataset: %w", err)
|
||||
}
|
||||
return &ds, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
|
||||
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
|
||||
var args []interface{}
|
||||
if limit > 0 {
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query += " LIMIT ?"
|
||||
} else {
|
||||
query += " LIMIT $1"
|
||||
}
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list datasets: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var out []*Dataset
|
||||
for rows.Next() {
|
||||
var ds Dataset
|
||||
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan dataset: %w", err)
|
||||
}
|
||||
out = append(out, &ds)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating datasets: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Dataset, error) {
|
||||
if term == "" {
|
||||
return []*Dataset{}, nil
|
||||
}
|
||||
|
||||
// Escape %/_ for LIKE and use parameterized query.
|
||||
escaped := strings.ReplaceAll(term, "\\", "\\\\")
|
||||
escaped = strings.ReplaceAll(escaped, "%", "\\%")
|
||||
escaped = strings.ReplaceAll(escaped, "_", "\\_")
|
||||
pattern := "%" + escaped + "%"
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT name, url, created_at, updated_at FROM datasets
|
||||
WHERE name LIKE ? ESCAPE '\'
|
||||
ORDER BY name ASC`
|
||||
args = append(args, pattern)
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
}
|
||||
} else {
|
||||
query = `SELECT name, url, created_at, updated_at FROM datasets
|
||||
WHERE name LIKE $1 ESCAPE '\'
|
||||
ORDER BY name ASC`
|
||||
args = append(args, pattern)
|
||||
if limit > 0 {
|
||||
query += " LIMIT $2"
|
||||
args = append(args, limit)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search datasets: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var out []*Dataset
|
||||
for rows.Next() {
|
||||
var ds Dataset
|
||||
if err := rows.Scan(&ds.Name, &ds.URL, &ds.CreatedAt, &ds.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan dataset: %w", err)
|
||||
}
|
||||
out = append(out, &ds)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating datasets: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
// Package storage provides database abstraction and job management.
|
||||
package storage
|
||||
|
||||
import (
|
||||
|
|
@ -6,133 +5,9 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// DBConfig holds database connection configuration.
|
||||
type DBConfig struct {
|
||||
Type string
|
||||
Connection string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// DB wraps a database connection with type information.
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
// DBTypeSQLite is the constant for SQLite database type
|
||||
const DBTypeSQLite = "sqlite"
|
||||
|
||||
// NewDB creates a new database connection.
|
||||
func NewDB(config DBConfig) (*DB, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
|
||||
switch strings.ToLower(config.Type) {
|
||||
case DBTypeSQLite:
|
||||
conn, err = sql.Open("sqlite3", config.Connection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
// Enable foreign keys
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
// Enable WAL mode for better concurrency
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA journal_mode = WAL"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
|
||||
}
|
||||
// Additional SQLite optimizations for throughput
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA synchronous = NORMAL"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA cache_size = 10000"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set cache size: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(context.Background(), "PRAGMA temp_store = MEMORY"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set temp store: %w", err)
|
||||
}
|
||||
case "postgres":
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
case "postgresql":
|
||||
// Handle "postgresql" as alias for "postgres"
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
|
||||
}
|
||||
|
||||
// Optimize connection pool for better throughput
|
||||
conn.SetMaxOpenConns(50) // Increase max open connections
|
||||
conn.SetMaxIdleConns(25) // Maintain idle connections
|
||||
conn.SetConnMaxLifetime(5 * time.Minute) // Connection lifetime
|
||||
conn.SetConnMaxIdleTime(2 * time.Minute) // Idle connection timeout
|
||||
|
||||
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
|
||||
}
|
||||
|
||||
func buildPostgresConnectionString(config DBConfig) string {
|
||||
if config.Connection != "" {
|
||||
return config.Connection
|
||||
}
|
||||
|
||||
var connStr strings.Builder
|
||||
connStr.WriteString("host=")
|
||||
if config.Host != "" {
|
||||
connStr.WriteString(config.Host)
|
||||
} else {
|
||||
connStr.WriteString("localhost")
|
||||
}
|
||||
|
||||
if config.Port > 0 {
|
||||
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
|
||||
} else {
|
||||
connStr.WriteString(" port=5432")
|
||||
}
|
||||
|
||||
if config.Username != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
|
||||
}
|
||||
|
||||
if config.Password != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
|
||||
}
|
||||
|
||||
if config.Database != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
|
||||
} else {
|
||||
connStr.WriteString(" dbname=fetch_ml")
|
||||
}
|
||||
|
||||
connStr.WriteString(" sslmode=disable")
|
||||
return connStr.String()
|
||||
}
|
||||
|
||||
// NewDBFromPath creates a new database from a file path (legacy constructor).
|
||||
func NewDBFromPath(dbPath string) (*DB, error) {
|
||||
return NewDB(DBConfig{
|
||||
Type: DBTypeSQLite,
|
||||
Connection: dbPath,
|
||||
})
|
||||
}
|
||||
|
||||
// Job represents a machine learning job in the system.
|
||||
type Job struct {
|
||||
ID string `json:"id"`
|
||||
|
|
@ -161,19 +36,6 @@ type Worker struct {
|
|||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Initialize creates database schema.
|
||||
func (db *DB) Initialize(schema string) error {
|
||||
if _, err := db.conn.ExecContext(context.Background(), schema); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (db *DB) Close() error {
|
||||
return db.conn.Close()
|
||||
}
|
||||
|
||||
// CreateJob inserts a new job into the database.
|
||||
func (db *DB) CreateJob(job *Job) error {
|
||||
datasetsJSON, _ := json.Marshal(job.Datasets)
|
||||
|
|
@ -185,11 +47,20 @@ func (db *DB) CreateJob(job *Job) error {
|
|||
VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, job.ID, job.JobName, job.Args, job.Status,
|
||||
job.Priority, string(datasetsJSON), string(metadataJSON))
|
||||
_, err := db.conn.ExecContext(
|
||||
context.Background(),
|
||||
query,
|
||||
job.ID,
|
||||
job.JobName,
|
||||
job.Args,
|
||||
job.Status,
|
||||
job.Priority,
|
||||
string(datasetsJSON),
|
||||
string(metadataJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create job: %w", err)
|
||||
}
|
||||
|
|
@ -242,21 +113,30 @@ func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error {
|
|||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?,
|
||||
started_at = CASE WHEN ? = 'running' AND started_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = ?`
|
||||
started_at = CASE WHEN ? = 'running' AND started_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN ? IN ('completed', 'failed')
|
||||
AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = ?`
|
||||
} else {
|
||||
query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3,
|
||||
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = $6`
|
||||
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL
|
||||
THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN $5 IN ('completed', 'failed')
|
||||
AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = $6`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, status, workerID, errorMsg, status, status, id)
|
||||
_, err := db.conn.ExecContext(
|
||||
context.Background(),
|
||||
query,
|
||||
status,
|
||||
workerID,
|
||||
errorMsg,
|
||||
status,
|
||||
status,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job status: %w", err)
|
||||
}
|
||||
|
|
@ -339,17 +219,25 @@ func (db *DB) RegisterWorker(worker *Worker) error {
|
|||
VALUES (?, ?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
hostname = EXCLUDED.hostname,
|
||||
status = EXCLUDED.status,
|
||||
current_jobs = EXCLUDED.current_jobs,
|
||||
max_jobs = EXCLUDED.max_jobs,
|
||||
metadata = EXCLUDED.metadata`
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
hostname = EXCLUDED.hostname,
|
||||
status = EXCLUDED.status,
|
||||
current_jobs = EXCLUDED.current_jobs,
|
||||
max_jobs = EXCLUDED.max_jobs,
|
||||
metadata = EXCLUDED.metadata`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, worker.ID, worker.Hostname, worker.Status,
|
||||
worker.CurrentJobs, worker.MaxJobs, string(metadataJSON))
|
||||
_, err := db.conn.ExecContext(
|
||||
context.Background(),
|
||||
query,
|
||||
worker.ID,
|
||||
worker.Hostname,
|
||||
worker.Status,
|
||||
worker.CurrentJobs,
|
||||
worker.MaxJobs,
|
||||
string(metadataJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register worker: %w", err)
|
||||
}
|
||||
|
|
@ -377,10 +265,14 @@ func (db *DB) GetActiveWorkers() ([]*Worker, error) {
|
|||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
|
||||
FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')`
|
||||
FROM workers
|
||||
WHERE status = 'active'
|
||||
AND last_heartbeat > datetime('now', '-30 seconds')`
|
||||
} else {
|
||||
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
|
||||
FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
|
||||
FROM workers
|
||||
WHERE status = 'active'
|
||||
AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query)
|
||||
32
internal/storage/schema_embed.go
Normal file
32
internal/storage/schema_embed.go
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed schema_sqlite.sql
|
||||
var sqliteSchemaFS embed.FS
|
||||
|
||||
//go:embed schema_postgres.sql
|
||||
var postgresSchemaFS embed.FS
|
||||
|
||||
func SchemaForDBType(dbType string) (string, error) {
|
||||
switch strings.ToLower(dbType) {
|
||||
case DBTypeSQLite:
|
||||
b, err := sqliteSchemaFS.ReadFile("schema_sqlite.sql")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read sqlite schema: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
case "postgres", "postgresql":
|
||||
b, err := postgresSchemaFS.ReadFile("schema_postgres.sql")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read postgres schema: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
}
|
||||
|
|
@ -43,6 +43,13 @@ CREATE TABLE IF NOT EXISTS system_metrics (
|
|||
PRIMARY KEY (metric_name, timestamp)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS datasets (
|
||||
name TEXT PRIMARY KEY,
|
||||
url TEXT NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
|
||||
|
|
@ -52,6 +59,8 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
|
|||
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
|
||||
|
||||
-- Function to update updated_at timestamp
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
|
|
|
|||
|
|
@ -43,6 +43,59 @@ CREATE TABLE IF NOT EXISTS system_metrics (
|
|||
PRIMARY KEY (metric_name, timestamp)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS experiments (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id TEXT,
|
||||
workspace_id TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS experiment_environments (
|
||||
experiment_id TEXT PRIMARY KEY,
|
||||
python_version TEXT,
|
||||
cuda_version TEXT,
|
||||
system_os TEXT,
|
||||
system_arch TEXT,
|
||||
hostname TEXT,
|
||||
requirements_hash TEXT,
|
||||
conda_env_hash TEXT,
|
||||
dependencies TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS experiment_git_info (
|
||||
experiment_id TEXT PRIMARY KEY,
|
||||
commit_sha TEXT,
|
||||
branch TEXT,
|
||||
remote_url TEXT,
|
||||
is_dirty INTEGER DEFAULT 0,
|
||||
diff_patch TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS experiment_seeds (
|
||||
experiment_id TEXT PRIMARY KEY,
|
||||
numpy_seed INTEGER,
|
||||
torch_seed INTEGER,
|
||||
tensorflow_seed INTEGER,
|
||||
random_seed INTEGER,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS datasets (
|
||||
name TEXT PRIMARY KEY,
|
||||
url TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
|
||||
|
|
@ -51,6 +104,11 @@ CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
|
|||
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_experiments_created_at ON experiments(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_experiments_user_id ON experiments(user_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
|
||||
|
||||
-- Triggers to update timestamps
|
||||
CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
|
||||
|
|
@ -59,3 +117,17 @@ CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
|
|||
BEGIN
|
||||
UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_experiments_timestamp
|
||||
AFTER UPDATE ON experiments
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE experiments SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_datasets_timestamp
|
||||
AFTER UPDATE ON datasets
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE datasets SET updated_at = CURRENT_TIMESTAMP WHERE name = NEW.name;
|
||||
END;
|
||||
|
|
|
|||
Loading…
Reference in a new issue