package tests import ( "context" "encoding/json" "log/slog" "net" "net/http" "net/url" "testing" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" ) // setupTestServer creates a test server with WebSocket handler and returns the address func setupTestServer(t *testing.T) string { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) wsHandler := api.NewWSHandler(authConfig, logger, expManager, nil) // Create listener to get actual port listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to create listener: %v", err) } addr := listener.Addr().String() server := &http.Server{ Handler: wsHandler, ReadHeaderTimeout: 5 * time.Second, } // Start server serverErr := make(chan error, 1) go func() { serverErr <- server.Serve(listener) }() // Wait for server to start select { case err := <-serverErr: t.Fatalf("Failed to start server: %v", err) case <-time.After(100 * time.Millisecond): // Server should be ready } t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = server.Shutdown(ctx) <-serverErr // Wait for server to stop }) return addr } func TestWebSocketRealConnection(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 1: Basic WebSocket connection u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if err != nil { t.Fatalf("Failed to connect to WebSocket: %v", err) } defer func() { _ = conn.Close() }() t.Log("Successfully established WebSocket connection") // Test 2: Send a status request _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) err = conn.WriteMessage(websocket.BinaryMessage, []byte{0x02, 0x00}) if err != nil { t.Fatalf("Failed to send status request: %v", err) } // Test 3: Read response with timeout _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) messageType, message, err := conn.ReadMessage() if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { t.Log("Read timeout - this is expected for status request") } else { t.Logf("Failed to read message: %v", err) } } else { t.Logf("Received message type %d: %s", messageType, string(message)) } // Test 4: Send invalid message _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, []byte("invalid")) if err != nil { t.Fatalf("Failed to send invalid message: %v", err) } // Try to read response (may get error due to server closing connection) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, _, err = conn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.ClosePolicyViolation) { t.Log("Server correctly closed connection due to invalid message") } else { t.Logf("Server handled invalid message with error: %v", err) } } } func TestWebSocketBinaryProtocol(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) time.Sleep(100 * time.Millisecond) // Connect to WebSocket u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if err != nil { t.Fatalf("Failed to connect to WebSocket: %v", err) } defer func() { _ = conn.Close() }() // Test 4: Send binary message with queue job opcode jobData := map[string]interface{}{ "job_id": "test-job-1", "commit_id": "abc123", "user": "testuser", "script": "train.py", } dataBytes, _ := json.Marshal(jobData) // Create binary message: [opcode][data_length][data] binaryMessage := make([]byte, 1+4+len(dataBytes)) binaryMessage[0] = 0x01 // OpcodeQueueJob // Add data length (big endian) binaryMessage[1] = byte(len(dataBytes) >> 24) binaryMessage[2] = byte(len(dataBytes) >> 16) binaryMessage[3] = byte(len(dataBytes) >> 8) binaryMessage[4] = byte(len(dataBytes)) // Add data copy(binaryMessage[5:], dataBytes) err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage) if err != nil { t.Fatalf("Failed to send binary message: %v", err) } t.Log("Successfully sent binary queue job message") // Read response (if any) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, message, err := conn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure) { t.Log("Connection closed normally after binary message") } else { t.Logf("No response received (expected): %v", err) } } else { t.Logf("Received response to binary message: %s", string(message)) } } func TestWebSocketConcurrentConnections(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 5: Multiple concurrent connections numConnections := 5 connections := make([]*websocket.Conn, numConnections) errors := make([]error, numConnections) // Create multiple connections for i := range numConnections { u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { _ = resp.Body.Close() } if err != nil { errors[i] = err continue } connections[i] = conn } // Close all connections for _, conn := range connections { if conn != nil { _ = conn.Close() } } // Verify all connections succeeded for i, err := range errors { if err != nil { t.Errorf("Connection %d failed: %v", i, err) } } successCount := 0 for _, conn := range connections { if conn != nil { successCount++ } } if successCount != numConnections { t.Errorf("Expected %d successful connections, got %d", numConnections, successCount) } t.Logf("Successfully established %d concurrent WebSocket connections", successCount) } func TestWebSocketConnectionResilience(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 6: Connection resilience and reconnection u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} // First connection conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp1 != nil && resp1.Body != nil { defer func() { _ = resp1.Body.Close() }() } if err != nil { t.Fatalf("Failed to establish first connection: %v", err) } // Send a message err = conn1.WriteJSON(map[string]interface{}{ "opcode": 0x02, "data": "", }) if err != nil { t.Fatalf("Failed to send message on first connection: %v", err) } // Close first connection _ = conn1.Close() // Wait a moment time.Sleep(100 * time.Millisecond) // Reconnect conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp2 != nil && resp2.Body != nil { defer func() { _ = resp2.Body.Close() }() } if err != nil { t.Fatalf("Failed to reconnect: %v", err) } defer func() { _ = conn2.Close() }() // Send message on reconnected connection err = conn2.WriteJSON(map[string]interface{}{ "opcode": 0x02, "data": "", }) if err != nil { t.Fatalf("Failed to send message on reconnected connection: %v", err) } t.Log("Successfully tested connection resilience and reconnection") }