fetch_ml/internal/worker/worker.go
Jeremie Fraeys 4c8c9dfe4b
refactor: Export SelectDependencyManifest for API helpers
- Renamed selectDependencyManifest to SelectDependencyManifest (exported)
- Added re-export in worker package for backward compatibility
- Updated internal call in container.go to use exported function
- API helpers can now access via worker.SelectDependencyManifest

Build status: Compiles successfully
2026-02-17 16:45:59 -05:00

140 lines
3.9 KiB
Go

// Package worker provides the ML task worker implementation
package worker
import (
"context"
"net/http"
"time"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// MLServer wraps network.SSHClient for backward compatibility.
type MLServer struct {
SSHClient interface{}
}
// JupyterManager interface for jupyter service management
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
// isValidName validates that input strings contain only safe characters.
func isValidName(input string) bool {
return len(input) > 0 && len(input) < 256
}
// NewMLServer creates a new ML server connection.
func NewMLServer(cfg *Config) (*MLServer, error) {
return &MLServer{}, nil
}
// Worker represents an ML task worker with composed dependencies.
type Worker struct {
id string
config *Config
logger *logging.Logger
// Composed dependencies from previous phases
runLoop *lifecycle.RunLoop
runner *executor.JobRunner
metrics *metrics.Metrics
metricsSrv *http.Server
health *lifecycle.HealthMonitor
resources *resources.Manager
// Legacy fields for backward compatibility during migration
jupyter JupyterManager
}
// Start begins the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker starting",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers)
w.health.RecordHeartbeat()
w.runLoop.Start()
}
// Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() {
w.logger.Info("worker stopping", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Shutdown performs a graceful shutdown with timeout.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
return nil
}
// IsHealthy returns true if the worker is healthy.
func (w *Worker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
stats["healthy"] = w.IsHealthy()
return stats
}
// GetID returns the worker ID.
func (w *Worker) GetID() string {
return w.id
}
// runningCount returns the number of currently running tasks
func (w *Worker) runningCount() int {
if w.runLoop == nil {
return 0
}
return w.runLoop.RunningCount()
}
func (w *Worker) getGPUDetector() GPUDetector {
factory := &GPUDetectorFactory{}
return factory.CreateDetector(w.config)
}
// SelectDependencyManifest re-exports the executor function for API helpers.
// It detects the dependency manifest file in the given directory.
func SelectDependencyManifest(filesPath string) (string, error) {
return executor.SelectDependencyManifest(filesPath)
}