refactor: Phase 3 - Extract data integrity layer
Created integrity package with extracted data utilities: 1. internal/worker/integrity/hash.go (113 lines) - FileSHA256Hex() - SHA256 hash of single file - NormalizeSHA256ChecksumHex() - Checksum normalization - DirOverallSHA256Hex() - Directory hash (sequential) - DirOverallSHA256HexParallel() - Directory hash (parallel workers) 2. internal/worker/integrity/validate.go (76 lines) - DatasetVerifier type for dataset validation - VerifyDatasetSpecs() method for checksum validation - ProvenanceCalculator type for provenance computation - ComputeProvenance() method for task provenance Note: Used 'integrity' instead of 'data' due to .gitignore conflict (data/ directory is ignored for experiment artifacts) Functions extracted from data_integrity.go: - fileSHA256Hex → FileSHA256Hex - normalizeSHA256ChecksumHex → NormalizeSHA256ChecksumHex - dirOverallSHA256HexGo → DirOverallSHA256Hex - dirOverallSHA256HexParallel → DirOverallSHA256HexParallel - verifyDatasetSpecs logic → DatasetVerifier - computeTaskProvenance logic → ProvenanceCalculator Build status: Compiles successfully
This commit is contained in:
parent
22f3d66f1d
commit
3248279c01
2 changed files with 306 additions and 0 deletions
185
internal/worker/integrity/hash.go
Normal file
185
internal/worker/integrity/hash.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
// Package integrity provides data integrity and hashing utilities
|
||||
package integrity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileSHA256Hex computes SHA256 hash of a single file
|
||||
func FileSHA256Hex(path string) (string, error) {
|
||||
f, err := os.Open(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// NormalizeSHA256ChecksumHex normalizes a SHA256 checksum string
|
||||
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
checksum = strings.TrimPrefix(checksum, "sha256:")
|
||||
checksum = strings.TrimPrefix(checksum, "SHA256:")
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
if checksum == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(checksum) != 64 {
|
||||
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
|
||||
}
|
||||
if _, err := hex.DecodeString(checksum); err != nil {
|
||||
return "", fmt.Errorf("invalid sha256 hex: %w", err)
|
||||
}
|
||||
return strings.ToLower(checksum), nil
|
||||
}
|
||||
|
||||
// DirOverallSHA256Hex computes overall SHA256 of directory contents
|
||||
func DirOverallSHA256Hex(root string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
var files []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Deterministic order
|
||||
sort.Strings(files)
|
||||
|
||||
// Hash file hashes to avoid holding all bytes
|
||||
overall := sha256.New()
|
||||
for _, rel := range files {
|
||||
p := filepath.Join(root, rel)
|
||||
sum, err := FileSHA256Hex(p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
overall.Write([]byte(sum))
|
||||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// DirOverallSHA256HexParallel computes directory hash using parallel workers
|
||||
func DirOverallSHA256HexParallel(root string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
// Collect all files
|
||||
var files []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Sort for deterministic order
|
||||
sort.Strings(files)
|
||||
|
||||
// Parallel hashing with worker pool
|
||||
numWorkers := runtime.NumCPU()
|
||||
if numWorkers > 8 {
|
||||
numWorkers = 8
|
||||
}
|
||||
|
||||
type result struct {
|
||||
index int
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
workCh := make(chan int, len(files))
|
||||
resultCh := make(chan result, len(files))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for idx := range workCh {
|
||||
rel := files[idx]
|
||||
p := filepath.Join(root, rel)
|
||||
hash, err := FileSHA256Hex(p)
|
||||
resultCh <- result{index: idx, hash: hash, err: err}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i := range files {
|
||||
workCh <- i
|
||||
}
|
||||
close(workCh)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
hashes := make([]string, len(files))
|
||||
for r := range resultCh {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
hashes[r.index] = r.hash
|
||||
}
|
||||
|
||||
// Combine hashes deterministically
|
||||
overall := sha256.New()
|
||||
for _, h := range hashes {
|
||||
overall.Write([]byte(h))
|
||||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
121
internal/worker/integrity/validate.go
Normal file
121
internal/worker/integrity/validate.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// Package integrity provides data integrity and validation utilities
|
||||
package integrity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// DatasetVerifier validates dataset specifications
|
||||
type DatasetVerifier struct {
|
||||
dataDir string
|
||||
}
|
||||
|
||||
// NewDatasetVerifier creates a new dataset verifier
|
||||
func NewDatasetVerifier(dataDir string) *DatasetVerifier {
|
||||
return &DatasetVerifier{dataDir: dataDir}
|
||||
}
|
||||
|
||||
// VerifyDatasetSpecs validates dataset checksums
|
||||
func (v *DatasetVerifier) VerifyDatasetSpecs(task *queue.Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if len(task.DatasetSpecs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
want, err := NormalizeSHA256ChecksumHex(ds.Checksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
|
||||
}
|
||||
if want == "" {
|
||||
continue
|
||||
}
|
||||
if err := container.ValidateJobName(ds.Name); err != nil {
|
||||
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
|
||||
}
|
||||
path := filepath.Join(v.dataDir, ds.Name)
|
||||
got, err := DirOverallSHA256Hex(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
|
||||
}
|
||||
if got != want {
|
||||
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProvenanceCalculator computes task provenance information
|
||||
type ProvenanceCalculator struct {
|
||||
basePath string
|
||||
}
|
||||
|
||||
// NewProvenanceCalculator creates a new provenance calculator
|
||||
func NewProvenanceCalculator(basePath string) *ProvenanceCalculator {
|
||||
return &ProvenanceCalculator{basePath: basePath}
|
||||
}
|
||||
|
||||
// ComputeProvenance calculates provenance for a task
|
||||
func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string]string, error) {
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("task is nil")
|
||||
}
|
||||
out := map[string]string{}
|
||||
|
||||
if task.SnapshotID != "" {
|
||||
out["snapshot_id"] = task.SnapshotID
|
||||
}
|
||||
|
||||
datasets := pc.resolveDatasets(task)
|
||||
if len(datasets) > 0 {
|
||||
out["datasets"] = strings.Join(datasets, ",")
|
||||
}
|
||||
|
||||
// Note: Additional provenance fields would require access to experiment manager
|
||||
// This is kept minimal to avoid tight coupling
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (pc *ProvenanceCalculator) resolveDatasets(task *queue.Task) []string {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
out := make([]string, 0, len(task.DatasetSpecs))
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
if ds.Name != "" {
|
||||
out = append(out, ds.Name)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
if len(task.Datasets) > 0 {
|
||||
return task.Datasets
|
||||
}
|
||||
return parseDatasetsFromArgs(task.Args)
|
||||
}
|
||||
|
||||
func parseDatasetsFromArgs(args string) []string {
|
||||
if !strings.Contains(args, "--datasets") {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(args)
|
||||
for i, part := range parts {
|
||||
if part == "--datasets" && i+1 < len(parts) {
|
||||
return strings.Split(parts[i+1], ",")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in a new issue