fetch_ml/tests/unit/jupyter/package_blacklist_test.go

143 lines
3.2 KiB
Go

package jupyter_test
import (
"log/slog"
"os"
"testing"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
)
func TestPackageBlacklistEnforcement(t *testing.T) {
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
blocked := cfg.BlockedPackages
foundRequests := false
foundUrllib3 := false
foundHttpx := false
for _, pkg := range blocked {
if pkg == "requests" {
foundRequests = true
}
if pkg == "urllib3" {
foundUrllib3 = true
}
if pkg == "httpx" {
foundHttpx = true
}
}
if !foundRequests {
t.Fatalf("expected requests to be blocked by default")
}
if !foundUrllib3 {
t.Fatalf("expected urllib3 to be blocked by default")
}
if !foundHttpx {
t.Fatalf("expected httpx to be blocked by default")
}
}
func TestPackageBlacklistEnvironmentOverride(t *testing.T) {
old := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "custom-package,another-package")
t.Cleanup(func() {
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", old)
})
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
foundCustom := false
foundAnother := false
for _, pkg := range cfg.BlockedPackages {
if pkg == "custom-package" {
foundCustom = true
}
if pkg == "another-package" {
foundAnother = true
}
}
if !foundCustom {
t.Fatalf("expected custom-package to be blocked from env")
}
if !foundAnother {
t.Fatalf("expected another-package to be blocked from env")
}
}
func TestPackageValidation(t *testing.T) {
logger := logging.NewLogger(slog.LevelInfo, false)
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
sm := jupyter.NewSecurityManager(logger, cfg)
pkgReq := &jupyter.PackageRequest{
PackageName: "requests",
RequestedBy: "test-user",
Channel: "pypi",
Version: "2.28.0",
}
err := sm.ValidatePackageRequest(pkgReq)
if err == nil {
t.Fatalf("expected validation to fail for blocked package requests")
}
pkgReq.PackageName = "numpy"
pkgReq.Channel = "conda-forge"
err = sm.ValidatePackageRequest(pkgReq)
if err != nil {
t.Fatalf("expected validation to pass for numpy, got %v", err)
}
}
func TestPackageParsing(t *testing.T) {
pipOutput := `numpy==1.24.0
pandas==2.0.0
requests==2.28.0
# Some comment
scipy==1.10.0`
packages := jupyter.ParsePipList(pipOutput)
expected := []string{"numpy", "pandas", "requests", "scipy"}
if len(packages) != len(expected) {
t.Fatalf("expected %d packages, got %d", len(expected), len(packages))
}
for _, exp := range expected {
found := false
for _, got := range packages {
if got == exp {
found = true
break
}
}
if !found {
t.Fatalf("expected to find package %q in parsed list", exp)
}
}
condaOutput := `numpy=1.24.0=py39h8ecf13d_0
pandas=2.0.0=py39h8ecf13d_0
requests=2.28.0=py39h8ecf13d_0
# Some comment
scipy=1.10.0=py39h8ecf13d_0`
packages = jupyter.ParseCondaList(condaOutput)
if len(packages) != len(expected) {
t.Fatalf("expected %d packages, got %d", len(expected), len(packages))
}
for _, exp := range expected {
found := false
for _, got := range packages {
if got == exp {
found = true
break
}
}
if !found {
t.Fatalf("expected to find package %q in parsed conda list", exp)
}
}
}