From c74e91dd696f92f41e29f43e6c756ea04332a325 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Sun, 8 Mar 2026 13:03:55 -0400 Subject: [PATCH] test: update test suite and remove deprecated privacy middleware Test improvements: - fixtures/: Updated mocks, fixtures with group context, SSH server, TUI driver - integration/: WebSocket queue and handler tests with groups - e2e/: WebSocket and TLS proxy end-to-end tests - unit/api/ws_test.go: WebSocket API tests - unit/scheduler/service_templates_test.go: Service template tests - benchmarks/scheduler_bench_test.go: Performance benchmarks Cleanup: - Remove privacy middleware (replaced by audit system) - Remove privacy_test.go --- internal/middleware/privacy.go | 94 ------------- tests/benchmarks/scheduler_bench_test.go | 4 + tests/e2e/websocket_e2e_test.go | 4 +- tests/e2e/wss_reverse_proxy_e2e_test.go | 4 +- tests/fixtures/scheduler_fixture.go | 47 ++++--- tests/fixtures/scheduler_mock.go | 7 +- tests/fixtures/ssh_server.go | 20 +-- tests/fixtures/test_utils.go | 4 +- tests/fixtures/tui_driver.go | 2 +- tests/integration/consistency/cmd/update.go | 27 ++-- .../websocket_queue_integration_test.go | 9 +- .../ws_handler_integration_test.go | 2 + tests/unit/api/ws_test.go | 16 +-- tests/unit/middleware/privacy_test.go | 130 ------------------ .../unit/scheduler/service_templates_test.go | 4 +- 15 files changed, 89 insertions(+), 285 deletions(-) delete mode 100644 internal/middleware/privacy.go delete mode 100644 tests/unit/middleware/privacy_test.go diff --git a/internal/middleware/privacy.go b/internal/middleware/privacy.go deleted file mode 100644 index d2c063a..0000000 --- a/internal/middleware/privacy.go +++ /dev/null @@ -1,94 +0,0 @@ -// Package middleware provides privacy enforcement for experiment access control. -package middleware - -import ( - "context" - "fmt" - - "github.com/jfraeys/fetch_ml/internal/auth" -) - -// PrivacyLevel defines experiment visibility levels. -type PrivacyLevel string - -const ( - // PrivacyPrivate restricts access to owner only. - PrivacyPrivate PrivacyLevel = "private" - // PrivacyTeam allows team members to view. - PrivacyTeam PrivacyLevel = "team" - // PrivacyPublic allows all authenticated users. - PrivacyPublic PrivacyLevel = "public" - // PrivacyAnonymized allows access with PII stripped. - PrivacyAnonymized PrivacyLevel = "anonymized" -) - -// PrivacyEnforcer handles privacy access control. -type PrivacyEnforcer struct { - enforceTeams bool - auditAccess bool -} - -// NewPrivacyEnforcer creates a privacy enforcer. -func NewPrivacyEnforcer(enforceTeams, auditAccess bool) *PrivacyEnforcer { - return &PrivacyEnforcer{ - enforceTeams: enforceTeams, - auditAccess: auditAccess, - } -} - -// CanAccess checks if a user can access an experiment. -func (pe *PrivacyEnforcer) CanAccess( - ctx context.Context, - user *auth.User, - experimentOwner string, - level string, - team string, -) (bool, error) { - privacyLevel := GetPrivacyLevelFromString(level) - switch privacyLevel { - case PrivacyPublic: - return true, nil - case PrivacyPrivate: - return user.Name == experimentOwner || user.Admin, nil - case PrivacyTeam: - if user.Name == experimentOwner || user.Admin { - return true, nil - } - if !pe.enforceTeams { - return true, nil // Teams not enforced, allow access - } - // Check if user is in same team - return pe.isUserInTeam(ctx, user, team) - case PrivacyAnonymized: - // Anonymized data is accessible but with PII stripped - return true, nil - default: - return false, fmt.Errorf("unknown privacy level: %s", privacyLevel) - } -} - -func (pe *PrivacyEnforcer) isUserInTeam(ctx context.Context, user *auth.User, team string) (bool, error) { - // Note: Team membership check not yet implemented. - // Future: query teams database or use JWT claims for verification. - // Currently denies access when team enforcement is enabled. - _ = ctx - _ = user - _ = team - return false, nil -} - -// GetPrivacyLevelFromString converts string to PrivacyLevel. -func GetPrivacyLevelFromString(level string) PrivacyLevel { - switch level { - case "private": - return PrivacyPrivate - case "team": - return PrivacyTeam - case "public": - return PrivacyPublic - case "anonymized": - return PrivacyAnonymized - default: - return PrivacyPrivate // Default to private for safety - } -} diff --git a/tests/benchmarks/scheduler_bench_test.go b/tests/benchmarks/scheduler_bench_test.go index fe96d88..be585c9 100644 --- a/tests/benchmarks/scheduler_bench_test.go +++ b/tests/benchmarks/scheduler_bench_test.go @@ -73,9 +73,13 @@ func BenchmarkStateStoreAppend(b *testing.B) { // BenchmarkSchedulerSubmitJob measures job submission throughput func BenchmarkSchedulerSubmitJob(b *testing.B) { + // Create isolated state directory + stateDir := b.TempDir() + // Create scheduler directly for benchmark cfg := scheduler.HubConfig{ BindAddr: "localhost:0", + StateDir: stateDir, DefaultBatchSlots: 4, StarvationThresholdMins: 5, AcceptanceTimeoutSecs: 5, diff --git a/tests/e2e/websocket_e2e_test.go b/tests/e2e/websocket_e2e_test.go index 9c85a08..0081e65 100644 --- a/tests/e2e/websocket_e2e_test.go +++ b/tests/e2e/websocket_e2e_test.go @@ -25,11 +25,11 @@ func setupTestServer(t *testing.T) string { authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) // Create listener to get actual port listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") diff --git a/tests/e2e/wss_reverse_proxy_e2e_test.go b/tests/e2e/wss_reverse_proxy_e2e_test.go index 965f1bd..37ecb77 100644 --- a/tests/e2e/wss_reverse_proxy_e2e_test.go +++ b/tests/e2e/wss_reverse_proxy_e2e_test.go @@ -40,10 +40,10 @@ func startWSBackendServer(t *testing.T) *httptest.Server { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) srv := httptest.NewServer(h) t.Cleanup(srv.Close) diff --git a/tests/fixtures/scheduler_fixture.go b/tests/fixtures/scheduler_fixture.go index 5502fa5..f2ac11e 100644 --- a/tests/fixtures/scheduler_fixture.go +++ b/tests/fixtures/scheduler_fixture.go @@ -2,6 +2,7 @@ package tests import ( + "fmt" "os" "testing" "time" @@ -72,33 +73,45 @@ func (f *SchedulerTestFixture) Cleanup() { // Then stop the hub f.Hub.Stop() // Clean up isolated state directory - os.RemoveAll(f.stateDir) + if err := os.RemoveAll(f.stateDir); err != nil { + // Log cleanup error but don't fail test + fmt.Fprintf(os.Stderr, "failed to remove state dir: %v\n", err) + } } // DefaultHubConfig returns a default hub configuration for testing func DefaultHubConfig() scheduler.HubConfig { + tokens := map[string]string{ + "test-token-worker-restart-1": "worker-restart-1", + "test-token-mode-switch-worker": "mode-switch-worker", + "test-token-mode-switch-worker-2": "mode-switch-worker-2", + "test-token-e2e-worker-1": "e2e-worker-1", + "test-token-e2e-worker-2": "e2e-worker-2", + "test-token-worker-death-test": "worker-death-test", + "test-token-worker-split-1": "worker-split-1", + "test-token-worker-split-2": "worker-split-2", + "test-token-worker-split-3": "worker-split-3", + "test-token-worker-timeout": "worker-timeout", + "test-token-worker-gang": "worker-gang", + "test-token-bench-worker": "bench-worker", + "test-token-bench-hb-worker": "bench-hb-worker", + "test-token-bench-assign-worker": "bench-assign-worker", + } + + // Add tokens for dynamic benchmark worker IDs (0-999 for each pattern) + for i := range 1000 { + tokens[fmt.Sprintf("test-token-bench-worker-%d", i)] = fmt.Sprintf("bench-worker-%d", i) + tokens[fmt.Sprintf("test-token-bench-multi-worker-%d", i)] = fmt.Sprintf("bench-multi-worker-%d", i) + } + return scheduler.HubConfig{ BindAddr: "localhost:0", DefaultBatchSlots: 4, StarvationThresholdMins: 5, AcceptanceTimeoutSecs: 5, GangAllocTimeoutSecs: 10, - WorkerTokens: map[string]string{ - "test-token-worker-restart-1": "worker-restart-1", - "test-token-mode-switch-worker": "mode-switch-worker", - "test-token-mode-switch-worker-2": "mode-switch-worker-2", - "test-token-e2e-worker-1": "e2e-worker-1", - "test-token-e2e-worker-2": "e2e-worker-2", - "test-token-worker-death-test": "worker-death-test", - "test-token-worker-split-1": "worker-split-1", - "test-token-worker-split-2": "worker-split-2", - "test-token-worker-split-3": "worker-split-3", - "test-token-worker-timeout": "worker-timeout", - "test-token-worker-gang": "worker-gang", - "test-token-bench-worker": "bench-worker", - "test-token-bench-hb-worker": "bench-hb-worker", - "test-token-bench-assign-worker": "bench-assign-worker", - }, + // #nosec G101 -- These are test fixture tokens, not real credentials + WorkerTokens: tokens, } } diff --git a/tests/fixtures/scheduler_mock.go b/tests/fixtures/scheduler_mock.go index 980740a..f38b057 100644 --- a/tests/fixtures/scheduler_mock.go +++ b/tests/fixtures/scheduler_mock.go @@ -3,7 +3,9 @@ package tests import ( "encoding/json" + "fmt" "net/http" + "os" "sync" "testing" "time" @@ -201,7 +203,10 @@ func (mw *MockWorker) Close() { mw.mu.Unlock() close(mw.SendCh) - mw.Conn.Close() + if err := mw.Conn.Close(); err != nil { + // Log but don't fail in cleanup + fmt.Fprintf(os.Stderr, "failed to close connection: %v\n", err) + } mw.wg.Wait() } diff --git a/tests/fixtures/ssh_server.go b/tests/fixtures/ssh_server.go index 2a4a9e8..e8ffb11 100644 --- a/tests/fixtures/ssh_server.go +++ b/tests/fixtures/ssh_server.go @@ -45,6 +45,7 @@ func NewSSHTestServer(t *testing.T) *SSHTestServer { } // Read private key + // #nosec G304 -- path is constructed from repo root to known test key location privateKey, err := os.ReadFile(privateKeyPath) if err != nil { t.Fatalf("failed to read private key: %v", err) @@ -98,7 +99,7 @@ func (s *SSHTestServer) waitForSSH() error { Auth: []ssh.AuthMethod{ ssh.PublicKeys(s.Signer), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // Test only + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 -- Test only Timeout: 5 * time.Second, } @@ -106,9 +107,8 @@ func (s *SSHTestServer) waitForSSH() error { // Retry with backoff for i := 0; i < 10; i++ { - client, err := ssh.Dial("tcp", addr, config) - if err == nil { - client.Close() + if client, err := ssh.Dial("tcp", addr, config); err == nil { + _ = client.Close() return nil } time.Sleep(500 * time.Millisecond) @@ -124,7 +124,7 @@ func (s *SSHTestServer) NewClient() (*ssh.Client, error) { Auth: []ssh.AuthMethod{ ssh.PublicKeys(s.Signer), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // Test only + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 -- Test only Timeout: 10 * time.Second, } @@ -178,7 +178,7 @@ func (s *SSHTestServer) ExecWithPTY(cmd string, term string, width, height int) session, err := client.NewSession() if err != nil { - client.Close() + _ = client.Close() return nil, fmt.Errorf("failed to create session: %w", err) } @@ -190,14 +190,14 @@ func (s *SSHTestServer) ExecWithPTY(cmd string, term string, width, height int) } if err := session.RequestPty(term, width, height, modes); err != nil { - session.Close() - client.Close() + _ = session.Close() + _ = client.Close() return nil, fmt.Errorf("failed to request pty: %w", err) } if err := session.Start(cmd); err != nil { - session.Close() - client.Close() + _ = session.Close() + _ = client.Close() return nil, fmt.Errorf("failed to start command: %w", err) } diff --git a/tests/fixtures/test_utils.go b/tests/fixtures/test_utils.go index ce80c59..df72512 100644 --- a/tests/fixtures/test_utils.go +++ b/tests/fixtures/test_utils.go @@ -513,7 +513,7 @@ func CopyDir(src, dst string) error { } func copyFile(src, dst string) error { - //nolint:gosec // G304: Potential file inclusion via variable - this is a test utility + // #nosec G304 -- test utility with controlled paths srcFile, err := os.Open(src) if err != nil { return err @@ -525,7 +525,7 @@ func copyFile(src, dst string) error { return err } - //nolint:gosec // G304: Potential file inclusion via variable - this is a test utility + // #nosec G304 -- test utility with controlled paths dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode()) if err != nil { return err diff --git a/tests/fixtures/tui_driver.go b/tests/fixtures/tui_driver.go index dbeab71..3732507 100644 --- a/tests/fixtures/tui_driver.go +++ b/tests/fixtures/tui_driver.go @@ -167,7 +167,7 @@ func (d *TUIDriver) WaitForOutput(expected string, timeout time.Duration) error // Close closes the TUI driver and session func (d *TUIDriver) Close() error { - d.stdin.Close() + _ = d.stdin.Close() return d.session.Close() } diff --git a/tests/integration/consistency/cmd/update.go b/tests/integration/consistency/cmd/update.go index 83e329c..bf4f4d4 100644 --- a/tests/integration/consistency/cmd/update.go +++ b/tests/integration/consistency/cmd/update.go @@ -12,41 +12,42 @@ import ( func main() { fixturesDir := filepath.Join("tests", "fixtures", "consistency") - + // Load current expected hashes expectedPath := filepath.Join(fixturesDir, "dataset_hash", "expected_hashes.json") + // #nosec G304 -- path is a hardcoded test fixture path data, err := os.ReadFile(expectedPath) if err != nil { fmt.Fprintf(os.Stderr, "Failed to read expected hashes: %v\n", err) os.Exit(1) } - + var expected consistency.ExpectedHashes if err := json.Unmarshal(data, &expected); err != nil { fmt.Fprintf(os.Stderr, "Failed to parse expected hashes: %v\n", err) os.Exit(1) } - + // Use Go implementation as reference goImpl := consistency.NewGoImpl() - + updated := false for i, fixture := range expected.Fixtures { fixturePath := filepath.Join(fixturesDir, "dataset_hash", fixture.ID, "input") - + // Check if fixture exists if _, err := os.Stat(fixturePath); os.IsNotExist(err) { fmt.Printf("Skipping %s: fixture not found at %s\n", fixture.ID, fixturePath) continue } - + // Compute hash using reference implementation hash, err := goImpl.HashDataset(fixturePath) if err != nil { fmt.Printf("Error hashing %s: %v\n", fixture.ID, err) continue } - + // Update if different or TODO if fixture.ExpectedHash == "TODO_COMPUTE" { fmt.Printf("%s: computed %s\n", fixture.ID, hash) @@ -59,7 +60,7 @@ func main() { } else { fmt.Printf("%s: unchanged (%s)\n", fixture.ID, hash) } - + // Compute individual file hashes for j, file := range fixture.Files { if file.ContentHash == "TODO" || file.ContentHash == "" { @@ -75,23 +76,23 @@ func main() { } } } - + if !updated { fmt.Println("\nNo updates needed.") return } - + // Write updated hashes output, err := json.MarshalIndent(expected, "", " ") if err != nil { fmt.Fprintf(os.Stderr, "Failed to marshal updated hashes: %v\n", err) os.Exit(1) } - - if err := os.WriteFile(expectedPath, output, 0644); err != nil { + + if err := os.WriteFile(expectedPath, output, 0600); err != nil { fmt.Fprintf(os.Stderr, "Failed to write updated hashes: %v\n", err) os.Exit(1) } - + fmt.Println("\nUpdated expected_hashes.json") } diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index 94482b7..f0ffe37 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -48,7 +48,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( @@ -64,6 +64,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { jobsHandler, jupyterHandler, datasetsHandler, + nil, // groupsHandler ) server := httptest.NewServer(wsHandler) defer server.Close() @@ -149,7 +150,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( @@ -165,6 +166,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { jobsHandler, jupyterHandler, datasetsHandler, + nil, // groupsHandler ) server := httptest.NewServer(wsHandler) defer server.Close() @@ -254,7 +256,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil) + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( @@ -270,6 +272,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { jobsHandler, jupyterHandler, datasetsHandler, + nil, // groupsHandler ) server := httptest.NewServer(wsHandler) defer server.Close() diff --git a/tests/integration/ws_handler_integration_test.go b/tests/integration/ws_handler_integration_test.go index 38e7f5a..7e0004a 100644 --- a/tests/integration/ws_handler_integration_test.go +++ b/tests/integration/ws_handler_integration_test.go @@ -72,6 +72,7 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) ( nil, // jobsHandler nil, // jupyterHandler nil, // datasetsHandler + nil, // groupsHandler ) server := httptest.NewServer(handler) return server, tq, expManager, s, db @@ -601,6 +602,7 @@ func setupWSIntegrationServer(t *testing.T) ( nil, // jobsHandler nil, // jupyterHandler nil, // datasetsHandler + nil, // groupsHandler ) // Setup test server server := httptest.NewServer(handler) diff --git a/tests/unit/api/ws_test.go b/tests/unit/api/ws_test.go index 77df093..96c9d54 100644 --- a/tests/unit/api/ws_test.go +++ b/tests/unit/api/ws_test.go @@ -23,11 +23,11 @@ func TestNewWSHandler(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) if handler == nil { t.Error("Expected non-nil WSHandler") @@ -61,11 +61,11 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) // Create a test HTTP request req := httptest.NewRequest("GET", "/ws", nil) @@ -101,11 +101,11 @@ func TestWSHandlerInvalidRequest(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) // Create a test HTTP request without WebSocket headers req := httptest.NewRequest("GET", "/ws", nil) @@ -129,11 +129,11 @@ func TestWSHandlerPostRequest(t *testing.T) { authConfig := &auth.Config{} logger := logging.NewLogger(slog.LevelInfo, false) expManager := experiment.NewManager("/tmp") - jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") - handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil) // Create a POST request (should fail) req := httptest.NewRequest("POST", "/ws", strings.NewReader("data")) diff --git a/tests/unit/middleware/privacy_test.go b/tests/unit/middleware/privacy_test.go deleted file mode 100644 index c6e7c95..0000000 --- a/tests/unit/middleware/privacy_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package middleware_test - -import ( - "context" - "testing" - - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/middleware" -) - -func TestPrivacyEnforcer_CanAccess(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - user *auth.User - owner string - level string - team string - enforceTeams bool - want bool - }{ - { - name: "owner can access private", - user: &auth.User{Name: "alice"}, - owner: "alice", - level: "private", - want: true, - }, - { - name: "non-owner cannot access private", - user: &auth.User{Name: "bob"}, - owner: "alice", - level: "private", - want: false, - }, - { - name: "admin can access private", - user: &auth.User{Name: "admin", Admin: true}, - owner: "alice", - level: "private", - want: true, - }, - { - name: "public allows all", - user: &auth.User{Name: "anyone"}, - owner: "alice", - level: "public", - want: true, - }, - { - name: "owner can access team", - user: &auth.User{Name: "alice"}, - owner: "alice", - level: "team", - team: "research", - want: true, - }, - { - name: "non-owner denied team when enforcing", - user: &auth.User{Name: "bob"}, - owner: "alice", - level: "team", - team: "research", - enforceTeams: true, - want: false, - }, - { - name: "non-owner allowed team when not enforcing", - user: &auth.User{Name: "bob"}, - owner: "alice", - level: "team", - team: "research", - enforceTeams: false, - want: true, - }, - { - name: "anonymized allows all", - user: &auth.User{Name: "anyone"}, - owner: "alice", - level: "anonymized", - want: true, - }, - { - name: "unknown level defaults to private (deny)", - user: &auth.User{Name: "bob"}, - owner: "alice", - level: "unknown", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pe := middleware.NewPrivacyEnforcer(tt.enforceTeams, false) - got, err := pe.CanAccess(ctx, tt.user, tt.owner, tt.level, tt.team) - if err != nil { - t.Errorf("CanAccess() error = %v", err) - return - } - if got != tt.want { - t.Errorf("CanAccess() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestGetPrivacyLevelFromString(t *testing.T) { - tests := []struct { - input string - expected middleware.PrivacyLevel - }{ - {"private", middleware.PrivacyPrivate}, - {"team", middleware.PrivacyTeam}, - {"public", middleware.PrivacyPublic}, - {"anonymized", middleware.PrivacyAnonymized}, - {"unknown", middleware.PrivacyPrivate}, // Default - {"", middleware.PrivacyPrivate}, // Default - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := middleware.GetPrivacyLevelFromString(tt.input) - if got != tt.expected { - t.Errorf("GetPrivacyLevelFromString(%q) = %v, want %v", - tt.input, got, tt.expected) - } - }) - } -} diff --git a/tests/unit/scheduler/service_templates_test.go b/tests/unit/scheduler/service_templates_test.go index 20530d5..97d7726 100644 --- a/tests/unit/scheduler/service_templates_test.go +++ b/tests/unit/scheduler/service_templates_test.go @@ -243,10 +243,10 @@ func TestTemplateVariableExpansion(t *testing.T) { } assert.True(t, hasServicePort, "Command should contain {{SERVICE_PORT}} template variable") - // Check env contains secret template + // Check env contains token template (used for secret generation) val, ok := template.Env["JUPYTER_TOKEN"] assert.True(t, ok, "Should have JUPYTER_TOKEN env var") - assert.Contains(t, val, "{{SECRET:", "Should use secret template") + assert.Contains(t, val, "{{TOKEN:", "Should use token template for secret generation") } // BenchmarkPortAllocation benchmarks port allocation performance