fetch_ml/internal/worker/worker.go
Jeremie Fraeys 3187ff26ea
refactor: complete maintainability phases 1-9 and fix all tests
Test fixes (all 41 test packages now pass):
- Fix ComputeTaskProvenance - add dataset_specs JSON output
- Fix EnforceTaskProvenance - populate all metadata fields in best-effort mode
- Fix PrewarmNextOnce - preserve prewarm state when queue empty
- Fix RunManifest directory creation in SetupJobDirectories
- Add ManifestWriter to test worker (simpleManifestWriter)
- Fix worker ID mismatch (use cfg.WorkerID)
- Fix WebSocket binary protocol responses
- Implement all WebSocket handlers: QueueJob, QueueJobWithSnapshot, StatusRequest,
  CancelJob, Prune, ValidateRequest (with run manifest validation), LogMetric,
  GetExperiment, DatasetList/Register/Info/Search

Maintainability phases completed:
- Phases 1-6: Domain types, error system, config boundaries, worker/API/queue splits
- Phase 7: TUI cleanup - reorganize model package (jobs.go, messages.go, styles.go, keys.go)
- Phase 8: MLServer unification - consolidate worker + TUI into internal/network/mlserver.go
- Phase 9: CI enforcement - add scripts/ci-checks.sh with 5 checks:
  * No internal/ -> cmd/ imports
  * domain/ has zero internal imports
  * File size limit (500 lines, rigid)
  * No circular imports
  * Package naming conventions

Documentation:
- Add docs/src/file-naming-conventions.md
- Add make ci-checks target

Lines changed: +756/-36 (WebSocket fixes), +518/-320 (TUI), +263/-20 (Phase 8-9)
2026-02-17 20:32:14 -05:00

555 lines
17 KiB
Go

