test: Reorganize and add unit tests
Reorganize tests for better structure and coverage: - Move container/security_test.go from internal/ to tests/unit/container/ - Move related tests to proper unit test locations - Delete orphaned test files (startup_blacklist_test.go) - Add privacy middleware unit tests - Add worker config unit tests - Update E2E tests for homelab and websocket scenarios - Update test fixtures with utility functions - Add CLI helper script for arraylist fixes
This commit is contained in:
parent
4756348c48
commit
27c8b08a16
12 changed files with 411 additions and 86 deletions
14
cli/fix_arraylist.sh
Normal file
14
cli/fix_arraylist.sh
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
# Fix ArrayList Zig 0.15 syntax
|
||||
|
||||
cd /Users/jfraeys/Documents/dev/fetch_ml/cli/src
|
||||
|
||||
for f in $(find . -name "*.zig" -exec grep -l "ArrayList" {} \;); do
|
||||
# Fix .deinit() -> .deinit(allocator)
|
||||
sed -i '' 's/\.deinit();/.deinit(allocator);/g' "$f"
|
||||
|
||||
# Fix .toOwnedSlice() -> .toOwnedSlice(allocator)
|
||||
sed -i '' 's/\.toOwnedSlice();/.toOwnedSlice(allocator);/g' "$f"
|
||||
done
|
||||
|
||||
echo "Fixed deinit and toOwnedSlice patterns"
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
package jupyter
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStartupBlockedPackages_DefaultInheritsInstallBlocked(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
_, hadStartup := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
if hadStartup {
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
} else {
|
||||
defer os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
}
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 2 {
|
||||
t.Fatalf("expected startup list to inherit 2 items, got %d", len(startup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupBlockedPackages_Disabled(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "off")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 0 {
|
||||
t.Fatalf("expected startup list to be disabled, got %d", len(startup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupBlockedPackages_ExplicitList(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "aiohttp")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 1 || startup[0] != "aiohttp" {
|
||||
t.Fatalf("expected explicit startup list [aiohttp], got %v", startup)
|
||||
}
|
||||
}
|
||||
|
|
@ -19,7 +19,8 @@ const (
|
|||
// TestHomelabSetupE2E tests the complete homelab setup workflow end-to-end
|
||||
func TestHomelabSetupE2E(t *testing.T) {
|
||||
// Skip if essential tools not available
|
||||
manageScript := manageScriptPath
|
||||
repoRoot := e2eRepoRoot(t)
|
||||
manageScript := filepath.Join(repoRoot, "tools/manage.sh")
|
||||
if _, err := os.Stat(manageScript); os.IsNotExist(err) {
|
||||
t.Skip("manage.sh not found")
|
||||
}
|
||||
|
|
@ -29,8 +30,8 @@ func TestHomelabSetupE2E(t *testing.T) {
|
|||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
// Use fixtures for manage script operations
|
||||
ms := tests.NewManageScript(manageScript)
|
||||
// Use fixtures for manage script operations with correct working directory
|
||||
ms := tests.NewManageScriptWithDir(manageScript, repoRoot)
|
||||
defer ms.StopAndCleanup()
|
||||
|
||||
testDir := t.TempDir()
|
||||
|
|
@ -266,8 +267,10 @@ func TestPerformanceE2E(t *testing.T) {
|
|||
t.Skip("manage.sh not found")
|
||||
}
|
||||
|
||||
// Use fixtures for manage script operations
|
||||
ms := tests.NewManageScript(manageScript)
|
||||
// Use fixtures for manage script operations with correct working directory
|
||||
repoRoot := e2eRepoRoot(t)
|
||||
manageScript = filepath.Join(repoRoot, "tools/manage.sh")
|
||||
ms := tests.NewManageScriptWithDir(manageScript, repoRoot)
|
||||
|
||||
t.Run("PerformanceMetrics", func(t *testing.T) {
|
||||
// Test health check performance
|
||||
|
|
@ -312,8 +315,10 @@ func TestConfigurationScenariosE2E(t *testing.T) {
|
|||
t.Skip("manage.sh not found")
|
||||
}
|
||||
|
||||
// Use fixtures for manage script operations
|
||||
ms := tests.NewManageScript(manageScript)
|
||||
// Use fixtures for manage script operations with correct working directory
|
||||
repoRoot := e2eRepoRoot(t)
|
||||
manageScript = filepath.Join(repoRoot, "tools/manage.sh")
|
||||
ms := tests.NewManageScriptWithDir(manageScript, repoRoot)
|
||||
|
||||
t.Run("ConfigurationHandling", func(t *testing.T) {
|
||||
testDir := t.TempDir()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ func setupTestServer(t *testing.T) string {
|
|||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func startWSBackendServer(t *testing.T) *httptest.Server {
|
|||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
|
|
|
|||
17
tests/fixtures/test_utils.go
vendored
17
tests/fixtures/test_utils.go
vendored
|
|
@ -279,6 +279,7 @@ func (tq *TaskQueue) Close() error {
|
|||
// ManageScript provides utilities for manage.sh operations
|
||||
type ManageScript struct {
|
||||
path string
|
||||
dir string
|
||||
}
|
||||
|
||||
// NewManageScript creates a new manage script utility
|
||||
|
|
@ -286,10 +287,22 @@ func NewManageScript(path string) *ManageScript {
|
|||
return &ManageScript{path: path}
|
||||
}
|
||||
|
||||
// NewManageScriptWithDir creates a new manage script utility with a specific working directory
|
||||
func NewManageScriptWithDir(path, dir string) *ManageScript {
|
||||
return &ManageScript{path: path, dir: dir}
|
||||
}
|
||||
|
||||
func (ms *ManageScript) setDir(cmd *exec.Cmd) {
|
||||
if ms.dir != "" {
|
||||
cmd.Dir = ms.dir
|
||||
}
|
||||
}
|
||||
|
||||
// Status gets the status of services
|
||||
func (ms *ManageScript) Status() (string, error) {
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
|
||||
cmd := exec.CommandContext(context.Background(), ms.path, "status")
|
||||
ms.setDir(cmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
return string(output), err
|
||||
}
|
||||
|
|
@ -298,6 +311,7 @@ func (ms *ManageScript) Status() (string, error) {
|
|||
func (ms *ManageScript) Start() error {
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
|
||||
cmd := exec.CommandContext(context.Background(), ms.path, "start")
|
||||
ms.setDir(cmd)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
|
@ -305,6 +319,7 @@ func (ms *ManageScript) Start() error {
|
|||
func (ms *ManageScript) Stop() error {
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
|
||||
cmd := exec.CommandContext(context.Background(), ms.path, "stop")
|
||||
ms.setDir(cmd)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
|
@ -312,6 +327,7 @@ func (ms *ManageScript) Stop() error {
|
|||
func (ms *ManageScript) Cleanup() error {
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
|
||||
cmd := exec.CommandContext(context.Background(), ms.path, "cleanup")
|
||||
ms.setDir(cmd)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
|
@ -325,6 +341,7 @@ func (ms *ManageScript) StopAndCleanup() {
|
|||
func (ms *ManageScript) Health() (string, error) {
|
||||
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
|
||||
cmd := exec.CommandContext(context.Background(), ms.path, "health")
|
||||
ms.setDir(cmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
return string(output), err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
@ -144,7 +144,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
@ -246,7 +246,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func TestNewWSHandler(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ func TestWSHandlerInvalidRequest(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ func TestWSHandlerPostRequest(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package container
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
)
|
||||
|
||||
// TestContainerSecurityPolicy enforces the security contract for container configurations.
|
||||
|
|
@ -10,13 +12,13 @@ import (
|
|||
func TestContainerSecurityPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config PodmanConfig
|
||||
config container.PodmanConfig
|
||||
shouldFail bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "reject privileged mode",
|
||||
config: PodmanConfig{
|
||||
config: container.PodmanConfig{
|
||||
Image: "pytorch:latest",
|
||||
Privileged: true, // NEVER allowed
|
||||
},
|
||||
|
|
@ -25,7 +27,7 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "reject host network",
|
||||
config: PodmanConfig{
|
||||
config: container.PodmanConfig{
|
||||
Image: "pytorch:latest",
|
||||
Network: "host", // NEVER allowed
|
||||
},
|
||||
|
|
@ -34,7 +36,7 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "accept valid configuration",
|
||||
config: PodmanConfig{
|
||||
config: container.PodmanConfig{
|
||||
Image: "pytorch:latest",
|
||||
Privileged: false,
|
||||
Network: "bridge",
|
||||
|
|
@ -45,7 +47,7 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "accept empty network (default bridge)",
|
||||
config: PodmanConfig{
|
||||
config: container.PodmanConfig{
|
||||
Image: "pytorch:latest",
|
||||
Privileged: false,
|
||||
Network: "", // Empty means default bridge
|
||||
|
|
@ -55,7 +57,7 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "warn on non-read-only mounts",
|
||||
config: PodmanConfig{
|
||||
config: container.PodmanConfig{
|
||||
Image: "pytorch:latest",
|
||||
Privileged: false,
|
||||
Network: "bridge",
|
||||
|
|
@ -68,11 +70,11 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateSecurityPolicy(tt.config)
|
||||
err := container.ValidateSecurityPolicy(tt.config)
|
||||
if tt.shouldFail {
|
||||
if err == nil {
|
||||
t.Errorf("%s: expected failure (%s), got success", tt.name, tt.reason)
|
||||
} else if !errors.Is(err, ErrSecurityViolation) {
|
||||
} else if !errors.Is(err, container.ErrSecurityViolation) {
|
||||
t.Errorf("%s: expected ErrSecurityViolation, got %v", tt.name, err)
|
||||
}
|
||||
} else {
|
||||
|
|
@ -87,22 +89,22 @@ func TestContainerSecurityPolicy(t *testing.T) {
|
|||
// TestSecurityPolicy_IsolationEnforcement verifies isolation boundaries
|
||||
func TestSecurityPolicy_IsolationEnforcement(t *testing.T) {
|
||||
t.Run("privileged_equals_root_access", func(t *testing.T) {
|
||||
cfg := PodmanConfig{
|
||||
cfg := container.PodmanConfig{
|
||||
Image: "test:latest",
|
||||
Privileged: true,
|
||||
}
|
||||
err := ValidateSecurityPolicy(cfg)
|
||||
err := container.ValidateSecurityPolicy(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("privileged mode must be rejected - it grants root access to host")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("host_network_equals_no_isolation", func(t *testing.T) {
|
||||
cfg := PodmanConfig{
|
||||
cfg := container.PodmanConfig{
|
||||
Image: "test:latest",
|
||||
Network: "host",
|
||||
}
|
||||
err := ValidateSecurityPolicy(cfg)
|
||||
err := container.ValidateSecurityPolicy(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("host network must be rejected - it removes network isolation")
|
||||
}
|
||||
130
tests/unit/middleware/privacy_test.go
Normal file
130
tests/unit/middleware/privacy_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/middleware"
|
||||
)
|
||||
|
||||
func TestPrivacyEnforcer_CanAccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *auth.User
|
||||
owner string
|
||||
level string
|
||||
team string
|
||||
enforceTeams bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "owner can access private",
|
||||
user: &auth.User{Name: "alice"},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-owner cannot access private",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "admin can access private",
|
||||
user: &auth.User{Name: "admin", Admin: true},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public allows all",
|
||||
user: &auth.User{Name: "anyone"},
|
||||
owner: "alice",
|
||||
level: "public",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "owner can access team",
|
||||
user: &auth.User{Name: "alice"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-owner denied team when enforcing",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
enforceTeams: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-owner allowed team when not enforcing",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
enforceTeams: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "anonymized allows all",
|
||||
user: &auth.User{Name: "anyone"},
|
||||
owner: "alice",
|
||||
level: "anonymized",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "unknown level defaults to private (deny)",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "unknown",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pe := middleware.NewPrivacyEnforcer(tt.enforceTeams, false)
|
||||
got, err := pe.CanAccess(ctx, tt.user, tt.owner, tt.level, tt.team)
|
||||
if err != nil {
|
||||
t.Errorf("CanAccess() error = %v", err)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("CanAccess() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrivacyLevelFromString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected middleware.PrivacyLevel
|
||||
}{
|
||||
{"private", middleware.PrivacyPrivate},
|
||||
{"team", middleware.PrivacyTeam},
|
||||
{"public", middleware.PrivacyPublic},
|
||||
{"anonymized", middleware.PrivacyAnonymized},
|
||||
{"unknown", middleware.PrivacyPrivate}, // Default
|
||||
{"", middleware.PrivacyPrivate}, // Default
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := middleware.GetPrivacyLevelFromString(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("GetPrivacyLevelFromString(%q) = %v, want %v",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
132
tests/unit/privacy/pii_test.go
Normal file
132
tests/unit/privacy/pii_test.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package privacy_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/privacy"
|
||||
)
|
||||
|
||||
func TestDetectPII(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expected []string // Expected PII types found
|
||||
}{
|
||||
{
|
||||
name: "email detection",
|
||||
text: "Contact me at user@example.com for details",
|
||||
expected: []string{"email"},
|
||||
},
|
||||
{
|
||||
name: "SSN detection",
|
||||
text: "My SSN is 123-45-6789",
|
||||
expected: []string{"ssn"},
|
||||
},
|
||||
{
|
||||
name: "phone detection",
|
||||
text: "Call me at 555-123-4567",
|
||||
expected: []string{"phone"},
|
||||
},
|
||||
{
|
||||
name: "IP address detection",
|
||||
text: "Server at 192.168.1.1",
|
||||
expected: []string{"ip_address"},
|
||||
},
|
||||
{
|
||||
name: "multiple PII types",
|
||||
text: "Email: test@example.com, SSN: 123-45-6789",
|
||||
expected: []string{"email", "ssn"},
|
||||
},
|
||||
{
|
||||
name: "no PII",
|
||||
text: "This is just a normal hypothesis about learning rates",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "credit card detection",
|
||||
text: "Card: 4111-1111-1111-1111",
|
||||
expected: []string{"credit_card"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
findings := privacy.DetectPII(tt.text)
|
||||
|
||||
if len(tt.expected) == 0 {
|
||||
if len(findings) != 0 {
|
||||
t.Errorf("expected no PII, found %d findings", len(findings))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that all expected types are found
|
||||
foundTypes := make(map[string]bool)
|
||||
for _, f := range findings {
|
||||
foundTypes[f.Type] = true
|
||||
}
|
||||
|
||||
for _, expectedType := range tt.expected {
|
||||
if !foundTypes[expectedType] {
|
||||
t.Errorf("expected to find %s, but didn't", expectedType)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPII(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has PII",
|
||||
text: "Contact: user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no PII",
|
||||
text: "Learning rate 0.01 worked well",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := privacy.HasPII(tt.text)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasPII() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactSample(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
match string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "short match",
|
||||
match: "abc",
|
||||
want: "[PII]",
|
||||
},
|
||||
{
|
||||
name: "long match",
|
||||
match: "user@example.com",
|
||||
want: "us...om",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := privacy.RedactSample(tt.match)
|
||||
if got != tt.want {
|
||||
t.Errorf("redactSample() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
82
tests/unit/worker/config_test.go
Normal file
82
tests/unit/worker/config_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package worker_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
func TestSandboxConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config worker.SandboxConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid none network",
|
||||
config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: 48},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid bridge network",
|
||||
config: worker.SandboxConfig{NetworkMode: "bridge", MaxRuntimeHours: 24},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid slirp4netns network",
|
||||
config: worker.SandboxConfig{NetworkMode: "slirp4netns", MaxRuntimeHours: 12},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid empty network",
|
||||
config: worker.SandboxConfig{NetworkMode: "", MaxRuntimeHours: 48},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid network mode",
|
||||
config: worker.SandboxConfig{NetworkMode: "host", MaxRuntimeHours: 48},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative runtime hours",
|
||||
config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: -1},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero runtime hours is valid",
|
||||
config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: 0},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxConfig_WithSecrets(t *testing.T) {
|
||||
config := worker.SandboxConfig{
|
||||
NetworkMode: "none",
|
||||
ReadOnlyRoot: true,
|
||||
AllowSecrets: true,
|
||||
AllowedSecrets: []string{"HF_TOKEN", "WANDB_API_KEY"},
|
||||
MaxRuntimeHours: 48,
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Valid sandbox config should not error: %v", err)
|
||||
}
|
||||
|
||||
if !config.AllowSecrets {
|
||||
t.Error("AllowSecrets should be true")
|
||||
}
|
||||
|
||||
if len(config.AllowedSecrets) != 2 {
|
||||
t.Errorf("Expected 2 allowed secrets, got %d", len(config.AllowedSecrets))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue