package worker import ( "context" "net/http" "sync" "time" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/envpool" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/tracking" ) // MLServer wraps network.SSHClient for backward compatibility. type MLServer struct { *network.SSHClient } // JupyterManager is the subset of the Jupyter service manager used by the worker. // It exists to keep task execution testable. type JupyterManager interface { StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) StopService(ctx context.Context, serviceID string) error RemoveService(ctx context.Context, serviceID string, purge bool) error RestoreWorkspace(ctx context.Context, name string) (string, error) ListServices() []*jupyter.JupyterService ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) } // isValidName validates that input strings contain only safe characters. // isValidName checks if the input string is a valid name. func isValidName(input string) bool { return len(input) > 0 && len(input) < 256 } // NewMLServer creates a new ML server connection. // NewMLServer returns a new MLServer instance. func NewMLServer(cfg *Config) (*MLServer, error) { if cfg.LocalMode { return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil } client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) if err != nil { return nil, err } return &MLServer{SSHClient: client}, nil } // Worker represents an ML task worker. type Worker struct { id string config *Config server *MLServer queue queue.Backend resources *resources.Manager running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown runningMu sync.RWMutex ctx context.Context cancel context.CancelFunc logger *logging.Logger metrics *metrics.Metrics metricsSrv *http.Server datasetCache map[string]time.Time datasetCacheMu sync.RWMutex datasetCacheTTL time.Duration // Graceful shutdown fields shutdownCh chan struct{} activeTasks sync.Map // map[string]*queue.Task - track active tasks gracefulWait sync.WaitGroup podman *container.PodmanManager jupyter JupyterManager trackingRegistry *tracking.Registry envPool *envpool.Pool prewarmMu sync.Mutex prewarmTargetID string prewarmCancel context.CancelFunc prewarmStartedAt time.Time } func (w *Worker) getGPUDetector() GPUDetector { factory := &GPUDetectorFactory{} return factory.CreateDetector(w.config) }