From 27c8b08a1652e35051c299095e906443d76892d7 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Wed, 18 Feb 2026 21:28:13 -0500 Subject: [PATCH] test: Reorganize and add unit tests Reorganize tests for better structure and coverage: - Move container/security_test.go from internal/ to tests/unit/container/ - Move related tests to proper unit test locations - Delete orphaned test files (startup_blacklist_test.go) - Add privacy middleware unit tests - Add worker config unit tests - Update E2E tests for homelab and websocket scenarios - Update test fixtures with utility functions - Add CLI helper script for arraylist fixes --- cli/fix_arraylist.sh | 14 ++ internal/jupyter/startup_blacklist_test.go | 57 -------- tests/e2e/homelab_e2e_test.go | 19 ++- tests/e2e/websocket_e2e_test.go | 2 +- tests/e2e/wss_reverse_proxy_e2e_test.go | 2 +- tests/fixtures/test_utils.go | 17 +++ .../websocket_queue_integration_test.go | 6 +- tests/unit/api/ws_test.go | 8 +- .../unit}/container/security_test.go | 28 ++-- tests/unit/middleware/privacy_test.go | 130 +++++++++++++++++ tests/unit/privacy/pii_test.go | 132 ++++++++++++++++++ tests/unit/worker/config_test.go | 82 +++++++++++ 12 files changed, 411 insertions(+), 86 deletions(-) create mode 100644 cli/fix_arraylist.sh delete mode 100644 internal/jupyter/startup_blacklist_test.go rename {internal => tests/unit}/container/security_test.go (81%) create mode 100644 tests/unit/middleware/privacy_test.go create mode 100644 tests/unit/privacy/pii_test.go create mode 100644 tests/unit/worker/config_test.go diff --git a/cli/fix_arraylist.sh b/cli/fix_arraylist.sh new file mode 100644 index 0000000..a1801f5 --- /dev/null +++ b/cli/fix_arraylist.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Fix ArrayList Zig 0.15 syntax + +cd /Users/jfraeys/Documents/dev/fetch_ml/cli/src + +for f in $(find . -name "*.zig" -exec grep -l "ArrayList" {} \;); do + # Fix .deinit() -> .deinit(allocator) + sed -i '' 's/\.deinit();/.deinit(allocator);/g' "$f" + + # Fix .toOwnedSlice() -> .toOwnedSlice(allocator) + sed -i '' 's/\.toOwnedSlice();/.toOwnedSlice(allocator);/g' "$f" +done + +echo "Fixed deinit and toOwnedSlice patterns" diff --git a/internal/jupyter/startup_blacklist_test.go b/internal/jupyter/startup_blacklist_test.go deleted file mode 100644 index e3be4c4..0000000 --- a/internal/jupyter/startup_blacklist_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package jupyter - -import ( - "os" - "testing" -) - -func TestStartupBlockedPackages_DefaultInheritsInstallBlocked(t *testing.T) { - oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") - _, hadStartup := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - - _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") - _ = os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) - if hadStartup { - defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) - } else { - defer os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - } - - cfg := DefaultEnhancedSecurityConfigFromEnv() - startup := startupBlockedPackages(cfg.BlockedPackages) - if len(startup) != 2 { - t.Fatalf("expected startup list to inherit 2 items, got %d", len(startup)) - } -} - -func TestStartupBlockedPackages_Disabled(t *testing.T) { - oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") - oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") - _ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "off") - defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) - defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) - - cfg := DefaultEnhancedSecurityConfigFromEnv() - startup := startupBlockedPackages(cfg.BlockedPackages) - if len(startup) != 0 { - t.Fatalf("expected startup list to be disabled, got %d", len(startup)) - } -} - -func TestStartupBlockedPackages_ExplicitList(t *testing.T) { - oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") - oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") - _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") - _ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "aiohttp") - defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) - defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) - - cfg := DefaultEnhancedSecurityConfigFromEnv() - startup := startupBlockedPackages(cfg.BlockedPackages) - if len(startup) != 1 || startup[0] != "aiohttp" { - t.Fatalf("expected explicit startup list [aiohttp], got %v", startup) - } -} diff --git a/tests/e2e/homelab_e2e_test.go b/tests/e2e/homelab_e2e_test.go index 96e677f..2cefbbf 100644 --- a/tests/e2e/homelab_e2e_test.go +++ b/tests/e2e/homelab_e2e_test.go @@ -19,7 +19,8 @@ const ( // TestHomelabSetupE2E tests the complete homelab setup workflow end-to-end func TestHomelabSetupE2E(t *testing.T) { // Skip if essential tools not available - manageScript := manageScriptPath + repoRoot := e2eRepoRoot(t) + manageScript := filepath.Join(repoRoot, "tools/manage.sh") if _, err := os.Stat(manageScript); os.IsNotExist(err) { t.Skip("manage.sh not found") } @@ -29,8 +30,8 @@ func TestHomelabSetupE2E(t *testing.T) { t.Skip("CLI not built - run 'make build' first") } - // Use fixtures for manage script operations - ms := tests.NewManageScript(manageScript) + // Use fixtures for manage script operations with correct working directory + ms := tests.NewManageScriptWithDir(manageScript, repoRoot) defer ms.StopAndCleanup() testDir := t.TempDir() @@ -266,8 +267,10 @@ func TestPerformanceE2E(t *testing.T) { t.Skip("manage.sh not found") } - // Use fixtures for manage script operations - ms := tests.NewManageScript(manageScript) + // Use fixtures for manage script operations with correct working directory + repoRoot := e2eRepoRoot(t) + manageScript = filepath.Join(repoRoot, "tools/manage.sh") + ms := tests.NewManageScriptWithDir(manageScript, repoRoot) t.Run("PerformanceMetrics", func(t *testing.T) { // Test health check performance @@ -312,8 +315,10 @@ func TestConfigurationScenariosE2E(t *testing.T) { t.Skip("manage.sh not found") } - // Use fixtures for manage script operations - ms := tests.NewManageScript(manageScript) + // Use fixtures for manage script operations with correct working directory + repoRoot := e2eRepoRoot(t) + manageScript = filepath.Join(repoRoot, "tools/manage.sh") + ms := tests.NewManageScriptWithDir(manageScript, repoRoot) t.Run("ConfigurationHandling", func(t *testing.T) { testDir := t.TempDir() diff --git a/tests/e2e/websocket_e2e_test.go b/tests/e2e/websocket_e2e_test.go index 7a0bf1a..9c85a08 100644 --- a/tests/e2e/websocket_e2e_test.go +++ b/tests/e2e/websocket_e2e_test.go @@ -25,7 +25,7 @@ func setupTestServer(t *testing.T) string { authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") diff --git a/tests/e2e/wss_reverse_proxy_e2e_test.go b/tests/e2e/wss_reverse_proxy_e2e_test.go index 3dbc53a..965f1bd 100644 --- a/tests/e2e/wss_reverse_proxy_e2e_test.go +++ b/tests/e2e/wss_reverse_proxy_e2e_test.go @@ -40,7 +40,7 @@ func startWSBackendServer(t *testing.T) *httptest.Server { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) diff --git a/tests/fixtures/test_utils.go b/tests/fixtures/test_utils.go index 7616a2f..b679caa 100644 --- a/tests/fixtures/test_utils.go +++ b/tests/fixtures/test_utils.go @@ -279,6 +279,7 @@ func (tq *TaskQueue) Close() error { // ManageScript provides utilities for manage.sh operations type ManageScript struct { path string + dir string } // NewManageScript creates a new manage script utility @@ -286,10 +287,22 @@ func NewManageScript(path string) *ManageScript { return &ManageScript{path: path} } +// NewManageScriptWithDir creates a new manage script utility with a specific working directory +func NewManageScriptWithDir(path, dir string) *ManageScript { + return &ManageScript{path: path, dir: dir} +} + +func (ms *ManageScript) setDir(cmd *exec.Cmd) { + if ms.dir != "" { + cmd.Dir = ms.dir + } +} + // Status gets the status of services func (ms *ManageScript) Status() (string, error) { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "status") + ms.setDir(cmd) output, err := cmd.CombinedOutput() return string(output), err } @@ -298,6 +311,7 @@ func (ms *ManageScript) Status() (string, error) { func (ms *ManageScript) Start() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "start") + ms.setDir(cmd) return cmd.Run() } @@ -305,6 +319,7 @@ func (ms *ManageScript) Start() error { func (ms *ManageScript) Stop() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "stop") + ms.setDir(cmd) return cmd.Run() } @@ -312,6 +327,7 @@ func (ms *ManageScript) Stop() error { func (ms *ManageScript) Cleanup() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "cleanup") + ms.setDir(cmd) return cmd.Run() } @@ -325,6 +341,7 @@ func (ms *ManageScript) StopAndCleanup() { func (ms *ManageScript) Health() (string, error) { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "health") + ms.setDir(cmd) output, err := cmd.CombinedOutput() return string(output), err } diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index ee3045e..e82c756 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -47,7 +47,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( @@ -144,7 +144,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( @@ -246,7 +246,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( diff --git a/tests/unit/api/ws_test.go b/tests/unit/api/ws_test.go index cc24c11..77df093 100644 --- a/tests/unit/api/ws_test.go +++ b/tests/unit/api/ws_test.go @@ -23,7 +23,7 @@ func TestNewWSHandler(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") @@ -61,7 +61,7 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") @@ -101,7 +101,7 @@ func TestWSHandlerInvalidRequest(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") @@ -129,7 +129,7 @@ func TestWSHandlerPostRequest(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") diff --git a/internal/container/security_test.go b/tests/unit/container/security_test.go similarity index 81% rename from internal/container/security_test.go rename to tests/unit/container/security_test.go index 5e4bac8..4d5a805 100644 --- a/internal/container/security_test.go +++ b/tests/unit/container/security_test.go @@ -1,8 +1,10 @@ -package container +package tests import ( "errors" "testing" + + "github.com/jfraeys/fetch_ml/internal/container" ) // TestContainerSecurityPolicy enforces the security contract for container configurations. @@ -10,13 +12,13 @@ import ( func TestContainerSecurityPolicy(t *testing.T) { tests := []struct { name string - config PodmanConfig + config container.PodmanConfig shouldFail bool reason string }{ { name: "reject privileged mode", - config: PodmanConfig{ + config: container.PodmanConfig{ Image: "pytorch:latest", Privileged: true, // NEVER allowed }, @@ -25,7 +27,7 @@ func TestContainerSecurityPolicy(t *testing.T) { }, { name: "reject host network", - config: PodmanConfig{ + config: container.PodmanConfig{ Image: "pytorch:latest", Network: "host", // NEVER allowed }, @@ -34,7 +36,7 @@ func TestContainerSecurityPolicy(t *testing.T) { }, { name: "accept valid configuration", - config: PodmanConfig{ + config: container.PodmanConfig{ Image: "pytorch:latest", Privileged: false, Network: "bridge", @@ -45,7 +47,7 @@ func TestContainerSecurityPolicy(t *testing.T) { }, { name: "accept empty network (default bridge)", - config: PodmanConfig{ + config: container.PodmanConfig{ Image: "pytorch:latest", Privileged: false, Network: "", // Empty means default bridge @@ -55,7 +57,7 @@ func TestContainerSecurityPolicy(t *testing.T) { }, { name: "warn on non-read-only mounts", - config: PodmanConfig{ + config: container.PodmanConfig{ Image: "pytorch:latest", Privileged: false, Network: "bridge", @@ -68,11 +70,11 @@ func TestContainerSecurityPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateSecurityPolicy(tt.config) + err := container.ValidateSecurityPolicy(tt.config) if tt.shouldFail { if err == nil { t.Errorf("%s: expected failure (%s), got success", tt.name, tt.reason) - } else if !errors.Is(err, ErrSecurityViolation) { + } else if !errors.Is(err, container.ErrSecurityViolation) { t.Errorf("%s: expected ErrSecurityViolation, got %v", tt.name, err) } } else { @@ -87,22 +89,22 @@ func TestContainerSecurityPolicy(t *testing.T) { // TestSecurityPolicy_IsolationEnforcement verifies isolation boundaries func TestSecurityPolicy_IsolationEnforcement(t *testing.T) { t.Run("privileged_equals_root_access", func(t *testing.T) { - cfg := PodmanConfig{ + cfg := container.PodmanConfig{ Image: "test:latest", Privileged: true, } - err := ValidateSecurityPolicy(cfg) + err := container.ValidateSecurityPolicy(cfg) if err == nil { t.Fatal("privileged mode must be rejected - it grants root access to host") } }) t.Run("host_network_equals_no_isolation", func(t *testing.T) { - cfg := PodmanConfig{ + cfg := container.PodmanConfig{ Image: "test:latest", Network: "host", } - err := ValidateSecurityPolicy(cfg) + err := container.ValidateSecurityPolicy(cfg) if err == nil { t.Fatal("host network must be rejected - it removes network isolation") } diff --git a/tests/unit/middleware/privacy_test.go b/tests/unit/middleware/privacy_test.go new file mode 100644 index 0000000..c6e7c95 --- /dev/null +++ b/tests/unit/middleware/privacy_test.go @@ -0,0 +1,130 @@ +package middleware_test + +import ( + "context" + "testing" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/middleware" +) + +func TestPrivacyEnforcer_CanAccess(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + user *auth.User + owner string + level string + team string + enforceTeams bool + want bool + }{ + { + name: "owner can access private", + user: &auth.User{Name: "alice"}, + owner: "alice", + level: "private", + want: true, + }, + { + name: "non-owner cannot access private", + user: &auth.User{Name: "bob"}, + owner: "alice", + level: "private", + want: false, + }, + { + name: "admin can access private", + user: &auth.User{Name: "admin", Admin: true}, + owner: "alice", + level: "private", + want: true, + }, + { + name: "public allows all", + user: &auth.User{Name: "anyone"}, + owner: "alice", + level: "public", + want: true, + }, + { + name: "owner can access team", + user: &auth.User{Name: "alice"}, + owner: "alice", + level: "team", + team: "research", + want: true, + }, + { + name: "non-owner denied team when enforcing", + user: &auth.User{Name: "bob"}, + owner: "alice", + level: "team", + team: "research", + enforceTeams: true, + want: false, + }, + { + name: "non-owner allowed team when not enforcing", + user: &auth.User{Name: "bob"}, + owner: "alice", + level: "team", + team: "research", + enforceTeams: false, + want: true, + }, + { + name: "anonymized allows all", + user: &auth.User{Name: "anyone"}, + owner: "alice", + level: "anonymized", + want: true, + }, + { + name: "unknown level defaults to private (deny)", + user: &auth.User{Name: "bob"}, + owner: "alice", + level: "unknown", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pe := middleware.NewPrivacyEnforcer(tt.enforceTeams, false) + got, err := pe.CanAccess(ctx, tt.user, tt.owner, tt.level, tt.team) + if err != nil { + t.Errorf("CanAccess() error = %v", err) + return + } + if got != tt.want { + t.Errorf("CanAccess() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetPrivacyLevelFromString(t *testing.T) { + tests := []struct { + input string + expected middleware.PrivacyLevel + }{ + {"private", middleware.PrivacyPrivate}, + {"team", middleware.PrivacyTeam}, + {"public", middleware.PrivacyPublic}, + {"anonymized", middleware.PrivacyAnonymized}, + {"unknown", middleware.PrivacyPrivate}, // Default + {"", middleware.PrivacyPrivate}, // Default + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := middleware.GetPrivacyLevelFromString(tt.input) + if got != tt.expected { + t.Errorf("GetPrivacyLevelFromString(%q) = %v, want %v", + tt.input, got, tt.expected) + } + }) + } +} diff --git a/tests/unit/privacy/pii_test.go b/tests/unit/privacy/pii_test.go new file mode 100644 index 0000000..e2b9298 --- /dev/null +++ b/tests/unit/privacy/pii_test.go @@ -0,0 +1,132 @@ +package privacy_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/privacy" +) + +func TestDetectPII(t *testing.T) { + tests := []struct { + name string + text string + expected []string // Expected PII types found + }{ + { + name: "email detection", + text: "Contact me at user@example.com for details", + expected: []string{"email"}, + }, + { + name: "SSN detection", + text: "My SSN is 123-45-6789", + expected: []string{"ssn"}, + }, + { + name: "phone detection", + text: "Call me at 555-123-4567", + expected: []string{"phone"}, + }, + { + name: "IP address detection", + text: "Server at 192.168.1.1", + expected: []string{"ip_address"}, + }, + { + name: "multiple PII types", + text: "Email: test@example.com, SSN: 123-45-6789", + expected: []string{"email", "ssn"}, + }, + { + name: "no PII", + text: "This is just a normal hypothesis about learning rates", + expected: []string{}, + }, + { + name: "credit card detection", + text: "Card: 4111-1111-1111-1111", + expected: []string{"credit_card"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + findings := privacy.DetectPII(tt.text) + + if len(tt.expected) == 0 { + if len(findings) != 0 { + t.Errorf("expected no PII, found %d findings", len(findings)) + } + return + } + + // Check that all expected types are found + foundTypes := make(map[string]bool) + for _, f := range findings { + foundTypes[f.Type] = true + } + + for _, expectedType := range tt.expected { + if !foundTypes[expectedType] { + t.Errorf("expected to find %s, but didn't", expectedType) + } + } + }) + } +} + +func TestHasPII(t *testing.T) { + tests := []struct { + name string + text string + expected bool + }{ + { + name: "has PII", + text: "Contact: user@example.com", + expected: true, + }, + { + name: "no PII", + text: "Learning rate 0.01 worked well", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := privacy.HasPII(tt.text) + if result != tt.expected { + t.Errorf("HasPII() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestRedactSample(t *testing.T) { + tests := []struct { + name string + match string + want string + }{ + { + name: "short match", + match: "abc", + want: "[PII]", + }, + { + name: "long match", + match: "user@example.com", + want: "us...om", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := privacy.RedactSample(tt.match) + if got != tt.want { + t.Errorf("redactSample() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tests/unit/worker/config_test.go b/tests/unit/worker/config_test.go new file mode 100644 index 0000000..6c16d78 --- /dev/null +++ b/tests/unit/worker/config_test.go @@ -0,0 +1,82 @@ +package worker_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker" +) + +func TestSandboxConfig_Validate(t *testing.T) { + tests := []struct { + name string + config worker.SandboxConfig + wantErr bool + }{ + { + name: "valid none network", + config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: 48}, + wantErr: false, + }, + { + name: "valid bridge network", + config: worker.SandboxConfig{NetworkMode: "bridge", MaxRuntimeHours: 24}, + wantErr: false, + }, + { + name: "valid slirp4netns network", + config: worker.SandboxConfig{NetworkMode: "slirp4netns", MaxRuntimeHours: 12}, + wantErr: false, + }, + { + name: "valid empty network", + config: worker.SandboxConfig{NetworkMode: "", MaxRuntimeHours: 48}, + wantErr: false, + }, + { + name: "invalid network mode", + config: worker.SandboxConfig{NetworkMode: "host", MaxRuntimeHours: 48}, + wantErr: true, + }, + { + name: "negative runtime hours", + config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: -1}, + wantErr: true, + }, + { + name: "zero runtime hours is valid", + config: worker.SandboxConfig{NetworkMode: "none", MaxRuntimeHours: 0}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSandboxConfig_WithSecrets(t *testing.T) { + config := worker.SandboxConfig{ + NetworkMode: "none", + ReadOnlyRoot: true, + AllowSecrets: true, + AllowedSecrets: []string{"HF_TOKEN", "WANDB_API_KEY"}, + MaxRuntimeHours: 48, + } + + if err := config.Validate(); err != nil { + t.Errorf("Valid sandbox config should not error: %v", err) + } + + if !config.AllowSecrets { + t.Error("AllowSecrets should be true") + } + + if len(config.AllowedSecrets) != 2 { + t.Errorf("Expected 2 allowed secrets, got %d", len(config.AllowedSecrets)) + } +}