- Remove duplicate hash_selector.go (build tags handle switching) - Fix benchmark to use worker.DirOverallSHA256Hex - Fix snapshot_store.go to use integrity.DirOverallSHA256Hex directly - Native tests pass, benchmarks now correctly test native vs Go
279 lines
6.7 KiB
Go
279 lines
6.7 KiB
Go
package worker
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
"github.com/minio/minio-go/v7"
|
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
|
)
|
|
|
|
type SnapshotFetcher interface {
|
|
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
|
|
}
|
|
|
|
type minioSnapshotFetcher struct {
|
|
client *minio.Client
|
|
}
|
|
|
|
func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) {
|
|
obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return obj, nil
|
|
}
|
|
|
|
func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) {
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("missing snapshot store config")
|
|
}
|
|
endpoint := strings.TrimSpace(cfg.Endpoint)
|
|
if endpoint == "" {
|
|
return nil, fmt.Errorf("missing snapshot store endpoint")
|
|
}
|
|
bucket := strings.TrimSpace(cfg.Bucket)
|
|
if bucket == "" {
|
|
return nil, fmt.Errorf("missing snapshot store bucket")
|
|
}
|
|
creds := cfg.credentials()
|
|
client, err := minio.New(endpoint, &minio.Options{
|
|
Creds: creds,
|
|
Secure: cfg.Secure,
|
|
Region: strings.TrimSpace(cfg.Region),
|
|
MaxRetries: cfg.MaxRetries,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &minioSnapshotFetcher{client: client}, nil
|
|
}
|
|
|
|
func (c *SnapshotStoreConfig) credentials() *credentials.Credentials {
|
|
if c != nil {
|
|
ak := strings.TrimSpace(c.AccessKey)
|
|
sk := strings.TrimSpace(c.SecretKey)
|
|
st := strings.TrimSpace(c.SessionToken)
|
|
if ak != "" && sk != "" {
|
|
return credentials.NewStaticV4(ak, sk, st)
|
|
}
|
|
}
|
|
return credentials.NewChainCredentials([]credentials.Provider{
|
|
&credentials.EnvMinio{},
|
|
&credentials.EnvAWS{},
|
|
})
|
|
}
|
|
|
|
func ResolveSnapshot(
|
|
ctx context.Context,
|
|
dataDir string,
|
|
cfg *SnapshotStoreConfig,
|
|
snapshotID string,
|
|
wantSHA256 string,
|
|
fetcher SnapshotFetcher,
|
|
) (string, error) {
|
|
dataDir = strings.TrimSpace(dataDir)
|
|
if dataDir == "" {
|
|
return "", fmt.Errorf("missing data_dir")
|
|
}
|
|
snapshotID = strings.TrimSpace(snapshotID)
|
|
if snapshotID == "" {
|
|
return "", fmt.Errorf("missing snapshot_id")
|
|
}
|
|
if err := container.ValidateJobName(snapshotID); err != nil {
|
|
return "", fmt.Errorf("invalid snapshot_id: %w", err)
|
|
}
|
|
want, err := integrity.NormalizeSHA256ChecksumHex(wantSHA256)
|
|
if err != nil || want == "" {
|
|
return "", fmt.Errorf("invalid snapshot_sha256")
|
|
}
|
|
|
|
cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want)
|
|
if info, err := os.Stat(cacheDir); err == nil && info.IsDir() {
|
|
return cacheDir, nil
|
|
}
|
|
|
|
if cfg == nil || !cfg.Enabled {
|
|
return filepath.Join(dataDir, "snapshots", snapshotID), nil
|
|
}
|
|
|
|
bucket := strings.TrimSpace(cfg.Bucket)
|
|
if bucket == "" {
|
|
return "", fmt.Errorf("missing snapshot store bucket")
|
|
}
|
|
prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/")
|
|
key := snapshotID + ".tar.gz"
|
|
if prefix != "" {
|
|
key = path.Join(prefix, key)
|
|
}
|
|
|
|
if fetcher == nil {
|
|
mf, err := newMinioSnapshotFetcher(cfg)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
fetcher = mf
|
|
}
|
|
|
|
fetchCtx := ctx
|
|
if cfg.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
rc, err := fetcher.Get(fetchCtx, bucket, key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer func() { _ = rc.Close() }()
|
|
|
|
tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp")
|
|
// Use PathRegistry for consistent directory creation
|
|
paths := config.FromEnv()
|
|
if err := paths.EnsureDir(tmpRoot); err != nil {
|
|
return "", err
|
|
}
|
|
workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer func() { _ = os.RemoveAll(workDir) }()
|
|
|
|
archivePath := filepath.Join(workDir, "snapshot.tar.gz")
|
|
f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, copyErr := io.Copy(f, rc)
|
|
closeErr := f.Close()
|
|
if copyErr != nil {
|
|
return "", copyErr
|
|
}
|
|
if closeErr != nil {
|
|
return "", closeErr
|
|
}
|
|
|
|
extractDir := filepath.Join(workDir, "extracted")
|
|
if err := paths.EnsureDir(extractDir); err != nil {
|
|
return "", err
|
|
}
|
|
if err := extractTarGz(archivePath, extractDir); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
got, err := integrity.DirOverallSHA256Hex(extractDir)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if got != want {
|
|
return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got)
|
|
}
|
|
|
|
if err := paths.EnsureDir(filepath.Dir(cacheDir)); err != nil {
|
|
return "", err
|
|
}
|
|
if err := os.Rename(extractDir, cacheDir); err != nil {
|
|
if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() {
|
|
return cacheDir, nil
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
return cacheDir, nil
|
|
}
|
|
|
|
func extractTarGz(archivePath, dstDir string) error {
|
|
archivePath = filepath.Clean(archivePath)
|
|
dstDir = filepath.Clean(dstDir)
|
|
|
|
f, err := os.Open(archivePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = f.Close() }()
|
|
|
|
gz, err := gzip.NewReader(f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = gz.Close() }()
|
|
|
|
tr := tar.NewReader(gz)
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
name := strings.TrimSpace(hdr.Name)
|
|
name = strings.TrimPrefix(name, "./")
|
|
clean := path.Clean(name)
|
|
if clean == "." {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") {
|
|
return fmt.Errorf("invalid tar entry")
|
|
}
|
|
|
|
target, err := safeJoin(dstDir, filepath.FromSlash(clean))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch hdr.Typeflag {
|
|
case tar.TypeDir:
|
|
if err := os.MkdirAll(target, 0750); err != nil {
|
|
return err
|
|
}
|
|
case tar.TypeReg:
|
|
if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil {
|
|
return err
|
|
}
|
|
mode := hdr.FileInfo().Mode() & 0777
|
|
out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.CopyN(out, tr, hdr.Size); err != nil {
|
|
_ = out.Close()
|
|
return err
|
|
}
|
|
if err := out.Close(); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported tar entry type")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func safeJoin(baseDir, rel string) (string, error) {
|
|
baseDir = filepath.Clean(baseDir)
|
|
joined := filepath.Join(baseDir, rel)
|
|
joined = filepath.Clean(joined)
|
|
basePrefix := baseDir + string(os.PathSeparator)
|
|
if joined != baseDir && !strings.HasPrefix(joined, basePrefix) {
|
|
return "", fmt.Errorf("invalid relative path")
|
|
}
|
|
return joined, nil
|
|
}
|
|
|
|
// ExtractTarGz is an exported wrapper for testing/benchmarking.
|
|
func ExtractTarGz(archivePath, dstDir string) error {
|
|
return extractTarGz(archivePath, dstDir)
|
|
}
|