fetch_ml/internal/worker/snapshot_store.go
Jeremie Fraeys 158c525bef
fix: resolve benchmark and build tag conflicts
- Remove duplicate hash_selector.go (build tags handle switching)
- Fix benchmark to use worker.DirOverallSHA256Hex
- Fix snapshot_store.go to use integrity.DirOverallSHA256Hex directly
- Native tests pass, benchmarks now correctly test native vs Go
2026-02-21 14:26:48 -05:00

279 lines
6.7 KiB
Go

package worker
import (
"archive/tar"
"compress/gzip"
"context"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type SnapshotFetcher interface {
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
}
type minioSnapshotFetcher struct {
client *minio.Client
}
func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) {
obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{})
if err != nil {
return nil, err
}
return obj, nil
}
func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) {
if cfg == nil {
return nil, fmt.Errorf("missing snapshot store config")
}
endpoint := strings.TrimSpace(cfg.Endpoint)
if endpoint == "" {
return nil, fmt.Errorf("missing snapshot store endpoint")
}
bucket := strings.TrimSpace(cfg.Bucket)
if bucket == "" {
return nil, fmt.Errorf("missing snapshot store bucket")
}
creds := cfg.credentials()
client, err := minio.New(endpoint, &minio.Options{
Creds: creds,
Secure: cfg.Secure,
Region: strings.TrimSpace(cfg.Region),
MaxRetries: cfg.MaxRetries,
})
if err != nil {
return nil, err
}
return &minioSnapshotFetcher{client: client}, nil
}
func (c *SnapshotStoreConfig) credentials() *credentials.Credentials {
if c != nil {
ak := strings.TrimSpace(c.AccessKey)
sk := strings.TrimSpace(c.SecretKey)
st := strings.TrimSpace(c.SessionToken)
if ak != "" && sk != "" {
return credentials.NewStaticV4(ak, sk, st)
}
}
return credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvMinio{},
&credentials.EnvAWS{},
})
}
func ResolveSnapshot(
ctx context.Context,
dataDir string,
cfg *SnapshotStoreConfig,
snapshotID string,
wantSHA256 string,
fetcher SnapshotFetcher,
) (string, error) {
dataDir = strings.TrimSpace(dataDir)
if dataDir == "" {
return "", fmt.Errorf("missing data_dir")
}
snapshotID = strings.TrimSpace(snapshotID)
if snapshotID == "" {
return "", fmt.Errorf("missing snapshot_id")
}
if err := container.ValidateJobName(snapshotID); err != nil {
return "", fmt.Errorf("invalid snapshot_id: %w", err)
}
want, err := integrity.NormalizeSHA256ChecksumHex(wantSHA256)
if err != nil || want == "" {
return "", fmt.Errorf("invalid snapshot_sha256")
}
cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want)
if info, err := os.Stat(cacheDir); err == nil && info.IsDir() {
return cacheDir, nil
}
if cfg == nil || !cfg.Enabled {
return filepath.Join(dataDir, "snapshots", snapshotID), nil
}
bucket := strings.TrimSpace(cfg.Bucket)
if bucket == "" {
return "", fmt.Errorf("missing snapshot store bucket")
}
prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/")
key := snapshotID + ".tar.gz"
if prefix != "" {
key = path.Join(prefix, key)
}
if fetcher == nil {
mf, err := newMinioSnapshotFetcher(cfg)
if err != nil {
return "", err
}
fetcher = mf
}
fetchCtx := ctx
if cfg.Timeout > 0 {
var cancel context.CancelFunc
fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout)
defer cancel()
}
rc, err := fetcher.Get(fetchCtx, bucket, key)
if err != nil {
return "", err
}
defer func() { _ = rc.Close() }()
tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp")
// Use PathRegistry for consistent directory creation
paths := config.FromEnv()
if err := paths.EnsureDir(tmpRoot); err != nil {
return "", err
}
workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-")
if err != nil {
return "", err
}
defer func() { _ = os.RemoveAll(workDir) }()
archivePath := filepath.Join(workDir, "snapshot.tar.gz")
f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return "", err
}
_, copyErr := io.Copy(f, rc)
closeErr := f.Close()
if copyErr != nil {
return "", copyErr
}
if closeErr != nil {
return "", closeErr
}
extractDir := filepath.Join(workDir, "extracted")
if err := paths.EnsureDir(extractDir); err != nil {
return "", err
}
if err := extractTarGz(archivePath, extractDir); err != nil {
return "", err
}
got, err := integrity.DirOverallSHA256Hex(extractDir)
if err != nil {
return "", err
}
if got != want {
return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got)
}
if err := paths.EnsureDir(filepath.Dir(cacheDir)); err != nil {
return "", err
}
if err := os.Rename(extractDir, cacheDir); err != nil {
if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() {
return cacheDir, nil
}
return "", err
}
return cacheDir, nil
}
func extractTarGz(archivePath, dstDir string) error {
archivePath = filepath.Clean(archivePath)
dstDir = filepath.Clean(dstDir)
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer func() { _ = gz.Close() }()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
name := strings.TrimSpace(hdr.Name)
name = strings.TrimPrefix(name, "./")
clean := path.Clean(name)
if clean == "." {
continue
}
if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") {
return fmt.Errorf("invalid tar entry")
}
target, err := safeJoin(dstDir, filepath.FromSlash(clean))
if err != nil {
return err
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0750); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil {
return err
}
mode := hdr.FileInfo().Mode() & 0777
out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return err
}
if _, err := io.CopyN(out, tr, hdr.Size); err != nil {
_ = out.Close()
return err
}
if err := out.Close(); err != nil {
return err
}
default:
return fmt.Errorf("unsupported tar entry type")
}
}
return nil
}
func safeJoin(baseDir, rel string) (string, error) {
baseDir = filepath.Clean(baseDir)
joined := filepath.Join(baseDir, rel)
joined = filepath.Clean(joined)
basePrefix := baseDir + string(os.PathSeparator)
if joined != baseDir && !strings.HasPrefix(joined, basePrefix) {
return "", fmt.Errorf("invalid relative path")
}
return joined, nil
}
// ExtractTarGz is an exported wrapper for testing/benchmarking.
func ExtractTarGz(archivePath, dstDir string) error {
return extractTarGz(archivePath, dstDir)
}