fetch_ml/tests/unit/api/ws_test.go
Jeremie Fraeys 27c8b08a16
test: Reorganize and add unit tests
Reorganize tests for better structure and coverage:
- Move container/security_test.go from internal/ to tests/unit/container/
- Move related tests to proper unit test locations
- Delete orphaned test files (startup_blacklist_test.go)
- Add privacy middleware unit tests
- Add worker config unit tests
- Update E2E tests for homelab and websocket scenarios
- Update test fixtures with utility functions
- Add CLI helper script for arraylist fixes
2026-02-18 21:28:13 -05:00

203 lines
6.8 KiB
Go

//nolint:revive // Package name 'api' is appropriate for this test package
package api_test
import (
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
)
func TestNewWSHandler(t *testing.T) {
t.Parallel()
authConfig := &auth.Config{}
logger := logging.NewLogger(slog.LevelInfo, false)
expManager := experiment.NewManager("/tmp")
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
if handler == nil {
t.Error("Expected non-nil WSHandler")
}
}
func TestWSHandlerConstants(t *testing.T) {
t.Parallel() // Enable parallel execution
// Test that constants are defined correctly
if wspkg.OpcodeQueueJob != 0x01 {
t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", wspkg.OpcodeQueueJob)
}
if wspkg.OpcodeStatusRequest != 0x02 {
t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", wspkg.OpcodeStatusRequest)
}
if wspkg.OpcodeCancelJob != 0x03 {
t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", wspkg.OpcodeCancelJob)
}
if wspkg.OpcodePrune != 0x04 {
t.Errorf("Expected OpcodePrune to be 0x04, got %d", wspkg.OpcodePrune)
}
}
func TestWSHandlerWebSocketUpgrade(t *testing.T) {
t.Parallel()
authConfig := &auth.Config{}
logger := logging.NewLogger(slog.LevelInfo, false)
expManager := experiment.NewManager("/tmp")
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
// Create a test HTTP request
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-WebSocket-Version", "13")
// Create a ResponseRecorder to capture the response
w := httptest.NewRecorder()
// Call the handler
handler.ServeHTTP(w, req)
// Check that the upgrade was attempted
resp := w.Result()
defer func() { _ = resp.Body.Close() }()
// httptest.ResponseRecorder doesn't support hijacking, so WebSocket upgrade will fail
// We expect either 500 (due to hijacker limitation) or 400 (due to other issues
// The important thing is that the handler doesn't panic and responds
if resp.StatusCode != http.StatusInternalServerError && resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 500 or 400 for httptest limitation, got %d", resp.StatusCode)
}
// The test verifies that the handler attempts the upgrade and handles errors gracefully
t.Log("WebSocket upgrade test completed - expected limitation with httptest.ResponseRecorder")
}
func TestWSHandlerInvalidRequest(t *testing.T) {
t.Parallel()
authConfig := &auth.Config{}
logger := logging.NewLogger(slog.LevelInfo, false)
expManager := experiment.NewManager("/tmp")
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
// Create a test HTTP request without WebSocket headers
req := httptest.NewRequest("GET", "/ws", nil)
w := httptest.NewRecorder()
// Call the handler
handler.ServeHTTP(w, req)
// Should fail the upgrade
resp := w.Result()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid WebSocket request, got %d", resp.StatusCode)
}
}
func TestWSHandlerPostRequest(t *testing.T) {
t.Parallel()
authConfig := &auth.Config{}
logger := logging.NewLogger(slog.LevelInfo, false)
expManager := experiment.NewManager("/tmp")
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
// Create a POST request (should fail)
req := httptest.NewRequest("POST", "/ws", strings.NewReader("data"))
w := httptest.NewRecorder()
// Call the handler
handler.ServeHTTP(w, req)
// Should fail the upgrade
resp := w.Result()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400 for POST request, got %d", resp.StatusCode)
}
}
func TestWSHandlerOriginCheck(t *testing.T) {
t.Parallel() // Enable parallel execution
// This test verifies that the CheckOrigin function exists and returns true
// The actual implementation should be improved for security
// Create a request with Origin header
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Origin", "https://example.com")
// The upgrader should accept this origin (currently returns true)
// This is a placeholder test - the origin checking should be enhanced
if req.Header.Get("Origin") != "https://example.com" {
t.Error("Origin header not set correctly")
}
}
func TestWebSocketMessageConstants(t *testing.T) {
t.Parallel() // Enable parallel execution
// Test that binary protocol constants are properly defined
constants := map[string]byte{
"OpcodeQueueJob": wspkg.OpcodeQueueJob,
"OpcodeStatusRequest": wspkg.OpcodeStatusRequest,
"OpcodeCancelJob": wspkg.OpcodeCancelJob,
"OpcodePrune": wspkg.OpcodePrune,
}
expectedValues := map[string]byte{
"OpcodeQueueJob": 0x01,
"OpcodeStatusRequest": 0x02,
"OpcodeCancelJob": 0x03,
"OpcodePrune": 0x04,
}
for name, actual := range constants {
expected, exists := expectedValues[name]
if !exists {
t.Errorf("Constant %s not found in expected values", name)
continue
}
if actual != expected {
t.Errorf("Expected %s to be %d, got %d", name, expected, actual)
}
}
}
// Note: Full WebSocket integration tests would require a more complex setup
// with actual WebSocket connections, which is typically done in integration tests
// rather than unit tests. These tests focus on the handler setup and basic request handling.