feat(worker): add integrity checks, snapshot staging, and prewarm support

This commit is contained in:
Jeremie Fraeys 2026-01-05 12:31:13 -05:00
parent add4a90e62
commit 82034c68f3
12 changed files with 4493 additions and 1145 deletions

View file

@ -1,179 +0,0 @@
package main
import (
"fmt"
"path/filepath"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"gopkg.in/yaml.v3"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
datasetCacheDefaultTTL = 30 * time.Minute
)
// Config holds worker configuration.
type Config struct {
Host string `yaml:"host"`
User string `yaml:"user"`
SSHKey string `yaml:"ssh_key"`
Port int `yaml:"port"`
BasePath string `yaml:"base_path"`
TrainScript string `yaml:"train_script"`
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
KnownHosts string `yaml:"known_hosts"`
WorkerID string `yaml:"worker_id"`
MaxWorkers int `yaml:"max_workers"`
PollInterval int `yaml:"poll_interval_seconds"`
Resources config.ResourceConfig `yaml:"resources"`
LocalMode bool `yaml:"local_mode"`
// Authentication
Auth auth.Config `yaml:"auth"`
// Metrics exporter
Metrics MetricsConfig `yaml:"metrics"`
// Metrics buffering
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
// Data management
DataManagerPath string `yaml:"data_manager_path"`
AutoFetchData bool `yaml:"auto_fetch_data"`
DataDir string `yaml:"data_dir"`
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
// Podman execution
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUAccess bool `yaml:"gpu_access"`
// Task lease and retry settings
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min)
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min)
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min)
}
// MetricsConfig controls the Prometheus exporter.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_addr"`
}
// LoadConfig loads worker configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := config.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
}
cfg.Resources.ApplyDefaults()
if cfg.MaxWorkers > 0 {
cfg.Resources.MaxWorkers = cfg.MaxWorkers
} else {
cfg.MaxWorkers = cfg.Resources.MaxWorkers
}
if cfg.PollInterval == 0 {
cfg.PollInterval = smart.PollInterval()
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
}
if cfg.DataDir == "" {
if cfg.Host == "" || !cfg.AutoFetchData {
cfg.DataDir = config.DefaultLocalDataDir
} else {
cfg.DataDir = smart.DataDir()
}
}
if cfg.Metrics.ListenAddr == "" {
cfg.Metrics.ListenAddr = ":9100"
}
if cfg.MetricsFlushInterval == 0 {
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
// Set lease and retry defaults
if cfg.TaskLeaseDuration == 0 {
cfg.TaskLeaseDuration = 30 * time.Minute
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 1 * time.Minute
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.GracefulTimeout == 0 {
cfg.GracefulTimeout = 5 * time.Minute
}
return &cfg, nil
}
// Validate implements config.Validator interface.
func (c *Config) Validate() error {
if c.Port != 0 {
if err := config.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = config.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
if c.MaxWorkers < 1 {
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
}
return nil
}
// Task struct and Redis constants moved to internal/queue

File diff suppressed because it is too large Load diff

288
internal/envpool/envpool.go Normal file
View file

@ -0,0 +1,288 @@
package envpool
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
)
type CommandRunner interface {
CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error)
}
type execRunner struct{}
func (r execRunner) CombinedOutput(
ctx context.Context,
name string,
args ...string,
) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
return cmd.CombinedOutput()
}
type Pool struct {
runner CommandRunner
imagePrefix string
cacheMu sync.Mutex
cache map[string]cacheEntry
cacheTTL time.Duration
}
type cacheEntry struct {
exists bool
expires time.Time
}
func New(imagePrefix string) *Pool {
prefix := strings.TrimSpace(imagePrefix)
if prefix == "" {
prefix = "fetchml-prewarm"
}
return &Pool{
runner: execRunner{},
imagePrefix: prefix,
cache: make(map[string]cacheEntry),
cacheTTL: 30 * time.Second,
}
}
func (p *Pool) WithRunner(r CommandRunner) *Pool {
if r != nil {
p.runner = r
}
return p
}
func (p *Pool) WithCacheTTL(ttl time.Duration) *Pool {
if ttl > 0 {
p.cacheTTL = ttl
}
return p
}
func (p *Pool) WarmImageTag(depsManifestSHA256 string) (string, error) {
sha := strings.TrimSpace(depsManifestSHA256)
if sha == "" {
return "", fmt.Errorf("missing deps sha256")
}
if !isLowerHexLen(sha, 64) {
return "", fmt.Errorf("invalid deps sha256")
}
return fmt.Sprintf("%s:%s", p.imagePrefix, sha[:12]), nil
}
func (p *Pool) ImageExists(ctx context.Context, imageRef string) (bool, error) {
ref := strings.TrimSpace(imageRef)
if ref == "" {
return false, fmt.Errorf("missing image ref")
}
p.cacheMu.Lock()
if ent, ok := p.cache[ref]; ok && time.Now().Before(ent.expires) {
exists := ent.exists
p.cacheMu.Unlock()
return exists, nil
}
p.cacheMu.Unlock()
out, err := p.runner.CombinedOutput(ctx, "podman", "image", "inspect", ref)
if err == nil {
p.setCache(ref, true)
return true, nil
}
if looksLikeImageNotFound(out) {
p.setCache(ref, false)
return false, nil
}
var ee *exec.ExitError
if errors.As(err, &ee) {
p.setCache(ref, false)
return false, nil
}
return false, err
}
func looksLikeImageNotFound(out []byte) bool {
s := strings.ToLower(strings.TrimSpace(string(out)))
if s == "" {
return false
}
return strings.Contains(s, "no such") ||
strings.Contains(s, "not found") ||
strings.Contains(s, "does not exist")
}
func (p *Pool) setCache(imageRef string, exists bool) {
p.cacheMu.Lock()
p.cache[imageRef] = cacheEntry{exists: exists, expires: time.Now().Add(p.cacheTTL)}
p.cacheMu.Unlock()
}
type PrepareRequest struct {
BaseImage string
TargetImage string
HostWorkspace string
ContainerWorkspace string
DepsPathInContainer string
}
func (p *Pool) PruneImages(ctx context.Context, olderThan time.Duration) error {
if olderThan <= 0 {
return fmt.Errorf("invalid olderThan")
}
h := int(olderThan.Round(time.Hour).Hours())
if h < 1 {
h = 1
}
until := fmt.Sprintf("%dh", h)
_, err := p.runner.CombinedOutput(
ctx,
"podman",
"image",
"prune",
"-a",
"-f",
"--filter",
"label=fetchml.prewarm=true",
"--filter",
"until="+until,
)
return err
}
func (p *Pool) Prepare(ctx context.Context, req PrepareRequest) error {
baseImage := strings.TrimSpace(req.BaseImage)
targetImage := strings.TrimSpace(req.TargetImage)
hostWS := strings.TrimSpace(req.HostWorkspace)
containerWS := strings.TrimSpace(req.ContainerWorkspace)
depsInContainer := strings.TrimSpace(req.DepsPathInContainer)
if baseImage == "" {
return fmt.Errorf("missing base image")
}
if targetImage == "" {
return fmt.Errorf("missing target image")
}
if hostWS == "" {
return fmt.Errorf("missing host workspace")
}
if containerWS == "" {
return fmt.Errorf("missing container workspace")
}
if depsInContainer == "" {
return fmt.Errorf("missing deps path")
}
if !strings.HasPrefix(depsInContainer, containerWS) {
return fmt.Errorf("deps path must be under container workspace")
}
exists, err := p.ImageExists(ctx, targetImage)
if err != nil {
return err
}
if exists {
return nil
}
containerName, err := randomContainerName("fetchml-prewarm")
if err != nil {
return err
}
// Do not use --rm since we need a container to commit.
runArgs := []string{
"run",
"--name", containerName,
"--security-opt", "no-new-privileges",
"--cap-drop", "ALL",
"--userns", "keep-id",
"-v", fmt.Sprintf("%s:%s:rw", hostWS, containerWS),
baseImage,
"--workspace", containerWS,
"--deps", depsInContainer,
"--prepare-only",
}
if out, err := p.runner.CombinedOutput(ctx, "podman", runArgs...); err != nil {
_ = p.cleanupContainer(context.Background(), containerName)
return fmt.Errorf("podman run prewarm failed: %w", scrubOutput(out, err))
}
if out, err := p.runner.CombinedOutput(
ctx,
"podman",
"commit",
containerName,
targetImage,
); err != nil {
_ = p.cleanupContainer(context.Background(), containerName)
return fmt.Errorf("podman commit prewarm failed: %w", scrubOutput(out, err))
}
_, _ = p.runner.CombinedOutput(
ctx,
"podman",
"image",
"label",
targetImage,
"fetchml.prewarm=true",
)
_ = p.cleanupContainer(context.Background(), containerName)
p.setCache(targetImage, true)
return nil
}
func (p *Pool) cleanupContainer(ctx context.Context, name string) error {
n := strings.TrimSpace(name)
if n == "" {
return nil
}
_, err := p.runner.CombinedOutput(ctx, "podman", "rm", n)
return err
}
func randomContainerName(prefix string) (string, error) {
p := strings.TrimSpace(prefix)
if p == "" {
p = "fetchml-prewarm"
}
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
return "", err
}
return fmt.Sprintf("%s-%s", p, hex.EncodeToString(b)), nil
}
func isLowerHexLen(s string, want int) bool {
if len(s) != want {
return false
}
for i := 0; i < len(s); i++ {
c := s[i]
if (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') {
continue
}
return false
}
return true
}
func scrubOutput(out []byte, err error) error {
if len(out) == 0 {
return err
}
s := strings.TrimSpace(string(out))
if len(s) > 400 {
s = s[:400]
}
return fmt.Errorf("%w (output=%q)", err, s)
}

