diff --git a/internal/worker/artifacts.go b/internal/worker/artifacts.go index adf005f..355466c 100644 --- a/internal/worker/artifacts.go +++ b/internal/worker/artifacts.go @@ -12,7 +12,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/manifest" ) -func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error) { +func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) { runDir = strings.TrimSpace(runDir) if runDir == "" { return nil, fmt.Errorf("run dir is empty") @@ -27,6 +27,7 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error) var files []manifest.ArtifactFile var total int64 + var fileCount int now := time.Now().UTC() @@ -92,12 +93,22 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error) return err } + // Check artifact caps before adding + fileCount++ + if caps != nil && caps.MaxArtifactFiles > 0 && fileCount > caps.MaxArtifactFiles { + return fmt.Errorf("artifact file count cap exceeded: %d files (max %d)", fileCount, caps.MaxArtifactFiles) + } + + total += info.Size() + if caps != nil && caps.MaxArtifactTotalBytes > 0 && total > caps.MaxArtifactTotalBytes { + return fmt.Errorf("artifact total size cap exceeded: %d bytes (max %d)", total, caps.MaxArtifactTotalBytes) + } + files = append(files, manifest.ArtifactFile{ Path: rel, SizeBytes: info.Size(), Modified: info.ModTime().UTC(), }) - total += info.Size() return nil }) if err != nil { @@ -119,6 +130,6 @@ const manifestFilename = "run_manifest.json" // ScanArtifacts is an exported wrapper for testing/benchmarking. // When includeAll is false, excludes code/, snapshot/, *.log files, and symlinks. -func ScanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error) { - return scanArtifacts(runDir, includeAll) +func ScanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) { + return scanArtifacts(runDir, includeAll, caps) } diff --git a/internal/worker/native_bridge_libs.go b/internal/worker/native_bridge_libs.go index 0fac19e..bc8d951 100644 --- a/internal/worker/native_bridge_libs.go +++ b/internal/worker/native_bridge_libs.go @@ -67,7 +67,7 @@ func HasSIMDSHA256() bool { } func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) { - return ScanArtifacts(runDir, false) + return ScanArtifacts(runDir, false, nil) } func ExtractTarGzNative(archivePath, dstDir string) error { diff --git a/tests/benchmarks/artifact_and_snapshot_bench_test.go b/tests/benchmarks/artifact_and_snapshot_bench_test.go index 50caee7..2115ccd 100644 --- a/tests/benchmarks/artifact_and_snapshot_bench_test.go +++ b/tests/benchmarks/artifact_and_snapshot_bench_test.go @@ -131,7 +131,7 @@ func BenchmarkScanArtifacts(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - _, err := worker.ScanArtifacts(runDir, false) + _, err := worker.ScanArtifacts(runDir, false, nil) if err != nil { b.Fatal(err) } diff --git a/tests/benchmarks/artifact_scanner_bench_test.go b/tests/benchmarks/artifact_scanner_bench_test.go index 2b31b95..9e3d66a 100644 --- a/tests/benchmarks/artifact_scanner_bench_test.go +++ b/tests/benchmarks/artifact_scanner_bench_test.go @@ -19,7 +19,7 @@ func BenchmarkArtifactScanGo(b *testing.B) { b.ReportAllocs() for b.Loop() { - _, err := worker.ScanArtifacts(tmpDir, false) + _, err := worker.ScanArtifacts(tmpDir, false, nil) if err != nil { b.Fatal(err) } @@ -57,7 +57,7 @@ func BenchmarkArtifactScanLarge(b *testing.B) { b.Run("Go", func(b *testing.B) { b.ReportAllocs() for b.Loop() { - _, err := worker.ScanArtifacts(tmpDir, false) + _, err := worker.ScanArtifacts(tmpDir, false, nil) if err != nil { b.Fatal(err) } diff --git a/tests/unit/worker/artifacts_test.go b/tests/unit/worker/artifacts_test.go index a99df85..cf67960 100644 --- a/tests/unit/worker/artifacts_test.go +++ b/tests/unit/worker/artifacts_test.go @@ -30,7 +30,7 @@ func TestScanArtifacts_SkipsKnownPathsAndLogs(t *testing.T) { mustWrite("checkpoints/best.pt", []byte("checkpoint")) mustWrite("plots/loss.png", []byte("png")) - art, err := worker.ScanArtifacts(runDir, false) + art, err := worker.ScanArtifacts(runDir, false, nil) if err != nil { t.Fatalf("scanArtifacts: %v", err) }