Some checks failed
Build CLI with Embedded SQLite / build (arm64, aarch64-linux) (push) Waiting to run
Build CLI with Embedded SQLite / build (x86_64, x86_64-linux) (push) Waiting to run
Build CLI with Embedded SQLite / build-macos (arm64) (push) Waiting to run
Build CLI with Embedded SQLite / build-macos (x86_64) (push) Waiting to run
Security Scan / Security Analysis (push) Waiting to run
Security Scan / Native Library Security (push) Waiting to run
Checkout test / test (push) Successful in 6s
CI/CD Pipeline / Test (push) Failing after 1s
CI/CD Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI/CD Pipeline / Build (push) Has been skipped
CI/CD Pipeline / Test Scripts (push) Has been skipped
CI/CD Pipeline / Test Native Libraries (push) Has been skipped
CI/CD Pipeline / GPU Golden Test Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 39s
CI/CD Pipeline / Docker Build (push) Has been skipped
- Surface GPUDetectionInfo from parseGPUCountFromConfig for detection metadata - Document FETCH_ML_TOTAL_CPU and FETCH_ML_GPU_SLOTS_PER_GPU env vars - Add debug logging for all env var overrides to stderr - Track config-layer auto-detection in GPUDetectionInfo.ConfigLayerAutoDetected - Add --include-all flag to artifact scanner (includeAll parameter) - Add AMD production mode enforcement (error in non-local mode) - Add GPU detector unit tests for env overrides and AMD aliasing
235 lines
6.5 KiB
Go
235 lines
6.5 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/resources"
|
|
"github.com/jfraeys/fetch_ml/internal/tracking"
|
|
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
|
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
|
)
|
|
|
|
// NewWorker creates a new worker instance with composed dependencies.
|
|
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|
// Create queue backend
|
|
backendCfg := queue.BackendConfig{
|
|
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
|
RedisAddr: cfg.RedisAddr,
|
|
RedisPassword: cfg.RedisPassword,
|
|
RedisDB: cfg.RedisDB,
|
|
SQLitePath: cfg.Queue.SQLitePath,
|
|
FilesystemPath: cfg.Queue.FilesystemPath,
|
|
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
|
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
|
}
|
|
|
|
queueClient, err := queue.NewBackend(backendCfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ctx = logging.EnsureTrace(ctx)
|
|
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
|
|
|
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
|
logger := baseLogger.Component(ctx, "worker")
|
|
metricsObj := &metrics.Metrics{}
|
|
|
|
podmanMgr, err := container.NewPodmanManager(logger)
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to create podman manager: %w", err)
|
|
}
|
|
|
|
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
|
|
}
|
|
|
|
trackingRegistry := tracking.NewRegistry(logger)
|
|
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
|
|
|
|
if len(cfg.Plugins) == 0 {
|
|
logger.Warn("no plugins configured, defining defaults")
|
|
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
|
|
logger,
|
|
podmanMgr,
|
|
trackingplugins.MLflowOptions{
|
|
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
|
|
},
|
|
)
|
|
if err == nil {
|
|
trackingRegistry.Register(mlflowPlugin)
|
|
}
|
|
|
|
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
|
|
logger,
|
|
podmanMgr,
|
|
trackingplugins.TensorBoardOptions{
|
|
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
|
|
},
|
|
)
|
|
if err == nil {
|
|
trackingRegistry.Register(tensorboardPlugin)
|
|
}
|
|
|
|
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
|
|
} else {
|
|
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to load plugins: %w", err)
|
|
}
|
|
}
|
|
|
|
// Create run loop configuration
|
|
runLoopConfig := lifecycle.RunLoopConfig{
|
|
WorkerID: cfg.WorkerID,
|
|
MaxWorkers: cfg.MaxWorkers,
|
|
PollInterval: time.Duration(cfg.PollInterval) * time.Second,
|
|
TaskLeaseDuration: cfg.TaskLeaseDuration,
|
|
HeartbeatInterval: cfg.HeartbeatInterval,
|
|
GracefulTimeout: cfg.GracefulTimeout,
|
|
PrewarmEnabled: cfg.PrewarmEnabled,
|
|
}
|
|
|
|
// Create executors
|
|
localExecutor := executor.NewLocalExecutor(logger, nil)
|
|
containerExecutor := executor.NewContainerExecutor(
|
|
logger,
|
|
nil,
|
|
executor.ContainerConfig{
|
|
PodmanImage: cfg.PodmanImage,
|
|
BasePath: cfg.BasePath,
|
|
},
|
|
)
|
|
|
|
// Create task executor adapter
|
|
exec := executor.NewTaskExecutorAdapter(
|
|
localExecutor,
|
|
containerExecutor,
|
|
cfg.LocalMode,
|
|
)
|
|
|
|
// Create state manager for task lifecycle management
|
|
stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available
|
|
|
|
// Create job runner
|
|
jobRunner := executor.NewJobRunner(
|
|
localExecutor,
|
|
containerExecutor,
|
|
nil, // ManifestWriter - can be added later if needed
|
|
logger,
|
|
)
|
|
|
|
runLoop := lifecycle.NewRunLoop(
|
|
runLoopConfig,
|
|
queueClient,
|
|
exec,
|
|
metricsObj,
|
|
logger,
|
|
stateMgr,
|
|
)
|
|
|
|
// Create resource manager
|
|
gpuCount, gpuDetectionInfo := parseGPUCountFromConfig(cfg)
|
|
rm, err := resources.NewManager(resources.Options{
|
|
TotalCPU: parseCPUFromConfig(cfg),
|
|
GPUCount: gpuCount,
|
|
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
|
|
})
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to init resource manager: %w", err)
|
|
}
|
|
|
|
worker := &Worker{
|
|
id: cfg.WorkerID,
|
|
config: cfg,
|
|
logger: logger,
|
|
runLoop: runLoop,
|
|
runner: jobRunner,
|
|
metrics: metricsObj,
|
|
health: lifecycle.NewHealthMonitor(),
|
|
resources: rm,
|
|
jupyter: jupyterMgr,
|
|
gpuDetectionInfo: gpuDetectionInfo,
|
|
}
|
|
|
|
// Log GPU configuration
|
|
if !cfg.LocalMode {
|
|
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
|
if gpuType == "amd" {
|
|
cancel()
|
|
return nil, fmt.Errorf(
|
|
"AMD GPU mode is not supported in production (FETCH_ML_GPU_TYPE=amd). " +
|
|
"Use 'nvidia', 'apple', 'none', or GPUDevices config. " +
|
|
"AMD support is available in local mode for experimental development",
|
|
)
|
|
} else if cfg.AppleGPU.Enabled {
|
|
logger.Warn(
|
|
"apple MPS GPU mode is intended for development; do not use in production",
|
|
"gpu_type", "apple",
|
|
)
|
|
}
|
|
}
|
|
|
|
// Pre-pull tracking images in background
|
|
go worker.prePullImages()
|
|
|
|
// Cancel context is not needed after creation
|
|
cancel()
|
|
|
|
return worker, nil
|
|
}
|
|
|
|
// prePullImages pulls required container images in the background
|
|
func (w *Worker) prePullImages() {
|
|
if w.config.LocalMode {
|
|
return
|
|
}
|
|
|
|
w.logger.Info("starting image pre-pulling")
|
|
|
|
// Pull worker image
|
|
if w.config.PodmanImage != "" {
|
|
w.pullImage(w.config.PodmanImage)
|
|
}
|
|
|
|
// Pull plugin images
|
|
for name, cfg := range w.config.Plugins {
|
|
if !cfg.Enabled || cfg.Image == "" {
|
|
continue
|
|
}
|
|
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
|
w.pullImage(cfg.Image)
|
|
}
|
|
}
|
|
|
|
// pullImage pulls a single container image
|
|
func (w *Worker) pullImage(image string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
defer cancel()
|
|
|
|
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
|
if output, err := cmd.CombinedOutput(); err != nil {
|
|
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
|
} else {
|
|
w.logger.Info("image pulled successfully", "image", image)
|
|
}
|
|
}
|