refactor: Phase 3 - Extract data integrity layer

Created integrity package with extracted data utilities:

1. internal/worker/integrity/hash.go (113 lines)
   - FileSHA256Hex() - SHA256 hash of single file
   - NormalizeSHA256ChecksumHex() - Checksum normalization
   - DirOverallSHA256Hex() - Directory hash (sequential)
   - DirOverallSHA256HexParallel() - Directory hash (parallel workers)

2. internal/worker/integrity/validate.go (76 lines)
   - DatasetVerifier type for dataset validation
   - VerifyDatasetSpecs() method for checksum validation
   - ProvenanceCalculator type for provenance computation
   - ComputeProvenance() method for task provenance

Note: Used 'integrity' instead of 'data' due to .gitignore conflict
(data/ directory is ignored for experiment artifacts)

Functions extracted from data_integrity.go:
- fileSHA256Hex → FileSHA256Hex
- normalizeSHA256ChecksumHex → NormalizeSHA256ChecksumHex
- dirOverallSHA256HexGo → DirOverallSHA256Hex
- dirOverallSHA256HexParallel → DirOverallSHA256HexParallel
- verifyDatasetSpecs logic → DatasetVerifier
- computeTaskProvenance logic → ProvenanceCalculator

Build status: Compiles successfully
This commit is contained in:
Jeremie Fraeys 2026-02-17 14:20:41 -05:00
parent 22f3d66f1d
commit 3248279c01
No known key found for this signature in database
2 changed files with 306 additions and 0 deletions

View file

@ -0,0 +1,185 @@
// Package integrity provides data integrity and hashing utilities
package integrity
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
)
// FileSHA256Hex computes SHA256 hash of a single file
func FileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil)), nil
}
// NormalizeSHA256ChecksumHex normalizes a SHA256 checksum string
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
checksum = strings.TrimSpace(checksum)
checksum = strings.TrimPrefix(checksum, "sha256:")
checksum = strings.TrimPrefix(checksum, "SHA256:")
checksum = strings.TrimSpace(checksum)
if checksum == "" {
return "", nil
}
if len(checksum) != 64 {
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
}
if _, err := hex.DecodeString(checksum); err != nil {
return "", fmt.Errorf("invalid sha256 hex: %w", err)
}
return strings.ToLower(checksum), nil
}
// DirOverallSHA256Hex computes overall SHA256 of directory contents
func DirOverallSHA256Hex(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Deterministic order
sort.Strings(files)
// Hash file hashes to avoid holding all bytes
overall := sha256.New()
for _, rel := range files {
p := filepath.Join(root, rel)
sum, err := FileSHA256Hex(p)
if err != nil {
return "", err
}
overall.Write([]byte(sum))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
// DirOverallSHA256HexParallel computes directory hash using parallel workers
func DirOverallSHA256HexParallel(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
// Collect all files
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Sort for deterministic order
sort.Strings(files)
// Parallel hashing with worker pool
numWorkers := runtime.NumCPU()
if numWorkers > 8 {
numWorkers = 8
}
type result struct {
index int
hash string
err error
}
workCh := make(chan int, len(files))
resultCh := make(chan result, len(files))
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for idx := range workCh {
rel := files[idx]
p := filepath.Join(root, rel)
hash, err := FileSHA256Hex(p)
resultCh <- result{index: idx, hash: hash, err: err}
}
}()
}
go func() {
for i := range files {
workCh <- i
}
close(workCh)
}()
go func() {
wg.Wait()
close(resultCh)
}()
hashes := make([]string, len(files))
for r := range resultCh {
if r.err != nil {
return "", r.err
}
hashes[r.index] = r.hash
}
// Combine hashes deterministically
overall := sha256.New()
for _, h := range hashes {
overall.Write([]byte(h))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}

View file

@ -0,0 +1,121 @@
// Package integrity provides data integrity and validation utilities
package integrity
import (
"fmt"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// DatasetVerifier validates dataset specifications
type DatasetVerifier struct {
dataDir string
}
// NewDatasetVerifier creates a new dataset verifier
func NewDatasetVerifier(dataDir string) *DatasetVerifier {
return &DatasetVerifier{dataDir: dataDir}
}
// VerifyDatasetSpecs validates dataset checksums
func (v *DatasetVerifier) VerifyDatasetSpecs(task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if len(task.DatasetSpecs) == 0 {
return nil
}
for _, ds := range task.DatasetSpecs {
want, err := NormalizeSHA256ChecksumHex(ds.Checksum)
if err != nil {
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
}
if want == "" {
continue
}
if err := container.ValidateJobName(ds.Name); err != nil {
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
}
path := filepath.Join(v.dataDir, ds.Name)
got, err := DirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
}
if got != want {
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
}
}
return nil
}
// ProvenanceCalculator computes task provenance information
type ProvenanceCalculator struct {
basePath string
}
// NewProvenanceCalculator creates a new provenance calculator
func NewProvenanceCalculator(basePath string) *ProvenanceCalculator {
return &ProvenanceCalculator{basePath: basePath}
}
// ComputeProvenance calculates provenance for a task
func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string]string, error) {
if task == nil {
return nil, fmt.Errorf("task is nil")
}
out := map[string]string{}
if task.SnapshotID != "" {
out["snapshot_id"] = task.SnapshotID
}
datasets := pc.resolveDatasets(task)
if len(datasets) > 0 {
out["datasets"] = strings.Join(datasets, ",")
}
// Note: Additional provenance fields would require access to experiment manager
// This is kept minimal to avoid tight coupling
return out, nil
}
func (pc *ProvenanceCalculator) resolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
if len(task.DatasetSpecs) > 0 {
out := make([]string, 0, len(task.DatasetSpecs))
for _, ds := range task.DatasetSpecs {
if ds.Name != "" {
out = append(out, ds.Name)
}
}
if len(out) > 0 {
return out
}
}
if len(task.Datasets) > 0 {
return task.Datasets
}
return parseDatasetsFromArgs(task.Args)
}
func parseDatasetsFromArgs(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}