- Refactor internal/worker and internal/queue packages - Update cmd/tui for monitoring interface - Update test configurations
225 lines
4.9 KiB
Go
225 lines
4.9 KiB
Go
// Package redis provides a Redis-based queue implementation
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/domain"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// Queue implements a Redis-based task queue
|
|
type Queue struct {
|
|
client *redis.Client
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
metricsCh chan metricEvent
|
|
metricsDone chan struct{}
|
|
flushEvery time.Duration
|
|
}
|
|
|
|
type metricEvent struct {
|
|
JobName string
|
|
Metric string
|
|
Value float64
|
|
}
|
|
|
|
// Config holds configuration for Queue
|
|
type Config struct {
|
|
RedisAddr string
|
|
RedisPassword string
|
|
RedisDB int
|
|
MetricsFlushInterval time.Duration
|
|
}
|
|
|
|
// NewQueue creates a new Redis queue instance
|
|
func NewQueue(cfg Config) (*Queue, error) {
|
|
var opts *redis.Options
|
|
var err error
|
|
|
|
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
|
|
opts, err = redis.ParseURL(cfg.RedisAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid redis url: %w", err)
|
|
}
|
|
} else {
|
|
opts = &redis.Options{
|
|
Addr: cfg.RedisAddr,
|
|
Password: cfg.RedisPassword,
|
|
DB: cfg.RedisDB,
|
|
PoolSize: 50,
|
|
MinIdleConns: 10,
|
|
MaxRetries: 3,
|
|
DialTimeout: 5 * time.Second,
|
|
ReadTimeout: 3 * time.Second,
|
|
}
|
|
}
|
|
|
|
client := redis.NewClient(opts)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Test connection
|
|
if err := client.Ping(ctx).Err(); err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
|
}
|
|
|
|
flushEvery := cfg.MetricsFlushInterval
|
|
if flushEvery == 0 {
|
|
flushEvery = 500 * time.Millisecond
|
|
}
|
|
|
|
q := &Queue{
|
|
client: client,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
metricsCh: make(chan metricEvent, 100),
|
|
metricsDone: make(chan struct{}),
|
|
flushEvery: flushEvery,
|
|
}
|
|
|
|
go q.metricsFlusher()
|
|
|
|
return q, nil
|
|
}
|
|
|
|
// Close closes the queue
|
|
func (q *Queue) Close() error {
|
|
q.cancel()
|
|
close(q.metricsCh)
|
|
<-q.metricsDone
|
|
return q.client.Close()
|
|
}
|
|
|
|
// AddTask adds a task to the queue
|
|
func (q *Queue) AddTask(task *domain.Task) error {
|
|
if task == nil {
|
|
return errors.New("task is nil")
|
|
}
|
|
if task.ID == "" {
|
|
return errors.New("task ID is required")
|
|
}
|
|
|
|
data, err := json.Marshal(task)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal task: %w", err)
|
|
}
|
|
|
|
// Add to task hash and queue
|
|
pipe := q.client.Pipeline()
|
|
pipe.HSet(q.ctx, "ml:task:"+task.ID, "data", data)
|
|
pipe.LPush(q.ctx, "ml:queue", task.ID)
|
|
_, err = pipe.Exec(q.ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTask retrieves a task by ID
|
|
func (q *Queue) GetTask(id string) (*domain.Task, error) {
|
|
if id == "" {
|
|
return nil, errors.New("task ID is required")
|
|
}
|
|
|
|
data, err := q.client.HGet(q.ctx, "ml:task:"+id, "data").Bytes()
|
|
if err == redis.Nil {
|
|
return nil, fmt.Errorf("task not found: %s", id)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get task: %w", err)
|
|
}
|
|
|
|
var task domain.Task
|
|
if err := json.Unmarshal(data, &task); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
|
|
}
|
|
|
|
return &task, nil
|
|
}
|
|
|
|
// ListTasks lists all tasks in the queue
|
|
func (q *Queue) ListTasks() ([]*domain.Task, error) {
|
|
// Get all task IDs from the queue
|
|
ids, err := q.client.LRange(q.ctx, "ml:queue", 0, -1).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list tasks: %w", err)
|
|
}
|
|
|
|
var tasks []*domain.Task
|
|
for _, id := range ids {
|
|
task, err := q.GetTask(id)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
// CancelTask cancels a task
|
|
func (q *Queue) CancelTask(id string) error {
|
|
if id == "" {
|
|
return errors.New("task ID is required")
|
|
}
|
|
|
|
// Remove from queue and mark as cancelled
|
|
pipe := q.client.Pipeline()
|
|
pipe.LRem(q.ctx, "ml:queue", 0, id)
|
|
pipe.HSet(q.ctx, "ml:task:"+id, "status", "cancelled")
|
|
_, err := pipe.Exec(q.ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to cancel task: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateTask updates a task
|
|
func (q *Queue) UpdateTask(task *domain.Task) error {
|
|
if task == nil || task.ID == "" {
|
|
return errors.New("task is nil or missing ID")
|
|
}
|
|
|
|
data, err := json.Marshal(task)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal task: %w", err)
|
|
}
|
|
|
|
if err := q.client.HSet(q.ctx, "ml:task:"+task.ID, "data", data).Err(); err != nil {
|
|
return fmt.Errorf("failed to update task: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// metricsFlusher periodically flushes metrics
|
|
func (q *Queue) metricsFlusher() {
|
|
ticker := time.NewTicker(q.flushEvery)
|
|
defer ticker.Stop()
|
|
|
|
metrics := make(map[string]map[string]float64)
|
|
|
|
for {
|
|
select {
|
|
case <-q.ctx.Done():
|
|
close(q.metricsDone)
|
|
return
|
|
case evt := <-q.metricsCh:
|
|
if metrics[evt.JobName] == nil {
|
|
metrics[evt.JobName] = make(map[string]float64)
|
|
}
|
|
metrics[evt.JobName][evt.Metric] = evt.Value
|
|
case <-ticker.C:
|
|
// Flush metrics to Redis or other backend
|
|
metrics = make(map[string]map[string]float64)
|
|
}
|
|
}
|
|
}
|