feat(worker): add integrity checks, snapshot staging, and prewarm support
This commit is contained in:
parent
add4a90e62
commit
82034c68f3
12 changed files with 4493 additions and 1145 deletions
|
|
@ -1,179 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMetricsFlushInterval = 500 * time.Millisecond
|
||||
datasetCacheDefaultTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// Config holds worker configuration.
|
||||
type Config struct {
|
||||
Host string `yaml:"host"`
|
||||
User string `yaml:"user"`
|
||||
SSHKey string `yaml:"ssh_key"`
|
||||
Port int `yaml:"port"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
TrainScript string `yaml:"train_script"`
|
||||
RedisAddr string `yaml:"redis_addr"`
|
||||
RedisPassword string `yaml:"redis_password"`
|
||||
RedisDB int `yaml:"redis_db"`
|
||||
KnownHosts string `yaml:"known_hosts"`
|
||||
WorkerID string `yaml:"worker_id"`
|
||||
MaxWorkers int `yaml:"max_workers"`
|
||||
PollInterval int `yaml:"poll_interval_seconds"`
|
||||
Resources config.ResourceConfig `yaml:"resources"`
|
||||
LocalMode bool `yaml:"local_mode"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.Config `yaml:"auth"`
|
||||
|
||||
// Metrics exporter
|
||||
Metrics MetricsConfig `yaml:"metrics"`
|
||||
// Metrics buffering
|
||||
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
|
||||
|
||||
// Data management
|
||||
DataManagerPath string `yaml:"data_manager_path"`
|
||||
AutoFetchData bool `yaml:"auto_fetch_data"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
|
||||
|
||||
// Podman execution
|
||||
PodmanImage string `yaml:"podman_image"`
|
||||
ContainerWorkspace string `yaml:"container_workspace"`
|
||||
ContainerResults string `yaml:"container_results"`
|
||||
GPUAccess bool `yaml:"gpu_access"`
|
||||
|
||||
// Task lease and retry settings
|
||||
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min)
|
||||
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min)
|
||||
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
||||
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min)
|
||||
}
|
||||
|
||||
// MetricsConfig controls the Prometheus exporter.
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
}
|
||||
|
||||
// LoadConfig loads worker configuration from a YAML file.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := fileutil.SecureFileRead(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get smart defaults for current environment
|
||||
smart := config.GetSmartDefaults()
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
}
|
||||
if cfg.WorkerID == "" {
|
||||
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
||||
}
|
||||
cfg.Resources.ApplyDefaults()
|
||||
if cfg.MaxWorkers > 0 {
|
||||
cfg.Resources.MaxWorkers = cfg.MaxWorkers
|
||||
} else {
|
||||
cfg.MaxWorkers = cfg.Resources.MaxWorkers
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = smart.PollInterval()
|
||||
}
|
||||
if cfg.DataManagerPath == "" {
|
||||
cfg.DataManagerPath = "./data_manager"
|
||||
}
|
||||
if cfg.DataDir == "" {
|
||||
if cfg.Host == "" || !cfg.AutoFetchData {
|
||||
cfg.DataDir = config.DefaultLocalDataDir
|
||||
} else {
|
||||
cfg.DataDir = smart.DataDir()
|
||||
}
|
||||
}
|
||||
if cfg.Metrics.ListenAddr == "" {
|
||||
cfg.Metrics.ListenAddr = ":9100"
|
||||
}
|
||||
if cfg.MetricsFlushInterval == 0 {
|
||||
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
|
||||
// Set lease and retry defaults
|
||||
if cfg.TaskLeaseDuration == 0 {
|
||||
cfg.TaskLeaseDuration = 30 * time.Minute
|
||||
}
|
||||
if cfg.HeartbeatInterval == 0 {
|
||||
cfg.HeartbeatInterval = 1 * time.Minute
|
||||
}
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 3
|
||||
}
|
||||
if cfg.GracefulTimeout == 0 {
|
||||
cfg.GracefulTimeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements config.Validator interface.
|
||||
func (c *Config) Validate() error {
|
||||
if c.Port != 0 {
|
||||
if err := config.ValidatePort(c.Port); err != nil {
|
||||
return fmt.Errorf("invalid SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = config.ExpandPath(c.BasePath)
|
||||
if !filepath.IsAbs(c.BasePath) {
|
||||
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.MaxWorkers < 1 {
|
||||
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Task struct and Redis constants moved to internal/queue
|
||||
File diff suppressed because it is too large
Load diff
288
internal/envpool/envpool.go
Normal file
288
internal/envpool/envpool.go
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
package envpool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CommandRunner interface {
|
||||
CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error)
|
||||
}
|
||||
|
||||
type execRunner struct{}
|
||||
|
||||
func (r execRunner) CombinedOutput(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
args ...string,
|
||||
) ([]byte, error) {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
runner CommandRunner
|
||||
|
||||
imagePrefix string
|
||||
|
||||
cacheMu sync.Mutex
|
||||
cache map[string]cacheEntry
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
exists bool
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
func New(imagePrefix string) *Pool {
|
||||
prefix := strings.TrimSpace(imagePrefix)
|
||||
if prefix == "" {
|
||||
prefix = "fetchml-prewarm"
|
||||
}
|
||||
return &Pool{
|
||||
runner: execRunner{},
|
||||
imagePrefix: prefix,
|
||||
cache: make(map[string]cacheEntry),
|
||||
cacheTTL: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) WithRunner(r CommandRunner) *Pool {
|
||||
if r != nil {
|
||||
p.runner = r
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) WithCacheTTL(ttl time.Duration) *Pool {
|
||||
if ttl > 0 {
|
||||
p.cacheTTL = ttl
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) WarmImageTag(depsManifestSHA256 string) (string, error) {
|
||||
sha := strings.TrimSpace(depsManifestSHA256)
|
||||
if sha == "" {
|
||||
return "", fmt.Errorf("missing deps sha256")
|
||||
}
|
||||
if !isLowerHexLen(sha, 64) {
|
||||
return "", fmt.Errorf("invalid deps sha256")
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", p.imagePrefix, sha[:12]), nil
|
||||
}
|
||||
|
||||
func (p *Pool) ImageExists(ctx context.Context, imageRef string) (bool, error) {
|
||||
ref := strings.TrimSpace(imageRef)
|
||||
if ref == "" {
|
||||
return false, fmt.Errorf("missing image ref")
|
||||
}
|
||||
|
||||
p.cacheMu.Lock()
|
||||
if ent, ok := p.cache[ref]; ok && time.Now().Before(ent.expires) {
|
||||
exists := ent.exists
|
||||
p.cacheMu.Unlock()
|
||||
return exists, nil
|
||||
}
|
||||
p.cacheMu.Unlock()
|
||||
|
||||
out, err := p.runner.CombinedOutput(ctx, "podman", "image", "inspect", ref)
|
||||
if err == nil {
|
||||
p.setCache(ref, true)
|
||||
return true, nil
|
||||
}
|
||||
if looksLikeImageNotFound(out) {
|
||||
p.setCache(ref, false)
|
||||
return false, nil
|
||||
}
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
p.setCache(ref, false)
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func looksLikeImageNotFound(out []byte) bool {
|
||||
s := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(s, "no such") ||
|
||||
strings.Contains(s, "not found") ||
|
||||
strings.Contains(s, "does not exist")
|
||||
}
|
||||
|
||||
func (p *Pool) setCache(imageRef string, exists bool) {
|
||||
p.cacheMu.Lock()
|
||||
p.cache[imageRef] = cacheEntry{exists: exists, expires: time.Now().Add(p.cacheTTL)}
|
||||
p.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
type PrepareRequest struct {
|
||||
BaseImage string
|
||||
TargetImage string
|
||||
HostWorkspace string
|
||||
ContainerWorkspace string
|
||||
DepsPathInContainer string
|
||||
}
|
||||
|
||||
func (p *Pool) PruneImages(ctx context.Context, olderThan time.Duration) error {
|
||||
if olderThan <= 0 {
|
||||
return fmt.Errorf("invalid olderThan")
|
||||
}
|
||||
h := int(olderThan.Round(time.Hour).Hours())
|
||||
if h < 1 {
|
||||
h = 1
|
||||
}
|
||||
until := fmt.Sprintf("%dh", h)
|
||||
_, err := p.runner.CombinedOutput(
|
||||
ctx,
|
||||
"podman",
|
||||
"image",
|
||||
"prune",
|
||||
"-a",
|
||||
"-f",
|
||||
"--filter",
|
||||
"label=fetchml.prewarm=true",
|
||||
"--filter",
|
||||
"until="+until,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Pool) Prepare(ctx context.Context, req PrepareRequest) error {
|
||||
baseImage := strings.TrimSpace(req.BaseImage)
|
||||
targetImage := strings.TrimSpace(req.TargetImage)
|
||||
hostWS := strings.TrimSpace(req.HostWorkspace)
|
||||
containerWS := strings.TrimSpace(req.ContainerWorkspace)
|
||||
depsInContainer := strings.TrimSpace(req.DepsPathInContainer)
|
||||
|
||||
if baseImage == "" {
|
||||
return fmt.Errorf("missing base image")
|
||||
}
|
||||
if targetImage == "" {
|
||||
return fmt.Errorf("missing target image")
|
||||
}
|
||||
if hostWS == "" {
|
||||
return fmt.Errorf("missing host workspace")
|
||||
}
|
||||
if containerWS == "" {
|
||||
return fmt.Errorf("missing container workspace")
|
||||
}
|
||||
if depsInContainer == "" {
|
||||
return fmt.Errorf("missing deps path")
|
||||
}
|
||||
if !strings.HasPrefix(depsInContainer, containerWS) {
|
||||
return fmt.Errorf("deps path must be under container workspace")
|
||||
}
|
||||
|
||||
exists, err := p.ImageExists(ctx, targetImage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
containerName, err := randomContainerName("fetchml-prewarm")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not use --rm since we need a container to commit.
|
||||
runArgs := []string{
|
||||
"run",
|
||||
"--name", containerName,
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--cap-drop", "ALL",
|
||||
"--userns", "keep-id",
|
||||
"-v", fmt.Sprintf("%s:%s:rw", hostWS, containerWS),
|
||||
baseImage,
|
||||
"--workspace", containerWS,
|
||||
"--deps", depsInContainer,
|
||||
"--prepare-only",
|
||||
}
|
||||
|
||||
if out, err := p.runner.CombinedOutput(ctx, "podman", runArgs...); err != nil {
|
||||
_ = p.cleanupContainer(context.Background(), containerName)
|
||||
return fmt.Errorf("podman run prewarm failed: %w", scrubOutput(out, err))
|
||||
}
|
||||
|
||||
if out, err := p.runner.CombinedOutput(
|
||||
ctx,
|
||||
"podman",
|
||||
"commit",
|
||||
containerName,
|
||||
targetImage,
|
||||
); err != nil {
|
||||
_ = p.cleanupContainer(context.Background(), containerName)
|
||||
return fmt.Errorf("podman commit prewarm failed: %w", scrubOutput(out, err))
|
||||
}
|
||||
_, _ = p.runner.CombinedOutput(
|
||||
ctx,
|
||||
"podman",
|
||||
"image",
|
||||
"label",
|
||||
targetImage,
|
||||
"fetchml.prewarm=true",
|
||||
)
|
||||
|
||||
_ = p.cleanupContainer(context.Background(), containerName)
|
||||
p.setCache(targetImage, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) cleanupContainer(ctx context.Context, name string) error {
|
||||
n := strings.TrimSpace(name)
|
||||
if n == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := p.runner.CombinedOutput(ctx, "podman", "rm", n)
|
||||
return err
|
||||
}
|
||||
|
||||
func randomContainerName(prefix string) (string, error) {
|
||||
p := strings.TrimSpace(prefix)
|
||||
if p == "" {
|
||||
p = "fetchml-prewarm"
|
||||
}
|
||||
b := make([]byte, 6)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", p, hex.EncodeToString(b)), nil
|
||||
}
|
||||
|
||||
func isLowerHexLen(s string, want int) bool {
|
||||
if len(s) != want {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func scrubOutput(out []byte, err error) error {
|
||||
if len(out) == 0 {
|
||||
return err
|
||||
}
|
||||
s := strings.TrimSpace(string(out))
|
||||
if len(s) > 400 {
|
||||
s = s[:400]
|
||||
}
|
||||
return fmt.Errorf("%w (output=%q)", err, s)
|
||||
}
|
||||
323
internal/resources/manager.go
Normal file
323
internal/resources/manager.go
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
package resources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
totalCPU int
|
||||
freeCPU int
|
||||
slotsPerGPU int
|
||||
gpuFree []int
|
||||
|
||||
acquireTotal atomic.Int64
|
||||
acquireWaitTotal atomic.Int64
|
||||
acquireTimeoutTotal atomic.Int64
|
||||
acquireWaitNanos atomic.Int64
|
||||
}
|
||||
|
||||
type Snapshot struct {
|
||||
TotalCPU int
|
||||
FreeCPU int
|
||||
SlotsPerGPU int
|
||||
GPUFree []int
|
||||
|
||||
AcquireTotal int64
|
||||
AcquireWaitTotal int64
|
||||
AcquireTimeoutTotal int64
|
||||
AcquireWaitSeconds float64
|
||||
}
|
||||
|
||||
func FormatCUDAVisibleDevices(lease *Lease) string {
|
||||
if lease == nil {
|
||||
return "-1"
|
||||
}
|
||||
if len(lease.gpus) == 0 {
|
||||
return "-1"
|
||||
}
|
||||
idx := make([]int, 0, len(lease.gpus))
|
||||
for _, g := range lease.gpus {
|
||||
idx = append(idx, g.Index)
|
||||
}
|
||||
sort.Ints(idx)
|
||||
parts := make([]string, 0, len(idx))
|
||||
for _, i := range idx {
|
||||
parts = append(parts, strconv.Itoa(i))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
type GPUAllocation struct {
|
||||
Index int
|
||||
Slots int
|
||||
}
|
||||
|
||||
type Lease struct {
|
||||
cpu int
|
||||
gpus []GPUAllocation
|
||||
m *Manager
|
||||
}
|
||||
|
||||
func (l *Lease) CPU() int { return l.cpu }
|
||||
|
||||
func (l *Lease) GPUs() []GPUAllocation {
|
||||
out := make([]GPUAllocation, len(l.gpus))
|
||||
copy(out, l.gpus)
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *Lease) Release() {
|
||||
if l == nil || l.m == nil {
|
||||
return
|
||||
}
|
||||
m := l.m
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if l.cpu > 0 {
|
||||
m.freeCPU += l.cpu
|
||||
if m.freeCPU > m.totalCPU {
|
||||
m.freeCPU = m.totalCPU
|
||||
}
|
||||
}
|
||||
for _, g := range l.gpus {
|
||||
if g.Index >= 0 && g.Index < len(m.gpuFree) {
|
||||
m.gpuFree[g.Index] += g.Slots
|
||||
if m.gpuFree[g.Index] > m.slotsPerGPU {
|
||||
m.gpuFree[g.Index] = m.slotsPerGPU
|
||||
}
|
||||
}
|
||||
}
|
||||
m.cond.Broadcast()
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
TotalCPU int
|
||||
GPUCount int
|
||||
SlotsPerGPU int
|
||||
}
|
||||
|
||||
func NewManager(opts Options) (*Manager, error) {
|
||||
if opts.TotalCPU < 0 {
|
||||
return nil, fmt.Errorf("total cpu must be >= 0")
|
||||
}
|
||||
if opts.GPUCount < 0 {
|
||||
return nil, fmt.Errorf("gpu count must be >= 0")
|
||||
}
|
||||
if opts.SlotsPerGPU <= 0 {
|
||||
opts.SlotsPerGPU = 1
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
totalCPU: opts.TotalCPU,
|
||||
freeCPU: opts.TotalCPU,
|
||||
slotsPerGPU: opts.SlotsPerGPU,
|
||||
gpuFree: make([]int, opts.GPUCount),
|
||||
}
|
||||
for i := range m.gpuFree {
|
||||
m.gpuFree[i] = m.slotsPerGPU
|
||||
}
|
||||
m.cond = sync.NewCond(&m.mu)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Snapshot() Snapshot {
|
||||
if m == nil {
|
||||
return Snapshot{}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
gpuFree := make([]int, len(m.gpuFree))
|
||||
copy(gpuFree, m.gpuFree)
|
||||
totalCPU := m.totalCPU
|
||||
freeCPU := m.freeCPU
|
||||
slotsPerGPU := m.slotsPerGPU
|
||||
m.mu.Unlock()
|
||||
|
||||
waitNanos := m.acquireWaitNanos.Load()
|
||||
return Snapshot{
|
||||
TotalCPU: totalCPU,
|
||||
FreeCPU: freeCPU,
|
||||
SlotsPerGPU: slotsPerGPU,
|
||||
GPUFree: gpuFree,
|
||||
AcquireTotal: m.acquireTotal.Load(),
|
||||
AcquireWaitTotal: m.acquireWaitTotal.Load(),
|
||||
AcquireTimeoutTotal: m.acquireTimeoutTotal.Load(),
|
||||
AcquireWaitSeconds: float64(waitNanos) / float64(time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Acquire(ctx context.Context, task *queue.Task) (*Lease, error) {
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("resource manager is nil")
|
||||
}
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("context is nil")
|
||||
}
|
||||
|
||||
m.acquireTotal.Add(1)
|
||||
start := time.Now()
|
||||
waited := false
|
||||
|
||||
reqCPU := task.CPU
|
||||
if reqCPU < 0 {
|
||||
return nil, fmt.Errorf("cpu request must be >= 0")
|
||||
}
|
||||
if reqCPU > m.totalCPU {
|
||||
return nil, fmt.Errorf("cpu request %d exceeds total cpu %d", reqCPU, m.totalCPU)
|
||||
}
|
||||
|
||||
reqGPU := task.GPU
|
||||
if reqGPU < 0 {
|
||||
return nil, fmt.Errorf("gpu request must be >= 0")
|
||||
}
|
||||
if reqGPU > len(m.gpuFree) {
|
||||
return nil, fmt.Errorf("gpu request %d exceeds available gpus %d", reqGPU, len(m.gpuFree))
|
||||
}
|
||||
|
||||
slotsPerTaskGPU, err := m.gpuSlotsForTask(task.GPUMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
m.acquireTimeoutTotal.Add(1)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
gpuAlloc, ok := m.tryAllocateGPUsLocked(reqGPU, slotsPerTaskGPU)
|
||||
if ok && (reqCPU == 0 || m.freeCPU >= reqCPU) {
|
||||
if reqCPU > 0 {
|
||||
m.freeCPU -= reqCPU
|
||||
}
|
||||
for _, g := range gpuAlloc {
|
||||
m.gpuFree[g.Index] -= g.Slots
|
||||
}
|
||||
if waited {
|
||||
m.acquireWaitTotal.Add(1)
|
||||
m.acquireWaitNanos.Add(time.Since(start).Nanoseconds())
|
||||
}
|
||||
return &Lease{cpu: reqCPU, gpus: gpuAlloc, m: m}, nil
|
||||
}
|
||||
waited = true
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.mu.Lock()
|
||||
m.cond.Broadcast()
|
||||
m.mu.Unlock()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
m.cond.Wait()
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) gpuSlotsForTask(gpuMem string) (int, error) {
|
||||
if m.slotsPerGPU <= 0 {
|
||||
return 1, nil
|
||||
}
|
||||
if strings.TrimSpace(gpuMem) == "" {
|
||||
return m.slotsPerGPU, nil
|
||||
}
|
||||
|
||||
if frac, ok := parseFraction(strings.TrimSpace(gpuMem)); ok {
|
||||
if frac <= 0 {
|
||||
return 1, nil
|
||||
}
|
||||
if frac > 1 {
|
||||
frac = 1
|
||||
}
|
||||
slots := int(math.Ceil(frac * float64(m.slotsPerGPU)))
|
||||
if slots < 1 {
|
||||
slots = 1
|
||||
}
|
||||
if slots > m.slotsPerGPU {
|
||||
slots = m.slotsPerGPU
|
||||
}
|
||||
return slots, nil
|
||||
}
|
||||
|
||||
return m.slotsPerGPU, nil
|
||||
}
|
||||
|
||||
func (m *Manager) tryAllocateGPUsLocked(reqGPU int, slotsPerTaskGPU int) ([]GPUAllocation, bool) {
|
||||
if reqGPU == 0 {
|
||||
return nil, true
|
||||
}
|
||||
if slotsPerTaskGPU <= 0 {
|
||||
slotsPerTaskGPU = m.slotsPerGPU
|
||||
}
|
||||
|
||||
alloc := make([]GPUAllocation, 0, reqGPU)
|
||||
used := make(map[int]struct{}, reqGPU)
|
||||
|
||||
for len(alloc) < reqGPU {
|
||||
bestIdx := -1
|
||||
bestFree := -1
|
||||
for i := 0; i < len(m.gpuFree); i++ {
|
||||
if _, ok := used[i]; ok {
|
||||
continue
|
||||
}
|
||||
free := m.gpuFree[i]
|
||||
if free >= slotsPerTaskGPU && free > bestFree {
|
||||
bestFree = free
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return nil, false
|
||||
}
|
||||
used[bestIdx] = struct{}{}
|
||||
alloc = append(alloc, GPUAllocation{Index: bestIdx, Slots: slotsPerTaskGPU})
|
||||
}
|
||||
return alloc, true
|
||||
}
|
||||
|
||||
func parseFraction(s string) (float64, bool) {
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
if strings.HasSuffix(s, "%") {
|
||||
v := strings.TrimSuffix(s, "%")
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return f / 100.0, true
|
||||
}
|
||||
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
if f > 1 {
|
||||
return 0, false
|
||||
}
|
||||
return f, true
|
||||
}
|
||||
371
internal/worker/config.go
Normal file
371
internal/worker/config.go
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMetricsFlushInterval = 500 * time.Millisecond
|
||||
datasetCacheDefaultTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type QueueConfig struct {
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
}
|
||||
|
||||
// Config holds worker configuration.
|
||||
type Config struct {
|
||||
Host string `yaml:"host"`
|
||||
User string `yaml:"user"`
|
||||
SSHKey string `yaml:"ssh_key"`
|
||||
Port int `yaml:"port"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
TrainScript string `yaml:"train_script"`
|
||||
RedisURL string `yaml:"redis_url"`
|
||||
RedisAddr string `yaml:"redis_addr"`
|
||||
RedisPassword string `yaml:"redis_password"`
|
||||
RedisDB int `yaml:"redis_db"`
|
||||
Queue QueueConfig `yaml:"queue"`
|
||||
KnownHosts string `yaml:"known_hosts"`
|
||||
WorkerID string `yaml:"worker_id"`
|
||||
MaxWorkers int `yaml:"max_workers"`
|
||||
PollInterval int `yaml:"poll_interval_seconds"`
|
||||
Resources config.ResourceConfig `yaml:"resources"`
|
||||
LocalMode bool `yaml:"local_mode"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.Config `yaml:"auth"`
|
||||
|
||||
// Metrics exporter
|
||||
Metrics MetricsConfig `yaml:"metrics"`
|
||||
// Metrics buffering
|
||||
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
|
||||
|
||||
// Data management
|
||||
DataManagerPath string `yaml:"data_manager_path"`
|
||||
AutoFetchData bool `yaml:"auto_fetch_data"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
|
||||
|
||||
SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"`
|
||||
|
||||
// Provenance enforcement
|
||||
// Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort.
|
||||
ProvenanceBestEffort bool `yaml:"provenance_best_effort"`
|
||||
|
||||
// Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env).
|
||||
PrewarmEnabled bool `yaml:"prewarm_enabled"`
|
||||
|
||||
// Podman execution
|
||||
PodmanImage string `yaml:"podman_image"`
|
||||
ContainerWorkspace string `yaml:"container_workspace"`
|
||||
ContainerResults string `yaml:"container_results"`
|
||||
GPUDevices []string `yaml:"gpu_devices"`
|
||||
GPUVendor string `yaml:"gpu_vendor"`
|
||||
GPUVisibleDevices []int `yaml:"gpu_visible_devices"`
|
||||
GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"`
|
||||
|
||||
// Apple M-series GPU configuration
|
||||
AppleGPU AppleGPUConfig `yaml:"apple_gpu"`
|
||||
|
||||
// Task lease and retry settings
|
||||
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min)
|
||||
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min)
|
||||
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
||||
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
|
||||
|
||||
// Plugins configuration
|
||||
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
|
||||
}
|
||||
|
||||
// MetricsConfig controls the Prometheus exporter.
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
}
|
||||
|
||||
type SnapshotStoreConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Secure bool `yaml:"secure"`
|
||||
Region string `yaml:"region"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
Prefix string `yaml:"prefix"`
|
||||
AccessKey string `yaml:"access_key"`
|
||||
SecretKey string `yaml:"secret_key"`
|
||||
SessionToken string `yaml:"session_token"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
MaxRetries int `yaml:"max_retries"`
|
||||
}
|
||||
|
||||
// AppleGPUConfig holds configuration for Apple M-series GPU support
|
||||
type AppleGPUConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MetalDevice string `yaml:"metal_device"`
|
||||
MPSRuntime string `yaml:"mps_runtime"`
|
||||
}
|
||||
|
||||
// LoadConfig loads worker configuration from a YAML file.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := fileutil.SecureFileRead(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.RedisURL) != "" {
|
||||
cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL))
|
||||
cfg.RedisAddr = cfg.RedisURL
|
||||
cfg.RedisPassword = ""
|
||||
cfg.RedisDB = 0
|
||||
}
|
||||
|
||||
// Get smart defaults for current environment
|
||||
smart := config.GetSmartDefaults()
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
}
|
||||
if cfg.WorkerID == "" {
|
||||
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
||||
}
|
||||
cfg.Resources.ApplyDefaults()
|
||||
if cfg.MaxWorkers > 0 {
|
||||
cfg.Resources.MaxWorkers = cfg.MaxWorkers
|
||||
} else {
|
||||
cfg.MaxWorkers = cfg.Resources.MaxWorkers
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = smart.PollInterval()
|
||||
}
|
||||
if cfg.DataManagerPath == "" {
|
||||
cfg.DataManagerPath = "./data_manager"
|
||||
}
|
||||
if cfg.DataDir == "" {
|
||||
if cfg.Host == "" || !cfg.AutoFetchData {
|
||||
cfg.DataDir = config.DefaultLocalDataDir
|
||||
} else {
|
||||
cfg.DataDir = smart.DataDir()
|
||||
}
|
||||
}
|
||||
if cfg.SnapshotStore.Timeout == 0 {
|
||||
cfg.SnapshotStore.Timeout = 10 * time.Minute
|
||||
}
|
||||
if cfg.SnapshotStore.MaxRetries == 0 {
|
||||
cfg.SnapshotStore.MaxRetries = 3
|
||||
}
|
||||
if cfg.Metrics.ListenAddr == "" {
|
||||
cfg.Metrics.ListenAddr = ":9100"
|
||||
}
|
||||
if cfg.MetricsFlushInterval == 0 {
|
||||
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.Queue.Backend) == "" {
|
||||
cfg.Queue.Backend = string(queue.QueueBackendRedis)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) {
|
||||
if strings.TrimSpace(cfg.Queue.SQLitePath) == "" {
|
||||
cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db")
|
||||
}
|
||||
cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.GPUVendor) == "" {
|
||||
if cfg.AppleGPU.Enabled {
|
||||
cfg.GPUVendor = string(GPUTypeApple)
|
||||
} else if len(cfg.GPUDevices) > 0 ||
|
||||
len(cfg.GPUVisibleDevices) > 0 ||
|
||||
len(cfg.GPUVisibleDeviceIDs) > 0 {
|
||||
cfg.GPUVendor = string(GPUTypeNVIDIA)
|
||||
} else {
|
||||
cfg.GPUVendor = string(GPUTypeNone)
|
||||
}
|
||||
}
|
||||
|
||||
// Set lease and retry defaults
|
||||
if cfg.TaskLeaseDuration == 0 {
|
||||
cfg.TaskLeaseDuration = 30 * time.Minute
|
||||
}
|
||||
if cfg.HeartbeatInterval == 0 {
|
||||
cfg.HeartbeatInterval = 1 * time.Minute
|
||||
}
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 3
|
||||
}
|
||||
if cfg.GracefulTimeout == 0 {
|
||||
cfg.GracefulTimeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements config.Validator interface.
|
||||
func (c *Config) Validate() error {
|
||||
if c.Port != 0 {
|
||||
if err := config.ValidatePort(c.Port); err != nil {
|
||||
return fmt.Errorf("invalid SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = config.ExpandPath(c.BasePath)
|
||||
if !filepath.IsAbs(c.BasePath) {
|
||||
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
|
||||
}
|
||||
}
|
||||
|
||||
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
|
||||
if backend == "" {
|
||||
backend = string(queue.QueueBackendRedis)
|
||||
c.Queue.Backend = backend
|
||||
}
|
||||
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) {
|
||||
return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite)
|
||||
}
|
||||
|
||||
if backend == string(queue.QueueBackendSQLite) {
|
||||
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
|
||||
return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite)
|
||||
}
|
||||
c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath)
|
||||
if !filepath.IsAbs(c.Queue.SQLitePath) {
|
||||
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
addr := strings.TrimSpace(c.RedisAddr)
|
||||
if strings.HasPrefix(addr, "redis://") {
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err)
|
||||
}
|
||||
if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("invalid Redis configuration: invalid redis url")
|
||||
}
|
||||
} else {
|
||||
if err := config.ValidateRedisAddr(addr); err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.MaxWorkers < 1 {
|
||||
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) {
|
||||
case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd":
|
||||
// ok
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"gpu_vendor must be one of %q, %q, %q, %q",
|
||||
string(GPUTypeNVIDIA),
|
||||
"amd",
|
||||
string(GPUTypeApple),
|
||||
string(GPUTypeNone),
|
||||
)
|
||||
}
|
||||
|
||||
// Strict GPU visibility configuration:
|
||||
// - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive.
|
||||
// - UUID-style gpu_visible_device_ids is NVIDIA-only.
|
||||
vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor))
|
||||
if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 {
|
||||
return fmt.Errorf("gpu_visible_devices and gpu_visible_device_ids are mutually exclusive")
|
||||
}
|
||||
if len(c.GPUVisibleDeviceIDs) > 0 {
|
||||
if vendor != string(GPUTypeNVIDIA) {
|
||||
return fmt.Errorf(
|
||||
"gpu_visible_device_ids is only supported when gpu_vendor is %q",
|
||||
string(GPUTypeNVIDIA),
|
||||
)
|
||||
}
|
||||
for _, id := range c.GPUVisibleDeviceIDs {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return fmt.Errorf("gpu_visible_device_ids contains an empty value")
|
||||
}
|
||||
if !strings.HasPrefix(id, "GPU-") {
|
||||
return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) {
|
||||
if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 {
|
||||
return fmt.Errorf(
|
||||
"gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q",
|
||||
vendor,
|
||||
)
|
||||
}
|
||||
}
|
||||
if vendor == "amd" {
|
||||
if len(c.GPUVisibleDeviceIDs) > 0 {
|
||||
return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor)
|
||||
}
|
||||
for _, idx := range c.GPUVisibleDevices {
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("gpu_visible_devices contains negative index %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.SnapshotStore.Enabled {
|
||||
if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" {
|
||||
return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true")
|
||||
}
|
||||
if strings.TrimSpace(c.SnapshotStore.Bucket) == "" {
|
||||
return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true")
|
||||
}
|
||||
ak := strings.TrimSpace(c.SnapshotStore.AccessKey)
|
||||
sk := strings.TrimSpace(c.SnapshotStore.SecretKey)
|
||||
if (ak == "") != (sk == "") {
|
||||
return fmt.Errorf(
|
||||
"snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty",
|
||||
)
|
||||
}
|
||||
if c.SnapshotStore.Timeout < 0 {
|
||||
return fmt.Errorf("snapshot_store.timeout must be >= 0")
|
||||
}
|
||||
if c.SnapshotStore.MaxRetries < 0 {
|
||||
return fmt.Errorf("snapshot_store.max_retries must be >= 0")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
547
internal/worker/core.go
Normal file
547
internal/worker/core.go
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||||
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility.
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
}
|
||||
|
||||
// JupyterManager is the subset of the Jupyter service manager used by the worker.
|
||||
// It exists to keep task execution testable.
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
}
|
||||
|
||||
// isValidName validates that input strings contain only safe characters.
|
||||
// isValidName checks if the input string is a valid name.
|
||||
func isValidName(input string) bool {
|
||||
return len(input) > 0 && len(input) < 256
|
||||
}
|
||||
|
||||
// NewMLServer creates a new ML server connection.
|
||||
// NewMLServer returns a new MLServer instance.
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
if cfg.LocalMode {
|
||||
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
|
||||
}
|
||||
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MLServer{SSHClient: client}, nil
|
||||
}
|
||||
|
||||
// Worker represents an ML task worker.
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
server *MLServer
|
||||
queue queue.Backend
|
||||
resources *resources.Manager
|
||||
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
||||
runningMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
|
||||
datasetCache map[string]time.Time
|
||||
datasetCacheMu sync.RWMutex
|
||||
datasetCacheTTL time.Duration
|
||||
|
||||
// Graceful shutdown fields
|
||||
shutdownCh chan struct{}
|
||||
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
||||
gracefulWait sync.WaitGroup
|
||||
|
||||
podman *container.PodmanManager
|
||||
jupyter JupyterManager
|
||||
trackingRegistry *tracking.Registry
|
||||
envPool *envpool.Pool
|
||||
|
||||
prewarmMu sync.Mutex
|
||||
prewarmTargetID string
|
||||
prewarmCancel context.CancelFunc
|
||||
prewarmStartedAt time.Time
|
||||
}
|
||||
|
||||
func envInt(name string) (int, bool) {
|
||||
v := strings.TrimSpace(os.Getenv(name))
|
||||
if v == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
|
||||
func parseCPUFromConfig(cfg *Config) int {
|
||||
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
|
||||
return n
|
||||
}
|
||||
if cfg != nil {
|
||||
if cfg.Resources.PodmanCPUs != "" {
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
|
||||
if f < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(math.Floor(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
|
||||
func parseGPUCountFromConfig(cfg *Config) int {
|
||||
factory := &GPUDetectorFactory{}
|
||||
detector := factory.CreateDetector(cfg)
|
||||
return detector.DetectGPUCount()
|
||||
}
|
||||
|
||||
func (w *Worker) getGPUDetector() GPUDetector {
|
||||
factory := &GPUDetectorFactory{}
|
||||
return factory.CreateDetector(w.config)
|
||||
}
|
||||
|
||||
func parseGPUSlotsPerGPUFromConfig() int {
|
||||
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (w *Worker) setupMetricsExporter() {
|
||||
if !w.config.Metrics.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
reg.MustRegister(
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
collectors.NewGoCollector(),
|
||||
)
|
||||
|
||||
labels := prometheus.Labels{"worker_id": w.id}
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_processed_total",
|
||||
Help: "Total tasks processed successfully by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksProcessed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_failed_total",
|
||||
Help: "Total tasks failed by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksFailed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_active",
|
||||
Help: "Number of tasks currently running on this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.runningCount())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_queued",
|
||||
Help: "Latest observed queue depth from Redis.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.QueuedTasks.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_transferred_bytes_total",
|
||||
Help: "Total bytes transferred while fetching datasets.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataTransferred.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_fetch_time_seconds_total",
|
||||
Help: "Total time spent fetching datasets (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_execution_time_seconds_total",
|
||||
Help: "Total execution time for completed tasks (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_hit_total",
|
||||
Help: "Total environment prewarm hits (warmed image already existed).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_miss_total",
|
||||
Help: "Total environment prewarm misses (warmed image did not exist yet).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_built_total",
|
||||
Help: "Total environment prewarm images built.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_env_time_seconds_total",
|
||||
Help: "Total time spent building prewarm images (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_hit_total",
|
||||
Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotHit.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_miss_total",
|
||||
Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotMiss.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_built_total",
|
||||
Help: "Total snapshots prewarmed into .prewarm/.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotBuilt.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_prewarm_snapshot_time_seconds_total",
|
||||
Help: "Total time spent prewarming snapshots (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_worker_max_concurrency",
|
||||
Help: "Configured maximum concurrent tasks for this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.config.MaxWorkers)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_total",
|
||||
Help: "Total CPU tokens managed by the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().TotalCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_free",
|
||||
Help: "Free CPU tokens currently available in the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().FreeCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_total",
|
||||
Help: "Total resource acquisition attempts.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_total",
|
||||
Help: "Total resource acquisitions that had to wait for resources.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireWaitTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_timeout_total",
|
||||
Help: "Total resource acquisition attempts that timed out.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_seconds_total",
|
||||
Help: "Total seconds spent waiting for resources across all acquisitions.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return w.resources.Snapshot().AcquireWaitSeconds
|
||||
}))
|
||||
|
||||
snap := w.resources.Snapshot()
|
||||
for i := range snap.GPUFree {
|
||||
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
|
||||
idx := i
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_total",
|
||||
Help: "Total GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().SlotsPerGPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_free",
|
||||
Help: "Free GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
s := w.resources.Snapshot()
|
||||
if idx < 0 || idx >= len(s.GPUFree) {
|
||||
return 0
|
||||
}
|
||||
return float64(s.GPUFree[idx])
|
||||
}))
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: w.config.Metrics.ListenAddr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
w.metricsSrv = srv
|
||||
go func() {
|
||||
w.logger.Info("metrics exporter listening",
|
||||
"addr", w.config.Metrics.ListenAddr,
|
||||
"enabled", w.config.Metrics.Enabled)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
w.logger.Warn("metrics exporter stopped",
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// NewWorker creates a new worker instance.
|
||||
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||
srv, err := NewMLServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := srv.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
backendCfg := queue.BackendConfig{
|
||||
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
SQLitePath: cfg.Queue.SQLitePath,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := queueClient.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create data_dir if it doesn't exist (for production without NAS)
|
||||
if cfg.DataDir != "" {
|
||||
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
||||
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
ctx = logging.EnsureTrace(ctx)
|
||||
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
||||
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
metricsObj := &metrics.Metrics{}
|
||||
|
||||
podmanMgr, err := container.NewPodmanManager(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create podman manager: %w", err)
|
||||
}
|
||||
|
||||
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
|
||||
}
|
||||
|
||||
trackingRegistry := tracking.NewRegistry(logger)
|
||||
pluginLoader := factory.NewPluginLoader(logger, podmanMgr)
|
||||
|
||||
if len(cfg.Plugins) == 0 {
|
||||
logger.Warn("no plugins configured, defining defaults")
|
||||
// Register defaults manually for backward compatibility/local dev
|
||||
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.MLflowOptions{
|
||||
ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(mlflowPlugin)
|
||||
}
|
||||
|
||||
tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
trackingplugins.TensorBoardOptions{
|
||||
LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"),
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
trackingRegistry.Register(tensorboardPlugin)
|
||||
}
|
||||
|
||||
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
|
||||
} else {
|
||||
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
|
||||
return nil, fmt.Errorf("failed to load plugins: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
server: srv,
|
||||
queue: queueClient,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
metrics: metricsObj,
|
||||
shutdownCh: make(chan struct{}),
|
||||
podman: podmanMgr,
|
||||
jupyter: jupyterMgr,
|
||||
trackingRegistry: trackingRegistry,
|
||||
envPool: envpool.New(""),
|
||||
}
|
||||
|
||||
rm, rmErr := resources.NewManager(resources.Options{
|
||||
TotalCPU: parseCPUFromConfig(cfg),
|
||||
GPUCount: parseGPUCountFromConfig(cfg),
|
||||
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
|
||||
})
|
||||
if rmErr != nil {
|
||||
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
|
||||
}
|
||||
worker.resources = rm
|
||||
|
||||
if !cfg.LocalMode {
|
||||
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
||||
if cfg.AppleGPU.Enabled {
|
||||
logger.Warn("apple MPS GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "apple",
|
||||
)
|
||||
}
|
||||
if gpuType == "amd" {
|
||||
logger.Warn("amd GPU mode is intended for development; do not use in production",
|
||||
"gpu_type", "amd",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
worker.setupMetricsExporter()
|
||||
|
||||
// Pre-pull tracking images in background
|
||||
go worker.prePullImages()
|
||||
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
func (w *Worker) prePullImages() {
|
||||
if w.config.LocalMode {
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("starting image pre-pulling")
|
||||
|
||||
// Pull worker image
|
||||
if w.config.PodmanImage != "" {
|
||||
w.pullImage(w.config.PodmanImage)
|
||||
}
|
||||
|
||||
// Pull plugin images
|
||||
for name, cfg := range w.config.Plugins {
|
||||
if !cfg.Enabled || cfg.Image == "" {
|
||||
continue
|
||||
}
|
||||
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||
w.pullImage(cfg.Image)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) pullImage(image string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||
} else {
|
||||
w.logger.Info("image pulled successfully", "image", image)
|
||||
}
|
||||
}
|
||||
824
internal/worker/data_integrity.go
Normal file
824
internal/worker/data_integrity.go
Normal file
|
|
@ -0,0 +1,824 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/errtypes"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// NEW: Fetch datasets using data_manager.
|
||||
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
logger.Info("fetching datasets",
|
||||
"worker_id", w.id,
|
||||
"dataset_count", len(task.Datasets))
|
||||
|
||||
for _, dataset := range task.Datasets {
|
||||
if w.datasetIsFresh(dataset) {
|
||||
logger.Debug("skipping cached dataset",
|
||||
"dataset", dataset)
|
||||
continue
|
||||
}
|
||||
// Check for cancellation before each dataset fetch
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("fetching dataset",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
|
||||
// Create command with context for cancellation support
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
// Validate inputs to prevent command injection
|
||||
if !isValidName(task.JobName) || !isValidName(dataset) {
|
||||
cancel()
|
||||
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
|
||||
}
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
|
||||
cmd := exec.CommandContext(cmdCtx,
|
||||
w.config.DataManagerPath,
|
||||
"fetch",
|
||||
task.JobName,
|
||||
dataset,
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
cancel() // Clean up context
|
||||
|
||||
if err != nil {
|
||||
return &errtypes.DataFetchError{
|
||||
Dataset: dataset,
|
||||
JobName: task.JobName,
|
||||
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("dataset ready",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
w.markDatasetFetched(dataset)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveDatasets(task *queue.Task) []string {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
out := make([]string, 0, len(task.DatasetSpecs))
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
if ds.Name != "" {
|
||||
out = append(out, ds.Name)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
if len(task.Datasets) > 0 {
|
||||
return task.Datasets
|
||||
}
|
||||
return parseDatasets(task.Args)
|
||||
}
|
||||
|
||||
func parseDatasets(args string) []string {
|
||||
if !strings.Contains(args, "--datasets") {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(args)
|
||||
for i, part := range parts {
|
||||
if part == "--datasets" && i+1 < len(parts) {
|
||||
return strings.Split(parts[i+1], ",")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) datasetIsFresh(dataset string) bool {
|
||||
w.datasetCacheMu.RLock()
|
||||
defer w.datasetCacheMu.RUnlock()
|
||||
expires, ok := w.datasetCache[dataset]
|
||||
return ok && time.Now().Before(expires)
|
||||
}
|
||||
|
||||
func (w *Worker) markDatasetFetched(dataset string) {
|
||||
expires := time.Now().Add(w.datasetCacheTTL)
|
||||
w.datasetCacheMu.Lock()
|
||||
w.datasetCache[dataset] = expires
|
||||
w.datasetCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Worker) cancelPrewarmLocked() {
|
||||
if w.prewarmCancel != nil {
|
||||
w.prewarmCancel()
|
||||
w.prewarmCancel = nil
|
||||
}
|
||||
w.prewarmTargetID = ""
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmNextLoop() {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return
|
||||
}
|
||||
if w.ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return
|
||||
}
|
||||
// Phase 1: Best-effort prewarm of the next queued task.
|
||||
// This must never be required for correctness.
|
||||
runOnce := func() {
|
||||
_, err := w.PrewarmNextOnce(w.ctx)
|
||||
if err != nil {
|
||||
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run once immediately so prewarm doesn't lag behind the worker loop.
|
||||
runOnce()
|
||||
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
w.prewarmMu.Lock()
|
||||
w.cancelPrewarmLocked()
|
||||
w.prewarmMu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
runOnce()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return false, nil
|
||||
}
|
||||
if ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
next, err := w.queue.PeekNextTask()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if next == nil {
|
||||
w.prewarmMu.Lock()
|
||||
w.cancelPrewarmLocked()
|
||||
w.prewarmMu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return w.prewarmTaskOnce(ctx, next)
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return false, nil
|
||||
}
|
||||
if ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return false, nil
|
||||
}
|
||||
if next == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
w.prewarmMu.Lock()
|
||||
if w.prewarmTargetID == next.ID {
|
||||
w.prewarmMu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
w.cancelPrewarmLocked()
|
||||
prewarmCtx, cancel := context.WithCancel(ctx)
|
||||
w.prewarmCancel = cancel
|
||||
w.prewarmTargetID = next.ID
|
||||
w.prewarmStartedAt = time.Now()
|
||||
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
|
||||
phase := "datasets"
|
||||
dsCnt := len(resolveDatasets(next))
|
||||
snapID := next.SnapshotID
|
||||
if strings.TrimSpace(snapID) != "" {
|
||||
phase = "snapshot"
|
||||
} else if dsCnt == 0 {
|
||||
phase = "env"
|
||||
}
|
||||
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
||||
WorkerID: w.id,
|
||||
TaskID: next.ID,
|
||||
SnapshotID: snapID,
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Phase: phase,
|
||||
DatasetCnt: dsCnt,
|
||||
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
||||
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
||||
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
||||
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
||||
})
|
||||
w.prewarmMu.Unlock()
|
||||
|
||||
w.logger.Info("prewarm started",
|
||||
"worker_id", w.id,
|
||||
"task_id", next.ID,
|
||||
"snapshot_id", snapID,
|
||||
"phase", phase,
|
||||
)
|
||||
|
||||
local := *next
|
||||
local.Datasets = resolveDatasets(&local)
|
||||
|
||||
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
|
||||
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
|
||||
hasEnv := false
|
||||
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
|
||||
if local.Metadata != nil {
|
||||
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
|
||||
commitID := strings.TrimSpace(local.Metadata["commit_id"])
|
||||
if depsSHA != "" && commitID != "" {
|
||||
expMgr := experiment.NewManager(w.config.BasePath)
|
||||
hostWorkspace := expMgr.GetFilesPath(commitID)
|
||||
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
|
||||
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
|
||||
hasEnv = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasSnapshot && !hasDatasets && !hasEnv {
|
||||
_ = w.queue.ClearWorkerPrewarmState(w.id)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if hasSnapshot {
|
||||
want := ""
|
||||
if local.Metadata != nil {
|
||||
want = local.Metadata["snapshot_sha256"]
|
||||
}
|
||||
start := time.Now()
|
||||
src, err := ResolveSnapshot(
|
||||
prewarmCtx,
|
||||
w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
local.SnapshotID,
|
||||
want,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
|
||||
_ = os.RemoveAll(dst)
|
||||
if err := copyDir(src, dst); err != nil {
|
||||
return true, err
|
||||
}
|
||||
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
|
||||
}
|
||||
|
||||
if hasDatasets {
|
||||
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
|
||||
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
||||
WorkerID: w.id,
|
||||
TaskID: local.ID,
|
||||
SnapshotID: local.SnapshotID,
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Phase: "ready",
|
||||
DatasetCnt: len(local.Datasets),
|
||||
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
||||
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
||||
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
||||
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
||||
})
|
||||
|
||||
w.logger.Info("prewarm ready",
|
||||
"worker_id", w.id,
|
||||
"task_id", local.ID,
|
||||
"snapshot_id", local.SnapshotID,
|
||||
)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if task.SnapshotID == "" {
|
||||
return nil
|
||||
}
|
||||
if err := container.ValidateJobName(task.SnapshotID); err != nil {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
||||
}
|
||||
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
|
||||
}
|
||||
if want == "" {
|
||||
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
||||
}
|
||||
path, err := ResolveSnapshot(
|
||||
ctx,
|
||||
w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
task.SnapshotID,
|
||||
want,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
got, err := dirOverallSHA256Hex(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
if got != want {
|
||||
return fmt.Errorf(
|
||||
"snapshot %q: checksum mismatch: expected %s, got %s",
|
||||
task.SnapshotID,
|
||||
want,
|
||||
got,
|
||||
)
|
||||
}
|
||||
w.logger.Job(
|
||||
ctx,
|
||||
task.JobName,
|
||||
task.ID,
|
||||
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileSHA256Hex(path string) (string, error) {
|
||||
f, err := os.Open(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
checksum = strings.TrimPrefix(checksum, "sha256:")
|
||||
checksum = strings.TrimPrefix(checksum, "SHA256:")
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
if checksum == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(checksum) != 64 {
|
||||
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
|
||||
}
|
||||
if _, err := hex.DecodeString(checksum); err != nil {
|
||||
return "", fmt.Errorf("invalid sha256 hex: %w", err)
|
||||
}
|
||||
return strings.ToLower(checksum), nil
|
||||
}
|
||||
|
||||
func dirOverallSHA256Hex(root string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
var files []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Deterministic order.
|
||||
for i := 0; i < len(files); i++ {
|
||||
for j := i + 1; j < len(files); j++ {
|
||||
if files[i] > files[j] {
|
||||
files[i], files[j] = files[j], files[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hash file hashes to avoid holding all bytes.
|
||||
overall := sha256.New()
|
||||
for _, rel := range files {
|
||||
p := filepath.Join(root, rel)
|
||||
sum, err := fileSHA256Hex(p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
overall.Write([]byte(sum))
|
||||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if len(task.DatasetSpecs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
|
||||
}
|
||||
if want == "" {
|
||||
continue
|
||||
}
|
||||
if err := container.ValidateJobName(ds.Name); err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
|
||||
}
|
||||
path := filepath.Join(w.config.DataDir, ds.Name)
|
||||
got, err := dirOverallSHA256Hex(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
|
||||
}
|
||||
if got != want {
|
||||
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
|
||||
}
|
||||
logger.Info("dataset checksum verified", "dataset", ds.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
out := map[string]string{}
|
||||
|
||||
if task.SnapshotID != "" {
|
||||
out["snapshot_id"] = task.SnapshotID
|
||||
}
|
||||
|
||||
datasets := resolveDatasets(task)
|
||||
if len(datasets) > 0 {
|
||||
out["datasets"] = strings.Join(datasets, ",")
|
||||
}
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
b, err := json.Marshal(task.DatasetSpecs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
|
||||
}
|
||||
out["dataset_specs"] = string(b)
|
||||
}
|
||||
|
||||
if task.Metadata == nil {
|
||||
return out, nil
|
||||
}
|
||||
commitID := task.Metadata["commit_id"]
|
||||
if commitID == "" {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
expMgr := experiment.NewManager(basePath)
|
||||
manifest, err := expMgr.ReadManifest(commitID)
|
||||
if err == nil && manifest != nil && manifest.OverallSHA != "" {
|
||||
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
|
||||
}
|
||||
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
depName, err := selectDependencyManifest(filesPath)
|
||||
if err == nil && depName != "" {
|
||||
depPath := filepath.Join(filesPath, depName)
|
||||
sha, err := fileSHA256Hex(depPath)
|
||||
if err == nil && sha != "" {
|
||||
out["deps_manifest_name"] = depName
|
||||
out["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
prov, err := computeTaskProvenance(w.config.BasePath, task)
|
||||
if err != nil {
|
||||
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
|
||||
return
|
||||
}
|
||||
if len(prov) == 0 {
|
||||
return
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
task.Metadata = map[string]string{}
|
||||
}
|
||||
for k, v := range prov {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
// Phase 1: best-effort only; do not error if overwriting.
|
||||
task.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return fmt.Errorf("missing task metadata")
|
||||
}
|
||||
commitID := task.Metadata["commit_id"]
|
||||
if commitID == "" {
|
||||
return fmt.Errorf("missing commit_id")
|
||||
}
|
||||
|
||||
current, err := computeTaskProvenance(w.config.BasePath, task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
snapshotCur := ""
|
||||
if task.SnapshotID != "" {
|
||||
want := ""
|
||||
if task.Metadata != nil {
|
||||
want = task.Metadata["snapshot_sha256"]
|
||||
}
|
||||
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
|
||||
if nerr != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", nerr)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
||||
}
|
||||
} else if wantNorm != "" {
|
||||
resolved, err := ResolveSnapshot(
|
||||
ctx, w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
task.SnapshotID,
|
||||
wantNorm,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", err)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
} else {
|
||||
sha, err := dirOverallSHA256Hex(resolved)
|
||||
if err == nil {
|
||||
snapshotCur = sha
|
||||
} else if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", err)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
|
||||
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
|
||||
// compute from the local snapshot directory if it exists.
|
||||
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
|
||||
if sha, err := dirOverallSHA256Hex(localPath); err == nil {
|
||||
snapshotCur = sha
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
|
||||
type requiredField struct {
|
||||
Key string
|
||||
Cur string
|
||||
}
|
||||
required := []requiredField{
|
||||
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
|
||||
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
|
||||
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
|
||||
}
|
||||
if task.SnapshotID != "" {
|
||||
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
|
||||
}
|
||||
|
||||
for _, f := range required {
|
||||
want := strings.TrimSpace(task.Metadata[f.Key])
|
||||
if f.Key == "snapshot_sha256" {
|
||||
norm, nerr := normalizeSHA256ChecksumHex(want)
|
||||
if nerr != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", nerr)
|
||||
want = ""
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
||||
}
|
||||
} else {
|
||||
want = norm
|
||||
}
|
||||
}
|
||||
if want == "" {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("missing provenance field; continuing due to best-effort mode",
|
||||
"field", f.Key)
|
||||
if f.Cur != "" {
|
||||
if f.Key == "snapshot_sha256" {
|
||||
task.Metadata[f.Key] = "sha256:" + f.Cur
|
||||
} else {
|
||||
task.Metadata[f.Key] = f.Cur
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("missing provenance field: %s", f.Key)
|
||||
}
|
||||
if f.Cur == "" {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
|
||||
"field", f.Key)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
|
||||
}
|
||||
if want != f.Cur {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("provenance mismatch; continuing due to best-effort mode",
|
||||
"field", f.Key,
|
||||
"expected", want,
|
||||
"current", f.Cur)
|
||||
if f.Key == "snapshot_sha256" {
|
||||
task.Metadata[f.Key] = "sha256:" + f.Cur
|
||||
} else {
|
||||
task.Metadata[f.Key] = f.Cur
|
||||
}
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectDependencyManifest(filesPath string) (string, error) {
|
||||
if filesPath == "" {
|
||||
return "", fmt.Errorf("missing files path")
|
||||
}
|
||||
candidates := []string{
|
||||
"environment.yml",
|
||||
"environment.yaml",
|
||||
"poetry.lock",
|
||||
"pyproject.toml",
|
||||
"requirements.txt",
|
||||
}
|
||||
for _, name := range candidates {
|
||||
p := filepath.Join(filesPath, name)
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
if name == "poetry.lock" {
|
||||
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
|
||||
if _, err := os.Stat(pyprojectPath); err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
|
||||
}
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf(
|
||||
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
|
||||
"poetry.lock, pyproject.toml, requirements.txt)")
|
||||
}
|
||||
|
||||
// Exported wrappers for tests under tests/.
|
||||
|
||||
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
|
||||
|
||||
func SelectDependencyManifest(filesPath string) (string, error) {
|
||||
return selectDependencyManifest(filesPath)
|
||||
}
|
||||
|
||||
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
||||
return normalizeSHA256ChecksumHex(checksum)
|
||||
}
|
||||
|
||||
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
|
||||
|
||||
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
||||
return computeTaskProvenance(basePath, task)
|
||||
}
|
||||
|
||||
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
||||
return w.enforceTaskProvenance(ctx, task)
|
||||
}
|
||||
|
||||
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||
return w.verifyDatasetSpecs(ctx, task)
|
||||
}
|
||||
|
||||
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||
return w.verifySnapshot(ctx, task)
|
||||
}
|
||||
|
||||
func NewTestWorker(cfg *Config) *Worker {
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
ctx := logging.EnsureTrace(context.Background())
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
return &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
ctx := logging.EnsureTrace(context.Background())
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
return &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
queue: tq,
|
||||
metrics: &metrics.Metrics{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
|
||||
w := NewTestWorkerWithQueue(cfg, tq)
|
||||
w.jupyter = jm
|
||||
return w
|
||||
}
|
||||
1029
internal/worker/execution.go
Normal file
1029
internal/worker/execution.go
Normal file
File diff suppressed because it is too large
Load diff
168
internal/worker/gpu_detector.go
Normal file
168
internal/worker/gpu_detector.go
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GPUType represents different GPU types
|
||||
type GPUType string
|
||||
|
||||
const (
|
||||
GPUTypeNVIDIA GPUType = "nvidia"
|
||||
GPUTypeApple GPUType = "apple"
|
||||
GPUTypeNone GPUType = "none"
|
||||
)
|
||||
|
||||
// GPUDetector interface for detecting GPU availability
|
||||
type GPUDetector interface {
|
||||
DetectGPUCount() int
|
||||
GetGPUType() GPUType
|
||||
GetDevicePaths() []string
|
||||
}
|
||||
|
||||
// NVIDIA GPUDetector implementation
|
||||
type NVIDIADetector struct{}
|
||||
|
||||
func (d *NVIDIADetector) DetectGPUCount() int {
|
||||
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
|
||||
return n
|
||||
}
|
||||
// Could use nvidia-sml or other detection methods here
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *NVIDIADetector) GetGPUType() GPUType {
|
||||
return GPUTypeNVIDIA
|
||||
}
|
||||
|
||||
func (d *NVIDIADetector) GetDevicePaths() []string {
|
||||
// Prefer standard NVIDIA device nodes when present.
|
||||
patterns := []string{
|
||||
"/dev/nvidiactl",
|
||||
"/dev/nvidia-modeset",
|
||||
"/dev/nvidia-uvm",
|
||||
"/dev/nvidia-uvm-tools",
|
||||
"/dev/nvidia*",
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, 8)
|
||||
for _, pat := range patterns {
|
||||
if filepath.Base(pat) == pat {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(pat, "*") {
|
||||
matches, _ := filepath.Glob(pat)
|
||||
for _, m := range matches {
|
||||
if _, ok := seen[m]; ok {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(m); err != nil {
|
||||
continue
|
||||
}
|
||||
seen[m] = struct{}{}
|
||||
out = append(out, m)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[pat]; ok {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(pat); err != nil {
|
||||
continue
|
||||
}
|
||||
seen[pat] = struct{}{}
|
||||
out = append(out, pat)
|
||||
}
|
||||
// Fallback for non-NVIDIA setups where only generic DRM device exists.
|
||||
if len(out) == 0 {
|
||||
if _, err := os.Stat("/dev/dri"); err == nil {
|
||||
out = append(out, "/dev/dri")
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Apple M-series GPUDetector implementation
|
||||
type AppleDetector struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (d *AppleDetector) DetectGPUCount() int {
|
||||
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
|
||||
return n
|
||||
}
|
||||
if d.enabled {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *AppleDetector) GetGPUType() GPUType {
|
||||
return GPUTypeApple
|
||||
}
|
||||
|
||||
func (d *AppleDetector) GetDevicePaths() []string {
|
||||
return []string{"/dev/metal", "/dev/mps"}
|
||||
}
|
||||
|
||||
// None GPUDetector implementation
|
||||
type NoneDetector struct{}
|
||||
|
||||
func (d *NoneDetector) DetectGPUCount() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *NoneDetector) GetGPUType() GPUType {
|
||||
return GPUTypeNone
|
||||
}
|
||||
|
||||
func (d *NoneDetector) GetDevicePaths() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GPUDetectorFactory creates appropriate GPU detector based on config
|
||||
type GPUDetectorFactory struct{}
|
||||
|
||||
func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector {
|
||||
// Check for explicit environment override
|
||||
if gpuType := os.Getenv("FETCH_ML_GPU_TYPE"); gpuType != "" {
|
||||
switch gpuType {
|
||||
case string(GPUTypeNVIDIA):
|
||||
return &NVIDIADetector{}
|
||||
case string(GPUTypeApple):
|
||||
return &AppleDetector{enabled: true}
|
||||
case string(GPUTypeNone):
|
||||
return &NoneDetector{}
|
||||
}
|
||||
}
|
||||
|
||||
// Respect configured vendor when explicitly set.
|
||||
if cfg != nil {
|
||||
switch GPUType(cfg.GPUVendor) {
|
||||
case GPUTypeApple:
|
||||
return &AppleDetector{enabled: cfg.AppleGPU.Enabled}
|
||||
case GPUTypeNone:
|
||||
return &NoneDetector{}
|
||||
case GPUTypeNVIDIA:
|
||||
return &NVIDIADetector{}
|
||||
case "amd":
|
||||
// AMD uses similar device exposure patterns in this codebase.
|
||||
return &NVIDIADetector{}
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect based on config
|
||||
if cfg != nil {
|
||||
if cfg.AppleGPU.Enabled {
|
||||
return &AppleDetector{enabled: true}
|
||||
}
|
||||
if len(cfg.GPUDevices) > 0 {
|
||||
return &NVIDIADetector{}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to no GPU
|
||||
return &NoneDetector{}
|
||||
}
|
||||
130
internal/worker/jupyter_task.go
Normal file
130
internal/worker/jupyter_task.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
const (
|
||||
jupyterTaskTypeKey = "task_type"
|
||||
jupyterTaskTypeValue = "jupyter"
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
jupyterTaskOutputType = "jupyter_output"
|
||||
)
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service *jupyter.JupyterService `json:"service,omitempty"`
|
||||
Services []*jupyter.JupyterService `json:"services"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
func isJupyterTask(task *queue.Task) bool {
|
||||
if task == nil || task.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
|
||||
}
|
||||
|
||||
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
if w == nil {
|
||||
return nil, fmt.Errorf("worker is nil")
|
||||
}
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
if w.jupyter == nil {
|
||||
return nil, fmt.Errorf("jupyter manager not configured")
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return nil, fmt.Errorf("missing task metadata")
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
|
||||
if action == "" {
|
||||
return nil, fmt.Errorf("missing jupyter action")
|
||||
}
|
||||
|
||||
// Validate job name since it is used as the task status key and shows up in logs.
|
||||
if err := container.ValidateJobName(task.JobName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
switch action {
|
||||
case jupyterActionStart:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
if ws == "" {
|
||||
return nil, fmt.Errorf("missing jupyter workspace")
|
||||
}
|
||||
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionStop:
|
||||
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter service id")
|
||||
}
|
||||
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionRemove:
|
||||
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter service id")
|
||||
}
|
||||
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
|
||||
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionList:
|
||||
services := w.jupyter.ListServices()
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionRestore:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
|
||||
return json.Marshal(out)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid jupyter action: %s", action)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
return w.runJupyterTask(ctx, task)
|
||||
}
|
||||
525
internal/worker/runloop.go
Normal file
525
internal/worker/runloop.go
Normal file
|
|
@ -0,0 +1,525 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
)
|
||||
|
||||
// Start starts the worker's main processing loop.
|
||||
func (w *Worker) Start() {
|
||||
w.logger.Info("worker started",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers,
|
||||
"poll_interval", w.config.PollInterval)
|
||||
|
||||
go w.heartbeat()
|
||||
if w.config != nil && w.config.PrewarmEnabled {
|
||||
go w.prewarmNextLoop()
|
||||
go w.prewarmImageGCLoop()
|
||||
}
|
||||
|
||||
for {
|
||||
switch {
|
||||
case w.ctx.Err() != nil:
|
||||
w.logger.Info("shutdown signal received, waiting for tasks")
|
||||
w.waitForTasks()
|
||||
return
|
||||
default:
|
||||
}
|
||||
if w.runningCount() >= w.config.MaxWorkers {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
queueStart := time.Now()
|
||||
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
|
||||
task, err := w.queue.GetNextTaskWithLeaseBlocking(
|
||||
w.config.WorkerID,
|
||||
w.config.TaskLeaseDuration,
|
||||
blockTimeout,
|
||||
)
|
||||
queueLatency := time.Since(queueStart)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
continue
|
||||
}
|
||||
w.logger.Error("error fetching task",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task == nil {
|
||||
if queueLatency > 200*time.Millisecond {
|
||||
w.logger.Debug("queue poll latency",
|
||||
"latency_ms", queueLatency.Milliseconds())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if depth, derr := w.queue.QueueDepth(); derr == nil {
|
||||
if queueLatency > 100*time.Millisecond || depth > 0 {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"remaining_depth", depth)
|
||||
}
|
||||
} else if queueLatency > 100*time.Millisecond {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"depth_error", derr)
|
||||
}
|
||||
|
||||
// Reserve a running slot *before* starting the goroutine so we don't drain
|
||||
// the entire queue while max_workers is 1.
|
||||
w.reserveRunningSlot(task.ID)
|
||||
go w.executeTaskWithLease(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) reserveRunningSlot(taskID string) {
|
||||
w.runningMu.Lock()
|
||||
defer w.runningMu.Unlock()
|
||||
if w.running == nil {
|
||||
w.running = make(map[string]context.CancelFunc)
|
||||
}
|
||||
// Track a cancel func for future shutdown handling; currently best-effort.
|
||||
_, cancel := context.WithCancel(w.ctx)
|
||||
w.running[taskID] = cancel
|
||||
}
|
||||
|
||||
func (w *Worker) releaseRunningSlot(taskID string) {
|
||||
w.runningMu.Lock()
|
||||
defer w.runningMu.Unlock()
|
||||
if w.running == nil {
|
||||
return
|
||||
}
|
||||
if cancel, ok := w.running[taskID]; ok {
|
||||
cancel()
|
||||
delete(w.running, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmImageGCLoop() {
|
||||
if w.config == nil || !w.config.PrewarmEnabled {
|
||||
return
|
||||
}
|
||||
if w.envPool == nil {
|
||||
return
|
||||
}
|
||||
if w.config.LocalMode {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(w.config.PodmanImage) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
lastSeen := ""
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
runGC := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if w.queue != nil {
|
||||
v, err := w.queue.PrewarmGCRequestValue()
|
||||
if err == nil && v != "" && v != lastSeen {
|
||||
lastSeen = v
|
||||
runGC()
|
||||
continue
|
||||
}
|
||||
}
|
||||
runGC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) heartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.Heartbeat(w.id); err != nil {
|
||||
w.logger.Warn("heartbeat failed",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) waitForTasks() {
|
||||
timeout := time.After(5 * time.Minute)
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
w.logger.Warn("shutdown timeout, force stopping",
|
||||
"running_tasks", len(w.running))
|
||||
return
|
||||
case <-ticker.C:
|
||||
count := w.runningCount()
|
||||
if count == 0 {
|
||||
w.logger.Info("all tasks completed, shutting down")
|
||||
return
|
||||
}
|
||||
w.logger.Debug("waiting for tasks to complete",
|
||||
"remaining", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) runningCount() int {
|
||||
w.runningMu.RLock()
|
||||
defer w.runningMu.RUnlock()
|
||||
return len(w.running)
|
||||
}
|
||||
|
||||
// GetMetrics returns current worker metrics.
|
||||
func (w *Worker) GetMetrics() map[string]any {
|
||||
stats := w.metrics.GetStats()
|
||||
stats["worker_id"] = w.id
|
||||
stats["max_workers"] = w.config.MaxWorkers
|
||||
return stats
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the worker.
|
||||
func (w *Worker) Stop() {
|
||||
w.cancel()
|
||||
w.waitForTasks()
|
||||
|
||||
// FIXED: Check error return values
|
||||
if err := w.server.Close(); err != nil {
|
||||
w.logger.Warn("error closing server connection", "error", err)
|
||||
}
|
||||
if err := w.queue.Close(); err != nil {
|
||||
w.logger.Warn("error closing queue connection", "error", err)
|
||||
}
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics exporter shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
||||
}
|
||||
|
||||
// Execute task with lease management and retry.
|
||||
func (w *Worker) executeTaskWithLease(task *queue.Task) {
|
||||
defer w.releaseRunningSlot(task.ID)
|
||||
|
||||
// Track task for graceful shutdown
|
||||
w.gracefulWait.Add(1)
|
||||
w.activeTasks.Store(task.ID, task)
|
||||
defer w.gracefulWait.Done()
|
||||
defer w.activeTasks.Delete(task.ID)
|
||||
|
||||
// Create task-specific context with timeout
|
||||
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
|
||||
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
|
||||
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
|
||||
|
||||
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
|
||||
defer taskCancel()
|
||||
|
||||
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
|
||||
logger.Info("starting task",
|
||||
"worker_id", w.id,
|
||||
"datasets", task.Datasets,
|
||||
"priority", task.Priority)
|
||||
|
||||
// Record task start
|
||||
w.metrics.RecordTaskStart()
|
||||
defer w.metrics.RecordTaskCompletion()
|
||||
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
logger.Info("task cancelled before execution")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
|
||||
if isJupyterTask(task) {
|
||||
out, err := w.runJupyterTask(taskCtx, task)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if err != nil {
|
||||
logger.Error("jupyter task failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = err.Error()
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
if len(out) > 0 {
|
||||
task.Output = string(out)
|
||||
}
|
||||
task.Status = "completed"
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse datasets from task arguments
|
||||
task.Datasets = resolveDatasets(task)
|
||||
|
||||
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
|
||||
logger.Error("task validation failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Validation failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error("failed to update task status after validation failure", "error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
|
||||
logger.Error("provenance validation failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task status after provenance validation failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
lease, err := w.resources.Acquire(taskCtx, task)
|
||||
if err != nil {
|
||||
logger.Error("resource acquisition failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task status after resource acquisition failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
defer lease.Release()
|
||||
|
||||
// Start heartbeat goroutine
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer cancelHeartbeat()
|
||||
|
||||
go w.heartbeatLoop(heartbeatCtx, task.ID)
|
||||
|
||||
// Update task status
|
||||
task.Status = "running"
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
task.WorkerID = w.id
|
||||
|
||||
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
|
||||
w.recordTaskProvenance(taskCtx, task)
|
||||
|
||||
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
|
||||
logger.Error("failed to update task status", "error", err)
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
|
||||
if w.config.AutoFetchData && len(task.Datasets) > 0 {
|
||||
if err := w.fetchDatasets(taskCtx, task); err != nil {
|
||||
logger.Error("data fetch failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
err := w.queue.UpdateTask(task)
|
||||
if err != nil {
|
||||
logger.Error("failed to update task status after data fetch failure", "error", err)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
|
||||
logger.Error("dataset checksum verification failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task after dataset checksum verification failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
if err := w.verifySnapshot(taskCtx, task); err != nil {
|
||||
logger.Error("snapshot checksum verification failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task after snapshot checksum verification failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Execute job with panic recovery
|
||||
var execErr error
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
execErr = fmt.Errorf("panic during execution: %v", r)
|
||||
}
|
||||
}()
|
||||
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
|
||||
execErr = w.runJob(taskCtx, task, cudaVisible)
|
||||
}()
|
||||
|
||||
// Finalize task
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
|
||||
if execErr != nil {
|
||||
task.Error = execErr.Error()
|
||||
|
||||
// Check if transient error (network, timeout, etc)
|
||||
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
|
||||
w.logger.Warn("task failed with transient error, will retry",
|
||||
"task_id", task.ID,
|
||||
"error", execErr,
|
||||
"retry_count", task.RetryCount)
|
||||
_ = w.queue.RetryTask(task)
|
||||
} else {
|
||||
task.Status = "failed"
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
|
||||
// Release lease
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
}
|
||||
|
||||
// Heartbeat loop to renew lease.
|
||||
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
|
||||
ticker := time.NewTicker(w.config.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
|
||||
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
|
||||
return
|
||||
}
|
||||
// Also update worker heartbeat
|
||||
_ = w.queue.Heartbeat(w.config.WorkerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the worker.
|
||||
func (w *Worker) Shutdown() error {
|
||||
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
|
||||
|
||||
// Wait for active tasks with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.gracefulWait.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timeout := time.After(w.config.GracefulTimeout)
|
||||
select {
|
||||
case <-done:
|
||||
w.logger.Info("all tasks completed, shutdown successful")
|
||||
case <-timeout:
|
||||
w.logger.Warn("graceful shutdown timeout, releasing active leases")
|
||||
w.releaseAllLeases()
|
||||
}
|
||||
|
||||
return w.queue.Close()
|
||||
}
|
||||
|
||||
// Release all active leases.
|
||||
func (w *Worker) releaseAllLeases() {
|
||||
w.activeTasks.Range(func(key, _ interface{}) bool {
|
||||
taskID := key.(string)
|
||||
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
|
||||
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions.
|
||||
func (w *Worker) countActiveTasks() int {
|
||||
count := 0
|
||||
w.activeTasks.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Check if error is transient (network, timeout, resource unavailable, etc)
|
||||
errStr := err.Error()
|
||||
transientIndicators := []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"resource temporarily unavailable",
|
||||
"no such host",
|
||||
"network unreachable",
|
||||
}
|
||||
for _, indicator := range transientIndicators {
|
||||
if strings.Contains(strings.ToLower(errStr), indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
270
internal/worker/snapshot_store.go
Normal file
270
internal/worker/snapshot_store.go
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
type SnapshotFetcher interface {
|
||||
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
type minioSnapshotFetcher struct {
|
||||
client *minio.Client
|
||||
}
|
||||
|
||||
func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) {
|
||||
obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("missing snapshot store config")
|
||||
}
|
||||
endpoint := strings.TrimSpace(cfg.Endpoint)
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("missing snapshot store endpoint")
|
||||
}
|
||||
bucket := strings.TrimSpace(cfg.Bucket)
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("missing snapshot store bucket")
|
||||
}
|
||||
creds := cfg.credentials()
|
||||
client, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: creds,
|
||||
Secure: cfg.Secure,
|
||||
Region: strings.TrimSpace(cfg.Region),
|
||||
MaxRetries: cfg.MaxRetries,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &minioSnapshotFetcher{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *SnapshotStoreConfig) credentials() *credentials.Credentials {
|
||||
if c != nil {
|
||||
ak := strings.TrimSpace(c.AccessKey)
|
||||
sk := strings.TrimSpace(c.SecretKey)
|
||||
st := strings.TrimSpace(c.SessionToken)
|
||||
if ak != "" && sk != "" {
|
||||
return credentials.NewStaticV4(ak, sk, st)
|
||||
}
|
||||
}
|
||||
return credentials.NewChainCredentials([]credentials.Provider{
|
||||
&credentials.EnvMinio{},
|
||||
&credentials.EnvAWS{},
|
||||
})
|
||||
}
|
||||
|
||||
func ResolveSnapshot(
|
||||
ctx context.Context,
|
||||
dataDir string,
|
||||
cfg *SnapshotStoreConfig,
|
||||
snapshotID string,
|
||||
wantSHA256 string,
|
||||
fetcher SnapshotFetcher,
|
||||
) (string, error) {
|
||||
dataDir = strings.TrimSpace(dataDir)
|
||||
if dataDir == "" {
|
||||
return "", fmt.Errorf("missing data_dir")
|
||||
}
|
||||
snapshotID = strings.TrimSpace(snapshotID)
|
||||
if snapshotID == "" {
|
||||
return "", fmt.Errorf("missing snapshot_id")
|
||||
}
|
||||
if err := container.ValidateJobName(snapshotID); err != nil {
|
||||
return "", fmt.Errorf("invalid snapshot_id: %w", err)
|
||||
}
|
||||
want, err := normalizeSHA256ChecksumHex(wantSHA256)
|
||||
if err != nil || want == "" {
|
||||
return "", fmt.Errorf("invalid snapshot_sha256")
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want)
|
||||
if info, err := os.Stat(cacheDir); err == nil && info.IsDir() {
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return filepath.Join(dataDir, "snapshots", snapshotID), nil
|
||||
}
|
||||
|
||||
bucket := strings.TrimSpace(cfg.Bucket)
|
||||
if bucket == "" {
|
||||
return "", fmt.Errorf("missing snapshot store bucket")
|
||||
}
|
||||
prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/")
|
||||
key := snapshotID + ".tar.gz"
|
||||
if prefix != "" {
|
||||
key = path.Join(prefix, key)
|
||||
}
|
||||
|
||||
if fetcher == nil {
|
||||
mf, err := newMinioSnapshotFetcher(cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fetcher = mf
|
||||
}
|
||||
|
||||
fetchCtx := ctx
|
||||
if cfg.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rc, err := fetcher.Get(fetchCtx, bucket, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = rc.Close() }()
|
||||
|
||||
tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp")
|
||||
if err := os.MkdirAll(tmpRoot, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(workDir) }()
|
||||
|
||||
archivePath := filepath.Join(workDir, "snapshot.tar.gz")
|
||||
f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_, copyErr := io.Copy(f, rc)
|
||||
closeErr := f.Close()
|
||||
if copyErr != nil {
|
||||
return "", copyErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
return "", closeErr
|
||||
}
|
||||
|
||||
extractDir := filepath.Join(workDir, "extracted")
|
||||
if err := os.MkdirAll(extractDir, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := extractTarGz(archivePath, extractDir); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
got, err := dirOverallSHA256Hex(extractDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if got != want {
|
||||
return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(cacheDir), 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Rename(extractDir, cacheDir); err != nil {
|
||||
if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() {
|
||||
return cacheDir, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
func extractTarGz(archivePath, dstDir string) error {
|
||||
archivePath = filepath.Clean(archivePath)
|
||||
dstDir = filepath.Clean(dstDir)
|
||||
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = gz.Close() }()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := strings.TrimSpace(hdr.Name)
|
||||
name = strings.TrimPrefix(name, "./")
|
||||
clean := path.Clean(name)
|
||||
if clean == "." {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") {
|
||||
return fmt.Errorf("invalid tar entry")
|
||||
}
|
||||
|
||||
target, err := safeJoin(dstDir, filepath.FromSlash(clean))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(target, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
mode := hdr.FileInfo().Mode() & 0777
|
||||
out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.CopyN(out, tr, hdr.Size); err != nil {
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported tar entry type")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func safeJoin(baseDir, rel string) (string, error) {
|
||||
baseDir = filepath.Clean(baseDir)
|
||||
joined := filepath.Join(baseDir, rel)
|
||||
joined = filepath.Clean(joined)
|
||||
basePrefix := baseDir + string(os.PathSeparator)
|
||||
if joined != baseDir && !strings.HasPrefix(joined, basePrefix) {
|
||||
return "", fmt.Errorf("invalid relative path")
|
||||
}
|
||||
return joined, nil
|
||||
}
|
||||
Loading…
Reference in a new issue