fetch_ml/internal/worker/factory.go
Jeremie Fraeys a5c1a9fc0b
refactor: Phase 4 - split worker package into focused files
Split 551-line worker/core.go into single-concern files:

- worker/config.go (+44 lines)
  - Added config parsing: envInt(), parseCPUFromConfig(), parseGPUCountFromConfig()
  - parseGPUSlotsPerGPUFromConfig()
  - Now has all config logic in one place (440 lines total)

- worker/metrics.go (new file, 172 lines)
  - Extracted setupMetricsExporter() with ~30 Prometheus metric registrations
  - Isolated metrics logic for easy modification

- worker/factory.go (new file, 183 lines)
  - Extracted NewWorker() factory function
  - Moved prePullImages(), pullImage() from core.go
  - Centralized worker instantiation

- worker/worker.go (renamed from core.go, ~100 lines)
  - Now just defines Worker struct, MLServer, JupyterManager
  - Clean, focused file without mixed concerns

Lines redistributed: ~350 lines moved from monolithic core.go
Build status: Compiles successfully
2026-02-17 12:57:02 -05:00

212 lines
5.9 KiB
Go

package worker
import (
"context"
"fmt"
"log"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/envpool"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
)
// NewWorker creates a new worker instance.
func NewWorker(cfg *Config, _ string) (*Worker, error) {
srv, err := NewMLServer(cfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := srv.Close(); closeErr != nil {
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
}
}
}()
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
SQLitePath: cfg.Queue.SQLitePath,
FilesystemPath: cfg.Queue.FilesystemPath,
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queueClient, err := queue.NewBackend(backendCfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := queueClient.Close(); closeErr != nil {
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
}
}
}()
// Create data_dir if it doesn't exist (for production without NAS)
if cfg.DataDir != "" {
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
}
}
ctx, cancel := context.WithCancel(context.Background())
defer func() {
if err != nil {
cancel()
}
}()
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
baseLogger := logging.NewLogger(slog.LevelInfo, false)
logger := baseLogger.Component(ctx, "worker")
metricsObj := &metrics.Metrics{}
podmanMgr, err := container.NewPodmanManager(logger)
if err != nil {
return nil, fmt.Errorf("failed to create podman manager: %w", err)
}
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
if err != nil {
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
}
trackingRegistry := tracking.NewRegistry(logger)
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
if len(cfg.Plugins) == 0 {
logger.Warn("no plugins configured, defining defaults")
// Register defaults manually for backward compatibility/local dev
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
logger,
podmanMgr,
trackingplugins.MLflowOptions{
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
},
)
if err == nil {
trackingRegistry.Register(mlflowPlugin)
}
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
logger,
podmanMgr,
trackingplugins.TensorBoardOptions{
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
},
)
if err == nil {
trackingRegistry.Register(tensorboardPlugin)
}
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
} else {
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
return nil, fmt.Errorf("failed to load plugins: %w", err)
}
}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
server: srv,
queue: queueClient,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
ctx: ctx,
cancel: cancel,
logger: logger,
metrics: metricsObj,
shutdownCh: make(chan struct{}),
podman: podmanMgr,
jupyter: jupyterMgr,
trackingRegistry: trackingRegistry,
envPool: envpool.New(""),
}
rm, rmErr := resources.NewManager(resources.Options{
TotalCPU: parseCPUFromConfig(cfg),
GPUCount: parseGPUCountFromConfig(cfg),
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
})
if rmErr != nil {
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
}
worker.resources = rm
if !cfg.LocalMode {
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
if cfg.AppleGPU.Enabled {
logger.Warn("apple MPS GPU mode is intended for development; do not use in production",
"gpu_type", "apple",
)
}
if gpuType == "amd" {
logger.Warn("amd GPU mode is intended for development; do not use in production",
"gpu_type", "amd",
)
}
}
worker.setupMetricsExporter()
// Pre-pull tracking images in background
go worker.prePullImages()
return worker, nil
}
// prePullImages pulls required container images in the background
func (w *Worker) prePullImages() {
if w.config.LocalMode {
return
}
w.logger.Info("starting image pre-pulling")
// Pull worker image
if w.config.PodmanImage != "" {
w.pullImage(w.config.PodmanImage)
}
// Pull plugin images
for name, cfg := range w.config.Plugins {
if !cfg.Enabled || cfg.Image == "" {
continue
}
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.pullImage(cfg.Image)
}
}
// pullImage pulls a single container image
func (w *Worker) pullImage(image string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "podman", "pull", image)
if output, err := cmd.CombinedOutput(); err != nil {
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
} else {
w.logger.Info("image pulled successfully", "image", image)
}
}