test: update test suite and remove deprecated privacy middleware
Test improvements: - fixtures/: Updated mocks, fixtures with group context, SSH server, TUI driver - integration/: WebSocket queue and handler tests with groups - e2e/: WebSocket and TLS proxy end-to-end tests - unit/api/ws_test.go: WebSocket API tests - unit/scheduler/service_templates_test.go: Service template tests - benchmarks/scheduler_bench_test.go: Performance benchmarks Cleanup: - Remove privacy middleware (replaced by audit system) - Remove privacy_test.go
This commit is contained in:
parent
cb142213fa
commit
c74e91dd69
15 changed files with 89 additions and 285 deletions
|
|
@ -1,94 +0,0 @@
|
|||
// Package middleware provides privacy enforcement for experiment access control.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
)
|
||||
|
||||
// PrivacyLevel defines experiment visibility levels.
|
||||
type PrivacyLevel string
|
||||
|
||||
const (
|
||||
// PrivacyPrivate restricts access to owner only.
|
||||
PrivacyPrivate PrivacyLevel = "private"
|
||||
// PrivacyTeam allows team members to view.
|
||||
PrivacyTeam PrivacyLevel = "team"
|
||||
// PrivacyPublic allows all authenticated users.
|
||||
PrivacyPublic PrivacyLevel = "public"
|
||||
// PrivacyAnonymized allows access with PII stripped.
|
||||
PrivacyAnonymized PrivacyLevel = "anonymized"
|
||||
)
|
||||
|
||||
// PrivacyEnforcer handles privacy access control.
|
||||
type PrivacyEnforcer struct {
|
||||
enforceTeams bool
|
||||
auditAccess bool
|
||||
}
|
||||
|
||||
// NewPrivacyEnforcer creates a privacy enforcer.
|
||||
func NewPrivacyEnforcer(enforceTeams, auditAccess bool) *PrivacyEnforcer {
|
||||
return &PrivacyEnforcer{
|
||||
enforceTeams: enforceTeams,
|
||||
auditAccess: auditAccess,
|
||||
}
|
||||
}
|
||||
|
||||
// CanAccess checks if a user can access an experiment.
|
||||
func (pe *PrivacyEnforcer) CanAccess(
|
||||
ctx context.Context,
|
||||
user *auth.User,
|
||||
experimentOwner string,
|
||||
level string,
|
||||
team string,
|
||||
) (bool, error) {
|
||||
privacyLevel := GetPrivacyLevelFromString(level)
|
||||
switch privacyLevel {
|
||||
case PrivacyPublic:
|
||||
return true, nil
|
||||
case PrivacyPrivate:
|
||||
return user.Name == experimentOwner || user.Admin, nil
|
||||
case PrivacyTeam:
|
||||
if user.Name == experimentOwner || user.Admin {
|
||||
return true, nil
|
||||
}
|
||||
if !pe.enforceTeams {
|
||||
return true, nil // Teams not enforced, allow access
|
||||
}
|
||||
// Check if user is in same team
|
||||
return pe.isUserInTeam(ctx, user, team)
|
||||
case PrivacyAnonymized:
|
||||
// Anonymized data is accessible but with PII stripped
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown privacy level: %s", privacyLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func (pe *PrivacyEnforcer) isUserInTeam(ctx context.Context, user *auth.User, team string) (bool, error) {
|
||||
// Note: Team membership check not yet implemented.
|
||||
// Future: query teams database or use JWT claims for verification.
|
||||
// Currently denies access when team enforcement is enabled.
|
||||
_ = ctx
|
||||
_ = user
|
||||
_ = team
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetPrivacyLevelFromString converts string to PrivacyLevel.
|
||||
func GetPrivacyLevelFromString(level string) PrivacyLevel {
|
||||
switch level {
|
||||
case "private":
|
||||
return PrivacyPrivate
|
||||
case "team":
|
||||
return PrivacyTeam
|
||||
case "public":
|
||||
return PrivacyPublic
|
||||
case "anonymized":
|
||||
return PrivacyAnonymized
|
||||
default:
|
||||
return PrivacyPrivate // Default to private for safety
|
||||
}
|
||||
}
|
||||
|
|
@ -73,9 +73,13 @@ func BenchmarkStateStoreAppend(b *testing.B) {
|
|||
|
||||
// BenchmarkSchedulerSubmitJob measures job submission throughput
|
||||
func BenchmarkSchedulerSubmitJob(b *testing.B) {
|
||||
// Create isolated state directory
|
||||
stateDir := b.TempDir()
|
||||
|
||||
// Create scheduler directly for benchmark
|
||||
cfg := scheduler.HubConfig{
|
||||
BindAddr: "localhost:0",
|
||||
StateDir: stateDir,
|
||||
DefaultBatchSlots: 4,
|
||||
StarvationThresholdMins: 5,
|
||||
AcceptanceTimeoutSecs: 5,
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ func setupTestServer(t *testing.T) string {
|
|||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
// Create listener to get actual port
|
||||
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ func startWSBackendServer(t *testing.T) *httptest.Server {
|
|||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
srv := httptest.NewServer(h)
|
||||
t.Cleanup(srv.Close)
|
||||
|
|
|
|||
47
tests/fixtures/scheduler_fixture.go
vendored
47
tests/fixtures/scheduler_fixture.go
vendored
|
|
@ -2,6 +2,7 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -72,33 +73,45 @@ func (f *SchedulerTestFixture) Cleanup() {
|
|||
// Then stop the hub
|
||||
f.Hub.Stop()
|
||||
// Clean up isolated state directory
|
||||
os.RemoveAll(f.stateDir)
|
||||
if err := os.RemoveAll(f.stateDir); err != nil {
|
||||
// Log cleanup error but don't fail test
|
||||
fmt.Fprintf(os.Stderr, "failed to remove state dir: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHubConfig returns a default hub configuration for testing
|
||||
func DefaultHubConfig() scheduler.HubConfig {
|
||||
tokens := map[string]string{
|
||||
"test-token-worker-restart-1": "worker-restart-1",
|
||||
"test-token-mode-switch-worker": "mode-switch-worker",
|
||||
"test-token-mode-switch-worker-2": "mode-switch-worker-2",
|
||||
"test-token-e2e-worker-1": "e2e-worker-1",
|
||||
"test-token-e2e-worker-2": "e2e-worker-2",
|
||||
"test-token-worker-death-test": "worker-death-test",
|
||||
"test-token-worker-split-1": "worker-split-1",
|
||||
"test-token-worker-split-2": "worker-split-2",
|
||||
"test-token-worker-split-3": "worker-split-3",
|
||||
"test-token-worker-timeout": "worker-timeout",
|
||||
"test-token-worker-gang": "worker-gang",
|
||||
"test-token-bench-worker": "bench-worker",
|
||||
"test-token-bench-hb-worker": "bench-hb-worker",
|
||||
"test-token-bench-assign-worker": "bench-assign-worker",
|
||||
}
|
||||
|
||||
// Add tokens for dynamic benchmark worker IDs (0-999 for each pattern)
|
||||
for i := range 1000 {
|
||||
tokens[fmt.Sprintf("test-token-bench-worker-%d", i)] = fmt.Sprintf("bench-worker-%d", i)
|
||||
tokens[fmt.Sprintf("test-token-bench-multi-worker-%d", i)] = fmt.Sprintf("bench-multi-worker-%d", i)
|
||||
}
|
||||
|
||||
return scheduler.HubConfig{
|
||||
BindAddr: "localhost:0",
|
||||
DefaultBatchSlots: 4,
|
||||
StarvationThresholdMins: 5,
|
||||
AcceptanceTimeoutSecs: 5,
|
||||
GangAllocTimeoutSecs: 10,
|
||||
WorkerTokens: map[string]string{
|
||||
"test-token-worker-restart-1": "worker-restart-1",
|
||||
"test-token-mode-switch-worker": "mode-switch-worker",
|
||||
"test-token-mode-switch-worker-2": "mode-switch-worker-2",
|
||||
"test-token-e2e-worker-1": "e2e-worker-1",
|
||||
"test-token-e2e-worker-2": "e2e-worker-2",
|
||||
"test-token-worker-death-test": "worker-death-test",
|
||||
"test-token-worker-split-1": "worker-split-1",
|
||||
"test-token-worker-split-2": "worker-split-2",
|
||||
"test-token-worker-split-3": "worker-split-3",
|
||||
"test-token-worker-timeout": "worker-timeout",
|
||||
"test-token-worker-gang": "worker-gang",
|
||||
"test-token-bench-worker": "bench-worker",
|
||||
"test-token-bench-hb-worker": "bench-hb-worker",
|
||||
"test-token-bench-assign-worker": "bench-assign-worker",
|
||||
},
|
||||
// #nosec G101 -- These are test fixture tokens, not real credentials
|
||||
WorkerTokens: tokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
7
tests/fixtures/scheduler_mock.go
vendored
7
tests/fixtures/scheduler_mock.go
vendored
|
|
@ -3,7 +3,9 @@ package tests
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -201,7 +203,10 @@ func (mw *MockWorker) Close() {
|
|||
mw.mu.Unlock()
|
||||
|
||||
close(mw.SendCh)
|
||||
mw.Conn.Close()
|
||||
if err := mw.Conn.Close(); err != nil {
|
||||
// Log but don't fail in cleanup
|
||||
fmt.Fprintf(os.Stderr, "failed to close connection: %v\n", err)
|
||||
}
|
||||
mw.wg.Wait()
|
||||
}
|
||||
|
||||
|
|
|
|||
20
tests/fixtures/ssh_server.go
vendored
20
tests/fixtures/ssh_server.go
vendored
|
|
@ -45,6 +45,7 @@ func NewSSHTestServer(t *testing.T) *SSHTestServer {
|
|||
}
|
||||
|
||||
// Read private key
|
||||
// #nosec G304 -- path is constructed from repo root to known test key location
|
||||
privateKey, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read private key: %v", err)
|
||||
|
|
@ -98,7 +99,7 @@ func (s *SSHTestServer) waitForSSH() error {
|
|||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(s.Signer),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // Test only
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 -- Test only
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
|
|
@ -106,9 +107,8 @@ func (s *SSHTestServer) waitForSSH() error {
|
|||
|
||||
// Retry with backoff
|
||||
for i := 0; i < 10; i++ {
|
||||
client, err := ssh.Dial("tcp", addr, config)
|
||||
if err == nil {
|
||||
client.Close()
|
||||
if client, err := ssh.Dial("tcp", addr, config); err == nil {
|
||||
_ = client.Close()
|
||||
return nil
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
|
@ -124,7 +124,7 @@ func (s *SSHTestServer) NewClient() (*ssh.Client, error) {
|
|||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(s.Signer),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // Test only
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 -- Test only
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ func (s *SSHTestServer) ExecWithPTY(cmd string, term string, width, height int)
|
|||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -190,14 +190,14 @@ func (s *SSHTestServer) ExecWithPTY(cmd string, term string, width, height int)
|
|||
}
|
||||
|
||||
if err := session.RequestPty(term, width, height, modes); err != nil {
|
||||
session.Close()
|
||||
client.Close()
|
||||
_ = session.Close()
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("failed to request pty: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Start(cmd); err != nil {
|
||||
session.Close()
|
||||
client.Close()
|
||||
_ = session.Close()
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
4
tests/fixtures/test_utils.go
vendored
4
tests/fixtures/test_utils.go
vendored
|
|
@ -513,7 +513,7 @@ func CopyDir(src, dst string) error {
|
|||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
//nolint:gosec // G304: Potential file inclusion via variable - this is a test utility
|
||||
// #nosec G304 -- test utility with controlled paths
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -525,7 +525,7 @@ func copyFile(src, dst string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
//nolint:gosec // G304: Potential file inclusion via variable - this is a test utility
|
||||
// #nosec G304 -- test utility with controlled paths
|
||||
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
2
tests/fixtures/tui_driver.go
vendored
2
tests/fixtures/tui_driver.go
vendored
|
|
@ -167,7 +167,7 @@ func (d *TUIDriver) WaitForOutput(expected string, timeout time.Duration) error
|
|||
|
||||
// Close closes the TUI driver and session
|
||||
func (d *TUIDriver) Close() error {
|
||||
d.stdin.Close()
|
||||
_ = d.stdin.Close()
|
||||
return d.session.Close()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,41 +12,42 @@ import (
|
|||
|
||||
func main() {
|
||||
fixturesDir := filepath.Join("tests", "fixtures", "consistency")
|
||||
|
||||
|
||||
// Load current expected hashes
|
||||
expectedPath := filepath.Join(fixturesDir, "dataset_hash", "expected_hashes.json")
|
||||
// #nosec G304 -- path is a hardcoded test fixture path
|
||||
data, err := os.ReadFile(expectedPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read expected hashes: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
||||
var expected consistency.ExpectedHashes
|
||||
if err := json.Unmarshal(data, &expected); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse expected hashes: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
||||
// Use Go implementation as reference
|
||||
goImpl := consistency.NewGoImpl()
|
||||
|
||||
|
||||
updated := false
|
||||
for i, fixture := range expected.Fixtures {
|
||||
fixturePath := filepath.Join(fixturesDir, "dataset_hash", fixture.ID, "input")
|
||||
|
||||
|
||||
// Check if fixture exists
|
||||
if _, err := os.Stat(fixturePath); os.IsNotExist(err) {
|
||||
fmt.Printf("Skipping %s: fixture not found at %s\n", fixture.ID, fixturePath)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Compute hash using reference implementation
|
||||
hash, err := goImpl.HashDataset(fixturePath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error hashing %s: %v\n", fixture.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Update if different or TODO
|
||||
if fixture.ExpectedHash == "TODO_COMPUTE" {
|
||||
fmt.Printf("%s: computed %s\n", fixture.ID, hash)
|
||||
|
|
@ -59,7 +60,7 @@ func main() {
|
|||
} else {
|
||||
fmt.Printf("%s: unchanged (%s)\n", fixture.ID, hash)
|
||||
}
|
||||
|
||||
|
||||
// Compute individual file hashes
|
||||
for j, file := range fixture.Files {
|
||||
if file.ContentHash == "TODO" || file.ContentHash == "" {
|
||||
|
|
@ -75,23 +76,23 @@ func main() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !updated {
|
||||
fmt.Println("\nNo updates needed.")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Write updated hashes
|
||||
output, err := json.MarshalIndent(expected, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to marshal updated hashes: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(expectedPath, output, 0644); err != nil {
|
||||
|
||||
if err := os.WriteFile(expectedPath, output, 0600); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write updated hashes: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
||||
fmt.Println("\nUpdated expected_hashes.json")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
@ -64,6 +64,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
nil, // groupsHandler
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
@ -149,7 +150,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
@ -165,6 +166,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
nil, // groupsHandler
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
@ -254,7 +256,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
|
|
@ -270,6 +272,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
nil, // groupsHandler
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
|
|||
nil, // jobsHandler
|
||||
nil, // jupyterHandler
|
||||
nil, // datasetsHandler
|
||||
nil, // groupsHandler
|
||||
)
|
||||
server := httptest.NewServer(handler)
|
||||
return server, tq, expManager, s, db
|
||||
|
|
@ -601,6 +602,7 @@ func setupWSIntegrationServer(t *testing.T) (
|
|||
nil, // jobsHandler
|
||||
nil, // jupyterHandler
|
||||
nil, // datasetsHandler
|
||||
nil, // groupsHandler
|
||||
)
|
||||
// Setup test server
|
||||
server := httptest.NewServer(handler)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ func TestNewWSHandler(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
if handler == nil {
|
||||
t.Error("Expected non-nil WSHandler")
|
||||
|
|
@ -61,11 +61,11 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
// Create a test HTTP request
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
|
|
@ -101,11 +101,11 @@ func TestWSHandlerInvalidRequest(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
// Create a test HTTP request without WebSocket headers
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
|
|
@ -129,11 +129,11 @@ func TestWSHandlerPostRequest(t *testing.T) {
|
|||
authConfig := &auth.Config{}
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
expManager := experiment.NewManager("/tmp")
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler, nil)
|
||||
|
||||
// Create a POST request (should fail)
|
||||
req := httptest.NewRequest("POST", "/ws", strings.NewReader("data"))
|
||||
|
|
|
|||
|
|
@ -1,130 +0,0 @@
|
|||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/middleware"
|
||||
)
|
||||
|
||||
func TestPrivacyEnforcer_CanAccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *auth.User
|
||||
owner string
|
||||
level string
|
||||
team string
|
||||
enforceTeams bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "owner can access private",
|
||||
user: &auth.User{Name: "alice"},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-owner cannot access private",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "admin can access private",
|
||||
user: &auth.User{Name: "admin", Admin: true},
|
||||
owner: "alice",
|
||||
level: "private",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public allows all",
|
||||
user: &auth.User{Name: "anyone"},
|
||||
owner: "alice",
|
||||
level: "public",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "owner can access team",
|
||||
user: &auth.User{Name: "alice"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-owner denied team when enforcing",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
enforceTeams: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-owner allowed team when not enforcing",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "team",
|
||||
team: "research",
|
||||
enforceTeams: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "anonymized allows all",
|
||||
user: &auth.User{Name: "anyone"},
|
||||
owner: "alice",
|
||||
level: "anonymized",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "unknown level defaults to private (deny)",
|
||||
user: &auth.User{Name: "bob"},
|
||||
owner: "alice",
|
||||
level: "unknown",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pe := middleware.NewPrivacyEnforcer(tt.enforceTeams, false)
|
||||
got, err := pe.CanAccess(ctx, tt.user, tt.owner, tt.level, tt.team)
|
||||
if err != nil {
|
||||
t.Errorf("CanAccess() error = %v", err)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("CanAccess() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrivacyLevelFromString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected middleware.PrivacyLevel
|
||||
}{
|
||||
{"private", middleware.PrivacyPrivate},
|
||||
{"team", middleware.PrivacyTeam},
|
||||
{"public", middleware.PrivacyPublic},
|
||||
{"anonymized", middleware.PrivacyAnonymized},
|
||||
{"unknown", middleware.PrivacyPrivate}, // Default
|
||||
{"", middleware.PrivacyPrivate}, // Default
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := middleware.GetPrivacyLevelFromString(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("GetPrivacyLevelFromString(%q) = %v, want %v",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -243,10 +243,10 @@ func TestTemplateVariableExpansion(t *testing.T) {
|
|||
}
|
||||
assert.True(t, hasServicePort, "Command should contain {{SERVICE_PORT}} template variable")
|
||||
|
||||
// Check env contains secret template
|
||||
// Check env contains token template (used for secret generation)
|
||||
val, ok := template.Env["JUPYTER_TOKEN"]
|
||||
assert.True(t, ok, "Should have JUPYTER_TOKEN env var")
|
||||
assert.Contains(t, val, "{{SECRET:", "Should use secret template")
|
||||
assert.Contains(t, val, "{{TOKEN:", "Should use token template for secret generation")
|
||||
}
|
||||
|
||||
// BenchmarkPortAllocation benchmarks port allocation performance
|
||||
|
|
|
|||
Loading…
Reference in a new issue