fetch_ml/internal/queue/queue.go
Jeremie Fraeys 6866ba9366
refactor(queue): integrate scheduler backend and storage improvements
Update queue and storage systems for scheduler integration:
- Queue backend with scheduler coordination
- Filesystem queue with batch operations
- Deduplication with tenant-aware keys
- Storage layer with audit logging hooks
- Domain models (Task, Events, Errors) with scheduler fields
- Database layer with tenant isolation
- Dataset storage with integrity checks
2026-02-26 12:06:46 -05:00

816 lines
20 KiB
Go

package queue
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/redis/go-redis/v9"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
defaultLeaseDuration = 30 * time.Minute
defaultMaxRetries = 3
defaultBlockTimeout = 1 * time.Second
)
// TaskQueue manages ML experiment tasks via Redis
type TaskQueue struct {
ctx context.Context
client *redis.Client
cancel context.CancelFunc
metricsCh chan metricEvent
metricsDone chan struct{}
dedup *CommitDedup
flushEvery time.Duration
}
type metricEvent struct {
JobName string
Metric string
Value float64
}
type PrewarmState struct {
WorkerID string `json:"worker_id"`
TaskID string `json:"task_id"`
SnapshotID string `json:"snapshot_id,omitempty"`
StartedAt string `json:"started_at"`
UpdatedAt string `json:"updated_at"`
Phase string `json:"phase"`
EnvImage string `json:"env_image,omitempty"`
DatasetCnt int `json:"dataset_count"`
EnvHit int64 `json:"env_hit,omitempty"`
EnvMiss int64 `json:"env_miss,omitempty"`
EnvBuilt int64 `json:"env_built,omitempty"`
EnvTimeNs int64 `json:"env_time_ns,omitempty"`
}
// Config holds configuration for TaskQueue
type Config struct {
RedisAddr string
RedisPassword string
RedisDB int
MetricsFlushInterval time.Duration
}
// NewTaskQueue creates a new task queue instance with Redis backend.
func NewTaskQueue(cfg Config) (*TaskQueue, error) {
var opts *redis.Options
var err error
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
opts, err = redis.ParseURL(cfg.RedisAddr)
if err != nil {
return nil, fmt.Errorf("invalid redis url: %w", err)
}
} else {
opts = &redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
// Connection pooling optimizations
PoolSize: 50, // Increase connection pool size
MinIdleConns: 10, // Maintain minimum idle connections
MaxRetries: 3, // Reduce retries for faster failure
DialTimeout: 5 * time.Second, // Faster connection timeout
ReadTimeout: 3 * time.Second, // Faster read timeout
WriteTimeout: 3 * time.Second, // Faster write timeout
PoolTimeout: 4 * time.Second, // Faster pool timeout
}
}
rdb := redis.NewClient(opts)
ctx, cancel := context.WithCancel(context.Background())
if err := rdb.Ping(ctx).Err(); err != nil {
cancel()
return nil, fmt.Errorf("redis connection failed: %w", err)
}
flushEvery := cfg.MetricsFlushInterval
if flushEvery == 0 {
flushEvery = defaultMetricsFlushInterval
}
tq := &TaskQueue{
client: rdb,
ctx: ctx,
cancel: cancel,
metricsCh: make(chan metricEvent, 256),
metricsDone: make(chan struct{}),
flushEvery: flushEvery,
dedup: NewCommitDedup(1 * time.Hour), // 1 hour default TTL for commit dedup
}
go tq.runMetricsBuffer()
go tq.runLeaseReclamation() // Start lease reclamation background job
go tq.runDedupCleanup() // Start dedup cleanup background job
return tq, nil
}
// AddTask adds a new task to the queue with default retry settings
func (tq *TaskQueue) AddTask(task *Task) error {
// Set default retry settings if not specified
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
taskData, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
pipe := tq.client.Pipeline()
// Store task data
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
// Add to priority queue (ZSET)
// Use priority as score (higher priority = higher score)
pipe.ZAddArgs(tq.ctx, TaskQueueKey, redis.ZAddArgs{
Members: []redis.Z{
{
Score: float64(task.Priority),
Member: task.ID,
},
},
})
// Initialize status
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
// Queue depth update piggybacks on the same pipeline round-trip to avoid
// an extra Redis call per enqueue under high load.
depthCmd := pipe.ZCard(tq.ctx, TaskQueueKey)
_, err = pipe.Exec(tq.ctx)
if err != nil {
return fmt.Errorf("failed to enqueue task: %w", err)
}
// Record metrics
TasksQueued.Inc()
// Update queue depth
UpdateQueueDepth(depthCmd.Val())
return nil
}
func isBlockingUnsupported(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax")
}
func (tq *TaskQueue) pollUntilDeadline(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
deadline := time.Now().Add(blockTimeout)
sleep := 25 * time.Millisecond
for {
if time.Now().After(deadline) {
return nil, nil
}
task, err := tq.GetNextTaskWithLease(workerID, leaseDuration)
if err != nil {
return nil, err
}
if task != nil {
return task, nil
}
select {
case <-tq.ctx.Done():
return nil, tq.ctx.Err()
case <-time.After(sleep):
}
}
}
// GetNextTask gets the next task without lease (backward compatible)
func (tq *TaskQueue) GetNextTask() (*Task, error) {
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
return tq.GetTask(taskID)
}
// PeekNextTask returns the highest priority task without removing it from the queue.
// This is intended for best-effort prewarm logic; it must never be required for correctness.
func (tq *TaskQueue) PeekNextTask() (*Task, error) {
// ZRANGE with REV gives highest scores first.
ids, err := tq.client.ZRevRange(tq.ctx, TaskQueueKey, 0, 0).Result()
if err != nil {
return nil, err
}
if len(ids) == 0 {
return nil, nil
}
return tq.GetTask(ids[0])
}
func (tq *TaskQueue) SetWorkerPrewarmState(state PrewarmState) error {
if state.WorkerID == "" {
return fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + state.WorkerID
data, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal prewarm state: %w", err)
}
// Keep short TTL to avoid stale prewarm state if worker dies.
return tq.client.Set(tq.ctx, key, data, 30*time.Second).Err()
}
func (tq *TaskQueue) ClearWorkerPrewarmState(workerID string) error {
if workerID == "" {
return fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + workerID
return tq.client.Del(tq.ctx, key).Err()
}
func (tq *TaskQueue) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
if workerID == "" {
return nil, fmt.Errorf("missing worker_id")
}
key := WorkerPrewarmKey + workerID
v, err := tq.client.Get(tq.ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state PrewarmState
if err := json.Unmarshal([]byte(v), &state); err != nil {
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
}
return &state, nil
}
func (tq *TaskQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
var cursor uint64
pattern := WorkerPrewarmKey + "*"
out := make([]PrewarmState, 0, 8)
for {
keys, next, err := tq.client.Scan(tq.ctx, cursor, pattern, 50).Result()
if err != nil {
return nil, err
}
cursor = next
for _, key := range keys {
v, err := tq.client.Get(tq.ctx, key).Result()
if err == redis.Nil {
continue
}
if err != nil {
return nil, err
}
var state PrewarmState
if err := json.Unmarshal([]byte(v), &state); err != nil {
return nil, fmt.Errorf("unmarshal prewarm state: %w", err)
}
out = append(out, state)
}
if cursor == 0 {
break
}
}
return out, nil
}
// GetNextTaskWithLease gets the next task and acquires a lease
func (tq *TaskQueue) GetNextTaskWithLease(
workerID string,
leaseDuration time.Duration,
) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
// Pop highest priority task
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
task, err := tq.GetTask(taskID)
if err != nil {
// Re-queue the task if we can't fetch it
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
// Acquire lease
now := time.Now()
leaseExpiry := now.Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
task.LeasedBy = workerID
// Update task with lease
if err := tq.UpdateTask(task); err != nil {
// Re-queue if update fails
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
return task, nil
}
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task.
// Once a task is received, it then acquires a lease.
func (tq *TaskQueue) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
if blockTimeout <= 0 {
blockTimeout = defaultBlockTimeout
}
result, err := tq.client.BZPopMax(tq.ctx, blockTimeout, TaskQueueKey).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
if isBlockingUnsupported(err) {
return tq.pollUntilDeadline(workerID, leaseDuration, blockTimeout)
}
return nil, err
}
taskID, ok := result.Member.(string)
if !ok {
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
}
task, err := tq.GetTask(taskID)
if err != nil {
return nil, err
}
now := time.Now()
leaseExpiry := now.Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
task.LeasedBy = workerID
if err := tq.UpdateTask(task); err != nil {
return nil, err
}
return task, nil
}
// RenewLease renews the lease on a task (heartbeat)
func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Renew lease
leaseExpiry := time.Now().Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
// Record renewal metric
RecordLeaseRenewal(workerID)
return tq.UpdateTask(task)
}
// ReleaseLease releases the lease on a task
func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
return tq.UpdateTask(task)
}
// RetryTask re-queues a failed task with smart backoff based on failure class
func (tq *TaskQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
// Move to dead letter queue
RecordDLQAddition("max_retries")
return tq.MoveToDeadLetterQueue(task, "max retries exceeded")
}
// Classify the error if it exists
failureClass := domain.FailureUnknown
if task.Error != "" {
failureClass = domain.ClassifyFailure(0, nil, task.Error)
}
// Check if error is retryable
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
RecordDLQAddition(string(failureClass))
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error // Preserve last error
task.Error = "" // Clear current error
// Calculate smart backoff based on failure class
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
// Record retry metrics
RecordTaskRetry(task.JobName, failureClass)
// Re-queue with same priority
return tq.AddTask(task)
}
// MoveToDeadLetterQueue moves a task to the dead letter queue
func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
taskData, err := json.Marshal(task)
if err != nil {
return err
}
// Store in dead letter queue with timestamp
key := "task:dlq:" + task.ID
// Record metrics using domain.FailureClass
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
RecordTaskFailure(task.JobName, failureClass)
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
pipe.ZRem(tq.ctx, TaskQueueKey, task.ID)
_, err = pipe.Exec(tq.ctx)
return err
}
// runLeaseReclamation reclaims expired leases every 1 minute
func (tq *TaskQueue) runLeaseReclamation() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-tq.ctx.Done():
return
case <-ticker.C:
if err := tq.reclaimExpiredLeases(); err != nil {
// Log error but continue
continue
}
}
}
}
// reclaimExpiredLeases finds and re-queues tasks with expired leases
func (tq *TaskQueue) reclaimExpiredLeases() error {
// Scan for all task keys
iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator()
now := time.Now()
for iter.Next(tq.ctx) {
taskKey := iter.Val()
taskID := taskKey[len(TaskPrefix):]
task, err := tq.GetTask(taskID)
if err != nil {
continue
}
// Check if lease expired and task is still running
if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" {
// Lease expired - retry or fail the task
task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy)
// Record lease expiration
RecordLeaseExpiration()
if task.RetryCount < task.MaxRetries {
// Retry the task
if err := tq.RetryTask(task); err != nil {
continue
}
} else {
// Max retries exceeded - move to DLQ
if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil {
continue
}
}
}
}
return iter.Err()
}
// GetTask retrieves a task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result()
if err != nil {
return nil, err
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
return nil, err
}
return &task, nil
}
// GetAllTasks retrieves all tasks from the queue
func (tq *TaskQueue) GetAllTasks() ([]*Task, error) {
// Get all task keys
keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result()
if err != nil {
return nil, err
}
var tasks []*Task
for _, key := range keys {
data, err := tq.client.Get(tq.ctx, key).Result()
if err != nil {
continue // Skip tasks that can't be retrieved
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
continue // Skip malformed tasks
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// GetTaskByName retrieves a task by its job name
func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) {
tasks, err := tq.GetAllTasks()
if err != nil {
return nil, err
}
for _, task := range tasks {
if task.JobName == jobName {
return task, nil
}
}
return nil, fmt.Errorf("task with job name '%s' not found", jobName)
}
// CancelTask marks a task as cancelled
func (tq *TaskQueue) CancelTask(taskID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Update task status to cancelled
task.Status = "cancelled"
now := time.Now()
task.EndedAt = &now
return tq.UpdateTask(task)
}
// UpdateTask updates a task in Redis
func (tq *TaskQueue) UpdateTask(task *Task) error {
taskData, err := json.Marshal(task)
if err != nil {
return err
}
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
_, err = pipe.Exec(tq.ctx)
return err
}
// UpdateTaskWithMetrics updates task and records metrics
func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error {
if err := tq.UpdateTask(task); err != nil {
return err
}
metricName := "tasks_" + action
return tq.RecordMetric(task.JobName, metricName, 1)
}
// RecordMetric records a metric value
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
evt := metricEvent{JobName: jobName, Metric: metric, Value: value}
select {
case tq.metricsCh <- evt:
return nil
default:
return tq.writeMetrics(jobName, map[string]float64{metric: value})
}
}
// Heartbeat records worker heartbeat
func (tq *TaskQueue) Heartbeat(workerID string) error {
return tq.client.HSet(tq.ctx, WorkerHeartbeat,
workerID, time.Now().Unix()).Err()
}
// QueueDepth returns the number of pending tasks
func (tq *TaskQueue) QueueDepth() (int64, error) {
return tq.client.ZCard(tq.ctx, TaskQueueKey).Result()
}
// Close closes the task queue and cleans up resources
func (tq *TaskQueue) Close() error {
tq.cancel()
<-tq.metricsDone // Wait for metrics buffer to finish
return tq.client.Close()
}
func (tq *TaskQueue) SignalPrewarmGC() error {
return tq.client.Set(
tq.ctx,
PrewarmGCRequestKey,
time.Now().UnixNano(),
10*time.Minute,
).Err()
}
func (tq *TaskQueue) PrewarmGCRequestValue() (string, error) {
v, err := tq.client.Get(tq.ctx, PrewarmGCRequestKey).Result()
if err == redis.Nil {
return "", nil
}
return v, err
}
// GetRedisClient returns the underlying Redis client for direct access
func (tq *TaskQueue) GetRedisClient() *redis.Client {
return tq.client
}
// WaitForNextTask waits for next task with timeout
func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) {
if ctx == nil {
ctx = tq.ctx
}
result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
member, ok := result.Member.(string)
if !ok {
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
}
return tq.GetTask(member)
}
// runDedupCleanup periodically cleans up expired dedup entries every 5 minutes
func (tq *TaskQueue) runDedupCleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-tq.ctx.Done():
return
case <-ticker.C:
tq.dedup.Cleanup()
}
}
}
// AddTaskDedup adds a task with commit deduplication check
// Returns ErrAlreadyQueued if the same job+commit was recently queued
func (tq *TaskQueue) AddTaskDedup(task *Task, commitID string) error {
if commitID != "" && tq.dedup.IsDuplicate(task.JobName, commitID) {
return ErrAlreadyQueued
}
if err := tq.AddTask(task); err != nil {
return err
}
// Mark as queued on success
if commitID != "" {
tq.dedup.MarkQueued(task.JobName, commitID)
}
return nil
}
func (tq *TaskQueue) runMetricsBuffer() {
defer close(tq.metricsDone)
ticker := time.NewTicker(tq.flushEvery)
defer ticker.Stop()
pending := make(map[string]map[string]float64)
flush := func() {
for job, metrics := range pending {
if err := tq.writeMetrics(job, metrics); err != nil {
continue
}
delete(pending, job)
}
}
for {
select {
case <-tq.ctx.Done():
flush()
return
case evt, ok := <-tq.metricsCh:
if !ok {
flush()
return
}
if _, exists := pending[evt.JobName]; !exists {
pending[evt.JobName] = make(map[string]float64)
}
pending[evt.JobName][evt.Metric] = evt.Value
case <-ticker.C:
flush()
}
}
}
// writeMetrics writes metrics to Redis
func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error {
if len(metrics) == 0 {
return nil
}
key := JobMetricsPrefix + jobName
args := make([]any, 0, len(metrics)*2+2)
args = append(args, "timestamp", time.Now().Unix())
for metric, value := range metrics {
args = append(args, metric, value)
}
return tq.client.HSet(context.Background(), key, args...).Err()
}