package api import ( "log/slog" "net/http" "net/http/httptest" "strings" "testing" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" ) func TestNewWSHandler(t *testing.T) { t.Parallel() // Enable parallel execution authConfig := &auth.AuthConfig{} logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") handler := api.NewWSHandler(authConfig, logger, expManager, nil) if handler == nil { t.Error("Expected non-nil WSHandler") } } func TestWSHandlerConstants(t *testing.T) { t.Parallel() // Enable parallel execution // Test that constants are defined correctly if api.OpcodeQueueJob != 0x01 { t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", api.OpcodeQueueJob) } if api.OpcodeStatusRequest != 0x02 { t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", api.OpcodeStatusRequest) } if api.OpcodeCancelJob != 0x03 { t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", api.OpcodeCancelJob) } if api.OpcodePrune != 0x04 { t.Errorf("Expected OpcodePrune to be 0x04, got %d", api.OpcodePrune) } } func TestWSHandlerWebSocketUpgrade(t *testing.T) { t.Parallel() // Enable parallel execution authConfig := &auth.AuthConfig{} logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") handler := api.NewWSHandler(authConfig, logger, expManager, nil) // Create a test HTTP request req := httptest.NewRequest("GET", "/ws", nil) req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "upgrade") req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") req.Header.Set("Sec-WebSocket-Version", "13") // Create a ResponseRecorder to capture the response w := httptest.NewRecorder() // Call the handler handler.ServeHTTP(w, req) // Check that the upgrade was attempted resp := w.Result() defer resp.Body.Close() // httptest.ResponseRecorder doesn't support hijacking, so WebSocket upgrade will fail // We expect either 500 (due to hijacker limitation) or 400 (due to other issues // The important thing is that the handler doesn't panic and responds if resp.StatusCode != http.StatusInternalServerError && resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status 500 or 400 for httptest limitation, got %d", resp.StatusCode) } // The test verifies that the handler attempts the upgrade and handles errors gracefully t.Log("WebSocket upgrade test completed - expected limitation with httptest.ResponseRecorder") } func TestWSHandlerInvalidRequest(t *testing.T) { t.Parallel() // Enable parallel execution authConfig := &auth.AuthConfig{} logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") handler := api.NewWSHandler(authConfig, logger, expManager, nil) // Create a test HTTP request without WebSocket headers req := httptest.NewRequest("GET", "/ws", nil) w := httptest.NewRecorder() // Call the handler handler.ServeHTTP(w, req) // Should fail the upgrade resp := w.Result() defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status 400 for invalid WebSocket request, got %d", resp.StatusCode) } } func TestWSHandlerPostRequest(t *testing.T) { t.Parallel() // Enable parallel execution authConfig := &auth.AuthConfig{} logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") handler := api.NewWSHandler(authConfig, logger, expManager, nil) // Create a POST request (should fail) req := httptest.NewRequest("POST", "/ws", strings.NewReader("data")) w := httptest.NewRecorder() // Call the handler handler.ServeHTTP(w, req) // Should fail the upgrade resp := w.Result() defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status 400 for POST request, got %d", resp.StatusCode) } } func TestWSHandlerOriginCheck(t *testing.T) { t.Parallel() // Enable parallel execution // This test verifies that the CheckOrigin function exists and returns true // The actual implementation should be improved for security // Create a request with Origin header req := httptest.NewRequest("GET", "/ws", nil) req.Header.Set("Origin", "https://example.com") // The upgrader should accept this origin (currently returns true) // This is a placeholder test - the origin checking should be enhanced if req.Header.Get("Origin") != "https://example.com" { t.Error("Origin header not set correctly") } } func TestWebSocketMessageConstants(t *testing.T) { t.Parallel() // Enable parallel execution // Test that binary protocol constants are properly defined constants := map[string]byte{ "OpcodeQueueJob": api.OpcodeQueueJob, "OpcodeStatusRequest": api.OpcodeStatusRequest, "OpcodeCancelJob": api.OpcodeCancelJob, "OpcodePrune": api.OpcodePrune, } expectedValues := map[string]byte{ "OpcodeQueueJob": 0x01, "OpcodeStatusRequest": 0x02, "OpcodeCancelJob": 0x03, "OpcodePrune": 0x04, } for name, actual := range constants { expected, exists := expectedValues[name] if !exists { t.Errorf("Constant %s not found in expected values", name) continue } if actual != expected { t.Errorf("Expected %s to be %d, got %d", name, expected, actual) } } } // Note: Full WebSocket integration tests would require a more complex setup // with actual WebSocket connections, which is typically done in integration tests // rather than unit tests. These tests focus on the handler setup and basic request handling.