package storage import ( "context" "database/sql" "encoding/json" "fmt" "time" ) // Job represents a machine learning job in the system. type Job struct { ID string `json:"id"` JobName string `json:"job_name"` Args string `json:"args"` Status string `json:"status"` Priority int64 `json:"priority"` CreatedAt time.Time `json:"created_at"` StartedAt *time.Time `json:"started_at,omitempty"` EndedAt *time.Time `json:"ended_at,omitempty"` WorkerID string `json:"worker_id,omitempty"` Error string `json:"error,omitempty"` Datasets []string `json:"datasets,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` UpdatedAt time.Time `json:"updated_at"` } // Worker represents a worker node in the system. type Worker struct { ID string `json:"id"` Hostname string `json:"hostname"` LastHeartbeat time.Time `json:"last_heartbeat"` Status string `json:"status"` CurrentJobs int `json:"current_jobs"` MaxJobs int `json:"max_jobs"` Metadata map[string]string `json:"metadata,omitempty"` } // CreateJob inserts a new job into the database. func (db *DB) CreateJob(job *Job) error { datasetsJSON, _ := json.Marshal(job.Datasets) metadataJSON, _ := json.Marshal(job.Metadata) var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7)` } _, err := db.conn.ExecContext( context.Background(), query, job.ID, job.JobName, job.Args, job.Status, job.Priority, string(datasetsJSON), string(metadataJSON), ) if err != nil { return fmt.Errorf("failed to create job: %w", err) } return nil } // GetJob retrieves a job by ID. func (db *DB) GetJob(id string) (*Job, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs WHERE id = ?` } else { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs WHERE id = $1` } var job Job var datasetsJSON, metadataJSON string var workerID sql.NullString var errorMsg sql.NullString err := db.conn.QueryRowContext(context.Background(), query, id).Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to get job: %w", err) } if workerID.Valid { job.WorkerID = workerID.String } if errorMsg.Valid { job.Error = errorMsg.String } _ = json.Unmarshal([]byte(datasetsJSON), &job.Datasets) _ = json.Unmarshal([]byte(metadataJSON), &job.Metadata) return &job, nil } // UpdateJobStatus updates the status of a job with optional worker and error info. func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { var query string if db.dbType == DBTypeSQLite { query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?, started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END WHERE id = ?` } else { query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3, started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END WHERE id = $6` } _, err := db.conn.ExecContext( context.Background(), query, status, workerID, errorMsg, status, status, id, ) if err != nil { return fmt.Errorf("failed to update job status: %w", err) } return nil } // ListJobs retrieves jobs with optional status filter and limit. func (db *DB) ListJobs(status string, limit int) ([]*Job, error) { query := `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs` var args []interface{} if status != "" { if db.dbType == DBTypeSQLite { query += " WHERE status = ?" } else { query += " WHERE status = $1" } args = append(args, status) } query += " ORDER BY created_at DESC" if limit > 0 { if db.dbType == DBTypeSQLite { query += " LIMIT ?" } else { query += fmt.Sprintf(" LIMIT $%d", len(args)+1) } args = append(args, limit) } rows, err := db.conn.QueryContext(context.Background(), query, args...) if err != nil { return nil, fmt.Errorf("failed to list jobs: %w", err) } defer func() { _ = rows.Close() }() var jobs []*Job for rows.Next() { var job Job var datasetsJSON, metadataJSON string var workerID sql.NullString var errorMsg sql.NullString err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to scan job: %w", err) } if workerID.Valid { job.WorkerID = workerID.String } if errorMsg.Valid { job.Error = errorMsg.String } _ = json.Unmarshal([]byte(datasetsJSON), &job.Datasets) _ = json.Unmarshal([]byte(metadataJSON), &job.Metadata) jobs = append(jobs, &job) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating jobs: %w", err) } return jobs, nil } // DeleteJob removes a job from the database by ID. func (db *DB) DeleteJob(id string) error { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM jobs WHERE id = ?` } else { query = `DELETE FROM jobs WHERE id = $1` } _, err := db.conn.ExecContext(context.Background(), query, id) if err != nil { return fmt.Errorf("failed to delete job: %w", err) } return nil } // DeleteJobsByPrefix removes all jobs with IDs matching the given prefix. func (db *DB) DeleteJobsByPrefix(prefix string) error { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM jobs WHERE id LIKE ?` } else { query = `DELETE FROM jobs WHERE id LIKE $1` } _, err := db.conn.ExecContext(context.Background(), query, prefix+"%") if err != nil { return fmt.Errorf("failed to delete jobs by prefix: %w", err) } return nil } // RegisterWorker registers or updates a worker in the database. func (db *DB) RegisterWorker(worker *Worker) error { metadataJSON, _ := json.Marshal(worker.Metadata) var query string if db.dbType == DBTypeSQLite { query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) VALUES (?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO UPDATE SET hostname = EXCLUDED.hostname, status = EXCLUDED.status, current_jobs = EXCLUDED.current_jobs, max_jobs = EXCLUDED.max_jobs, metadata = EXCLUDED.metadata` } _, err := db.conn.ExecContext( context.Background(), query, worker.ID, worker.Hostname, worker.Status, worker.CurrentJobs, worker.MaxJobs, string(metadataJSON), ) if err != nil { return fmt.Errorf("failed to register worker: %w", err) } return nil } // UpdateWorkerHeartbeat updates the last heartbeat timestamp for a worker. func (db *DB) UpdateWorkerHeartbeat(workerID string) error { var query string if db.dbType == DBTypeSQLite { query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?` } else { query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1` } _, err := db.conn.ExecContext(context.Background(), query, workerID) if err != nil { return fmt.Errorf("failed to update worker heartbeat: %w", err) } return nil } // GetActiveWorkers retrieves all currently active workers. func (db *DB) GetActiveWorkers() ([]*Worker, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')` } else { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'` } rows, err := db.conn.QueryContext(context.Background(), query) if err != nil { return nil, fmt.Errorf("failed to get active workers: %w", err) } defer func() { _ = rows.Close() }() var workers []*Worker for rows.Next() { var worker Worker var metadataJSON string err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat, &worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON) if err != nil { return nil, fmt.Errorf("failed to scan worker: %w", err) } _ = json.Unmarshal([]byte(metadataJSON), &worker.Metadata) workers = append(workers, &worker) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating workers: %w", err) } return workers, nil } // RecordJobMetric records a metric for a specific job. func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error { var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)` } else { query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)` } _, err := db.conn.ExecContext(context.Background(), query, jobID, metricName, metricValue) if err != nil { return fmt.Errorf("failed to record job metric: %w", err) } return nil } // RecordSystemMetric records a system-wide metric. func (db *DB) RecordSystemMetric(metricName, metricValue string) error { var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)` } else { query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)` } _, err := db.conn.ExecContext(context.Background(), query, metricName, metricValue) if err != nil { return fmt.Errorf("failed to record system metric: %w", err) } return nil } // GetJobMetrics retrieves all metrics for a specific job. func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT metric_name, metric_value FROM job_metrics WHERE job_id = ? ORDER BY timestamp DESC` } else { query = `SELECT metric_name, metric_value FROM job_metrics WHERE job_id = $1 ORDER BY timestamp DESC` } rows, err := db.conn.QueryContext(context.Background(), query, jobID) if err != nil { return nil, fmt.Errorf("failed to get job metrics: %w", err) } defer func() { _ = rows.Close() }() metrics := make(map[string]string) for rows.Next() { var name, value string if err := rows.Scan(&name, &value); err != nil { return nil, fmt.Errorf("failed to scan metric: %w", err) } metrics[name] = value } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating job metrics: %w", err) } return metrics, nil }