View file

@ -0,0 +1,323 @@
package resources
import (
"context"
"errors"
"fmt"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/jfraeys/fetch_ml/internal/queue"
)
type Manager struct {
mu sync.Mutex
cond *sync.Cond
totalCPU int
freeCPU int
slotsPerGPU int
gpuFree []int
acquireTotal atomic.Int64
acquireWaitTotal atomic.Int64
acquireTimeoutTotal atomic.Int64
acquireWaitNanos atomic.Int64
}
type Snapshot struct {
TotalCPU int
FreeCPU int
SlotsPerGPU int
GPUFree []int
AcquireTotal int64
AcquireWaitTotal int64
AcquireTimeoutTotal int64
AcquireWaitSeconds float64
}
func FormatCUDAVisibleDevices(lease *Lease) string {
if lease == nil {
return "-1"
}
if len(lease.gpus) == 0 {
return "-1"
}
idx := make([]int, 0, len(lease.gpus))
for _, g := range lease.gpus {
idx = append(idx, g.Index)
}
sort.Ints(idx)
parts := make([]string, 0, len(idx))
for _, i := range idx {
parts = append(parts, strconv.Itoa(i))
}
return strings.Join(parts, ",")
}
type GPUAllocation struct {
Index int
Slots int
}
type Lease struct {
cpu int
gpus []GPUAllocation
m *Manager
}
func (l *Lease) CPU() int { return l.cpu }
func (l *Lease) GPUs() []GPUAllocation {
out := make([]GPUAllocation, len(l.gpus))
copy(out, l.gpus)
return out
}
func (l *Lease) Release() {
if l == nil || l.m == nil {
return
}
m := l.m
m.mu.Lock()
defer m.mu.Unlock()
if l.cpu > 0 {
m.freeCPU += l.cpu
if m.freeCPU > m.totalCPU {
m.freeCPU = m.totalCPU
}
}
for _, g := range l.gpus {
if g.Index >= 0 && g.Index < len(m.gpuFree) {
m.gpuFree[g.Index] += g.Slots
if m.gpuFree[g.Index] > m.slotsPerGPU {
m.gpuFree[g.Index] = m.slotsPerGPU
}
}
}
m.cond.Broadcast()
}
type Options struct {
TotalCPU int
GPUCount int
SlotsPerGPU int
}
func NewManager(opts Options) (*Manager, error) {
if opts.TotalCPU < 0 {
return nil, fmt.Errorf("total cpu must be >= 0")
}
if opts.GPUCount < 0 {
return nil, fmt.Errorf("gpu count must be >= 0")
}
if opts.SlotsPerGPU <= 0 {
opts.SlotsPerGPU = 1
}
m := &Manager{
totalCPU: opts.TotalCPU,
freeCPU: opts.TotalCPU,
slotsPerGPU: opts.SlotsPerGPU,
gpuFree: make([]int, opts.GPUCount),
}
for i := range m.gpuFree {
m.gpuFree[i] = m.slotsPerGPU
}
m.cond = sync.NewCond(&m.mu)
return m, nil
}
func (m *Manager) Snapshot() Snapshot {
if m == nil {
return Snapshot{}
}
m.mu.Lock()
gpuFree := make([]int, len(m.gpuFree))
copy(gpuFree, m.gpuFree)
totalCPU := m.totalCPU
freeCPU := m.freeCPU
slotsPerGPU := m.slotsPerGPU
m.mu.Unlock()
waitNanos := m.acquireWaitNanos.Load()
return Snapshot{
TotalCPU: totalCPU,
FreeCPU: freeCPU,
SlotsPerGPU: slotsPerGPU,
GPUFree: gpuFree,
AcquireTotal: m.acquireTotal.Load(),
AcquireWaitTotal: m.acquireWaitTotal.Load(),
AcquireTimeoutTotal: m.acquireTimeoutTotal.Load(),
AcquireWaitSeconds: float64(waitNanos) / float64(time.Second),
}
}
func (m *Manager) Acquire(ctx context.Context, task *queue.Task) (*Lease, error) {
if m == nil {
return nil, fmt.Errorf("resource manager is nil")
}
if task == nil {
return nil, fmt.Errorf("task is nil")
}
if ctx == nil {
return nil, fmt.Errorf("context is nil")
}
m.acquireTotal.Add(1)
start := time.Now()
waited := false
reqCPU := task.CPU
if reqCPU < 0 {
return nil, fmt.Errorf("cpu request must be >= 0")
}
if reqCPU > m.totalCPU {
return nil, fmt.Errorf("cpu request %d exceeds total cpu %d", reqCPU, m.totalCPU)
}
reqGPU := task.GPU
if reqGPU < 0 {
return nil, fmt.Errorf("gpu request must be >= 0")
}
if reqGPU > len(m.gpuFree) {
return nil, fmt.Errorf("gpu request %d exceeds available gpus %d", reqGPU, len(m.gpuFree))
}
slotsPerTaskGPU, err := m.gpuSlotsForTask(task.GPUMemory)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
for {
if ctx.Err() != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
m.acquireTimeoutTotal.Add(1)
}
return nil, ctx.Err()
}
gpuAlloc, ok := m.tryAllocateGPUsLocked(reqGPU, slotsPerTaskGPU)
if ok && (reqCPU == 0 || m.freeCPU >= reqCPU) {
if reqCPU > 0 {
m.freeCPU -= reqCPU
}
for _, g := range gpuAlloc {
m.gpuFree[g.Index] -= g.Slots
}
if waited {
m.acquireWaitTotal.Add(1)
m.acquireWaitNanos.Add(time.Since(start).Nanoseconds())
}
return &Lease{cpu: reqCPU, gpus: gpuAlloc, m: m}, nil
}
waited = true
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
m.mu.Lock()
m.cond.Broadcast()
m.mu.Unlock()
case <-done:
}
}()
m.cond.Wait()
close(done)
}
}
func (m *Manager) gpuSlotsForTask(gpuMem string) (int, error) {
if m.slotsPerGPU <= 0 {
return 1, nil
}
if strings.TrimSpace(gpuMem) == "" {
return m.slotsPerGPU, nil
}
if frac, ok := parseFraction(strings.TrimSpace(gpuMem)); ok {
if frac <= 0 {
return 1, nil
}
if frac > 1 {
frac = 1
}
slots := int(math.Ceil(frac * float64(m.slotsPerGPU)))
if slots < 1 {
slots = 1
}
if slots > m.slotsPerGPU {
slots = m.slotsPerGPU
}
return slots, nil
}
return m.slotsPerGPU, nil
}
func (m *Manager) tryAllocateGPUsLocked(reqGPU int, slotsPerTaskGPU int) ([]GPUAllocation, bool) {
if reqGPU == 0 {
return nil, true
}
if slotsPerTaskGPU <= 0 {
slotsPerTaskGPU = m.slotsPerGPU
}
alloc := make([]GPUAllocation, 0, reqGPU)
used := make(map[int]struct{}, reqGPU)
for len(alloc) < reqGPU {
bestIdx := -1
bestFree := -1
for i := 0; i < len(m.gpuFree); i++ {
if _, ok := used[i]; ok {
continue
}
free := m.gpuFree[i]
if free >= slotsPerTaskGPU && free > bestFree {
bestFree = free
bestIdx = i
}
}
if bestIdx < 0 {
return nil, false
}
used[bestIdx] = struct{}{}
alloc = append(alloc, GPUAllocation{Index: bestIdx, Slots: slotsPerTaskGPU})
}
return alloc, true
}
func parseFraction(s string) (float64, bool) {
if s == "" {
return 0, false
}
if strings.HasSuffix(s, "%") {
v := strings.TrimSuffix(s, "%")
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil {
return 0, false
}
return f / 100.0, true
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, false
}
if f > 1 {
return 0, false
}
return f, true
}

