fetch_ml/tests/e2e/websocket_e2e_test.go
Jeremie Fraeys dbf96020af
refactor(dependency-hygiene): Fix Redis leak, simplify TUI wrapper, clean go.mod
Phase 1: Fix Redis Schema Leak
- Create internal/storage/dataset.go with DatasetStore abstraction
- Remove all direct Redis calls from cmd/data_manager/data_sync.go
- data_manager now uses DatasetStore for transfer tracking and metadata

Phase 2: Simplify TUI Services
- Embed *queue.TaskQueue directly in services.TaskQueue
- Eliminate 60% of wrapper boilerplate (203 -> ~100 lines)
- Keep only TUI-specific methods (EnqueueTask, GetJobStatus, experiment methods)

Phase 5: Clean go.mod Dependencies
- Remove duplicate go-redis/redis/v8 dependency
- Migrate internal/storage/migrate.go to redis/go-redis/v9
- Separate test-only deps (miniredis, testify) into own block

Results:
- Zero direct Redis calls in cmd/
- 60% fewer lines in TUI services
- Cleaner dependency structure
2026-02-17 21:13:49 -05:00

290 lines
8.2 KiB
Go

package tests
import (
"context"
"log/slog"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// setupTestServer creates a test server with WebSocket handler and returns the address
func setupTestServer(t *testing.T) string {
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
// Create listener to get actual port
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
addr := listener.Addr().String()
server := &http.Server{
Handler: wsHandler,
ReadHeaderTimeout: 5 * time.Second,
}
// Start server
serverErr := make(chan error, 1)
go func() {
serverErr <- server.Serve(listener)
}()
// Wait for server to start
select {
case err := <-serverErr:
t.Fatalf("Failed to start server: %v", err)
case <-time.After(100 * time.Millisecond):
// Server should be ready
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
<-serverErr // Wait for server to stop
})
return addr
}
func TestWebSocketRealConnection(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 1: Basic WebSocket connection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer func() { _ = conn.Close() }()
t.Log("Successfully established WebSocket connection")
// Test 2: Send a status request
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
// New protocol: [opcode:1][api_key_hash:16]
statusMsg := []byte{0x02} // opcode
statusMsg = append(statusMsg, make([]byte, 16)...) // 16-byte API key hash
err = conn.WriteMessage(websocket.BinaryMessage, statusMsg)
if err != nil {
t.Fatalf("Failed to send status request: %v", err)
}
// Test 3: Read response with timeout
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, message, err := conn.ReadMessage()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
t.Log("Read timeout - this is expected for status request")
} else {
t.Logf("Failed to read message: %v", err)
}
} else {
t.Logf("Received message type %d: %s", messageType, string(message))
}
// Test 4: Send invalid message
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, []byte("invalid"))
if err != nil {
t.Fatalf("Failed to send invalid message: %v", err)
}
// Try to read response (may get error due to server closing connection)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err = conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.ClosePolicyViolation) {
t.Log("Server correctly closed connection due to invalid message")
} else {
t.Logf("Server handled invalid message with error: %v", err)
}
}
}
func TestWebSocketBinaryProtocol(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
time.Sleep(100 * time.Millisecond)
// Connect to WebSocket
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket: %v", err)
}
defer func() { _ = conn.Close() }()
// Test 4: Send binary message with queue job opcode using new protocol
// Create binary message with new protocol:
// [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
jobName := "test_job"
commitID := "aaaaaaaaaaaaaaaaaaaa" // 20-byte commit ID
binaryMessage := []byte{0x01} // OpcodeQueueJob
binaryMessage = append(binaryMessage, make([]byte, 16)...) // 16-byte API key hash
binaryMessage = append(binaryMessage, []byte(commitID)...) // 20-byte commit ID
binaryMessage = append(binaryMessage, 5) // priority
binaryMessage = append(binaryMessage, byte(len(jobName))) // job name length
binaryMessage = append(binaryMessage, []byte(jobName)...) // job name
err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage)
if err != nil {
t.Fatalf("Failed to send binary message: %v", err)
}
t.Log("Successfully sent binary queue job message")
// Read response (if any)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
t.Log("Connection closed normally after binary message")
} else {
t.Logf("No response received (expected): %v", err)
}
} else {
t.Logf("Received response to binary message: %s", string(message))
}
}
func TestWebSocketConcurrentConnections(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 5: Multiple concurrent connections
numConnections := 5
connections := make([]*websocket.Conn, numConnections)
errors := make([]error, numConnections)
// Create multiple connections
for i := range numConnections {
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
if err != nil {
errors[i] = err
continue
}
connections[i] = conn
}
// Close all connections
for _, conn := range connections {
if conn != nil {
_ = conn.Close()
}
}
// Verify all connections succeeded
for i, err := range errors {
if err != nil {
t.Errorf("Connection %d failed: %v", i, err)
}
}
successCount := 0
for _, conn := range connections {
if conn != nil {
successCount++
}
}
if successCount != numConnections {
t.Errorf("Expected %d successful connections, got %d", numConnections, successCount)
}
t.Logf("Successfully established %d concurrent WebSocket connections", successCount)
}
func TestWebSocketConnectionResilience(t *testing.T) {
t.Parallel() // Enable parallel execution
// Setup test server
addr := setupTestServer(t)
// Test 6: Connection resilience and reconnection
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
// First connection
conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp1 != nil && resp1.Body != nil {
defer func() { _ = resp1.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to establish first connection: %v", err)
}
// Send a message
err = conn1.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on first connection: %v", err)
}
// Close first connection
_ = conn1.Close()
// Wait a moment
time.Sleep(100 * time.Millisecond)
// Reconnect
conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil)
if resp2 != nil && resp2.Body != nil {
defer func() { _ = resp2.Body.Close() }()
}
if err != nil {
t.Fatalf("Failed to reconnect: %v", err)
}
defer func() { _ = conn2.Close() }()
// Send message on reconnected connection
err = conn2.WriteJSON(map[string]interface{}{
"opcode": 0x02,
"data": "",
})
if err != nil {
t.Fatalf("Failed to send message on reconnected connection: %v", err)
}
t.Log("Successfully tested connection resilience and reconnection")
}