package tests import ( "context" "log/slog" "net" "net/http" "net/url" "testing" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api/datasets" "github.com/jfraeys/fetch_ml/internal/api/jobs" jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" ) // setupTestServer creates a test server with WebSocket handler and returns the address func setupTestServer(t *testing.T) string { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) // Create listener to get actual port listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to create listener: %v", err) } addr := listener.Addr().String() server := &http.Server{ Handler: wsHandler, ReadHeaderTimeout: 5 * time.Second, } // Start server serverErr := make(chan error, 1) go func() { serverErr <- server.Serve(listener) }() // Wait for server to start select { case err := <-serverErr: t.Fatalf("Failed to start server: %v", err) case <-time.After(100 * time.Millisecond): // Server should be ready } t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = server.Shutdown(ctx) <-serverErr // Wait for server to stop }) return addr } func TestWebSocketRealConnection(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 1: Basic WebSocket connection u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if err != nil { t.Fatalf("Failed to connect to WebSocket: %v", err) } defer func() { _ = conn.Close() }() t.Log("Successfully established WebSocket connection") // Test 2: Send a status request _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // New protocol: [opcode:1][api_key_hash:16] statusMsg := []byte{0x02} // opcode statusMsg = append(statusMsg, make([]byte, 16)...) // 16-byte API key hash err = conn.WriteMessage(websocket.BinaryMessage, statusMsg) if err != nil { t.Fatalf("Failed to send status request: %v", err) } // Test 3: Read response with timeout _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) messageType, message, err := conn.ReadMessage() if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { t.Log("Read timeout - this is expected for status request") } else { t.Logf("Failed to read message: %v", err) } } else { t.Logf("Received message type %d: %s", messageType, string(message)) } // Test 4: Send invalid message _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, []byte("invalid")) if err != nil { t.Fatalf("Failed to send invalid message: %v", err) } // Try to read response (may get error due to server closing connection) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, _, err = conn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.ClosePolicyViolation) { t.Log("Server correctly closed connection due to invalid message") } else { t.Logf("Server handled invalid message with error: %v", err) } } } func TestWebSocketBinaryProtocol(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) time.Sleep(100 * time.Millisecond) // Connect to WebSocket u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if err != nil { t.Fatalf("Failed to connect to WebSocket: %v", err) } defer func() { _ = conn.Close() }() // Test 4: Send binary message with queue job opcode using new protocol // Create binary message with new protocol: // [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] jobName := "test_job" commitID := "aaaaaaaaaaaaaaaaaaaa" // 20-byte commit ID binaryMessage := []byte{0x01} // OpcodeQueueJob binaryMessage = append(binaryMessage, make([]byte, 16)...) // 16-byte API key hash binaryMessage = append(binaryMessage, []byte(commitID)...) // 20-byte commit ID binaryMessage = append(binaryMessage, 5) // priority binaryMessage = append(binaryMessage, byte(len(jobName))) // job name length binaryMessage = append(binaryMessage, []byte(jobName)...) // job name err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage) if err != nil { t.Fatalf("Failed to send binary message: %v", err) } t.Log("Successfully sent binary queue job message") // Read response (if any) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, message, err := conn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure) { t.Log("Connection closed normally after binary message") } else { t.Logf("No response received (expected): %v", err) } } else { t.Logf("Received response to binary message: %s", string(message)) } } func TestWebSocketConcurrentConnections(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 5: Multiple concurrent connections numConnections := 5 connections := make([]*websocket.Conn, numConnections) errors := make([]error, numConnections) // Create multiple connections for i := range numConnections { u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp != nil && resp.Body != nil { _ = resp.Body.Close() } if err != nil { errors[i] = err continue } connections[i] = conn } // Close all connections for _, conn := range connections { if conn != nil { _ = conn.Close() } } // Verify all connections succeeded for i, err := range errors { if err != nil { t.Errorf("Connection %d failed: %v", i, err) } } successCount := 0 for _, conn := range connections { if conn != nil { successCount++ } } if successCount != numConnections { t.Errorf("Expected %d successful connections, got %d", numConnections, successCount) } t.Logf("Successfully established %d concurrent WebSocket connections", successCount) } func TestWebSocketConnectionResilience(t *testing.T) { t.Parallel() // Enable parallel execution // Setup test server addr := setupTestServer(t) // Test 6: Connection resilience and reconnection u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} // First connection conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp1 != nil && resp1.Body != nil { defer func() { _ = resp1.Body.Close() }() } if err != nil { t.Fatalf("Failed to establish first connection: %v", err) } // Send a message err = conn1.WriteJSON(map[string]interface{}{ "opcode": 0x02, "data": "", }) if err != nil { t.Fatalf("Failed to send message on first connection: %v", err) } // Close first connection _ = conn1.Close() // Wait a moment time.Sleep(100 * time.Millisecond) // Reconnect conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil) if resp2 != nil && resp2.Body != nil { defer func() { _ = resp2.Body.Close() }() } if err != nil { t.Fatalf("Failed to reconnect: %v", err) } defer func() { _ = conn2.Close() }() // Send message on reconnected connection err = conn2.WriteJSON(map[string]interface{}{ "opcode": 0x02, "data": "", }) if err != nil { t.Fatalf("Failed to send message on reconnected connection: %v", err) } t.Log("Successfully tested connection resilience and reconnection") }