package storage import ( "database/sql" "encoding/json" "fmt" "strings" "time" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) type DBConfig struct { Type string Connection string Host string Port int Username string Password string Database string } type DB struct { conn *sql.DB dbType string } func NewDB(config DBConfig) (*DB, error) { var conn *sql.DB var err error switch strings.ToLower(config.Type) { case "sqlite": conn, err = sql.Open("sqlite3", config.Connection) if err != nil { return nil, fmt.Errorf("failed to open SQLite database: %w", err) } // Enable foreign keys if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil { return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } // Enable WAL mode for better concurrency if _, err := conn.Exec("PRAGMA journal_mode = WAL"); err != nil { return nil, fmt.Errorf("failed to enable WAL mode: %w", err) } case "postgres": connStr := buildPostgresConnectionString(config) conn, err = sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) } case "postgresql": // Handle "postgresql" as alias for "postgres" connStr := buildPostgresConnectionString(config) conn, err = sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) } default: return nil, fmt.Errorf("unsupported database type: %s", config.Type) } return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil } func buildPostgresConnectionString(config DBConfig) string { if config.Connection != "" { return config.Connection } var connStr strings.Builder connStr.WriteString("host=") if config.Host != "" { connStr.WriteString(config.Host) } else { connStr.WriteString("localhost") } if config.Port > 0 { connStr.WriteString(fmt.Sprintf(" port=%d", config.Port)) } else { connStr.WriteString(" port=5432") } if config.Username != "" { connStr.WriteString(fmt.Sprintf(" user=%s", config.Username)) } if config.Password != "" { connStr.WriteString(fmt.Sprintf(" password=%s", config.Password)) } if config.Database != "" { connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database)) } else { connStr.WriteString(" dbname=fetch_ml") } connStr.WriteString(" sslmode=disable") return connStr.String() } // Legacy constructor for backward compatibility func NewDBFromPath(dbPath string) (*DB, error) { return NewDB(DBConfig{ Type: "sqlite", Connection: dbPath, }) } type Job struct { ID string `json:"id"` JobName string `json:"job_name"` Args string `json:"args"` Status string `json:"status"` Priority int64 `json:"priority"` CreatedAt time.Time `json:"created_at"` StartedAt *time.Time `json:"started_at,omitempty"` EndedAt *time.Time `json:"ended_at,omitempty"` WorkerID string `json:"worker_id,omitempty"` Error string `json:"error,omitempty"` Datasets []string `json:"datasets,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` UpdatedAt time.Time `json:"updated_at"` } type Worker struct { ID string `json:"id"` Hostname string `json:"hostname"` LastHeartbeat time.Time `json:"last_heartbeat"` Status string `json:"status"` CurrentJobs int `json:"current_jobs"` MaxJobs int `json:"max_jobs"` Metadata map[string]string `json:"metadata,omitempty"` } func (db *DB) Initialize(schema string) error { if _, err := db.conn.Exec(schema); err != nil { return fmt.Errorf("failed to initialize database: %w", err) } return nil } func (db *DB) Close() error { return db.conn.Close() } // Job operations func (db *DB) CreateJob(job *Job) error { datasetsJSON, _ := json.Marshal(job.Datasets) metadataJSON, _ := json.Marshal(job.Metadata) var query string if db.dbType == "sqlite" { query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7)` } _, err := db.conn.Exec(query, job.ID, job.JobName, job.Args, job.Status, job.Priority, string(datasetsJSON), string(metadataJSON)) if err != nil { return fmt.Errorf("failed to create job: %w", err) } return nil } func (db *DB) GetJob(id string) (*Job, error) { var query string if db.dbType == "sqlite" { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs WHERE id = ?` } else { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs WHERE id = $1` } var job Job var datasetsJSON, metadataJSON string var workerID sql.NullString var errorMsg sql.NullString err := db.conn.QueryRow(query, id).Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to get job: %w", err) } if workerID.Valid { job.WorkerID = workerID.String } if errorMsg.Valid { job.Error = errorMsg.String } json.Unmarshal([]byte(datasetsJSON), &job.Datasets) json.Unmarshal([]byte(metadataJSON), &job.Metadata) return &job, nil } func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { var query string if db.dbType == "sqlite" { query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?, started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END WHERE id = ?` } else { query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3, started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END WHERE id = $6` } _, err := db.conn.Exec(query, status, workerID, errorMsg, status, status, id) if err != nil { return fmt.Errorf("failed to update job status: %w", err) } return nil } func (db *DB) ListJobs(status string, limit int) ([]*Job, error) { var query string if db.dbType == "sqlite" { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs` } else { query = `SELECT id, job_name, args, status, priority, created_at, started_at, ended_at, worker_id, error, datasets, metadata, updated_at FROM jobs` } var args []interface{} if status != "" { if db.dbType == "sqlite" { query += " WHERE status = ?" } else { query += " WHERE status = $1" } args = append(args, status) } query += " ORDER BY created_at DESC" if limit > 0 { if db.dbType == "sqlite" { query += " LIMIT ?" } else { query += fmt.Sprintf(" LIMIT $%d", len(args)+1) } args = append(args, limit) } rows, err := db.conn.Query(query, args...) if err != nil { return nil, fmt.Errorf("failed to list jobs: %w", err) } defer rows.Close() var jobs []*Job for rows.Next() { var job Job var datasetsJSON, metadataJSON string var workerID sql.NullString var errorMsg sql.NullString err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to scan job: %w", err) } if workerID.Valid { job.WorkerID = workerID.String } if errorMsg.Valid { job.Error = errorMsg.String } json.Unmarshal([]byte(datasetsJSON), &job.Datasets) json.Unmarshal([]byte(metadataJSON), &job.Metadata) jobs = append(jobs, &job) } return jobs, nil } // Worker operations func (db *DB) RegisterWorker(worker *Worker) error { metadataJSON, _ := json.Marshal(worker.Metadata) var query string if db.dbType == "sqlite" { query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) VALUES (?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO UPDATE SET hostname = EXCLUDED.hostname, status = EXCLUDED.status, current_jobs = EXCLUDED.current_jobs, max_jobs = EXCLUDED.max_jobs, metadata = EXCLUDED.metadata` } _, err := db.conn.Exec(query, worker.ID, worker.Hostname, worker.Status, worker.CurrentJobs, worker.MaxJobs, string(metadataJSON)) if err != nil { return fmt.Errorf("failed to register worker: %w", err) } return nil } func (db *DB) UpdateWorkerHeartbeat(workerID string) error { var query string if db.dbType == "sqlite" { query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?` } else { query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1` } _, err := db.conn.Exec(query, workerID) if err != nil { return fmt.Errorf("failed to update worker heartbeat: %w", err) } return nil } func (db *DB) GetActiveWorkers() ([]*Worker, error) { var query string if db.dbType == "sqlite" { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')` } else { query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'` } rows, err := db.conn.Query(query) if err != nil { return nil, fmt.Errorf("failed to get active workers: %w", err) } defer rows.Close() var workers []*Worker for rows.Next() { var worker Worker var metadataJSON string err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat, &worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON) if err != nil { return nil, fmt.Errorf("failed to scan worker: %w", err) } json.Unmarshal([]byte(metadataJSON), &worker.Metadata) workers = append(workers, &worker) } return workers, nil } // Metrics operations func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error { var query string if db.dbType == "sqlite" { query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)` } else { query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)` } _, err := db.conn.Exec(query, jobID, metricName, metricValue) if err != nil { return fmt.Errorf("failed to record job metric: %w", err) } return nil } func (db *DB) RecordSystemMetric(metricName, metricValue string) error { var query string if db.dbType == "sqlite" { query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)` } else { query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)` } _, err := db.conn.Exec(query, metricName, metricValue) if err != nil { return fmt.Errorf("failed to record system metric: %w", err) } return nil } func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) { var query string if db.dbType == "sqlite" { query = `SELECT metric_name, metric_value FROM job_metrics WHERE job_id = ? ORDER BY timestamp DESC` } else { query = `SELECT metric_name, metric_value FROM job_metrics WHERE job_id = $1 ORDER BY timestamp DESC` } rows, err := db.conn.Query(query, jobID) if err != nil { return nil, fmt.Errorf("failed to get job metrics: %w", err) } defer rows.Close() metrics := make(map[string]string) for rows.Next() { var name, value string if err := rows.Scan(&name, &value); err != nil { return nil, fmt.Errorf("failed to scan metric: %w", err) } metrics[name] = value } return metrics, nil }