- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
651 lines
16 KiB
Go
651 lines
16 KiB
Go
package queue
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
defaultMetricsFlushInterval = 500 * time.Millisecond
|
|
defaultLeaseDuration = 30 * time.Minute
|
|
defaultMaxRetries = 3
|
|
defaultBlockTimeout = 1 * time.Second
|
|
)
|
|
|
|
// TaskQueue manages ML experiment tasks via Redis
|
|
type TaskQueue struct {
|
|
client *redis.Client
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
metricsCh chan metricEvent
|
|
metricsDone chan struct{}
|
|
flushEvery time.Duration
|
|
}
|
|
|
|
type metricEvent struct {
|
|
JobName string
|
|
Metric string
|
|
Value float64
|
|
}
|
|
|
|
// Config holds configuration for TaskQueue
|
|
type Config struct {
|
|
RedisAddr string
|
|
RedisPassword string
|
|
RedisDB int
|
|
MetricsFlushInterval time.Duration
|
|
}
|
|
|
|
// NewTaskQueue creates a new task queue instance with Redis backend.
|
|
func NewTaskQueue(cfg Config) (*TaskQueue, error) {
|
|
var opts *redis.Options
|
|
var err error
|
|
|
|
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
|
|
opts, err = redis.ParseURL(cfg.RedisAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid redis url: %w", err)
|
|
}
|
|
} else {
|
|
opts = &redis.Options{
|
|
Addr: cfg.RedisAddr,
|
|
Password: cfg.RedisPassword,
|
|
DB: cfg.RedisDB,
|
|
// Connection pooling optimizations
|
|
PoolSize: 50, // Increase connection pool size
|
|
MinIdleConns: 10, // Maintain minimum idle connections
|
|
MaxRetries: 3, // Reduce retries for faster failure
|
|
DialTimeout: 5 * time.Second, // Faster connection timeout
|
|
ReadTimeout: 3 * time.Second, // Faster read timeout
|
|
WriteTimeout: 3 * time.Second, // Faster write timeout
|
|
PoolTimeout: 4 * time.Second, // Faster pool timeout
|
|
}
|
|
}
|
|
|
|
rdb := redis.NewClient(opts)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
if err := rdb.Ping(ctx).Err(); err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("redis connection failed: %w", err)
|
|
}
|
|
|
|
flushEvery := cfg.MetricsFlushInterval
|
|
if flushEvery == 0 {
|
|
flushEvery = defaultMetricsFlushInterval
|
|
}
|
|
|
|
tq := &TaskQueue{
|
|
client: rdb,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
metricsCh: make(chan metricEvent, 256),
|
|
metricsDone: make(chan struct{}),
|
|
flushEvery: flushEvery,
|
|
}
|
|
|
|
go tq.runMetricsBuffer()
|
|
go tq.runLeaseReclamation() // Start lease reclamation background job
|
|
|
|
return tq, nil
|
|
}
|
|
|
|
// AddTask adds a new task to the queue with default retry settings
|
|
func (tq *TaskQueue) AddTask(task *Task) error {
|
|
// Set default retry settings if not specified
|
|
if task.MaxRetries == 0 {
|
|
task.MaxRetries = defaultMaxRetries
|
|
}
|
|
|
|
taskData, err := json.Marshal(task)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal task: %w", err)
|
|
}
|
|
|
|
pipe := tq.client.Pipeline()
|
|
|
|
// Store task data
|
|
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
|
|
|
|
// Add to priority queue (ZSET)
|
|
// Use priority as score (higher priority = higher score)
|
|
pipe.ZAddArgs(tq.ctx, TaskQueueKey, redis.ZAddArgs{
|
|
Members: []redis.Z{
|
|
{
|
|
Score: float64(task.Priority),
|
|
Member: task.ID,
|
|
},
|
|
},
|
|
})
|
|
|
|
// Initialize status
|
|
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
|
|
"status", task.Status,
|
|
"task_id", task.ID,
|
|
"updated_at", time.Now().Format(time.RFC3339))
|
|
|
|
// Queue depth update piggybacks on the same pipeline round-trip to avoid
|
|
// an extra Redis call per enqueue under high load.
|
|
depthCmd := pipe.ZCard(tq.ctx, TaskQueueKey)
|
|
|
|
_, err = pipe.Exec(tq.ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to enqueue task: %w", err)
|
|
}
|
|
|
|
// Record metrics
|
|
TasksQueued.Inc()
|
|
|
|
// Update queue depth
|
|
UpdateQueueDepth(depthCmd.Val())
|
|
|
|
return nil
|
|
}
|
|
|
|
func isBlockingUnsupported(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
msg := strings.ToLower(err.Error())
|
|
return strings.Contains(msg, "unknown command") && strings.Contains(msg, "bzpopmax")
|
|
}
|
|
|
|
func (tq *TaskQueue) pollUntilDeadline(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
|
|
deadline := time.Now().Add(blockTimeout)
|
|
sleep := 25 * time.Millisecond
|
|
|
|
for {
|
|
if time.Now().After(deadline) {
|
|
return nil, nil
|
|
}
|
|
|
|
task, err := tq.GetNextTaskWithLease(workerID, leaseDuration)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if task != nil {
|
|
return task, nil
|
|
}
|
|
|
|
select {
|
|
case <-tq.ctx.Done():
|
|
return nil, tq.ctx.Err()
|
|
case <-time.After(sleep):
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetNextTask gets the next task without lease (backward compatible)
|
|
func (tq *TaskQueue) GetNextTask() (*Task, error) {
|
|
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(result) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
taskID := result[0].Member.(string)
|
|
return tq.GetTask(taskID)
|
|
}
|
|
|
|
// GetNextTaskWithLease gets the next task and acquires a lease
|
|
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
|
if leaseDuration == 0 {
|
|
leaseDuration = defaultLeaseDuration
|
|
}
|
|
|
|
// Pop highest priority task
|
|
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(result) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
taskID := result[0].Member.(string)
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
// Re-queue the task if we can't fetch it
|
|
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
|
|
Score: result[0].Score,
|
|
Member: taskID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Acquire lease
|
|
now := time.Now()
|
|
leaseExpiry := now.Add(leaseDuration)
|
|
task.LeaseExpiry = &leaseExpiry
|
|
task.LeasedBy = workerID
|
|
|
|
// Update task with lease
|
|
if err := tq.UpdateTask(task); err != nil {
|
|
// Re-queue if update fails
|
|
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
|
|
Score: result[0].Score,
|
|
Member: taskID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
return task, nil
|
|
}
|
|
|
|
// GetNextTaskWithLeaseBlocking blocks up to blockTimeout waiting for a task before acquiring a lease.
|
|
func (tq *TaskQueue) GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
|
|
if leaseDuration == 0 {
|
|
leaseDuration = defaultLeaseDuration
|
|
}
|
|
if blockTimeout <= 0 {
|
|
blockTimeout = defaultBlockTimeout
|
|
}
|
|
|
|
result, err := tq.client.BZPopMax(tq.ctx, blockTimeout, TaskQueueKey).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
if isBlockingUnsupported(err) {
|
|
return tq.pollUntilDeadline(workerID, leaseDuration, blockTimeout)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
taskID, ok := result.Member.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
|
|
}
|
|
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now()
|
|
leaseExpiry := now.Add(leaseDuration)
|
|
task.LeaseExpiry = &leaseExpiry
|
|
task.LeasedBy = workerID
|
|
|
|
if err := tq.UpdateTask(task); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return task, nil
|
|
}
|
|
|
|
// RenewLease renews the lease on a task (heartbeat)
|
|
func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
|
|
if leaseDuration == 0 {
|
|
leaseDuration = defaultLeaseDuration
|
|
}
|
|
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify the worker owns the lease
|
|
if task.LeasedBy != workerID {
|
|
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
|
|
}
|
|
|
|
// Renew lease
|
|
leaseExpiry := time.Now().Add(leaseDuration)
|
|
task.LeaseExpiry = &leaseExpiry
|
|
|
|
// Record renewal metric
|
|
RecordLeaseRenewal(workerID)
|
|
|
|
return tq.UpdateTask(task)
|
|
}
|
|
|
|
// ReleaseLease releases the lease on a task
|
|
func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify the worker owns the lease
|
|
if task.LeasedBy != workerID {
|
|
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
|
|
}
|
|
|
|
// Clear lease
|
|
task.LeaseExpiry = nil
|
|
task.LeasedBy = ""
|
|
|
|
return tq.UpdateTask(task)
|
|
}
|
|
|
|
// RetryTask re-queues a failed task with smart backoff based on error category
|
|
func (tq *TaskQueue) RetryTask(task *Task) error {
|
|
if task.RetryCount >= task.MaxRetries {
|
|
// Move to dead letter queue
|
|
RecordDLQAddition("max_retries")
|
|
return tq.MoveToDeadLetterQueue(task, "max retries exceeded")
|
|
}
|
|
|
|
// Classify the error if it exists
|
|
errorCategory := ErrorUnknown
|
|
if task.Error != "" {
|
|
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
|
}
|
|
|
|
// Check if error is retryable
|
|
if !IsRetryable(errorCategory) {
|
|
RecordDLQAddition(string(errorCategory))
|
|
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
|
}
|
|
|
|
task.RetryCount++
|
|
task.Status = "queued"
|
|
task.LastError = task.Error // Preserve last error
|
|
task.Error = "" // Clear current error
|
|
|
|
// Calculate smart backoff based on error category
|
|
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
|
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
|
|
task.NextRetry = &nextRetry
|
|
|
|
// Clear lease
|
|
task.LeaseExpiry = nil
|
|
task.LeasedBy = ""
|
|
|
|
// Record retry metrics
|
|
RecordTaskRetry(task.JobName, errorCategory)
|
|
|
|
// Re-queue with same priority
|
|
return tq.AddTask(task)
|
|
}
|
|
|
|
// MoveToDeadLetterQueue moves a task to the dead letter queue
|
|
func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
|
|
|
taskData, err := json.Marshal(task)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Store in dead letter queue with timestamp
|
|
key := "task:dlq:" + task.ID
|
|
|
|
// Record metrics
|
|
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
|
|
|
pipe := tq.client.Pipeline()
|
|
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
|
|
pipe.ZRem(tq.ctx, TaskQueueKey, task.ID)
|
|
_, err = pipe.Exec(tq.ctx)
|
|
return err
|
|
}
|
|
|
|
// runLeaseReclamation reclaims expired leases every 1 minute
|
|
func (tq *TaskQueue) runLeaseReclamation() {
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-tq.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := tq.reclaimExpiredLeases(); err != nil {
|
|
// Log error but continue
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// reclaimExpiredLeases finds and re-queues tasks with expired leases
|
|
func (tq *TaskQueue) reclaimExpiredLeases() error {
|
|
// Scan for all task keys
|
|
iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator()
|
|
now := time.Now()
|
|
|
|
for iter.Next(tq.ctx) {
|
|
taskKey := iter.Val()
|
|
taskID := taskKey[len(TaskPrefix):]
|
|
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Check if lease expired and task is still running
|
|
if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" {
|
|
// Lease expired - retry or fail the task
|
|
task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy)
|
|
|
|
// Record lease expiration
|
|
RecordLeaseExpiration()
|
|
|
|
if task.RetryCount < task.MaxRetries {
|
|
// Retry the task
|
|
if err := tq.RetryTask(task); err != nil {
|
|
continue
|
|
}
|
|
} else {
|
|
// Max retries exceeded - move to DLQ
|
|
if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return iter.Err()
|
|
}
|
|
|
|
// GetTask retrieves a task by ID
|
|
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
|
|
data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var task Task
|
|
if err := json.Unmarshal([]byte(data), &task); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &task, nil
|
|
}
|
|
|
|
// GetAllTasks retrieves all tasks from the queue
|
|
func (tq *TaskQueue) GetAllTasks() ([]*Task, error) {
|
|
// Get all task keys
|
|
keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var tasks []*Task
|
|
for _, key := range keys {
|
|
data, err := tq.client.Get(tq.ctx, key).Result()
|
|
if err != nil {
|
|
continue // Skip tasks that can't be retrieved
|
|
}
|
|
|
|
var task Task
|
|
if err := json.Unmarshal([]byte(data), &task); err != nil {
|
|
continue // Skip malformed tasks
|
|
}
|
|
|
|
tasks = append(tasks, &task)
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
// GetTaskByName retrieves a task by its job name
|
|
func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) {
|
|
tasks, err := tq.GetAllTasks()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
if task.JobName == jobName {
|
|
return task, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("task with job name '%s' not found", jobName)
|
|
}
|
|
|
|
// CancelTask marks a task as cancelled
|
|
func (tq *TaskQueue) CancelTask(taskID string) error {
|
|
task, err := tq.GetTask(taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update task status to cancelled
|
|
task.Status = "cancelled"
|
|
now := time.Now()
|
|
task.EndedAt = &now
|
|
|
|
return tq.UpdateTask(task)
|
|
}
|
|
|
|
// UpdateTask updates a task in Redis
|
|
func (tq *TaskQueue) UpdateTask(task *Task) error {
|
|
taskData, err := json.Marshal(task)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pipe := tq.client.Pipeline()
|
|
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
|
|
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
|
|
"status", task.Status,
|
|
"task_id", task.ID,
|
|
"updated_at", time.Now().Format(time.RFC3339))
|
|
|
|
_, err = pipe.Exec(tq.ctx)
|
|
return err
|
|
}
|
|
|
|
// UpdateTaskWithMetrics updates task and records metrics
|
|
func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error {
|
|
if err := tq.UpdateTask(task); err != nil {
|
|
return err
|
|
}
|
|
|
|
metricName := "tasks_" + action
|
|
return tq.RecordMetric(task.JobName, metricName, 1)
|
|
}
|
|
|
|
// RecordMetric records a metric value
|
|
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
|
|
evt := metricEvent{JobName: jobName, Metric: metric, Value: value}
|
|
select {
|
|
case tq.metricsCh <- evt:
|
|
return nil
|
|
default:
|
|
return tq.writeMetrics(jobName, map[string]float64{metric: value})
|
|
}
|
|
}
|
|
|
|
// Heartbeat records worker heartbeat
|
|
func (tq *TaskQueue) Heartbeat(workerID string) error {
|
|
return tq.client.HSet(tq.ctx, WorkerHeartbeat,
|
|
workerID, time.Now().Unix()).Err()
|
|
}
|
|
|
|
// QueueDepth returns the number of pending tasks
|
|
func (tq *TaskQueue) QueueDepth() (int64, error) {
|
|
return tq.client.ZCard(tq.ctx, TaskQueueKey).Result()
|
|
}
|
|
|
|
// Close closes the task queue and cleans up resources
|
|
func (tq *TaskQueue) Close() error {
|
|
tq.cancel()
|
|
<-tq.metricsDone // Wait for metrics buffer to finish
|
|
return tq.client.Close()
|
|
}
|
|
|
|
// GetRedisClient returns the underlying Redis client for direct access
|
|
func (tq *TaskQueue) GetRedisClient() *redis.Client {
|
|
return tq.client
|
|
}
|
|
|
|
// WaitForNextTask waits for next task with timeout
|
|
func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) {
|
|
if ctx == nil {
|
|
ctx = tq.ctx
|
|
}
|
|
result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
member, ok := result.Member.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
|
|
}
|
|
return tq.GetTask(member)
|
|
}
|
|
|
|
// runMetricsBuffer buffers and flushes metrics
|
|
func (tq *TaskQueue) runMetricsBuffer() {
|
|
defer close(tq.metricsDone)
|
|
ticker := time.NewTicker(tq.flushEvery)
|
|
defer ticker.Stop()
|
|
pending := make(map[string]map[string]float64)
|
|
flush := func() {
|
|
for job, metrics := range pending {
|
|
if err := tq.writeMetrics(job, metrics); err != nil {
|
|
continue
|
|
}
|
|
delete(pending, job)
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-tq.ctx.Done():
|
|
flush()
|
|
return
|
|
case evt, ok := <-tq.metricsCh:
|
|
if !ok {
|
|
flush()
|
|
return
|
|
}
|
|
if _, exists := pending[evt.JobName]; !exists {
|
|
pending[evt.JobName] = make(map[string]float64)
|
|
}
|
|
pending[evt.JobName][evt.Metric] = evt.Value
|
|
case <-ticker.C:
|
|
flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeMetrics writes metrics to Redis
|
|
func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error {
|
|
if len(metrics) == 0 {
|
|
return nil
|
|
}
|
|
key := JobMetricsPrefix + jobName
|
|
args := make([]any, 0, len(metrics)*2+2)
|
|
args = append(args, "timestamp", time.Now().Unix())
|
|
for metric, value := range metrics {
|
|
args = append(args, metric, value)
|
|
}
|
|
return tq.client.HSet(context.Background(), key, args...).Err()
|
|
}
|