371
internal/worker/config.go Normal file
View file

@ -0,0 +1,371 @@
package worker
import (
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
"gopkg.in/yaml.v3"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
datasetCacheDefaultTTL = 30 * time.Minute
)
type QueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
}
// Config holds worker configuration.
type Config struct {
Host string `yaml:"host"`
User string `yaml:"user"`
SSHKey string `yaml:"ssh_key"`
Port int `yaml:"port"`
BasePath string `yaml:"base_path"`
TrainScript string `yaml:"train_script"`
RedisURL string `yaml:"redis_url"`
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
Queue QueueConfig `yaml:"queue"`
KnownHosts string `yaml:"known_hosts"`
WorkerID string `yaml:"worker_id"`
MaxWorkers int `yaml:"max_workers"`
PollInterval int `yaml:"poll_interval_seconds"`
Resources config.ResourceConfig `yaml:"resources"`
LocalMode bool `yaml:"local_mode"`
// Authentication
Auth auth.Config `yaml:"auth"`
// Metrics exporter
Metrics MetricsConfig `yaml:"metrics"`
// Metrics buffering
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
// Data management
DataManagerPath string `yaml:"data_manager_path"`
AutoFetchData bool `yaml:"auto_fetch_data"`
DataDir string `yaml:"data_dir"`
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"`
// Provenance enforcement
// Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort.
ProvenanceBestEffort bool `yaml:"provenance_best_effort"`
// Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env).
PrewarmEnabled bool `yaml:"prewarm_enabled"`
// Podman execution
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUDevices []string `yaml:"gpu_devices"`
GPUVendor string `yaml:"gpu_vendor"`
GPUVisibleDevices []int `yaml:"gpu_visible_devices"`
GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"`
// Apple M-series GPU configuration
AppleGPU AppleGPUConfig `yaml:"apple_gpu"`
// Task lease and retry settings
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min)
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min)
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
// Plugins configuration
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
}
// MetricsConfig controls the Prometheus exporter.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_addr"`
}
type SnapshotStoreConfig struct {
Enabled bool `yaml:"enabled"`
Endpoint string `yaml:"endpoint"`
Secure bool `yaml:"secure"`
Region string `yaml:"region"`
Bucket string `yaml:"bucket"`
Prefix string `yaml:"prefix"`
AccessKey string `yaml:"access_key"`
SecretKey string `yaml:"secret_key"`
SessionToken string `yaml:"session_token"`
Timeout time.Duration `yaml:"timeout"`
MaxRetries int `yaml:"max_retries"`
}
// AppleGPUConfig holds configuration for Apple M-series GPU support
type AppleGPUConfig struct {
Enabled bool `yaml:"enabled"`
MetalDevice string `yaml:"metal_device"`
MPSRuntime string `yaml:"mps_runtime"`
}
// LoadConfig loads worker configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
if strings.TrimSpace(cfg.RedisURL) != "" {
cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL))
cfg.RedisAddr = cfg.RedisURL
cfg.RedisPassword = ""
cfg.RedisDB = 0
}
// Get smart defaults for current environment
smart := config.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
}
cfg.Resources.ApplyDefaults()
if cfg.MaxWorkers > 0 {
cfg.Resources.MaxWorkers = cfg.MaxWorkers
} else {
cfg.MaxWorkers = cfg.Resources.MaxWorkers
}
if cfg.PollInterval == 0 {
cfg.PollInterval = smart.PollInterval()
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
}
if cfg.DataDir == "" {
if cfg.Host == "" || !cfg.AutoFetchData {
cfg.DataDir = config.DefaultLocalDataDir
} else {
cfg.DataDir = smart.DataDir()
}
}
if cfg.SnapshotStore.Timeout == 0 {
cfg.SnapshotStore.Timeout = 10 * time.Minute
}
if cfg.SnapshotStore.MaxRetries == 0 {
cfg.SnapshotStore.MaxRetries = 3
}
if cfg.Metrics.ListenAddr == "" {
cfg.Metrics.ListenAddr = ":9100"
}
if cfg.MetricsFlushInterval == 0 {
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
if strings.TrimSpace(cfg.Queue.Backend) == "" {
cfg.Queue.Backend = string(queue.QueueBackendRedis)
}
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) {
if strings.TrimSpace(cfg.Queue.SQLitePath) == "" {
cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db")
}
cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath)
}
if strings.TrimSpace(cfg.GPUVendor) == "" {
if cfg.AppleGPU.Enabled {
cfg.GPUVendor = string(GPUTypeApple)
} else if len(cfg.GPUDevices) > 0 ||
len(cfg.GPUVisibleDevices) > 0 ||
len(cfg.GPUVisibleDeviceIDs) > 0 {
cfg.GPUVendor = string(GPUTypeNVIDIA)
} else {
cfg.GPUVendor = string(GPUTypeNone)
}
}
// Set lease and retry defaults
if cfg.TaskLeaseDuration == 0 {
cfg.TaskLeaseDuration = 30 * time.Minute
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 1 * time.Minute
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.GracefulTimeout == 0 {
cfg.GracefulTimeout = 5 * time.Minute
}
return &cfg, nil
}
// Validate implements config.Validator interface.
func (c *Config) Validate() error {
if c.Port != 0 {
if err := config.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = config.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
}
}
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
if backend == "" {
backend = string(queue.QueueBackendRedis)
c.Queue.Backend = backend
}
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) {
return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite)
}
if backend == string(queue.QueueBackendSQLite) {
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite)
}
c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath)
if !filepath.IsAbs(c.Queue.SQLitePath) {
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
}
}
if c.RedisAddr != "" {
addr := strings.TrimSpace(c.RedisAddr)
if strings.HasPrefix(addr, "redis://") {
u, err := url.Parse(addr)
if err != nil {
return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err)
}
if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("invalid Redis configuration: invalid redis url")
}
} else {
if err := config.ValidateRedisAddr(addr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
}
if c.MaxWorkers < 1 {
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
}
switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) {
case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd":
// ok
default:
return fmt.Errorf(
"gpu_vendor must be one of %q, %q, %q, %q",
string(GPUTypeNVIDIA),
"amd",
string(GPUTypeApple),
string(GPUTypeNone),
)
}
// Strict GPU visibility configuration:
// - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive.
// - UUID-style gpu_visible_device_ids is NVIDIA-only.
vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor))
if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 {
return fmt.Errorf("gpu_visible_devices and gpu_visible_device_ids are mutually exclusive")
}
if len(c.GPUVisibleDeviceIDs) > 0 {
if vendor != string(GPUTypeNVIDIA) {
return fmt.Errorf(
"gpu_visible_device_ids is only supported when gpu_vendor is %q",
string(GPUTypeNVIDIA),
)
}
for _, id := range c.GPUVisibleDeviceIDs {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("gpu_visible_device_ids contains an empty value")
}
if !strings.HasPrefix(id, "GPU-") {
return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id)
}
}
}
if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) {
if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 {
return fmt.Errorf(
"gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q",
vendor,
)
}
}
if vendor == "amd" {
if len(c.GPUVisibleDeviceIDs) > 0 {
return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor)
}
for _, idx := range c.GPUVisibleDevices {
if idx < 0 {
return fmt.Errorf("gpu_visible_devices contains negative index %d", idx)
}
}
}
if c.SnapshotStore.Enabled {
if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" {
return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true")
}
if strings.TrimSpace(c.SnapshotStore.Bucket) == "" {
return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true")
}
ak := strings.TrimSpace(c.SnapshotStore.AccessKey)
sk := strings.TrimSpace(c.SnapshotStore.SecretKey)
if (ak == "") != (sk == "") {
return fmt.Errorf(
"snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty",
)
}
if c.SnapshotStore.Timeout < 0 {
return fmt.Errorf("snapshot_store.timeout must be >= 0")
}
if c.SnapshotStore.MaxRetries < 0 {
return fmt.Errorf("snapshot_store.max_retries must be >= 0")
}
}
return nil
}

