fetch_ml/internal/api/protocol.go
Jeremie Fraeys 188cf55939
refactor(api): overhaul WebSocket handler and protocol layer
Major WebSocket handler refactor:
- Rewrite ws/handler.go with structured message routing and backpressure
- Add connection lifecycle management with heartbeats and timeouts
- Implement graceful connection draining for zero-downtime restarts

Protocol improvements:
- Define structured protocol types in protocol.go for hub communication
- Add versioned message envelopes for backward compatibility
- Standardize error codes and response formats across WebSocket API

Job streaming via WebSocket:
- Simplify ws/jobs.go with async job status streaming
- Add compression for high-volume job updates

Testing:
- Update websocket_e2e_test.go for new protocol semantics
- Add connection resilience tests
2026-03-12 12:01:21 -04:00

382 lines
10 KiB
Go

package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"sync"
"time"
)
// safeUint64FromTime safely converts time.Time to uint64 timestamp
func safeUint64FromTime(t time.Time) uint64 {
unix := t.Unix()
if unix < 0 {
return 0
}
return uint64(unix)
}
var bufferPool = sync.Pool{
New: func() any {
buf := make([]byte, 0, 256)
return &buf
},
}
// Response packet types
const (
PacketTypeSuccess = 0x00
PacketTypeError = 0x01
PacketTypeProgress = 0x02
PacketTypeStatus = 0x03
PacketTypeData = 0x04
PacketTypeLog = 0x05
)
// Error codes - byte values for compact binary wire format
// Groupings are intentional and indicate error categories:
//
// 0x00-0x05 = Generic client errors (validation, auth, permissions)
// 0x10-0x14 = Infrastructure errors (server, database, network, storage, timeout)
// 0x20-0x24 = Job lifecycle errors (not found, running, failed to start, execution failed, cancelled)
// 0x30-0x33 = Resource exhaustion errors (OOM, disk full, config, unavailable)
//
// For human-readable error codes, use the string constants from internal/api/errors.
// This package provides ByteCodeFromErrorCode() to bridge string codes to wire bytes.
const (
ErrorCodeUnknownError = 0x00
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeResourceAlreadyExists = 0x05
ErrorCodeServerOverloaded = 0x10
ErrorCodeDatabaseError = 0x11
ErrorCodeNetworkError = 0x12
ErrorCodeStorageError = 0x13
ErrorCodeTimeout = 0x14
ErrorCodeJobNotFound = 0x20
ErrorCodeJobAlreadyRunning = 0x21
ErrorCodeJobFailedToStart = 0x22
ErrorCodeJobExecutionFailed = 0x23
ErrorCodeJobCancelled = 0x24
ErrorCodeOutOfMemory = 0x30
ErrorCodeDiskFull = 0x31
ErrorCodeInvalidConfiguration = 0x32
ErrorCodeServiceUnavailable = 0x33
)
// Progress types
const (
ProgressTypePercentage = 0x00
ProgressTypeStage = 0x01
ProgressTypeMessage = 0x02
ProgressTypeBytesTransferred = 0x03
)
// Log levels
const (
LogLevelDebug = 0x00
LogLevelInfo = 0x01
LogLevelWarn = 0x02
LogLevelError = 0x03
)
// ResponsePacket represents a structured response packet
type ResponsePacket struct {
DataType string
SuccessMessage string
LogMessage string
ErrorMessage string
ErrorDetails string
ProgressMessage string
StatusData string
DataPayload []byte
Timestamp uint64
ProgressValue uint32
ProgressTotal uint32
ErrorCode byte
ProgressType byte
LogLevel byte
PacketType byte
}
// NewSuccessPacket creates a success response packet
func NewSuccessPacket(message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeSuccess,
Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message,
}
}
// NewSuccessPacketWithPayload creates a success response packet with JSON payload
func NewSuccessPacketWithPayload(message string, payload any) *ResponsePacket {
// Convert payload to JSON for the DataPayload field
payloadBytes, _ := json.Marshal(payload)
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message,
DataType: "status",
DataPayload: payloadBytes,
}
}
// NewErrorPacket creates an error response packet
// Accepts string error code from internal/api/errors package
func NewErrorPacket(errorCode string, message string, details string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeError,
Timestamp: safeUint64FromTime(time.Now()),
ErrorCode: ByteCodeFromErrorCode(errorCode),
ErrorMessage: message,
ErrorDetails: details,
}
}
// NewProgressPacket creates a progress response packet
func NewProgressPacket(
progressType byte,
value uint32,
total uint32,
message string,
) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeProgress,
Timestamp: safeUint64FromTime(time.Now()),
ProgressType: progressType,
ProgressValue: value,
ProgressTotal: total,
ProgressMessage: message,
}
}
// NewStatusPacket creates a status response packet
func NewStatusPacket(data string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeStatus,
Timestamp: safeUint64FromTime(time.Now()),
StatusData: data,
}
}
// NewDataPacket creates a data response packet
func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: safeUint64FromTime(time.Now()),
DataType: dataType,
DataPayload: payload,
}
}
// NewLogPacket creates a log response packet
func NewLogPacket(level byte, message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeLog,
Timestamp: safeUint64FromTime(time.Now()),
LogLevel: level,
LogMessage: message,
}
}
// Serialize converts the packet to binary format
func (p *ResponsePacket) Serialize() ([]byte, error) {
// For small packets, avoid pool overhead
if p.estimatedSize() <= 1024 {
buf := make([]byte, 0, p.estimatedSize())
return serializePacketToBuffer(p, buf)
}
// Use pool for larger packets
bufPtr := bufferPool.Get().(*[]byte)
defer func() {
*bufPtr = (*bufPtr)[:0]
bufferPool.Put(bufPtr)
}()
buf := *bufPtr
// Ensure buffer has enough capacity
if cap(buf) < p.estimatedSize() {
buf = make([]byte, 0, p.estimatedSize())
} else {
buf = buf[:0]
}
return serializePacketToBuffer(p, buf)
}
func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) {
// Packet type
buf = append(buf, p.PacketType)
// Timestamp (8 bytes, big-endian)
var timestampBytes [8]byte
binary.BigEndian.PutUint64(timestampBytes[:], p.Timestamp)
buf = append(buf, timestampBytes[:]...)
// Packet-specific data
switch p.PacketType {
case PacketTypeSuccess:
buf = appendString(buf, p.SuccessMessage)
case PacketTypeError:
buf = append(buf, p.ErrorCode)
buf = appendString(buf, p.ErrorMessage)
buf = appendString(buf, p.ErrorDetails)
case PacketTypeProgress:
buf = append(buf, p.ProgressType)
buf = appendUint32(buf, p.ProgressValue)
buf = appendUint32(buf, p.ProgressTotal)
buf = appendString(buf, p.ProgressMessage)
case PacketTypeStatus:
buf = appendString(buf, p.StatusData)
case PacketTypeData:
buf = appendString(buf, p.DataType)
buf = appendBytes(buf, p.DataPayload)
case PacketTypeLog:
buf = append(buf, p.LogLevel)
buf = appendString(buf, p.LogMessage)
default:
return nil, fmt.Errorf("unknown packet type: %d", p.PacketType)
}
return buf, nil
}
// uint16ToBytes extracts high and low bytes from uint16 safely
func uint16ToBytes(v uint16) (high, low byte) {
var b [2]byte
binary.BigEndian.PutUint16(b[:], v)
return b[0], b[1]
}
// appendString writes a string with fixed 16-bit length prefix
func appendString(buf []byte, s string) []byte {
length := min(len(s), 65535)
// #nosec G115 -- length is bounded by min() to 65535, safe conversion
len16 := uint16(length)
high, low := uint16ToBytes(len16)
buf = append(buf, high, low)
buf = append(buf, s...)
return buf
}
// uint32ToBytes extracts 4 bytes from uint32 safely
func uint32ToBytes(v uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], v)
return b
}
// appendBytes writes bytes with fixed 32-bit length prefix
func appendBytes(buf []byte, b []byte) []byte {
length := min(len(b), 4294967295)
// #nosec G115 -- length is bounded by min() to max uint32, safe conversion
len32 := uint32(length)
bytes := uint32ToBytes(len32)
buf = append(buf, bytes[:]...)
buf = append(buf, b...)
return buf
}
func appendUint32(buf []byte, value uint32) []byte {
var tmp [4]byte
binary.BigEndian.PutUint32(tmp[:], value)
return append(buf, tmp[:]...)
}
func (p *ResponsePacket) estimatedSize() int {
base := 1 + 8 // packet type + timestamp
switch p.PacketType {
case PacketTypeSuccess:
return base + 2 + len(p.SuccessMessage)
case PacketTypeError:
return base + 1 + 2 + len(p.ErrorMessage) + 2 + len(p.ErrorDetails)
case PacketTypeProgress:
return base + 1 + 4 + 4 + 2 + len(p.ProgressMessage)
case PacketTypeStatus:
return base + 2 + len(p.StatusData)
case PacketTypeData:
return base + 2 + len(p.DataType) + 4 + len(p.DataPayload)
case PacketTypeLog:
return base + 1 + 2 + len(p.LogMessage)
default:
return base
}
}
// ByteCodeFromErrorCode converts string error codes from internal/api/errors to wire format bytes
// This is the single mapping point between human-readable string codes and compact binary codes
func ByteCodeFromErrorCode(code string) byte {
switch code {
case "INVALID_REQUEST", "BAD_REQUEST":
return ErrorCodeInvalidRequest
case "AUTHENTICATION_FAILED":
return ErrorCodeAuthenticationFailed
case "PERMISSION_DENIED", "FORBIDDEN":
return ErrorCodePermissionDenied
case "RESOURCE_NOT_FOUND", "NOT_FOUND":
return ErrorCodeResourceNotFound
case "RESOURCE_ALREADY_EXISTS":
return ErrorCodeResourceAlreadyExists
case "SERVER_OVERLOADED":
return ErrorCodeServerOverloaded
case "DATABASE_ERROR":
return ErrorCodeDatabaseError
case "NETWORK_ERROR":
return ErrorCodeNetworkError
case "STORAGE_ERROR":
return ErrorCodeStorageError
case "TIMEOUT":
return ErrorCodeTimeout
case "JOB_NOT_FOUND":
return ErrorCodeJobNotFound
case "JOB_ALREADY_RUNNING":
return ErrorCodeJobAlreadyRunning
case "JOB_FAILED_TO_START":
return ErrorCodeJobFailedToStart
case "JOB_EXECUTION_FAILED":
return ErrorCodeJobExecutionFailed
case "JOB_CANCELLED":
return ErrorCodeJobCancelled
case "OUT_OF_MEMORY":
return ErrorCodeOutOfMemory
case "DISK_FULL":
return ErrorCodeDiskFull
case "INVALID_CONFIGURATION":
return ErrorCodeInvalidConfiguration
case "SERVICE_UNAVAILABLE":
return ErrorCodeServiceUnavailable
default:
return ErrorCodeUnknownError
}
}
// GetLogLevelName returns the name for a log level
func GetLogLevelName(level byte) string {
switch level {
case LogLevelDebug:
return "DEBUG"
case LogLevelInfo:
return "INFO"
case LogLevelWarn:
return "WARN"
case LogLevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}