package jupyter_test import ( "log/slog" "os" "testing" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" ) func TestPackageBlacklistEnforcement(t *testing.T) { cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv() blocked := cfg.BlockedPackages foundAiohttp := false for _, pkg := range blocked { if pkg == "aiohttp" { foundAiohttp = true } } if !foundAiohttp { t.Fatalf("expected aiohttp to be blocked by default") } } func TestPackageBlacklistEnvironmentOverride(t *testing.T) { old := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "custom-package,another-package") t.Cleanup(func() { _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", old) }) cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv() foundCustom := false foundAnother := false for _, pkg := range cfg.BlockedPackages { if pkg == "custom-package" { foundCustom = true } if pkg == "another-package" { foundAnother = true } } if !foundCustom { t.Fatalf("expected custom-package to be blocked from env") } if !foundAnother { t.Fatalf("expected another-package to be blocked from env") } } func TestPackageValidation(t *testing.T) { logger := logging.NewLogger(slog.LevelInfo, false) cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv() sm := jupyter.NewSecurityManager(logger, cfg) pkgReq := &jupyter.PackageRequest{ PackageName: "aiohttp", RequestedBy: "test-user", Channel: "pypi", Version: "2.28.0", } err := sm.ValidatePackageRequest(pkgReq) if err == nil { t.Fatalf("expected validation to fail for blocked package requests") } pkgReq.PackageName = "numpy" pkgReq.Channel = "conda-forge" err = sm.ValidatePackageRequest(pkgReq) if err != nil { t.Fatalf("expected validation to pass for numpy, got %v", err) } } func TestPackageParsing(t *testing.T) { pipOutput := `numpy==1.24.0 pandas==2.0.0 requests==2.28.0 # Some comment scipy==1.10.0` packages := jupyter.ParsePipList(pipOutput) expected := []string{"numpy", "pandas", "requests", "scipy"} if len(packages) != len(expected) { t.Fatalf("expected %d packages, got %d", len(expected), len(packages)) } for _, exp := range expected { found := false for _, got := range packages { if got == exp { found = true break } } if !found { t.Fatalf("expected to find package %q in parsed list", exp) } } condaOutput := `numpy=1.24.0=py39h8ecf13d_0 pandas=2.0.0=py39h8ecf13d_0 requests=2.28.0=py39h8ecf13d_0 # Some comment scipy=1.10.0=py39h8ecf13d_0` packages = jupyter.ParseCondaList(condaOutput) if len(packages) != len(expected) { t.Fatalf("expected %d packages, got %d", len(expected), len(packages)) } for _, exp := range expected { found := false for _, got := range packages { if got == exp { found = true break } } if !found { t.Fatalf("expected to find package %q in parsed conda list", exp) } } }