547
internal/worker/core.go Normal file
View file

@ -0,0 +1,547 @@
package worker
import (
"context"
"fmt"
"log"
"log/slog"
"math"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/envpool"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// MLServer wraps network.SSHClient for backward compatibility.
type MLServer struct {
*network.SSHClient
}
// JupyterManager is the subset of the Jupyter service manager used by the worker.
// It exists to keep task execution testable.
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
}
// isValidName validates that input strings contain only safe characters.
// isValidName checks if the input string is a valid name.
func isValidName(input string) bool {
return len(input) > 0 && len(input) < 256
}
// NewMLServer creates a new ML server connection.
// NewMLServer returns a new MLServer instance.
func NewMLServer(cfg *Config) (*MLServer, error) {
if cfg.LocalMode {
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
}
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
return &MLServer{SSHClient: client}, nil
}
// Worker represents an ML task worker.
type Worker struct {
id string
config *Config
server *MLServer
queue queue.Backend
resources *resources.Manager
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
runningMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
metrics *metrics.Metrics
metricsSrv *http.Server
datasetCache map[string]time.Time
datasetCacheMu sync.RWMutex
datasetCacheTTL time.Duration
// Graceful shutdown fields
shutdownCh chan struct{}
activeTasks sync.Map // map[string]*queue.Task - track active tasks
gracefulWait sync.WaitGroup
podman *container.PodmanManager
jupyter JupyterManager
trackingRegistry *tracking.Registry
envPool *envpool.Pool
prewarmMu sync.Mutex
prewarmTargetID string
prewarmCancel context.CancelFunc
prewarmStartedAt time.Time
}
func envInt(name string) (int, bool) {
v := strings.TrimSpace(os.Getenv(name))
if v == "" {
return 0, false
}
n, err := strconv.Atoi(v)
if err != nil {
return 0, false
}
return n, true
}
func parseCPUFromConfig(cfg *Config) int {
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
return n
}
if cfg != nil {
if cfg.Resources.PodmanCPUs != "" {
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
if f < 0 {
return 0
}
return int(math.Floor(f))
}
}
}
return runtime.NumCPU()
}
func parseGPUCountFromConfig(cfg *Config) int {
factory := &GPUDetectorFactory{}
detector := factory.CreateDetector(cfg)
return detector.DetectGPUCount()
}
func (w *Worker) getGPUDetector() GPUDetector {
factory := &GPUDetectorFactory{}
return factory.CreateDetector(w.config)
}
func parseGPUSlotsPerGPUFromConfig() int {
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
return n
}
return 1
}
func (w *Worker) setupMetricsExporter() {
if !w.config.Metrics.Enabled {
return
}
reg := prometheus.NewRegistry()
reg.MustRegister(
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
collectors.NewGoCollector(),
)
labels := prometheus.Labels{"worker_id": w.id}
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_processed_total",
Help: "Total tasks processed successfully by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksProcessed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_failed_total",
Help: "Total tasks failed by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksFailed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_active",
Help: "Number of tasks currently running on this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.runningCount())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_queued",
Help: "Latest observed queue depth from Redis.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.QueuedTasks.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_transferred_bytes_total",
Help: "Total bytes transferred while fetching datasets.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataTransferred.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_fetch_time_seconds_total",
Help: "Total time spent fetching datasets (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_execution_time_seconds_total",
Help: "Total execution time for completed tasks (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_env_hit_total",
Help: "Total environment prewarm hits (warmed image already existed).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmEnvHit.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_env_miss_total",
Help: "Total environment prewarm misses (warmed image did not exist yet).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmEnvMiss.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_env_built_total",
Help: "Total environment prewarm images built.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmEnvBuilt.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_env_time_seconds_total",
Help: "Total time spent building prewarm images (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_snapshot_hit_total",
Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmSnapshotHit.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_snapshot_miss_total",
Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmSnapshotMiss.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_snapshot_built_total",
Help: "Total snapshots prewarmed into .prewarm/.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmSnapshotBuilt.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_prewarm_snapshot_time_seconds_total",
Help: "Total time spent prewarming snapshots (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_worker_max_concurrency",
Help: "Configured maximum concurrent tasks for this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.config.MaxWorkers)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_cpu_total",
Help: "Total CPU tokens managed by the worker resource manager.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().TotalCPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_cpu_free",
Help: "Free CPU tokens currently available in the worker resource manager.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().FreeCPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_total",
Help: "Total resource acquisition attempts.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_wait_total",
Help: "Total resource acquisitions that had to wait for resources.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireWaitTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_timeout_total",
Help: "Total resource acquisition attempts that timed out.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_wait_seconds_total",
Help: "Total seconds spent waiting for resources across all acquisitions.",
ConstLabels: labels,
}, func() float64 {
return w.resources.Snapshot().AcquireWaitSeconds
}))
snap := w.resources.Snapshot()
for i := range snap.GPUFree {
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
idx := i
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_gpu_slots_total",
Help: "Total GPU slots per GPU index.",
ConstLabels: gpuLabels,
}, func() float64 {
return float64(w.resources.Snapshot().SlotsPerGPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_gpu_slots_free",
Help: "Free GPU slots per GPU index.",
ConstLabels: gpuLabels,
}, func() float64 {
s := w.resources.Snapshot()
if idx < 0 || idx >= len(s.GPUFree) {
return 0
}
return float64(s.GPUFree[idx])
}))
}
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
srv := &http.Server{
Addr: w.config.Metrics.ListenAddr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
w.metricsSrv = srv
go func() {
w.logger.Info("metrics exporter listening",
"addr", w.config.Metrics.ListenAddr,
"enabled", w.config.Metrics.Enabled)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
w.logger.Warn("metrics exporter stopped",
"error", err)
}
}()
}
// NewWorker creates a new worker instance.
func NewWorker(cfg *Config, _ string) (*Worker, error) {
srv, err := NewMLServer(cfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := srv.Close(); closeErr != nil {
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
}
}
}()
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
SQLitePath: cfg.Queue.SQLitePath,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queueClient, err := queue.NewBackend(backendCfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := queueClient.Close(); closeErr != nil {
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
}
}
}()
// Create data_dir if it doesn't exist (for production without NAS)
if cfg.DataDir != "" {
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
}
}
ctx, cancel := context.WithCancel(context.Background())
defer func() {
if err != nil {
cancel()
}
}()
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
baseLogger := logging.NewLogger(slog.LevelInfo, false)
logger := baseLogger.Component(ctx, "worker")
metricsObj := &metrics.Metrics{}
podmanMgr, err := container.NewPodmanManager(logger)
if err != nil {
return nil, fmt.Errorf("failed to create podman manager: %w", err)
}
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
if err != nil {
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
}
trackingRegistry := tracking.NewRegistry(logger)
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
if len(cfg.Plugins) == 0 {
logger.Warn("no plugins configured, defining defaults")
// Register defaults manually for backward compatibility/local dev
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
logger,
podmanMgr,
trackingplugins.MLflowOptions{
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
},
)
if err == nil {
trackingRegistry.Register(mlflowPlugin)
}
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
logger,
podmanMgr,
trackingplugins.TensorBoardOptions{
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
},
)
if err == nil {
trackingRegistry.Register(tensorboardPlugin)
}
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
} else {
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
return nil, fmt.Errorf("failed to load plugins: %w", err)
}
}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
server: srv,
queue: queueClient,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
ctx: ctx,
cancel: cancel,
logger: logger,
metrics: metricsObj,
shutdownCh: make(chan struct{}),
podman: podmanMgr,
jupyter: jupyterMgr,
trackingRegistry: trackingRegistry,
envPool: envpool.New(""),
}
rm, rmErr := resources.NewManager(resources.Options{
TotalCPU: parseCPUFromConfig(cfg),
GPUCount: parseGPUCountFromConfig(cfg),
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
})
if rmErr != nil {
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
}
worker.resources = rm
if !cfg.LocalMode {
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
if cfg.AppleGPU.Enabled {
logger.Warn("apple MPS GPU mode is intended for development; do not use in production",
"gpu_type", "apple",
)
}
if gpuType == "amd" {
logger.Warn("amd GPU mode is intended for development; do not use in production",
"gpu_type", "amd",
)
}
}
worker.setupMetricsExporter()
// Pre-pull tracking images in background
go worker.prePullImages()
return worker, nil
}
func (w *Worker) prePullImages() {
if w.config.LocalMode {
return
}
w.logger.Info("starting image pre-pulling")
// Pull worker image
if w.config.PodmanImage != "" {
w.pullImage(w.config.PodmanImage)
}
// Pull plugin images
for name, cfg := range w.config.Plugins {
if !cfg.Enabled || cfg.Image == "" {
continue
}
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.pullImage(cfg.Image)
}
}
func (w *Worker) pullImage(image string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "podman", "pull", image)
if output, err := cmd.CombinedOutput(); err != nil {
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
} else {
w.logger.Info("image pulled successfully", "image", image)
}
}

