- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
291 lines
7.6 KiB
Go
291 lines
7.6 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
// setupTestServer creates a test server with WebSocket handler and returns the address
|
|
func setupTestServer(t *testing.T) string {
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
authConfig := &auth.Config{Enabled: false}
|
|
expManager := experiment.NewManager(t.TempDir())
|
|
|
|
wsHandler := api.NewWSHandler(authConfig, logger, expManager, nil)
|
|
|
|
// Create listener to get actual port
|
|
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create listener: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
|
|
server := &http.Server{
|
|
Handler: wsHandler,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
}
|
|
|
|
// Start server
|
|
serverErr := make(chan error, 1)
|
|
go func() {
|
|
serverErr <- server.Serve(listener)
|
|
}()
|
|
|
|
// Wait for server to start
|
|
select {
|
|
case err := <-serverErr:
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
case <-time.After(100 * time.Millisecond):
|
|
// Server should be ready
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
<-serverErr // Wait for server to stop
|
|
})
|
|
|
|
return addr
|
|
}
|
|
|
|
func TestWebSocketRealConnection(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 1: Basic WebSocket connection
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() { _ = resp.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to WebSocket: %v", err)
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
t.Log("Successfully established WebSocket connection")
|
|
|
|
// Test 2: Send a status request
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
err = conn.WriteMessage(websocket.BinaryMessage, []byte{0x02, 0x00})
|
|
if err != nil {
|
|
t.Fatalf("Failed to send status request: %v", err)
|
|
}
|
|
|
|
// Test 3: Read response with timeout
|
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
messageType, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
t.Log("Read timeout - this is expected for status request")
|
|
} else {
|
|
t.Logf("Failed to read message: %v", err)
|
|
}
|
|
} else {
|
|
t.Logf("Received message type %d: %s", messageType, string(message))
|
|
}
|
|
|
|
// Test 4: Send invalid message
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
err = conn.WriteMessage(websocket.TextMessage, []byte("invalid"))
|
|
if err != nil {
|
|
t.Fatalf("Failed to send invalid message: %v", err)
|
|
}
|
|
|
|
// Try to read response (may get error due to server closing connection)
|
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, _, err = conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsCloseError(err, websocket.ClosePolicyViolation) {
|
|
t.Log("Server correctly closed connection due to invalid message")
|
|
} else {
|
|
t.Logf("Server handled invalid message with error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWebSocketBinaryProtocol(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Connect to WebSocket
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() { _ = resp.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to WebSocket: %v", err)
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
// Test 4: Send binary message with queue job opcode
|
|
jobData := map[string]interface{}{
|
|
"job_id": "test-job-1",
|
|
"commit_id": "abc123",
|
|
"user": "testuser",
|
|
"script": "train.py",
|
|
}
|
|
|
|
dataBytes, _ := json.Marshal(jobData)
|
|
|
|
// Create binary message: [opcode][data_length][data]
|
|
binaryMessage := make([]byte, 1+4+len(dataBytes))
|
|
binaryMessage[0] = 0x01 // OpcodeQueueJob
|
|
|
|
// Add data length (big endian)
|
|
binaryMessage[1] = byte(len(dataBytes) >> 24)
|
|
binaryMessage[2] = byte(len(dataBytes) >> 16)
|
|
binaryMessage[3] = byte(len(dataBytes) >> 8)
|
|
binaryMessage[4] = byte(len(dataBytes))
|
|
|
|
// Add data
|
|
copy(binaryMessage[5:], dataBytes)
|
|
|
|
err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage)
|
|
if err != nil {
|
|
t.Fatalf("Failed to send binary message: %v", err)
|
|
}
|
|
|
|
t.Log("Successfully sent binary queue job message")
|
|
|
|
// Read response (if any)
|
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
t.Log("Connection closed normally after binary message")
|
|
} else {
|
|
t.Logf("No response received (expected): %v", err)
|
|
}
|
|
} else {
|
|
t.Logf("Received response to binary message: %s", string(message))
|
|
}
|
|
}
|
|
|
|
func TestWebSocketConcurrentConnections(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 5: Multiple concurrent connections
|
|
numConnections := 5
|
|
connections := make([]*websocket.Conn, numConnections)
|
|
errors := make([]error, numConnections)
|
|
|
|
// Create multiple connections
|
|
for i := range numConnections {
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
errors[i] = err
|
|
continue
|
|
}
|
|
connections[i] = conn
|
|
}
|
|
|
|
// Close all connections
|
|
for _, conn := range connections {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
// Verify all connections succeeded
|
|
for i, err := range errors {
|
|
if err != nil {
|
|
t.Errorf("Connection %d failed: %v", i, err)
|
|
}
|
|
}
|
|
|
|
successCount := 0
|
|
for _, conn := range connections {
|
|
if conn != nil {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
if successCount != numConnections {
|
|
t.Errorf("Expected %d successful connections, got %d", numConnections, successCount)
|
|
}
|
|
|
|
t.Logf("Successfully established %d concurrent WebSocket connections", successCount)
|
|
}
|
|
|
|
func TestWebSocketConnectionResilience(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 6: Connection resilience and reconnection
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
|
|
// First connection
|
|
conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp1 != nil && resp1.Body != nil {
|
|
defer func() { _ = resp1.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to establish first connection: %v", err)
|
|
}
|
|
|
|
// Send a message
|
|
err = conn1.WriteJSON(map[string]interface{}{
|
|
"opcode": 0x02,
|
|
"data": "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to send message on first connection: %v", err)
|
|
}
|
|
|
|
// Close first connection
|
|
_ = conn1.Close()
|
|
|
|
// Wait a moment
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Reconnect
|
|
conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp2 != nil && resp2.Body != nil {
|
|
defer func() { _ = resp2.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to reconnect: %v", err)
|
|
}
|
|
defer func() { _ = conn2.Close() }()
|
|
|
|
// Send message on reconnected connection
|
|
err = conn2.WriteJSON(map[string]interface{}{
|
|
"opcode": 0x02,
|
|
"data": "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to send message on reconnected connection: %v", err)
|
|
}
|
|
|
|
t.Log("Successfully tested connection resilience and reconnection")
|
|
}
|