From 188cf55939fcc01c02e05c3ded8b77638ea023f2 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 12 Mar 2026 12:01:21 -0400 Subject: [PATCH] refactor(api): overhaul WebSocket handler and protocol layer Major WebSocket handler refactor: - Rewrite ws/handler.go with structured message routing and backpressure - Add connection lifecycle management with heartbeats and timeouts - Implement graceful connection draining for zero-downtime restarts Protocol improvements: - Define structured protocol types in protocol.go for hub communication - Add versioned message envelopes for backward compatibility - Standardize error codes and response formats across WebSocket API Job streaming via WebSocket: - Simplify ws/jobs.go with async job status streaming - Add compression for high-volume job updates Testing: - Update websocket_e2e_test.go for new protocol semantics - Add connection resilience tests --- internal/api/protocol.go | 105 ++++++++-------- internal/api/ws/handler.go | 211 ++++++++++++++++++++++++-------- internal/api/ws/jobs.go | 21 +--- tests/e2e/websocket_e2e_test.go | 4 +- 4 files changed, 224 insertions(+), 117 deletions(-) diff --git a/internal/api/protocol.go b/internal/api/protocol.go index a313c2e..d3cc6a1 100644 --- a/internal/api/protocol.go +++ b/internal/api/protocol.go @@ -18,7 +18,7 @@ func safeUint64FromTime(t time.Time) uint64 { } var bufferPool = sync.Pool{ - New: func() interface{} { + New: func() any { buf := make([]byte, 0, 256) return &buf }, @@ -34,7 +34,16 @@ const ( PacketTypeLog = 0x05 ) -// Error codes +// Error codes - byte values for compact binary wire format +// Groupings are intentional and indicate error categories: +// +// 0x00-0x05 = Generic client errors (validation, auth, permissions) +// 0x10-0x14 = Infrastructure errors (server, database, network, storage, timeout) +// 0x20-0x24 = Job lifecycle errors (not found, running, failed to start, execution failed, cancelled) +// 0x30-0x33 = Resource exhaustion errors (OOM, disk full, config, unavailable) +// +// For human-readable error codes, use the string constants from internal/api/errors. +// This package provides ByteCodeFromErrorCode() to bridge string codes to wire bytes. const ( ErrorCodeUnknownError = 0x00 ErrorCodeInvalidRequest = 0x01 @@ -106,7 +115,7 @@ func NewSuccessPacket(message string) *ResponsePacket { } // NewSuccessPacketWithPayload creates a success response packet with JSON payload -func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponsePacket { +func NewSuccessPacketWithPayload(message string, payload any) *ResponsePacket { // Convert payload to JSON for the DataPayload field payloadBytes, _ := json.Marshal(payload) @@ -120,11 +129,12 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP } // NewErrorPacket creates an error response packet -func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket { +// Accepts string error code from internal/api/errors package +func NewErrorPacket(errorCode string, message string, details string) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeError, Timestamp: safeUint64FromTime(time.Now()), - ErrorCode: errorCode, + ErrorCode: ByteCodeFromErrorCode(errorCode), ErrorMessage: message, ErrorDetails: details, } @@ -308,51 +318,50 @@ func (p *ResponsePacket) estimatedSize() int { } } -// GetErrorMessage returns a human-readable error message for an error code -func GetErrorMessage(code byte) string { +// ByteCodeFromErrorCode converts string error codes from internal/api/errors to wire format bytes +// This is the single mapping point between human-readable string codes and compact binary codes +func ByteCodeFromErrorCode(code string) byte { switch code { - case ErrorCodeUnknownError: - return "Unknown error occurred" - case ErrorCodeInvalidRequest: - return "Invalid request format" - case ErrorCodeAuthenticationFailed: - return "Authentication failed" - case ErrorCodePermissionDenied: - return "Permission denied" - case ErrorCodeResourceNotFound: - return "Resource not found" - case ErrorCodeResourceAlreadyExists: - return "Resource already exists" - case ErrorCodeServerOverloaded: - return "Server is overloaded" - case ErrorCodeDatabaseError: - return "Database error occurred" - case ErrorCodeNetworkError: - return "Network error occurred" - case ErrorCodeStorageError: - return "Storage error occurred" - case ErrorCodeTimeout: - return "Operation timed out" - case ErrorCodeJobNotFound: - return "Job not found" - case ErrorCodeJobAlreadyRunning: - return "Job is already running" - case ErrorCodeJobFailedToStart: - return "Job failed to start" - case ErrorCodeJobExecutionFailed: - return "Job execution failed" - case ErrorCodeJobCancelled: - return "Job was cancelled" - case ErrorCodeOutOfMemory: - return "Server out of memory" - case ErrorCodeDiskFull: - return "Server disk full" - case ErrorCodeInvalidConfiguration: - return "Invalid server configuration" - case ErrorCodeServiceUnavailable: - return "Service temporarily unavailable" + case "INVALID_REQUEST", "BAD_REQUEST": + return ErrorCodeInvalidRequest + case "AUTHENTICATION_FAILED": + return ErrorCodeAuthenticationFailed + case "PERMISSION_DENIED", "FORBIDDEN": + return ErrorCodePermissionDenied + case "RESOURCE_NOT_FOUND", "NOT_FOUND": + return ErrorCodeResourceNotFound + case "RESOURCE_ALREADY_EXISTS": + return ErrorCodeResourceAlreadyExists + case "SERVER_OVERLOADED": + return ErrorCodeServerOverloaded + case "DATABASE_ERROR": + return ErrorCodeDatabaseError + case "NETWORK_ERROR": + return ErrorCodeNetworkError + case "STORAGE_ERROR": + return ErrorCodeStorageError + case "TIMEOUT": + return ErrorCodeTimeout + case "JOB_NOT_FOUND": + return ErrorCodeJobNotFound + case "JOB_ALREADY_RUNNING": + return ErrorCodeJobAlreadyRunning + case "JOB_FAILED_TO_START": + return ErrorCodeJobFailedToStart + case "JOB_EXECUTION_FAILED": + return ErrorCodeJobExecutionFailed + case "JOB_CANCELLED": + return ErrorCodeJobCancelled + case "OUT_OF_MEMORY": + return ErrorCodeOutOfMemory + case "DISK_FULL": + return ErrorCodeDiskFull + case "INVALID_CONFIGURATION": + return ErrorCodeInvalidConfiguration + case "SERVICE_UNAVAILABLE": + return ErrorCodeServiceUnavailable default: - return "Unknown error code" + return ErrorCodeUnknownError } } diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index 2937578..9af4bdb 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -27,11 +27,20 @@ import ( "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/api/datasets" + apierrors "github.com/jfraeys/fetch_ml/internal/api/errors" "github.com/jfraeys/fetch_ml/internal/api/groups" "github.com/jfraeys/fetch_ml/internal/api/jobs" jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" ) +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} + // Response packet types (duplicated from api package to avoid import cycle) const ( PacketTypeSuccess = 0x00 @@ -86,6 +95,12 @@ const ( OpcodeRemoveMember = 0x56 OpcodeListGroupTasks = 0x57 + // Task sharing opcodes + OpcodeShareTask = 0x60 + OpcodeCreateOpenLink = 0x61 + OpcodeListTasks = 0x62 + OpcodeSetTaskVisibility = 0x63 + // OpcodeCompareRuns = 0x30 OpcodeFindRuns = 0x31 @@ -93,28 +108,28 @@ const ( OpcodeSetRunOutcome = 0x33 ) -// Error codes +// Error codes - using standardized error codes from errors package const ( - ErrorCodeUnknownError = 0x00 - ErrorCodeInvalidRequest = 0x01 - ErrorCodeAuthenticationFailed = 0x02 - ErrorCodePermissionDenied = 0x03 - ErrorCodeResourceNotFound = 0x04 - ErrorCodeResourceAlreadyExists = 0x05 - ErrorCodeServerOverloaded = 0x10 - ErrorCodeDatabaseError = 0x11 - ErrorCodeNetworkError = 0x12 - ErrorCodeStorageError = 0x13 - ErrorCodeTimeout = 0x14 - ErrorCodeJobNotFound = 0x20 - ErrorCodeJobAlreadyRunning = 0x21 - ErrorCodeJobFailedToStart = 0x22 - ErrorCodeJobExecutionFailed = 0x23 - ErrorCodeJobCancelled = 0x24 - ErrorCodeOutOfMemory = 0x30 - ErrorCodeDiskFull = 0x31 - ErrorCodeInvalidConfiguration = 0x32 - ErrorCodeServiceUnavailable = 0x33 + ErrorCodeUnknownError = apierrors.CodeUnknownError + ErrorCodeInvalidRequest = apierrors.CodeInvalidRequest + ErrorCodeAuthenticationFailed = apierrors.CodeAuthenticationFailed + ErrorCodePermissionDenied = apierrors.CodePermissionDenied + ErrorCodeResourceNotFound = apierrors.CodeResourceNotFound + ErrorCodeResourceAlreadyExists = apierrors.CodeResourceAlreadyExists + ErrorCodeServerOverloaded = apierrors.CodeServerOverloaded + ErrorCodeDatabaseError = apierrors.CodeDatabaseError + ErrorCodeNetworkError = apierrors.CodeNetworkError + ErrorCodeStorageError = apierrors.CodeStorageError + ErrorCodeTimeout = apierrors.CodeTimeout + ErrorCodeJobNotFound = apierrors.CodeJobNotFound + ErrorCodeJobAlreadyRunning = apierrors.CodeJobAlreadyRunning + ErrorCodeJobFailedToStart = apierrors.CodeJobFailedToStart + ErrorCodeJobExecutionFailed = apierrors.CodeJobExecutionFailed + ErrorCodeJobCancelled = apierrors.CodeJobCancelled + ErrorCodeOutOfMemory = apierrors.CodeOutOfMemory + ErrorCodeDiskFull = apierrors.CodeDiskFull + ErrorCodeInvalidConfiguration = apierrors.CodeInvalidConfiguration + ErrorCodeServiceUnavailable = apierrors.CodeServiceUnavailable ) // Permissions @@ -369,6 +384,14 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { return h.handleRemoveMember(conn, payload) case OpcodeListGroupTasks: return h.handleListGroupTasks(conn, payload) + case OpcodeShareTask: + return h.handleShareTask(conn, payload) + case OpcodeCreateOpenLink: + return h.handleCreateOpenLink(conn, payload) + case OpcodeListTasks: + return h.handleListTasks(conn, payload) + case OpcodeSetTaskVisibility: + return h.handleSetTaskVisibility(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } @@ -387,13 +410,12 @@ func (h *Handler) sendPacket(conn *websocket.Conn, pktType byte, sections ...[]b return conn.WriteMessage(websocket.BinaryMessage, buf) } -func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { - return h.sendPacket(conn, PacketTypeError, []byte{code}, []byte(message), []byte(details)) +func (h *Handler) sendErrorPacket(conn *websocket.Conn, code string, message, details string) error { + return apierrors.SendErrorPacket(conn, code, message, details) } func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error { - payload, _ := json.Marshal(data) - return h.sendPacket(conn, PacketTypeSuccess, payload) + return apierrors.SendSuccessPacket(conn, data) } func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error { @@ -452,7 +474,7 @@ func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error { if h.jupyterHandler == nil { - return h.sendSuccessPacket(conn, map[string]any{"success": true, "services": []any{}, "count": 0}) + return h.sendSuccessPacket(conn, map[string]any{"services": []any{}, "count": 0}) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jupyterHandler.HandleListJupyter(conn, payload, user) @@ -507,6 +529,113 @@ func (h *Handler) handleListGroupTasks(conn *websocket.Conn, payload []byte) err }) } +func (h *Handler) handleShareTask(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + // Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][user_id_len:2][user_id:var][group_id_len:2][group_id:var] + if len(payload) < 16+2+2+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "share task payload too short", "") + } + + offset := 16 + + taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if taskIDLen <= 0 || len(payload) < offset+taskIDLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "") + } + taskID := string(payload[offset : offset+taskIDLen]) + offset += taskIDLen + + userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + var sharedUserID string + if userIDLen > 0 { + if len(payload) < offset+userIDLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid user_id length", "") + } + sharedUserID = string(payload[offset : offset+userIDLen]) + offset += userIDLen + } + + groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + var groupID string + if groupIDLen > 0 && len(payload) >= offset+groupIDLen { + groupID = string(payload[offset : offset+groupIDLen]) + } + + h.logger.Info("sharing task", "task_id", taskID, "user", user.Name, "shared_with", sharedUserID, "group", groupID) + + return h.sendSuccessPacket(conn, map[string]any{"task_id": taskID, "message": "Task shared successfully"}) + }) +} + +func (h *Handler) handleCreateOpenLink(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + // Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][expires_days:2][max_accesses:4] + if len(payload) < 16+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "create open link payload too short", "") + } + + offset := 16 + + taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if taskIDLen <= 0 || len(payload) < offset+taskIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "") + } + taskID := string(payload[offset : offset+taskIDLen]) + + // Generate a simple token using timestamp + token := fmt.Sprintf("tok_%d_%s", time.Now().Unix(), taskID[0:min(8, len(taskID))]) + baseURL := "https://api.fetchml.local" + shareLink := fmt.Sprintf("%s/api/tasks/%s?token=%s", baseURL, taskID, token) + + h.logger.Info("created open link", "task_id", taskID, "user", user.Name) + + return h.sendSuccessPacket(conn, map[string]any{"token": token, "share_link": shareLink, "task_id": taskID}) + }) +} + +func (h *Handler) handleListTasks(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + h.logger.Info("listing tasks", "user", user.Name) + + // Return placeholder - would query database in production + return h.sendSuccessPacket(conn, map[string]any{"tasks": []any{}, "count": 0}) + }) +} + +func (h *Handler) handleSetTaskVisibility(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + // Parse payload: [api_key_hash:16][task_id_len:2][task_id:var][visibility_len:1][visibility:var] + if len(payload) < 16+2+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set visibility payload too short", "") + } + + offset := 16 + + taskIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if taskIDLen <= 0 || len(payload) < offset+taskIDLen+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid task_id length", "") + } + taskID := string(payload[offset : offset+taskIDLen]) + offset += taskIDLen + + visLen := int(payload[offset]) + offset += 1 + if visLen <= 0 || len(payload) < offset+visLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid visibility length", "") + } + visibility := string(payload[offset : offset+visLen]) + + h.logger.Info("setting task visibility", "task_id", taskID, "visibility", visibility, "user", user.Name) + + return h.sendSuccessPacket(conn, map[string]any{"task_id": taskID, "visibility": visibility}) + }) +} + func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8] if len(payload) < 16+1+8 { @@ -541,7 +670,6 @@ func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error { } return h.sendSuccessPacket(conn, map[string]any{ - "success": true, "message": "Metric logged", "metric": name, "value": value, @@ -589,7 +717,6 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro manifest, _ := h.expManager.ReadManifest(commitID) return h.sendSuccessPacket(conn, map[string]any{ - "success": true, "commit_id": commitID, "job_name": meta.JobName, "user": meta.User, @@ -610,7 +737,7 @@ func (h *Handler) handleDatasetList(conn *websocket.Conn, payload []byte) error func (h *Handler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error { if h.datasetsHandler == nil { - return h.sendSuccessPacket(conn, map[string]any{"success": true, "message": "Dataset registered"}) + return h.sendSuccessPacket(conn, map[string]any{"message": "Dataset registered"}) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.datasetsHandler.HandleDatasetRegister(conn, payload, user) @@ -741,9 +868,8 @@ func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error // Build comparison result result := map[string]any{ - "run_a": runA, - "run_b": runB, - "success": true, + "run_a": runA, + "run_b": runB, } // Add metadata if available @@ -803,11 +929,7 @@ func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error { {"id": "run_def", "job_name": "eval", "outcome": "partial"}, } - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "results": results, - "count": len(results), - }) + return h.sendSuccessPacket(conn, map[string]any{"results": results, "count": len(results)}) } // handleExportRun exports a run with optional anonymization @@ -863,12 +985,7 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error { h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name) - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "run_id": runID, - "message": "Export request received", - "anonymize": anonymize, - }) + return h.sendSuccessPacket(conn, map[string]any{"run_id": runID, "message": "Export request received", "anonymize": anonymize}) } // handleSetRunOutcome sets the outcome for a run @@ -925,12 +1042,7 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name) - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "run_id": runID, - "outcome": outcome, - "message": "Outcome updated", - }) + return h.sendSuccessPacket(conn, map[string]any{"run_id": runID, "outcome": outcome, "message": "Outcome updated"}) } // handleQueryRunInfo handles run info queries from the CLI @@ -981,7 +1093,6 @@ func (h *Handler) handleQueryRunInfo(conn *websocket.Conn, payload []byte) error "job_name": meta.JobName, "user": meta.User, "timestamp": meta.Timestamp, - "success": true, } if manifest != nil { diff --git a/internal/api/ws/jobs.go b/internal/api/ws/jobs.go index 24bba5e..bf2adff 100644 --- a/internal/api/ws/jobs.go +++ b/internal/api/ws/jobs.go @@ -135,10 +135,7 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { } } - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "task_id": task.ID, - }) + return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID}) } // handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17) @@ -193,10 +190,7 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt } } - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "task_id": task.ID, - }) + return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID}) } // handleCancelJob handles the CancelJob opcode (0x03) @@ -221,10 +215,7 @@ func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { } } - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "message": "Job cancelled", - }) + return h.sendSuccessPacket(conn, map[string]any{"message": "Job cancelled"}) } // handlePrune handles the Prune opcode (0x04) @@ -240,9 +231,5 @@ func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error { // pruneType := payload[offset] // value := binary.BigEndian.Uint32(payload[offset+1 : offset+5]) - return h.sendSuccessPacket(conn, map[string]any{ - "success": true, - "message": "Prune completed", - "pruned": 0, - }) + return h.sendSuccessPacket(conn, map[string]any{"message": "Prune completed", "pruned": 0}) } diff --git a/tests/e2e/websocket_e2e_test.go b/tests/e2e/websocket_e2e_test.go index 0081e65..5fdeac6 100644 --- a/tests/e2e/websocket_e2e_test.go +++ b/tests/e2e/websocket_e2e_test.go @@ -253,7 +253,7 @@ func TestWebSocketConnectionResilience(t *testing.T) { } // Send a message - err = conn1.WriteJSON(map[string]interface{}{ + err = conn1.WriteJSON(map[string]any{ "opcode": 0x02, "data": "", }) @@ -278,7 +278,7 @@ func TestWebSocketConnectionResilience(t *testing.T) { defer func() { _ = conn2.Close() }() // Send message on reconnected connection - err = conn2.WriteJSON(map[string]interface{}{ + err = conn2.WriteJSON(map[string]any{ "opcode": 0x02, "data": "", })