Phase 7 of the monorepo maintainability plan: New files created: - model/jobs.go - Job type, JobStatus constants, list.Item interface - model/messages.go - tea.Msg types (JobsLoadedMsg, StatusMsg, TickMsg, etc.) - model/styles.go - NewJobListDelegate(), JobListTitleStyle(), SpinnerStyle() - model/keys.go - KeyMap struct, DefaultKeys() function Modified files: - model/state.go - reduced from 226 to ~130 lines - Removed: Job, JobStatus, KeyMap, Keys, inline styles - Kept: State struct, domain re-exports, ViewMode, DatasetInfo, InitialState() - controller/commands.go - use model. prefix for message types - controller/controller.go - use model. prefix for message types - controller/settings.go - use model.SettingsContentMsg Deleted files: - controller/keys.go (moved to model/keys.go since State references KeyMap) Result: - No file >150 lines in model/ package - Single concern per file: state, jobs, messages, styles, keys - All 41 test packages pass
556 lines
17 KiB
Go
556 lines
17 KiB
Go
// Package worker provides the ML task worker implementation
|
|
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/resources"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/execution"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
|
)
|
|
|
|
// MLServer wraps network.SSHClient for backward compatibility.
|
|
type MLServer struct {
|
|
SSHClient interface{}
|
|
}
|
|
|
|
// JupyterManager interface for jupyter service management
|
|
type JupyterManager interface {
|
|
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
|
StopService(ctx context.Context, serviceID string) error
|
|
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
|
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
|
ListServices() []*jupyter.JupyterService
|
|
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
|
}
|
|
|
|
// _isValidName validates that input strings contain only safe characters.
|
|
func _isValidName(input string) bool {
|
|
return len(input) > 0 && len(input) < 256
|
|
}
|
|
|
|
// NewMLServer creates a new ML server connection.
|
|
func NewMLServer(cfg *Config) (*MLServer, error) {
|
|
return &MLServer{}, nil
|
|
}
|
|
|
|
// Worker represents an ML task worker with composed dependencies.
|
|
type Worker struct {
|
|
id string
|
|
config *Config
|
|
logger *logging.Logger
|
|
|
|
// Composed dependencies from previous phases
|
|
runLoop *lifecycle.RunLoop
|
|
runner *executor.JobRunner
|
|
metrics *metrics.Metrics
|
|
metricsSrv *http.Server
|
|
health *lifecycle.HealthMonitor
|
|
resources *resources.Manager
|
|
|
|
// Legacy fields for backward compatibility during migration
|
|
jupyter JupyterManager
|
|
queueClient queue.Backend // Stored for prewarming access
|
|
}
|
|
|
|
// Start begins the worker's main processing loop.
|
|
func (w *Worker) Start() {
|
|
w.logger.Info("worker starting",
|
|
"worker_id", w.id,
|
|
"max_concurrent", w.config.MaxWorkers)
|
|
|
|
w.health.RecordHeartbeat()
|
|
w.runLoop.Start()
|
|
}
|
|
|
|
// Stop gracefully shuts down the worker immediately.
|
|
func (w *Worker) Stop() {
|
|
w.logger.Info("worker stopping", "worker_id", w.id)
|
|
w.runLoop.Stop()
|
|
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.logger.Warn("metrics server shutdown error", "error", err)
|
|
}
|
|
}
|
|
|
|
w.logger.Info("worker stopped", "worker_id", w.id)
|
|
}
|
|
|
|
// Shutdown performs a graceful shutdown with timeout.
|
|
func (w *Worker) Shutdown() error {
|
|
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
|
|
|
|
w.runLoop.Stop()
|
|
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.logger.Warn("metrics server shutdown error", "error", err)
|
|
}
|
|
}
|
|
|
|
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
|
|
return nil
|
|
}
|
|
|
|
// IsHealthy returns true if the worker is healthy.
|
|
func (w *Worker) IsHealthy() bool {
|
|
return w.health.IsHealthy(5 * time.Minute)
|
|
}
|
|
|
|
// GetMetrics returns current worker metrics.
|
|
func (w *Worker) GetMetrics() map[string]any {
|
|
stats := w.metrics.GetStats()
|
|
stats["worker_id"] = w.id
|
|
stats["max_workers"] = w.config.MaxWorkers
|
|
stats["healthy"] = w.IsHealthy()
|
|
return stats
|
|
}
|
|
|
|
// GetID returns the worker ID.
|
|
func (w *Worker) GetID() string {
|
|
return w.id
|
|
}
|
|
|
|
// runningCount returns the number of currently running tasks
|
|
func (w *Worker) runningCount() int {
|
|
if w.runLoop == nil {
|
|
return 0
|
|
}
|
|
return w.runLoop.RunningCount()
|
|
}
|
|
|
|
func (w *Worker) _getGPUDetector() GPUDetector {
|
|
factory := &GPUDetectorFactory{}
|
|
return factory.CreateDetector(w.config)
|
|
}
|
|
|
|
// SelectDependencyManifest re-exports the executor function for API helpers.
|
|
// It detects the dependency manifest file in the given directory.
|
|
func SelectDependencyManifest(filesPath string) (string, error) {
|
|
return executor.SelectDependencyManifest(filesPath)
|
|
}
|
|
|
|
// DirOverallSHA256Hex re-exports the integrity function for test compatibility.
|
|
func DirOverallSHA256Hex(root string) (string, error) {
|
|
return integrity.DirOverallSHA256Hex(root)
|
|
}
|
|
|
|
// NormalizeSHA256ChecksumHex re-exports the integrity function for test compatibility.
|
|
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
|
return integrity.NormalizeSHA256ChecksumHex(checksum)
|
|
}
|
|
|
|
// StageSnapshot re-exports the execution function for test compatibility.
|
|
func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error {
|
|
return execution.StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir)
|
|
}
|
|
|
|
// StageSnapshotFromPath re-exports the execution function for test compatibility.
|
|
func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error {
|
|
return execution.StageSnapshotFromPath(basePath, taskID, srcPath, jobDir)
|
|
}
|
|
|
|
// ComputeTaskProvenance computes provenance information for a task.
|
|
// This re-exports the integrity function for test compatibility.
|
|
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
|
pc := integrity.NewProvenanceCalculator(basePath)
|
|
return pc.ComputeProvenance(task)
|
|
}
|
|
|
|
// VerifyDatasetSpecs verifies dataset specifications for this task.
|
|
// This is a test compatibility method that wraps the integrity package.
|
|
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
|
dataDir := w.config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = "/tmp/data"
|
|
}
|
|
verifier := integrity.NewDatasetVerifier(dataDir)
|
|
return verifier.VerifyDatasetSpecs(task)
|
|
}
|
|
|
|
// EnforceTaskProvenance enforces provenance requirements for a task.
|
|
// It validates and/or populates provenance metadata based on the ProvenanceBestEffort config.
|
|
// In strict mode (ProvenanceBestEffort=false), it returns an error if metadata doesn't match computed values.
|
|
// In best-effort mode (ProvenanceBestEffort=true), it populates missing metadata fields.
|
|
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
|
if task == nil {
|
|
return fmt.Errorf("task is nil")
|
|
}
|
|
|
|
basePath := w.config.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp"
|
|
}
|
|
dataDir := w.config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(basePath, "data")
|
|
}
|
|
|
|
bestEffort := w.config.ProvenanceBestEffort
|
|
|
|
// Get commit_id from metadata
|
|
commitID := task.Metadata["commit_id"]
|
|
if commitID == "" {
|
|
return fmt.Errorf("missing commit_id in task metadata")
|
|
}
|
|
|
|
// Compute and verify experiment manifest SHA
|
|
expPath := filepath.Join(basePath, commitID)
|
|
manifestSHA, err := integrity.DirOverallSHA256Hex(expPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute experiment manifest SHA: %w", err)
|
|
}
|
|
// In best-effort mode, we'll use whatever is provided or skip
|
|
manifestSHA = ""
|
|
}
|
|
|
|
// Handle experiment_manifest_overall_sha
|
|
expectedManifestSHA := task.Metadata["experiment_manifest_overall_sha"]
|
|
if expectedManifestSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing experiment_manifest_overall_sha in task metadata")
|
|
}
|
|
// Populate in best-effort mode
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["experiment_manifest_overall_sha"] = manifestSHA
|
|
} else if !bestEffort && expectedManifestSHA != manifestSHA {
|
|
return fmt.Errorf("experiment manifest SHA mismatch: expected %s, got %s", expectedManifestSHA, manifestSHA)
|
|
}
|
|
|
|
// Handle deps_manifest_sha256 - auto-detect if not provided
|
|
filesPath := filepath.Join(expPath, "files")
|
|
depsManifestName := task.Metadata["deps_manifest_name"]
|
|
if depsManifestName == "" {
|
|
// Auto-detect manifest file
|
|
depsManifestName, _ = executor.SelectDependencyManifest(filesPath)
|
|
}
|
|
if depsManifestName != "" {
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["deps_manifest_name"] = depsManifestName
|
|
depsPath := filepath.Join(filesPath, depsManifestName)
|
|
depsSHA, err := integrity.FileSHA256Hex(depsPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute deps manifest SHA: %w", err)
|
|
}
|
|
depsSHA = ""
|
|
}
|
|
|
|
expectedDepsSHA := task.Metadata["deps_manifest_sha256"]
|
|
if expectedDepsSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing deps_manifest_sha256 in task metadata")
|
|
}
|
|
task.Metadata["deps_manifest_sha256"] = depsSHA
|
|
} else if !bestEffort && expectedDepsSHA != depsSHA {
|
|
return fmt.Errorf("deps manifest SHA mismatch: expected %s, got %s", expectedDepsSHA, depsSHA)
|
|
}
|
|
}
|
|
|
|
// Handle snapshot_sha256 if SnapshotID is set
|
|
if task.SnapshotID != "" {
|
|
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
snapSHA, err := integrity.DirOverallSHA256Hex(snapPath)
|
|
if err != nil {
|
|
if !bestEffort {
|
|
return fmt.Errorf("failed to compute snapshot SHA: %w", err)
|
|
}
|
|
snapSHA = ""
|
|
}
|
|
|
|
expectedSnapSHA, _ := integrity.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
|
if expectedSnapSHA == "" {
|
|
if !bestEffort {
|
|
return fmt.Errorf("missing snapshot_sha256 in task metadata")
|
|
}
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
task.Metadata["snapshot_sha256"] = snapSHA
|
|
} else if !bestEffort && expectedSnapSHA != snapSHA {
|
|
return fmt.Errorf("snapshot SHA mismatch: expected %s, got %s", expectedSnapSHA, snapSHA)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifySnapshot verifies snapshot integrity for this task.
|
|
// It computes the SHA256 of the snapshot directory and compares with task metadata.
|
|
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|
if task.SnapshotID == "" {
|
|
return nil // No snapshot to verify
|
|
}
|
|
|
|
dataDir := w.config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = "/tmp/data"
|
|
}
|
|
|
|
// Get expected checksum from metadata
|
|
expectedChecksum, ok := task.Metadata["snapshot_sha256"]
|
|
if !ok || expectedChecksum == "" {
|
|
return fmt.Errorf("missing snapshot_sha256 in task metadata")
|
|
}
|
|
|
|
// Normalize the checksum (remove sha256: prefix if present)
|
|
expectedChecksum, err := integrity.NormalizeSHA256ChecksumHex(expectedChecksum)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid snapshot_sha256 format: %w", err)
|
|
}
|
|
|
|
// Compute actual checksum of snapshot directory
|
|
snapshotDir := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
actualChecksum, err := integrity.DirOverallSHA256Hex(snapshotDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to compute snapshot hash: %w", err)
|
|
}
|
|
|
|
// Compare checksums
|
|
if actualChecksum != expectedChecksum {
|
|
return fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RunJupyterTask runs a Jupyter-related task.
|
|
// It handles start, stop, remove, restore, and list_packages actions.
|
|
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
|
if w.jupyter == nil {
|
|
return nil, fmt.Errorf("jupyter manager not configured")
|
|
}
|
|
|
|
action := task.Metadata["jupyter_action"]
|
|
if action == "" {
|
|
action = "start" // Default action
|
|
}
|
|
|
|
switch action {
|
|
case "start":
|
|
name := task.Metadata["jupyter_name"]
|
|
if name == "" {
|
|
name = task.Metadata["jupyter_workspace"]
|
|
}
|
|
if name == "" {
|
|
// Extract from jobName if format is "jupyter-<name>"
|
|
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
|
|
name = task.JobName[8:]
|
|
}
|
|
}
|
|
if name == "" {
|
|
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
|
}
|
|
|
|
req := &jupyter.StartRequest{Name: name}
|
|
service, err := w.jupyter.StartService(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
|
}
|
|
|
|
output := map[string]interface{}{
|
|
"type": "start",
|
|
"service": service,
|
|
}
|
|
return json.Marshal(output)
|
|
|
|
case "stop":
|
|
serviceID := task.Metadata["jupyter_service_id"]
|
|
if serviceID == "" {
|
|
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
|
}
|
|
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
|
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
|
}
|
|
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
|
|
|
case "remove":
|
|
serviceID := task.Metadata["jupyter_service_id"]
|
|
if serviceID == "" {
|
|
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
|
}
|
|
purge := task.Metadata["jupyter_purge"] == "true"
|
|
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
|
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
|
}
|
|
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
|
|
|
case "restore":
|
|
name := task.Metadata["jupyter_name"]
|
|
if name == "" {
|
|
name = task.Metadata["jupyter_workspace"]
|
|
}
|
|
if name == "" {
|
|
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
|
}
|
|
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
|
}
|
|
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
|
|
|
|
case "list_packages":
|
|
serviceName := task.Metadata["jupyter_name"]
|
|
if serviceName == "" {
|
|
// Extract from jobName if format is "jupyter-packages-<name>"
|
|
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
|
|
serviceName = task.JobName[16:]
|
|
}
|
|
}
|
|
if serviceName == "" {
|
|
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
|
}
|
|
|
|
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
|
}
|
|
|
|
output := map[string]interface{}{
|
|
"type": "list_packages",
|
|
"packages": packages,
|
|
}
|
|
return json.Marshal(output)
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unknown jupyter action: %s", action)
|
|
}
|
|
}
|
|
|
|
// PrewarmNextOnce prewarms the next task in queue.
|
|
// It fetches the next task, verifies its snapshot, and stages it to the prewarm directory.
|
|
// Returns true if prewarming was performed, false if disabled or queue empty.
|
|
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|
// Check if prewarming is enabled
|
|
if !w.config.PrewarmEnabled {
|
|
return false, nil
|
|
}
|
|
|
|
// Get base path and data directory
|
|
basePath := w.config.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp"
|
|
}
|
|
dataDir := w.config.DataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(basePath, "data")
|
|
}
|
|
|
|
// Create prewarm directory
|
|
prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots")
|
|
if err := os.MkdirAll(prewarmDir, 0750); err != nil {
|
|
return false, fmt.Errorf("failed to create prewarm directory: %w", err)
|
|
}
|
|
|
|
// Try to get next task from queue client if available (peek, don't lease)
|
|
if w.queueClient != nil {
|
|
task, err := w.queueClient.PeekNextTask()
|
|
if err != nil {
|
|
// Queue empty - check if we have existing prewarm state
|
|
// Return false but preserve any existing state (don't delete)
|
|
state, _ := w.queueClient.GetWorkerPrewarmState(w.id)
|
|
if state != nil {
|
|
// We have existing state, return true to indicate prewarm is active
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
if task != nil && task.SnapshotID != "" {
|
|
// Resolve snapshot path using SHA from metadata if available
|
|
snapshotSHA := task.Metadata["snapshot_sha256"]
|
|
if snapshotSHA != "" {
|
|
snapshotSHA, _ = integrity.NormalizeSHA256ChecksumHex(snapshotSHA)
|
|
}
|
|
|
|
var srcDir string
|
|
if snapshotSHA != "" {
|
|
// Check if snapshot exists in SHA cache directory
|
|
shaDir := filepath.Join(dataDir, "snapshots", "sha256", snapshotSHA)
|
|
if info, err := os.Stat(shaDir); err == nil && info.IsDir() {
|
|
srcDir = shaDir
|
|
}
|
|
}
|
|
|
|
// Fall back to direct snapshot path if SHA directory doesn't exist
|
|
if srcDir == "" {
|
|
srcDir = filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
}
|
|
|
|
dstDir := filepath.Join(prewarmDir, task.ID)
|
|
if err := execution.CopyDir(srcDir, dstDir); err != nil {
|
|
return false, fmt.Errorf("failed to stage snapshot: %w", err)
|
|
}
|
|
|
|
// Store prewarm state in queue backend
|
|
if w.queueClient != nil {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
state := queue.PrewarmState{
|
|
WorkerID: w.id,
|
|
TaskID: task.ID,
|
|
SnapshotID: task.SnapshotID,
|
|
StartedAt: now,
|
|
UpdatedAt: now,
|
|
Phase: "staged",
|
|
}
|
|
_ = w.queueClient.SetWorkerPrewarmState(state)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
|
|
if w.runLoop != nil {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// RunJob runs a job task.
|
|
// It uses the JobRunner to execute the job and write the run manifest.
|
|
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
|
|
if w.runner == nil {
|
|
return fmt.Errorf("job runner not configured")
|
|
}
|
|
|
|
basePath := w.config.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp"
|
|
}
|
|
|
|
// Determine execution mode
|
|
mode := executor.ModeAuto
|
|
if w.config.LocalMode {
|
|
mode = executor.ModeLocal
|
|
}
|
|
|
|
// Create minimal GPU environment (empty for now)
|
|
gpuEnv := interfaces.ExecutionEnv{}
|
|
|
|
// Run the job
|
|
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv)
|
|
}
|