Phase 1: Fix Redis Schema Leak - Create internal/storage/dataset.go with DatasetStore abstraction - Remove all direct Redis calls from cmd/data_manager/data_sync.go - data_manager now uses DatasetStore for transfer tracking and metadata Phase 2: Simplify TUI Services - Embed *queue.TaskQueue directly in services.TaskQueue - Eliminate 60% of wrapper boilerplate (203 -> ~100 lines) - Keep only TUI-specific methods (EnqueueTask, GetJobStatus, experiment methods) Phase 5: Clean go.mod Dependencies - Remove duplicate go-redis/redis/v8 dependency - Migrate internal/storage/migrate.go to redis/go-redis/v9 - Separate test-only deps (miniredis, testify) into own block Results: - Zero direct Redis calls in cmd/ - 60% fewer lines in TUI services - Cleaner dependency structure
290 lines
8.2 KiB
Go
290 lines
8.2 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/api/ws"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
// setupTestServer creates a test server with WebSocket handler and returns the address
|
|
func setupTestServer(t *testing.T) string {
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
authConfig := &auth.Config{Enabled: false}
|
|
expManager := experiment.NewManager(t.TempDir())
|
|
|
|
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
|
|
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
|
|
|
// Create listener to get actual port
|
|
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create listener: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
|
|
server := &http.Server{
|
|
Handler: wsHandler,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
}
|
|
|
|
// Start server
|
|
serverErr := make(chan error, 1)
|
|
go func() {
|
|
serverErr <- server.Serve(listener)
|
|
}()
|
|
|
|
// Wait for server to start
|
|
select {
|
|
case err := <-serverErr:
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
case <-time.After(100 * time.Millisecond):
|
|
// Server should be ready
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
<-serverErr // Wait for server to stop
|
|
})
|
|
|
|
return addr
|
|
}
|
|
|
|
func TestWebSocketRealConnection(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 1: Basic WebSocket connection
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() { _ = resp.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to WebSocket: %v", err)
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
t.Log("Successfully established WebSocket connection")
|
|
|
|
// Test 2: Send a status request
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
// New protocol: [opcode:1][api_key_hash:16]
|
|
statusMsg := []byte{0x02} // opcode
|
|
statusMsg = append(statusMsg, make([]byte, 16)...) // 16-byte API key hash
|
|
err = conn.WriteMessage(websocket.BinaryMessage, statusMsg)
|
|
if err != nil {
|
|
t.Fatalf("Failed to send status request: %v", err)
|
|
}
|
|
|
|
// Test 3: Read response with timeout
|
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
messageType, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
t.Log("Read timeout - this is expected for status request")
|
|
} else {
|
|
t.Logf("Failed to read message: %v", err)
|
|
}
|
|
} else {
|
|
t.Logf("Received message type %d: %s", messageType, string(message))
|
|
}
|
|
|
|
// Test 4: Send invalid message
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
err = conn.WriteMessage(websocket.TextMessage, []byte("invalid"))
|
|
if err != nil {
|
|
t.Fatalf("Failed to send invalid message: %v", err)
|
|
}
|
|
|
|
// Try to read response (may get error due to server closing connection)
|
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, _, err = conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsCloseError(err, websocket.ClosePolicyViolation) {
|
|
t.Log("Server correctly closed connection due to invalid message")
|
|
} else {
|
|
t.Logf("Server handled invalid message with error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWebSocketBinaryProtocol(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Connect to WebSocket
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() { _ = resp.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to WebSocket: %v", err)
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
// Test 4: Send binary message with queue job opcode using new protocol
|
|
// Create binary message with new protocol:
|
|
// [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
|
jobName := "test_job"
|
|
commitID := "aaaaaaaaaaaaaaaaaaaa" // 20-byte commit ID
|
|
|
|
binaryMessage := []byte{0x01} // OpcodeQueueJob
|
|
binaryMessage = append(binaryMessage, make([]byte, 16)...) // 16-byte API key hash
|
|
binaryMessage = append(binaryMessage, []byte(commitID)...) // 20-byte commit ID
|
|
binaryMessage = append(binaryMessage, 5) // priority
|
|
binaryMessage = append(binaryMessage, byte(len(jobName))) // job name length
|
|
binaryMessage = append(binaryMessage, []byte(jobName)...) // job name
|
|
|
|
err = conn.WriteMessage(websocket.BinaryMessage, binaryMessage)
|
|
if err != nil {
|
|
t.Fatalf("Failed to send binary message: %v", err)
|
|
}
|
|
|
|
t.Log("Successfully sent binary queue job message")
|
|
|
|
// Read response (if any)
|
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
t.Log("Connection closed normally after binary message")
|
|
} else {
|
|
t.Logf("No response received (expected): %v", err)
|
|
}
|
|
} else {
|
|
t.Logf("Received response to binary message: %s", string(message))
|
|
}
|
|
}
|
|
|
|
func TestWebSocketConcurrentConnections(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 5: Multiple concurrent connections
|
|
numConnections := 5
|
|
connections := make([]*websocket.Conn, numConnections)
|
|
errors := make([]error, numConnections)
|
|
|
|
// Create multiple connections
|
|
for i := range numConnections {
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
errors[i] = err
|
|
continue
|
|
}
|
|
connections[i] = conn
|
|
}
|
|
|
|
// Close all connections
|
|
for _, conn := range connections {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
// Verify all connections succeeded
|
|
for i, err := range errors {
|
|
if err != nil {
|
|
t.Errorf("Connection %d failed: %v", i, err)
|
|
}
|
|
}
|
|
|
|
successCount := 0
|
|
for _, conn := range connections {
|
|
if conn != nil {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
if successCount != numConnections {
|
|
t.Errorf("Expected %d successful connections, got %d", numConnections, successCount)
|
|
}
|
|
|
|
t.Logf("Successfully established %d concurrent WebSocket connections", successCount)
|
|
}
|
|
|
|
func TestWebSocketConnectionResilience(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test server
|
|
addr := setupTestServer(t)
|
|
|
|
// Test 6: Connection resilience and reconnection
|
|
u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
|
|
|
|
// First connection
|
|
conn1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp1 != nil && resp1.Body != nil {
|
|
defer func() { _ = resp1.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to establish first connection: %v", err)
|
|
}
|
|
|
|
// Send a message
|
|
err = conn1.WriteJSON(map[string]interface{}{
|
|
"opcode": 0x02,
|
|
"data": "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to send message on first connection: %v", err)
|
|
}
|
|
|
|
// Close first connection
|
|
_ = conn1.Close()
|
|
|
|
// Wait a moment
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Reconnect
|
|
conn2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if resp2 != nil && resp2.Body != nil {
|
|
defer func() { _ = resp2.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Failed to reconnect: %v", err)
|
|
}
|
|
defer func() { _ = conn2.Close() }()
|
|
|
|
// Send message on reconnected connection
|
|
err = conn2.WriteJSON(map[string]interface{}{
|
|
"opcode": 0x02,
|
|
"data": "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to send message on reconnected connection: %v", err)
|
|
}
|
|
|
|
t.Log("Successfully tested connection resilience and reconnection")
|
|
}
|