- Move ci-test.sh and setup.sh to scripts/ - Trim docs/src/zig-cli.md to current structure - Replace hardcoded secrets with placeholders in configs - Update .gitignore to block .env*, secrets/, keys, build artifacts - Slim README.md to reflect current CLI/TUI split - Add cleanup trap to ci-test.sh - Ensure no secrets are committed
1045 lines
28 KiB
Go
1045 lines
28 KiB
Go
// Package main implements the ML task worker
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/errtypes"
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
|
"github.com/jfraeys/fetch_ml/internal/network"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
)
|
|
|
|
// MLServer wraps network.SSHClient for backward compatibility.
|
|
type MLServer struct {
|
|
*network.SSHClient
|
|
}
|
|
|
|
// isValidName validates that input strings contain only safe characters.
|
|
// isValidName checks if the input string is a valid name.
|
|
func isValidName(input string) bool {
|
|
return len(input) > 0 && len(input) < 256
|
|
}
|
|
|
|
// NewMLServer creates a new ML server connection.
|
|
// NewMLServer returns a new MLServer instance.
|
|
func NewMLServer(cfg *Config) (*MLServer, error) {
|
|
if cfg.LocalMode {
|
|
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
|
|
}
|
|
|
|
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MLServer{SSHClient: client}, nil
|
|
}
|
|
|
|
// Worker represents an ML task worker.
|
|
type Worker struct {
|
|
id string
|
|
config *Config
|
|
server *MLServer
|
|
queue *queue.TaskQueue
|
|
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
|
runningMu sync.RWMutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
logger *logging.Logger
|
|
metrics *metrics.Metrics
|
|
metricsSrv *http.Server
|
|
|
|
datasetCache map[string]time.Time
|
|
datasetCacheMu sync.RWMutex
|
|
datasetCacheTTL time.Duration
|
|
|
|
// Graceful shutdown fields
|
|
shutdownCh chan struct{}
|
|
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
|
gracefulWait sync.WaitGroup
|
|
}
|
|
|
|
func (w *Worker) setupMetricsExporter() {
|
|
if !w.config.Metrics.Enabled {
|
|
return
|
|
}
|
|
|
|
reg := prometheus.NewRegistry()
|
|
reg.MustRegister(
|
|
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
|
collectors.NewGoCollector(),
|
|
)
|
|
|
|
labels := prometheus.Labels{"worker_id": w.id}
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_tasks_processed_total",
|
|
Help: "Total tasks processed successfully by this worker.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.TasksProcessed.Load())
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_tasks_failed_total",
|
|
Help: "Total tasks failed by this worker.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.TasksFailed.Load())
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_tasks_active",
|
|
Help: "Number of tasks currently running on this worker.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.runningCount())
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_tasks_queued",
|
|
Help: "Latest observed queue depth from Redis.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.QueuedTasks.Load())
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_data_transferred_bytes_total",
|
|
Help: "Total bytes transferred while fetching datasets.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.DataTransferred.Load())
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_data_fetch_time_seconds_total",
|
|
Help: "Total time spent fetching datasets (seconds).",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_execution_time_seconds_total",
|
|
Help: "Total execution time for completed tasks (seconds).",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
|
|
}))
|
|
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "fetchml_worker_max_concurrency",
|
|
Help: "Configured maximum concurrent tasks for this worker.",
|
|
ConstLabels: labels,
|
|
}, func() float64 {
|
|
return float64(w.config.MaxWorkers)
|
|
}))
|
|
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
|
|
|
srv := &http.Server{
|
|
Addr: w.config.Metrics.ListenAddr,
|
|
Handler: mux,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
}
|
|
|
|
w.metricsSrv = srv
|
|
go func() {
|
|
w.logger.Info("metrics exporter listening",
|
|
"addr", w.config.Metrics.ListenAddr,
|
|
"enabled", w.config.Metrics.Enabled)
|
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
w.logger.Warn("metrics exporter stopped",
|
|
"error", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// NewWorker creates a new worker instance.
|
|
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|
srv, err := NewMLServer(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
queueCfg := queue.Config{
|
|
RedisAddr: cfg.RedisAddr,
|
|
RedisPassword: cfg.RedisPassword,
|
|
RedisDB: cfg.RedisDB,
|
|
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
|
}
|
|
queue, err := queue.NewTaskQueue(queueCfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create data_dir if it doesn't exist (for production without NAS)
|
|
if cfg.DataDir != "" {
|
|
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
|
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
|
}
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ctx = logging.EnsureTrace(ctx)
|
|
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
|
|
|
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
|
logger := baseLogger.Component(ctx, "worker")
|
|
metrics := &metrics.Metrics{}
|
|
|
|
worker := &Worker{
|
|
id: cfg.WorkerID,
|
|
config: cfg,
|
|
server: srv,
|
|
queue: queue,
|
|
running: make(map[string]context.CancelFunc),
|
|
datasetCache: make(map[string]time.Time),
|
|
datasetCacheTTL: cfg.DatasetCacheTTL,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
logger: logger,
|
|
metrics: metrics,
|
|
shutdownCh: make(chan struct{}),
|
|
}
|
|
|
|
worker.setupMetricsExporter()
|
|
|
|
return worker, nil
|
|
}
|
|
|
|
// Start starts the worker's main processing loop.
|
|
func (w *Worker) Start() {
|
|
w.logger.Info("worker started",
|
|
"worker_id", w.id,
|
|
"max_concurrent", w.config.MaxWorkers,
|
|
"poll_interval", w.config.PollInterval)
|
|
|
|
go w.heartbeat()
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
w.logger.Info("shutdown signal received, waiting for tasks")
|
|
w.waitForTasks()
|
|
return
|
|
default:
|
|
}
|
|
|
|
if w.runningCount() >= w.config.MaxWorkers {
|
|
time.Sleep(50 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
queueStart := time.Now()
|
|
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
|
|
task, err := w.queue.GetNextTaskWithLeaseBlocking(w.config.WorkerID, w.config.TaskLeaseDuration, blockTimeout)
|
|
queueLatency := time.Since(queueStart)
|
|
if err != nil {
|
|
if err == context.DeadlineExceeded {
|
|
continue
|
|
}
|
|
w.logger.Error("error fetching task",
|
|
"worker_id", w.id,
|
|
"error", err)
|
|
continue
|
|
}
|
|
|
|
if task == nil {
|
|
if queueLatency > 200*time.Millisecond {
|
|
w.logger.Debug("queue poll latency",
|
|
"latency_ms", queueLatency.Milliseconds())
|
|
}
|
|
continue
|
|
}
|
|
|
|
if depth, derr := w.queue.QueueDepth(); derr == nil {
|
|
if queueLatency > 100*time.Millisecond || depth > 0 {
|
|
w.logger.Debug("queue fetch metrics",
|
|
"latency_ms", queueLatency.Milliseconds(),
|
|
"remaining_depth", depth)
|
|
}
|
|
} else if queueLatency > 100*time.Millisecond {
|
|
w.logger.Debug("queue fetch metrics",
|
|
"latency_ms", queueLatency.Milliseconds(),
|
|
"depth_error", derr)
|
|
}
|
|
|
|
go w.executeTaskWithLease(task)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) heartbeat() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := w.queue.Heartbeat(w.id); err != nil {
|
|
w.logger.Warn("heartbeat failed",
|
|
"worker_id", w.id,
|
|
"error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// NEW: Fetch datasets using data_manager.
|
|
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
|
|
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
|
logger.Info("fetching datasets",
|
|
"worker_id", w.id,
|
|
"dataset_count", len(task.Datasets))
|
|
|
|
for _, dataset := range task.Datasets {
|
|
if w.datasetIsFresh(dataset) {
|
|
logger.Debug("skipping cached dataset",
|
|
"dataset", dataset)
|
|
continue
|
|
}
|
|
// Check for cancellation before each dataset fetch
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err())
|
|
default:
|
|
}
|
|
|
|
logger.Info("fetching dataset",
|
|
"worker_id", w.id,
|
|
"dataset", dataset)
|
|
|
|
// Create command with context for cancellation support
|
|
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
|
// Validate inputs to prevent command injection
|
|
if !isValidName(task.JobName) || !isValidName(dataset) {
|
|
cancel()
|
|
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
|
|
}
|
|
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
|
|
cmd := exec.CommandContext(cmdCtx,
|
|
w.config.DataManagerPath,
|
|
"fetch",
|
|
task.JobName,
|
|
dataset,
|
|
)
|
|
|
|
output, err := cmd.CombinedOutput()
|
|
cancel() // Clean up context
|
|
|
|
if err != nil {
|
|
return &errtypes.DataFetchError{
|
|
Dataset: dataset,
|
|
JobName: task.JobName,
|
|
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
|
|
}
|
|
}
|
|
|
|
logger.Info("dataset ready",
|
|
"worker_id", w.id,
|
|
"dataset", dataset)
|
|
w.markDatasetFetched(dataset)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) runJob(ctx context.Context, task *queue.Task) error {
|
|
// Validate job name to prevent path traversal
|
|
if err := container.ValidateJobName(task.JobName); err != nil {
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "validation",
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
jobDir, outputDir, logFile, err := w.setupJobDirectories(task)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return w.executeJob(ctx, task, jobDir, outputDir, logFile)
|
|
}
|
|
|
|
func (w *Worker) setupJobDirectories(task *queue.Task) (jobDir, outputDir, logFile string, err error) {
|
|
jobPaths := config.NewJobPaths(w.config.BasePath)
|
|
pendingDir := jobPaths.PendingPath()
|
|
jobDir = filepath.Join(pendingDir, task.JobName)
|
|
outputDir = filepath.Join(jobPaths.RunningPath(), task.JobName)
|
|
logFile = filepath.Join(outputDir, "output.log")
|
|
|
|
// Create pending directory
|
|
if err := os.MkdirAll(pendingDir, 0750); err != nil {
|
|
return "", "", "", &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "setup",
|
|
Err: fmt.Errorf("failed to create pending dir: %w", err),
|
|
}
|
|
}
|
|
|
|
// Create job directory in pending
|
|
if err := os.MkdirAll(jobDir, 0750); err != nil {
|
|
return "", "", "", &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "setup",
|
|
Err: fmt.Errorf("failed to create job dir: %w", err),
|
|
}
|
|
}
|
|
|
|
// Sanitize paths
|
|
jobDir, err = container.SanitizePath(jobDir)
|
|
if err != nil {
|
|
return "", "", "", &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "validation",
|
|
Err: err,
|
|
}
|
|
}
|
|
outputDir, err = container.SanitizePath(outputDir)
|
|
if err != nil {
|
|
return "", "", "", &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "validation",
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
return jobDir, outputDir, logFile, nil
|
|
}
|
|
|
|
func (w *Worker) executeJob(ctx context.Context, task *queue.Task, jobDir, outputDir, logFile string) error {
|
|
// Create output directory
|
|
if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) {
|
|
if err := os.MkdirAll(outputDir, 0750); err != nil {
|
|
return "", fmt.Errorf("mkdir failed: %w", err)
|
|
}
|
|
return "", nil
|
|
}); err != nil {
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "setup",
|
|
Err: fmt.Errorf("failed to create output dir: %w", err),
|
|
}
|
|
}
|
|
|
|
// Move job from pending to running
|
|
stagingStart := time.Now()
|
|
if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) {
|
|
// Remove existing directory if it exists
|
|
if _, err := os.Stat(outputDir); err == nil {
|
|
if err := os.RemoveAll(outputDir); err != nil {
|
|
return "", fmt.Errorf("remove existing failed: %w", err)
|
|
}
|
|
}
|
|
if err := os.Rename(jobDir, outputDir); err != nil {
|
|
return "", fmt.Errorf("rename failed: %w", err)
|
|
}
|
|
return "", nil
|
|
}); err != nil {
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "setup",
|
|
Err: fmt.Errorf("failed to move job: %w", err),
|
|
}
|
|
}
|
|
stagingDuration := time.Since(stagingStart)
|
|
|
|
// Execute job
|
|
if w.config.LocalMode {
|
|
return w.executeLocalJob(ctx, task, outputDir, logFile)
|
|
}
|
|
|
|
return w.executeContainerJob(ctx, task, outputDir, logFile, stagingDuration)
|
|
}
|
|
|
|
func (w *Worker) executeLocalJob(ctx context.Context, task *queue.Task, outputDir, logFile string) error {
|
|
// Create experiment script
|
|
scriptContent := `#!/bin/bash
|
|
set -e
|
|
|
|
echo "Starting experiment: ` + task.JobName + `"
|
|
echo "Task ID: ` + task.ID + `"
|
|
echo "Timestamp: $(date)"
|
|
|
|
# Simulate ML experiment
|
|
echo "Loading data..."
|
|
sleep 1
|
|
|
|
echo "Training model..."
|
|
sleep 2
|
|
|
|
echo "Evaluating model..."
|
|
sleep 1
|
|
|
|
# Generate results
|
|
ACCURACY=0.95
|
|
LOSS=0.05
|
|
EPOCHS=10
|
|
|
|
echo ""
|
|
echo "=== EXPERIMENT RESULTS ==="
|
|
echo "Accuracy: $ACCURACY"
|
|
echo "Loss: $LOSS"
|
|
echo "Epochs: $EPOCHS"
|
|
echo "Status: SUCCESS"
|
|
echo "========================="
|
|
echo "Experiment completed successfully!"
|
|
`
|
|
|
|
scriptPath := filepath.Join(outputDir, "run.sh")
|
|
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil {
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "execution",
|
|
Err: fmt.Errorf("failed to write script: %w", err),
|
|
}
|
|
}
|
|
|
|
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
|
if err != nil {
|
|
w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err)
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "execution",
|
|
Err: fmt.Errorf("failed to open log file: %w", err),
|
|
}
|
|
}
|
|
defer func() {
|
|
if err := logFileHandle.Close(); err != nil {
|
|
log.Printf("Warning: failed to close log file: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Execute the script directly
|
|
localCmd := exec.CommandContext(ctx, "bash", scriptPath)
|
|
localCmd.Stdout = logFileHandle
|
|
localCmd.Stderr = logFileHandle
|
|
|
|
w.logger.Info("executing local job",
|
|
"job", task.JobName,
|
|
"task_id", task.ID,
|
|
"script", scriptPath)
|
|
|
|
if err := localCmd.Run(); err != nil {
|
|
return &errtypes.TaskExecutionError{
|
|
TaskID: task.ID,
|
|
JobName: task.JobName,
|
|
Phase: "execution",
|
|
Err: fmt.Errorf("execution failed: %w", err),
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) executeContainerJob(
|
|
ctx context.Context,
|
|
task *queue.Task,
|
|
outputDir, logFile string,
|
|
stagingDuration time.Duration,
|
|
) error {
|
|
containerResults := w.config.ContainerResults
|
|
if containerResults == "" {
|
|
containerResults = config.DefaultContainerResults
|
|
}
|
|
|
|
containerWorkspace := w.config.ContainerWorkspace
|
|
if containerWorkspace == "" {
|
|
containerWorkspace = config.DefaultContainerWorkspace
|
|
}
|
|
|
|
jobPaths := config.NewJobPaths(w.config.BasePath)
|
|
stagingStart := time.Now()
|
|
|
|
podmanCfg := container.PodmanConfig{
|
|
Image: w.config.PodmanImage,
|
|
Workspace: filepath.Join(outputDir, "code"),
|
|
Results: filepath.Join(outputDir, "results"),
|
|
ContainerWorkspace: containerWorkspace,
|
|
ContainerResults: containerResults,
|
|
GPUAccess: w.config.GPUAccess,
|
|
}
|
|
|
|
scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript)
|
|
requirementsPath := filepath.Join(containerWorkspace, "requirements.txt")
|
|
|
|
var extraArgs []string
|
|
if task.Args != "" {
|
|
extraArgs = strings.Fields(task.Args)
|
|
}
|
|
|
|
ioBefore, ioErr := telemetry.ReadProcessIO()
|
|
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, requirementsPath, extraArgs)
|
|
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
|
if err == nil {
|
|
podmanCmd.Stdout = logFileHandle
|
|
podmanCmd.Stderr = logFileHandle
|
|
} else {
|
|
w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err)
|
|
}
|
|
|
|
w.logger.Info("executing podman job",
|
|
"job", task.JobName,
|
|
"image", w.config.PodmanImage,
|
|
"workspace", podmanCfg.Workspace,
|
|
"results", podmanCfg.Results)
|
|
|
|
containerStart := time.Now()
|
|
if err := podmanCmd.Run(); err != nil {
|
|
containerDuration := time.Since(containerStart)
|
|
// Move job to failed directory
|
|
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
|
|
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) {
|
|
if err := os.Rename(outputDir, failedDir); err != nil {
|
|
return "", fmt.Errorf("rename to failed failed: %w", err)
|
|
}
|
|
return "", nil
|
|
}); moveErr != nil {
|
|
w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr)
|
|
}
|
|
|
|
if ioErr == nil {
|
|
if after, err := telemetry.ReadProcessIO(); err == nil {
|
|
delta := telemetry.DiffIO(ioBefore, after)
|
|
w.logger.Debug("worker io stats",
|
|
"job", task.JobName,
|
|
"read_bytes", delta.ReadBytes,
|
|
"write_bytes", delta.WriteBytes)
|
|
}
|
|
}
|
|
w.logger.Info("job timing (failure)",
|
|
"job", task.JobName,
|
|
"staging_ms", stagingDuration.Milliseconds(),
|
|
"container_ms", containerDuration.Milliseconds(),
|
|
"finalize_ms", 0,
|
|
"total_ms", time.Since(stagingStart).Milliseconds(),
|
|
)
|
|
return fmt.Errorf("execution failed: %w", err)
|
|
}
|
|
containerDuration := time.Since(containerStart)
|
|
|
|
finalizeStart := time.Now()
|
|
// Move job to finished directory
|
|
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
|
|
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) {
|
|
if err := os.Rename(outputDir, finishedDir); err != nil {
|
|
return "", fmt.Errorf("rename to finished failed: %w", err)
|
|
}
|
|
return "", nil
|
|
}); moveErr != nil {
|
|
w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr)
|
|
}
|
|
finalizeDuration := time.Since(finalizeStart)
|
|
totalDuration := time.Since(stagingStart)
|
|
var ioDelta telemetry.IOStats
|
|
if ioErr == nil {
|
|
if after, err := telemetry.ReadProcessIO(); err == nil {
|
|
ioDelta = telemetry.DiffIO(ioBefore, after)
|
|
}
|
|
}
|
|
|
|
w.logger.Info("job timing",
|
|
"job", task.JobName,
|
|
"staging_ms", stagingDuration.Milliseconds(),
|
|
"container_ms", containerDuration.Milliseconds(),
|
|
"finalize_ms", finalizeDuration.Milliseconds(),
|
|
"total_ms", totalDuration.Milliseconds(),
|
|
"io_read_bytes", ioDelta.ReadBytes,
|
|
"io_write_bytes", ioDelta.WriteBytes,
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseDatasets(args string) []string {
|
|
if !strings.Contains(args, "--datasets") {
|
|
return nil
|
|
}
|
|
|
|
parts := strings.Fields(args)
|
|
for i, part := range parts {
|
|
if part == "--datasets" && i+1 < len(parts) {
|
|
return strings.Split(parts[i+1], ",")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) waitForTasks() {
|
|
timeout := time.After(5 * time.Minute)
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-timeout:
|
|
w.logger.Warn("shutdown timeout, force stopping",
|
|
"running_tasks", len(w.running))
|
|
return
|
|
case <-ticker.C:
|
|
count := w.runningCount()
|
|
if count == 0 {
|
|
w.logger.Info("all tasks completed, shutting down")
|
|
return
|
|
}
|
|
w.logger.Debug("waiting for tasks to complete",
|
|
"remaining", count)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) runningCount() int {
|
|
w.runningMu.RLock()
|
|
defer w.runningMu.RUnlock()
|
|
return len(w.running)
|
|
}
|
|
|
|
func (w *Worker) datasetIsFresh(dataset string) bool {
|
|
w.datasetCacheMu.RLock()
|
|
defer w.datasetCacheMu.RUnlock()
|
|
expires, ok := w.datasetCache[dataset]
|
|
return ok && time.Now().Before(expires)
|
|
}
|
|
|
|
func (w *Worker) markDatasetFetched(dataset string) {
|
|
expires := time.Now().Add(w.datasetCacheTTL)
|
|
w.datasetCacheMu.Lock()
|
|
w.datasetCache[dataset] = expires
|
|
w.datasetCacheMu.Unlock()
|
|
}
|
|
|
|
// GetMetrics returns current worker metrics.
|
|
func (w *Worker) GetMetrics() map[string]any {
|
|
stats := w.metrics.GetStats()
|
|
stats["worker_id"] = w.id
|
|
stats["max_workers"] = w.config.MaxWorkers
|
|
return stats
|
|
}
|
|
|
|
// Stop gracefully shuts down the worker.
|
|
func (w *Worker) Stop() {
|
|
w.cancel()
|
|
w.waitForTasks()
|
|
|
|
// FIXED: Check error return values
|
|
if err := w.server.Close(); err != nil {
|
|
w.logger.Warn("error closing server connection", "error", err)
|
|
}
|
|
if err := w.queue.Close(); err != nil {
|
|
w.logger.Warn("error closing queue connection", "error", err)
|
|
}
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.logger.Warn("metrics exporter shutdown error", "error", err)
|
|
}
|
|
}
|
|
w.logger.Info("worker stopped", "worker_id", w.id)
|
|
}
|
|
|
|
// Execute task with lease management and retry.
|
|
func (w *Worker) executeTaskWithLease(task *queue.Task) {
|
|
// Track task for graceful shutdown
|
|
w.gracefulWait.Add(1)
|
|
w.activeTasks.Store(task.ID, task)
|
|
defer w.gracefulWait.Done()
|
|
defer w.activeTasks.Delete(task.ID)
|
|
|
|
// Create task-specific context with timeout
|
|
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
|
|
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
|
|
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
|
|
|
|
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
|
|
defer taskCancel()
|
|
|
|
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
|
|
logger.Info("starting task",
|
|
"worker_id", w.id,
|
|
"datasets", task.Datasets,
|
|
"priority", task.Priority)
|
|
|
|
// Record task start
|
|
w.metrics.RecordTaskStart()
|
|
defer w.metrics.RecordTaskCompletion()
|
|
|
|
// Check for context cancellation
|
|
select {
|
|
case <-taskCtx.Done():
|
|
logger.Info("task cancelled before execution")
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Parse datasets from task arguments
|
|
if task.Datasets == nil {
|
|
task.Datasets = parseDatasets(task.Args)
|
|
}
|
|
|
|
// Start heartbeat goroutine
|
|
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
|
defer cancelHeartbeat()
|
|
|
|
go w.heartbeatLoop(heartbeatCtx, task.ID)
|
|
|
|
// Update task status
|
|
task.Status = "running"
|
|
now := time.Now()
|
|
task.StartedAt = &now
|
|
task.WorkerID = w.id
|
|
|
|
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
|
|
logger.Error("failed to update task status", "error", err)
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
|
|
if w.config.AutoFetchData && len(task.Datasets) > 0 {
|
|
if err := w.fetchDatasets(taskCtx, task); err != nil {
|
|
logger.Error("data fetch failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
err := w.queue.UpdateTask(task)
|
|
if err != nil {
|
|
logger.Error("failed to update task status after data fetch failure", "error", err)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
}
|
|
|
|
// Execute job with panic recovery
|
|
var execErr error
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
execErr = fmt.Errorf("panic during execution: %v", r)
|
|
}
|
|
}()
|
|
execErr = w.runJob(taskCtx, task)
|
|
}()
|
|
|
|
// Finalize task
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
|
|
if execErr != nil {
|
|
task.Error = execErr.Error()
|
|
|
|
// Check if transient error (network, timeout, etc)
|
|
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
|
|
w.logger.Warn("task failed with transient error, will retry",
|
|
"task_id", task.ID,
|
|
"error", execErr,
|
|
"retry_count", task.RetryCount)
|
|
_ = w.queue.RetryTask(task)
|
|
} else {
|
|
task.Status = "failed"
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
}
|
|
} else {
|
|
task.Status = "completed"
|
|
|
|
// Read output file for completed tasks
|
|
jobPaths := config.NewJobPaths(w.config.BasePath)
|
|
outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName)
|
|
logFile := filepath.Join(outputDir, "output.log")
|
|
if outputBytes, err := os.ReadFile(logFile); err == nil {
|
|
task.Output = string(outputBytes)
|
|
}
|
|
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
}
|
|
|
|
// Release lease
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
}
|
|
|
|
// Heartbeat loop to renew lease.
|
|
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
|
|
ticker := time.NewTicker(w.config.HeartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
|
|
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
|
|
return
|
|
}
|
|
// Also update worker heartbeat
|
|
_ = w.queue.Heartbeat(w.config.WorkerID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the worker.
|
|
func (w *Worker) Shutdown() error {
|
|
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
|
|
|
|
// Wait for active tasks with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
w.gracefulWait.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
timeout := time.After(w.config.GracefulTimeout)
|
|
select {
|
|
case <-done:
|
|
w.logger.Info("all tasks completed, shutdown successful")
|
|
case <-timeout:
|
|
w.logger.Warn("graceful shutdown timeout, releasing active leases")
|
|
w.releaseAllLeases()
|
|
}
|
|
|
|
return w.queue.Close()
|
|
}
|
|
|
|
// Release all active leases.
|
|
func (w *Worker) releaseAllLeases() {
|
|
w.activeTasks.Range(func(key, _ interface{}) bool {
|
|
taskID := key.(string)
|
|
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
|
|
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// Helper functions.
|
|
func (w *Worker) countActiveTasks() int {
|
|
count := 0
|
|
w.activeTasks.Range(func(_, _ interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
return count
|
|
}
|
|
|
|
func isTransientError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
// Check if error is transient (network, timeout, resource unavailable, etc)
|
|
errStr := err.Error()
|
|
transientIndicators := []string{
|
|
"connection refused",
|
|
"timeout",
|
|
"temporary failure",
|
|
"resource temporarily unavailable",
|
|
"no such host",
|
|
"network unreachable",
|
|
}
|
|
for _, indicator := range transientIndicators {
|
|
if strings.Contains(strings.ToLower(errStr), indicator) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func main() {
|
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
|
|
|
// Parse authentication flags
|
|
authFlags := auth.ParseAuthFlags()
|
|
if err := auth.ValidateFlags(authFlags); err != nil {
|
|
log.Fatalf("Authentication flag error: %v", err)
|
|
}
|
|
|
|
// Get API key from various sources
|
|
apiKey := auth.GetAPIKeyFromSources(authFlags)
|
|
|
|
// Load configuration
|
|
configPath := "config-local.yaml"
|
|
if authFlags.ConfigFile != "" {
|
|
configPath = authFlags.ConfigFile
|
|
}
|
|
|
|
resolvedConfig, err := config.ResolveConfigPath(configPath)
|
|
if err != nil {
|
|
log.Fatalf("%v", err)
|
|
}
|
|
|
|
cfg, err := LoadConfig(resolvedConfig)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// Validate authentication configuration
|
|
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
|
log.Fatalf("Invalid authentication configuration: %v", err)
|
|
}
|
|
|
|
// Validate configuration
|
|
if err := cfg.Validate(); err != nil {
|
|
log.Fatalf("Invalid configuration: %v", err)
|
|
}
|
|
|
|
// Test authentication if enabled
|
|
if cfg.Auth.Enabled && apiKey != "" {
|
|
user, err := cfg.Auth.ValidateAPIKey(apiKey)
|
|
if err != nil {
|
|
log.Fatalf("Authentication failed: %v", err)
|
|
}
|
|
log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin)
|
|
} else if cfg.Auth.Enabled {
|
|
log.Fatal("Authentication required but no API key provided")
|
|
}
|
|
|
|
worker, err := NewWorker(cfg, apiKey)
|
|
if err != nil {
|
|
log.Fatalf("Failed to create worker: %v", err)
|
|
}
|
|
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go worker.Start()
|
|
|
|
sig := <-sigChan
|
|
log.Printf("Received signal: %v", sig)
|
|
|
|
// Use graceful shutdown
|
|
if err := worker.Shutdown(); err != nil {
|
|
log.Printf("Graceful shutdown error: %v", err)
|
|
worker.Stop() // Fallback to force stop
|
|
} else {
|
|
log.Println("Worker shut down gracefully")
|
|
}
|
|
}
|