View file

@ -0,0 +1,824 @@
package worker
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// NEW: Fetch datasets using data_manager.
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
logger := w.logger.Job(ctx, task.JobName, task.ID)
logger.Info("fetching datasets",
"worker_id", w.id,
"dataset_count", len(task.Datasets))
for _, dataset := range task.Datasets {
if w.datasetIsFresh(dataset) {
logger.Debug("skipping cached dataset",
"dataset", dataset)
continue
}
// Check for cancellation before each dataset fetch
select {
case <-ctx.Done():
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
default:
}
logger.Info("fetching dataset",
"worker_id", w.id,
"dataset", dataset)
// Create command with context for cancellation support
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
// Validate inputs to prevent command injection
if !isValidName(task.JobName) || !isValidName(dataset) {
cancel()
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
}
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
cmd := exec.CommandContext(cmdCtx,
w.config.DataManagerPath,
"fetch",
task.JobName,
dataset,
)
output, err := cmd.CombinedOutput()
cancel() // Clean up context
if err != nil {
return &errtypes.DataFetchError{
Dataset: dataset,
JobName: task.JobName,
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
}
}
logger.Info("dataset ready",
"worker_id", w.id,
"dataset", dataset)
w.markDatasetFetched(dataset)
}
return nil
}
func resolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
if len(task.DatasetSpecs) > 0 {
out := make([]string, 0, len(task.DatasetSpecs))
for _, ds := range task.DatasetSpecs {
if ds.Name != "" {
out = append(out, ds.Name)
}
}
if len(out) > 0 {
return out
}
}
if len(task.Datasets) > 0 {
return task.Datasets
}
return parseDatasets(task.Args)
}
func parseDatasets(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}
func (w *Worker) datasetIsFresh(dataset string) bool {
w.datasetCacheMu.RLock()
defer w.datasetCacheMu.RUnlock()
expires, ok := w.datasetCache[dataset]
return ok && time.Now().Before(expires)
}
func (w *Worker) markDatasetFetched(dataset string) {
expires := time.Now().Add(w.datasetCacheTTL)
w.datasetCacheMu.Lock()
w.datasetCache[dataset] = expires
w.datasetCacheMu.Unlock()
}
func (w *Worker) cancelPrewarmLocked() {
if w.prewarmCancel != nil {
w.prewarmCancel()
w.prewarmCancel = nil
}
w.prewarmTargetID = ""
}
func (w *Worker) prewarmNextLoop() {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.ctx == nil || w.queue == nil || w.metrics == nil {
return
}
// Phase 1: Best-effort prewarm of the next queued task.
// This must never be required for correctness.
runOnce := func() {
_, err := w.PrewarmNextOnce(w.ctx)
if err != nil {
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
}
}
// Run once immediately so prewarm doesn't lag behind the worker loop.
runOnce()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return
case <-ticker.C:
}
runOnce()
}
}
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
next, err := w.queue.PeekNextTask()
if err != nil {
return false, err
}
if next == nil {
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return false, nil
}
return w.prewarmTaskOnce(ctx, next)
}
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
if next == nil {
return false, nil
}
w.prewarmMu.Lock()
if w.prewarmTargetID == next.ID {
w.prewarmMu.Unlock()
return false, nil
}
w.cancelPrewarmLocked()
prewarmCtx, cancel := context.WithCancel(ctx)
w.prewarmCancel = cancel
w.prewarmTargetID = next.ID
w.prewarmStartedAt = time.Now()
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
phase := "datasets"
dsCnt := len(resolveDatasets(next))
snapID := next.SnapshotID
if strings.TrimSpace(snapID) != "" {
phase = "snapshot"
} else if dsCnt == 0 {
phase = "env"
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: next.ID,
SnapshotID: snapID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: phase,
DatasetCnt: dsCnt,
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.prewarmMu.Unlock()
w.logger.Info("prewarm started",
"worker_id", w.id,
"task_id", next.ID,
"snapshot_id", snapID,
"phase", phase,
)
local := *next
local.Datasets = resolveDatasets(&local)
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
hasEnv := false
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
if local.Metadata != nil {
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
commitID := strings.TrimSpace(local.Metadata["commit_id"])
if depsSHA != "" && commitID != "" {
expMgr := experiment.NewManager(w.config.BasePath)
hostWorkspace := expMgr.GetFilesPath(commitID)
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
hasEnv = true
}
}
}
}
}
if !hasSnapshot && !hasDatasets && !hasEnv {
_ = w.queue.ClearWorkerPrewarmState(w.id)
return false, nil
}
if hasSnapshot {
want := ""
if local.Metadata != nil {
want = local.Metadata["snapshot_sha256"]
}
start := time.Now()
src, err := ResolveSnapshot(
prewarmCtx,
w.config.DataDir,
&w.config.SnapshotStore,
local.SnapshotID,
want,
nil,
)
if err != nil {
return true, err
}
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
_ = os.RemoveAll(dst)
if err := copyDir(src, dst); err != nil {
return true, err
}
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
}
if hasDatasets {
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
return true, err
}
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: local.ID,
SnapshotID: local.SnapshotID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: "ready",
DatasetCnt: len(local.Datasets),
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.logger.Info("prewarm ready",
"worker_id", w.id,
"task_id", local.ID,
"snapshot_id", local.SnapshotID,
)
return true, nil
}
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.SnapshotID == "" {
return nil
}
if err := container.ValidateJobName(task.SnapshotID); err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
}
if task.Metadata == nil {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
}
if want == "" {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
path, err := ResolveSnapshot(
ctx,
w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
want,
nil,
)
if err != nil {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
got, err := dirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
}
if got != want {
return fmt.Errorf(
"snapshot %q: checksum mismatch: expected %s, got %s",
task.SnapshotID,
want,
got,
)
}
w.logger.Job(
ctx,
task.JobName,
task.ID,
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
return nil
}
func fileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil)), nil
}
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
checksum = strings.TrimSpace(checksum)
checksum = strings.TrimPrefix(checksum, "sha256:")
checksum = strings.TrimPrefix(checksum, "SHA256:")
checksum = strings.TrimSpace(checksum)
if checksum == "" {
return "", nil
}
if len(checksum) != 64 {
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
}
if _, err := hex.DecodeString(checksum); err != nil {
return "", fmt.Errorf("invalid sha256 hex: %w", err)
}
return strings.ToLower(checksum), nil
}
func dirOverallSHA256Hex(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Deterministic order.
for i := 0; i < len(files); i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
// Hash file hashes to avoid holding all bytes.
overall := sha256.New()
for _, rel := range files {
p := filepath.Join(root, rel)
sum, err := fileSHA256Hex(p)
if err != nil {
return "", err
}
overall.Write([]byte(sum))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if len(task.DatasetSpecs) == 0 {
return nil
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
for _, ds := range task.DatasetSpecs {
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
if err != nil {
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
}
if want == "" {
continue
}
if err := container.ValidateJobName(ds.Name); err != nil {
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
}
path := filepath.Join(w.config.DataDir, ds.Name)
got, err := dirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
}
if got != want {
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
}
logger.Info("dataset checksum verified", "dataset", ds.Name)
}
return nil
}
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
if task == nil {
return nil, fmt.Errorf("task is nil")
}
out := map[string]string{}
if task.SnapshotID != "" {
out["snapshot_id"] = task.SnapshotID
}
datasets := resolveDatasets(task)
if len(datasets) > 0 {
out["datasets"] = strings.Join(datasets, ",")
}
if len(task.DatasetSpecs) > 0 {
b, err := json.Marshal(task.DatasetSpecs)
if err != nil {
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
}
out["dataset_specs"] = string(b)
}
if task.Metadata == nil {
return out, nil
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return out, nil
}
expMgr := experiment.NewManager(basePath)
manifest, err := expMgr.ReadManifest(commitID)
if err == nil && manifest != nil && manifest.OverallSHA != "" {
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
}
filesPath := expMgr.GetFilesPath(commitID)
depName, err := selectDependencyManifest(filesPath)
if err == nil && depName != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := fileSHA256Hex(depPath)
if err == nil && sha != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
if task == nil {
return
}
prov, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
return
}
if len(prov) == 0 {
return
}
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
for k, v := range prov {
if v == "" {
continue
}
// Phase 1: best-effort only; do not error if overwriting.
task.Metadata[k] = v
}
}
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.Metadata == nil {
return fmt.Errorf("missing task metadata")
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return fmt.Errorf("missing commit_id")
}
current, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
return err
}
snapshotCur := ""
if task.SnapshotID != "" {
want := ""
if task.Metadata != nil {
want = task.Metadata["snapshot_sha256"]
}
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", nerr)
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else if wantNorm != "" {
resolved, err := ResolveSnapshot(
ctx, w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
wantNorm,
nil,
)
if err != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
} else {
sha, err := dirOverallSHA256Hex(resolved)
if err == nil {
snapshotCur = sha
} else if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
}
}
}
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
// compute from the local snapshot directory if it exists.
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
if sha, err := dirOverallSHA256Hex(localPath); err == nil {
snapshotCur = sha
}
}
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
type requiredField struct {
Key string
Cur string
}
required := []requiredField{
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
}
if task.SnapshotID != "" {
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
}
for _, f := range required {
want := strings.TrimSpace(task.Metadata[f.Key])
if f.Key == "snapshot_sha256" {
norm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
"snapshot_id", task.SnapshotID,
"error", nerr)
want = ""
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else {
want = norm
}
}
if want == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("missing provenance field; continuing due to best-effort mode",
"field", f.Key)
if f.Cur != "" {
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
}
continue
}
return fmt.Errorf("missing provenance field: %s", f.Key)
}
if f.Cur == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
"field", f.Key)
continue
}
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
}
if want != f.Cur {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("provenance mismatch; continuing due to best-effort mode",
"field", f.Key,
"expected", want,
"current", f.Cur)
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
continue
}
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
}
}
return nil
}
func selectDependencyManifest(filesPath string) (string, error) {
if filesPath == "" {
return "", fmt.Errorf("missing files path")
}
candidates := []string{
"environment.yml",
"environment.yaml",
"poetry.lock",
"pyproject.toml",
"requirements.txt",
}
for _, name := range candidates {
p := filepath.Join(filesPath, name)
if _, err := os.Stat(p); err == nil {
if name == "poetry.lock" {
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
if _, err := os.Stat(pyprojectPath); err != nil {
return "", fmt.Errorf(
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
}
}
return name, nil
}
}
return "", fmt.Errorf(
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
"poetry.lock, pyproject.toml, requirements.txt)")
}
// Exported wrappers for tests under tests/.
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
func SelectDependencyManifest(filesPath string) (string, error) {
return selectDependencyManifest(filesPath)
}
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
return normalizeSHA256ChecksumHex(checksum)
}
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
return computeTaskProvenance(basePath, task)
}
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
return w.enforceTaskProvenance(ctx, task)
}
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
return w.verifyDatasetSpecs(ctx, task)
}
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return w.verifySnapshot(ctx, task)
}
func NewTestWorker(cfg *Config) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
ctx, cancel := context.WithCancel(ctx)
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
queue: tq,
metrics: &metrics.Metrics{},
ctx: ctx,
cancel: cancel,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
w := NewTestWorkerWithQueue(cfg, tq)
w.jupyter = jm
return w
}

