fetch_ml/tests/e2e/websocket_e2e_test.go
Jeremie Fraeys c980167041 test: implement comprehensive test suite with multiple test types
- Add end-to-end tests for complete workflow validation
- Include integration tests for API and database interactions
- Add unit tests for all major components and utilities
- Include performance tests for payload handling
- Add CLI API integration tests
- Include Podman container integration tests
- Add WebSocket and queue execution tests
- Include shell script tests for setup validation

Provides comprehensive test coverage ensuring platform reliability
and functionality across all components and interactions.
2025-12-04 16:55:13 -05:00

275 lines
7 KiB
Go

package tests
import (
"context"
"encoding/json"
"log/slog"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// setupTestServer creates a test server with WebSocket handler and returns the address
func setupTestServer(t *testing.T) (*http.Server, string) {
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.AuthConfig{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
wsHandler := api.NewWSHandler(authConfig, logger, expManager, nil)
// Create listener to get actual port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
addr := listener.Addr().String()
server := &http.Server{
Handler: wsHandler,
}
// Start server
serverErr := make(chan error, 1)
go func() {
serverErr <- server.Serve(listener)
}()
// Wait for server to start
select {
case err := <-serverErr:
t.Fatalf("Failed to start server: %v", err)
case <-time.After(100 * time.Millisecond):
// Server should be ready
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Shutdown(ctx)
<-serverErr // Wait for server to stop
})
return server, addr
}
func TestWebSocketRealConnection(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
_, addr := setupTestServer(t)
// Test 1: Basic WebSocket connection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer conn.Close()
t.Log("Successfully established WebSocket connection")
// Test 2: Send a status request
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
err = conn.WriteMessage(websocket.BinaryMessage, []byte{0x02, 0x00})
if err != nil {
t.Fatalf("Failed to send status request: %v", err)
}
// Test 3: Read response with timeout
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, message, err := conn.ReadMessage()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
t.Log("Read timeout - this is expected for status request")
} else {
t.Logf("Failed to read message: %v", err)
}
} else {
t.Logf("Received message type %d: %s", messageType, string(message))
}
// Test 4: Send invalid message
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, []byte("invalid"))
if err != nil {
t.Fatalf("Failed to send invalid message: %v", err)
}
// Try to read response (may get error due to server closing connection)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err = conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.ClosePolicyViolation) {
t.Log("Server correctly closed connection due to invalid message")
} else {
t.Logf("Server handled invalid message with error: %v", err)
}
}
}
func TestWebSocketBinaryProtocol(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
_, addr := setupTestServer(t)
time.Sleep(100 * time.Millisecond)
// Connect to WebSocket
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer conn.Close()
// Test 4: Send binary message with queue job opcode
jobData := map[string]interface{}{
"job_id": "test-job-1",
"commit_id": "abc123",
"user": "testuser",
"script": "train.py",
}
dataBytes, _ := json.Marshal(jobData)
// Create binary message: [opcode][data_length][data]
binaryMessage := make([]byte, 1+4+len(dataBytes))
binaryMessage[0] = 0x01 // OpcodeQueueJob
// Add data length (big endian)
binaryMessage[1] = byte(len(dataBytes) >> 24)
binaryMessage[2] = byte(len(dataBytes) >> 16)
binaryMessage[3] = byte(len(dataBytes) >> 8)
binaryMessage[4] = byte(len(dataBytes))
// Add data
copy(binaryMessage[5:], dataBytes)
err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage)
if err != nil {
t.Fatalf("Failed to send binary message: %v", err)
}
t.Log("Successfully sent binary queue job message")
// Read response (if any)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
t.Log("Connection closed normally after binary message")
} else {
t.Logf("No response received (expected): %v", err)
}
} else {
t.Logf("Received response to binary message: %s", string(message))
}
}
func TestWebSocketConcurrentConnections(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
_, addr := setupTestServer(t)
// Test 5: Multiple concurrent connections
numConnections := 5
connections := make([]*websocket.Conn, numConnections)
errors := make([]error, numConnections)
// Create multiple connections
for i := range numConnections {
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
errors[i] = err
continue
}
connections[i] = conn
}
// Close all connections
for _, conn := range connections {
if conn != nil {
conn.Close()
}
}
// Verify all connections succeeded
for i, err := range errors {
if err != nil {
t.Errorf("Connection %d failed: %v", i, err)
}
}
successCount := 0
for _, conn := range connections {
if conn != nil {
successCount++
}
}
if successCount != numConnections {
t.Errorf("Expected %d successful connections, got %d", numConnections, successCount)
}
t.Logf("Successfully established %d concurrent WebSocket connections", successCount)
}
func TestWebSocketConnectionResilience(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
_, addr := setupTestServer(t)
// Test 6: Connection resilience and reconnection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
// First connection
conn1, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("Failed to establish first connection: %v", err)
}
// Send a message
err = conn1.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on first connection: %v", err)
}
// Close first connection
conn1.Close()
// Wait a moment
time.Sleep(100 * time.Millisecond)
// Reconnect
conn2, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("Failed to reconnect: %v", err)
}
defer conn2.Close()
// Send message on reconnected connection
err = conn2.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on reconnected connection: %v", err)
}
t.Log("Successfully tested connection resilience and reconnection")
}