fetch_ml/internal/api/helpers/hash_helpers.go
Jeremie Fraeys 4c8c9dfe4b
refactor: Export SelectDependencyManifest for API helpers
- Renamed selectDependencyManifest to SelectDependencyManifest (exported)
- Added re-export in worker package for backward compatibility
- Updated internal call in container.go to use exported function
- API helpers can now access via worker.SelectDependencyManifest

Build status: Compiles successfully
2026-02-17 16:45:59 -05:00

121 lines
3.3 KiB
Go

// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
// ComputeDatasetID computes a dataset ID from dataset specs or dataset names.
func ComputeDatasetID(datasetSpecs []queue.DatasetSpec, datasets []string) string {
if len(datasetSpecs) > 0 {
var checksums []string
for _, ds := range datasetSpecs {
if ds.Checksum != "" {
checksums = append(checksums, ds.Checksum)
} else if ds.Name != "" {
checksums = append(checksums, ds.Name)
}
}
if len(checksums) > 0 {
h := sha256.New()
for _, cs := range checksums {
h.Write([]byte(cs))
}
return hex.EncodeToString(h.Sum(nil))[:16]
}
}
if len(datasets) > 0 {
h := sha256.New()
for _, ds := range datasets {
h.Write([]byte(ds))
}
return hex.EncodeToString(h.Sum(nil))[:16]
}
return ""
}
// ComputeParamsHash computes a hash of the args string.
func ComputeParamsHash(args string) string {
if strings.TrimSpace(args) == "" {
return ""
}
h := sha256.New()
h.Write([]byte(strings.TrimSpace(args)))
return hex.EncodeToString(h.Sum(nil))[:16]
}
// FileSHA256Hex computes the SHA256 hash of a file.
// This delegates to the integrity package for consistent hashing.
func FileSHA256Hex(path string) (string, error) {
return integrity.FileSHA256Hex(path)
}
// ExpectedProvenanceForCommit computes expected provenance metadata for a commit.
func ExpectedProvenanceForCommit(
expMgr *experiment.Manager,
commitID string,
) (map[string]string, error) {
out := map[string]string{}
manifest, err := expMgr.ReadManifest(commitID)
if err != nil {
return nil, err
}
if manifest == nil || manifest.OverallSHA == "" {
return nil, fmt.Errorf("missing manifest overall_sha")
}
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
filesPath := expMgr.GetFilesPath(commitID)
depName, err := worker.SelectDependencyManifest(filesPath)
if err == nil && strings.TrimSpace(depName) != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := FileSHA256Hex(depPath)
if err == nil && strings.TrimSpace(sha) != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
// EnsureMinimalExperimentFiles ensures minimal experiment files exist.
func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error {
if expMgr == nil {
return fmt.Errorf("missing experiment manager")
}
commitID = strings.TrimSpace(commitID)
if commitID == "" {
return fmt.Errorf("missing commit id")
}
filesPath := expMgr.GetFilesPath(commitID)
if err := os.MkdirAll(filesPath, 0750); err != nil {
return err
}
trainPath := filepath.Join(filesPath, "train.py")
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
return err
}
}
reqPath := filepath.Join(filesPath, "requirements.txt")
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
return err
}
}
return nil
}