1029
internal/worker/execution.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,168 @@
package worker
import (
"os"
"path/filepath"
"strings"
)
// GPUType represents different GPU types
type GPUType string
const (
GPUTypeNVIDIA GPUType = "nvidia"
GPUTypeApple GPUType = "apple"
GPUTypeNone GPUType = "none"
)
// GPUDetector interface for detecting GPU availability
type GPUDetector interface {
DetectGPUCount() int
GetGPUType() GPUType
GetDevicePaths() []string
}
// NVIDIA GPUDetector implementation
type NVIDIADetector struct{}
func (d *NVIDIADetector) DetectGPUCount() int {
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
return n
}
// Could use nvidia-sml or other detection methods here
return 0
}
func (d *NVIDIADetector) GetGPUType() GPUType {
return GPUTypeNVIDIA
}
func (d *NVIDIADetector) GetDevicePaths() []string {
// Prefer standard NVIDIA device nodes when present.
patterns := []string{
"/dev/nvidiactl",
"/dev/nvidia-modeset",
"/dev/nvidia-uvm",
"/dev/nvidia-uvm-tools",
"/dev/nvidia*",
}
seen := make(map[string]struct{})
out := make([]string, 0, 8)
for _, pat := range patterns {
if filepath.Base(pat) == pat {
continue
}
if strings.Contains(pat, "*") {
matches, _ := filepath.Glob(pat)
for _, m := range matches {
if _, ok := seen[m]; ok {
continue
}
if _, err := os.Stat(m); err != nil {
continue
}
seen[m] = struct{}{}
out = append(out, m)
}
continue
}
if _, ok := seen[pat]; ok {
continue
}
if _, err := os.Stat(pat); err != nil {
continue
}
seen[pat] = struct{}{}
out = append(out, pat)
}
// Fallback for non-NVIDIA setups where only generic DRM device exists.
if len(out) == 0 {
if _, err := os.Stat("/dev/dri"); err == nil {
out = append(out, "/dev/dri")
}
}
return out
}
// Apple M-series GPUDetector implementation
type AppleDetector struct {
enabled bool
}
func (d *AppleDetector) DetectGPUCount() int {
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
return n
}
if d.enabled {
return 1
}
return 0
}
func (d *AppleDetector) GetGPUType() GPUType {
return GPUTypeApple
}
func (d *AppleDetector) GetDevicePaths() []string {
return []string{"/dev/metal", "/dev/mps"}
}
// None GPUDetector implementation
type NoneDetector struct{}
func (d *NoneDetector) DetectGPUCount() int {
return 0
}
func (d *NoneDetector) GetGPUType() GPUType {
return GPUTypeNone
}
func (d *NoneDetector) GetDevicePaths() []string {
return nil
}
// GPUDetectorFactory creates appropriate GPU detector based on config
type GPUDetectorFactory struct{}
func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector {
// Check for explicit environment override
if gpuType := os.Getenv("FETCH_ML_GPU_TYPE"); gpuType != "" {
switch gpuType {
case string(GPUTypeNVIDIA):
return &NVIDIADetector{}
case string(GPUTypeApple):
return &AppleDetector{enabled: true}
case string(GPUTypeNone):
return &NoneDetector{}
}
}
// Respect configured vendor when explicitly set.
if cfg != nil {
switch GPUType(cfg.GPUVendor) {
case GPUTypeApple:
return &AppleDetector{enabled: cfg.AppleGPU.Enabled}
case GPUTypeNone:
return &NoneDetector{}
case GPUTypeNVIDIA:
return &NVIDIADetector{}
case "amd":
// AMD uses similar device exposure patterns in this codebase.
return &NVIDIADetector{}
}
}
// Auto-detect based on config
if cfg != nil {
if cfg.AppleGPU.Enabled {
return &AppleDetector{enabled: true}
}
if len(cfg.GPUDevices) > 0 {
return &NVIDIADetector{}
}
}
// Default to no GPU
return &NoneDetector{}
}

