- Refactor worker configuration management - Improve container executor lifecycle handling - Update runloop and worker core logic - Enhance scheduler service template generation - Remove obsolete 'scheduler' symlink/directory
482 lines
14 KiB
Go
482 lines
14 KiB
Go
// Package worker provides the ML task worker implementation
|
|
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/resources"
|
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/execution"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
|
)
|
|
|
|
// Worker represents an ML task worker with composed dependencies.
|
|
type Worker struct {
|
|
Jupyter plugins.JupyterManager
|
|
QueueClient queue.Backend
|
|
Config *Config
|
|
Logger *logging.Logger
|
|
RunLoop *lifecycle.RunLoop
|
|
Runner *executor.JobRunner
|
|
Metrics *metrics.Metrics
|
|
metricsSrv *http.Server
|
|
Health *lifecycle.HealthMonitor
|
|
Resources *resources.Manager
|
|
ID string
|
|
gpuDetectionInfo GPUDetectionInfo
|
|
schedulerConn *scheduler.SchedulerConn // For distributed mode
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// Start begins the worker's main processing loop.
|
|
func (w *Worker) Start() {
|
|
w.Logger.Info("worker starting",
|
|
"worker_id", w.ID,
|
|
"max_concurrent", w.Config.MaxWorkers,
|
|
"mode", w.Config.Mode,
|
|
)
|
|
slog.SetDefault(w.Logger.Logger)
|
|
|
|
w.ctx, w.cancel = context.WithCancel(context.Background())
|
|
w.Health.RecordHeartbeat()
|
|
|
|
// Start heartbeat loop for distributed mode
|
|
if w.Config.Mode == "distributed" && w.schedulerConn != nil {
|
|
go w.heartbeatLoop()
|
|
}
|
|
|
|
w.RunLoop.Start()
|
|
}
|
|
|
|
// heartbeatLoop sends periodic heartbeats with slot status to scheduler
|
|
func (w *Worker) heartbeatLoop() {
|
|
// Use configured interval or default to 10s
|
|
intervalSecs := w.Config.Scheduler.HeartbeatIntervalSecs
|
|
if intervalSecs == 0 {
|
|
intervalSecs = 10
|
|
}
|
|
|
|
// Add jitter (0-5s) to prevent thundering herd
|
|
jitter := time.Duration(rand.Intn(5)) * time.Second
|
|
interval := time.Duration(intervalSecs)*time.Second + jitter
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
w.Health.RecordHeartbeat()
|
|
if w.schedulerConn != nil {
|
|
slots := scheduler.SlotStatus{
|
|
BatchTotal: w.Config.MaxWorkers,
|
|
BatchInUse: w.RunLoop.RunningCount(),
|
|
}
|
|
w.schedulerConn.Send(scheduler.Message{
|
|
Type: scheduler.MsgHeartbeat,
|
|
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
|
WorkerID: w.ID,
|
|
Slots: slots,
|
|
}),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop gracefully shuts down the worker immediately.
|
|
func (w *Worker) Stop() {
|
|
w.Logger.Info("worker stopping", "worker_id", w.ID)
|
|
|
|
if w.cancel != nil {
|
|
w.cancel()
|
|
}
|
|
|
|
w.RunLoop.Stop()
|
|
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.Logger.Warn("metrics server shutdown error", "error", err)
|
|
}
|
|
}
|
|
|
|
w.Logger.Info("worker stopped", "worker_id", w.ID)
|
|
}
|
|
|
|
// Shutdown performs a graceful shutdown with timeout.
|
|
func (w *Worker) Shutdown() error {
|
|
w.Logger.Info("starting graceful shutdown", "worker_id", w.ID)
|
|
|
|
w.RunLoop.Stop()
|
|
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.Logger.Warn("metrics server shutdown error", "error", err)
|
|
}
|
|
}
|
|
|
|
w.Logger.Info("worker shut down gracefully", "worker_id", w.ID)
|
|
return nil
|
|
}
|
|
|
|
// IsHealthy returns true if the worker is healthy.
|
|
func (w *Worker) IsHealthy() bool {
|
|
return w.Health.IsHealthy(5 * time.Minute)
|
|
}
|
|
|
|
// GetMetrics returns current worker metrics.
|
|
func (w *Worker) GetMetrics() map[string]any {
|
|
stats := w.Metrics.GetStats()
|
|
stats["worker_id"] = w.ID
|
|
stats["max_workers"] = w.Config.MaxWorkers
|
|
stats["healthy"] = w.IsHealthy()
|
|
return stats
|
|
}
|
|
|
|
// GetID returns the worker ID.
|
|
func (w *Worker) GetID() string {
|
|
return w.ID
|
|
}
|
|
|
|
// SelectDependencyManifest re-exports the executor function for API helpers.
|
|
// It detects the dependency manifest file in the given directory.
|
|
func SelectDependencyManifest(filesPath string) (string, error) {
|
|
return executor.SelectDependencyManifest(filesPath)
|
|
}
|
|
|
|
// DirOverallSHA256Hex re-exports the integrity function for test compatibility.
|
|
func DirOverallSHA256Hex(root string) (string, error) {
|
|
return integrity.DirOverallSHA256Hex(root)
|
|
}
|
|
|
|
// NormalizeSHA256ChecksumHex re-exports the integrity function for test compatibility.
|
|
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
|
return integrity.NormalizeSHA256ChecksumHex(checksum)
|
|
}
|
|
|
|
// StageSnapshot re-exports the execution function for test compatibility.
|
|
func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error {
|
|
return execution.StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir)
|
|
}
|
|
|
|
// StageSnapshotFromPath re-exports the execution function for test compatibility.
|
|
func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error {
|
|
return execution.StageSnapshotFromPath(basePath, taskID, srcPath, jobDir)
|
|
}
|
|
|
|
// ComputeTaskProvenance computes provenance information for a task.
|
|
// This re-exports the integrity function for test compatibility.
|
|
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
|
pc := integrity.NewProvenanceCalculator(basePath)
|
|
return pc.ComputeProvenance(task)
|
|
}
|
|
|
|
// VerifyDatasetSpecs verifies dataset specifications for this task.
|
|
// This is a test compatibility method that wraps the integrity package.
|
|
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
|
dataDir := w.Config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = "/tmp/data"
|
|
}
|
|
verifier := integrity.NewDatasetVerifier(dataDir)
|
|
return verifier.VerifyDatasetSpecs(task)
|
|
}
|
|
|
|
// EnforceTaskProvenance enforces provenance requirements for a task.
|
|
// It validates and/or populates provenance metadata based on the ProvenanceBestEffort config.
|
|
// In strict mode (ProvenanceBestEffort=false), it returns an error if metadata doesn't match computed values.
|
|
// In best-effort mode (ProvenanceBestEffort=true), it populates missing metadata fields.
|
|
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
|
if task == nil {
|
|
return fmt.Errorf("task is nil")
|
|
}
|
|
|
|
basePath := w.Config.BasePath
|
|
if basePath == "" {
|
|
basePath = os.TempDir()
|
|
}
|
|
dataDir := w.Config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(basePath, "data")
|
|
}
|
|
|
|
bestEffort := w.Config.ProvenanceBestEffort
|
|
|
|
// Get commit_id from metadata
|
|
commitID := task.Metadata["commit_id"]
|
|
if commitID == "" {
|
|
return fmt.Errorf("missing commit_id in task metadata")
|
|
}
|
|
|
|
// Compute and verify experiment manifest SHA
|
|
expPath := filepath.Join(basePath, commitID)
|
|
manifestSHA, err := integrity.DirOverallSHA256Hex(expPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute experiment manifest SHA: %w", err)
|
|
}
|
|
// In best-effort mode, we'll use whatever is provided or skip
|
|
manifestSHA = ""
|
|
}
|
|
|
|
// Handle experiment_manifest_overall_sha
|
|
expectedManifestSHA := task.Metadata["experiment_manifest_overall_sha"]
|
|
if expectedManifestSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing experiment_manifest_overall_sha in task metadata")
|
|
}
|
|
// Populate in best-effort mode
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["experiment_manifest_overall_sha"] = manifestSHA
|
|
} else if !bestEffort && expectedManifestSHA != manifestSHA {
|
|
return fmt.Errorf("experiment manifest SHA mismatch: expected %s, got %s", expectedManifestSHA, manifestSHA)
|
|
}
|
|
|
|
// Handle deps_manifest_sha256 - auto-detect if not provided
|
|
filesPath := filepath.Join(expPath, "files")
|
|
depsManifestName := task.Metadata["deps_manifest_name"]
|
|
if depsManifestName == "" {
|
|
// Auto-detect manifest file
|
|
depsManifestName, _ = executor.SelectDependencyManifest(filesPath)
|
|
}
|
|
if depsManifestName != "" {
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["deps_manifest_name"] = depsManifestName
|
|
depsPath := filepath.Join(filesPath, depsManifestName)
|
|
depsSHA, err := integrity.FileSHA256Hex(depsPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute deps manifest SHA: %w", err)
|
|
}
|
|
depsSHA = ""
|
|
}
|
|
|
|
expectedDepsSHA := task.Metadata["deps_manifest_sha256"]
|
|
if expectedDepsSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing deps_manifest_sha256 in task metadata")
|
|
}
|
|
task.Metadata["deps_manifest_sha256"] = depsSHA
|
|
} else if !bestEffort && expectedDepsSHA != depsSHA {
|
|
return fmt.Errorf("deps manifest SHA mismatch: expected %s, got %s", expectedDepsSHA, depsSHA)
|
|
}
|
|
}
|
|
|
|
// Handle snapshot_sha256 if SnapshotID is set
|
|
if task.SnapshotID != "" {
|
|
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
snapSHA, err := integrity.DirOverallSHA256Hex(snapPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute snapshot SHA: %w", err)
|
|
}
|
|
snapSHA = ""
|
|
}
|
|
|
|
expectedSnapSHA, _ := integrity.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
|
if expectedSnapSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing snapshot_sha256 in task metadata")
|
|
}
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["snapshot_sha256"] = snapSHA
|
|
} else if !bestEffort && expectedSnapSHA != snapSHA {
|
|
return fmt.Errorf("snapshot SHA mismatch: expected %s, got %s", expectedSnapSHA, snapSHA)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifySnapshot verifies snapshot integrity for this task.
|
|
// It computes the SHA256 of the snapshot directory and compares with task metadata.
|
|
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|
if task.SnapshotID == "" {
|
|
return nil // No snapshot to verify
|
|
}
|
|
|
|
dataDir := w.Config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = os.TempDir() + "/data"
|
|
}
|
|
|
|
// Get expected checksum from metadata
|
|
expectedChecksum, ok := task.Metadata["snapshot_sha256"]
|
|
if !ok || expectedChecksum == "" {
|
|
return fmt.Errorf("missing snapshot_sha256 in task metadata")
|
|
}
|
|
|
|
// Normalize the checksum (remove sha256: prefix if present)
|
|
expectedChecksum, err := integrity.NormalizeSHA256ChecksumHex(expectedChecksum)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid snapshot_sha256 format: %w", err)
|
|
}
|
|
|
|
// Compute actual checksum of snapshot directory
|
|
snapshotDir := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
actualChecksum, err := integrity.DirOverallSHA256Hex(snapshotDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to compute snapshot hash: %w", err)
|
|
}
|
|
|
|
// Compare checksums
|
|
if actualChecksum != expectedChecksum {
|
|
return fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetJupyterManager returns the Jupyter manager for plugin use
|
|
// This implements the plugins.TaskRunner interface
|
|
func (w *Worker) GetJupyterManager() plugins.JupyterManager {
|
|
return w.Jupyter
|
|
}
|
|
|
|
// PrewarmNextOnce prewarms the next task in queue.
|
|
// It fetches the next task, verifies its snapshot, and stages it to the prewarm directory.
|
|
// Returns true if prewarming was performed, false if disabled or queue empty.
|
|
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|
// Check if prewarming is enabled
|
|
if !w.Config.PrewarmEnabled {
|
|
return false, nil
|
|
}
|
|
|
|
// Get base path and data directory
|
|
basePath := w.Config.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp"
|
|
}
|
|
dataDir := w.Config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(basePath, "data")
|
|
}
|
|
|
|
// Create prewarm directory
|
|
prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots")
|
|
if err := os.MkdirAll(prewarmDir, 0o750); err != nil {
|
|
return false, fmt.Errorf("failed to create prewarm directory: %w", err)
|
|
}
|
|
|
|
// Try to get next task from queue client if available (peek, don't lease)
|
|
if w.QueueClient != nil {
|
|
task, err := w.QueueClient.PeekNextTask()
|
|
if err != nil {
|
|
// Queue empty - check if we have existing prewarm state
|
|
// Return false but preserve any existing state (don't delete)
|
|
state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID)
|
|
if state != nil {
|
|
// We have existing state, return true to indicate prewarm is active
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
if task != nil && task.SnapshotID != "" {
|
|
// Resolve snapshot path using SHA from metadata if available
|
|
snapshotSHA := task.Metadata["snapshot_sha256"]
|
|
if snapshotSHA != "" {
|
|
snapshotSHA, _ = integrity.NormalizeSHA256ChecksumHex(snapshotSHA)
|
|
}
|
|
|
|
var srcDir string
|
|
if snapshotSHA != "" {
|
|
// Check if snapshot exists in SHA cache directory
|
|
shaDir := filepath.Join(dataDir, "snapshots", "sha256", snapshotSHA)
|
|
if info, err := os.Stat(shaDir); err == nil && info.IsDir() {
|
|
srcDir = shaDir
|
|
}
|
|
}
|
|
|
|
// Fall back to direct snapshot path if SHA directory doesn't exist
|
|
if srcDir == "" {
|
|
srcDir = filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
}
|
|
|
|
dstDir := filepath.Join(prewarmDir, task.ID)
|
|
if err := execution.CopyDir(srcDir, dstDir); err != nil {
|
|
return false, fmt.Errorf("failed to stage snapshot: %w", err)
|
|
}
|
|
|
|
// Store prewarm state in queue backend
|
|
if w.QueueClient != nil {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
state := queue.PrewarmState{
|
|
WorkerID: w.ID,
|
|
TaskID: task.ID,
|
|
SnapshotID: task.SnapshotID,
|
|
StartedAt: now,
|
|
UpdatedAt: now,
|
|
Phase: "staged",
|
|
}
|
|
_ = w.QueueClient.SetWorkerPrewarmState(state)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
|
|
if w.RunLoop != nil {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// RunJob runs a job task.
|
|
// It uses the JobRunner to execute the job and write the run manifest.
|
|
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
|
|
if w.Runner == nil {
|
|
return fmt.Errorf("job runner not configured")
|
|
}
|
|
|
|
basePath := w.Config.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp"
|
|
}
|
|
|
|
// Determine execution mode
|
|
mode := executor.ModeAuto
|
|
if w.Config.LocalMode {
|
|
mode = executor.ModeLocal
|
|
}
|
|
|
|
// Create minimal GPU environment (empty for now)
|
|
gpuEnv := interfaces.ExecutionEnv{}
|
|
|
|
// Run the job
|
|
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
|
|
}
|
|
|
|
func mustMarshal(v any) []byte {
|
|
b, _ := json.Marshal(v)
|
|
return b
|
|
}
|