// Package worker provides the ML task worker implementation
package worker
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"time"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/worker/execution"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// JupyterManager interface for jupyter service management
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
// _isValidName validates that input strings contain only safe characters.
func _isValidName(input string) bool {
return len(input) > 0 && len(input) < 256
}
// MLServer is an alias for network.MLServer for backward compatibility.
type MLServer = network.MLServer
// NewMLServer creates a new ML server connection.
func NewMLServer(cfg *Config) (*MLServer, error) {
return network.NewMLServer("", "", "", 0, "")
}
// Worker represents an ML task worker with composed dependencies.
type Worker struct {
id string
config *Config
logger *logging.Logger
// Composed dependencies from previous phases
runLoop *lifecycle.RunLoop
runner *executor.JobRunner
metrics *metrics.Metrics
metricsSrv *http.Server
health *lifecycle.HealthMonitor
resources *resources.Manager
// Legacy fields for backward compatibility during migration
jupyter JupyterManager
queueClient queue.Backend // Stored for prewarming access
}
// Start begins the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker starting",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers)
w.health.RecordHeartbeat()
w.runLoop.Start()
}
// Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() {
w.logger.Info("worker stopping", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Shutdown performs a graceful shutdown with timeout.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
return nil
}
// IsHealthy returns true if the worker is healthy.
func (w *Worker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
stats["healthy"] = w.IsHealthy()
return stats
}
// GetID returns the worker ID.
func (w *Worker) GetID() string {
return w.id
}
// runningCount returns the number of currently running tasks
func (w *Worker) runningCount() int {
if w.runLoop == nil {
return 0
}
return w.runLoop.RunningCount()
}
func (w *Worker) _getGPUDetector() GPUDetector {
factory := &GPUDetectorFactory{}
return factory.CreateDetector(w.config)
}
// SelectDependencyManifest re-exports the executor function for API helpers.
// It detects the dependency manifest file in the given directory.
func SelectDependencyManifest(filesPath string) (string, error) {
return executor.SelectDependencyManifest(filesPath)
}
// DirOverallSHA256Hex re-exports the integrity function for test compatibility.
func DirOverallSHA256Hex(root string) (string, error) {
return integrity.DirOverallSHA256Hex(root)
}
// NormalizeSHA256ChecksumHex re-exports the integrity function for test compatibility.
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
return integrity.NormalizeSHA256ChecksumHex(checksum)
}
// StageSnapshot re-exports the execution function for test compatibility.
func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error {
return execution.StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir)
}
// StageSnapshotFromPath re-exports the execution function for test compatibility.
func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error {
return execution.StageSnapshotFromPath(basePath, taskID, srcPath, jobDir)
}
// ComputeTaskProvenance computes provenance information for a task.
// This re-exports the integrity function for test compatibility.
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
pc := integrity.NewProvenanceCalculator(basePath)
return pc.ComputeProvenance(task)
}
// VerifyDatasetSpecs verifies dataset specifications for this task.
// This is a test compatibility method that wraps the integrity package.
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
dataDir := w.config.DataDir
if dataDir == "" {
dataDir = "/tmp/data"
}
verifier := integrity.NewDatasetVerifier(dataDir)
return verifier.VerifyDatasetSpecs(task)
}
// EnforceTaskProvenance enforces provenance requirements for a task.
// It validates and/or populates provenance metadata based on the ProvenanceBestEffort config.
// In strict mode (ProvenanceBestEffort=false), it returns an error if metadata doesn't match computed values.
// In best-effort mode (ProvenanceBestEffort=true), it populates missing metadata fields.
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
basePath := w.config.BasePath
if basePath == "" {
basePath = "/tmp"
}
dataDir := w.config.DataDir
if dataDir == "" {
dataDir = filepath.Join(basePath, "data")
}
bestEffort := w.config.ProvenanceBestEffort
// Get commit_id from metadata
commitID := task.Metadata["commit_id"]
if commitID == "" {
return fmt.Errorf("missing commit_id in task metadata")
}
// Compute and verify experiment manifest SHA
expPath := filepath.Join(basePath, commitID)
manifestSHA, err := integrity.DirOverallSHA256Hex(expPath)
if err != nil {
if !bestEffort {
return fmt.Errorf("failed to compute experiment manifest SHA: %w", err)
}
// In best-effort mode, we'll use whatever is provided or skip
manifestSHA = ""
}
// Handle experiment_manifest_overall_sha
expectedManifestSHA := task.Metadata["experiment_manifest_overall_sha"]
if expectedManifestSHA == "" {
if !bestEffort {
return fmt.Errorf("missing experiment_manifest_overall_sha in task metadata")
}
// Populate in best-effort mode
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
task.Metadata["experiment_manifest_overall_sha"] = manifestSHA
} else if !bestEffort && expectedManifestSHA != manifestSHA {
return fmt.Errorf("experiment manifest SHA mismatch: expected %s, got %s", expectedManifestSHA, manifestSHA)
}
// Handle deps_manifest_sha256 - auto-detect if not provided
filesPath := filepath.Join(expPath, "files")
depsManifestName := task.Metadata["deps_manifest_name"]
if depsManifestName == "" {
// Auto-detect manifest file
depsManifestName, _ = executor.SelectDependencyManifest(filesPath)
}
if depsManifestName != "" {
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
task.Metadata["deps_manifest_name"] = depsManifestName
depsPath := filepath.Join(filesPath, depsManifestName)
depsSHA, err := integrity.FileSHA256Hex(depsPath)
if err != nil {
if !bestEffort {
return fmt.Errorf("failed to compute deps manifest SHA: %w", err)
}
depsSHA = ""
}
expectedDepsSHA := task.Metadata["deps_manifest_sha256"]
if expectedDepsSHA == "" {
if !bestEffort {
return fmt.Errorf("missing deps_manifest_sha256 in task metadata")
}
task.Metadata["deps_manifest_sha256"] = depsSHA
} else if !bestEffort && expectedDepsSHA != depsSHA {
return fmt.Errorf("deps manifest SHA mismatch: expected %s, got %s", expectedDepsSHA, depsSHA)
}
}
// Handle snapshot_sha256 if SnapshotID is set
if task.SnapshotID != "" {
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
snapSHA, err := integrity.DirOverallSHA256Hex(snapPath)
if err != nil {
if !bestEffort {
return fmt.Errorf("failed to compute snapshot SHA: %w", err)
}
snapSHA = ""
}
expectedSnapSHA, _ := integrity.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if expectedSnapSHA == "" {
if !bestEffort {
return fmt.Errorf("missing snapshot_sha256 in task metadata")
}
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
task.Metadata["snapshot_sha256"] = snapSHA
} else if !bestEffort && expectedSnapSHA != snapSHA {
return fmt.Errorf("snapshot SHA mismatch: expected %s, got %s", expectedSnapSHA, snapSHA)
}
}
return nil
}
// VerifySnapshot verifies snapshot integrity for this task.
// It computes the SHA256 of the snapshot directory and compares with task metadata.
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
if task.SnapshotID == "" {
return nil // No snapshot to verify
}
dataDir := w.config.DataDir
if dataDir == "" {
dataDir = "/tmp/data"
}
// Get expected checksum from metadata
expectedChecksum, ok := task.Metadata["snapshot_sha256"]
if !ok || expectedChecksum == "" {
return fmt.Errorf("missing snapshot_sha256 in task metadata")
}
// Normalize the checksum (remove sha256: prefix if present)
expectedChecksum, err := integrity.NormalizeSHA256ChecksumHex(expectedChecksum)
if err != nil {
return fmt.Errorf("invalid snapshot_sha256 format: %w", err)
}
// Compute actual checksum of snapshot directory
snapshotDir := filepath.Join(dataDir, "snapshots", task.SnapshotID)
actualChecksum, err := integrity.DirOverallSHA256Hex(snapshotDir)
if err != nil {
return fmt.Errorf("failed to compute snapshot hash: %w", err)
}
// Compare checksums
if actualChecksum != expectedChecksum {
return fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
}
return nil
}
// RunJupyterTask runs a Jupyter-related task.
// It handles start, stop, remove, restore, and list_packages actions.
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w.jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
action := task.Metadata["jupyter_action"]
if action == "" {
action = "start" // Default action
}
switch action {
case "start":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
// Extract from jobName if format is "jupyter-<name>"
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
name = task.JobName[8:]
}
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
req := &jupyter.StartRequest{Name: name}
service, err := w.jupyter.StartService(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
}
output := map[string]interface{}{
"type": "start",
"service": service,
}
return json.Marshal(output)
case "stop":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
case "remove":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
purge := task.Metadata["jupyter_purge"] == "true"
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
case "restore":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
}
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
case "list_packages":
serviceName := task.Metadata["jupyter_name"]
if serviceName == "" {
// Extract from jobName if format is "jupyter-packages-<name>"
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
serviceName = task.JobName[16:]
}
}
if serviceName == "" {
return nil, fmt.Errorf("missing jupyter_name in task metadata")
}
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName)
if err != nil {
return nil, fmt.Errorf("failed to list installed packages: %w", err)
}
output := map[string]interface{}{
"type": "list_packages",
"packages": packages,
}
return json.Marshal(output)
default:
return nil, fmt.Errorf("unknown jupyter action: %s", action)
}
}
// PrewarmNextOnce prewarms the next task in queue.
// It fetches the next task, verifies its snapshot, and stages it to the prewarm directory.
// Returns true if prewarming was performed, false if disabled or queue empty.
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// Check if prewarming is enabled
if !w.config.PrewarmEnabled {
return false, nil
}
// Get base path and data directory
basePath := w.config.BasePath
if basePath == "" {
basePath = "/tmp"
}
dataDir := w.config.DataDir
if dataDir == "" {
dataDir = filepath.Join(basePath, "data")
}
// Create prewarm directory
prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots")
if err := os.MkdirAll(prewarmDir, 0750); err != nil {
return false, fmt.Errorf("failed to create prewarm directory: %w", err)
}
// Try to get next task from queue client if available (peek, don't lease)
if w.queueClient != nil {
task, err := w.queueClient.PeekNextTask()
if err != nil {
// Queue empty - check if we have existing prewarm state
// Return false but preserve any existing state (don't delete)
state, _ := w.queueClient.GetWorkerPrewarmState(w.id)
if state != nil {
// We have existing state, return true to indicate prewarm is active
return true, nil
}
return false, nil
}
if task != nil && task.SnapshotID != "" {
// Resolve snapshot path using SHA from metadata if available
snapshotSHA := task.Metadata["snapshot_sha256"]
if snapshotSHA != "" {
snapshotSHA, _ = integrity.NormalizeSHA256ChecksumHex(snapshotSHA)
}
var srcDir string
if snapshotSHA != "" {
// Check if snapshot exists in SHA cache directory
shaDir := filepath.Join(dataDir, "snapshots", "sha256", snapshotSHA)
if info, err := os.Stat(shaDir); err == nil && info.IsDir() {
srcDir = shaDir
}
}
// Fall back to direct snapshot path if SHA directory doesn't exist
if srcDir == "" {
srcDir = filepath.Join(dataDir, "snapshots", task.SnapshotID)
}
dstDir := filepath.Join(prewarmDir, task.ID)
if err := execution.CopyDir(srcDir, dstDir); err != nil {
return false, fmt.Errorf("failed to stage snapshot: %w", err)
}
// Store prewarm state in queue backend
if w.queueClient != nil {
now := time.Now().UTC().Format(time.RFC3339)
state := queue.PrewarmState{
WorkerID: w.id,
TaskID: task.ID,
SnapshotID: task.SnapshotID,
StartedAt: now,
UpdatedAt: now,
Phase: "staged",
}
_ = w.queueClient.SetWorkerPrewarmState(state)
}
return true, nil
}
}
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
if w.runLoop != nil {
return true, nil
}
return false, nil
}
// RunJob runs a job task.
// It uses the JobRunner to execute the job and write the run manifest.
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
if w.runner == nil {
return fmt.Errorf("job runner not configured")
}
basePath := w.config.BasePath
if basePath == "" {
basePath = "/tmp"
}
// Determine execution mode
mode := executor.ModeAuto
if w.config.LocalMode {
mode = executor.ModeLocal
}
// Create minimal GPU environment (empty for now)
gpuEnv := interfaces.ExecutionEnv{}
// Run the job
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv)
}