fetch_ml/internal/queue/sqlite_queue.go

762 lines
17 KiB
Go

package queue
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
_ "github.com/mattn/go-sqlite3"
)
type SQLiteQueue struct {
db *sql.DB
ctx context.Context
cancel context.CancelFunc
}
func NewSQLiteQueue(path string) (*SQLiteQueue, error) {
if path == "" {
return nil, fmt.Errorf("sqlite queue path is required")
}
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=5000&_foreign_keys=on", path))
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
ctx, cancel := context.WithCancel(context.Background())
q := &SQLiteQueue{db: db, ctx: ctx, cancel: cancel}
if err := q.initSchema(); err != nil {
_ = db.Close()
cancel()
return nil, err
}
go q.leaseReclaimer()
go q.kvJanitor()
return q, nil
}
func (q *SQLiteQueue) initSchema() error {
stmts := []string{
"PRAGMA journal_mode=WAL;",
"PRAGMA synchronous=NORMAL;",
`CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
job_name TEXT,
status TEXT,
priority INTEGER,
created_at INTEGER,
updated_at INTEGER,
payload BLOB
);`,
"CREATE INDEX IF NOT EXISTS idx_tasks_job_name ON tasks(job_name);",
"CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);",
`CREATE TABLE IF NOT EXISTS queue (
task_id TEXT PRIMARY KEY,
priority INTEGER,
available_at INTEGER,
created_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_queue_available ON queue(available_at, priority DESC, created_at);",
`CREATE TABLE IF NOT EXISTS worker_prewarm (
worker_id TEXT PRIMARY KEY,
payload BLOB,
updated_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_prewarm_updated ON worker_prewarm(updated_at);",
`CREATE TABLE IF NOT EXISTS kv (
key TEXT PRIMARY KEY,
value TEXT,
expires_at INTEGER
);`,
"CREATE INDEX IF NOT EXISTS idx_kv_expires ON kv(expires_at);",
`CREATE TABLE IF NOT EXISTS worker_heartbeat (
worker_id TEXT PRIMARY KEY,
last_seen INTEGER
);`,
}
for _, stmt := range stmts {
if _, err := q.db.ExecContext(q.ctx, stmt); err != nil {
return fmt.Errorf("sqlite schema init failed: %w", err)
}
}
return nil
}
func (q *SQLiteQueue) Close() error {
q.cancel()
return q.db.Close()
}
func (q *SQLiteQueue) AddTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.ID == "" {
return fmt.Errorf("task id is required")
}
if task.JobName == "" {
return fmt.Errorf("job name is required")
}
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
now := time.Now().UTC()
if task.CreatedAt.IsZero() {
task.CreatedAt = now
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
createdAt := task.CreatedAt.UnixNano()
updatedAt := now.UnixNano()
availableAt := now.UnixNano()
if task.NextRetry != nil {
availableAt = task.NextRetry.UTC().UnixNano()
}
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO tasks(id, job_name, status, priority, created_at, updated_at, payload)
VALUES(?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
job_name=excluded.job_name,
status=excluded.status,
priority=excluded.priority,
created_at=excluded.created_at,
updated_at=excluded.updated_at,
payload=excluded.payload`,
task.ID,
task.JobName,
task.Status,
task.Priority,
createdAt,
updatedAt,
payload,
)
if err != nil {
return err
}
if task.Status == "queued" {
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO queue(task_id, priority, available_at, created_at)
VALUES(?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
priority=excluded.priority,
available_at=excluded.available_at,
created_at=excluded.created_at`,
task.ID,
task.Priority,
availableAt,
createdAt,
)
if err != nil {
return err
}
} else {
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
TasksQueued.Inc()
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
return nil
}
func (q *SQLiteQueue) GetTask(taskID string) (*Task, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT payload FROM tasks WHERE id = ?", taskID)
var payload []byte
if err := row.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
return &t, nil
}
func (q *SQLiteQueue) GetAllTasks() ([]*Task, error) {
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]*Task, 0, 32)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
out = append(out, &t)
}
return out, rows.Err()
}
func (q *SQLiteQueue) GetTaskByName(jobName string) (*Task, error) {
row := q.db.QueryRowContext(
q.ctx,
"SELECT payload FROM tasks WHERE job_name = ? ORDER BY created_at DESC LIMIT 1",
jobName,
)
var payload []byte
if err := row.Scan(&payload); err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
return &t, nil
}
func (q *SQLiteQueue) UpdateTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
now := time.Now().UTC().UnixNano()
availableAt := time.Now().UTC().UnixNano()
if task.NextRetry != nil {
availableAt = task.NextRetry.UTC().UnixNano()
}
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
_, err = tx.ExecContext(
q.ctx,
`UPDATE tasks SET job_name=?, status=?, priority=?, updated_at=?, payload=? WHERE id=?`,
task.JobName,
task.Status,
task.Priority,
now,
payload,
task.ID,
)
if err != nil {
return err
}
if task.Status == "queued" {
_, err = tx.ExecContext(
q.ctx,
`INSERT INTO queue(task_id, priority, available_at, created_at)
VALUES(?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
priority=excluded.priority,
available_at=excluded.available_at`,
task.ID,
task.Priority,
availableAt,
task.CreatedAt.UTC().UnixNano(),
)
if err != nil {
return err
}
} else {
_, err = tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", task.ID)
if err != nil {
return err
}
}
return tx.Commit()
}
func (q *SQLiteQueue) UpdateTaskWithMetrics(task *Task, _ string) error {
return q.UpdateTask(task)
}
func (q *SQLiteQueue) GetNextTask() (*Task, error) {
t, _, err := q.peekOrPop(false)
return t, err
}
func (q *SQLiteQueue) PeekNextTask() (*Task, error) {
t, _, err := q.peekOrPop(true)
return t, err
}
func (q *SQLiteQueue) peekOrPop(peek bool) (*Task, int64, error) {
now := time.Now().UTC().UnixNano()
query := `SELECT q.task_id, t.payload FROM queue q
JOIN tasks t ON t.id = q.task_id
WHERE q.available_at <= ?
ORDER BY q.priority DESC, q.created_at ASC
LIMIT 1`
row := q.db.QueryRowContext(q.ctx, query, now)
var taskID string
var payload []byte
if err := row.Scan(&taskID, &payload); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, 0, nil
}
return nil, 0, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, 0, err
}
if !peek {
if _, err := q.db.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
return nil, 0, err
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
}
return &t, 0, nil
}
func (q *SQLiteQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
return q.claimTask(workerID, leaseDuration)
}
func (q *SQLiteQueue) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
if blockTimeout <= 0 {
blockTimeout = defaultBlockTimeout
}
deadline := time.Now().Add(blockTimeout)
for {
t, err := q.claimTask(workerID, leaseDuration)
if err != nil {
return nil, err
}
if t != nil {
return t, nil
}
if time.Now().After(deadline) {
return nil, nil
}
time.Sleep(50 * time.Millisecond)
}
}
func (q *SQLiteQueue) claimTask(workerID string, leaseDuration time.Duration) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
if workerID == "" {
return nil, fmt.Errorf("worker_id is required")
}
now := time.Now().UTC()
nowNs := now.UnixNano()
tx, err := q.db.BeginTx(q.ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return nil, err
}
defer func() { _ = tx.Rollback() }()
row := tx.QueryRowContext(
q.ctx,
`SELECT q.task_id, t.payload FROM queue q
JOIN tasks t ON t.id = q.task_id
WHERE q.available_at <= ?
ORDER BY q.priority DESC, q.created_at ASC
LIMIT 1`,
nowNs,
)
var taskID string
var payload []byte
if err := row.Scan(&taskID, &payload); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
return nil, err
}
// Mirror Redis semantics: acquire lease fields but do not modify status here.
exp := now.Add(leaseDuration)
t.LeaseExpiry = &exp
t.LeasedBy = workerID
newPayload, err := json.Marshal(&t)
if err != nil {
return nil, err
}
if _, err := tx.ExecContext(q.ctx, "DELETE FROM queue WHERE task_id = ?", taskID); err != nil {
return nil, err
}
if _, err := tx.ExecContext(
q.ctx,
"UPDATE tasks SET updated_at=?, payload=? WHERE id=?",
nowNs,
newPayload,
taskID,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
return &t, nil
}
func (q *SQLiteQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
exp := time.Now().UTC().Add(leaseDuration)
t.LeaseExpiry = &exp
RecordLeaseRenewal(workerID)
return q.UpdateTask(t)
}
func (q *SQLiteQueue) ReleaseLease(taskID string, workerID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
t.LeaseExpiry = nil
t.LeasedBy = ""
return q.UpdateTask(t)
}
func (q *SQLiteQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
RecordDLQAddition("max_retries")
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
}
errorCategory := ErrorUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
}
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error
task.Error = ""
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
task.LeaseExpiry = nil
task.LeasedBy = ""
RecordTaskRetry(task.JobName, errorCategory)
return q.AddTask(task)
}
func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
return q.UpdateTask(task)
}
func (q *SQLiteQueue) CancelTask(taskID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
t.Status = "cancelled"
now := time.Now().UTC()
t.EndedAt = &now
return q.UpdateTask(t)
}
func (q *SQLiteQueue) RecordMetric(_, _ string, _ float64) error {
return nil
}
func (q *SQLiteQueue) Heartbeat(workerID string) error {
if workerID == "" {
return fmt.Errorf("worker_id is required")
}
_, err := q.db.ExecContext(
q.ctx,
`INSERT INTO worker_heartbeat(worker_id, last_seen) VALUES(?, ?)
ON CONFLICT(worker_id) DO UPDATE SET last_seen=excluded.last_seen`,
workerID,
time.Now().UTC().UnixNano(),
)
return err
}
func (q *SQLiteQueue) QueueDepth() (int64, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT COUNT(1) FROM queue")
var n int64
if err := row.Scan(&n); err != nil {
return 0, err
}
return n, nil
}
func (q *SQLiteQueue) SetWorkerPrewarmState(state PrewarmState) error {
if state.WorkerID == "" {
return fmt.Errorf("missing worker_id")
}
if state.StartedAt == "" {
state.StartedAt = time.Now().UTC().Format(time.RFC3339)
}
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
payload, err := json.Marshal(state)
if err != nil {
return err
}
_, err = q.db.ExecContext(
q.ctx,
`INSERT INTO worker_prewarm(worker_id, payload, updated_at) VALUES(?, ?, ?)
ON CONFLICT(worker_id) DO UPDATE SET payload=excluded.payload, updated_at=excluded.updated_at`,
state.WorkerID,
payload,
time.Now().UTC().UnixNano(),
)
return err
}
func (q *SQLiteQueue) ClearWorkerPrewarmState(workerID string) error {
if workerID == "" {
return fmt.Errorf("missing worker_id")
}
_, err := q.db.ExecContext(q.ctx, "DELETE FROM worker_prewarm WHERE worker_id = ?", workerID)
return err
}
func (q *SQLiteQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
if workerID == "" {
return nil, fmt.Errorf("missing worker_id")
}
row := q.db.QueryRowContext(q.ctx, "SELECT payload, updated_at FROM worker_prewarm WHERE worker_id = ?", workerID)
var payload []byte
var updatedAt int64
if err := row.Scan(&payload, &updatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
// Mirror Redis TTL semantics (30s).
if time.Since(time.Unix(0, updatedAt)) > 30*time.Second {
return nil, nil
}
var st PrewarmState
if err := json.Unmarshal(payload, &st); err != nil {
return nil, err
}
return &st, nil
}
func (q *SQLiteQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
cutoff := time.Now().UTC().Add(-30 * time.Second).UnixNano()
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM worker_prewarm WHERE updated_at >= ?", cutoff)
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]PrewarmState, 0, 8)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
return nil, err
}
var st PrewarmState
if err := json.Unmarshal(payload, &st); err != nil {
return nil, err
}
out = append(out, st)
}
return out, rows.Err()
}
func (q *SQLiteQueue) SignalPrewarmGC() error {
return q.kvSet(PrewarmGCRequestKey, fmt.Sprintf("%d", time.Now().UTC().UnixNano()), 10*time.Minute)
}
func (q *SQLiteQueue) PrewarmGCRequestValue() (string, error) {
v, _, err := q.kvGet(PrewarmGCRequestKey)
return v, err
}
func (q *SQLiteQueue) kvSet(key, value string, ttl time.Duration) error {
exp := time.Now().UTC().Add(ttl).UnixNano()
_, err := q.db.ExecContext(
q.ctx,
`INSERT INTO kv(key, value, expires_at) VALUES(?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value, expires_at=excluded.expires_at`,
key,
value,
exp,
)
return err
}
func (q *SQLiteQueue) kvGet(key string) (string, bool, error) {
row := q.db.QueryRowContext(q.ctx, "SELECT value, expires_at FROM kv WHERE key = ?", key)
var value string
var exp int64
if err := row.Scan(&value, &exp); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
}
return "", false, err
}
if exp > 0 && time.Now().UTC().UnixNano() > exp {
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE key = ?", key)
return "", false, nil
}
return value, true, nil
}
func (q *SQLiteQueue) kvJanitor() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-q.ctx.Done():
return
case <-ticker.C:
_, _ = q.db.ExecContext(q.ctx, "DELETE FROM kv WHERE expires_at > 0 AND expires_at < ?", time.Now().UTC().UnixNano())
}
}
}
func (q *SQLiteQueue) leaseReclaimer() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-q.ctx.Done():
return
case <-ticker.C:
_ = q.reclaimExpiredLeases()
}
}
}
func (q *SQLiteQueue) ReclaimExpiredLeases() error {
return q.reclaimExpiredLeases()
}
func (q *SQLiteQueue) reclaimExpiredLeases() error {
now := time.Now().UTC()
rows, err := q.db.QueryContext(q.ctx, "SELECT payload FROM tasks")
if err != nil {
return err
}
payloads := make([][]byte, 0, 32)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
_ = rows.Close()
return err
}
payloads = append(payloads, payload)
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return err
}
_ = rows.Close()
for _, payload := range payloads {
var t Task
if err := json.Unmarshal(payload, &t); err != nil {
continue
}
if t.LeaseExpiry == nil {
continue
}
if t.Status != "running" {
continue
}
if t.LeaseExpiry.Before(now) {
t.Error = fmt.Sprintf("worker %s lease expired", t.LeasedBy)
RecordLeaseExpiration()
if t.RetryCount < t.MaxRetries {
_ = q.RetryTask(&t)
} else {
_ = q.MoveToDeadLetterQueue(&t, "lease expiry after max retries")
}
}
}
return nil
}