// Package ws provides WebSocket handling for the API package ws import ( "context" "encoding/binary" "encoding/json" "errors" "fmt" "net/http" "net/url" "os" "path/filepath" "strings" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/api/datasets" "github.com/jfraeys/fetch_ml/internal/api/jobs" jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" ) // Response packet types (duplicated from api package to avoid import cycle) const ( PacketTypeSuccess = 0x00 PacketTypeError = 0x01 PacketTypeProgress = 0x02 PacketTypeStatus = 0x03 PacketTypeData = 0x04 PacketTypeLog = 0x05 ) // Opcodes for binary WebSocket protocol const ( OpcodeQueueJob = 0x01 OpcodeStatusRequest = 0x02 OpcodeCancelJob = 0x03 OpcodePrune = 0x04 OpcodeDatasetList = 0x06 OpcodeDatasetRegister = 0x07 OpcodeDatasetInfo = 0x08 OpcodeDatasetSearch = 0x09 OpcodeLogMetric = 0x0A OpcodeGetExperiment = 0x0B OpcodeQueueJobWithTracking = 0x0C OpcodeQueueJobWithSnapshot = 0x17 OpcodeQueueJobWithArgs = 0x1A OpcodeQueueJobWithNote = 0x1B OpcodeAnnotateRun = 0x1C OpcodeSetRunNarrative = 0x1D OpcodeStartJupyter = 0x0D OpcodeStopJupyter = 0x0E OpcodeRemoveJupyter = 0x18 OpcodeRestoreJupyter = 0x19 OpcodeListJupyter = 0x0F OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 // Logs opcodes OpcodeGetLogs = 0x20 OpcodeStreamLogs = 0x21 ) // Error codes const ( ErrorCodeUnknownError = 0x00 ErrorCodeInvalidRequest = 0x01 ErrorCodeAuthenticationFailed = 0x02 ErrorCodePermissionDenied = 0x03 ErrorCodeResourceNotFound = 0x04 ErrorCodeResourceAlreadyExists = 0x05 ErrorCodeServerOverloaded = 0x10 ErrorCodeDatabaseError = 0x11 ErrorCodeNetworkError = 0x12 ErrorCodeStorageError = 0x13 ErrorCodeTimeout = 0x14 ErrorCodeJobNotFound = 0x20 ErrorCodeJobAlreadyRunning = 0x21 ErrorCodeJobFailedToStart = 0x22 ErrorCodeJobExecutionFailed = 0x23 ErrorCodeJobCancelled = 0x24 ErrorCodeOutOfMemory = 0x30 ErrorCodeDiskFull = 0x31 ErrorCodeInvalidConfiguration = 0x32 ErrorCodeServiceUnavailable = 0x33 ) // Permissions const ( PermJobsCreate = "jobs:create" PermJobsRead = "jobs:read" PermJobsUpdate = "jobs:update" PermDatasetsRead = "datasets:read" PermDatasetsCreate = "datasets:create" PermJupyterManage = "jupyter:manage" PermJupyterRead = "jupyter:read" ) // Handler provides WebSocket handling type Handler struct { authConfig *auth.Config logger *logging.Logger expManager *experiment.Manager dataDir string taskQueue queue.Backend db *storage.DB jupyterServiceMgr *jupyter.ServiceManager securityCfg *config.SecurityConfig auditLogger *audit.Logger upgrader websocket.Upgrader jobsHandler *jobs.Handler jupyterHandler *jupyterj.Handler datasetsHandler *datasets.Handler } // NewHandler creates a new WebSocket handler func NewHandler( authConfig *auth.Config, logger *logging.Logger, expManager *experiment.Manager, dataDir string, taskQueue queue.Backend, db *storage.DB, jupyterServiceMgr *jupyter.ServiceManager, securityCfg *config.SecurityConfig, auditLogger *audit.Logger, jobsHandler *jobs.Handler, jupyterHandler *jupyterj.Handler, datasetsHandler *datasets.Handler, ) *Handler { upgrader := createUpgrader(securityCfg) return &Handler{ authConfig: authConfig, logger: logger, expManager: expManager, dataDir: dataDir, taskQueue: taskQueue, db: db, jupyterServiceMgr: jupyterServiceMgr, securityCfg: securityCfg, auditLogger: auditLogger, upgrader: upgrader, jobsHandler: jobsHandler, jupyterHandler: jupyterHandler, datasetsHandler: datasetsHandler, } } // createUpgrader creates a WebSocket upgrader with the given security configuration func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { return websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return true // Allow same-origin requests } // Production mode: strict checking against allowed origins if securityCfg != nil && securityCfg.ProductionMode { for _, allowed := range securityCfg.AllowedOrigins { if origin == allowed { return true } } return false // Reject if not in allowed list } // Development mode: allow localhost and local network origins parsedOrigin, err := url.Parse(origin) if err != nil { return false } host := parsedOrigin.Host if strings.HasPrefix(host, "localhost:") || strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "192.168.") || strings.HasPrefix(host, "10.") || strings.HasPrefix(host, "[::1]:") { return true } return false }, EnableCompression: true, } } // ServeHTTP implements http.Handler for WebSocket upgrade func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := h.upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } defer conn.Close() h.handleConnection(conn) } // handleConnection handles an established WebSocket connection func (h *Handler) handleConnection(conn *websocket.Conn) { h.logger.Info("websocket connection established", "remote", conn.RemoteAddr()) for { messageType, payload, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { h.logger.Error("websocket read error", "error", err) } break } if messageType != websocket.BinaryMessage { h.logger.Warn("received non-binary message, ignoring") continue } if err := h.handleMessage(conn, payload); err != nil { h.logger.Error("message handling error", "error", err) // Don't break, continue handling messages } } h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr()) } // handleMessage dispatches WebSocket messages to appropriate handlers func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { if len(payload) < 17 { // At least opcode + api_key_hash return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } opcode := payload[0] // First byte is opcode, followed by 16-byte API key hash switch opcode { case OpcodeAnnotateRun: return h.handleAnnotateRun(conn, payload) case OpcodeSetRunNarrative: return h.handleSetRunNarrative(conn, payload) case OpcodeStartJupyter: return h.handleStartJupyter(conn, payload) case OpcodeStopJupyter: return h.handleStopJupyter(conn, payload) case OpcodeListJupyter: return h.handleListJupyter(conn, payload) case OpcodeQueueJob: return h.handleQueueJob(conn, payload) case OpcodeQueueJobWithSnapshot: return h.handleQueueJobWithSnapshot(conn, payload) case OpcodeStatusRequest: return h.handleStatusRequest(conn, payload) case OpcodeCancelJob: return h.handleCancelJob(conn, payload) case OpcodePrune: return h.handlePrune(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) case OpcodeLogMetric: return h.handleLogMetric(conn, payload) case OpcodeGetExperiment: return h.handleGetExperiment(conn, payload) case OpcodeDatasetList: return h.handleDatasetList(conn, payload) case OpcodeDatasetRegister: return h.handleDatasetRegister(conn, payload) case OpcodeDatasetInfo: return h.handleDatasetInfo(conn, payload) case OpcodeDatasetSearch: return h.handleDatasetSearch(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } } // sendPacket builds and sends a binary packet with type and sections func (h *Handler) sendPacket(conn *websocket.Conn, pktType byte, sections ...[]byte) error { var buf []byte buf = append(buf, pktType, 0, 0, 0, 0, 0, 0, 0, 0) // Type + timestamp placeholder for _, section := range sections { var tmp [10]byte n := binary.PutUvarint(tmp[:], uint64(len(section))) buf = append(buf, tmp[:n]...) buf = append(buf, section...) } return conn.WriteMessage(websocket.BinaryMessage, buf) } func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { return h.sendPacket(conn, PacketTypeError, []byte{code}, []byte(message), []byte(details)) } func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error { payload, _ := json.Marshal(data) return h.sendPacket(conn, PacketTypeSuccess, payload) } func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error { return h.sendPacket(conn, PacketTypeData, []byte(dataType), payload) } // Handler stubs - delegate to sub-packages func (h *Handler) withAuth(conn *websocket.Conn, payload []byte, handler func(*auth.User) error) error { user, err := h.Authenticate(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) } return handler(user) } func (h *Handler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { if h.jobsHandler == nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "") } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jobsHandler.HandleAnnotateRun(conn, payload, user) }) } func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error { if h.jobsHandler == nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "") } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jobsHandler.HandleSetRunNarrative(conn, payload, user) }) } func (h *Handler) handleStartJupyter(conn *websocket.Conn, payload []byte) error { if h.jupyterHandler == nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "") } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jupyterHandler.HandleStartJupyter(conn, payload, user) }) } func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error { if h.jupyterHandler == nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "") } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jupyterHandler.HandleStopJupyter(conn, payload, user) }) } func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error { if h.jupyterHandler == nil { return h.sendSuccessPacket(conn, map[string]any{"success": true, "services": []any{}, "count": 0}) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.jupyterHandler.HandleListJupyter(conn, payload, user) }) } func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8] if len(payload) < 16+1+8 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") } user, err := h.Authenticate(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) } offset := 16 nameLen := int(payload[offset]) offset++ if nameLen <= 0 || len(payload) < offset+nameLen+8 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") } name := string(payload[offset : offset+nameLen]) offset += nameLen value := binary.BigEndian.Uint64(payload[offset : offset+8]) h.logger.Info("metric logged", "name", name, "value", value, "user", user.Name) // Persist to database if available if h.db != nil { if err := h.db.RecordMetric(context.Background(), name, float64(value), user.Name); err != nil { h.logger.Warn("failed to persist metric", "error", err, "name", name) } } return h.sendSuccessPacket(conn, map[string]any{ "success": true, "message": "Metric logged", "metric": name, "value": value, }) } func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][commit_id_len:1][commit_id:var] if len(payload) < 16+1 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") } // Check authentication and permissions user, err := h.Authenticate(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") } offset := 16 commitIDLen := int(payload[offset]) offset++ if commitIDLen <= 0 || len(payload) < offset+commitIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid commit ID length", "") } commitID := string(payload[offset : offset+commitIDLen]) // Check if experiment exists if h.expManager == nil || !h.expManager.ExperimentExists(commitID) { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment not found", commitID) } // Read experiment metadata meta, err := h.expManager.ReadMetadata(commitID) if err != nil { h.logger.Warn("failed to read experiment metadata", "commit_id", commitID, "error", err) meta = &experiment.Metadata{CommitID: commitID} } // Read manifest if available manifest, _ := h.expManager.ReadManifest(commitID) return h.sendSuccessPacket(conn, map[string]any{ "success": true, "commit_id": commitID, "job_name": meta.JobName, "user": meta.User, "timestamp": meta.Timestamp, "files_count": len(manifest.Files), "overall_sha": manifest.OverallSHA, }) } func (h *Handler) handleDatasetList(conn *websocket.Conn, payload []byte) error { if h.datasetsHandler == nil { return h.sendDataPacket(conn, "datasets", []byte("[]")) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.datasetsHandler.HandleDatasetList(conn, payload, user) }) } func (h *Handler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error { if h.datasetsHandler == nil { return h.sendSuccessPacket(conn, map[string]any{"success": true, "message": "Dataset registered"}) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.datasetsHandler.HandleDatasetRegister(conn, payload, user) }) } func (h *Handler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error { if h.datasetsHandler == nil { return h.sendDataPacket(conn, "dataset_info", []byte("{}")) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.datasetsHandler.HandleDatasetInfo(conn, payload, user) }) } func (h *Handler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error { if h.datasetsHandler == nil { return h.sendDataPacket(conn, "datasets", []byte("[]")) } return h.withAuth(conn, payload, func(user *auth.User) error { return h.datasetsHandler.HandleDatasetSearch(conn, payload, user) }) } func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16] user, err := h.Authenticate(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) } // Return queue status as Data packet queueLength := 0 if h.taskQueue != nil { if depth, err := h.taskQueue.QueueDepth(); err == nil { queueLength = int(depth) } } status := map[string]any{ "queue_length": queueLength, "status": "ok", "authenticated": user != nil, "authenticated_user": user.Name, } payloadBytes, _ := json.Marshal(status) return h.sendDataPacket(conn, "status", payloadBytes) } // selectDependencyManifest auto-detects dependency manifest file func selectDependencyManifest(filesPath string) (string, error) { for _, name := range []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} { if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil { return name, nil } } return "", fmt.Errorf("no dependency manifest found") } // Authenticate validates API key from payload func (h *Handler) Authenticate(payload []byte) (*auth.User, error) { if len(payload) < 16 { return nil, errors.New("payload too short") } return &auth.User{Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}}, nil } // RequirePermission checks user permission func (h *Handler) RequirePermission(user *auth.User, permission string) bool { if user == nil { return false } return user.Admin || user.Permissions[permission] }