View file

@ -0,0 +1,130 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
)
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
jupyterTaskOutputType = "jupyter_output"
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service *jupyter.JupyterService `json:"service,omitempty"`
Services []*jupyter.JupyterService `json:"services"`
RestorePath string `json:"restore_path,omitempty"`
}
func isJupyterTask(task *queue.Task) bool {
if task == nil || task.Metadata == nil {
return false
}
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
}
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w == nil {
return nil, fmt.Errorf("worker is nil")
}
if task == nil {
return nil, fmt.Errorf("task is nil")
}
if w.jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
if task.Metadata == nil {
return nil, fmt.Errorf("missing task metadata")
}
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
if action == "" {
return nil, fmt.Errorf("missing jupyter action")
}
// Validate job name since it is used as the task status key and shows up in logs.
if err := container.ValidateJobName(task.JobName); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
switch action {
case jupyterActionStart:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
if ws == "" {
return nil, fmt.Errorf("missing jupyter workspace")
}
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
return json.Marshal(out)
case jupyterActionStop:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionRemove:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionList:
services := w.jupyter.ListServices()
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
return json.Marshal(out)
case jupyterActionRestore:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
return json.Marshal(out)
default:
return nil, fmt.Errorf("invalid jupyter action: %s", action)
}
}
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
return w.runJupyterTask(ctx, task)
}

525
internal/worker/runloop.go Normal file
View file

