fetch_ml/internal/queue/queue.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

651 lines
16 KiB
Go

package queue
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
defaultLeaseDuration = 30 * time.Minute
defaultMaxRetries = 3
defaultBlockTimeout = 1 * time.Second
)
// TaskQueue manages ML experiment tasks via Redis
type TaskQueue struct {
client *redis.Client
ctx context.Context
cancel context.CancelFunc
metricsCh chan metricEvent
metricsDone chan struct{}
flushEvery time.Duration
}
type metricEvent struct {
JobName string
Metric string
Value float64
}
// Config holds configuration for TaskQueue
type Config struct {
RedisAddr string
RedisPassword string
RedisDB int
MetricsFlushInterval time.Duration
}
// NewTaskQueue creates a new task queue instance with Redis backend.
func NewTaskQueue(cfg Config) (*TaskQueue, error) {
var opts *redis.Options
var err error
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
opts, err = redis.ParseURL(cfg.RedisAddr)
if err != nil {
return nil, fmt.Errorf("invalid redis url: %w", err)
}
} else {
opts = &redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
// Connection pooling optimizations
PoolSize: 50, // Increase connection pool size
MinIdleConns: 10, // Maintain minimum idle connections
MaxRetries: 3, // Reduce retries for faster failure
DialTimeout: 5 * time.Second, // Faster connection timeout
ReadTimeout: 3 * time.Second, // Faster read timeout
WriteTimeout: 3 * time.Second, // Faster write timeout
PoolTimeout: 4 * time.Second, // Faster pool timeout
}
}
rdb := redis.NewClient(opts)
ctx, cancel := context.WithCancel(context.Background())
if err := rdb.Ping(ctx).Err(); err != nil {
cancel()
return nil, fmt.Errorf("redis connection failed: %w", err)
}
flushEvery := cfg.MetricsFlushInterval
if flushEvery == 0 {
flushEvery = defaultMetricsFlushInterval
}
tq := &TaskQueue{
client: rdb,
ctx: ctx,
cancel: cancel,
metricsCh: make(chan metricEvent, 256),
metricsDone: make(chan struct{}),
flushEvery: flushEvery,
}
go tq.runMetricsBuffer()
go tq.runLeaseReclamation() // Start lease reclamation background job
return tq, nil
}
// AddTask adds a new task to the queue with default retry settings
func (tq *TaskQueue) AddTask(task *Task) error {
// Set default retry settings if not specified
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
taskData, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
pipe := tq.client.Pipeline()
// Store task data
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
// Add to priority queue (ZSET)
// Use priority as score (higher priority = higher score)
pipe.ZAddArgs(tq.ctx, TaskQueueKey, redis.ZAddArgs{
Members: []redis.Z{
{
Score: float64(task.Priority),
Member: task.ID,
},
},
})
// Initialize status
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
// Queue depth update piggybacks on the same pipeline round-trip to avoid
// an extra Redis call per enqueue under high load.
depthCmd := pipe.ZCard(tq.ctx, TaskQueueKey)
_, err = pipe.Exec(tq.ctx)
if err != nil {
return fmt.Errorf("failed to enqueue task: %w", err)
}
// Record metrics
TasksQueued.Inc()
// Update queue depth
UpdateQueueDepth(depthCmd.Val())
return nil
}
func isBlockingUnsupported(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax")
}
func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
deadline := time.Now().Add(blockTimeout)
sleep := 25 * time.Millisecond
for {
if time.Now().After(deadline) {
return nil, nil
}
task, err := tq.GetNextTaskWithLease(workerID, leaseDuration)
if err != nil {
return nil, err
}
if task != nil {
return task, nil
}
select {
case <-tq.ctx.Done():
return nil, tq.ctx.Err()
case <-time.After(sleep):
}
}
}
// GetNextTask gets the next task without lease (backward compatible)
func (tq *TaskQueue) GetNextTask() (*Task, error) {
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
return tq.GetTask(taskID)
}
// GetNextTaskWithLease gets the next task and acquires a lease
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
// Pop highest priority task
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
task, err := tq.GetTask(taskID)
if err != nil {
// Re-queue the task if we can't fetch it
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
// Acquire lease
now := time.Now()
leaseExpiry := now.Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
task.LeasedBy = workerID
// Update task with lease
if err := tq.UpdateTask(task); err != nil {
// Re-queue if update fails
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
return task, nil
}
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease.
func (tq *TaskQueue) GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
if blockTimeout <= 0 {
blockTimeout = defaultBlockTimeout
}
result, err := tq.client.BZPopMax(tq.ctx, blockTimeout, TaskQueueKey).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
if isBlockingUnsupported(err) {
return tq.pollUntilDeadline(workerID, leaseDuration, blockTimeout)
}
return nil, err
}
taskID, ok := result.Member.(string)
if !ok {
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
}
task, err := tq.GetTask(taskID)
if err != nil {
return nil, err
}
now := time.Now()
leaseExpiry := now.Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
task.LeasedBy = workerID
if err := tq.UpdateTask(task); err != nil {
return nil, err
}
return task, nil
}
// RenewLease renews the lease on a task (heartbeat)
func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Renew lease
leaseExpiry := time.Now().Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
// Record renewal metric
RecordLeaseRenewal(workerID)
return tq.UpdateTask(task)
}
// ReleaseLease releases the lease on a task
func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
return tq.UpdateTask(task)
}
// RetryTask re-queues a failed task with smart backoff based on error category
func (tq *TaskQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
// Move to dead letter queue
RecordDLQAddition("max_retries")
return tq.MoveToDeadLetterQueue(task, "max retries exceeded")
}
// Classify the error if it exists
errorCategory := ErrorUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
}
// Check if error is retryable
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error // Preserve last error
task.Error = "" // Clear current error
// Calculate smart backoff based on error category
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
// Record retry metrics
RecordTaskRetry(task.JobName, errorCategory)
// Re-queue with same priority
return tq.AddTask(task)
}
// MoveToDeadLetterQueue moves a task to the dead letter queue
func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
taskData, err := json.Marshal(task)
if err != nil {
return err
}
// Store in dead letter queue with timestamp
key := "task:dlq:" + task.ID
// Record metrics
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
pipe.ZRem(tq.ctx, TaskQueueKey, task.ID)
_, err = pipe.Exec(tq.ctx)
return err
}
// runLeaseReclamation reclaims expired leases every 1 minute
func (tq *TaskQueue) runLeaseReclamation() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-tq.ctx.Done():
return
case <-ticker.C:
if err := tq.reclaimExpiredLeases(); err != nil {
// Log error but continue
continue
}
}
}
}
// reclaimExpiredLeases finds and re-queues tasks with expired leases
func (tq *TaskQueue) reclaimExpiredLeases() error {
// Scan for all task keys
iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator()
now := time.Now()
for iter.Next(tq.ctx) {
taskKey := iter.Val()
taskID := taskKey[len(TaskPrefix):]
task, err := tq.GetTask(taskID)
if err != nil {
continue
}
// Check if lease expired and task is still running
if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" {
// Lease expired - retry or fail the task
task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy)
// Record lease expiration
RecordLeaseExpiration()
if task.RetryCount < task.MaxRetries {
// Retry the task
if err := tq.RetryTask(task); err != nil {
continue
}
} else {
// Max retries exceeded - move to DLQ
if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil {
continue
}
}
}
}
return iter.Err()
}
// GetTask retrieves a task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result()
if err != nil {
return nil, err
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
return nil, err
}
return &task, nil
}
// GetAllTasks retrieves all tasks from the queue
func (tq *TaskQueue) GetAllTasks() ([]*Task, error) {
// Get all task keys
keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result()
if err != nil {
return nil, err
}
var tasks []*Task
for _, key := range keys {
data, err := tq.client.Get(tq.ctx, key).Result()
if err != nil {
continue // Skip tasks that can't be retrieved
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
continue // Skip malformed tasks
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// GetTaskByName retrieves a task by its job name
func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) {
tasks, err := tq.GetAllTasks()
if err != nil {
return nil, err
}
for _, task := range tasks {
if task.JobName == jobName {
return task, nil
}
}
return nil, fmt.Errorf("task with job name '%s' not found", jobName)
}
// CancelTask marks a task as cancelled
func (tq *TaskQueue) CancelTask(taskID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Update task status to cancelled
task.Status = "cancelled"
now := time.Now()
task.EndedAt = &now
return tq.UpdateTask(task)
}
// UpdateTask updates a task in Redis
func (tq *TaskQueue) UpdateTask(task *Task) error {
taskData, err := json.Marshal(task)
if err != nil {
return err
}
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
_, err = pipe.Exec(tq.ctx)
return err
}
// UpdateTaskWithMetrics updates task and records metrics
func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error {
if err := tq.UpdateTask(task); err != nil {
return err
}
metricName := "tasks_" + action
return tq.RecordMetric(task.JobName, metricName, 1)
}
// RecordMetric records a metric value
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
evt := metricEvent{JobName: jobName, Metric: metric, Value: value}
select {
case tq.metricsCh <- evt:
return nil
default:
return tq.writeMetrics(jobName, map[string]float64{metric: value})
}
}
// Heartbeat records worker heartbeat
func (tq *TaskQueue) Heartbeat(workerID string) error {
return tq.client.HSet(tq.ctx, WorkerHeartbeat,
workerID, time.Now().Unix()).Err()
}
// QueueDepth returns the number of pending tasks
func (tq *TaskQueue) QueueDepth() (int64, error) {
return tq.client.ZCard(tq.ctx, TaskQueueKey).Result()
}
// Close closes the task queue and cleans up resources
func (tq *TaskQueue) Close() error {
tq.cancel()
<-tq.metricsDone // Wait for metrics buffer to finish
return tq.client.Close()
}
// GetRedisClient returns the underlying Redis client for direct access
func (tq *TaskQueue) GetRedisClient() *redis.Client {
return tq.client
}
// WaitForNextTask waits for next task with timeout
func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) {
if ctx == nil {
ctx = tq.ctx
}
result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
member, ok := result.Member.(string)
if !ok {
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
}
return tq.GetTask(member)
}
// runMetricsBuffer buffers and flushes metrics
func (tq *TaskQueue) runMetricsBuffer() {
defer close(tq.metricsDone)
ticker := time.NewTicker(tq.flushEvery)
defer ticker.Stop()
pending := make(map[string]map[string]float64)
flush := func() {
for job, metrics := range pending {
if err := tq.writeMetrics(job, metrics); err != nil {
continue
}
delete(pending, job)
}
}
for {
select {
case <-tq.ctx.Done():
flush()
return
case evt, ok := <-tq.metricsCh:
if !ok {
flush()
return
}
if _, exists := pending[evt.JobName]; !exists {
pending[evt.JobName] = make(map[string]float64)
}
pending[evt.JobName][evt.Metric] = evt.Value
case <-ticker.C:
flush()
}
}
}
// writeMetrics writes metrics to Redis
func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error {
if len(metrics) == 0 {
return nil
}
key := JobMetricsPrefix + jobName
args := make([]any, 0, len(metrics)*2+2)
args = append(args, "timestamp", time.Now().Unix())
for metric, value := range metrics {
args = append(args, metric, value)
}
return tq.client.HSet(context.Background(), key, args...).Err()
}