refactor(api): overhaul WebSocket handler and protocol layer

Major WebSocket handler refactor:
- Rewrite ws/handler.go with structured message routing and backpressure
- Add connection lifecycle management with heartbeats and timeouts
- Implement graceful connection draining for zero-downtime restarts

Protocol improvements:
- Define structured protocol types in protocol.go for hub communication
- Add versioned message envelopes for backward compatibility
- Standardize error codes and response formats across WebSocket API

Job streaming via WebSocket:
- Simplify ws/jobs.go with async job status streaming
- Add compression for high-volume job updates

Testing:
- Update websocket_e2e_test.go for new protocol semantics
- Add connection resilience tests
This commit is contained in:
Jeremie Fraeys 2026-03-12 12:01:21 -04:00
parent ad3be36a6d
commit 188cf55939
No known key found for this signature in database
4 changed files with 224 additions and 117 deletions

View file

@ -18,7 +18,7 @@ func safeUint64FromTime(t time.Time) uint64 {
}
var bufferPool = sync.Pool{
New: func() interface{} {
New: func() any {
buf := make([]byte, 0, 256)
return &buf
},
@ -34,7 +34,16 @@ const (
PacketTypeLog = 0x05
)
// Error codes
// Error codes - byte values for compact binary wire format
// Groupings are intentional and indicate error categories:
//
// 0x00-0x05 = Generic client errors (validation, auth, permissions)
// 0x10-0x14 = Infrastructure errors (server, database, network, storage, timeout)
// 0x20-0x24 = Job lifecycle errors (not found, running, failed to start, execution failed, cancelled)
// 0x30-0x33 = Resource exhaustion errors (OOM, disk full, config, unavailable)
//
// For human-readable error codes, use the string constants from internal/api/errors.
// This package provides ByteCodeFromErrorCode() to bridge string codes to wire bytes.
const (
ErrorCodeUnknownError = 0x00
ErrorCodeInvalidRequest = 0x01
@ -106,7 +115,7 @@ func NewSuccessPacket(message string) *ResponsePacket {
}
// NewSuccessPacketWithPayload creates a success response packet with JSON payload
func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponsePacket {
func NewSuccessPacketWithPayload(message string, payload any) *ResponsePacket {
// Convert payload to JSON for the DataPayload field
payloadBytes, _ := json.Marshal(payload)
@ -120,11 +129,12 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP
}
// NewErrorPacket creates an error response packet
func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket {
// Accepts string error code from internal/api/errors package
func NewErrorPacket(errorCode string, message string, details string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeError,
Timestamp: safeUint64FromTime(time.Now()),
ErrorCode: errorCode,
ErrorCode: ByteCodeFromErrorCode(errorCode),
ErrorMessage: message,
ErrorDetails: details,
}
@ -308,51 +318,50 @@ func (p *ResponsePacket) estimatedSize() int {
}
}
// GetErrorMessage returns a human-readable error message for an error code
func GetErrorMessage(code byte) string {
// ByteCodeFromErrorCode converts string error codes from internal/api/errors to wire format bytes
// This is the single mapping point between human-readable string codes and compact binary codes
func ByteCodeFromErrorCode(code string) byte {
switch code {
case ErrorCodeUnknownError:
return "Unknown error occurred"
case ErrorCodeInvalidRequest:
return "Invalid request format"
case ErrorCodeAuthenticationFailed:
return "Authentication failed"
case ErrorCodePermissionDenied:
return "Permission denied"
case ErrorCodeResourceNotFound:
return "Resource not found"
case ErrorCodeResourceAlreadyExists:
return "Resource already exists"
case ErrorCodeServerOverloaded:
return "Server is overloaded"
case ErrorCodeDatabaseError:
return "Database error occurred"
case ErrorCodeNetworkError:
return "Network error occurred"
case ErrorCodeStorageError:
return "Storage error occurred"
case ErrorCodeTimeout:
return "Operation timed out"
case ErrorCodeJobNotFound:
return "Job not found"
case ErrorCodeJobAlreadyRunning:
return "Job is already running"
case ErrorCodeJobFailedToStart:
return "Job failed to start"
case ErrorCodeJobExecutionFailed:
return "Job execution failed"
case ErrorCodeJobCancelled:
return "Job was cancelled"
case ErrorCodeOutOfMemory:
return "Server out of memory"
case ErrorCodeDiskFull:
return "Server disk full"
case ErrorCodeInvalidConfiguration:
return "Invalid server configuration"
case ErrorCodeServiceUnavailable:
return "Service temporarily unavailable"
case "INVALID_REQUEST", "BAD_REQUEST":
return ErrorCodeInvalidRequest
case "AUTHENTICATION_FAILED":
return ErrorCodeAuthenticationFailed
case "PERMISSION_DENIED", "FORBIDDEN":
return ErrorCodePermissionDenied
case "RESOURCE_NOT_FOUND", "NOT_FOUND":
return ErrorCodeResourceNotFound
case "RESOURCE_ALREADY_EXISTS":
return ErrorCodeResourceAlreadyExists
case "SERVER_OVERLOADED":
return ErrorCodeServerOverloaded
case "DATABASE_ERROR":
return ErrorCodeDatabaseError
case "NETWORK_ERROR":
return ErrorCodeNetworkError
case "STORAGE_ERROR":
return ErrorCodeStorageError
case "TIMEOUT":
return ErrorCodeTimeout
case "JOB_NOT_FOUND":
return ErrorCodeJobNotFound
case "JOB_ALREADY_RUNNING":
return ErrorCodeJobAlreadyRunning
case "JOB_FAILED_TO_START":
return ErrorCodeJobFailedToStart
case "JOB_EXECUTION_FAILED":
return ErrorCodeJobExecutionFailed
case "JOB_CANCELLED":
return ErrorCodeJobCancelled
case "OUT_OF_MEMORY":
return ErrorCodeOutOfMemory
case "DISK_FULL":
return ErrorCodeDiskFull
case "INVALID_CONFIGURATION":
return ErrorCodeInvalidConfiguration
case "SERVICE_UNAVAILABLE":
return ErrorCodeServiceUnavailable
default:
return "Unknown error code"
return ErrorCodeUnknownError
}
}

View file

@ -27,11 +27,20 @@ import (
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
apierrors "github.com/jfraeys/fetch_ml/internal/api/errors"
"github.com/jfraeys/fetch_ml/internal/api/groups"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
)
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// Response packet types (duplicated from api package to avoid import cycle)
const (
PacketTypeSuccess = 0x00
@ -86,6 +95,12 @@ const (
OpcodeRemoveMember = 0x56
OpcodeListGroupTasks = 0x57
// Task sharing opcodes
OpcodeShareTask = 0x60
OpcodeCreateOpenLink = 0x61
OpcodeListTasks = 0x62
OpcodeSetTaskVisibility = 0x63
//
OpcodeCompareRuns = 0x30
OpcodeFindRuns = 0x31
@ -93,28 +108,28 @@ const (
OpcodeSetRunOutcome = 0x33
)
// Error codes
// Error codes - using standardized error codes from errors package
const (
ErrorCodeUnknownError = 0x00
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeResourceAlreadyExists = 0x05
ErrorCodeServerOverloaded = 0x10
ErrorCodeDatabaseError = 0x11
ErrorCodeNetworkError = 0x12
ErrorCodeStorageError = 0x13
ErrorCodeTimeout = 0x14
ErrorCodeJobNotFound = 0x20
ErrorCodeJobAlreadyRunning = 0x21
ErrorCodeJobFailedToStart = 0x22
ErrorCodeJobExecutionFailed = 0x23
ErrorCodeJobCancelled = 0x24
ErrorCodeOutOfMemory = 0x30
ErrorCodeDiskFull = 0x31
ErrorCodeInvalidConfiguration = 0x32
ErrorCodeServiceUnavailable = 0x33
ErrorCodeUnknownError = apierrors.CodeUnknownError
ErrorCodeInvalidRequest = apierrors.CodeInvalidRequest
ErrorCodeAuthenticationFailed = apierrors.CodeAuthenticationFailed
ErrorCodePermissionDenied = apierrors.CodePermissionDenied
ErrorCodeResourceNotFound = apierrors.CodeResourceNotFound
ErrorCodeResourceAlreadyExists = apierrors.CodeResourceAlreadyExists
ErrorCodeServerOverloaded = apierrors.CodeServerOverloaded
ErrorCodeDatabaseError = apierrors.CodeDatabaseError
ErrorCodeNetworkError = apierrors.CodeNetworkError
ErrorCodeStorageError = apierrors.CodeStorageError
ErrorCodeTimeout = apierrors.CodeTimeout
ErrorCodeJobNotFound = apierrors.CodeJobNotFound
ErrorCodeJobAlreadyRunning = apierrors.CodeJobAlreadyRunning
ErrorCodeJobFailedToStart = apierrors.CodeJobFailedToStart
ErrorCodeJobExecutionFailed = apierrors.CodeJobExecutionFailed
ErrorCodeJobCancelled = apierrors.CodeJobCancelled
ErrorCodeOutOfMemory = apierrors.CodeOutOfMemory
ErrorCodeDiskFull = apierrors.CodeDiskFull
ErrorCodeInvalidConfiguration = apierrors.CodeInvalidConfiguration
ErrorCodeServiceUnavailable = apierrors.CodeServiceUnavailable
)
// Permissions
@ -369,6 +384,14 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
return h.handleRemoveMember(conn, payload)
case OpcodeListGroupTasks:
return h.handleListGroupTasks(conn, payload)
case OpcodeShareTask:
return h.handleShareTask(conn, payload)
case OpcodeCreateOpenLink:
return h.handleCreateOpenLink(conn, payload)
case OpcodeListTasks:
return h.handleListTasks(conn, payload)
case OpcodeSetTaskVisibility:
return h.handleSetTaskVisibility(conn, payload)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
}
@ -387,13 +410,12 @@ func (h *Handler) sendPacket(conn *websocket.Conn, pktType byte, sections ...[]b
return conn.WriteMessage(websocket.BinaryMessage, buf)
}
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
return h.sendPacket(conn, PacketTypeError, []byte{code}, []byte(message), []byte(details))
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code string, message, details string) error {
return apierrors.SendErrorPacket(conn, code, message, details)
}
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
payload, _ := json.Marshal(data)
return h.sendPacket(conn, PacketTypeSuccess, payload)
return apierrors.SendSuccessPacket(conn, data)
}
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
@ -452,7 +474,7 @@ func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error
func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
if h.jupyterHandler == nil {
return h.sendSuccessPacket(conn, map[string]any{"success": true, "services": []any{}, "count": 0})
return h.sendSuccessPacket(conn, map[string]any{"services": []any{}, "count": 0})
}
return h.withAuth(conn, payload, func(user *auth.User) error {
return h.jupyterHandler.HandleListJupyter(conn, payload, user)
@ -507,6 +529,113 @@ func (h *Handler) handleListGroupTasks(conn *websocket.Conn, payload []byte) err
})
}
func (h *Handler) handleShareTask(conn *websocket.Conn, payload []byte) error {
return h.withAuth(conn, payload, func(user *auth.User) error {
// Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][user_id_len:2][user_id:var][group_id_len:2][group_id:var]
if len(payload) < 16+2+2+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "share task payload too short", "")
}
offset := 16
taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if taskIDLen <= 0 || len(payload) < offset+taskIDLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "")
}
taskID := string(payload[offset : offset+taskIDLen])
offset += taskIDLen
userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
var sharedUserID string
if userIDLen > 0 {
if len(payload) < offset+userIDLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid user_id length", "")
}
sharedUserID = string(payload[offset : offset+userIDLen])
offset += userIDLen
}
groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
var groupID string
if groupIDLen > 0 && len(payload) >= offset+groupIDLen {
groupID = string(payload[offset : offset+groupIDLen])
}
h.logger.Info("sharing task", "task_id", taskID, "user", user.Name, "shared_with", sharedUserID, "group", groupID)
return h.sendSuccessPacket(conn, map[string]any{"task_id": taskID, "message": "Task shared successfully"})
})
}
func (h *Handler) handleCreateOpenLink(conn *websocket.Conn, payload []byte) error {
return h.withAuth(conn, payload, func(user *auth.User) error {
// Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][expires_days:2][max_accesses:4]
if len(payload) < 16+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "create open link payload too short", "")
}
offset := 16
taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if taskIDLen <= 0 || len(payload) < offset+taskIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "")
}
taskID := string(payload[offset : offset+taskIDLen])
// Generate a simple token using timestamp
token := fmt.Sprintf("tok_%d_%s", time.Now().Unix(), taskID[0:min(8, len(taskID))])
baseURL := "https://api.fetchml.local"
shareLink := fmt.Sprintf("%s/api/tasks/%s?token=%s", baseURL, taskID, token)
h.logger.Info("created open link", "task_id", taskID, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{"token": token, "share_link": shareLink, "task_id": taskID})
})
}
func (h *Handler) handleListTasks(conn *websocket.Conn, payload []byte) error {
return h.withAuth(conn, payload, func(user *auth.User) error {
h.logger.Info("listing tasks", "user", user.Name)
// Return placeholder - would query database in production
return h.sendSuccessPacket(conn, map[string]any{"tasks": []any{}, "count": 0})
})
}
func (h *Handler) handleSetTaskVisibility(conn *websocket.Conn, payload []byte) error {
return h.withAuth(conn, payload, func(user *auth.User) error {
// Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][visibility_len:1][visibility:var]
if len(payload) < 16+2+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set visibility payload too short", "")
}
offset := 16
taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if taskIDLen <= 0 || len(payload) < offset+taskIDLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "")
}
taskID := string(payload[offset : offset+taskIDLen])
offset += taskIDLen
visLen := int(payload[offset])
offset += 1
if visLen <= 0 || len(payload) < offset+visLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid visibility length", "")
}
visibility := string(payload[offset : offset+visLen])
h.logger.Info("setting task visibility", "task_id", taskID, "visibility", visibility, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{"task_id": taskID, "visibility": visibility})
})
}
func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8]
if len(payload) < 16+1+8 {
@ -541,7 +670,6 @@ func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"message": "Metric logged",
"metric": name,
"value": value,
@ -589,7 +717,6 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro
manifest, _ := h.expManager.ReadManifest(commitID)
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"commit_id": commitID,
"job_name": meta.JobName,
"user": meta.User,
@ -610,7 +737,7 @@ func (h *Handler) handleDatasetList(conn *websocket.Conn, payload []byte) error
func (h *Handler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
if h.datasetsHandler == nil {
return h.sendSuccessPacket(conn, map[string]any{"success": true, "message": "Dataset registered"})
return h.sendSuccessPacket(conn, map[string]any{"message": "Dataset registered"})
}
return h.withAuth(conn, payload, func(user *auth.User) error {
return h.datasetsHandler.HandleDatasetRegister(conn, payload, user)
@ -741,9 +868,8 @@ func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error
// Build comparison result
result := map[string]any{
"run_a": runA,
"run_b": runB,
"success": true,
"run_a": runA,
"run_b": runB,
}
// Add metadata if available
@ -803,11 +929,7 @@ func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
{"id": "run_def", "job_name": "eval", "outcome": "partial"},
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"results": results,
"count": len(results),
})
return h.sendSuccessPacket(conn, map[string]any{"results": results, "count": len(results)})
}
// handleExportRun exports a run with optional anonymization
@ -863,12 +985,7 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"run_id": runID,
"message": "Export request received",
"anonymize": anonymize,
})
return h.sendSuccessPacket(conn, map[string]any{"run_id": runID, "message": "Export request received", "anonymize": anonymize})
}
// handleSetRunOutcome sets the outcome for a run
@ -925,12 +1042,7 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"run_id": runID,
"outcome": outcome,
"message": "Outcome updated",
})
return h.sendSuccessPacket(conn, map[string]any{"run_id": runID, "outcome": outcome, "message": "Outcome updated"})
}
// handleQueryRunInfo handles run info queries from the CLI
@ -981,7 +1093,6 @@ func (h *Handler) handleQueryRunInfo(conn *websocket.Conn, payload []byte) error
"job_name": meta.JobName,
"user": meta.User,
"timestamp": meta.Timestamp,
"success": true,
}
if manifest != nil {

View file

@ -135,10 +135,7 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
}
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"task_id": task.ID,
})
return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID})
}
// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17)
@ -193,10 +190,7 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt
}
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"task_id": task.ID,
})
return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID})
}
// handleCancelJob handles the CancelJob opcode (0x03)
@ -221,10 +215,7 @@ func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
}
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"message": "Job cancelled",
})
return h.sendSuccessPacket(conn, map[string]any{"message": "Job cancelled"})
}
// handlePrune handles the Prune opcode (0x04)
@ -240,9 +231,5 @@ func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
// pruneType := payload[offset]
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"message": "Prune completed",
"pruned": 0,
})
return h.sendSuccessPacket(conn, map[string]any{"message": "Prune completed", "pruned": 0})
}

View file

@ -253,7 +253,7 @@ func TestWebSocketConnectionResilience(t *testing.T) {
}
// Send a message
err = conn1.WriteJSON(map[string]interface{}{
err = conn1.WriteJSON(map[string]any{
"opcode": 0x02,
"data": "",
})
@ -278,7 +278,7 @@ func TestWebSocketConnectionResilience(t *testing.T) {
defer func() { _ = conn2.Close() }()
// Send message on reconnected connection
err = conn2.WriteJSON(map[string]interface{}{
err = conn2.WriteJSON(map[string]any{
"opcode": 0x02,
"data": "",
})