refactor: Migrate all test imports from api to api/ws package

Updated 6 test files to use proper api/ws package imports:

1. tests/e2e/websocket_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

2. tests/e2e/wss_reverse_proxy_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

3. tests/integration/ws_handler_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

4. tests/integration/websocket_queue_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

5. tests/unit/api/ws_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

6. tests/unit/api/ws_jobs_args_test.go
   - api.Opcode* → wspkg.Opcode*

Removed api/ws_compat.go shim as all tests now use proper imports.

Build status: Compiles successfully
This commit is contained in:
Jeremie Fraeys 2026-02-17 13:52:20 -05:00
parent 83ca393ebc
commit d8cc2a4efa
No known key found for this signature in database
7 changed files with 54 additions and 116 deletions

View file

@ -1,64 +0,0 @@
package api
import (
"net/http"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Re-export WebSocket types from ws package for backward compatibility
// Deprecated: Use api/ws package directly
// NewWSHandler creates a new WebSocket handler (re-export from ws package)
// Deprecated: Use ws.NewHandler instead
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
dataDir string,
taskQueue queue.Backend,
db *storage.DB,
jupyterServiceMgr *jupyter.ServiceManager,
securityCfg *config.SecurityConfig,
auditLogger *audit.Logger,
) http.Handler {
return ws.NewHandler(authConfig, logger, expManager, dataDir, taskQueue, db, jupyterServiceMgr, securityCfg, auditLogger)
}
// WebSocket opcodes (re-exported from ws package)
// Deprecated: Use ws package constants directly
const (
OpcodeQueueJob = ws.OpcodeQueueJob
OpcodeStatusRequest = ws.OpcodeStatusRequest
OpcodeCancelJob = ws.OpcodeCancelJob
OpcodePrune = ws.OpcodePrune
OpcodeDatasetList = ws.OpcodeDatasetList
OpcodeDatasetRegister = ws.OpcodeDatasetRegister
OpcodeDatasetInfo = ws.OpcodeDatasetInfo
OpcodeDatasetSearch = ws.OpcodeDatasetSearch
OpcodeLogMetric = ws.OpcodeLogMetric
OpcodeGetExperiment = ws.OpcodeGetExperiment
OpcodeQueueJobWithTracking = ws.OpcodeQueueJobWithTracking
OpcodeQueueJobWithSnapshot = ws.OpcodeQueueJobWithSnapshot
OpcodeQueueJobWithArgs = ws.OpcodeQueueJobWithArgs
OpcodeQueueJobWithNote = ws.OpcodeQueueJobWithNote
OpcodeAnnotateRun = ws.OpcodeAnnotateRun
OpcodeSetRunNarrative = ws.OpcodeSetRunNarrative
OpcodeStartJupyter = ws.OpcodeStartJupyter
OpcodeStopJupyter = ws.OpcodeStopJupyter
OpcodeRemoveJupyter = ws.OpcodeRemoveJupyter
OpcodeRestoreJupyter = ws.OpcodeRestoreJupyter
OpcodeListJupyter = ws.OpcodeListJupyter
OpcodeListJupyterPackages = ws.OpcodeListJupyterPackages
OpcodeValidateRequest = ws.OpcodeValidateRequest
OpcodeGetLogs = ws.OpcodeGetLogs
OpcodeStreamLogs = ws.OpcodeStreamLogs
)

View file

@ -10,7 +10,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -22,7 +22,7 @@ func setupTestServer(t *testing.T) string {
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
wsHandler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
// Create listener to get actual port
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")

View file

@ -11,7 +11,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -37,7 +37,7 @@ func startWSBackendServer(t *testing.T) *httptest.Server {
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
h := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
srv := httptest.NewServer(h)
t.Cleanup(srv.Close)

View file

@ -14,6 +14,7 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -43,7 +44,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
wsHandler := api.NewWSHandler(
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
@ -134,7 +135,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
wsHandler := api.NewWSHandler(
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
@ -230,7 +231,7 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
wsHandler := api.NewWSHandler(
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
@ -387,7 +388,7 @@ func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes))
buf = append(buf, api.OpcodeQueueJob)
buf = append(buf, wspkg.OpcodeQueueJob)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)
@ -426,7 +427,7 @@ func buildQueueJobWithSnapshotMessage(
}
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB))
buf = append(buf, api.OpcodeQueueJobWithSnapshot)
buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)

View file

