// Package worker provides the ML task worker implementation package worker import ( "context" "net/http" "time" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/worker/executor" "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" ) // MLServer wraps network.SSHClient for backward compatibility. type MLServer struct { SSHClient interface{} } // JupyterManager interface for jupyter service management type JupyterManager interface { StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) StopService(ctx context.Context, serviceID string) error RemoveService(ctx context.Context, serviceID string, purge bool) error RestoreWorkspace(ctx context.Context, name string) (string, error) ListServices() []*jupyter.JupyterService ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) } // isValidName validates that input strings contain only safe characters. func isValidName(input string) bool { return len(input) > 0 && len(input) < 256 } // NewMLServer creates a new ML server connection. func NewMLServer(cfg *Config) (*MLServer, error) { return &MLServer{}, nil } // Worker represents an ML task worker with composed dependencies. type Worker struct { id string config *Config logger *logging.Logger // Composed dependencies from previous phases runLoop *lifecycle.RunLoop runner *executor.JobRunner metrics *metrics.Metrics metricsSrv *http.Server health *lifecycle.HealthMonitor resources *resources.Manager // Legacy fields for backward compatibility during migration jupyter JupyterManager } // Start begins the worker's main processing loop. func (w *Worker) Start() { w.logger.Info("worker starting", "worker_id", w.id, "max_concurrent", w.config.MaxWorkers) w.health.RecordHeartbeat() w.runLoop.Start() } // Stop gracefully shuts down the worker immediately. func (w *Worker) Stop() { w.logger.Info("worker stopping", "worker_id", w.id) w.runLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.logger.Warn("metrics server shutdown error", "error", err) } } w.logger.Info("worker stopped", "worker_id", w.id) } // Shutdown performs a graceful shutdown with timeout. func (w *Worker) Shutdown() error { w.logger.Info("starting graceful shutdown", "worker_id", w.id) w.runLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.logger.Warn("metrics server shutdown error", "error", err) } } w.logger.Info("worker shut down gracefully", "worker_id", w.id) return nil } // IsHealthy returns true if the worker is healthy. func (w *Worker) IsHealthy() bool { return w.health.IsHealthy(5 * time.Minute) } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { stats := w.metrics.GetStats() stats["worker_id"] = w.id stats["max_workers"] = w.config.MaxWorkers stats["healthy"] = w.IsHealthy() return stats } // GetID returns the worker ID. func (w *Worker) GetID() string { return w.id } // runningCount returns the number of currently running tasks func (w *Worker) runningCount() int { if w.runLoop == nil { return 0 } return w.runLoop.RunningCount() } func (w *Worker) getGPUDetector() GPUDetector { factory := &GPUDetectorFactory{} return factory.CreateDetector(w.config) }