Add comprehensive benchmarking suite for C++ optimization targets: - tests/benchmarks/dataset_hash_bench_test.go - dirOverallSHA256Hex profiling - tests/benchmarks/queue_bench_test.go - filesystem queue profiling - tests/benchmarks/artifact_and_snapshot_bench_test.go - scanArtifacts/extractTarGz profiling - tests/unit/worker/artifacts_test.go - moved from internal/ for clean separation Add parallel Go implementation as baseline for C++ comparison: - internal/worker/data_integrity.go: dirOverallSHA256HexParallel() with worker pool - Benchmarks show 2.1x speedup (3.97ms -> 1.90ms) vs sequential Exported wrappers for testing: - ScanArtifacts() - artifact scanning - ExtractTarGz() - tar.gz extraction - DirOverallSHA256HexParallel() - parallel hashing Profiling results (Apple M2 Ultra): - dirOverallSHA256Hex: 78% syscall overhead (target for mmap C++) - rebuildIndex: 96% syscall overhead (target for binary index C++) - scanArtifacts: 87% syscall overhead (target for fast traversal C++) - extractTarGz: 95% syscall overhead (target for parallel gzip C++) Related: C++ optimization strategy in memory 5d5f0bb6
928 lines
24 KiB
Go
928 lines
24 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/errtypes"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
)
|
|
|
|
// NEW: Fetch datasets using data_manager.
|
|
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
|
|
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
|
logger.Info("fetching datasets",
|
|
"worker_id", w.id,
|
|
"dataset_count", len(task.Datasets))
|
|
|
|
for _, dataset := range task.Datasets {
|
|
if w.datasetIsFresh(dataset) {
|
|
logger.Debug("skipping cached dataset",
|
|
"dataset", dataset)
|
|
continue
|
|
}
|
|
// Check for cancellation before each dataset fetch
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
|
|
default:
|
|
}
|
|
|
|
logger.Info("fetching dataset",
|
|
"worker_id", w.id,
|
|
"dataset", dataset)
|
|
|
|
// Create command with context for cancellation support
|
|
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
|
// Validate inputs to prevent command injection
|
|
if !isValidName(task.JobName) || !isValidName(dataset) {
|
|
cancel()
|
|
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
|
|
}
|
|
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
|
|
cmd := exec.CommandContext(cmdCtx,
|
|
w.config.DataManagerPath,
|
|
"fetch",
|
|
task.JobName,
|
|
dataset,
|
|
)
|
|
|
|
output, err := cmd.CombinedOutput()
|
|
cancel() // Clean up context
|
|
|
|
if err != nil {
|
|
return &errtypes.DataFetchError{
|
|
Dataset: dataset,
|
|
JobName: task.JobName,
|
|
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
|
|
}
|
|
}
|
|
|
|
logger.Info("dataset ready",
|
|
"worker_id", w.id,
|
|
"dataset", dataset)
|
|
w.markDatasetFetched(dataset)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func resolveDatasets(task *queue.Task) []string {
|
|
if task == nil {
|
|
return nil
|
|
}
|
|
if len(task.DatasetSpecs) > 0 {
|
|
out := make([]string, 0, len(task.DatasetSpecs))
|
|
for _, ds := range task.DatasetSpecs {
|
|
if ds.Name != "" {
|
|
out = append(out, ds.Name)
|
|
}
|
|
}
|
|
if len(out) > 0 {
|
|
return out
|
|
}
|
|
}
|
|
if len(task.Datasets) > 0 {
|
|
return task.Datasets
|
|
}
|
|
return parseDatasets(task.Args)
|
|
}
|
|
|
|
func parseDatasets(args string) []string {
|
|
if !strings.Contains(args, "--datasets") {
|
|
return nil
|
|
}
|
|
|
|
parts := strings.Fields(args)
|
|
for i, part := range parts {
|
|
if part == "--datasets" && i+1 < len(parts) {
|
|
return strings.Split(parts[i+1], ",")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) datasetIsFresh(dataset string) bool {
|
|
w.datasetCacheMu.RLock()
|
|
defer w.datasetCacheMu.RUnlock()
|
|
expires, ok := w.datasetCache[dataset]
|
|
return ok && time.Now().Before(expires)
|
|
}
|
|
|
|
func (w *Worker) markDatasetFetched(dataset string) {
|
|
expires := time.Now().Add(w.datasetCacheTTL)
|
|
w.datasetCacheMu.Lock()
|
|
w.datasetCache[dataset] = expires
|
|
w.datasetCacheMu.Unlock()
|
|
}
|
|
|
|
func (w *Worker) cancelPrewarmLocked() {
|
|
if w.prewarmCancel != nil {
|
|
w.prewarmCancel()
|
|
w.prewarmCancel = nil
|
|
}
|
|
w.prewarmTargetID = ""
|
|
}
|
|
|
|
func (w *Worker) prewarmNextLoop() {
|
|
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
|
return
|
|
}
|
|
if w.ctx == nil || w.queue == nil || w.metrics == nil {
|
|
return
|
|
}
|
|
// Phase 1: Best-effort prewarm of the next queued task.
|
|
// This must never be required for correctness.
|
|
runOnce := func() {
|
|
_, err := w.PrewarmNextOnce(w.ctx)
|
|
if err != nil {
|
|
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
|
|
}
|
|
}
|
|
|
|
// Run once immediately so prewarm doesn't lag behind the worker loop.
|
|
runOnce()
|
|
|
|
ticker := time.NewTicker(500 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
w.prewarmMu.Lock()
|
|
w.cancelPrewarmLocked()
|
|
w.prewarmMu.Unlock()
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
runOnce()
|
|
}
|
|
}
|
|
|
|
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
|
return false, nil
|
|
}
|
|
if ctx == nil || w.queue == nil || w.metrics == nil {
|
|
return false, nil
|
|
}
|
|
|
|
next, err := w.queue.PeekNextTask()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if next == nil {
|
|
w.prewarmMu.Lock()
|
|
w.cancelPrewarmLocked()
|
|
w.prewarmMu.Unlock()
|
|
return false, nil
|
|
}
|
|
|
|
return w.prewarmTaskOnce(ctx, next)
|
|
}
|
|
|
|
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
|
|
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
|
|
return false, nil
|
|
}
|
|
if ctx == nil || w.queue == nil || w.metrics == nil {
|
|
return false, nil
|
|
}
|
|
if next == nil {
|
|
return false, nil
|
|
}
|
|
|
|
w.prewarmMu.Lock()
|
|
if w.prewarmTargetID == next.ID {
|
|
w.prewarmMu.Unlock()
|
|
return false, nil
|
|
}
|
|
w.cancelPrewarmLocked()
|
|
prewarmCtx, cancel := context.WithCancel(ctx)
|
|
w.prewarmCancel = cancel
|
|
w.prewarmTargetID = next.ID
|
|
w.prewarmStartedAt = time.Now()
|
|
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
|
|
phase := "datasets"
|
|
dsCnt := len(resolveDatasets(next))
|
|
snapID := next.SnapshotID
|
|
if strings.TrimSpace(snapID) != "" {
|
|
phase = "snapshot"
|
|
} else if dsCnt == 0 {
|
|
phase = "env"
|
|
}
|
|
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
|
WorkerID: w.id,
|
|
TaskID: next.ID,
|
|
SnapshotID: snapID,
|
|
StartedAt: startedAt,
|
|
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
|
Phase: phase,
|
|
DatasetCnt: dsCnt,
|
|
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
|
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
|
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
|
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
|
})
|
|
w.prewarmMu.Unlock()
|
|
|
|
w.logger.Info("prewarm started",
|
|
"worker_id", w.id,
|
|
"task_id", next.ID,
|
|
"snapshot_id", snapID,
|
|
"phase", phase,
|
|
)
|
|
|
|
local := *next
|
|
local.Datasets = resolveDatasets(&local)
|
|
|
|
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
|
|
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
|
|
hasEnv := false
|
|
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
|
|
if local.Metadata != nil {
|
|
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
|
|
commitID := strings.TrimSpace(local.Metadata["commit_id"])
|
|
if depsSHA != "" && commitID != "" {
|
|
expMgr := experiment.NewManager(w.config.BasePath)
|
|
hostWorkspace := expMgr.GetFilesPath(commitID)
|
|
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
|
|
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
|
|
hasEnv = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !hasSnapshot && !hasDatasets && !hasEnv {
|
|
_ = w.queue.ClearWorkerPrewarmState(w.id)
|
|
return false, nil
|
|
}
|
|
|
|
if hasSnapshot {
|
|
want := ""
|
|
if local.Metadata != nil {
|
|
want = local.Metadata["snapshot_sha256"]
|
|
}
|
|
start := time.Now()
|
|
src, err := ResolveSnapshot(
|
|
prewarmCtx,
|
|
w.config.DataDir,
|
|
&w.config.SnapshotStore,
|
|
local.SnapshotID,
|
|
want,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return true, err
|
|
}
|
|
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
|
|
_ = os.RemoveAll(dst)
|
|
if err := copyDir(src, dst); err != nil {
|
|
return true, err
|
|
}
|
|
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
|
|
}
|
|
|
|
if hasDatasets {
|
|
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
|
|
return true, err
|
|
}
|
|
}
|
|
|
|
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
|
|
WorkerID: w.id,
|
|
TaskID: local.ID,
|
|
SnapshotID: local.SnapshotID,
|
|
StartedAt: startedAt,
|
|
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
|
Phase: "ready",
|
|
DatasetCnt: len(local.Datasets),
|
|
EnvHit: w.metrics.PrewarmEnvHit.Load(),
|
|
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
|
|
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
|
|
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
|
|
})
|
|
|
|
w.logger.Info("prewarm ready",
|
|
"worker_id", w.id,
|
|
"task_id", local.ID,
|
|
"snapshot_id", local.SnapshotID,
|
|
)
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
|
|
if task == nil {
|
|
return fmt.Errorf("task is nil")
|
|
}
|
|
if task.SnapshotID == "" {
|
|
return nil
|
|
}
|
|
if err := container.ValidateJobName(task.SnapshotID); err != nil {
|
|
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
|
|
}
|
|
if task.Metadata == nil {
|
|
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
|
}
|
|
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
|
if err != nil {
|
|
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
|
|
}
|
|
if want == "" {
|
|
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
|
|
}
|
|
path, err := ResolveSnapshot(
|
|
ctx,
|
|
w.config.DataDir,
|
|
&w.config.SnapshotStore,
|
|
task.SnapshotID,
|
|
want,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
|
}
|
|
got, err := dirOverallSHA256Hex(path)
|
|
if err != nil {
|
|
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
|
|
}
|
|
if got != want {
|
|
return fmt.Errorf(
|
|
"snapshot %q: checksum mismatch: expected %s, got %s",
|
|
task.SnapshotID,
|
|
want,
|
|
got,
|
|
)
|
|
}
|
|
w.logger.Job(
|
|
ctx,
|
|
task.JobName,
|
|
task.ID,
|
|
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
|
|
return nil
|
|
}
|
|
|
|
func fileSHA256Hex(path string) (string, error) {
|
|
f, err := os.Open(filepath.Clean(path))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer func() { _ = f.Close() }()
|
|
|
|
h := sha256.New()
|
|
if _, err := io.Copy(h, f); err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
|
}
|
|
|
|
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
|
|
checksum = strings.TrimSpace(checksum)
|
|
checksum = strings.TrimPrefix(checksum, "sha256:")
|
|
checksum = strings.TrimPrefix(checksum, "SHA256:")
|
|
checksum = strings.TrimSpace(checksum)
|
|
if checksum == "" {
|
|
return "", nil
|
|
}
|
|
if len(checksum) != 64 {
|
|
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
|
|
}
|
|
if _, err := hex.DecodeString(checksum); err != nil {
|
|
return "", fmt.Errorf("invalid sha256 hex: %w", err)
|
|
}
|
|
return strings.ToLower(checksum), nil
|
|
}
|
|
|
|
func dirOverallSHA256Hex(root string) (string, error) {
|
|
root = filepath.Clean(root)
|
|
info, err := os.Stat(root)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !info.IsDir() {
|
|
return "", fmt.Errorf("not a directory")
|
|
}
|
|
|
|
var files []string
|
|
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
|
if walkErr != nil {
|
|
return walkErr
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
rel, err := filepath.Rel(root, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
files = append(files, rel)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Deterministic order.
|
|
for i := 0; i < len(files); i++ {
|
|
for j := i + 1; j < len(files); j++ {
|
|
if files[i] > files[j] {
|
|
files[i], files[j] = files[j], files[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Hash file hashes to avoid holding all bytes.
|
|
overall := sha256.New()
|
|
for _, rel := range files {
|
|
p := filepath.Join(root, rel)
|
|
sum, err := fileSHA256Hex(p)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
overall.Write([]byte(sum))
|
|
}
|
|
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
|
}
|
|
|
|
// dirOverallSHA256HexParallel is a parallel Go implementation for baseline comparison.
|
|
// This demonstrates best-effort Go performance before C++ optimization.
|
|
// Uses worker pool to hash files in parallel, then combines deterministically.
|
|
func dirOverallSHA256HexParallel(root string) (string, error) {
|
|
root = filepath.Clean(root)
|
|
info, err := os.Stat(root)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !info.IsDir() {
|
|
return "", fmt.Errorf("not a directory")
|
|
}
|
|
|
|
// Collect all files first
|
|
var files []string
|
|
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
|
if walkErr != nil {
|
|
return walkErr
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
rel, err := filepath.Rel(root, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
files = append(files, rel)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Sort for deterministic order
|
|
sort.Strings(files)
|
|
|
|
// Parallel hashing with worker pool
|
|
numWorkers := runtime.NumCPU()
|
|
if numWorkers > 8 {
|
|
numWorkers = 8 // Cap at 8 workers
|
|
}
|
|
|
|
type result struct {
|
|
index int
|
|
hash string
|
|
err error
|
|
}
|
|
|
|
workCh := make(chan int, len(files))
|
|
resultCh := make(chan result, len(files))
|
|
|
|
// Start workers
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numWorkers; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for idx := range workCh {
|
|
rel := files[idx]
|
|
p := filepath.Join(root, rel)
|
|
hash, err := fileSHA256Hex(p)
|
|
resultCh <- result{index: idx, hash: hash, err: err}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Send work
|
|
go func() {
|
|
for i := range files {
|
|
workCh <- i
|
|
}
|
|
close(workCh)
|
|
}()
|
|
|
|
// Collect results
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
}()
|
|
|
|
hashes := make([]string, len(files))
|
|
for r := range resultCh {
|
|
if r.err != nil {
|
|
return "", r.err
|
|
}
|
|
hashes[r.index] = r.hash
|
|
}
|
|
|
|
// Combine hashes deterministically
|
|
overall := sha256.New()
|
|
for _, h := range hashes {
|
|
overall.Write([]byte(h))
|
|
}
|
|
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
|
}
|
|
|
|
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
|
if task == nil {
|
|
return fmt.Errorf("task is nil")
|
|
}
|
|
if len(task.DatasetSpecs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
|
for _, ds := range task.DatasetSpecs {
|
|
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
|
|
if err != nil {
|
|
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
|
|
}
|
|
if want == "" {
|
|
continue
|
|
}
|
|
if err := container.ValidateJobName(ds.Name); err != nil {
|
|
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
|
|
}
|
|
path := filepath.Join(w.config.DataDir, ds.Name)
|
|
got, err := dirOverallSHA256Hex(path)
|
|
if err != nil {
|
|
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
|
|
}
|
|
if got != want {
|
|
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
|
|
}
|
|
logger.Info("dataset checksum verified", "dataset", ds.Name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
|
if task == nil {
|
|
return nil, fmt.Errorf("task is nil")
|
|
}
|
|
out := map[string]string{}
|
|
|
|
if task.SnapshotID != "" {
|
|
out["snapshot_id"] = task.SnapshotID
|
|
}
|
|
|
|
datasets := resolveDatasets(task)
|
|
if len(datasets) > 0 {
|
|
out["datasets"] = strings.Join(datasets, ",")
|
|
}
|
|
if len(task.DatasetSpecs) > 0 {
|
|
b, err := json.Marshal(task.DatasetSpecs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
|
|
}
|
|
out["dataset_specs"] = string(b)
|
|
}
|
|
|
|
if task.Metadata == nil {
|
|
return out, nil
|
|
}
|
|
commitID := task.Metadata["commit_id"]
|
|
if commitID == "" {
|
|
return out, nil
|
|
}
|
|
|
|
expMgr := experiment.NewManager(basePath)
|
|
manifest, err := expMgr.ReadManifest(commitID)
|
|
if err == nil && manifest != nil && manifest.OverallSHA != "" {
|
|
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
|
|
}
|
|
|
|
filesPath := expMgr.GetFilesPath(commitID)
|
|
depName, err := selectDependencyManifest(filesPath)
|
|
if err == nil && depName != "" {
|
|
depPath := filepath.Join(filesPath, depName)
|
|
sha, err := fileSHA256Hex(depPath)
|
|
if err == nil && sha != "" {
|
|
out["deps_manifest_name"] = depName
|
|
out["deps_manifest_sha256"] = sha
|
|
}
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
|
|
if task == nil {
|
|
return
|
|
}
|
|
prov, err := computeTaskProvenance(w.config.BasePath, task)
|
|
if err != nil {
|
|
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
|
|
return
|
|
}
|
|
if len(prov) == 0 {
|
|
return
|
|
}
|
|
if task.Metadata == nil {
|
|
task.Metadata = map[string]string{}
|
|
}
|
|
for k, v := range prov {
|
|
if v == "" {
|
|
continue
|
|
}
|
|
// Phase 1: best-effort only; do not error if overwriting.
|
|
task.Metadata[k] = v
|
|
}
|
|
}
|
|
|
|
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
|
if task == nil {
|
|
return fmt.Errorf("task is nil")
|
|
}
|
|
if task.Metadata == nil {
|
|
return fmt.Errorf("missing task metadata")
|
|
}
|
|
commitID := task.Metadata["commit_id"]
|
|
if commitID == "" {
|
|
return fmt.Errorf("missing commit_id")
|
|
}
|
|
|
|
current, err := computeTaskProvenance(w.config.BasePath, task)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
snapshotCur := ""
|
|
if task.SnapshotID != "" {
|
|
want := ""
|
|
if task.Metadata != nil {
|
|
want = task.Metadata["snapshot_sha256"]
|
|
}
|
|
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
|
|
if nerr != nil {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
|
|
"snapshot_id", task.SnapshotID,
|
|
"error", nerr)
|
|
} else {
|
|
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
|
}
|
|
} else if wantNorm != "" {
|
|
resolved, err := ResolveSnapshot(
|
|
ctx, w.config.DataDir,
|
|
&w.config.SnapshotStore,
|
|
task.SnapshotID,
|
|
wantNorm,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
|
|
"snapshot_id", task.SnapshotID,
|
|
"error", err)
|
|
} else {
|
|
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
|
|
}
|
|
} else {
|
|
sha, err := dirOverallSHA256Hex(resolved)
|
|
if err == nil {
|
|
snapshotCur = sha
|
|
} else if w.config != nil && w.config.ProvenanceBestEffort {
|
|
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
|
|
"snapshot_id", task.SnapshotID,
|
|
"error", err)
|
|
} else {
|
|
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
|
|
}
|
|
}
|
|
}
|
|
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
|
|
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
|
|
// compute from the local snapshot directory if it exists.
|
|
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
|
|
if sha, err := dirOverallSHA256Hex(localPath); err == nil {
|
|
snapshotCur = sha
|
|
}
|
|
}
|
|
}
|
|
|
|
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
|
|
|
type requiredField struct {
|
|
Key string
|
|
Cur string
|
|
}
|
|
required := []requiredField{
|
|
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
|
|
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
|
|
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
|
|
}
|
|
if task.SnapshotID != "" {
|
|
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
|
|
}
|
|
|
|
for _, f := range required {
|
|
want := strings.TrimSpace(task.Metadata[f.Key])
|
|
if f.Key == "snapshot_sha256" {
|
|
norm, nerr := normalizeSHA256ChecksumHex(want)
|
|
if nerr != nil {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
|
|
"snapshot_id", task.SnapshotID,
|
|
"error", nerr)
|
|
want = ""
|
|
} else {
|
|
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
|
|
}
|
|
} else {
|
|
want = norm
|
|
}
|
|
}
|
|
if want == "" {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
logger.Warn("missing provenance field; continuing due to best-effort mode",
|
|
"field", f.Key)
|
|
if f.Cur != "" {
|
|
if f.Key == "snapshot_sha256" {
|
|
task.Metadata[f.Key] = "sha256:" + f.Cur
|
|
} else {
|
|
task.Metadata[f.Key] = f.Cur
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
return fmt.Errorf("missing provenance field: %s", f.Key)
|
|
}
|
|
if f.Cur == "" {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
|
|
"field", f.Key)
|
|
continue
|
|
}
|
|
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
|
|
}
|
|
if want != f.Cur {
|
|
if w.config != nil && w.config.ProvenanceBestEffort {
|
|
logger.Warn("provenance mismatch; continuing due to best-effort mode",
|
|
"field", f.Key,
|
|
"expected", want,
|
|
"current", f.Cur)
|
|
if f.Key == "snapshot_sha256" {
|
|
task.Metadata[f.Key] = "sha256:" + f.Cur
|
|
} else {
|
|
task.Metadata[f.Key] = f.Cur
|
|
}
|
|
continue
|
|
}
|
|
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func selectDependencyManifest(filesPath string) (string, error) {
|
|
if filesPath == "" {
|
|
return "", fmt.Errorf("missing files path")
|
|
}
|
|
candidates := []string{
|
|
"environment.yml",
|
|
"environment.yaml",
|
|
"poetry.lock",
|
|
"pyproject.toml",
|
|
"requirements.txt",
|
|
}
|
|
for _, name := range candidates {
|
|
p := filepath.Join(filesPath, name)
|
|
if _, err := os.Stat(p); err == nil {
|
|
if name == "poetry.lock" {
|
|
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
|
|
if _, err := os.Stat(pyprojectPath); err != nil {
|
|
return "", fmt.Errorf(
|
|
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
|
|
}
|
|
}
|
|
return name, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf(
|
|
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
|
|
"poetry.lock, pyproject.toml, requirements.txt)")
|
|
}
|
|
|
|
// Exported wrappers for tests under tests/.
|
|
|
|
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
|
|
|
|
func SelectDependencyManifest(filesPath string) (string, error) {
|
|
return selectDependencyManifest(filesPath)
|
|
}
|
|
|
|
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
|
return normalizeSHA256ChecksumHex(checksum)
|
|
}
|
|
|
|
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
|
|
|
|
// DirOverallSHA256HexParallel is an exported wrapper for testing/benchmarking.
|
|
func DirOverallSHA256HexParallel(root string) (string, error) {
|
|
return dirOverallSHA256HexParallel(root)
|
|
}
|
|
|
|
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
|
|
return computeTaskProvenance(basePath, task)
|
|
}
|
|
|
|
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
|
|
return w.enforceTaskProvenance(ctx, task)
|
|
}
|
|
|
|
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
|
return w.verifyDatasetSpecs(ctx, task)
|
|
}
|
|
|
|
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|
return w.verifySnapshot(ctx, task)
|
|
}
|
|
|
|
func NewTestWorker(cfg *Config) *Worker {
|
|
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
logger := baseLogger.Component(ctx, "worker")
|
|
if cfg == nil {
|
|
cfg = &Config{}
|
|
}
|
|
if cfg.DatasetCacheTTL == 0 {
|
|
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
|
}
|
|
return &Worker{
|
|
id: cfg.WorkerID,
|
|
config: cfg,
|
|
logger: logger,
|
|
datasetCache: make(map[string]time.Time),
|
|
datasetCacheTTL: cfg.DatasetCacheTTL,
|
|
}
|
|
}
|
|
|
|
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
|
|
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
logger := baseLogger.Component(ctx, "worker")
|
|
if cfg == nil {
|
|
cfg = &Config{}
|
|
}
|
|
if cfg.DatasetCacheTTL == 0 {
|
|
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
|
}
|
|
return &Worker{
|
|
id: cfg.WorkerID,
|
|
config: cfg,
|
|
logger: logger,
|
|
queue: tq,
|
|
metrics: &metrics.Metrics{},
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
running: make(map[string]context.CancelFunc),
|
|
datasetCache: make(map[string]time.Time),
|
|
datasetCacheTTL: cfg.DatasetCacheTTL,
|
|
}
|
|
}
|
|
|
|
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
|
|
w := NewTestWorkerWithQueue(cfg, tq)
|
|
w.jupyter = jm
|
|
return w
|
|
}
|