@ -0,0 +1,525 @@
package worker
import (
"context"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
)
// Start starts the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker started",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers,
"poll_interval", w.config.PollInterval)
go w.heartbeat()
if w.config != nil && w.config.PrewarmEnabled {
go w.prewarmNextLoop()
go w.prewarmImageGCLoop()
}
for {
switch {
case w.ctx.Err() != nil:
w.logger.Info("shutdown signal received, waiting for tasks")
w.waitForTasks()
return
default:
}
if w.runningCount() >= w.config.MaxWorkers {
time.Sleep(50 * time.Millisecond)
continue
}
queueStart := time.Now()
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
task, err := w.queue.GetNextTaskWithLeaseBlocking(
w.config.WorkerID,
w.config.TaskLeaseDuration,
blockTimeout,
)
queueLatency := time.Since(queueStart)
if err != nil {
if err == context.DeadlineExceeded {
continue
}
w.logger.Error("error fetching task",
"worker_id", w.id,
"error", err)
continue
}
if task == nil {
if queueLatency > 200*time.Millisecond {
w.logger.Debug("queue poll latency",
"latency_ms", queueLatency.Milliseconds())
}
continue
}
if depth, derr := w.queue.QueueDepth(); derr == nil {
if queueLatency > 100*time.Millisecond || depth > 0 {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"remaining_depth", depth)
}
} else if queueLatency > 100*time.Millisecond {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"depth_error", derr)
}
// Reserve a running slot *before* starting the goroutine so we don't drain
// the entire queue while max_workers is 1.
w.reserveRunningSlot(task.ID)
go w.executeTaskWithLease(task)
}
}
func (w *Worker) reserveRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
w.running = make(map[string]context.CancelFunc)
}
// Track a cancel func for future shutdown handling; currently best-effort.
_, cancel := context.WithCancel(w.ctx)
w.running[taskID] = cancel
}
func (w *Worker) releaseRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
return
}
if cancel, ok := w.running[taskID]; ok {
cancel()
delete(w.running, taskID)
}
}
func (w *Worker) prewarmImageGCLoop() {
if w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.envPool == nil {
return
}
if w.config.LocalMode {
return
}
if strings.TrimSpace(w.config.PodmanImage) == "" {
return
}
lastSeen := ""
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
runGC := func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
}
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if w.queue != nil {
v, err := w.queue.PrewarmGCRequestValue()
if err == nil && v != "" && v != lastSeen {
lastSeen = v
runGC()
continue
}
}
runGC()
}
}
}
func (w *Worker) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if err := w.queue.Heartbeat(w.id); err != nil {
w.logger.Warn("heartbeat failed",
"worker_id", w.id,
"error", err)
}
}
}
}
func (w *Worker) waitForTasks() {
timeout := time.After(5 * time.Minute)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-timeout:
w.logger.Warn("shutdown timeout, force stopping",
"running_tasks", len(w.running))
return
case <-ticker.C:
count := w.runningCount()
if count == 0 {
w.logger.Info("all tasks completed, shutting down")
return
}
w.logger.Debug("waiting for tasks to complete",
"remaining", count)
}
}
}
func (w *Worker) runningCount() int {
w.runningMu.RLock()
defer w.runningMu.RUnlock()
return len(w.running)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
return stats
}
// Stop gracefully shuts down the worker.
func (w *Worker) Stop() {
w.cancel()
w.waitForTasks()
// FIXED: Check error return values
if err := w.server.Close(); err != nil {
w.logger.Warn("error closing server connection", "error", err)
}
if err := w.queue.Close(); err != nil {
w.logger.Warn("error closing queue connection", "error", err)
}
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics exporter shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Execute task with lease management and retry.
func (w *Worker) executeTaskWithLease(task *queue.Task) {
defer w.releaseRunningSlot(task.ID)
// Track task for graceful shutdown
w.gracefulWait.Add(1)
w.activeTasks.Store(task.ID, task)
defer w.gracefulWait.Done()
defer w.activeTasks.Delete(task.ID)
// Create task-specific context with timeout
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
defer taskCancel()
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
logger.Info("starting task",
"worker_id", w.id,
"datasets", task.Datasets,
"priority", task.Priority)
// Record task start
w.metrics.RecordTaskStart()
defer w.metrics.RecordTaskCompletion()
// Check for context cancellation
select {
case <-taskCtx.Done():
logger.Info("task cancelled before execution")
return
default:
}
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
if isJupyterTask(task) {
out, err := w.runJupyterTask(taskCtx, task)
endTime := time.Now()
task.EndedAt = &endTime
if err != nil {
logger.Error("jupyter task failed", "error", err)
task.Status = "failed"
task.Error = err.Error()
_ = w.queue.UpdateTaskWithMetrics(task, "final")
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if len(out) > 0 {
task.Output = string(out)
}
task.Status = "completed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
// Parse datasets from task arguments
task.Datasets = resolveDatasets(task)
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
logger.Error("task validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error("failed to update task status after validation failure", "error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
logger.Error("provenance validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after provenance validation failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
lease, err := w.resources.Acquire(taskCtx, task)
if err != nil {
logger.Error("resource acquisition failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after resource acquisition failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
defer lease.Release()
// Start heartbeat goroutine
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer cancelHeartbeat()
go w.heartbeatLoop(heartbeatCtx, task.ID)
// Update task status
task.Status = "running"
now := time.Now()
task.StartedAt = &now
task.WorkerID = w.id
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
w.recordTaskProvenance(taskCtx, task)
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
logger.Error("failed to update task status", "error", err)
w.metrics.RecordTaskFailure()
return
}
if w.config.AutoFetchData && len(task.Datasets) > 0 {
if err := w.fetchDatasets(taskCtx, task); err != nil {
logger.Error("data fetch failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
err := w.queue.UpdateTask(task)
if err != nil {
logger.Error("failed to update task status after data fetch failure", "error", err)
}
w.metrics.RecordTaskFailure()
return
}
}
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
logger.Error("dataset checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after dataset checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
if err := w.verifySnapshot(taskCtx, task); err != nil {
logger.Error("snapshot checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after snapshot checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
// Execute job with panic recovery
var execErr error
func() {
defer func() {
if r := recover(); r != nil {
execErr = fmt.Errorf("panic during execution: %v", r)
}
}()
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
execErr = w.runJob(taskCtx, task, cudaVisible)
}()
// Finalize task
endTime := time.Now()
task.EndedAt = &endTime
if execErr != nil {
task.Error = execErr.Error()
// Check if transient error (network, timeout, etc)
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
w.logger.Warn("task failed with transient error, will retry",
"task_id", task.ID,
"error", execErr,
"retry_count", task.RetryCount)
_ = w.queue.RetryTask(task)
} else {
task.Status = "failed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
} else {
task.Status = "completed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
// Release lease
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
}
// Heartbeat loop to renew lease.
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
ticker := time.NewTicker(w.config.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
return
}
// Also update worker heartbeat
_ = w.queue.Heartbeat(w.config.WorkerID)
}
}
}
// Shutdown gracefully shuts down the worker.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
// Wait for active tasks with timeout
done := make(chan struct{})
go func() {
w.gracefulWait.Wait()
close(done)
}()
timeout := time.After(w.config.GracefulTimeout)
select {
case <-done:
w.logger.Info("all tasks completed, shutdown successful")
case <-timeout:
w.logger.Warn("graceful shutdown timeout, releasing active leases")
w.releaseAllLeases()
}
return w.queue.Close()
}
// Release all active leases.
func (w *Worker) releaseAllLeases() {
w.activeTasks.Range(func(key, _ interface{}) bool {
taskID := key.(string)
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
}
return true
})
}
// Helper functions.
func (w *Worker) countActiveTasks() int {
count := 0
w.activeTasks.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
func isTransientError(err error) bool {
if err == nil {
return false
}
// Check if error is transient (network, timeout, resource unavailable, etc)
errStr := err.Error()
transientIndicators := []string{
"connection refused",
"timeout",
"temporary failure",
"resource temporarily unavailable",
"no such host",
"network unreachable",
}
for _, indicator := range transientIndicators {
if strings.Contains(strings.ToLower(errStr), indicator) {
return true
}
}
return false
}

View file

@ -0,0 +1,270 @@
package worker
import (
"archive/tar"
"compress/gzip"
"context"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type SnapshotFetcher interface {
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
}
type minioSnapshotFetcher struct {
client *minio.Client
}
func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) {
obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{})
if err != nil {
return nil, err
}
return obj, nil
}
func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) {
if cfg == nil {
return nil, fmt.Errorf("missing snapshot store config")
}
endpoint := strings.TrimSpace(cfg.Endpoint)
if endpoint == "" {
return nil, fmt.Errorf("missing snapshot store endpoint")
}
bucket := strings.TrimSpace(cfg.Bucket)
if bucket == "" {
return nil, fmt.Errorf("missing snapshot store bucket")
}
creds := cfg.credentials()
client, err := minio.New(endpoint, &minio.Options{
Creds: creds,
Secure: cfg.Secure,
Region: strings.TrimSpace(cfg.Region),
MaxRetries: cfg.MaxRetries,
})
if err != nil {
return nil, err
}
return &minioSnapshotFetcher{client: client}, nil
}
func (c *SnapshotStoreConfig) credentials() *credentials.Credentials {
if c != nil {
ak := strings.TrimSpace(c.AccessKey)
sk := strings.TrimSpace(c.SecretKey)
st := strings.TrimSpace(c.SessionToken)
if ak != "" && sk != "" {
return credentials.NewStaticV4(ak, sk, st)
}
}
return credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvMinio{},
&credentials.EnvAWS{},
})
}
func ResolveSnapshot(
ctx context.Context,
dataDir string,
cfg *SnapshotStoreConfig,
snapshotID string,
wantSHA256 string,
fetcher SnapshotFetcher,
) (string, error) {
dataDir = strings.TrimSpace(dataDir)
if dataDir == "" {
return "", fmt.Errorf("missing data_dir")
}
snapshotID = strings.TrimSpace(snapshotID)
if snapshotID == "" {
return "", fmt.Errorf("missing snapshot_id")
}
if err := container.ValidateJobName(snapshotID); err != nil {
return "", fmt.Errorf("invalid snapshot_id: %w", err)
}
want, err := normalizeSHA256ChecksumHex(wantSHA256)
if err != nil || want == "" {
return "", fmt.Errorf("invalid snapshot_sha256")
}
cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want)
if info, err := os.Stat(cacheDir); err == nil && info.IsDir() {
return cacheDir, nil
}
if cfg == nil || !cfg.Enabled {
return filepath.Join(dataDir, "snapshots", snapshotID), nil
}
bucket := strings.TrimSpace(cfg.Bucket)
if bucket == "" {
return "", fmt.Errorf("missing snapshot store bucket")
}
prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/")
key := snapshotID + ".tar.gz"
if prefix != "" {
key = path.Join(prefix, key)
}
if fetcher == nil {
mf, err := newMinioSnapshotFetcher(cfg)
if err != nil {
return "", err
}
fetcher = mf
}
fetchCtx := ctx
if cfg.Timeout > 0 {
var cancel context.CancelFunc
fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout)
defer cancel()
}
rc, err := fetcher.Get(fetchCtx, bucket, key)
if err != nil {
return "", err
}
defer func() { _ = rc.Close() }()
tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp")
if err := os.MkdirAll(tmpRoot, 0750); err != nil {
return "", err
}
workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-")
if err != nil {
return "", err
}
defer func() { _ = os.RemoveAll(workDir) }()
archivePath := filepath.Join(workDir, "snapshot.tar.gz")
f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return "", err
}
_, copyErr := io.Copy(f, rc)
closeErr := f.Close()
if copyErr != nil {
return "", copyErr
}
if closeErr != nil {
return "", closeErr
}
extractDir := filepath.Join(workDir, "extracted")
if err := os.MkdirAll(extractDir, 0750); err != nil {
return "", err
}
if err := extractTarGz(archivePath, extractDir); err != nil {
return "", err
}
got, err := dirOverallSHA256Hex(extractDir)
if err != nil {
return "", err
}
if got != want {
return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got)
}
if err := os.MkdirAll(filepath.Dir(cacheDir), 0750); err != nil {
return "", err
}
if err := os.Rename(extractDir, cacheDir); err != nil {
if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() {
return cacheDir, nil
}
return "", err
}
return cacheDir, nil
}
func extractTarGz(archivePath, dstDir string) error {
archivePath = filepath.Clean(archivePath)
dstDir = filepath.Clean(dstDir)
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer func() { _ = gz.Close() }()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
name := strings.TrimSpace(hdr.Name)
name = strings.TrimPrefix(name, "./")
clean := path.Clean(name)
if clean == "." {
continue
}
if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") {
return fmt.Errorf("invalid tar entry")
}
target, err := safeJoin(dstDir, filepath.FromSlash(clean))
if err != nil {
return err
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0750); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil {
return err
}
mode := hdr.FileInfo().Mode() & 0777
out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return err
}
if _, err := io.CopyN(out, tr, hdr.Size); err != nil {
_ = out.Close()
return err
}
if err := out.Close(); err != nil {
return err
}
default:
return fmt.Errorf("unsupported tar entry type")
}
}
return nil
}
func safeJoin(baseDir, rel string) (string, error) {
baseDir = filepath.Clean(baseDir)
joined := filepath.Join(baseDir, rel)
joined = filepath.Clean(joined)
basePrefix := baseDir + string(os.PathSeparator)
if joined != baseDir && !strings.HasPrefix(joined, basePrefix) {
return "", fmt.Errorf("invalid relative path")
}
return joined, nil
}