@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/jfraeys/fetch_ml/internal/api"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -58,7 +59,7 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
require.NoError(t, err)
require.NoError(t, db.Initialize(schema))
handler := api.NewWSHandler(
handler := wspkg.NewHandler(
authConfig,
logger,
expManager,
@ -142,7 +143,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestMissingForRunning_Fails(t *
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -221,7 +222,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestCommitMismatch_Fails(t *tes
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -300,7 +301,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestLocationMismatch_Fails(t *t
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -384,7 +385,7 @@ func TestWSHandler_ValidateRequest_TaskID_RunManifestLifecycleOrdering_Fails(t *
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -452,7 +453,7 @@ func TestWSHandler_ValidateRequest_TaskID_InvalidResources(t *testing.T) {
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -529,7 +530,7 @@ func TestWSHandler_ValidateRequest_TaskID_SnapshotMismatch(t *testing.T) {
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
@ -584,7 +585,7 @@ func setupWSIntegrationServer(t *testing.T) (
require.NoError(t, db.Initialize(schema))
// Create handler
handler := api.NewWSHandler(
handler := wspkg.NewHandler(
authConfig,
logger,
expManager,
@ -623,7 +624,7 @@ func TestWSHandler_QueueJob_Integration(t *testing.T) {
// Prepare queue_job message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
opcode := byte(api.OpcodeQueueJob)
opcode := byte(wspkg.OpcodeQueueJob)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := make([]byte, 20)
@ -706,7 +707,7 @@ func TestWSHandler_StatusRequest_Integration(t *testing.T) {
// Prepare status_request message
// Protocol: [opcode:1][api_key_hash:16]
opcode := byte(api.OpcodeStatusRequest)
opcode := byte(wspkg.OpcodeStatusRequest)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
@ -752,7 +753,7 @@ func TestWSHandler_ValidateRequest_Integration(t *testing.T) {
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+20)
msg = append(msg, byte(api.OpcodeValidateRequest))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(0))
msg = append(msg, byte(20))
@ -801,7 +802,7 @@ func TestWSHandler_CancelJob_Integration(t *testing.T) {
// Prepare cancel_job message
// Protocol: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var]
opcode := byte(api.OpcodeCancelJob)
opcode := byte(wspkg.OpcodeCancelJob)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
jobName := "job-to-cancel"
@ -846,7 +847,7 @@ func TestWSHandler_Prune_Integration(t *testing.T) {
// Prepare prune message
// Protocol: [opcode:1][api_key_hash:16][prune_type:1][value:4]
opcode := byte(api.OpcodePrune)
opcode := byte(wspkg.OpcodePrune)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
pruneType := byte(0) // Keep N
@ -897,7 +898,7 @@ func TestWSHandler_LogMetric_Integration(t *testing.T) {
// Prepare log_metric message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var]
opcode := byte(api.OpcodeLogMetric)
opcode := byte(wspkg.OpcodeLogMetric)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := []byte(commitIDStr)
@ -957,7 +958,7 @@ func TestWSHandler_GetExperiment_Integration(t *testing.T) {
// Prepare get_experiment message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20]
opcode := byte(api.OpcodeGetExperiment)
opcode := byte(wspkg.OpcodeGetExperiment)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := []byte(commitIDStr)
@ -995,7 +996,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) {
// 1) List should return empty array
{
msg := make([]byte, 0, 1+16)
msg = append(msg, byte(api.OpcodeDatasetList))
msg = append(msg, byte(wspkg.OpcodeDatasetList))
msg = append(msg, apiKeyHash...)
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
@ -1010,7 +1011,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) {
urlStr := "https://example.com/mnist.tar.gz"
{
msg := make([]byte, 0, 1+16+1+len(name)+2+len(urlStr))
msg = append(msg, byte(api.OpcodeDatasetRegister))
msg = append(msg, byte(wspkg.OpcodeDatasetRegister))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(name)))
msg = append(msg, []byte(name)...)
@ -1030,7 +1031,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) {
// 3) Info should return PacketTypeData
{
msg := make([]byte, 0, 1+16+1+len(name))
msg = append(msg, byte(api.OpcodeDatasetInfo))
msg = append(msg, byte(wspkg.OpcodeDatasetInfo))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(name)))
msg = append(msg, []byte(name)...)
@ -1047,7 +1048,7 @@ func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) {
term := "mn"
{
msg := make([]byte, 0, 1+16+1+len(term))
msg = append(msg, byte(api.OpcodeDatasetSearch))
msg = append(msg, byte(wspkg.OpcodeDatasetSearch))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(term)))
msg = append(msg, []byte(term)...)

View file

@ -4,17 +4,17 @@ package api_test
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
)
func TestWSHandlerNewQueueOpcodeExists(t *testing.T) {
if api.OpcodeQueueJobWithNote != 0x1B {
t.Fatalf("expected OpcodeQueueJobWithNote to be 0x1B, got %d", api.OpcodeQueueJobWithNote)
if wspkg.OpcodeQueueJobWithNote != 0x1B {
t.Fatalf("expected OpcodeQueueJobWithNote to be 0x1B, got %d", wspkg.OpcodeQueueJobWithNote)
}
if api.OpcodeAnnotateRun != 0x1C {
t.Fatalf("expected OpcodeAnnotateRun to be 0x1C, got %d", api.OpcodeAnnotateRun)
if wspkg.OpcodeAnnotateRun != 0x1C {
t.Fatalf("expected OpcodeAnnotateRun to be 0x1C, got %d", wspkg.OpcodeAnnotateRun)
}
if api.OpcodeSetRunNarrative != 0x1D {
t.Fatalf("expected OpcodeSetRunNarrative to be 0x1D, got %d", api.OpcodeSetRunNarrative)
if wspkg.OpcodeSetRunNarrative != 0x1D {
t.Fatalf("expected OpcodeSetRunNarrative to be 0x1D, got %d", wspkg.OpcodeSetRunNarrative)
}
}

View file

@ -8,7 +8,7 @@ import (
"strings"
"testing"
"github.com/jfraeys/fetch_ml/internal/api"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -21,7 +21,7 @@ func TestNewWSHandler(t *testing.T) {
logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger
expManager := experiment.NewManager("/tmp")
handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
if handler == nil {
t.Error("Expected non-nil WSHandler")
@ -32,20 +32,20 @@ func TestWSHandlerConstants(t *testing.T) {
t.Parallel() // Enable parallel execution
// Test that constants are defined correctly
if api.OpcodeQueueJob != 0x01 {
t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", api.OpcodeQueueJob)
if wspkg.OpcodeQueueJob != 0x01 {
t.Errorf("Expected OpcodeQueueJob to be 0x01, got %d", wspkg.OpcodeQueueJob)
}
if api.OpcodeStatusRequest != 0x02 {
t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", api.OpcodeStatusRequest)
if wspkg.OpcodeStatusRequest != 0x02 {
t.Errorf("Expected OpcodeStatusRequest to be 0x02, got %d", wspkg.OpcodeStatusRequest)
}
if api.OpcodeCancelJob != 0x03 {
t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", api.OpcodeCancelJob)
if wspkg.OpcodeCancelJob != 0x03 {
t.Errorf("Expected OpcodeCancelJob to be 0x03, got %d", wspkg.OpcodeCancelJob)
}
if api.OpcodePrune != 0x04 {
t.Errorf("Expected OpcodePrune to be 0x04, got %d", api.OpcodePrune)
if wspkg.OpcodePrune != 0x04 {
t.Errorf("Expected OpcodePrune to be 0x04, got %d", wspkg.OpcodePrune)
}
}
@ -56,7 +56,7 @@ func TestWSHandlerWebSocketUpgrade(t *testing.T) {
logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger
expManager := experiment.NewManager("/tmp")
handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
// Create a test HTTP request
req := httptest.NewRequest("GET", "/ws", nil)
@ -93,7 +93,7 @@ func TestWSHandlerInvalidRequest(t *testing.T) {
logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger
expManager := experiment.NewManager("/tmp")
handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
// Create a test HTTP request without WebSocket headers
req := httptest.NewRequest("GET", "/ws", nil)
@ -118,7 +118,7 @@ func TestWSHandlerPostRequest(t *testing.T) {
logger := logging.NewLogger(slog.LevelInfo, false) // Create a real logger
expManager := experiment.NewManager("/tmp")
handler := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
handler := wspkg.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
// Create a POST request (should fail)
req := httptest.NewRequest("POST", "/ws", strings.NewReader("data"))
@ -158,10 +158,10 @@ func TestWebSocketMessageConstants(t *testing.T) {
// Test that binary protocol constants are properly defined
constants := map[string]byte{
"OpcodeQueueJob": api.OpcodeQueueJob,
"OpcodeStatusRequest": api.OpcodeStatusRequest,
"OpcodeCancelJob": api.OpcodeCancelJob,
"OpcodePrune": api.OpcodePrune,
"OpcodeQueueJob": wspkg.OpcodeQueueJob,
"OpcodeStatusRequest": wspkg.OpcodeStatusRequest,
"OpcodeCancelJob": wspkg.OpcodeCancelJob,
"OpcodePrune": wspkg.OpcodePrune,
}
expectedValues := map[string]byte{