fetch_ml/tests/unit/worker/snapshot_store_test.go

139 lines
3.7 KiB
Go

package worker_test
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"io"
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
)
type memFetcher struct {
calls int
data []byte
err error
}
func (m *memFetcher) Get(_ context.Context, _, _ string) (io.ReadCloser, error) {
m.calls++
if m.err != nil {
return nil, m.err
}
return io.NopCloser(bytes.NewReader(m.data)), nil
}
func makeTarGz(t *testing.T, files map[string][]byte) []byte {
t.Helper()
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for name, b := range files {
h := &tar.Header{
Name: name,
Mode: 0644,
Size: int64(len(b)),
}
if err := tw.WriteHeader(h); err != nil {
t.Fatalf("tar header: %v", err)
}
if _, err := tw.Write(b); err != nil {
t.Fatalf("tar write: %v", err)
}
}
if err := tw.Close(); err != nil {
t.Fatalf("tar close: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("gz close: %v", err)
}
return buf.Bytes()
}
func TestResolveSnapshot_CacheHit(t *testing.T) {
dataDir := t.TempDir()
want := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want)
if err := os.MkdirAll(cacheDir, 0750); err != nil {
t.Fatalf("mkdir: %v", err)
}
if err := os.WriteFile(filepath.Join(cacheDir, "file.txt"), []byte("ok"), 0600); err != nil {
t.Fatalf("write: %v", err)
}
f := &memFetcher{err: io.EOF}
cfg := &worker.SnapshotStoreConfig{Enabled: true, Endpoint: "minio:9000", Bucket: "b", Secure: false}
p, err := worker.ResolveSnapshot(context.Background(), dataDir, cfg, "snap-1", want, f)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if p != cacheDir {
t.Fatalf("unexpected path: %q", p)
}
if f.calls != 0 {
t.Fatalf("expected no fetcher calls")
}
}
func TestResolveSnapshot_DownloadAndVerify(t *testing.T) {
dataDir := t.TempDir()
refDir := t.TempDir()
if err := os.WriteFile(filepath.Join(refDir, "file.txt"), []byte("ok"), 0600); err != nil {
t.Fatalf("write: %v", err)
}
want, err := worker.DirOverallSHA256Hex(refDir)
if err != nil {
t.Fatalf("hash: %v", err)
}
tarBytes := makeTarGz(t, map[string][]byte{"file.txt": []byte("ok")})
f := &memFetcher{data: tarBytes}
cfg := &worker.SnapshotStoreConfig{Enabled: true, Endpoint: "minio:9000", Bucket: "b", Secure: false}
p, err := worker.ResolveSnapshot(context.Background(), dataDir, cfg, "snap-1", want, f)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if f.calls != 1 {
t.Fatalf("expected one fetch call, got %d", f.calls)
}
b, err := os.ReadFile(filepath.Join(p, "file.txt"))
if err != nil {
t.Fatalf("read: %v", err)
}
if string(b) != "ok" {
t.Fatalf("unexpected contents: %q", string(b))
}
}
func TestResolveSnapshot_ChecksumMismatch(t *testing.T) {
dataDir := t.TempDir()
want := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
tarBytes := makeTarGz(t, map[string][]byte{"file.txt": []byte("ok")})
f := &memFetcher{data: tarBytes}
cfg := &worker.SnapshotStoreConfig{Enabled: true, Endpoint: "minio:9000", Bucket: "b", Secure: false}
if _, err := worker.ResolveSnapshot(context.Background(), dataDir, cfg, "snap-1", want, f); err == nil {
t.Fatalf("expected error")
}
}
func TestResolveSnapshot_RejectsTraversal(t *testing.T) {
dataDir := t.TempDir()
want := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
tarBytes := makeTarGz(t, map[string][]byte{"../evil": []byte("no")})
f := &memFetcher{data: tarBytes}
cfg := &worker.SnapshotStoreConfig{Enabled: true, Endpoint: "minio:9000", Bucket: "b", Secure: false}
if _, err := worker.ResolveSnapshot(context.Background(), dataDir, cfg, "snap-1", want, f); err == nil {
t.Fatalf("expected error")
}
}