fetch_ml/cmd/worker/worker_server.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

1021 lines
28 KiB
Go

// Package main implements the ML task worker
package main
import (
"context"
"fmt"
"log"
"log/slog"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// MLServer wraps network.SSHClient for backward compatibility.
type MLServer struct {
*network.SSHClient
}
// isValidName validates that input strings contain only safe characters.
// isValidName checks if the input string is a valid name.
func isValidName(input string) bool {
return len(input) > 0 && len(input) < 256
}
// NewMLServer creates a new ML server connection.
// NewMLServer returns a new MLServer instance.
func NewMLServer(cfg *Config) (*MLServer, error) {
if cfg.LocalMode {
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
}
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
return &MLServer{SSHClient: client}, nil
}
// Worker represents an ML task worker.
type Worker struct {
id string
config *Config
server *MLServer
queue *queue.TaskQueue
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
runningMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
metrics *metrics.Metrics
metricsSrv *http.Server
datasetCache map[string]time.Time
datasetCacheMu sync.RWMutex
datasetCacheTTL time.Duration
// Graceful shutdown fields
shutdownCh chan struct{}
activeTasks sync.Map // map[string]*queue.Task - track active tasks
gracefulWait sync.WaitGroup
}
func (w *Worker) setupMetricsExporter() {
if !w.config.Metrics.Enabled {
return
}
reg := prometheus.NewRegistry()
reg.MustRegister(
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
collectors.NewGoCollector(),
)
labels := prometheus.Labels{"worker_id": w.id}
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_processed_total",
Help: "Total tasks processed successfully by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksProcessed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_failed_total",
Help: "Total tasks failed by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksFailed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_active",
Help: "Number of tasks currently running on this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.runningCount())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_queued",
Help: "Latest observed queue depth from Redis.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.QueuedTasks.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_transferred_bytes_total",
Help: "Total bytes transferred while fetching datasets.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataTransferred.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_fetch_time_seconds_total",
Help: "Total time spent fetching datasets (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_execution_time_seconds_total",
Help: "Total execution time for completed tasks (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_worker_max_concurrency",
Help: "Configured maximum concurrent tasks for this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.config.MaxWorkers)
}))
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
srv := &http.Server{
Addr: w.config.Metrics.ListenAddr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
w.metricsSrv = srv
go func() {
w.logger.Info("metrics exporter listening",
"addr", w.config.Metrics.ListenAddr,
"enabled", w.config.Metrics.Enabled)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
w.logger.Warn("metrics exporter stopped",
"error", err)
}
}()
}
// NewWorker creates a new worker instance.
func NewWorker(cfg *Config, _ string) (*Worker, error) {
srv, err := NewMLServer(cfg)
if err != nil {
return nil, err
}
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
return nil, err
}
// Create data_dir if it doesn't exist (for production without NAS)
if cfg.DataDir != "" {
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
}
}
ctx, cancel := context.WithCancel(context.Background())
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
baseLogger := logging.NewLogger(slog.LevelInfo, false)
logger := baseLogger.Component(ctx, "worker")
metrics := &metrics.Metrics{}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
server: srv,
queue: queue,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
ctx: ctx,
cancel: cancel,
logger: logger,
metrics: metrics,
shutdownCh: make(chan struct{}),
}
worker.setupMetricsExporter()
return worker, nil
}
// Start starts the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker started",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers,
"poll_interval", w.config.PollInterval)
go w.heartbeat()
for {
select {
case <-w.ctx.Done():
w.logger.Info("shutdown signal received, waiting for tasks")
w.waitForTasks()
return
default:
}
if w.runningCount() >= w.config.MaxWorkers {
time.Sleep(50 * time.Millisecond)
continue
}
queueStart := time.Now()
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
task, err := w.queue.GetNextTaskWithLeaseBlocking(w.config.WorkerID, w.config.TaskLeaseDuration, blockTimeout)
queueLatency := time.Since(queueStart)
if err != nil {
if err == context.DeadlineExceeded {
continue
}
w.logger.Error("error fetching task",
"worker_id", w.id,
"error", err)
continue
}
if task == nil {
if queueLatency > 200*time.Millisecond {
w.logger.Debug("queue poll latency",
"latency_ms", queueLatency.Milliseconds())
}
continue
}
if depth, derr := w.queue.QueueDepth(); derr == nil {
if queueLatency > 100*time.Millisecond || depth > 0 {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"remaining_depth", depth)
}
} else if queueLatency > 100*time.Millisecond {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"depth_error", derr)
}
go w.executeTaskWithLease(task)
}
}
func (w *Worker) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if err := w.queue.Heartbeat(w.id); err != nil {
w.logger.Warn("heartbeat failed",
"worker_id", w.id,
"error", err)
}
}
}
}
// NEW: Fetch datasets using data_manager.
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
logger := w.logger.Job(ctx, task.JobName, task.ID)
logger.Info("fetching datasets",
"worker_id", w.id,
"dataset_count", len(task.Datasets))
for _, dataset := range task.Datasets {
if w.datasetIsFresh(dataset) {
logger.Debug("skipping cached dataset",
"dataset", dataset)
continue
}
// Check for cancellation before each dataset fetch
select {
case <-w.ctx.Done():
return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err())
default:
}
logger.Info("fetching dataset",
"worker_id", w.id,
"dataset", dataset)
// Create command with context for cancellation support
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
// Validate inputs to prevent command injection
if !isValidName(task.JobName) || !isValidName(dataset) {
cancel()
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
}
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
cmd := exec.CommandContext(cmdCtx,
w.config.DataManagerPath,
"fetch",
task.JobName,
dataset,
)
output, err := cmd.CombinedOutput()
cancel() // Clean up context
if err != nil {
return &errtypes.DataFetchError{
Dataset: dataset,
JobName: task.JobName,
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
}
}
logger.Info("dataset ready",
"worker_id", w.id,
"dataset", dataset)
w.markDatasetFetched(dataset)
}
return nil
}
func (w *Worker) runJob(ctx context.Context, task *queue.Task) error {
// Validate job name to prevent path traversal
if err := container.ValidateJobName(task.JobName); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
jobPaths := config.NewJobPaths(w.config.BasePath)
pendingDir := jobPaths.PendingPath()
jobDir := filepath.Join(pendingDir, task.JobName)
outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName)
logFile := filepath.Join(outputDir, "output.log")
// Create pending directory
if err := os.MkdirAll(pendingDir, 0750); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create pending dir: %w", err),
}
}
// Create job directory in pending
if err := os.MkdirAll(jobDir, 0750); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create job dir: %w", err),
}
}
// Sanitize paths
var err error
jobDir, err = container.SanitizePath(jobDir)
if err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
outputDir, err = container.SanitizePath(outputDir)
if err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
// Create output directory
if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) {
if err := os.MkdirAll(outputDir, 0750); err != nil {
return "", fmt.Errorf("mkdir failed: %w", err)
}
return "", nil
}); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create output dir: %w", err),
}
}
// Move job from pending to running
stagingStart := time.Now()
if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) {
// Remove existing directory if it exists
if _, err := os.Stat(outputDir); err == nil {
if err := os.RemoveAll(outputDir); err != nil {
return "", fmt.Errorf("remove existing failed: %w", err)
}
}
if err := os.Rename(jobDir, outputDir); err != nil {
return "", fmt.Errorf("rename failed: %w", err)
}
return "", nil
}); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to move job: %w", err),
}
}
stagingDuration := time.Since(stagingStart)
// In local mode, execute directly without podman
if w.config.LocalMode {
// Create experiment script
scriptContent := `#!/bin/bash
set -e
echo "Starting experiment: ` + task.JobName + `"
echo "Task ID: ` + task.ID + `"
echo "Timestamp: $(date)"
# Simulate ML experiment
echo "Loading data..."
sleep 1
echo "Training model..."
sleep 2
echo "Evaluating model..."
sleep 1
# Generate results
ACCURACY=0.95
LOSS=0.05
EPOCHS=10
echo ""
echo "=== EXPERIMENT RESULTS ==="
echo "Accuracy: $ACCURACY"
echo "Loss: $LOSS"
echo "Epochs: $EPOCHS"
echo "Status: SUCCESS"
echo "========================="
echo "Experiment completed successfully!"
`
scriptPath := filepath.Join(outputDir, "run.sh")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("failed to write script: %w", err),
}
}
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err)
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("failed to open log file: %w", err),
}
}
defer logFileHandle.Close()
// Execute the script directly
localCmd := exec.CommandContext(ctx, "bash", scriptPath)
localCmd.Stdout = logFileHandle
localCmd.Stderr = logFileHandle
w.logger.Info("executing local job",
"job", task.JobName,
"task_id", task.ID,
"script", scriptPath)
if err := localCmd.Run(); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("execution failed: %w", err),
}
}
return nil
}
if w.config.PodmanImage == "" {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: fmt.Errorf("podman_image must be configured"),
}
}
containerWorkspace := w.config.ContainerWorkspace
if containerWorkspace == "" {
containerWorkspace = config.DefaultContainerWorkspace
}
containerResults := w.config.ContainerResults
if containerResults == "" {
containerResults = config.DefaultContainerResults
}
podmanCfg := container.PodmanConfig{
Image: w.config.PodmanImage,
Workspace: filepath.Join(outputDir, "code"),
Results: filepath.Join(outputDir, "results"),
ContainerWorkspace: containerWorkspace,
ContainerResults: containerResults,
GPUAccess: w.config.GPUAccess,
}
scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript)
requirementsPath := filepath.Join(containerWorkspace, "requirements.txt")
var extraArgs []string
if task.Args != "" {
extraArgs = strings.Fields(task.Args)
}
ioBefore, ioErr := telemetry.ReadProcessIO()
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, requirementsPath, extraArgs)
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err == nil {
podmanCmd.Stdout = logFileHandle
podmanCmd.Stderr = logFileHandle
} else {
w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err)
}
w.logger.Info("executing podman job",
"job", task.JobName,
"image", w.config.PodmanImage,
"workspace", podmanCfg.Workspace,
"results", podmanCfg.Results)
containerStart := time.Now()
if err := podmanCmd.Run(); err != nil {
containerDuration := time.Since(containerStart)
// Move job to failed directory
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) {
if err := os.Rename(outputDir, failedDir); err != nil {
return "", fmt.Errorf("rename to failed failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr)
}
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
delta := telemetry.DiffIO(ioBefore, after)
w.logger.Debug("worker io stats",
"job", task.JobName,
"read_bytes", delta.ReadBytes,
"write_bytes", delta.WriteBytes)
}
}
w.logger.Info("job timing (failure)",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", 0,
"total_ms", time.Since(stagingStart).Milliseconds(),
)
return fmt.Errorf("execution failed: %w", err)
}
containerDuration := time.Since(containerStart)
finalizeStart := time.Now()
// Move job to finished directory
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) {
if err := os.Rename(outputDir, finishedDir); err != nil {
return "", fmt.Errorf("rename to finished failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr)
}
finalizeDuration := time.Since(finalizeStart)
totalDuration := time.Since(stagingStart)
var ioDelta telemetry.IOStats
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
ioDelta = telemetry.DiffIO(ioBefore, after)
}
}
w.logger.Info("job timing",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", finalizeDuration.Milliseconds(),
"total_ms", totalDuration.Milliseconds(),
"io_read_bytes", ioDelta.ReadBytes,
"io_write_bytes", ioDelta.WriteBytes,
)
return nil
}
func parseDatasets(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}
func (w *Worker) waitForTasks() {
timeout := time.After(5 * time.Minute)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-timeout:
w.logger.Warn("shutdown timeout, force stopping",
"running_tasks", len(w.running))
return
case <-ticker.C:
count := w.runningCount()
if count == 0 {
w.logger.Info("all tasks completed, shutting down")
return
}
w.logger.Debug("waiting for tasks to complete",
"remaining", count)
}
}
}
func (w *Worker) runningCount() int {
w.runningMu.RLock()
defer w.runningMu.RUnlock()
return len(w.running)
}
func (w *Worker) datasetIsFresh(dataset string) bool {
w.datasetCacheMu.RLock()
defer w.datasetCacheMu.RUnlock()
expires, ok := w.datasetCache[dataset]
return ok && time.Now().Before(expires)
}
func (w *Worker) markDatasetFetched(dataset string) {
expires := time.Now().Add(w.datasetCacheTTL)
w.datasetCacheMu.Lock()
w.datasetCache[dataset] = expires
w.datasetCacheMu.Unlock()
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
return stats
}
// Stop gracefully shuts down the worker.
func (w *Worker) Stop() {
w.cancel()
w.waitForTasks()
// FIXED: Check error return values
if err := w.server.Close(); err != nil {
w.logger.Warn("error closing server connection", "error", err)
}
if err := w.queue.Close(); err != nil {
w.logger.Warn("error closing queue connection", "error", err)
}
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics exporter shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Execute task with lease management and retry.
func (w *Worker) executeTaskWithLease(task *queue.Task) {
// Track task for graceful shutdown
w.gracefulWait.Add(1)
w.activeTasks.Store(task.ID, task)
defer w.gracefulWait.Done()
defer w.activeTasks.Delete(task.ID)
// Create task-specific context with timeout
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
defer taskCancel()
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
logger.Info("starting task",
"worker_id", w.id,
"datasets", task.Datasets,
"priority", task.Priority)
// Record task start
w.metrics.RecordTaskStart()
defer w.metrics.RecordTaskCompletion()
// Check for context cancellation
select {
case <-taskCtx.Done():
logger.Info("task cancelled before execution")
return
default:
}
// Parse datasets from task arguments
if task.Datasets == nil {
task.Datasets = parseDatasets(task.Args)
}
// Start heartbeat goroutine
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer cancelHeartbeat()
go w.heartbeatLoop(heartbeatCtx, task.ID)
// Update task status
task.Status = "running"
now := time.Now()
task.StartedAt = &now
task.WorkerID = w.id
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
logger.Error("failed to update task status", "error", err)
w.metrics.RecordTaskFailure()
return
}
if w.config.AutoFetchData && len(task.Datasets) > 0 {
if err := w.fetchDatasets(taskCtx, task); err != nil {
logger.Error("data fetch failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
err := w.queue.UpdateTask(task)
if err != nil {
logger.Error("failed to update task status after data fetch failure", "error", err)
}
w.metrics.RecordTaskFailure()
return
}
}
// Execute job with panic recovery
var execErr error
func() {
defer func() {
if r := recover(); r != nil {
execErr = fmt.Errorf("panic during execution: %v", r)
}
}()
execErr = w.runJob(taskCtx, task)
}()
// Finalize task
endTime := time.Now()
task.EndedAt = &endTime
if execErr != nil {
task.Error = execErr.Error()
// Check if transient error (network, timeout, etc)
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
w.logger.Warn("task failed with transient error, will retry",
"task_id", task.ID,
"error", execErr,
"retry_count", task.RetryCount)
_ = w.queue.RetryTask(task)
} else {
task.Status = "failed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
} else {
task.Status = "completed"
// Read output file for completed tasks
jobPaths := config.NewJobPaths(w.config.BasePath)
outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName)
logFile := filepath.Join(outputDir, "output.log")
if outputBytes, err := os.ReadFile(logFile); err == nil {
task.Output = string(outputBytes)
}
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
// Release lease
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
}
// Heartbeat loop to renew lease.
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
ticker := time.NewTicker(w.config.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
return
}
// Also update worker heartbeat
_ = w.queue.Heartbeat(w.config.WorkerID)
}
}
}
// Shutdown gracefully shuts down the worker.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
// Wait for active tasks with timeout
done := make(chan struct{})
go func() {
w.gracefulWait.Wait()
close(done)
}()
timeout := time.After(w.config.GracefulTimeout)
select {
case <-done:
w.logger.Info("all tasks completed, shutdown successful")
case <-timeout:
w.logger.Warn("graceful shutdown timeout, releasing active leases")
w.releaseAllLeases()
}
return w.queue.Close()
}
// Release all active leases.
func (w *Worker) releaseAllLeases() {
w.activeTasks.Range(func(key, _ interface{}) bool {
taskID := key.(string)
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
}
return true
})
}
// Helper functions.
func (w *Worker) countActiveTasks() int {
count := 0
w.activeTasks.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
func isTransientError(err error) bool {
if err == nil {
return false
}
// Check if error is transient (network, timeout, resource unavailable, etc)
errStr := err.Error()
transientIndicators := []string{
"connection refused",
"timeout",
"temporary failure",
"resource temporarily unavailable",
"no such host",
"network unreachable",
}
for _, indicator := range transientIndicators {
if strings.Contains(strings.ToLower(errStr), indicator) {
return true
}
}
return false
}
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateFlags(authFlags); err != nil {
log.Fatalf("Authentication flag error: %v", err)
}
// Get API key from various sources
apiKey := auth.GetAPIKeyFromSources(authFlags)
// Load configuration
configPath := "config-local.yaml"
if authFlags.ConfigFile != "" {
configPath = authFlags.ConfigFile
}
resolvedConfig, err := config.ResolveConfigPath(configPath)
if err != nil {
log.Fatalf("%v", err)
}
cfg, err := LoadConfig(resolvedConfig)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
}
// Validate configuration
if err := cfg.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Test authentication if enabled
if cfg.Auth.Enabled && apiKey != "" {
user, err := cfg.Auth.ValidateAPIKey(apiKey)
if err != nil {
log.Fatalf("Authentication failed: %v", err)
}
log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin)
} else if cfg.Auth.Enabled {
log.Fatal("Authentication required but no API key provided")
}
worker, err := NewWorker(cfg, apiKey)
if err != nil {
log.Fatalf("Failed to create worker: %v", err)
}
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go worker.Start()
sig := <-sigChan
log.Printf("Received signal: %v", sig)
// Use graceful shutdown
if err := worker.Shutdown(); err != nil {
log.Printf("Graceful shutdown error: %v", err)
worker.Stop() // Fallback to force stop
} else {
log.Println("Worker shut down gracefully")
}
}