fetch_ml/tests/e2e/websocket_e2e_test.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

291 lines
7.6 KiB
Go

package tests
import (
"context"
"encoding/json"
"log/slog"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// setupTestServer creates a test server with WebSocket handler and returns the address
func setupTestServer(t *testing.T) string {
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
wsHandler := api.NewWSHandler(authConfig, logger, expManager, nil)
// Create listener to get actual port
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
addr := listener.Addr().String()
server := &http.Server{
Handler: wsHandler,
ReadHeaderTimeout: 5 * time.Second,
}
// Start server
serverErr := make(chan error, 1)
go func() {
serverErr <- server.Serve(listener)
}()
// Wait for server to start
select {
case err := <-serverErr:
t.Fatalf("Failed to start server: %v", err)
case <-time.After(100 * time.Millisecond):
// Server should be ready
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
<-serverErr // Wait for server to stop
})
return addr
}
func TestWebSocketRealConnection(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 1: Basic WebSocket connection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer func() { _ = conn.Close() }()
t.Log("Successfully established WebSocket connection")
// Test 2: Send a status request
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
err = conn.WriteMessage(websocket.BinaryMessage, []byte{0x02, 0x00})
if err != nil {
t.Fatalf("Failed to send status request: %v", err)
}
// Test 3: Read response with timeout
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, message, err := conn.ReadMessage()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
t.Log("Read timeout - this is expected for status request")
} else {
t.Logf("Failed to read message: %v", err)
}
} else {
t.Logf("Received message type %d: %s", messageType, string(message))
}
// Test 4: Send invalid message
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, []byte("invalid"))
if err != nil {
t.Fatalf("Failed to send invalid message: %v", err)
}
// Try to read response (may get error due to server closing connection)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err = conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.ClosePolicyViolation) {
t.Log("Server correctly closed connection due to invalid message")
} else {
t.Logf("Server handled invalid message with error: %v", err)
}
}
}
func TestWebSocketBinaryProtocol(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
time.Sleep(100 * time.Millisecond)
// Connect to WebSocket
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer func() { _ = conn.Close() }()
// Test 4: Send binary message with queue job opcode
jobData := map[string]interface{}{
"job_id": "test-job-1",
"commit_id": "abc123",
"user": "testuser",
"script": "train.py",
}
dataBytes, _ := json.Marshal(jobData)
// Create binary message: [opcode][data_length][data]
binaryMessage := make([]byte, 1+4+len(dataBytes))
binaryMessage[0] = 0x01 // OpcodeQueueJob
// Add data length (big endian)
binaryMessage[1] = byte(len(dataBytes) >> 24)
binaryMessage[2] = byte(len(dataBytes) >> 16)
binaryMessage[3] = byte(len(dataBytes) >> 8)
binaryMessage[4] = byte(len(dataBytes))
// Add data
copy(binaryMessage[5:], dataBytes)
err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage)
if err != nil {
t.Fatalf("Failed to send binary message: %v", err)
}
t.Log("Successfully sent binary queue job message")
// Read response (if any)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
t.Log("Connection closed normally after binary message")
} else {
t.Logf("No response received (expected): %v", err)
}
} else {
t.Logf("Received response to binary message: %s", string(message))
}
}
func TestWebSocketConcurrentConnections(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 5: Multiple concurrent connections
numConnections := 5
connections := make([]*websocket.Conn, numConnections)
errors := make([]error, numConnections)
// Create multiple connections
for i := range numConnections {
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
if err != nil {
errors[i] = err
continue
}
connections[i] = conn
}
// Close all connections
for _, conn := range connections {
if conn != nil {
_ = conn.Close()
}
}
// Verify all connections succeeded
for i, err := range errors {
if err != nil {
t.Errorf("Connection %d failed: %v", i, err)
}
}
successCount := 0
for _, conn := range connections {
if conn != nil {
successCount++
}
}
if successCount != numConnections {
t.Errorf("Expected %d successful connections, got %d", numConnections, successCount)
}
t.Logf("Successfully established %d concurrent WebSocket connections", successCount)
}
func TestWebSocketConnectionResilience(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 6: Connection resilience and reconnection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
// First connection
conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp1 != nil && resp1.Body != nil {
defer func() { _ = resp1.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to establish first connection: %v", err)
}
// Send a message
err = conn1.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on first connection: %v", err)
}
// Close first connection
_ = conn1.Close()
// Wait a moment
time.Sleep(100 * time.Millisecond)
// Reconnect
conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp2 != nil && resp2.Body != nil {
defer func() { _ = resp2.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to reconnect: %v", err)
}
defer func() { _ = conn2.Close() }()
// Send message on reconnected connection
err = conn2.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on reconnected connection: %v", err)
}
t.Log("Successfully tested connection resilience and reconnection")
}