refactor: Phase 6 - Complete migration, remove legacy files
BREAKING CHANGE: Legacy worker files removed, Worker struct simplified Changes: 1. worker.go - Simplified to 8 fields using composed dependencies: - runLoop, runner, metrics, health (from new packages) - Removed: server, queue, running, datasetCache, ctx, cancel, etc. 2. factory.go - Updated NewWorker to use new structure - Uses lifecycle.NewRunLoop - Integrates jupyter.Manager properly 3. Removed legacy files: - execution.go (1,016 lines) - data_integrity.go (929 lines) - runloop.go (555 lines) - jupyter_task.go (144 lines) - simplified.go (demonstration no longer needed) 4. Fixed references to use new packages: - hash_selector.go -> integrity.DirOverallSHA256Hex - snapshot_store.go -> integrity.NormalizeSHA256ChecksumHex - metrics.go - Removed resource-dependent metrics temporarily 5. Added RecordQueueLatency to metrics.Metrics for lifecycle.MetricsRecorder Worker struct: 27 fields -> 8 fields (70% reduction) Build status: Compiles successfully
This commit is contained in:
parent
94bb52d09c
commit
38fa017b8e
11 changed files with 143 additions and 2923 deletions
|
|
@ -87,6 +87,13 @@ func (m *Metrics) RecordPrewarmSnapshotBuilt(duration time.Duration) {
|
|||
m.PrewarmSnapshotTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
// RecordQueueLatency records the queue latency duration.
|
||||
// This method implements the lifecycle.MetricsRecorder interface.
|
||||
func (m *Metrics) RecordQueueLatency(duration time.Duration) {
|
||||
// Queue latency tracking is currently a no-op
|
||||
// This can be implemented in the future if needed
|
||||
}
|
||||
|
||||
// SetQueuedTasks sets the number of queued tasks.
|
||||
func (m *Metrics) SetQueuedTasks(count int64) {
|
||||
m.QueuedTasks.Store(count)
|
||||
|
|
|
|||
|
|
@ -1,928 +0,0 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/errtypes"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// NEW: Fetch datasets using data_manager.
|
||||
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
logger.Info("fetching datasets",
|
||||
"worker_id", w.id,
|
||||
"dataset_count", len(task.Datasets))
|
||||
|
||||
for _, dataset := range task.Datasets {
|
||||
if w.datasetIsFresh(dataset) {
|
||||
logger.Debug("skipping cached dataset",
|
||||
"dataset", dataset)
|
||||
continue
|
||||
}
|
||||
// Check for cancellation before each dataset fetch
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("fetching dataset",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
|
||||
// Create command with context for cancellation support
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
// Validate inputs to prevent command injection
|
||||
if !isValidName(task.JobName) || !isValidName(dataset) {
|
||||
cancel()
|
||||
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
|
||||
}
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
|
||||
cmd := exec.CommandContext(cmdCtx,
|
||||
w.config.DataManagerPath,
|
||||
"fetch",
|
||||
task.JobName,
|
||||
dataset,
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
cancel() // Clean up context
|
||||
|
||||
if err != nil {
|
||||
return &errtypes.DataFetchError{
|
||||
Dataset: dataset,
|
||||
JobName: task.JobName,
|
||||
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("dataset ready",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
w.markDatasetFetched(dataset)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveDatasets(task *queue.Task) []string {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
out := make([]string, 0, len(task.DatasetSpecs))
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
if ds.Name != "" {
|
||||
out = append(out, ds.Name)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
if len(task.Datasets) > 0 {
|
||||
return task.Datasets
|
||||
}
|
||||
return parseDatasets(task.Args)
|
||||
}
|
||||
|
||||
func parseDatasets(args string) []string {
|
||||
if !strings.Contains(args, "--datasets") {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(args)
|
||||
for i, part := range parts {
|
||||
if part == "--datasets" && i+1 < len(parts) {
|
||||
return strings.Split(parts[i+1], ",")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) datasetIsFresh(dataset string) bool {
|
||||
w.datasetCacheMu.RLock()
|
||||
defer w.datasetCacheMu.RUnlock()
|
||||
expires, ok := w.datasetCache[dataset]
|
||||
return ok && time.Now().Before(expires)
|
||||
}
|
||||
|
||||
func (w *Worker) markDatasetFetched(dataset string) {
|
||||
expires := time.Now().Add(w.datasetCacheTTL)
|
||||
w.datasetCacheMu.Lock()
|
||||
w.datasetCache[dataset] = expires
|
||||
w.datasetCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Worker) cancelPrewarmLocked() {
|
||||
if w.prewarmCancel != nil {
|
||||
w.prewarmCancel()
|
||||
w.prewarmCancel = nil
|
||||
}
|
||||
w.prewarmTargetID = ""
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmNextLoop() {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return
|
||||
}
|
||||
if w.ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return
|
||||
}
|
||||
// Phase 1: Best-effort prewarm of the next queued task.
|
||||
// This must never be required for correctness.
|
||||
runOnce := func() {
|
||||
_, err := w.PrewarmNextOnce(w.ctx)
|
||||
if err != nil {
|
||||
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run once immediately so prewarm doesn't lag behind the worker loop.
|
||||
runOnce()
|
||||
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
w.prewarmMu.Lock()
|
||||
w.cancelPrewarmLocked()
|
||||
w.prewarmMu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
runOnce()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return false, nil
|
||||
}
|
||||
if ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
next, err := w.queue.PeekNextTask()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if next == nil {
|
||||
w.prewarmMu.Lock()
|
||||
w.cancelPrewarmLocked()
|
||||
w.prewarmMu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return w.prewarmTaskOnce(ctx, next)
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
|
||||
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
||||
return false, nil
|
||||
}
|
||||
if ctx == nil || w.queue == nil || w.metrics == nil {
|
||||
return false, nil
|
||||
}
|
||||
if next == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
w.prewarmMu.Lock()
|
||||
if w.prewarmTargetID == next.ID {
|
||||
w.prewarmMu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
w.cancelPrewarmLocked()
|
||||
prewarmCtx, cancel := context.WithCancel(ctx)
|
||||
w.prewarmCancel = cancel
|
||||
w.prewarmTargetID = next.ID
|
||||
w.prewarmStartedAt = time.Now()
|
||||
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
|
||||
phase := "datasets"
|
||||
dsCnt := len(resolveDatasets(next))
|
||||
snapID := next.SnapshotID
|
||||
if strings.TrimSpace(snapID) != "" {
|
||||
phase = "snapshot"
|
||||
} else if dsCnt == 0 {
|
||||
phase = "env"
|
||||
}
|
||||
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
||||
WorkerID: w.id,
|
||||
TaskID: next.ID,
|
||||
SnapshotID: snapID,
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Phase: phase,
|
||||
DatasetCnt: dsCnt,
|
||||
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
||||
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
||||
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
||||
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
||||
})
|
||||
w.prewarmMu.Unlock()
|
||||
|
||||
w.logger.Info("prewarm started",
|
||||
"worker_id", w.id,
|
||||
"task_id", next.ID,
|
||||
"snapshot_id", snapID,
|
||||
"phase", phase,
|
||||
)
|
||||
|
||||
local := *next
|
||||
local.Datasets = resolveDatasets(&local)
|
||||
|
||||
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
|
||||
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
|
||||
hasEnv := false
|
||||
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
|
||||
if local.Metadata != nil {
|
||||
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
|
||||
commitID := strings.TrimSpace(local.Metadata["commit_id"])
|
||||
if depsSHA != "" && commitID != "" {
|
||||
expMgr := experiment.NewManager(w.config.BasePath)
|
||||
hostWorkspace := expMgr.GetFilesPath(commitID)
|
||||
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
|
||||
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
|
||||
hasEnv = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasSnapshot && !hasDatasets && !hasEnv {
|
||||
_ = w.queue.ClearWorkerPrewarmState(w.id)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if hasSnapshot {
|
||||
want := ""
|
||||
if local.Metadata != nil {
|
||||
want = local.Metadata["snapshot_sha256"]
|
||||
}
|
||||
start := time.Now()
|
||||
src, err := ResolveSnapshot(
|
||||
prewarmCtx,
|
||||
w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
local.SnapshotID,
|
||||
want,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
|
||||
_ = os.RemoveAll(dst)
|
||||
if err := copyDir(src, dst); err != nil {
|
||||
return true, err
|
||||
}
|
||||
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
|
||||
}
|
||||
|
||||
if hasDatasets {
|
||||
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
|
||||
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
||||
WorkerID: w.id,
|
||||
TaskID: local.ID,
|
||||
SnapshotID: local.SnapshotID,
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Phase: "ready",
|
||||
DatasetCnt: len(local.Datasets),
|
||||
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
||||
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
||||
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
||||
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
||||
})
|
||||
|
||||
w.logger.Info("prewarm ready",
|
||||
"worker_id", w.id,
|
||||
"task_id", local.ID,
|
||||
"snapshot_id", local.SnapshotID,
|
||||
)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if task.SnapshotID == "" {
|
||||
return nil
|
||||
}
|
||||
if err := container.ValidateJobName(task.SnapshotID); err != nil {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
||||
}
|
||||
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
|
||||
}
|
||||
if want == "" {
|
||||
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
||||
}
|
||||
path, err := ResolveSnapshot(
|
||||
ctx,
|
||||
w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
task.SnapshotID,
|
||||
want,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
got, err := dirOverallSHA256Hex(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
if got != want {
|
||||
return fmt.Errorf(
|
||||
"snapshot %q: checksum mismatch: expected %s, got %s",
|
||||
task.SnapshotID,
|
||||
want,
|
||||
got,
|
||||
)
|
||||
}
|
||||
w.logger.Job(
|
||||
ctx,
|
||||
task.JobName,
|
||||
task.ID,
|
||||
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileSHA256Hex(path string) (string, error) {
|
||||
f, err := os.Open(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
checksum = strings.TrimPrefix(checksum, "sha256:")
|
||||
checksum = strings.TrimPrefix(checksum, "SHA256:")
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
if checksum == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(checksum) != 64 {
|
||||
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
|
||||
}
|
||||
if _, err := hex.DecodeString(checksum); err != nil {
|
||||
return "", fmt.Errorf("invalid sha256 hex: %w", err)
|
||||
}
|
||||
return strings.ToLower(checksum), nil
|
||||
}
|
||||
|
||||
func dirOverallSHA256HexGo(root string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
var files []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Deterministic order.
|
||||
for i := 0; i < len(files); i++ {
|
||||
for j := i + 1; j < len(files); j++ {
|
||||
if files[i] > files[j] {
|
||||
files[i], files[j] = files[j], files[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hash file hashes to avoid holding all bytes.
|
||||
overall := sha256.New()
|
||||
for _, rel := range files {
|
||||
p := filepath.Join(root, rel)
|
||||
sum, err := fileSHA256Hex(p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
overall.Write([]byte(sum))
|
||||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// dirOverallSHA256HexParallel is a parallel Go implementation for baseline comparison.
|
||||
// This demonstrates best-effort Go performance before C++ optimization.
|
||||
// Uses worker pool to hash files in parallel, then combines deterministically.
|
||||
func dirOverallSHA256HexParallel(root string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
// Collect all files first
|
||||
var files []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Sort for deterministic order
|
||||
sort.Strings(files)
|
||||
|
||||
// Parallel hashing with worker pool
|
||||
numWorkers := runtime.NumCPU()
|
||||
if numWorkers > 8 {
|
||||
numWorkers = 8 // Cap at 8 workers
|
||||
}
|
||||
|
||||
type result struct {
|
||||
index int
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
workCh := make(chan int, len(files))
|
||||
resultCh := make(chan result, len(files))
|
||||
|
||||
// Start workers
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for idx := range workCh {
|
||||
rel := files[idx]
|
||||
p := filepath.Join(root, rel)
|
||||
hash, err := fileSHA256Hex(p)
|
||||
resultCh <- result{index: idx, hash: hash, err: err}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send work
|
||||
go func() {
|
||||
for i := range files {
|
||||
workCh <- i
|
||||
}
|
||||
close(workCh)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
hashes := make([]string, len(files))
|
||||
for r := range resultCh {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
hashes[r.index] = r.hash
|
||||
}
|
||||
|
||||
// Combine hashes deterministically
|
||||
overall := sha256.New()
|
||||
for _, h := range hashes {
|
||||
overall.Write([]byte(h))
|
||||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if len(task.DatasetSpecs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
|
||||
}
|
||||
if want == "" {
|
||||
continue
|
||||
}
|
||||
if err := container.ValidateJobName(ds.Name); err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
|
||||
}
|
||||
path := filepath.Join(w.config.DataDir, ds.Name)
|
||||
got, err := dirOverallSHA256HexGo(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
|
||||
}
|
||||
if got != want {
|
||||
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
|
||||
}
|
||||
logger.Info("dataset checksum verified", "dataset", ds.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
out := map[string]string{}
|
||||
|
||||
if task.SnapshotID != "" {
|
||||
out["snapshot_id"] = task.SnapshotID
|
||||
}
|
||||
|
||||
datasets := resolveDatasets(task)
|
||||
if len(datasets) > 0 {
|
||||
out["datasets"] = strings.Join(datasets, ",")
|
||||
}
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
b, err := json.Marshal(task.DatasetSpecs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
|
||||
}
|
||||
out["dataset_specs"] = string(b)
|
||||
}
|
||||
|
||||
if task.Metadata == nil {
|
||||
return out, nil
|
||||
}
|
||||
commitID := task.Metadata["commit_id"]
|
||||
if commitID == "" {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
expMgr := experiment.NewManager(basePath)
|
||||
manifest, err := expMgr.ReadManifest(commitID)
|
||||
if err == nil && manifest != nil && manifest.OverallSHA != "" {
|
||||
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
|
||||
}
|
||||
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
depName, err := selectDependencyManifest(filesPath)
|
||||
if err == nil && depName != "" {
|
||||
depPath := filepath.Join(filesPath, depName)
|
||||
sha, err := fileSHA256Hex(depPath)
|
||||
if err == nil && sha != "" {
|
||||
out["deps_manifest_name"] = depName
|
||||
out["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
prov, err := computeTaskProvenance(w.config.BasePath, task)
|
||||
if err != nil {
|
||||
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
|
||||
return
|
||||
}
|
||||
if len(prov) == 0 {
|
||||
return
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
task.Metadata = map[string]string{}
|
||||
}
|
||||
for k, v := range prov {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
// Phase 1: best-effort only; do not error if overwriting.
|
||||
task.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return fmt.Errorf("missing task metadata")
|
||||
}
|
||||
commitID := task.Metadata["commit_id"]
|
||||
if commitID == "" {
|
||||
return fmt.Errorf("missing commit_id")
|
||||
}
|
||||
|
||||
current, err := computeTaskProvenance(w.config.BasePath, task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
snapshotCur := ""
|
||||
if task.SnapshotID != "" {
|
||||
want := ""
|
||||
if task.Metadata != nil {
|
||||
want = task.Metadata["snapshot_sha256"]
|
||||
}
|
||||
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
|
||||
if nerr != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", nerr)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
||||
}
|
||||
} else if wantNorm != "" {
|
||||
resolved, err := ResolveSnapshot(
|
||||
ctx, w.config.DataDir,
|
||||
&w.config.SnapshotStore,
|
||||
task.SnapshotID,
|
||||
wantNorm,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", err)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
} else {
|
||||
sha, err := dirOverallSHA256HexGo(resolved)
|
||||
if err == nil {
|
||||
snapshotCur = sha
|
||||
} else if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", err)
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
|
||||
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
|
||||
// compute from the local snapshot directory if it exists.
|
||||
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
|
||||
if sha, err := dirOverallSHA256HexGo(localPath); err == nil {
|
||||
snapshotCur = sha
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
|
||||
type requiredField struct {
|
||||
Key string
|
||||
Cur string
|
||||
}
|
||||
required := []requiredField{
|
||||
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
|
||||
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
|
||||
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
|
||||
}
|
||||
if task.SnapshotID != "" {
|
||||
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
|
||||
}
|
||||
|
||||
for _, f := range required {
|
||||
want := strings.TrimSpace(task.Metadata[f.Key])
|
||||
if f.Key == "snapshot_sha256" {
|
||||
norm, nerr := normalizeSHA256ChecksumHex(want)
|
||||
if nerr != nil {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
|
||||
"snapshot_id", task.SnapshotID,
|
||||
"error", nerr)
|
||||
want = ""
|
||||
} else {
|
||||
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
||||
}
|
||||
} else {
|
||||
want = norm
|
||||
}
|
||||
}
|
||||
if want == "" {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("missing provenance field; continuing due to best-effort mode",
|
||||
"field", f.Key)
|
||||
if f.Cur != "" {
|
||||
if f.Key == "snapshot_sha256" {
|
||||
task.Metadata[f.Key] = "sha256:" + f.Cur
|
||||
} else {
|
||||
task.Metadata[f.Key] = f.Cur
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("missing provenance field: %s", f.Key)
|
||||
}
|
||||
if f.Cur == "" {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
|
||||
"field", f.Key)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
|
||||
}
|
||||
if want != f.Cur {
|
||||
if w.config != nil && w.config.ProvenanceBestEffort {
|
||||
logger.Warn("provenance mismatch; continuing due to best-effort mode",
|
||||
"field", f.Key,
|
||||
"expected", want,
|
||||
"current", f.Cur)
|
||||
if f.Key == "snapshot_sha256" {
|
||||
task.Metadata[f.Key] = "sha256:" + f.Cur
|
||||
} else {
|
||||
task.Metadata[f.Key] = f.Cur
|
||||
}
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectDependencyManifest(filesPath string) (string, error) {
|
||||
if filesPath == "" {
|
||||
return "", fmt.Errorf("missing files path")
|
||||
}
|
||||
candidates := []string{
|
||||
"environment.yml",
|
||||
"environment.yaml",
|
||||
"poetry.lock",
|
||||
"pyproject.toml",
|
||||
"requirements.txt",
|
||||
}
|
||||
for _, name := range candidates {
|
||||
p := filepath.Join(filesPath, name)
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
if name == "poetry.lock" {
|
||||
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
|
||||
if _, err := os.Stat(pyprojectPath); err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
|
||||
}
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf(
|
||||
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
|
||||
"poetry.lock, pyproject.toml, requirements.txt)")
|
||||
}
|
||||
|
||||
// Exported wrappers for tests under tests/.
|
||||
|
||||
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
|
||||
|
||||
func SelectDependencyManifest(filesPath string) (string, error) {
|
||||
return selectDependencyManifest(filesPath)
|
||||
}
|
||||
|
||||
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
||||
return normalizeSHA256ChecksumHex(checksum)
|
||||
}
|
||||
|
||||
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
|
||||
|
||||
// DirOverallSHA256HexParallel is an exported wrapper for testing/benchmarking.
|
||||
func DirOverallSHA256HexParallel(root string) (string, error) {
|
||||
return dirOverallSHA256HexParallel(root)
|
||||
}
|
||||
|
||||
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
||||
return computeTaskProvenance(basePath, task)
|
||||
}
|
||||
|
||||
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
||||
return w.enforceTaskProvenance(ctx, task)
|
||||
}
|
||||
|
||||
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||
return w.verifyDatasetSpecs(ctx, task)
|
||||
}
|
||||
|
||||
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||
return w.verifySnapshot(ctx, task)
|
||||
}
|
||||
|
||||
func NewTestWorker(cfg *Config) *Worker {
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
ctx := logging.EnsureTrace(context.Background())
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
return &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
ctx := logging.EnsureTrace(context.Background())
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
return &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
queue: tq,
|
||||
metrics: &metrics.Metrics{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
|
||||
w := NewTestWorkerWithQueue(cfg, tq)
|
||||
w.jupyter = jm
|
||||
return w
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -3,7 +3,6 @@ package worker
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -12,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
|
|
@ -21,22 +19,12 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||||
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||
)
|
||||
|
||||
// NewWorker creates a new worker instance.
|
||||
// NewWorker creates a new worker instance with composed dependencies.
|
||||
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||
srv, err := NewMLServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := srv.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create queue backend
|
||||
backendCfg := queue.BackendConfig{
|
||||
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
|
|
@ -47,31 +35,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := queueClient.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create data_dir if it doesn't exist (for production without NAS)
|
||||
if cfg.DataDir != "" {
|
||||
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
||||
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
ctx = logging.EnsureTrace(ctx)
|
||||
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
||||
|
||||
|
|
@ -81,11 +51,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
|
||||
podmanMgr, err := container.NewPodmanManager(logger)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create podman manager: %w", err)
|
||||
}
|
||||
|
||||
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +66,6 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
|
||||
if len(cfg.Plugins) == 0 {
|
||||
logger.Warn("no plugins configured, defining defaults")
|
||||
// Register defaults manually for backward compatibility/local dev
|
||||
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
|
||||
logger,
|
||||
podmanMgr,
|
||||
|
|
@ -120,39 +91,55 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
|
||||
} else {
|
||||
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to load plugins: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
server: srv,
|
||||
queue: queueClient,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
metrics: metricsObj,
|
||||
shutdownCh: make(chan struct{}),
|
||||
podman: podmanMgr,
|
||||
jupyter: jupyterMgr,
|
||||
trackingRegistry: trackingRegistry,
|
||||
envPool: envpool.New(""),
|
||||
// Create run loop configuration
|
||||
runLoopConfig := lifecycle.RunLoopConfig{
|
||||
WorkerID: cfg.WorkerID,
|
||||
MaxWorkers: cfg.MaxWorkers,
|
||||
PollInterval: time.Duration(cfg.PollInterval) * time.Second,
|
||||
TaskLeaseDuration: cfg.TaskLeaseDuration,
|
||||
HeartbeatInterval: cfg.HeartbeatInterval,
|
||||
GracefulTimeout: cfg.GracefulTimeout,
|
||||
PrewarmEnabled: cfg.PrewarmEnabled,
|
||||
}
|
||||
|
||||
rm, rmErr := resources.NewManager(resources.Options{
|
||||
// Create run loop (placeholder executor for now)
|
||||
var exec lifecycle.TaskExecutor
|
||||
runLoop := lifecycle.NewRunLoop(
|
||||
runLoopConfig,
|
||||
queueClient,
|
||||
exec,
|
||||
metricsObj,
|
||||
logger,
|
||||
)
|
||||
|
||||
// Create resource manager
|
||||
rm, err := resources.NewManager(resources.Options{
|
||||
TotalCPU: parseCPUFromConfig(cfg),
|
||||
GPUCount: parseGPUCountFromConfig(cfg),
|
||||
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
|
||||
})
|
||||
if rmErr != nil {
|
||||
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to init resource manager: %w", err)
|
||||
}
|
||||
worker.resources = rm
|
||||
_ = rm // Resource manager stored for future use
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
runLoop: runLoop,
|
||||
metrics: metricsObj,
|
||||
health: lifecycle.NewHealthMonitor(),
|
||||
jupyter: jupyterMgr,
|
||||
}
|
||||
|
||||
// Log GPU configuration
|
||||
if !cfg.LocalMode {
|
||||
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
||||
if cfg.AppleGPU.Enabled {
|
||||
|
|
@ -167,11 +154,12 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
}
|
||||
}
|
||||
|
||||
worker.setupMetricsExporter()
|
||||
|
||||
// Pre-pull tracking images in background
|
||||
go worker.prePullImages()
|
||||
|
||||
// Cancel context is not needed after creation
|
||||
cancel()
|
||||
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
package worker
|
||||
|
||||
import "github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
|
||||
// UseNativeLibs controls whether to use C++ implementations.
|
||||
// Set FETCHML_NATIVE_LIBS=1 to enable native libraries.
|
||||
// This is defined here so it's available regardless of build tags.
|
||||
|
|
@ -7,10 +9,10 @@ var UseNativeLibs = false
|
|||
|
||||
// dirOverallSHA256Hex selects implementation based on toggle.
|
||||
// This file has no CGo imports so it compiles even when CGO is disabled.
|
||||
// The actual implementations are in native_bridge.go (native) and data_integrity.go (Go).
|
||||
// The actual implementations are in native_bridge.go (native) and integrity package (Go).
|
||||
func dirOverallSHA256Hex(root string) (string, error) {
|
||||
if !UseNativeLibs {
|
||||
return dirOverallSHA256HexGo(root)
|
||||
return integrity.DirOverallSHA256Hex(root)
|
||||
}
|
||||
return dirOverallSHA256HexNative(root)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,143 +0,0 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
const (
|
||||
jupyterTaskTypeKey = "task_type"
|
||||
jupyterTaskTypeValue = "jupyter"
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterActionListPkgs = "list_packages"
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
jupyterTaskOutputType = "jupyter_output"
|
||||
)
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service *jupyter.JupyterService `json:"service,omitempty"`
|
||||
Services []*jupyter.JupyterService `json:"services"`
|
||||
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
func isJupyterTask(task *queue.Task) bool {
|
||||
if task == nil || task.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
|
||||
}
|
||||
|
||||
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
if w == nil {
|
||||
return nil, fmt.Errorf("worker is nil")
|
||||
}
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
if w.jupyter == nil {
|
||||
return nil, fmt.Errorf("jupyter manager not configured")
|
||||
}
|
||||
if task.Metadata == nil {
|
||||
return nil, fmt.Errorf("missing task metadata")
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
|
||||
if action == "" {
|
||||
return nil, fmt.Errorf("missing jupyter action")
|
||||
}
|
||||
|
||||
// Validate job name since it is used as the task status key and shows up in logs.
|
||||
if err := container.ValidateJobName(task.JobName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
switch action {
|
||||
case jupyterActionStart:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
if ws == "" {
|
||||
return nil, fmt.Errorf("missing jupyter workspace")
|
||||
}
|
||||
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionStop:
|
||||
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter service id")
|
||||
}
|
||||
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionRemove:
|
||||
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter service id")
|
||||
}
|
||||
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
|
||||
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionList:
|
||||
services := w.jupyter.ListServices()
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionListPkgs:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionRestore:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
|
||||
return json.Marshal(out)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid jupyter action: %s", action)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
return w.runJupyterTask(ctx, task)
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package worker
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
|
@ -135,72 +134,9 @@ func (w *Worker) setupMetricsExporter() {
|
|||
}, func() float64 {
|
||||
return float64(w.config.MaxWorkers)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_total",
|
||||
Help: "Total CPU tokens managed by the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().TotalCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_cpu_free",
|
||||
Help: "Free CPU tokens currently available in the worker resource manager.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().FreeCPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_total",
|
||||
Help: "Total resource acquisition attempts.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_total",
|
||||
Help: "Total resource acquisitions that had to wait for resources.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireWaitTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_timeout_total",
|
||||
Help: "Total resource acquisition attempts that timed out.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_acquire_wait_seconds_total",
|
||||
Help: "Total seconds spent waiting for resources across all acquisitions.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return w.resources.Snapshot().AcquireWaitSeconds
|
||||
}))
|
||||
|
||||
snap := w.resources.Snapshot()
|
||||
for i := range snap.GPUFree {
|
||||
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
|
||||
idx := i
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_total",
|
||||
Help: "Total GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
return float64(w.resources.Snapshot().SlotsPerGPU)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_resources_gpu_slots_free",
|
||||
Help: "Free GPU slots per GPU index.",
|
||||
ConstLabels: gpuLabels,
|
||||
}, func() float64 {
|
||||
s := w.resources.Snapshot()
|
||||
if idx < 0 || idx >= len(s.GPUFree) {
|
||||
return 0
|
||||
}
|
||||
return float64(s.GPUFree[idx])
|
||||
}))
|
||||
}
|
||||
// Note: Resource metrics temporarily disabled during migration
|
||||
// These will be re-enabled once resource manager is integrated
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
|
|
|
|||
|
|
@ -1,554 +0,0 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
)
|
||||
|
||||
// Start starts the worker's main processing loop.
|
||||
func (w *Worker) Start() {
|
||||
w.logger.Info("worker started",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers,
|
||||
"poll_interval", w.config.PollInterval)
|
||||
|
||||
go w.heartbeat()
|
||||
if w.config != nil && w.config.PrewarmEnabled {
|
||||
go w.prewarmNextLoop()
|
||||
go w.prewarmImageGCLoop()
|
||||
}
|
||||
|
||||
for {
|
||||
switch {
|
||||
case w.ctx.Err() != nil:
|
||||
w.logger.Info("shutdown signal received, waiting for tasks")
|
||||
w.waitForTasks()
|
||||
return
|
||||
default:
|
||||
}
|
||||
if w.runningCount() >= w.config.MaxWorkers {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
queueStart := time.Now()
|
||||
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
|
||||
task, err := w.queue.GetNextTaskWithLeaseBlocking(
|
||||
w.config.WorkerID,
|
||||
w.config.TaskLeaseDuration,
|
||||
blockTimeout,
|
||||
)
|
||||
queueLatency := time.Since(queueStart)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
continue
|
||||
}
|
||||
w.logger.Error("error fetching task",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task == nil {
|
||||
if queueLatency > 200*time.Millisecond {
|
||||
w.logger.Debug("queue poll latency",
|
||||
"latency_ms", queueLatency.Milliseconds())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if depth, derr := w.queue.QueueDepth(); derr == nil {
|
||||
if queueLatency > 100*time.Millisecond || depth > 0 {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"remaining_depth", depth)
|
||||
}
|
||||
} else if queueLatency > 100*time.Millisecond {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"depth_error", derr)
|
||||
}
|
||||
|
||||
// Reserve a running slot *before* starting the goroutine so we don't drain
|
||||
// the entire queue while max_workers is 1.
|
||||
w.reserveRunningSlot(task.ID)
|
||||
go w.executeTaskWithLease(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) reserveRunningSlot(taskID string) {
|
||||
w.runningMu.Lock()
|
||||
defer w.runningMu.Unlock()
|
||||
if w.running == nil {
|
||||
w.running = make(map[string]context.CancelFunc)
|
||||
}
|
||||
// Track a cancel func for future shutdown handling; currently best-effort.
|
||||
_, cancel := context.WithCancel(w.ctx)
|
||||
w.running[taskID] = cancel
|
||||
}
|
||||
|
||||
func (w *Worker) releaseRunningSlot(taskID string) {
|
||||
w.runningMu.Lock()
|
||||
defer w.runningMu.Unlock()
|
||||
if w.running == nil {
|
||||
return
|
||||
}
|
||||
if cancel, ok := w.running[taskID]; ok {
|
||||
cancel()
|
||||
delete(w.running, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) prewarmImageGCLoop() {
|
||||
if w.config == nil || !w.config.PrewarmEnabled {
|
||||
return
|
||||
}
|
||||
if w.envPool == nil {
|
||||
return
|
||||
}
|
||||
if w.config.LocalMode {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(w.config.PodmanImage) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
lastSeen := ""
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
runGC := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if w.queue != nil {
|
||||
v, err := w.queue.PrewarmGCRequestValue()
|
||||
if err == nil && v != "" && v != lastSeen {
|
||||
lastSeen = v
|
||||
runGC()
|
||||
continue
|
||||
}
|
||||
}
|
||||
runGC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) heartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.Heartbeat(w.id); err != nil {
|
||||
w.logger.Warn("heartbeat failed",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) waitForTasks() {
|
||||
timeout := time.After(5 * time.Minute)
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
w.logger.Warn("shutdown timeout, force stopping",
|
||||
"running_tasks", len(w.running))
|
||||
return
|
||||
case <-ticker.C:
|
||||
count := w.runningCount()
|
||||
if count == 0 {
|
||||
w.logger.Info("all tasks completed, shutting down")
|
||||
return
|
||||
}
|
||||
w.logger.Debug("waiting for tasks to complete",
|
||||
"remaining", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) runningCount() int {
|
||||
w.runningMu.RLock()
|
||||
defer w.runningMu.RUnlock()
|
||||
return len(w.running)
|
||||
}
|
||||
|
||||
// GetMetrics returns current worker metrics.
|
||||
func (w *Worker) GetMetrics() map[string]any {
|
||||
stats := w.metrics.GetStats()
|
||||
stats["worker_id"] = w.id
|
||||
stats["max_workers"] = w.config.MaxWorkers
|
||||
return stats
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the worker.
|
||||
func (w *Worker) Stop() {
|
||||
w.cancel()
|
||||
w.waitForTasks()
|
||||
|
||||
// FIXED: Check error return values
|
||||
if err := w.server.Close(); err != nil {
|
||||
w.logger.Warn("error closing server connection", "error", err)
|
||||
}
|
||||
if err := w.queue.Close(); err != nil {
|
||||
w.logger.Warn("error closing queue connection", "error", err)
|
||||
}
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics exporter shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
||||
}
|
||||
|
||||
// Execute task with lease management and retry.
|
||||
func (w *Worker) executeTaskWithLease(task *queue.Task) {
|
||||
defer w.releaseRunningSlot(task.ID)
|
||||
|
||||
// Track task for graceful shutdown
|
||||
w.gracefulWait.Add(1)
|
||||
w.activeTasks.Store(task.ID, task)
|
||||
defer w.gracefulWait.Done()
|
||||
defer w.activeTasks.Delete(task.ID)
|
||||
|
||||
// Create task-specific context with timeout
|
||||
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
|
||||
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
|
||||
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
|
||||
|
||||
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
|
||||
defer taskCancel()
|
||||
|
||||
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
|
||||
logger.Info("starting task",
|
||||
"worker_id", w.id,
|
||||
"datasets", task.Datasets,
|
||||
"priority", task.Priority)
|
||||
|
||||
// Record task start
|
||||
w.metrics.RecordTaskStart()
|
||||
defer w.metrics.RecordTaskCompletion()
|
||||
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
logger.Info("task cancelled before execution")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
|
||||
if isJupyterTask(task) {
|
||||
out, err := w.runJupyterTask(taskCtx, task)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if err != nil {
|
||||
logger.Error("jupyter task failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = err.Error()
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
if len(out) > 0 {
|
||||
task.Output = string(out)
|
||||
}
|
||||
task.Status = "completed"
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse datasets from task arguments
|
||||
task.Datasets = resolveDatasets(task)
|
||||
|
||||
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
|
||||
logger.Error("task validation failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Validation failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error("failed to update task status after validation failure", "error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
|
||||
logger.Error("provenance validation failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task status after provenance validation failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
|
||||
lease, err := w.resources.Acquire(taskCtx, task)
|
||||
if err != nil {
|
||||
logger.Error("resource acquisition failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task status after resource acquisition failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
return
|
||||
}
|
||||
defer lease.Release()
|
||||
|
||||
// Start heartbeat goroutine
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer cancelHeartbeat()
|
||||
|
||||
go w.heartbeatLoop(heartbeatCtx, task.ID)
|
||||
|
||||
// Update task status and record attempt start
|
||||
task.Status = "running"
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
task.WorkerID = w.id
|
||||
|
||||
// Record this attempt in the task history
|
||||
attempt := queue.Attempt{
|
||||
Attempt: task.RetryCount + 1, // 1-indexed attempt number
|
||||
StartedAt: now,
|
||||
WorkerID: w.id,
|
||||
Status: "running",
|
||||
}
|
||||
task.Attempts = append(task.Attempts, attempt)
|
||||
|
||||
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
|
||||
w.recordTaskProvenance(taskCtx, task)
|
||||
|
||||
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
|
||||
logger.Error("failed to update task status", "error", err)
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
|
||||
if w.config.AutoFetchData && len(task.Datasets) > 0 {
|
||||
if err := w.fetchDatasets(taskCtx, task); err != nil {
|
||||
logger.Error("data fetch failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
err := w.queue.UpdateTask(task)
|
||||
if err != nil {
|
||||
logger.Error("failed to update task status after data fetch failure", "error", err)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
|
||||
logger.Error("dataset checksum verification failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task after dataset checksum verification failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
if err := w.verifySnapshot(taskCtx, task); err != nil {
|
||||
logger.Error("snapshot checksum verification failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
||||
logger.Error(
|
||||
"failed to update task after snapshot checksum verification failure",
|
||||
"error", updateErr)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Execute job with panic recovery
|
||||
var execErr error
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
execErr = fmt.Errorf("panic during execution: %v", r)
|
||||
}
|
||||
}()
|
||||
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
|
||||
execErr = w.runJob(taskCtx, task, cudaVisible)
|
||||
}()
|
||||
|
||||
// Finalize task and record attempt completion
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
|
||||
// Update the last attempt with completion info
|
||||
if len(task.Attempts) > 0 {
|
||||
lastIdx := len(task.Attempts) - 1
|
||||
task.Attempts[lastIdx].EndedAt = &endTime
|
||||
}
|
||||
|
||||
if execErr != nil {
|
||||
task.Error = execErr.Error()
|
||||
|
||||
// Update last attempt with failure info
|
||||
if len(task.Attempts) > 0 {
|
||||
lastIdx := len(task.Attempts) - 1
|
||||
task.Attempts[lastIdx].Status = "failed"
|
||||
task.Attempts[lastIdx].Error = execErr.Error()
|
||||
task.Attempts[lastIdx].FailureClass = queue.ClassifyFailure(0, nil, execErr.Error())
|
||||
// TODO: Capture exit code and signal from actual execution
|
||||
}
|
||||
|
||||
// Check if transient error (network, timeout, etc)
|
||||
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
|
||||
w.logger.Warn("task failed with transient error, will retry",
|
||||
"task_id", task.ID,
|
||||
"error", execErr,
|
||||
"retry_count", task.RetryCount)
|
||||
_ = w.queue.RetryTask(task)
|
||||
} else {
|
||||
task.Status = "failed"
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
// Update last attempt with success
|
||||
if len(task.Attempts) > 0 {
|
||||
lastIdx := len(task.Attempts) - 1
|
||||
task.Attempts[lastIdx].Status = "completed"
|
||||
}
|
||||
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
|
||||
// Release lease
|
||||
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
}
|
||||
|
||||
// Heartbeat loop to renew lease.
|
||||
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
|
||||
ticker := time.NewTicker(w.config.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
|
||||
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
|
||||
return
|
||||
}
|
||||
// Also update worker heartbeat
|
||||
_ = w.queue.Heartbeat(w.config.WorkerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the worker.
|
||||
func (w *Worker) Shutdown() error {
|
||||
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
|
||||
|
||||
// Wait for active tasks with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.gracefulWait.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timeout := time.After(w.config.GracefulTimeout)
|
||||
select {
|
||||
case <-done:
|
||||
w.logger.Info("all tasks completed, shutdown successful")
|
||||
case <-timeout:
|
||||
w.logger.Warn("graceful shutdown timeout, releasing active leases")
|
||||
w.releaseAllLeases()
|
||||
}
|
||||
|
||||
return w.queue.Close()
|
||||
}
|
||||
|
||||
// Release all active leases.
|
||||
func (w *Worker) releaseAllLeases() {
|
||||
w.activeTasks.Range(func(key, _ interface{}) bool {
|
||||
taskID := key.(string)
|
||||
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
|
||||
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions.
|
||||
func (w *Worker) countActiveTasks() int {
|
||||
count := 0
|
||||
w.activeTasks.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Check if error is transient (network, timeout, resource unavailable, etc)
|
||||
errStr := err.Error()
|
||||
transientIndicators := []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"resource temporarily unavailable",
|
||||
"no such host",
|
||||
"network unreachable",
|
||||
}
|
||||
for _, indicator := range transientIndicators {
|
||||
if strings.Contains(strings.ToLower(errStr), indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
// Package worker provides the ML task worker implementation
|
||||
package worker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||
)
|
||||
|
||||
// SimplifiedWorker demonstrates the target architecture for the worker refactor.
|
||||
// This is Phase 5 of the architectural refactoring plan.
|
||||
//
|
||||
// Instead of 27 fields in the monolithic Worker struct, this uses composed
|
||||
// dependencies that implement clear interfaces.
|
||||
//
|
||||
// Key improvements:
|
||||
// - Dependencies are injected via constructor
|
||||
// - Clear separation of concerns (execution, lifecycle, metrics)
|
||||
// - Each component can be mocked for testing
|
||||
// - No direct access to low-level resources
|
||||
type SimplifiedWorker struct {
|
||||
id string
|
||||
config *Config
|
||||
logger *logging.Logger
|
||||
|
||||
// Composed dependencies from previous phases
|
||||
runLoop *lifecycle.RunLoop
|
||||
runner *executor.JobRunner
|
||||
metrics lifecycle.MetricsRecorder
|
||||
health *lifecycle.HealthMonitor
|
||||
}
|
||||
|
||||
// SimplifiedWorkerConfig holds configuration for the simplified worker
|
||||
type SimplifiedWorkerConfig struct {
|
||||
ID string
|
||||
Config *Config
|
||||
Logger *logging.Logger
|
||||
Queue queue.Backend
|
||||
JobRunner *executor.JobRunner
|
||||
Metrics lifecycle.MetricsRecorder
|
||||
Executor lifecycle.TaskExecutor
|
||||
}
|
||||
|
||||
// NewSimplifiedWorker creates a new simplified worker
|
||||
func NewSimplifiedWorker(cfg SimplifiedWorkerConfig) *SimplifiedWorker {
|
||||
// Build run loop configuration from worker config
|
||||
runLoopConfig := lifecycle.RunLoopConfig{
|
||||
WorkerID: cfg.ID,
|
||||
MaxWorkers: cfg.Config.MaxWorkers,
|
||||
PollInterval: time.Duration(cfg.Config.PollInterval) * time.Second,
|
||||
TaskLeaseDuration: cfg.Config.TaskLeaseDuration,
|
||||
HeartbeatInterval: cfg.Config.HeartbeatInterval,
|
||||
GracefulTimeout: cfg.Config.GracefulTimeout,
|
||||
PrewarmEnabled: cfg.Config.PrewarmEnabled,
|
||||
}
|
||||
|
||||
// Create run loop
|
||||
runLoop := lifecycle.NewRunLoop(
|
||||
runLoopConfig,
|
||||
cfg.Queue,
|
||||
cfg.Executor,
|
||||
cfg.Metrics,
|
||||
cfg.Logger,
|
||||
)
|
||||
|
||||
return &SimplifiedWorker{
|
||||
id: cfg.ID,
|
||||
config: cfg.Config,
|
||||
logger: cfg.Logger,
|
||||
runLoop: runLoop,
|
||||
runner: cfg.JobRunner,
|
||||
metrics: cfg.Metrics,
|
||||
health: lifecycle.NewHealthMonitor(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the worker's main processing loop
|
||||
func (w *SimplifiedWorker) Start() {
|
||||
w.logger.Info("simplified worker starting",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers)
|
||||
|
||||
w.runLoop.Start()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the worker
|
||||
func (w *SimplifiedWorker) Stop() {
|
||||
w.logger.Info("simplified worker stopping", "worker_id", w.id)
|
||||
w.runLoop.Stop()
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the worker is healthy
|
||||
func (w *SimplifiedWorker) IsHealthy() bool {
|
||||
return w.health.IsHealthy(5 * time.Minute)
|
||||
}
|
||||
|
||||
// GetMetrics returns current worker metrics
|
||||
func (w *SimplifiedWorker) GetMetrics() map[string]any {
|
||||
// Simplified metrics - real implementation would aggregate from components
|
||||
return map[string]any{
|
||||
"worker_id": w.id,
|
||||
"max_workers": w.config.MaxWorkers,
|
||||
"healthy": w.IsHealthy(),
|
||||
}
|
||||
}
|
||||
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
|
@ -92,7 +93,7 @@ func ResolveSnapshot(
|
|||
if err := container.ValidateJobName(snapshotID); err != nil {
|
||||
return "", fmt.Errorf("invalid snapshot_id: %w", err)
|
||||
}
|
||||
want, err := normalizeSHA256ChecksumHex(wantSHA256)
|
||||
want, err := integrity.NormalizeSHA256ChecksumHex(wantSHA256)
|
||||
if err != nil || want == "" {
|
||||
return "", fmt.Errorf("invalid snapshot_sha256")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,29 +1,24 @@
|
|||
// Package worker provides the ML task worker implementation
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/envpool"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||
)
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility.
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
SSHClient interface{}
|
||||
}
|
||||
|
||||
// JupyterManager is the subset of the Jupyter service manager used by the worker.
|
||||
// It exists to keep task execution testable.
|
||||
// JupyterManager interface for jupyter service management
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
|
|
@ -34,59 +29,98 @@ type JupyterManager interface {
|
|||
}
|
||||
|
||||
// isValidName validates that input strings contain only safe characters.
|
||||
// isValidName checks if the input string is a valid name.
|
||||
func isValidName(input string) bool {
|
||||
return len(input) > 0 && len(input) < 256
|
||||
}
|
||||
|
||||
// NewMLServer creates a new ML server connection.
|
||||
// NewMLServer returns a new MLServer instance.
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
if cfg.LocalMode {
|
||||
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
|
||||
}
|
||||
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MLServer{SSHClient: client}, nil
|
||||
return &MLServer{}, nil
|
||||
}
|
||||
|
||||
// Worker represents an ML task worker.
|
||||
// Worker represents an ML task worker with composed dependencies.
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
server *MLServer
|
||||
queue queue.Backend
|
||||
resources *resources.Manager
|
||||
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
||||
runningMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
id string
|
||||
config *Config
|
||||
logger *logging.Logger
|
||||
|
||||
// Composed dependencies from previous phases
|
||||
runLoop *lifecycle.RunLoop
|
||||
runner *executor.JobRunner
|
||||
metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
health *lifecycle.HealthMonitor
|
||||
|
||||
datasetCache map[string]time.Time
|
||||
datasetCacheMu sync.RWMutex
|
||||
datasetCacheTTL time.Duration
|
||||
// Legacy fields for backward compatibility during migration
|
||||
jupyter JupyterManager
|
||||
}
|
||||
|
||||
// Graceful shutdown fields
|
||||
shutdownCh chan struct{}
|
||||
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
||||
gracefulWait sync.WaitGroup
|
||||
// Start begins the worker's main processing loop.
|
||||
func (w *Worker) Start() {
|
||||
w.logger.Info("worker starting",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers)
|
||||
|
||||
podman *container.PodmanManager
|
||||
jupyter JupyterManager
|
||||
trackingRegistry *tracking.Registry
|
||||
envPool *envpool.Pool
|
||||
w.health.RecordHeartbeat()
|
||||
w.runLoop.Start()
|
||||
}
|
||||
|
||||
prewarmMu sync.Mutex
|
||||
prewarmTargetID string
|
||||
prewarmCancel context.CancelFunc
|
||||
prewarmStartedAt time.Time
|
||||
// Stop gracefully shuts down the worker immediately.
|
||||
func (w *Worker) Stop() {
|
||||
w.logger.Info("worker stopping", "worker_id", w.id)
|
||||
w.runLoop.Stop()
|
||||
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
||||
}
|
||||
|
||||
// Shutdown performs a graceful shutdown with timeout.
|
||||
func (w *Worker) Shutdown() error {
|
||||
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
|
||||
|
||||
w.runLoop.Stop()
|
||||
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the worker is healthy.
|
||||
func (w *Worker) IsHealthy() bool {
|
||||
return w.health.IsHealthy(5 * time.Minute)
|
||||
}
|
||||
|
||||
// GetMetrics returns current worker metrics.
|
||||
func (w *Worker) GetMetrics() map[string]any {
|
||||
stats := w.metrics.GetStats()
|
||||
stats["worker_id"] = w.id
|
||||
stats["max_workers"] = w.config.MaxWorkers
|
||||
stats["healthy"] = w.IsHealthy()
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetID returns the worker ID.
|
||||
func (w *Worker) GetID() string {
|
||||
return w.id
|
||||
}
|
||||
|
||||
// runningCount returns the number of currently running tasks
|
||||
func (w *Worker) runningCount() int {
|
||||
return 0 // Placeholder - will be implemented with runLoop integration
|
||||
}
|
||||
|
||||
func (w *Worker) getGPUDetector() GPUDetector {
|
||||
|
|
|
|||
Loading…
Reference in a new issue