package queue import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/redis/go-redis/v9" ) const ( defaultMetricsFlushInterval = 500 * time.Millisecond defaultLeaseDuration = 30 * time.Minute defaultMaxRetries = 3 defaultBlockTimeout = 1 * time.Second ) // TaskQueue manages ML experiment tasks via Redis type TaskQueue struct { client *redis.Client ctx context.Context cancel context.CancelFunc metricsCh chan metricEvent metricsDone chan struct{} flushEvery time.Duration } type metricEvent struct { JobName string Metric string Value float64 } // Config holds configuration for TaskQueue type Config struct { RedisAddr string RedisPassword string RedisDB int MetricsFlushInterval time.Duration } // NewTaskQueue creates a new task queue instance with Redis backend. func NewTaskQueue(cfg Config) (*TaskQueue, error) { var opts *redis.Options var err error if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" { opts, err = redis.ParseURL(cfg.RedisAddr) if err != nil { return nil, fmt.Errorf("invalid redis url: %w", err) } } else { opts = &redis.Options{ Addr: cfg.RedisAddr, Password: cfg.RedisPassword, DB: cfg.RedisDB, // Connection pooling optimizations PoolSize: 50, // Increase connection pool size MinIdleConns: 10, // Maintain minimum idle connections MaxRetries: 3, // Reduce retries for faster failure DialTimeout: 5 * time.Second, // Faster connection timeout ReadTimeout: 3 * time.Second, // Faster read timeout WriteTimeout: 3 * time.Second, // Faster write timeout PoolTimeout: 4 * time.Second, // Faster pool timeout } } rdb := redis.NewClient(opts) ctx, cancel := context.WithCancel(context.Background()) if err := rdb.Ping(ctx).Err(); err != nil { cancel() return nil, fmt.Errorf("redis connection failed: %w", err) } flushEvery := cfg.MetricsFlushInterval if flushEvery == 0 { flushEvery = defaultMetricsFlushInterval } tq := &TaskQueue{ client: rdb, ctx: ctx, cancel: cancel, metricsCh: make(chan metricEvent, 256), metricsDone: make(chan struct{}), flushEvery: flushEvery, } go tq.runMetricsBuffer() go tq.runLeaseReclamation() // Start lease reclamation background job return tq, nil } // AddTask adds a new task to the queue with default retry settings func (tq *TaskQueue) AddTask(task *Task) error { // Set default retry settings if not specified if task.MaxRetries == 0 { task.MaxRetries = defaultMaxRetries } taskData, err := json.Marshal(task) if err != nil { return fmt.Errorf("failed to marshal task: %w", err) } pipe := tq.client.Pipeline() // Store task data pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour) // Add to priority queue (ZSET) // Use priority as score (higher priority = higher score) pipe.ZAddArgs(tq.ctx, TaskQueueKey, redis.ZAddArgs{ Members: []redis.Z{ { Score: float64(task.Priority), Member: task.ID, }, }, }) // Initialize status pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName, "status", task.Status, "task_id", task.ID, "updated_at", time.Now().Format(time.RFC3339)) // Queue depth update piggybacks on the same pipeline round-trip to avoid // an extra Redis call per enqueue under high load. depthCmd := pipe.ZCard(tq.ctx, TaskQueueKey) _, err = pipe.Exec(tq.ctx) if err != nil { return fmt.Errorf("failed to enqueue task: %w", err) } // Record metrics TasksQueued.Inc() // Update queue depth UpdateQueueDepth(depthCmd.Val()) return nil } func isBlockingUnsupported(err error) bool { if err == nil { return false } msg := strings.ToLower(err.Error()) return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax") } func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) { deadline := time.Now().Add(blockTimeout) sleep := 25 * time.Millisecond for { if time.Now().After(deadline) { return nil, nil } task, err := tq.GetNextTaskWithLease(workerID, leaseDuration) if err != nil { return nil, err } if task != nil { return task, nil } select { case <-tq.ctx.Done(): return nil, tq.ctx.Err() case <-time.After(sleep): } } } // GetNextTask gets the next task without lease (backward compatible) func (tq *TaskQueue) GetNextTask() (*Task, error) { result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result() if err != nil { return nil, err } if len(result) == 0 { return nil, nil } taskID := result[0].Member.(string) return tq.GetTask(taskID) } // GetNextTaskWithLease gets the next task and acquires a lease func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { if leaseDuration == 0 { leaseDuration = defaultLeaseDuration } // Pop highest priority task result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result() if err != nil { return nil, err } if len(result) == 0 { return nil, nil } taskID := result[0].Member.(string) task, err := tq.GetTask(taskID) if err != nil { // Re-queue the task if we can't fetch it tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{ Score: result[0].Score, Member: taskID, }) return nil, err } // Acquire lease now := time.Now() leaseExpiry := now.Add(leaseDuration) task.LeaseExpiry = &leaseExpiry task.LeasedBy = workerID // Update task with lease if err := tq.UpdateTask(task); err != nil { // Re-queue if update fails tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{ Score: result[0].Score, Member: taskID, }) return nil, err } return task, nil } // GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease. func (tq *TaskQueue) GetNextTaskWithLeaseBlocking( workerID string, leaseDuration, blockTimeout time.Duration, ) (*Task, error) { if leaseDuration == 0 { leaseDuration = defaultLeaseDuration } if blockTimeout <= 0 { blockTimeout = defaultBlockTimeout } result, err := tq.client.BZPopMax(tq.ctx, blockTimeout, TaskQueueKey).Result() if err == redis.Nil { return nil, nil } if err != nil { if isBlockingUnsupported(err) { return tq.pollUntilDeadline(workerID, leaseDuration, blockTimeout) } return nil, err } taskID, ok := result.Member.(string) if !ok { return nil, fmt.Errorf("unexpected task id type %T", result.Member) } task, err := tq.GetTask(taskID) if err != nil { return nil, err } now := time.Now() leaseExpiry := now.Add(leaseDuration) task.LeaseExpiry = &leaseExpiry task.LeasedBy = workerID if err := tq.UpdateTask(task); err != nil { return nil, err } return task, nil } // RenewLease renews the lease on a task (heartbeat) func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error { if leaseDuration == 0 { leaseDuration = defaultLeaseDuration } task, err := tq.GetTask(taskID) if err != nil { return err } // Verify the worker owns the lease if task.LeasedBy != workerID { return fmt.Errorf("task leased by different worker: %s", task.LeasedBy) } // Renew lease leaseExpiry := time.Now().Add(leaseDuration) task.LeaseExpiry = &leaseExpiry // Record renewal metric RecordLeaseRenewal(workerID) return tq.UpdateTask(task) } // ReleaseLease releases the lease on a task func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error { task, err := tq.GetTask(taskID) if err != nil { return err } // Verify the worker owns the lease if task.LeasedBy != workerID { return fmt.Errorf("task leased by different worker: %s", task.LeasedBy) } // Clear lease task.LeaseExpiry = nil task.LeasedBy = "" return tq.UpdateTask(task) } // RetryTask re-queues a failed task with smart backoff based on error category func (tq *TaskQueue) RetryTask(task *Task) error { if task.RetryCount >= task.MaxRetries { // Move to dead letter queue RecordDLQAddition("max_retries") return tq.MoveToDeadLetterQueue(task, "max retries exceeded") } // Classify the error if it exists errorCategory := ErrorUnknown if task.Error != "" { errorCategory = ClassifyError(fmt.Errorf("%s", task.Error)) } // Check if error is retryable if !IsRetryable(errorCategory) { RecordDLQAddition(string(errorCategory)) return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory)) } task.RetryCount++ task.Status = "queued" task.LastError = task.Error // Preserve last error task.Error = "" // Clear current error // Calculate smart backoff based on error category backoffSeconds := RetryDelay(errorCategory, task.RetryCount) nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second) task.NextRetry = &nextRetry // Clear lease task.LeaseExpiry = nil task.LeasedBy = "" // Record retry metrics RecordTaskRetry(task.JobName, errorCategory) // Re-queue with same priority return tq.AddTask(task) } // MoveToDeadLetterQueue moves a task to the dead letter queue func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error { task.Status = "failed" task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError) taskData, err := json.Marshal(task) if err != nil { return err } // Store in dead letter queue with timestamp key := "task:dlq:" + task.ID // Record metrics RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError))) pipe := tq.client.Pipeline() pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour) pipe.ZRem(tq.ctx, TaskQueueKey, task.ID) _, err = pipe.Exec(tq.ctx) return err } // runLeaseReclamation reclaims expired leases every 1 minute func (tq *TaskQueue) runLeaseReclamation() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-tq.ctx.Done(): return case <-ticker.C: if err := tq.reclaimExpiredLeases(); err != nil { // Log error but continue continue } } } } // reclaimExpiredLeases finds and re-queues tasks with expired leases func (tq *TaskQueue) reclaimExpiredLeases() error { // Scan for all task keys iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator() now := time.Now() for iter.Next(tq.ctx) { taskKey := iter.Val() taskID := taskKey[len(TaskPrefix):] task, err := tq.GetTask(taskID) if err != nil { continue } // Check if lease expired and task is still running if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" { // Lease expired - retry or fail the task task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy) // Record lease expiration RecordLeaseExpiration() if task.RetryCount < task.MaxRetries { // Retry the task if err := tq.RetryTask(task); err != nil { continue } } else { // Max retries exceeded - move to DLQ if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil { continue } } } } return iter.Err() } // GetTask retrieves a task by ID func (tq *TaskQueue) GetTask(taskID string) (*Task, error) { data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result() if err != nil { return nil, err } var task Task if err := json.Unmarshal([]byte(data), &task); err != nil { return nil, err } return &task, nil } // GetAllTasks retrieves all tasks from the queue func (tq *TaskQueue) GetAllTasks() ([]*Task, error) { // Get all task keys keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result() if err != nil { return nil, err } var tasks []*Task for _, key := range keys { data, err := tq.client.Get(tq.ctx, key).Result() if err != nil { continue // Skip tasks that can't be retrieved } var task Task if err := json.Unmarshal([]byte(data), &task); err != nil { continue // Skip malformed tasks } tasks = append(tasks, &task) } return tasks, nil } // GetTaskByName retrieves a task by its job name func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) { tasks, err := tq.GetAllTasks() if err != nil { return nil, err } for _, task := range tasks { if task.JobName == jobName { return task, nil } } return nil, fmt.Errorf("task with job name '%s' not found", jobName) } // CancelTask marks a task as cancelled func (tq *TaskQueue) CancelTask(taskID string) error { task, err := tq.GetTask(taskID) if err != nil { return err } // Update task status to cancelled task.Status = "cancelled" now := time.Now() task.EndedAt = &now return tq.UpdateTask(task) } // UpdateTask updates a task in Redis func (tq *TaskQueue) UpdateTask(task *Task) error { taskData, err := json.Marshal(task) if err != nil { return err } pipe := tq.client.Pipeline() pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour) pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName, "status", task.Status, "task_id", task.ID, "updated_at", time.Now().Format(time.RFC3339)) _, err = pipe.Exec(tq.ctx) return err } // UpdateTaskWithMetrics updates task and records metrics func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error { if err := tq.UpdateTask(task); err != nil { return err } metricName := "tasks_" + action return tq.RecordMetric(task.JobName, metricName, 1) } // RecordMetric records a metric value func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error { evt := metricEvent{JobName: jobName, Metric: metric, Value: value} select { case tq.metricsCh <- evt: return nil default: return tq.writeMetrics(jobName, map[string]float64{metric: value}) } } // Heartbeat records worker heartbeat func (tq *TaskQueue) Heartbeat(workerID string) error { return tq.client.HSet(tq.ctx, WorkerHeartbeat, workerID, time.Now().Unix()).Err() } // QueueDepth returns the number of pending tasks func (tq *TaskQueue) QueueDepth() (int64, error) { return tq.client.ZCard(tq.ctx, TaskQueueKey).Result() } // Close closes the task queue and cleans up resources func (tq *TaskQueue) Close() error { tq.cancel() <-tq.metricsDone // Wait for metrics buffer to finish return tq.client.Close() } // GetRedisClient returns the underlying Redis client for direct access func (tq *TaskQueue) GetRedisClient() *redis.Client { return tq.client } // WaitForNextTask waits for next task with timeout func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) { if ctx == nil { ctx = tq.ctx } result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result() if err == redis.Nil { return nil, nil } if err != nil { return nil, err } member, ok := result.Member.(string) if !ok { return nil, fmt.Errorf("unexpected task id type %T", result.Member) } return tq.GetTask(member) } // runMetricsBuffer buffers and flushes metrics func (tq *TaskQueue) runMetricsBuffer() { defer close(tq.metricsDone) ticker := time.NewTicker(tq.flushEvery) defer ticker.Stop() pending := make(map[string]map[string]float64) flush := func() { for job, metrics := range pending { if err := tq.writeMetrics(job, metrics); err != nil { continue } delete(pending, job) } } for { select { case <-tq.ctx.Done(): flush() return case evt, ok := <-tq.metricsCh: if !ok { flush() return } if _, exists := pending[evt.JobName]; !exists { pending[evt.JobName] = make(map[string]float64) } pending[evt.JobName][evt.Metric] = evt.Value case <-ticker.C: flush() } } } // writeMetrics writes metrics to Redis func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error { if len(metrics) == 0 { return nil } key := JobMetricsPrefix + jobName args := make([]any, 0, len(metrics)*2+2) args = append(args, "timestamp", time.Now().Unix()) for metric, value := range metrics { args = append(args, metric, value) } return tq.client.HSet(context.Background(), key, args...).Err() }