fetch_ml/internal/worker/factory.go
Jeremie Fraeys 3fb6902fa1
feat(worker): integrate scheduler endpoints and security hardening
Update worker system for scheduler integration:
- Worker server with scheduler registration
- Configuration with scheduler endpoint support
- Artifact handling with integrity verification
- Container executor with supply chain validation
- Local executor enhancements
- GPU detection improvements (cross-platform)
- Error handling with execution context
- Factory pattern for executor instantiation
- Hash integrity with native library support
2026-02-26 12:06:16 -05:00

248 lines
6.8 KiB
Go

package worker
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// NewWorker creates a new worker instance with composed dependencies.
func NewWorker(cfg *Config, _ string) (*Worker, error) {
// Create queue backend
backendCfg := queue.BackendConfig{
Mode: cfg.Mode,
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
SQLitePath: cfg.Queue.SQLitePath,
FilesystemPath: cfg.Queue.FilesystemPath,
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
MetricsFlushInterval: cfg.MetricsFlushInterval,
Scheduler: queue.SchedulerConfig{
Address: cfg.Scheduler.Address,
Cert: cfg.Scheduler.Cert,
Token: cfg.Scheduler.Token,
},
}
queueClient, err := queue.NewBackend(backendCfg)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
baseLogger := logging.NewLogger(slog.LevelInfo, false)
logger := baseLogger.Component(ctx, "worker")
metricsObj := &metrics.Metrics{}
podmanMgr, err := container.NewPodmanManager(logger)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create podman manager: %w", err)
}
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
}
trackingRegistry := tracking.NewRegistry(logger)
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
if len(cfg.Plugins) == 0 {
logger.Warn("no plugins configured, defining defaults")
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
logger,
podmanMgr,
trackingplugins.MLflowOptions{
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
},
)
if err == nil {
trackingRegistry.Register(mlflowPlugin)
}
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
logger,
podmanMgr,
trackingplugins.TensorBoardOptions{
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
},
)
if err == nil {
trackingRegistry.Register(tensorboardPlugin)
}
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
} else {
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
cancel()
return nil, fmt.Errorf("failed to load plugins: %w", err)
}
}
// Create run loop configuration
runLoopConfig := lifecycle.RunLoopConfig{
WorkerID: cfg.WorkerID,
MaxWorkers: cfg.MaxWorkers,
PollInterval: time.Duration(cfg.PollInterval) * time.Second,
TaskLeaseDuration: cfg.TaskLeaseDuration,
HeartbeatInterval: cfg.HeartbeatInterval,
GracefulTimeout: cfg.GracefulTimeout,
PrewarmEnabled: cfg.PrewarmEnabled,
}
// Create executors
localExecutor := executor.NewLocalExecutor(logger, nil)
containerExecutor := executor.NewContainerExecutor(
logger,
nil,
executor.ContainerConfig{
PodmanImage: cfg.PodmanImage,
BasePath: cfg.BasePath,
},
)
// Create task executor adapter
exec := executor.NewTaskExecutorAdapter(
localExecutor,
containerExecutor,
cfg.LocalMode,
)
// Create state manager for task lifecycle management
stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available
// Create job runner
jobRunner := executor.NewJobRunner(
localExecutor,
containerExecutor,
nil, // ManifestWriter - can be added later if needed
logger,
)
runLoop := lifecycle.NewRunLoop(
runLoopConfig,
queueClient,
exec,
metricsObj,
logger,
stateMgr,
)
// Create resource manager
gpuCount, gpuDetectionInfo := parseGPUCountFromConfig(cfg)
rm, err := resources.NewManager(resources.Options{
TotalCPU: parseCPUFromConfig(cfg),
GPUCount: gpuCount,
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
})
if err != nil {
cancel()
return nil, fmt.Errorf("failed to init resource manager: %w", err)
}
worker := &Worker{
ID: cfg.WorkerID,
Config: cfg,
Logger: logger,
RunLoop: runLoop,
Runner: jobRunner,
Metrics: metricsObj,
Health: lifecycle.NewHealthMonitor(),
Resources: rm,
Jupyter: jupyterMgr,
gpuDetectionInfo: gpuDetectionInfo,
}
// In distributed mode, store the scheduler connection for heartbeats
if cfg.Mode == "distributed" {
if schedBackend, ok := queueClient.(*queue.SchedulerBackend); ok {
worker.schedulerConn = schedBackend.Conn()
}
}
// Log GPU configuration
if !cfg.LocalMode {
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
if gpuType == "amd" {
cancel()
return nil, fmt.Errorf(
"AMD GPU mode is not supported in production (FETCH_ML_GPU_TYPE=amd). " +
"Use 'nvidia', 'apple', 'none', or GPUDevices config. " +
"AMD support is available in local mode for experimental development",
)
} else if cfg.AppleGPU.Enabled {
logger.Warn(
"apple MPS GPU mode is intended for development; do not use in production",
"gpu_type", "apple",
)
}
}
// Pre-pull tracking images in background
go worker.prePullImages()
// Cancel context is not needed after creation
cancel()
return worker, nil
}
// prePullImages pulls required container images in the background
func (w *Worker) prePullImages() {
if w.Config.LocalMode {
return
}
w.Logger.Info("starting image pre-pulling")
// Pull worker image
if w.Config.PodmanImage != "" {
w.pullImage(w.Config.PodmanImage)
}
// Pull plugin images
for name, cfg := range w.Config.Plugins {
if !cfg.Enabled || cfg.Image == "" {
continue
}
w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.pullImage(cfg.Image)
}
}
// pullImage pulls a single container image
func (w *Worker) pullImage(image string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "podman", "pull", image)
if output, err := cmd.CombinedOutput(); err != nil {
w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
} else {
w.Logger.Info("image pulled successfully", "image", image)
}
}