// Package ws provides WebSocket handling for the API package ws import ( "errors" "net/http" "net/url" "strings" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" ) // Opcodes for binary WebSocket protocol const ( OpcodeQueueJob = 0x01 OpcodeStatusRequest = 0x02 OpcodeCancelJob = 0x03 OpcodePrune = 0x04 OpcodeDatasetList = 0x06 OpcodeDatasetRegister = 0x07 OpcodeDatasetInfo = 0x08 OpcodeDatasetSearch = 0x09 OpcodeLogMetric = 0x0A OpcodeGetExperiment = 0x0B OpcodeQueueJobWithTracking = 0x0C OpcodeQueueJobWithSnapshot = 0x17 OpcodeQueueJobWithArgs = 0x1A OpcodeQueueJobWithNote = 0x1B OpcodeAnnotateRun = 0x1C OpcodeSetRunNarrative = 0x1D OpcodeStartJupyter = 0x0D OpcodeStopJupyter = 0x0E OpcodeRemoveJupyter = 0x18 OpcodeRestoreJupyter = 0x19 OpcodeListJupyter = 0x0F OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 // Logs opcodes OpcodeGetLogs = 0x20 OpcodeStreamLogs = 0x21 ) // Error codes const ( ErrorCodeUnknownError = 0x00 ErrorCodeInvalidRequest = 0x01 ErrorCodeAuthenticationFailed = 0x02 ErrorCodePermissionDenied = 0x03 ErrorCodeResourceNotFound = 0x04 ErrorCodeResourceAlreadyExists = 0x05 ErrorCodeServerOverloaded = 0x10 ErrorCodeDatabaseError = 0x11 ErrorCodeNetworkError = 0x12 ErrorCodeStorageError = 0x13 ErrorCodeTimeout = 0x14 ErrorCodeJobNotFound = 0x20 ErrorCodeJobAlreadyRunning = 0x21 ErrorCodeJobFailedToStart = 0x22 ErrorCodeJobExecutionFailed = 0x23 ErrorCodeJobCancelled = 0x24 ErrorCodeOutOfMemory = 0x30 ErrorCodeDiskFull = 0x31 ErrorCodeInvalidConfiguration = 0x32 ErrorCodeServiceUnavailable = 0x33 ) // Permissions const ( PermJobsCreate = "jobs:create" PermJobsRead = "jobs:read" PermJobsUpdate = "jobs:update" PermDatasetsRead = "datasets:read" PermDatasetsCreate = "datasets:create" PermJupyterManage = "jupyter:manage" PermJupyterRead = "jupyter:read" ) // Handler provides WebSocket handling type Handler struct { authConfig *auth.Config logger *logging.Logger expManager *experiment.Manager dataDir string taskQueue queue.Backend db *storage.DB jupyterServiceMgr *jupyter.ServiceManager securityCfg *config.SecurityConfig auditLogger *audit.Logger upgrader websocket.Upgrader } // NewHandler creates a new WebSocket handler func NewHandler( authConfig *auth.Config, logger *logging.Logger, expManager *experiment.Manager, dataDir string, taskQueue queue.Backend, db *storage.DB, jupyterServiceMgr *jupyter.ServiceManager, securityCfg *config.SecurityConfig, auditLogger *audit.Logger, ) *Handler { upgrader := createUpgrader(securityCfg) return &Handler{ authConfig: authConfig, logger: logger, expManager: expManager, dataDir: dataDir, taskQueue: taskQueue, db: db, jupyterServiceMgr: jupyterServiceMgr, securityCfg: securityCfg, auditLogger: auditLogger, upgrader: upgrader, } } // createUpgrader creates a WebSocket upgrader with the given security configuration func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { return websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return true // Allow same-origin requests } // Production mode: strict checking against allowed origins if securityCfg != nil && securityCfg.ProductionMode { for _, allowed := range securityCfg.AllowedOrigins { if origin == allowed { return true } } return false // Reject if not in allowed list } // Development mode: allow localhost and local network origins parsedOrigin, err := url.Parse(origin) if err != nil { return false } host := parsedOrigin.Host if strings.HasPrefix(host, "localhost:") || strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "192.168.") || strings.HasPrefix(host, "10.") || strings.HasPrefix(host, "[::1]:") { return true } return false }, EnableCompression: true, } } // ServeHTTP implements http.Handler for WebSocket upgrade func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := h.upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } defer conn.Close() h.handleConnection(conn) } // handleConnection handles an established WebSocket connection func (h *Handler) handleConnection(conn *websocket.Conn) { h.logger.Info("websocket connection established", "remote", conn.RemoteAddr()) for { messageType, payload, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { h.logger.Error("websocket read error", "error", err) } break } if messageType != websocket.BinaryMessage { h.logger.Warn("received non-binary message, ignoring") continue } if err := h.handleMessage(conn, payload); err != nil { h.logger.Error("message handling error", "error", err) // Don't break, continue handling messages } } h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr()) } // handleMessage dispatches WebSocket messages to appropriate handlers func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { if len(payload) < 17 { // At least opcode + api_key_hash return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } opcode := payload[16] // After 16-byte API key hash switch opcode { case OpcodeAnnotateRun: return h.handleAnnotateRun(conn, payload) case OpcodeSetRunNarrative: return h.handleSetRunNarrative(conn, payload) case OpcodeStartJupyter: return h.handleStartJupyter(conn, payload) case OpcodeStopJupyter: return h.handleStopJupyter(conn, payload) case OpcodeListJupyter: return h.handleListJupyter(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } } // sendErrorPacket sends an error response packet func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { err := map[string]interface{}{ "error": true, "code": code, "message": message, "details": details, } return conn.WriteJSON(err) } // sendSuccessPacket sends a success response packet func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error { return conn.WriteJSON(data) } // Handler stubs - these would delegate to sub-packages in full implementation func (h *Handler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Annotate run handled", }) } func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error { // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Set run narrative handled", }) } func (h *Handler) handleStartJupyter(conn *websocket.Conn, payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Start jupyter handled", }) } func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Stop jupyter handled", }) } func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "List jupyter handled", }) } func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { // Would delegate to validate package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Validate request handled", }) } // Authenticate extracts and validates the API key from payload func (h *Handler) Authenticate(payload []byte) (*auth.User, error) { if len(payload) < 16 { return nil, errors.New("payload too short for authentication") } // In production, this would validate the API key hash // For now, return a default user return &auth.User{ Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}, }, nil } // RequirePermission checks if a user has a required permission func (h *Handler) RequirePermission(user *auth.User, permission string) bool { if user == nil { return false } if user.Admin { return true } return user.Permissions[permission] }