diff --git a/internal/api/ws_compat.go b/internal/api/ws_compat.go deleted file mode 100644 index d44dbc1..0000000 --- a/internal/api/ws_compat.go +++ /dev/null @@ -1,64 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/jfraeys/fetch_ml/internal/api/ws" - "github.com/jfraeys/fetch_ml/internal/audit" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/config" - "github.com/jfraeys/fetch_ml/internal/experiment" - "github.com/jfraeys/fetch_ml/internal/jupyter" - "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/storage" -) - -// Re-export WebSocket types from ws package for backward compatibility -// Deprecated: Use api/ws package directly - -// NewWSHandler creates a new WebSocket handler (re-export from ws package) -// Deprecated: Use ws.NewHandler instead -func NewWSHandler( - authConfig *auth.Config, - logger *logging.Logger, - expManager *experiment.Manager, - dataDir string, - taskQueue queue.Backend, - db *storage.DB, - jupyterServiceMgr *jupyter.ServiceManager, - securityCfg *config.SecurityConfig, - auditLogger *audit.Logger, -) http.Handler { - return ws.NewHandler(authConfig, logger, expManager, dataDir, taskQueue, db, jupyterServiceMgr, securityCfg, auditLogger) -} - -// WebSocket opcodes (re-exported from ws package) -// Deprecated: Use ws package constants directly -const ( - OpcodeQueueJob = ws.OpcodeQueueJob - OpcodeStatusRequest = ws.OpcodeStatusRequest - OpcodeCancelJob = ws.OpcodeCancelJob - OpcodePrune = ws.OpcodePrune - OpcodeDatasetList = ws.OpcodeDatasetList - OpcodeDatasetRegister = ws.OpcodeDatasetRegister - OpcodeDatasetInfo = ws.OpcodeDatasetInfo - OpcodeDatasetSearch = ws.OpcodeDatasetSearch - OpcodeLogMetric = ws.OpcodeLogMetric - OpcodeGetExperiment = ws.OpcodeGetExperiment - OpcodeQueueJobWithTracking = ws.OpcodeQueueJobWithTracking - OpcodeQueueJobWithSnapshot = ws.OpcodeQueueJobWithSnapshot - OpcodeQueueJobWithArgs = ws.OpcodeQueueJobWithArgs - OpcodeQueueJobWithNote = ws.OpcodeQueueJobWithNote - OpcodeAnnotateRun = ws.OpcodeAnnotateRun - OpcodeSetRunNarrative = ws.OpcodeSetRunNarrative - OpcodeStartJupyter = ws.OpcodeStartJupyter - OpcodeStopJupyter = ws.OpcodeStopJupyter - OpcodeRemoveJupyter = ws.OpcodeRemoveJupyter - OpcodeRestoreJupyter = ws.OpcodeRestoreJupyter - OpcodeListJupyter = ws.OpcodeListJupyter - OpcodeListJupyterPackages = ws.OpcodeListJupyterPackages - OpcodeValidateRequest = ws.OpcodeValidateRequest - OpcodeGetLogs = ws.OpcodeGetLogs - OpcodeStreamLogs = ws.OpcodeStreamLogs -) diff --git a/tests/e2e/websocket_e2e_test.go b/tests/e2e/websocket_e2e_test.go index 1ae7b64..672b6d1 100644 --- a/tests/e2e/websocket_e2e_test.go +++ b/tests/e2e/websocket_e2e_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api" + "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" @@ -22,7 +22,7 @@ func setupTestServer(t *testing.T) string { authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - wsHandler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) // Create listener to get actual port listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") diff --git a/tests/e2e/wss_reverse_proxy_e2e_test.go b/tests/e2e/wss_reverse_proxy_e2e_test.go index 7719902..bc474b2 100644 --- a/tests/e2e/wss_reverse_proxy_e2e_test.go +++ b/tests/e2e/wss_reverse_proxy_e2e_test.go @@ -11,7 +11,7 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api" + "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" @@ -37,7 +37,7 @@ func startWSBackendServer(t *testing.T) *httptest.Server { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - h := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) srv := httptest.NewServer(h) t.Cleanup(srv.Close) diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index 8254439..8179b52 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -14,6 +14,7 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api" + wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" @@ -43,7 +44,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - wsHandler := api.NewWSHandler( + wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, @@ -134,7 +135,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - wsHandler := api.NewWSHandler( + wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, @@ -230,7 +231,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} - wsHandler := api.NewWSHandler( + wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, @@ -387,7 +388,7 @@ func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte copy(apiKeyHash, []byte(strings.Repeat("0", 16))) buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)) - buf = append(buf, api.OpcodeQueueJob) + buf = append(buf, wspkg.OpcodeQueueJob) buf = append(buf, apiKeyHash...) buf = append(buf, commitID...) buf = append(buf, priority) @@ -426,7 +427,7 @@ func buildQueueJobWithSnapshotMessage( } buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB)) - buf = append(buf, api.OpcodeQueueJobWithSnapshot) + buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot) buf = append(buf, apiKeyHash...) buf = append(buf, commitID...) buf = append(buf, priority) diff --git a/tests/integration/ws_handler_integration_test.go b/tests/integration/ws_handler_integration_test.go index ffff270..e922f56 100644 --- a/tests/integration/ws_handler_integration_test.go +++ b/tests/integration/ws_handler_integration_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "github.com/jfraeys/fetch_ml/internal/api" + wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" @@ -58,7 +59,7 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) ( require.NoError(t, err) require.NoError(t, db.Initialize(schema)) - handler := api.NewWSHandler( + handler := wspkg.NewHandler( authConfig, logger, expManager, @@ -142,7 +143,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestMissingForRunning_Fails(t * copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -221,7 +222,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestCommitMismatch_Fails(t *tes copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -300,7 +301,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestLocationMismatch_Fails(t *t copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -384,7 +385,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestLifecycleOrdering_Fails(t * copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -452,7 +453,7 @@ func TestWSHandler_ValidateRequest_TaskID_InvalidResources(t *testing.T) { copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -529,7 +530,7 @@ func TestWSHandler_ValidateRequest_TaskID_SnapshotMismatch(t *testing.T) { copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) @@ -584,7 +585,7 @@ func setupWSIntegrationServer(t *testing.T) ( require.NoError(t, db.Initialize(schema)) // Create handler - handler := api.NewWSHandler( + handler := wspkg.NewHandler( authConfig, logger, expManager, @@ -623,7 +624,7 @@ func TestWSHandler_QueueJob_Integration(t *testing.T) { // Prepare queue_job message // Protocol: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] - opcode := byte(api.OpcodeQueueJob) + opcode := byte(wspkg.OpcodeQueueJob) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := make([]byte, 20) @@ -706,7 +707,7 @@ func TestWSHandler_StatusRequest_Integration(t *testing.T) { // Prepare status_request message // Protocol: [opcode:1][api_key_hash:16] - opcode := byte(api.OpcodeStatusRequest) + opcode := byte(wspkg.OpcodeStatusRequest) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) @@ -752,7 +753,7 @@ func TestWSHandler_ValidateRequest_Integration(t *testing.T) { copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+20) - msg = append(msg, byte(api.OpcodeValidateRequest)) + msg = append(msg, byte(wspkg.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(0)) msg = append(msg, byte(20)) @@ -801,7 +802,7 @@ func TestWSHandler_CancelJob_Integration(t *testing.T) { // Prepare cancel_job message // Protocol: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var] - opcode := byte(api.OpcodeCancelJob) + opcode := byte(wspkg.OpcodeCancelJob) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) jobName := "job-to-cancel" @@ -846,7 +847,7 @@ func TestWSHandler_Prune_Integration(t *testing.T) { // Prepare prune message // Protocol: [opcode:1][api_key_hash:16][prune_type:1][value:4] - opcode := byte(api.OpcodePrune) + opcode := byte(wspkg.OpcodePrune) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) pruneType := byte(0) // Keep N @@ -897,7 +898,7 @@ func TestWSHandler_LogMetric_Integration(t *testing.T) { // Prepare log_metric message // Protocol: [opcode:1][api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var] - opcode := byte(api.OpcodeLogMetric) + opcode := byte(wspkg.OpcodeLogMetric) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := []byte(commitIDStr) @@ -957,7 +958,7 @@ func TestWSHandler_GetExperiment_Integration(t *testing.T) { // Prepare get_experiment message // Protocol: [opcode:1][api_key_hash:16][commit_id:20] - opcode := byte(api.OpcodeGetExperiment) + opcode := byte(wspkg.OpcodeGetExperiment) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := []byte(commitIDStr) @@ -995,7 +996,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) { // 1) List should return empty array { msg := make([]byte, 0, 1+16) - msg = append(msg, byte(api.OpcodeDatasetList)) + msg = append(msg, byte(wspkg.OpcodeDatasetList)) msg = append(msg, apiKeyHash...) err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) @@ -1010,7 +1011,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) { urlStr := "https://example.com/mnist.tar.gz" { msg := make([]byte, 0, 1+16+1+len(name)+2+len(urlStr)) - msg = append(msg, byte(api.OpcodeDatasetRegister)) + msg = append(msg, byte(wspkg.OpcodeDatasetRegister)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(name))) msg = append(msg, []byte(name)...) @@ -1030,7 +1031,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) { // 3) Info should return PacketTypeData { msg := make([]byte, 0, 1+16+1+len(name)) - msg = append(msg, byte(api.OpcodeDatasetInfo)) + msg = append(msg, byte(wspkg.OpcodeDatasetInfo)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(name))) msg = append(msg, []byte(name)...) @@ -1047,7 +1048,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) { term := "mn" { msg := make([]byte, 0, 1+16+1+len(term)) - msg = append(msg, byte(api.OpcodeDatasetSearch)) + msg = append(msg, byte(wspkg.OpcodeDatasetSearch)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(term))) msg = append(msg, []byte(term)...) diff --git a/tests/unit/api/ws_jobs_args_test.go b/tests/unit/api/ws_jobs_args_test.go index 122770c..5104cc6 100644 --- a/tests/unit/api/ws_jobs_args_test.go +++ b/tests/unit/api/ws_jobs_args_test.go @@ -4,17 +4,17 @@ package api_test import ( "testing" - "github.com/jfraeys/fetch_ml/internal/api" + wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" ) func TestWSHandlerNewQueueOpcodeExists(t *testing.T) { - if api.OpcodeQueueJobWithNote != 0x1B { - t.Fatalf("expected OpcodeQueueJobWithNote to be 0x1B, got %d", api.OpcodeQueueJobWithNote) + if wspkg.OpcodeQueueJobWithNote != 0x1B { + t.Fatalf("expected OpcodeQueueJobWithNote to be 0x1B, got %d", wspkg.OpcodeQueueJobWithNote) } - if api.OpcodeAnnotateRun != 0x1C { - t.Fatalf("expected OpcodeAnnotateRun to be 0x1C, got %d", api.OpcodeAnnotateRun) + if wspkg.OpcodeAnnotateRun != 0x1C { + t.Fatalf("expected OpcodeAnnotateRun to be 0x1C, got %d", wspkg.OpcodeAnnotateRun) } - if api.OpcodeSetRunNarrative != 0x1D { - t.Fatalf("expected OpcodeSetRunNarrative to be 0x1D, got %d", api.OpcodeSetRunNarrative) + if wspkg.OpcodeSetRunNarrative != 0x1D { + t.Fatalf("expected OpcodeSetRunNarrative to be 0x1D, got %d", wspkg.OpcodeSetRunNarrative) } } diff --git a/tests/unit/api/ws_test.go b/tests/unit/api/ws_test.go index 22cdf5b..bb494aa 100644 --- a/tests/unit/api/ws_test.go +++ b/tests/unit/api/ws_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/jfraeys/fetch_ml/internal/api" + wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" @@ -21,7 +21,7 @@ func TestNewWSHandler(t *testing.T) { logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") - handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) if handler == nil { t.Error("Expected non-nil WSHandler") @@ -32,20 +32,20 @@ func TestWSHandlerConstants(t *testing.T) { t.Parallel() // Enable parallel execution // Test that constants are defined correctly - if api.OpcodeQueueJob != 0x01 { - t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", api.OpcodeQueueJob) + if wspkg.OpcodeQueueJob != 0x01 { + t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", wspkg.OpcodeQueueJob) } - if api.OpcodeStatusRequest != 0x02 { - t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", api.OpcodeStatusRequest) + if wspkg.OpcodeStatusRequest != 0x02 { + t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", wspkg.OpcodeStatusRequest) } - if api.OpcodeCancelJob != 0x03 { - t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", api.OpcodeCancelJob) + if wspkg.OpcodeCancelJob != 0x03 { + t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", wspkg.OpcodeCancelJob) } - if api.OpcodePrune != 0x04 { - t.Errorf("Expected OpcodePrune to be 0x04, got %d", api.OpcodePrune) + if wspkg.OpcodePrune != 0x04 { + t.Errorf("Expected OpcodePrune to be 0x04, got %d", wspkg.OpcodePrune) } } @@ -56,7 +56,7 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) { logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") - handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) // Create a test HTTP request req := httptest.NewRequest("GET", "/ws", nil) @@ -93,7 +93,7 @@ func TestWSHandlerInvalidRequest(t *testing.T) { logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") - handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) // Create a test HTTP request without WebSocket headers req := httptest.NewRequest("GET", "/ws", nil) @@ -118,7 +118,7 @@ func TestWSHandlerPostRequest(t *testing.T) { logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger expManager := experiment.NewManager("/tmp") - handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) // Create a POST request (should fail) req := httptest.NewRequest("POST", "/ws", strings.NewReader("data")) @@ -158,10 +158,10 @@ func TestWebSocketMessageConstants(t *testing.T) { // Test that binary protocol constants are properly defined constants := map[string]byte{ - "OpcodeQueueJob": api.OpcodeQueueJob, - "OpcodeStatusRequest": api.OpcodeStatusRequest, - "OpcodeCancelJob": api.OpcodeCancelJob, - "OpcodePrune": api.OpcodePrune, + "OpcodeQueueJob": wspkg.OpcodeQueueJob, + "OpcodeStatusRequest": wspkg.OpcodeStatusRequest, + "OpcodeCancelJob": wspkg.OpcodeCancelJob, + "OpcodePrune": wspkg.OpcodePrune, } expectedValues := map[string]byte{