package worker import ( "archive/tar" "compress/gzip" "context" "fmt" "io" "os" "path" "path/filepath" "strings" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/worker/integrity" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" ) // SnapshotFetcher is an interface for fetching snapshots type SnapshotFetcher interface { Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) } type minioSnapshotFetcher struct { client *minio.Client } func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) { obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{}) if err != nil { return nil, err } return obj, nil } func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) { if cfg == nil { return nil, fmt.Errorf("missing snapshot store config") } endpoint := strings.TrimSpace(cfg.Endpoint) if endpoint == "" { return nil, fmt.Errorf("missing snapshot store endpoint") } bucket := strings.TrimSpace(cfg.Bucket) if bucket == "" { return nil, fmt.Errorf("missing snapshot store bucket") } creds := cfg.credentials() client, err := minio.New(endpoint, &minio.Options{ Creds: creds, Secure: cfg.Secure, Region: strings.TrimSpace(cfg.Region), MaxRetries: cfg.MaxRetries, }) if err != nil { return nil, err } return &minioSnapshotFetcher{client: client}, nil } func (c *SnapshotStoreConfig) credentials() *credentials.Credentials { if c != nil { ak := strings.TrimSpace(c.AccessKey) sk := strings.TrimSpace(c.SecretKey) st := strings.TrimSpace(c.SessionToken) if ak != "" && sk != "" { return credentials.NewStaticV4(ak, sk, st) } } return credentials.NewChainCredentials([]credentials.Provider{ &credentials.EnvMinio{}, &credentials.EnvAWS{}, }) } func ResolveSnapshot( ctx context.Context, dataDir string, cfg *SnapshotStoreConfig, snapshotID string, wantSHA256 string, fetcher SnapshotFetcher, ) (string, error) { dataDir = strings.TrimSpace(dataDir) if dataDir == "" { return "", fmt.Errorf("missing data_dir") } snapshotID = strings.TrimSpace(snapshotID) if snapshotID == "" { return "", fmt.Errorf("missing snapshot_id") } if err := container.ValidateJobName(snapshotID); err != nil { return "", fmt.Errorf("invalid snapshot_id: %w", err) } want, err := integrity.NormalizeSHA256ChecksumHex(wantSHA256) if err != nil || want == "" { return "", fmt.Errorf("invalid snapshot_sha256") } cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want) if info, err := os.Stat(cacheDir); err == nil && info.IsDir() { return cacheDir, nil } if cfg == nil || !cfg.Enabled { return filepath.Join(dataDir, "snapshots", snapshotID), nil } bucket := strings.TrimSpace(cfg.Bucket) if bucket == "" { return "", fmt.Errorf("missing snapshot store bucket") } prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/") key := snapshotID + ".tar.gz" if prefix != "" { key = path.Join(prefix, key) } if fetcher == nil { mf, err := newMinioSnapshotFetcher(cfg) if err != nil { return "", err } fetcher = mf } fetchCtx := ctx if cfg.Timeout > 0 { var cancel context.CancelFunc fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout) defer cancel() } rc, err := fetcher.Get(fetchCtx, bucket, key) if err != nil { return "", err } defer func() { _ = rc.Close() }() tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp") // Use PathRegistry for consistent directory creation paths := config.FromEnv() if err := paths.EnsureDir(tmpRoot); err != nil { return "", err } workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-") if err != nil { return "", err } defer func() { _ = os.RemoveAll(workDir) }() archivePath := filepath.Join(workDir, "snapshot.tar.gz") f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { return "", err } _, copyErr := io.Copy(f, rc) closeErr := f.Close() if copyErr != nil { return "", copyErr } if closeErr != nil { return "", closeErr } extractDir := filepath.Join(workDir, "extracted") if err := paths.EnsureDir(extractDir); err != nil { return "", err } if err := extractTarGz(archivePath, extractDir); err != nil { return "", err } got, err := integrity.DirOverallSHA256Hex(extractDir) if err != nil { return "", err } if got != want { return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got) } if err := paths.EnsureDir(filepath.Dir(cacheDir)); err != nil { return "", err } if err := os.Rename(extractDir, cacheDir); err != nil { if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() { return cacheDir, nil } return "", err } return cacheDir, nil } func extractTarGz(archivePath, dstDir string) error { archivePath = filepath.Clean(archivePath) dstDir = filepath.Clean(dstDir) f, err := os.Open(archivePath) if err != nil { return err } defer func() { _ = f.Close() }() gz, err := gzip.NewReader(f) if err != nil { return err } defer func() { _ = gz.Close() }() tr := tar.NewReader(gz) for { hdr, err := tr.Next() if err == io.EOF { break } if err != nil { return err } name := strings.TrimSpace(hdr.Name) name = strings.TrimPrefix(name, "./") clean := path.Clean(name) if clean == "." { continue } if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") { return fmt.Errorf("invalid tar entry") } target, err := safeJoin(dstDir, filepath.FromSlash(clean)) if err != nil { return err } switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(target, 0750); err != nil { return err } case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil { return err } mode := hdr.FileInfo().Mode() & 0777 out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) if err != nil { return err } if _, err := io.CopyN(out, tr, hdr.Size); err != nil { _ = out.Close() return err } if err := out.Close(); err != nil { return err } default: return fmt.Errorf("unsupported tar entry type") } } return nil } func safeJoin(baseDir, rel string) (string, error) { baseDir = filepath.Clean(baseDir) joined := filepath.Join(baseDir, rel) joined = filepath.Clean(joined) basePrefix := baseDir + string(os.PathSeparator) if joined != baseDir && !strings.HasPrefix(joined, basePrefix) { return "", fmt.Errorf("invalid relative path") } return joined, nil } // ExtractTarGz is an exported wrapper for testing/benchmarking. func ExtractTarGz(archivePath, dstDir string) error { return extractTarGz(archivePath, dstDir) }