fetch_ml/internal/worker/integrity/validate.go
Jeremie Fraeys 0b5e99f720
refactor(scheduler,worker): improve service management and GPU detection
Scheduler enhancements:
- auth.go: Group membership validation in authentication
- hub.go: Task distribution with group affinity
- port_allocator.go: Dynamic port allocation with conflict resolution
- scheduler_conn.go: Connection pooling and retry logic
- service_manager.go: Lifecycle management for scheduler services
- service_templates.go: Template-based service configuration
- state.go: Persistent state management with recovery

Worker improvements:
- config.go: Extended configuration for task visibility rules
- execution/setup.go: Sandboxed execution environment setup
- executor/container.go: Container runtime integration
- executor/runner.go: Task runner with visibility enforcement
- gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback)
- integrity/validate.go: Data integrity validation
- lifecycle/runloop.go: Improved runloop with graceful shutdown
- lifecycle/service_manager.go: Service lifecycle coordination
- process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups
- tenant/manager.go: Multi-tenant resource isolation
- tenant/middleware.go: Tenant context propagation
- worker.go: Core worker with group-scoped task execution
2026-03-08 13:03:15 -04:00

158 lines
4.1 KiB
Go

// Package integrity provides data integrity and validation utilities
package integrity
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
)
// DatasetVerifier validates dataset specifications
type DatasetVerifier struct {
dataDir string
}
// NewDatasetVerifier creates a new dataset verifier
func NewDatasetVerifier(dataDir string) *DatasetVerifier {
return &DatasetVerifier{dataDir: dataDir}
}
// VerifyDatasetSpecs validates dataset checksums
func (v *DatasetVerifier) VerifyDatasetSpecs(task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if len(task.DatasetSpecs) == 0 {
return nil
}
for _, ds := range task.DatasetSpecs {
want, err := NormalizeSHA256ChecksumHex(ds.Checksum)
if err != nil {
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
}
if want == "" {
continue
}
if err := container.ValidateJobName(ds.Name); err != nil {
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
}
path := filepath.Join(v.dataDir, ds.Name)
got, err := DirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
}
if got != want {
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
}
}
return nil
}
// ProvenanceCalculator computes task provenance information
type ProvenanceCalculator struct {
basePath string
}
// NewProvenanceCalculator creates a new provenance calculator
func NewProvenanceCalculator(basePath string) *ProvenanceCalculator {
return &ProvenanceCalculator{basePath: basePath}
}
// ComputeProvenance calculates provenance for a task
func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string]string, error) {
if task == nil {
return nil, fmt.Errorf("task is nil")
}
out := map[string]string{}
if task.SnapshotID != "" {
out["snapshot_id"] = task.SnapshotID
}
datasets := pc.resolveDatasets(task)
if len(datasets) > 0 {
out["datasets"] = strings.Join(datasets, ",")
}
// Add dataset_specs as JSON
if len(task.DatasetSpecs) > 0 {
specsJSON, err := json.Marshal(task.DatasetSpecs)
if err == nil {
out["dataset_specs"] = string(specsJSON)
}
}
// Get commit_id from metadata and read experiment manifest
if commitID := task.Metadata["commit_id"]; commitID != "" {
manifestPath := filepath.Join(pc.basePath, commitID, "manifest.json")
// #nosec G304 -- path is constructed from controlled basePath and validated commitID
if data, err := os.ReadFile(manifestPath); err == nil {
var manifest struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &manifest); err == nil {
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
}
}
// Add deps manifest info if available
filesPath := filepath.Join(pc.basePath, commitID, "files")
depsName := task.Metadata["deps_manifest_name"]
if depsName == "" {
// Auto-detect manifest file
depsName, _ = executor.SelectDependencyManifest(filesPath)
}
if depsName != "" {
out["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := FileSHA256Hex(depsPath); err == nil {
out["deps_manifest_sha256"] = sha
}
}
}
return out, nil
}
func (pc *ProvenanceCalculator) resolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
if len(task.DatasetSpecs) > 0 {
out := make([]string, 0, len(task.DatasetSpecs))
for _, ds := range task.DatasetSpecs {
if ds.Name != "" {
out = append(out, ds.Name)
}
}
if len(out) > 0 {
return out
}
}
if len(task.Datasets) > 0 {
return task.Datasets
}
return parseDatasetsFromArgs(task.Args)
}
func parseDatasetsFromArgs(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}