// Package integrity provides data integrity and validation utilities package integrity import ( "encoding/json" "fmt" "os" "path/filepath" "strings" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker/executor" ) // DatasetVerifier validates dataset specifications type DatasetVerifier struct { dataDir string } // NewDatasetVerifier creates a new dataset verifier func NewDatasetVerifier(dataDir string) *DatasetVerifier { return &DatasetVerifier{dataDir: dataDir} } // VerifyDatasetSpecs validates dataset checksums func (v *DatasetVerifier) VerifyDatasetSpecs(task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } if len(task.DatasetSpecs) == 0 { return nil } for _, ds := range task.DatasetSpecs { want, err := NormalizeSHA256ChecksumHex(ds.Checksum) if err != nil { return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err) } if want == "" { continue } if err := container.ValidateJobName(ds.Name); err != nil { return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err) } path := filepath.Join(v.dataDir, ds.Name) got, err := DirOverallSHA256Hex(path) if err != nil { return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err) } if got != want { return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got) } } return nil } // ProvenanceCalculator computes task provenance information type ProvenanceCalculator struct { basePath string } // NewProvenanceCalculator creates a new provenance calculator func NewProvenanceCalculator(basePath string) *ProvenanceCalculator { return &ProvenanceCalculator{basePath: basePath} } // ComputeProvenance calculates provenance for a task func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string]string, error) { if task == nil { return nil, fmt.Errorf("task is nil") } out := map[string]string{} if task.SnapshotID != "" { out["snapshot_id"] = task.SnapshotID } datasets := pc.resolveDatasets(task) if len(datasets) > 0 { out["datasets"] = strings.Join(datasets, ",") } // Add dataset_specs as JSON if len(task.DatasetSpecs) > 0 { specsJSON, err := json.Marshal(task.DatasetSpecs) if err == nil { out["dataset_specs"] = string(specsJSON) } } // Get commit_id from metadata and read experiment manifest if commitID := task.Metadata["commit_id"]; commitID != "" { manifestPath := filepath.Join(pc.basePath, commitID, "manifest.json") if data, err := os.ReadFile(manifestPath); err == nil { var manifest struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &manifest); err == nil { out["experiment_manifest_overall_sha"] = manifest.OverallSHA } } // Add deps manifest info if available filesPath := filepath.Join(pc.basePath, commitID, "files") depsName := task.Metadata["deps_manifest_name"] if depsName == "" { // Auto-detect manifest file depsName, _ = executor.SelectDependencyManifest(filesPath) } if depsName != "" { out["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := FileSHA256Hex(depsPath); err == nil { out["deps_manifest_sha256"] = sha } } } return out, nil } func (pc *ProvenanceCalculator) resolveDatasets(task *queue.Task) []string { if task == nil { return nil } if len(task.DatasetSpecs) > 0 { out := make([]string, 0, len(task.DatasetSpecs)) for _, ds := range task.DatasetSpecs { if ds.Name != "" { out = append(out, ds.Name) } } if len(out) > 0 { return out } } if len(task.Datasets) > 0 { return task.Datasets } return parseDatasetsFromArgs(task.Args) } func parseDatasetsFromArgs(args string) []string { if !strings.Contains(args, "--datasets") { return nil } parts := strings.Fields(args) for i, part := range parts { if part == "--datasets" && i+1 < len(parts) { return strings.Split(parts[i+1], ",") } } return nil }