package api import ( "compress/flate" "context" "fmt" "net" "net/http" "net/url" "strings" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" ) // Opcodes for binary WebSocket protocol const ( OpcodeQueueJob = 0x01 OpcodeStatusRequest = 0x02 OpcodeCancelJob = 0x03 OpcodePrune = 0x04 OpcodeDatasetList = 0x06 OpcodeDatasetRegister = 0x07 OpcodeDatasetInfo = 0x08 OpcodeDatasetSearch = 0x09 OpcodeLogMetric = 0x0A OpcodeGetExperiment = 0x0B OpcodeQueueJobWithTracking = 0x0C OpcodeQueueJobWithSnapshot = 0x17 OpcodeQueueJobWithArgs = 0x1A OpcodeQueueJobWithNote = 0x1B OpcodeAnnotateRun = 0x1C OpcodeSetRunNarrative = 0x1D OpcodeStartJupyter = 0x0D OpcodeStopJupyter = 0x0E OpcodeRemoveJupyter = 0x18 OpcodeRestoreJupyter = 0x19 OpcodeListJupyter = 0x0F OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 ) // createUpgrader creates a WebSocket upgrader with the given security configuration func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { return websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return true // Allow same-origin requests } // Production mode: strict checking against allowed origins if securityCfg != nil && securityCfg.ProductionMode { for _, allowed := range securityCfg.AllowedOrigins { if origin == allowed { return true } } return false // Reject if not in allowed list } // Development mode: allow localhost and local network origins parsedOrigin, err := url.Parse(origin) if err != nil { return false } host := parsedOrigin.Host return strings.HasSuffix(host, ":8080") || strings.HasPrefix(host, "localhost:") || strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "192.168.") || strings.HasPrefix(host, "10.") || strings.HasPrefix(host, "172.") }, // Performance optimizations HandshakeTimeout: 10 * time.Second, ReadBufferSize: 16 * 1024, WriteBufferSize: 16 * 1024, EnableCompression: true, } } // WSHandler handles WebSocket connections for the API. type WSHandler struct { authConfig *auth.Config logger *logging.Logger expManager *experiment.Manager dataDir string queue queue.Backend db *storage.DB jupyterServiceMgr *jupyter.ServiceManager securityConfig *config.SecurityConfig auditLogger *audit.Logger upgrader websocket.Upgrader } // NewWSHandler creates a new WebSocket handler. func NewWSHandler( authConfig *auth.Config, logger *logging.Logger, expManager *experiment.Manager, dataDir string, taskQueue queue.Backend, db *storage.DB, jupyterServiceMgr *jupyter.ServiceManager, securityConfig *config.SecurityConfig, auditLogger *audit.Logger, ) *WSHandler { return &WSHandler{ authConfig: authConfig, logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"), expManager: expManager, dataDir: dataDir, queue: taskQueue, db: db, jupyterServiceMgr: jupyterServiceMgr, securityConfig: securityConfig, auditLogger: auditLogger, upgrader: createUpgrader(securityConfig), } } // enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets. func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) { if conn == nil { return } if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok { if err := tcpConn.SetNoDelay(true); err != nil { logger.Warn("failed to enable tcp no delay", "error", err) } } } func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Add security headers w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") if r.TLS != nil { // Only set HSTS if using HTTPS w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } // Check API key before upgrading WebSocket apiKey := auth.ExtractAPIKeyFromRequest(r) clientIP := r.RemoteAddr // Validate API key if authentication is enabled if h.authConfig != nil && h.authConfig.Enabled { prefixLen := len(apiKey) if prefixLen > 8 { prefixLen = 8 } h.logger.Info( "websocket auth attempt", "api_key_length", len(apiKey), "api_key_prefix", apiKey[:prefixLen], ) userID, err := h.authConfig.ValidateAPIKey(apiKey) if err != nil { h.logger.Warn("websocket authentication failed", "error", err) // Audit log failed authentication if h.auditLogger != nil { h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error()) } http.Error(w, "Invalid API key", http.StatusUnauthorized) return } h.logger.Info("websocket authentication succeeded") // Audit log successful authentication if h.auditLogger != nil && userID != nil { h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "") } } conn, err := h.upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } conn.EnableWriteCompression(true) if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil { h.logger.Warn("failed to set websocket compression level", "error", err) } enableLowLatencyTCP(conn, h.logger) defer func() { _ = conn.Close() }() h.logger.Info("websocket connection established", "remote", r.RemoteAddr) for { messageType, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError( err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, ) { h.logger.Error("websocket read error", "error", err) } break } if messageType != websocket.BinaryMessage { h.logger.Warn("received non-binary message") continue } if err := h.handleMessage(conn, message); err != nil { h.logger.Error("message handling error", "error", err) // Send structured error response so CLI clients can parse it. // (Raw fallback bytes cause client-side InvalidPacket errors.) _ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error()) } } } func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { if len(message) < 1 { return fmt.Errorf("message too short") } opcode := message[0] payload := message[1:] switch opcode { case OpcodeQueueJob: return h.handleQueueJob(conn, payload) case OpcodeQueueJobWithTracking: return h.handleQueueJobWithTracking(conn, payload) case OpcodeQueueJobWithSnapshot: return h.handleQueueJobWithSnapshot(conn, payload) case OpcodeQueueJobWithArgs: return h.handleQueueJobWithArgs(conn, payload) case OpcodeQueueJobWithNote: return h.handleQueueJobWithNote(conn, payload) case OpcodeAnnotateRun: return h.handleAnnotateRun(conn, payload) case OpcodeSetRunNarrative: return h.handleSetRunNarrative(conn, payload) case OpcodeStatusRequest: return h.handleStatusRequest(conn, payload) case OpcodeCancelJob: return h.handleCancelJob(conn, payload) case OpcodePrune: return h.handlePrune(conn, payload) case OpcodeDatasetList: return h.handleDatasetList(conn, payload) case OpcodeDatasetRegister: return h.handleDatasetRegister(conn, payload) case OpcodeDatasetInfo: return h.handleDatasetInfo(conn, payload) case OpcodeDatasetSearch: return h.handleDatasetSearch(conn, payload) case OpcodeLogMetric: return h.handleLogMetric(conn, payload) case OpcodeGetExperiment: return h.handleGetExperiment(conn, payload) case OpcodeStartJupyter: return h.handleStartJupyter(conn, payload) case OpcodeStopJupyter: return h.handleStopJupyter(conn, payload) case OpcodeRemoveJupyter: return h.handleRemoveJupyter(conn, payload) case OpcodeRestoreJupyter: return h.handleRestoreJupyter(conn, payload) case OpcodeListJupyter: return h.handleListJupyter(conn, payload) case OpcodeListJupyterPackages: return h.handleListJupyterPackages(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) default: return fmt.Errorf("unknown opcode: 0x%02x", opcode) } }