Updated test file to pass jobs, jupyter, and datasets handlers to NewHandler. All tests now pass.
203 lines
6.8 KiB
Go
203 lines
6.8 KiB
Go
//nolint:revive // Package name 'api' is appropriate for this test package
|
|
package api_test
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
func TestNewWSHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
authConfig := &auth.Config{}
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
expManager := experiment.NewManager("/tmp")
|
|
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
|
|
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
|
|
|
if handler == nil {
|
|
t.Error("Expected non-nil WSHandler")
|
|
}
|
|
}
|
|
|
|
func TestWSHandlerConstants(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Test that constants are defined correctly
|
|
if wspkg.OpcodeQueueJob != 0x01 {
|
|
t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", wspkg.OpcodeQueueJob)
|
|
}
|
|
|
|
if wspkg.OpcodeStatusRequest != 0x02 {
|
|
t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", wspkg.OpcodeStatusRequest)
|
|
}
|
|
|
|
if wspkg.OpcodeCancelJob != 0x03 {
|
|
t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", wspkg.OpcodeCancelJob)
|
|
}
|
|
|
|
if wspkg.OpcodePrune != 0x04 {
|
|
t.Errorf("Expected OpcodePrune to be 0x04, got %d", wspkg.OpcodePrune)
|
|
}
|
|
}
|
|
|
|
func TestWSHandlerWebSocketUpgrade(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
authConfig := &auth.Config{}
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
expManager := experiment.NewManager("/tmp")
|
|
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
|
|
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
|
|
|
// Create a test HTTP request
|
|
req := httptest.NewRequest("GET", "/ws", nil)
|
|
req.Header.Set("Upgrade", "websocket")
|
|
req.Header.Set("Connection", "upgrade")
|
|
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
|
|
|
// Create a ResponseRecorder to capture the response
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Check that the upgrade was attempted
|
|
resp := w.Result()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
// httptest.ResponseRecorder doesn't support hijacking, so WebSocket upgrade will fail
|
|
// We expect either 500 (due to hijacker limitation) or 400 (due to other issues
|
|
// The important thing is that the handler doesn't panic and responds
|
|
if resp.StatusCode != http.StatusInternalServerError && resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status 500 or 400 for httptest limitation, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// The test verifies that the handler attempts the upgrade and handles errors gracefully
|
|
t.Log("WebSocket upgrade test completed - expected limitation with httptest.ResponseRecorder")
|
|
}
|
|
|
|
func TestWSHandlerInvalidRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
authConfig := &auth.Config{}
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
expManager := experiment.NewManager("/tmp")
|
|
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
|
|
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
|
|
|
// Create a test HTTP request without WebSocket headers
|
|
req := httptest.NewRequest("GET", "/ws", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Should fail the upgrade
|
|
resp := w.Result()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status 400 for invalid WebSocket request, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWSHandlerPostRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
authConfig := &auth.Config{}
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
expManager := experiment.NewManager("/tmp")
|
|
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
|
|
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
|
|
|
// Create a POST request (should fail)
|
|
req := httptest.NewRequest("POST", "/ws", strings.NewReader("data"))
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Should fail the upgrade
|
|
resp := w.Result()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status 400 for POST request, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWSHandlerOriginCheck(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// This test verifies that the CheckOrigin function exists and returns true
|
|
// The actual implementation should be improved for security
|
|
|
|
// Create a request with Origin header
|
|
req := httptest.NewRequest("GET", "/ws", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
|
|
// The upgrader should accept this origin (currently returns true)
|
|
// This is a placeholder test - the origin checking should be enhanced
|
|
if req.Header.Get("Origin") != "https://example.com" {
|
|
t.Error("Origin header not set correctly")
|
|
}
|
|
}
|
|
|
|
func TestWebSocketMessageConstants(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Test that binary protocol constants are properly defined
|
|
constants := map[string]byte{
|
|
"OpcodeQueueJob": wspkg.OpcodeQueueJob,
|
|
"OpcodeStatusRequest": wspkg.OpcodeStatusRequest,
|
|
"OpcodeCancelJob": wspkg.OpcodeCancelJob,
|
|
"OpcodePrune": wspkg.OpcodePrune,
|
|
}
|
|
|
|
expectedValues := map[string]byte{
|
|
"OpcodeQueueJob": 0x01,
|
|
"OpcodeStatusRequest": 0x02,
|
|
"OpcodeCancelJob": 0x03,
|
|
"OpcodePrune": 0x04,
|
|
}
|
|
|
|
for name, actual := range constants {
|
|
expected, exists := expectedValues[name]
|
|
if !exists {
|
|
t.Errorf("Constant %s not found in expected values", name)
|
|
continue
|
|
}
|
|
if actual != expected {
|
|
t.Errorf("Expected %s to be %d, got %d", name, expected, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Note: Full WebSocket integration tests would require a more complex setup
|
|
// with actual WebSocket connections, which is typically done in integration tests
|
|
// rather than unit tests. These tests focus on the handler setup and basic request handling.
|