fetch_ml/internal/storage/db.go
Jeremie Fraeys 803677be57 feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
2025-12-04 16:53:53 -05:00

433 lines
12 KiB
Go

package storage
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type DBConfig struct {
Type string
Connection string
Host string
Port int
Username string
Password string
Database string
}
type DB struct {
conn *sql.DB
dbType string
}
func NewDB(config DBConfig) (*DB, error) {
var conn *sql.DB
var err error
switch strings.ToLower(config.Type) {
case "sqlite":
conn, err = sql.Open("sqlite3", config.Connection)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Enable foreign keys
if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Enable WAL mode for better concurrency
if _, err := conn.Exec("PRAGMA journal_mode = WAL"); err != nil {
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
}
case "postgres":
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
case "postgresql":
// Handle "postgresql" as alias for "postgres"
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
default:
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
}
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
}
func buildPostgresConnectionString(config DBConfig) string {
if config.Connection != "" {
return config.Connection
}
var connStr strings.Builder
connStr.WriteString("host=")
if config.Host != "" {
connStr.WriteString(config.Host)
} else {
connStr.WriteString("localhost")
}
if config.Port > 0 {
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
} else {
connStr.WriteString(" port=5432")
}
if config.Username != "" {
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
}
if config.Password != "" {
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
}
if config.Database != "" {
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
} else {
connStr.WriteString(" dbname=fetch_ml")
}
connStr.WriteString(" sslmode=disable")
return connStr.String()
}
// Legacy constructor for backward compatibility
func NewDBFromPath(dbPath string) (*DB, error) {
return NewDB(DBConfig{
Type: "sqlite",
Connection: dbPath,
})
}
type Job struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type Worker struct {
ID string `json:"id"`
Hostname string `json:"hostname"`
LastHeartbeat time.Time `json:"last_heartbeat"`
Status string `json:"status"`
CurrentJobs int `json:"current_jobs"`
MaxJobs int `json:"max_jobs"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func (db *DB) Initialize(schema string) error {
if _, err := db.conn.Exec(schema); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
func (db *DB) Close() error {
return db.conn.Close()
}
// Job operations
func (db *DB) CreateJob(job *Job) error {
datasetsJSON, _ := json.Marshal(job.Datasets)
metadataJSON, _ := json.Marshal(job.Metadata)
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)`
}
_, err := db.conn.Exec(query, job.ID, job.JobName, job.Args, job.Status,
job.Priority, string(datasetsJSON), string(metadataJSON))
if err != nil {
return fmt.Errorf("failed to create job: %w", err)
}
return nil
}
func (db *DB) GetJob(id string) (*Job, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs WHERE id = ?`
} else {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs WHERE id = $1`
}
var job Job
var datasetsJSON, metadataJSON string
var workerID sql.NullString
var errorMsg sql.NullString
err := db.conn.QueryRow(query, id).Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to get job: %w", err)
}
if workerID.Valid {
job.WorkerID = workerID.String
}
if errorMsg.Valid {
job.Error = errorMsg.String
}
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
return &job, nil
}
func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error {
var query string
if db.dbType == "sqlite" {
query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?,
started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = ?`
} else {
query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3,
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = $6`
}
_, err := db.conn.Exec(query, status, workerID, errorMsg, status, status, id)
if err != nil {
return fmt.Errorf("failed to update job status: %w", err)
}
return nil
}
func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs`
} else {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs`
}
var args []interface{}
if status != "" {
if db.dbType == "sqlite" {
query += " WHERE status = ?"
} else {
query += " WHERE status = $1"
}
args = append(args, status)
}
query += " ORDER BY created_at DESC"
if limit > 0 {
if db.dbType == "sqlite" {
query += " LIMIT ?"
} else {
query += fmt.Sprintf(" LIMIT $%d", len(args)+1)
}
args = append(args, limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list jobs: %w", err)
}
defer rows.Close()
var jobs []*Job
for rows.Next() {
var job Job
var datasetsJSON, metadataJSON string
var workerID sql.NullString
var errorMsg sql.NullString
err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan job: %w", err)
}
if workerID.Valid {
job.WorkerID = workerID.String
}
if errorMsg.Valid {
job.Error = errorMsg.String
}
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
jobs = append(jobs, &job)
}
return jobs, nil
}
// Worker operations
func (db *DB) RegisterWorker(worker *Worker) error {
metadataJSON, _ := json.Marshal(worker.Metadata)
var query string
if db.dbType == "sqlite" {
query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
VALUES (?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
hostname = EXCLUDED.hostname,
status = EXCLUDED.status,
current_jobs = EXCLUDED.current_jobs,
max_jobs = EXCLUDED.max_jobs,
metadata = EXCLUDED.metadata`
}
_, err := db.conn.Exec(query, worker.ID, worker.Hostname, worker.Status,
worker.CurrentJobs, worker.MaxJobs, string(metadataJSON))
if err != nil {
return fmt.Errorf("failed to register worker: %w", err)
}
return nil
}
func (db *DB) UpdateWorkerHeartbeat(workerID string) error {
var query string
if db.dbType == "sqlite" {
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?`
} else {
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1`
}
_, err := db.conn.Exec(query, workerID)
if err != nil {
return fmt.Errorf("failed to update worker heartbeat: %w", err)
}
return nil
}
func (db *DB) GetActiveWorkers() ([]*Worker, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')`
} else {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
}
rows, err := db.conn.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to get active workers: %w", err)
}
defer rows.Close()
var workers []*Worker
for rows.Next() {
var worker Worker
var metadataJSON string
err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat,
&worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON)
if err != nil {
return nil, fmt.Errorf("failed to scan worker: %w", err)
}
json.Unmarshal([]byte(metadataJSON), &worker.Metadata)
workers = append(workers, &worker)
}
return workers, nil
}
// Metrics operations
func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error {
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)`
} else {
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)`
}
_, err := db.conn.Exec(query, jobID, metricName, metricValue)
if err != nil {
return fmt.Errorf("failed to record job metric: %w", err)
}
return nil
}
func (db *DB) RecordSystemMetric(metricName, metricValue string) error {
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)`
} else {
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)`
}
_, err := db.conn.Exec(query, metricName, metricValue)
if err != nil {
return fmt.Errorf("failed to record system metric: %w", err)
}
return nil
}
func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT metric_name, metric_value FROM job_metrics
WHERE job_id = ? ORDER BY timestamp DESC`
} else {
query = `SELECT metric_name, metric_value FROM job_metrics
WHERE job_id = $1 ORDER BY timestamp DESC`
}
rows, err := db.conn.Query(query, jobID)
if err != nil {
return nil, fmt.Errorf("failed to get job metrics: %w", err)
}
defer rows.Close()
metrics := make(map[string]string)
for rows.Next() {
var name, value string
if err := rows.Scan(&name, &value); err != nil {
return nil, fmt.Errorf("failed to scan metric: %w", err)
}
metrics[name] = value
}
return metrics, nil
}