fetch_ml/internal/worker/data_integrity.go

824 lines
21 KiB
Go

package worker
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// NEW: Fetch datasets using data_manager.
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
logger := w.logger.Job(ctx, task.JobName, task.ID)
logger.Info("fetching datasets",
"worker_id", w.id,
"dataset_count", len(task.Datasets))
for _, dataset := range task.Datasets {
if w.datasetIsFresh(dataset) {
logger.Debug("skipping cached dataset",
"dataset", dataset)
continue
}
// Check for cancellation before each dataset fetch
select {
case <-ctx.Done():
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
default:
}
logger.Info("fetching dataset",
"worker_id", w.id,
"dataset", dataset)
// Create command with context for cancellation support
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
// Validate inputs to prevent command injection
if !isValidName(task.JobName) || !isValidName(dataset) {
cancel()
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
}
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
cmd := exec.CommandContext(cmdCtx,
w.config.DataManagerPath,
"fetch",
task.JobName,
dataset,
)
output, err := cmd.CombinedOutput()
cancel() // Clean up context
if err != nil {
return &errtypes.DataFetchError{
Dataset: dataset,
JobName: task.JobName,
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
}
}
logger.Info("dataset ready",
"worker_id", w.id,
"dataset", dataset)
w.markDatasetFetched(dataset)
}
return nil
}
func resolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
if len(task.DatasetSpecs) > 0 {
out := make([]string, 0, len(task.DatasetSpecs))
for _, ds := range task.DatasetSpecs {
if ds.Name != "" {
out = append(out, ds.Name)
}
}
if len(out) > 0 {
return out
}
}
if len(task.Datasets) > 0 {
return task.Datasets
}
return parseDatasets(task.Args)
}
func parseDatasets(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}
func (w *Worker) datasetIsFresh(dataset string) bool {
w.datasetCacheMu.RLock()
defer w.datasetCacheMu.RUnlock()
expires, ok := w.datasetCache[dataset]
return ok && time.Now().Before(expires)
}
func (w *Worker) markDatasetFetched(dataset string) {
expires := time.Now().Add(w.datasetCacheTTL)
w.datasetCacheMu.Lock()
w.datasetCache[dataset] = expires
w.datasetCacheMu.Unlock()
}
func (w *Worker) cancelPrewarmLocked() {
if w.prewarmCancel != nil {
w.prewarmCancel()
w.prewarmCancel = nil
}
w.prewarmTargetID = ""
}
func (w *Worker) prewarmNextLoop() {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.ctx == nil || w.queue == nil || w.metrics == nil {
return
}
// Phase 1: Best-effort prewarm of the next queued task.
// This must never be required for correctness.
runOnce := func() {
_, err := w.PrewarmNextOnce(w.ctx)
if err != nil {
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
}
}
// Run once immediately so prewarm doesn't lag behind the worker loop.
runOnce()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return
case <-ticker.C:
}
runOnce()
}
}
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
next, err := w.queue.PeekNextTask()
if err != nil {
return false, err
}
if next == nil {
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return false, nil
}
return w.prewarmTaskOnce(ctx, next)
}
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
if next == nil {
return false, nil
}
w.prewarmMu.Lock()
if w.prewarmTargetID == next.ID {
w.prewarmMu.Unlock()
return false, nil
}
w.cancelPrewarmLocked()
prewarmCtx, cancel := context.WithCancel(ctx)
w.prewarmCancel = cancel
w.prewarmTargetID = next.ID
w.prewarmStartedAt = time.Now()
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
phase := "datasets"
dsCnt := len(resolveDatasets(next))
snapID := next.SnapshotID
if strings.TrimSpace(snapID) != "" {
phase = "snapshot"
} else if dsCnt == 0 {
phase = "env"
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: next.ID,
SnapshotID: snapID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: phase,
DatasetCnt: dsCnt,
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.prewarmMu.Unlock()
w.logger.Info("prewarm started",
"worker_id", w.id,
"task_id", next.ID,
"snapshot_id", snapID,
"phase", phase,
)
local := *next
local.Datasets = resolveDatasets(&local)
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
hasEnv := false
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
if local.Metadata != nil {
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
commitID := strings.TrimSpace(local.Metadata["commit_id"])
if depsSHA != "" && commitID != "" {
expMgr := experiment.NewManager(w.config.BasePath)
hostWorkspace := expMgr.GetFilesPath(commitID)
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
hasEnv = true
}
}
}
}
}
if !hasSnapshot && !hasDatasets && !hasEnv {
_ = w.queue.ClearWorkerPrewarmState(w.id)
return false, nil
}
if hasSnapshot {
want := ""
if local.Metadata != nil {
want = local.Metadata["snapshot_sha256"]
}
start := time.Now()
src, err := ResolveSnapshot(
prewarmCtx,
w.config.DataDir,
&w.config.SnapshotStore,
local.SnapshotID,
want,
nil,
)
if err != nil {
return true, err
}
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
_ = os.RemoveAll(dst)
if err := copyDir(src, dst); err != nil {
return true, err
}
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
}
if hasDatasets {
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
return true, err
}
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: local.ID,
SnapshotID: local.SnapshotID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: "ready",
DatasetCnt: len(local.Datasets),
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.logger.Info("prewarm ready",
"worker_id", w.id,
"task_id", local.ID,
"snapshot_id", local.SnapshotID,
)
return true, nil
}
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.SnapshotID == "" {
return nil
}
if err := container.ValidateJobName(task.SnapshotID); err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
}
if task.Metadata == nil {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
}
if want == "" {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
path, err := ResolveSnapshot(
ctx,
w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
want,
nil,
)
if err != nil {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
got, err := dirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
}
if got != want {
return fmt.Errorf(
"snapshot %q: checksum mismatch: expected %s, got %s",
task.SnapshotID,
want,
got,
)
}
w.logger.Job(
ctx,
task.JobName,
task.ID,
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
return nil
}
func fileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil)), nil
}
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
checksum = strings.TrimSpace(checksum)
checksum = strings.TrimPrefix(checksum, "sha256:")
checksum = strings.TrimPrefix(checksum, "SHA256:")
checksum = strings.TrimSpace(checksum)
if checksum == "" {
return "", nil
}
if len(checksum) != 64 {
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
}
if _, err := hex.DecodeString(checksum); err != nil {
return "", fmt.Errorf("invalid sha256 hex: %w", err)
}
return strings.ToLower(checksum), nil
}
func dirOverallSHA256Hex(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Deterministic order.
for i := 0; i < len(files); i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
// Hash file hashes to avoid holding all bytes.
overall := sha256.New()
for _, rel := range files {
p := filepath.Join(root, rel)
sum, err := fileSHA256Hex(p)
if err != nil {
return "", err
}
overall.Write([]byte(sum))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if len(task.DatasetSpecs) == 0 {
return nil
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
for _, ds := range task.DatasetSpecs {
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
if err != nil {
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
}
if want == "" {
continue
}
if err := container.ValidateJobName(ds.Name); err != nil {
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
}
path := filepath.Join(w.config.DataDir, ds.Name)
got, err := dirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
}
if got != want {
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
}
logger.Info("dataset checksum verified", "dataset", ds.Name)
}
return nil
}
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
if task == nil {
return nil, fmt.Errorf("task is nil")
}
out := map[string]string{}
if task.SnapshotID != "" {
out["snapshot_id"] = task.SnapshotID
}
datasets := resolveDatasets(task)
if len(datasets) > 0 {
out["datasets"] = strings.Join(datasets, ",")
}
if len(task.DatasetSpecs) > 0 {
b, err := json.Marshal(task.DatasetSpecs)
if err != nil {
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
}
out["dataset_specs"] = string(b)
}
if task.Metadata == nil {
return out, nil
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return out, nil
}
expMgr := experiment.NewManager(basePath)
manifest, err := expMgr.ReadManifest(commitID)
if err == nil && manifest != nil && manifest.OverallSHA != "" {
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
}
filesPath := expMgr.GetFilesPath(commitID)
depName, err := selectDependencyManifest(filesPath)
if err == nil && depName != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := fileSHA256Hex(depPath)
if err == nil && sha != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
if task == nil {
return
}
prov, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
return
}
if len(prov) == 0 {
return
}
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
for k, v := range prov {
if v == "" {
continue
}
// Phase 1: best-effort only; do not error if overwriting.
task.Metadata[k] = v
}
}
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.Metadata == nil {
return fmt.Errorf("missing task metadata")
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return fmt.Errorf("missing commit_id")
}
current, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
return err
}
snapshotCur := ""
if task.SnapshotID != "" {
want := ""
if task.Metadata != nil {
want = task.Metadata["snapshot_sha256"]
}
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", nerr)
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else if wantNorm != "" {
resolved, err := ResolveSnapshot(
ctx, w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
wantNorm,
nil,
)
if err != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
} else {
sha, err := dirOverallSHA256Hex(resolved)
if err == nil {
snapshotCur = sha
} else if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
}
}
}
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
// compute from the local snapshot directory if it exists.
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
if sha, err := dirOverallSHA256Hex(localPath); err == nil {
snapshotCur = sha
}
}
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
type requiredField struct {
Key string
Cur string
}
required := []requiredField{
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
}
if task.SnapshotID != "" {
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
}
for _, f := range required {
want := strings.TrimSpace(task.Metadata[f.Key])
if f.Key == "snapshot_sha256" {
norm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
"snapshot_id", task.SnapshotID,
"error", nerr)
want = ""
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else {
want = norm
}
}
if want == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("missing provenance field; continuing due to best-effort mode",
"field", f.Key)
if f.Cur != "" {
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
}
continue
}
return fmt.Errorf("missing provenance field: %s", f.Key)
}
if f.Cur == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
"field", f.Key)
continue
}
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
}
if want != f.Cur {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("provenance mismatch; continuing due to best-effort mode",
"field", f.Key,
"expected", want,
"current", f.Cur)
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
continue
}
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
}
}
return nil
}
func selectDependencyManifest(filesPath string) (string, error) {
if filesPath == "" {
return "", fmt.Errorf("missing files path")
}
candidates := []string{
"environment.yml",
"environment.yaml",
"poetry.lock",
"pyproject.toml",
"requirements.txt",
}
for _, name := range candidates {
p := filepath.Join(filesPath, name)
if _, err := os.Stat(p); err == nil {
if name == "poetry.lock" {
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
if _, err := os.Stat(pyprojectPath); err != nil {
return "", fmt.Errorf(
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
}
}
return name, nil
}
}
return "", fmt.Errorf(
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
"poetry.lock, pyproject.toml, requirements.txt)")
}
// Exported wrappers for tests under tests/.
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
func SelectDependencyManifest(filesPath string) (string, error) {
return selectDependencyManifest(filesPath)
}
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
return normalizeSHA256ChecksumHex(checksum)
}
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
return computeTaskProvenance(basePath, task)
}
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
return w.enforceTaskProvenance(ctx, task)
}
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
return w.verifyDatasetSpecs(ctx, task)
}
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return w.verifySnapshot(ctx, task)
}
func NewTestWorker(cfg *Config) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
ctx, cancel := context.WithCancel(ctx)
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
queue: tq,
metrics: &metrics.Metrics{},
ctx: ctx,
cancel: cancel,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
w := NewTestWorkerWithQueue(cfg, tq)
w.jupyter = jm
return w
}