refactor: Phase 4 - split worker package into focused files
Split 551-line worker/core.go into single-concern files: - worker/config.go (+44 lines) - Added config parsing: envInt(), parseCPUFromConfig(), parseGPUCountFromConfig() - parseGPUSlotsPerGPUFromConfig() - Now has all config logic in one place (440 lines total) - worker/metrics.go (new file, 172 lines) - Extracted setupMetricsExporter() with ~30 Prometheus metric registrations - Isolated metrics logic for easy modification - worker/factory.go (new file, 183 lines) - Extracted NewWorker() factory function - Moved prePullImages(), pullImage() from core.go - Centralized worker instantiation - worker/worker.go (renamed from core.go, ~100 lines) - Now just defines Worker struct, MLServer, JupyterManager - Clean, focused file without mixed concerns Lines redistributed: ~350 lines moved from monolithic core.go Build status: Compiles successfully
This commit is contained in:
parent
d1bef0a450
commit
a5c1a9fc0b
5 changed files with 580 additions and 550 deletions
|
|
@ -2,9 +2,12 @@ package worker
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -391,3 +394,49 @@ func (c *Config) Validate() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// envInt reads an integer from environment variable
|
||||
func envInt(name string) (int, bool) {
|
||||
v := strings.TrimSpace(os.Getenv(name))
|
||||
if v == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
|
||||
// parseCPUFromConfig determines total CPU from environment or config
|
||||
func parseCPUFromConfig(cfg *Config) int {
|
||||
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
|
||||
return n
|
||||
}
|
||||
if cfg != nil {
|
||||
if cfg.Resources.PodmanCPUs != "" {
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
|
||||
if f < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(math.Floor(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
|
||||
// parseGPUCountFromConfig detects GPU count from config
|
||||
func parseGPUCountFromConfig(cfg *Config) int {
|
||||
factory := &GPUDetectorFactory{}
|
||||
detector := factory.CreateDetector(cfg)
|
||||
return detector.DetectGPUCount()
|
||||
}
|
||||
|
||||
// parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment
|
||||
func parseGPUSlotsPerGPUFromConfig() int {
|
||||
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,550 +0,0 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||||
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility.
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
}
|
||||
|
||||
// JupyterManager is the subset of the Jupyter service manager used by the worker.
|
||||
// It exists to keep task execution testable.
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
// isValidName validates that input strings contain only safe characters.
|
||||
// isValidName checks if the input string is a valid name.
|
||||
func isValidName(input string) bool {
|
||||
return len(input) > 0 && len(input) < 256
|
||||
}
|
||||
|
||||
// NewMLServer creates a new ML server connection.
|
||||
// NewMLServer returns a new MLServer instance.
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
if cfg.LocalMode {
|
||||
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
|
||||
}
|
||||
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MLServer{SSHClient: client}, nil
|
||||
}
|
||||
|
||||
// Worker represents an ML task worker.
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
server *MLServer
|
||||
queue queue.Backend
|
||||
resources *resources.Manager
|
||||
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
||||
runningMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
|
||||
datasetCache map[string]time.Time
|
||||
datasetCacheMu sync.RWMutex
|
||||
datasetCacheTTL time.Duration
|
||||
|
||||
// Graceful shutdown fields
|
||||
shutdownCh chan struct{}
|
||||
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
||||
gracefulWait sync.WaitGroup
|
||||
|
||||
podman *container.PodmanManager
|
||||
jupyter JupyterManager
|
||||
trackingRegistry *tracking.Registry
|
||||
envPool *envpool.Pool
|
||||
|
||||
prewarmMu sync.Mutex
|
||||
prewarmTargetID string
|
||||
prewarmCancel context.CancelFunc
|
||||
prewarmStartedAt time.Time
|
||||
}
|
||||
|
||||
func envInt(name string) (int, bool) {
|
||||
v := strings.TrimSpace(os.Getenv(name))
|
||||
if v == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
|
||||
func parseCPUFromConfig(cfg *Config) int {
|
||||
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
|
||||
return n
|
||||
}
|
||||
if cfg != nil {
|
||||
if cfg.Resources.PodmanCPUs != "" {
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
|
||||
if f < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(math.Floor(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
|
||||
func parseGPUCountFromConfig(cfg *Config) int {
|
||||
factory := &GPUDetectorFactory{}
|
||||
detector := factory.CreateDetector(cfg)
|
||||
return detector.DetectGPUCount()
|
||||
}
|
||||
|
||||
func (w *Worker) getGPUDetector() GPUDetector {
|
||||
factory := &GPUDetectorFactory{}
|
||||
return factory.CreateDetector(w.config)
|
||||
}
|
||||
|
||||
func parseGPUSlotsPerGPUFromConfig() int {
|
||||
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (w *Worker) setupMetricsExporter() {
|
||||
if !w.config.Metrics.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
reg.MustRegister(
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
collectors.NewGoCollector(),
|
||||
)
|
||||
|
||||
labels := prometheus.Labels{"worker_id": w.id}
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_processed_total",
|
||||
Help: "Total tasks processed successfully by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksProcessed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_failed_total",
|
||||
Help: "Total tasks failed by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksFailed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_active",
|
||||
Help: "Number of tasks currently running on this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.runningCount())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_queued",
|
||||
Help: "Latest observed queue depth from Redis.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.QueuedTasks.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_transferred_bytes_total",
|
||||
Help: "Total bytes transferred while fetching datasets.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataTransferred.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_fetch_time_seconds_total",
|
||||
Help: "Total time spent fetching datasets (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_execution_time_seconds_total",
|
||||
Help: "Total execution time for completed tasks (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_hit_total",
|
||||
Help: "Total environment prewarm hits (warmed image already existed).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_miss_total",
|
||||
Help: "Total environment prewarm misses (warmed image did not exist yet).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_built_total",
|
||||
Help: "Total environment prewarm images built.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_time_seconds_total",
|
||||
Help: "Total time spent building prewarm images (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_hit_total",
|
||||
Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_miss_total",
|
||||
Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_built_total",
|
||||
Help: "Total snapshots prewarmed into .prewarm/.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_time_seconds_total",
|
||||
Help: "Total time spent prewarming snapshots (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_worker_max_concurrency",
|
||||
Help: "Configured maximum concurrent tasks for this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.config.MaxWorkers)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_total",
|
||||
Help: "Total CPU tokens managed by the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().TotalCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_free",
|
||||
Help: "Free CPU tokens currently available in the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().FreeCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_total",
|
||||
Help: "Total resource acquisition attempts.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_total",
|
||||
Help: "Total resource acquisitions that had to wait for resources.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireWaitTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_timeout_total",
|
||||
Help: "Total resource acquisition attempts that timed out.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_seconds_total",
|
||||
Help: "Total seconds spent waiting for resources across all acquisitions.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return w.resources.Snapshot().AcquireWaitSeconds
|
||||
}))
|
||||
|
||||
snap := w.resources.Snapshot()
|
||||
for i := range snap.GPUFree {
|
||||
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
|
||||
idx := i
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_total",
|
||||
Help: "Total GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().SlotsPerGPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_free",
|
||||
Help: "Free GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
s := w.resources.Snapshot()
|
||||
if idx < 0 || idx >= len(s.GPUFree) {
|
||||
return 0
|
||||
}
|
||||
return float64(s.GPUFree[idx])
|
||||
}))
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: w.config.Metrics.ListenAddr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
w.metricsSrv = srv
|
||||
go func() {
|
||||
w.logger.Info("metrics exporter listening",
|
||||
"addr", w.config.Metrics.ListenAddr,
|
||||
"enabled", w.config.Metrics.Enabled)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
w.logger.Warn("metrics exporter stopped",
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// NewWorker creates a new worker instance.
|
||||
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||
srv, err := NewMLServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := srv.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
backendCfg := queue.BackendConfig{
|
||||
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
SQLitePath: cfg.Queue.SQLitePath,
|
||||
FilesystemPath: cfg.Queue.FilesystemPath,
|
||||
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := queueClient.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create data_dir if it doesn't exist (for production without NAS)
|
||||
if cfg.DataDir != "" {
|
||||
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
||||
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
ctx = logging.EnsureTrace(ctx)
|
||||
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
||||
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
metricsObj := &metrics.Metrics{}
|
||||
|
||||
podmanMgr, err := container.NewPodmanManager(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create podman manager: %w", err)
|
||||
}
|
||||
|
||||
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
|
||||
}
|
||||
|
||||
trackingRegistry := tracking.NewRegistry(logger)
|
||||
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
|
||||
|
||||
if len(cfg.Plugins) == 0 {
|
||||
logger.Warn("no plugins configured, defining defaults")
|
||||
// Register defaults manually for backward compatibility/local dev
|
||||
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.MLflowOptions{
|
||||
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(mlflowPlugin)
|
||||
}
|
||||
|
||||
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.TensorBoardOptions{
|
||||
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(tensorboardPlugin)
|
||||
}
|
||||
|
||||
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
|
||||
} else {
|
||||
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
|
||||
return nil, fmt.Errorf("failed to load plugins: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
server: srv,
|
||||
queue: queueClient,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
metrics: metricsObj,
|
||||
shutdownCh: make(chan struct{}),
|
||||
podman: podmanMgr,
|
||||
jupyter: jupyterMgr,
|
||||
trackingRegistry: trackingRegistry,
|
||||
envPool: envpool.New(""),
|
||||
}
|
||||
|
||||
rm, rmErr := resources.NewManager(resources.Options{
|
||||
TotalCPU: parseCPUFromConfig(cfg),
|
||||
GPUCount: parseGPUCountFromConfig(cfg),
|
||||
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
|
||||
})
|
||||
if rmErr != nil {
|
||||
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
|
||||
}
|
||||
worker.resources = rm
|
||||
|
||||
if !cfg.LocalMode {
|
||||
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
||||
if cfg.AppleGPU.Enabled {
|
||||
logger.Warn("apple MPS GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "apple",
|
||||
)
|
||||
}
|
||||
if gpuType == "amd" {
|
||||
logger.Warn("amd GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "amd",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
worker.setupMetricsExporter()
|
||||
|
||||
// Pre-pull tracking images in background
|
||||
go worker.prePullImages()
|
||||
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
func (w *Worker) prePullImages() {
|
||||
if w.config.LocalMode {
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("starting image pre-pulling")
|
||||
|
||||
// Pull worker image
|
||||
if w.config.PodmanImage != "" {
|
||||
w.pullImage(w.config.PodmanImage)
|
||||
}
|
||||
|
||||
// Pull plugin images
|
||||
for name, cfg := range w.config.Plugins {
|
||||
if !cfg.Enabled || cfg.Image == "" {
|
||||
continue
|
||||
}
|
||||
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||
w.pullImage(cfg.Image)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) pullImage(image string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||
} else {
|
||||
w.logger.Info("image pulled successfully", "image", image)
|
||||
}
|
||||
}
|
||||
212
internal/worker/factory.go
Normal file
212
internal/worker/factory.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||||
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
||||
)
|
||||
|
||||
// NewWorker creates a new worker instance.
|
||||
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||
srv, err := NewMLServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := srv.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
backendCfg := queue.BackendConfig{
|
||||
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
SQLitePath: cfg.Queue.SQLitePath,
|
||||
FilesystemPath: cfg.Queue.FilesystemPath,
|
||||
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := queueClient.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create data_dir if it doesn't exist (for production without NAS)
|
||||
if cfg.DataDir != "" {
|
||||
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
||||
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
ctx = logging.EnsureTrace(ctx)
|
||||
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
||||
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
metricsObj := &metrics.Metrics{}
|
||||
|
||||
podmanMgr, err := container.NewPodmanManager(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create podman manager: %w", err)
|
||||
}
|
||||
|
||||
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
|
||||
}
|
||||
|
||||
trackingRegistry := tracking.NewRegistry(logger)
|
||||
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
|
||||
|
||||
if len(cfg.Plugins) == 0 {
|
||||
logger.Warn("no plugins configured, defining defaults")
|
||||
// Register defaults manually for backward compatibility/local dev
|
||||
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.MLflowOptions{
|
||||
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(mlflowPlugin)
|
||||
}
|
||||
|
||||
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.TensorBoardOptions{
|
||||
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(tensorboardPlugin)
|
||||
}
|
||||
|
||||
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
|
||||
} else {
|
||||
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
|
||||
return nil, fmt.Errorf("failed to load plugins: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
server: srv,
|
||||
queue: queueClient,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
metrics: metricsObj,
|
||||
shutdownCh: make(chan struct{}),
|
||||
podman: podmanMgr,
|
||||
jupyter: jupyterMgr,
|
||||
trackingRegistry: trackingRegistry,
|
||||
envPool: envpool.New(""),
|
||||
}
|
||||
|
||||
rm, rmErr := resources.NewManager(resources.Options{
|
||||
TotalCPU: parseCPUFromConfig(cfg),
|
||||
GPUCount: parseGPUCountFromConfig(cfg),
|
||||
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
|
||||
})
|
||||
if rmErr != nil {
|
||||
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
|
||||
}
|
||||
worker.resources = rm
|
||||
|
||||
if !cfg.LocalMode {
|
||||
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
||||
if cfg.AppleGPU.Enabled {
|
||||
logger.Warn("apple MPS GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "apple",
|
||||
)
|
||||
}
|
||||
if gpuType == "amd" {
|
||||
logger.Warn("amd GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "amd",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
worker.setupMetricsExporter()
|
||||
|
||||
// Pre-pull tracking images in background
|
||||
go worker.prePullImages()
|
||||
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
// prePullImages pulls required container images in the background
|
||||
func (w *Worker) prePullImages() {
|
||||
if w.config.LocalMode {
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("starting image pre-pulling")
|
||||
|
||||
// Pull worker image
|
||||
if w.config.PodmanImage != "" {
|
||||
w.pullImage(w.config.PodmanImage)
|
||||
}
|
||||
|
||||
// Pull plugin images
|
||||
for name, cfg := range w.config.Plugins {
|
||||
if !cfg.Enabled || cfg.Image == "" {
|
||||
continue
|
||||
}
|
||||
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||
w.pullImage(cfg.Image)
|
||||
}
|
||||
}
|
||||
|
||||
// pullImage pulls a single container image
|
||||
func (w *Worker) pullImage(image string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||
} else {
|
||||
w.logger.Info("image pulled successfully", "image", image)
|
||||
}
|
||||
}
|
||||
224
internal/worker/metrics.go
Normal file
224
internal/worker/metrics.go
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// setupMetricsExporter initializes the Prometheus metrics exporter
|
||||
func (w *Worker) setupMetricsExporter() {
|
||||
if !w.config.Metrics.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
reg.MustRegister(
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
collectors.NewGoCollector(),
|
||||
)
|
||||
|
||||
labels := prometheus.Labels{"worker_id": w.id}
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_processed_total",
|
||||
Help: "Total tasks processed successfully by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksProcessed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_failed_total",
|
||||
Help: "Total tasks failed by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksFailed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_active",
|
||||
Help: "Number of tasks currently running on this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.runningCount())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_queued",
|
||||
Help: "Latest observed queue depth from Redis.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.QueuedTasks.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_transferred_bytes_total",
|
||||
Help: "Total bytes transferred while fetching datasets.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataTransferred.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_fetch_time_seconds_total",
|
||||
Help: "Total time spent fetching datasets (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_execution_time_seconds_total",
|
||||
Help: "Total execution time for completed tasks (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_hit_total",
|
||||
Help: "Total environment prewarm hits (warmed image already existed).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_miss_total",
|
||||
Help: "Total environment prewarm misses (warmed image did not exist yet).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_built_total",
|
||||
Help: "Total environment prewarm images built.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_time_seconds_total",
|
||||
Help: "Total time spent building prewarm images (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_hit_total",
|
||||
Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_miss_total",
|
||||
Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_built_total",
|
||||
Help: "Total snapshots prewarmed into .prewarm/.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_time_seconds_total",
|
||||
Help: "Total time spent prewarming snapshots (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_worker_max_concurrency",
|
||||
Help: "Configured maximum concurrent tasks for this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.config.MaxWorkers)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_total",
|
||||
Help: "Total CPU tokens managed by the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().TotalCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_free",
|
||||
Help: "Free CPU tokens currently available in the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().FreeCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_total",
|
||||
Help: "Total resource acquisition attempts.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_total",
|
||||
Help: "Total resource acquisitions that had to wait for resources.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireWaitTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_timeout_total",
|
||||
Help: "Total resource acquisition attempts that timed out.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_seconds_total",
|
||||
Help: "Total seconds spent waiting for resources across all acquisitions.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return w.resources.Snapshot().AcquireWaitSeconds
|
||||
}))
|
||||
|
||||
snap := w.resources.Snapshot()
|
||||
for i := range snap.GPUFree {
|
||||
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
|
||||
idx := i
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_total",
|
||||
Help: "Total GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().SlotsPerGPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_free",
|
||||
Help: "Free GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
s := w.resources.Snapshot()
|
||||
if idx < 0 || idx >= len(s.GPUFree) {
|
||||
return 0
|
||||
}
|
||||
return float64(s.GPUFree[idx])
|
||||
}))
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: w.config.Metrics.ListenAddr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
w.metricsSrv = srv
|
||||
go func() {
|
||||
w.logger.Info("metrics exporter listening",
|
||||
"addr", w.config.Metrics.ListenAddr,
|
||||
"enabled", w.config.Metrics.Enabled)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
w.logger.Warn("metrics exporter stopped",
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
95
internal/worker/worker.go
Normal file
95
internal/worker/worker.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
)
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility.
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
}
|
||||
|
||||
// JupyterManager is the subset of the Jupyter service manager used by the worker.
|
||||
// It exists to keep task execution testable.
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
// isValidName validates that input strings contain only safe characters.
|
||||
// isValidName checks if the input string is a valid name.
|
||||
func isValidName(input string) bool {
|
||||
return len(input) > 0 && len(input) < 256
|
||||
}
|
||||
|
||||
// NewMLServer creates a new ML server connection.
|
||||
// NewMLServer returns a new MLServer instance.
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
if cfg.LocalMode {
|
||||
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
|
||||
}
|
||||
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MLServer{SSHClient: client}, nil
|
||||
}
|
||||
|
||||
// Worker represents an ML task worker.
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
server *MLServer
|
||||
queue queue.Backend
|
||||
resources *resources.Manager
|
||||
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
||||
runningMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
|
||||
datasetCache map[string]time.Time
|
||||
datasetCacheMu sync.RWMutex
|
||||
datasetCacheTTL time.Duration
|
||||
|
||||
// Graceful shutdown fields
|
||||
shutdownCh chan struct{}
|
||||
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
||||
gracefulWait sync.WaitGroup
|
||||
|
||||
podman *container.PodmanManager
|
||||
jupyter JupyterManager
|
||||
trackingRegistry *tracking.Registry
|
||||
envPool *envpool.Pool
|
||||
|
||||
prewarmMu sync.Mutex
|
||||
prewarmTargetID string
|
||||
prewarmCancel context.CancelFunc
|
||||
prewarmStartedAt time.Time
|
||||
}
|
||||
|
||||
func (w *Worker) getGPUDetector() GPUDetector {
|
||||
factory := &GPUDetectorFactory{}
|
||||
return factory.CreateDetector(w.config)
|
||||
}
|
||||
Loading…
Reference in a new issue