- Update Makefile with native build targets (preparing for C++) - Add profiler and performance regression detector commands - Update CI/testing scripts - Add additional unit tests for API, jupyter, queue, manifest
129 lines
2.9 KiB
Go
129 lines
2.9 KiB
Go
package jupyter_test
|
|
|
|
import (
|
|
"log/slog"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
func TestPackageBlacklistEnforcement(t *testing.T) {
|
|
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
|
|
|
|
blocked := cfg.BlockedPackages
|
|
foundAiohttp := false
|
|
|
|
for _, pkg := range blocked {
|
|
if pkg == "aiohttp" {
|
|
foundAiohttp = true
|
|
}
|
|
}
|
|
|
|
if !foundAiohttp {
|
|
t.Fatalf("expected aiohttp to be blocked by default")
|
|
}
|
|
}
|
|
|
|
func TestPackageBlacklistEnvironmentOverride(t *testing.T) {
|
|
old := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
|
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "custom-package,another-package")
|
|
t.Cleanup(func() {
|
|
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", old)
|
|
})
|
|
|
|
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
|
|
|
|
foundCustom := false
|
|
foundAnother := false
|
|
for _, pkg := range cfg.BlockedPackages {
|
|
if pkg == "custom-package" {
|
|
foundCustom = true
|
|
}
|
|
if pkg == "another-package" {
|
|
foundAnother = true
|
|
}
|
|
}
|
|
|
|
if !foundCustom {
|
|
t.Fatalf("expected custom-package to be blocked from env")
|
|
}
|
|
if !foundAnother {
|
|
t.Fatalf("expected another-package to be blocked from env")
|
|
}
|
|
}
|
|
|
|
func TestPackageValidation(t *testing.T) {
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
cfg := jupyter.DefaultEnhancedSecurityConfigFromEnv()
|
|
sm := jupyter.NewSecurityManager(logger, cfg)
|
|
|
|
pkgReq := &jupyter.PackageRequest{
|
|
PackageName: "aiohttp",
|
|
RequestedBy: "test-user",
|
|
Channel: "pypi",
|
|
Version: "2.28.0",
|
|
}
|
|
|
|
err := sm.ValidatePackageRequest(pkgReq)
|
|
if err == nil {
|
|
t.Fatalf("expected validation to fail for blocked package requests")
|
|
}
|
|
|
|
pkgReq.PackageName = "numpy"
|
|
pkgReq.Channel = "conda-forge"
|
|
err = sm.ValidatePackageRequest(pkgReq)
|
|
if err != nil {
|
|
t.Fatalf("expected validation to pass for numpy, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestPackageParsing(t *testing.T) {
|
|
pipOutput := `numpy==1.24.0
|
|
pandas==2.0.0
|
|
requests==2.28.0
|
|
# Some comment
|
|
scipy==1.10.0`
|
|
|
|
packages := jupyter.ParsePipList(pipOutput)
|
|
expected := []string{"numpy", "pandas", "requests", "scipy"}
|
|
if len(packages) != len(expected) {
|
|
t.Fatalf("expected %d packages, got %d", len(expected), len(packages))
|
|
}
|
|
for _, exp := range expected {
|
|
found := false
|
|
for _, got := range packages {
|
|
if got == exp {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("expected to find package %q in parsed list", exp)
|
|
}
|
|
}
|
|
|
|
condaOutput := `numpy=1.24.0=py39h8ecf13d_0
|
|
pandas=2.0.0=py39h8ecf13d_0
|
|
requests=2.28.0=py39h8ecf13d_0
|
|
# Some comment
|
|
scipy=1.10.0=py39h8ecf13d_0`
|
|
|
|
packages = jupyter.ParseCondaList(condaOutput)
|
|
if len(packages) != len(expected) {
|
|
t.Fatalf("expected %d packages, got %d", len(expected), len(packages))
|
|
}
|
|
for _, exp := range expected {
|
|
found := false
|
|
for _, got := range packages {
|
|
if got == exp {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("expected to find package %q in parsed conda list", exp)
|
|
}
|
|
}
|
|
}
|