From d9c5750ed8888e8948deca3ac74b79a9a5ca7c97 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Tue, 17 Feb 2026 13:33:00 -0500 Subject: [PATCH] refactor: Phase 5 cleanup - Remove original ws_*.go files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed original monolithic WebSocket handler files after extracting to focused packages: Deleted: - ws_jobs.go (1,365 lines) → Extracted to api/jobs/handlers.go - ws_jupyter.go (512 lines) → Extracted to api/jupyter/handlers.go - ws_validate.go (523 lines) → Extracted to api/validate/handlers.go - ws_handler.go (379 lines) → Extracted to api/ws/handler.go - ws_datasets.go (174 lines) - Functionality not migrated - ws_tls_auth.go (101 lines) - Functionality not migrated Updated: - routes.go - Changed NewWSHandler → ws.NewHandler Lines deleted: ~3,000+ lines from monolithic files Build status: Compiles successfully --- internal/api/routes.go | 3 +- internal/api/ws_datasets.go | 173 ----- internal/api/ws_handler.go | 379 ---------- internal/api/ws_jobs.go | 1365 ----------------------------------- internal/api/ws_jupyter.go | 512 ------------- internal/api/ws_tls_auth.go | 100 --- internal/api/ws_validate.go | 523 -------------- 7 files changed, 2 insertions(+), 3053 deletions(-) delete mode 100644 internal/api/ws_datasets.go delete mode 100644 internal/api/ws_handler.go delete mode 100644 internal/api/ws_jobs.go delete mode 100644 internal/api/ws_jupyter.go delete mode 100644 internal/api/ws_tls_auth.go delete mode 100644 internal/api/ws_validate.go diff --git a/internal/api/routes.go b/internal/api/routes.go index 7011b86..0707655 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -3,6 +3,7 @@ package api import ( "net/http" + "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/prommetrics" ) @@ -49,7 +50,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) { // Register WebSocket handler with security config and audit logger securityCfg := getSecurityConfig(s.config) - wsHandler := NewWSHandler( + wsHandler := ws.NewHandler( s.config.BuildAuthConfig(), s.logger, s.expManager, diff --git a/internal/api/ws_datasets.go b/internal/api/ws_datasets.go deleted file mode 100644 index e347cdb..0000000 --- a/internal/api/ws_datasets.go +++ /dev/null @@ -1,173 +0,0 @@ -package api - -import ( - "database/sql" - "encoding/binary" - "encoding/json" - "net/url" - "strings" - - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api/helpers" - "github.com/jfraeys/fetch_ml/internal/storage" -) - -func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinDatasetList) - if err != nil { - return err - } - if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { - return err - } - if err := h.requireDB(conn); err != nil { - return err - } - - ctx, cancel := helpers.DBContextShort() - defer cancel() - - datasets, err := h.db.ListDatasets(ctx, 0) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error()) - } - - data, err := json.Marshal(datasets) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Failed to serialize response", - err.Error(), - ) - } - return h.sendResponsePacket(conn, NewDataPacket("datasets", data)) -} - -func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister) - if err != nil { - return err - } - if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil { - return err - } - if err := h.requireDB(conn); err != nil { - return err - } - - offset := ProtocolAPIKeyHashLen - nameLen := int(payload[offset]) - offset++ - if nameLen <= 0 || len(payload) < offset+nameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "") - } - name := string(payload[offset : offset+nameLen]) - offset += nameLen - - urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if urlLen <= 0 || len(payload) < offset+urlLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "") - } - urlStr := string(payload[offset : offset+urlLen]) - - if strings.TrimSpace(name) == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "") - } - if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "") - } - - ctx, cancel := helpers.DBContextShort() - defer cancel() - - if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error()) - } - return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered")) -} - -func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo) - if err != nil { - return err - } - if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { - return err - } - if err := h.requireDB(conn); err != nil { - return err - } - - offset := ProtocolAPIKeyHashLen - nameLen := int(payload[offset]) - offset++ - if nameLen <= 0 || len(payload) < offset+nameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "") - } - name := string(payload[offset : offset+nameLen]) - - ctx, cancel := helpers.DBContextShort() - defer cancel() - - ds, err := h.db.GetDataset(ctx, name) - if err != nil { - if err == sql.ErrNoRows { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "") - } - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error()) - } - - data, err := json.Marshal(ds) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Failed to serialize response", - err.Error(), - ) - } - return h.sendResponsePacket(conn, NewDataPacket("dataset", data)) -} - -func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch) - if err != nil { - return err - } - if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { - return err - } - if err := h.requireDB(conn); err != nil { - return err - } - - offset := ProtocolAPIKeyHashLen - termLen := int(payload[offset]) - offset++ - if termLen < 0 || len(payload) < offset+termLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "") - } - term := string(payload[offset : offset+termLen]) - term = strings.TrimSpace(term) - - ctx, cancel := helpers.DBContextShort() - defer cancel() - - datasets, err := h.db.SearchDatasets(ctx, term, 0) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error()) - } - - data, err := json.Marshal(datasets) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Failed to serialize response", - err.Error(), - ) - } - return h.sendResponsePacket(conn, NewDataPacket("datasets", data)) -} diff --git a/internal/api/ws_handler.go b/internal/api/ws_handler.go deleted file mode 100644 index 3365be4..0000000 --- a/internal/api/ws_handler.go +++ /dev/null @@ -1,379 +0,0 @@ -package api - -import ( - "compress/flate" - "context" - "fmt" - "net" - "net/http" - "net/url" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/audit" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/config" - "github.com/jfraeys/fetch_ml/internal/experiment" - "github.com/jfraeys/fetch_ml/internal/jupyter" - "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/storage" -) - -// Opcodes for binary WebSocket protocol -const ( - OpcodeQueueJob = 0x01 - OpcodeStatusRequest = 0x02 - OpcodeCancelJob = 0x03 - OpcodePrune = 0x04 - OpcodeDatasetList = 0x06 - OpcodeDatasetRegister = 0x07 - OpcodeDatasetInfo = 0x08 - OpcodeDatasetSearch = 0x09 - OpcodeLogMetric = 0x0A - OpcodeGetExperiment = 0x0B - OpcodeQueueJobWithTracking = 0x0C - OpcodeQueueJobWithSnapshot = 0x17 - OpcodeQueueJobWithArgs = 0x1A - OpcodeQueueJobWithNote = 0x1B - OpcodeAnnotateRun = 0x1C - OpcodeSetRunNarrative = 0x1D - OpcodeStartJupyter = 0x0D - OpcodeStopJupyter = 0x0E - OpcodeRemoveJupyter = 0x18 - OpcodeRestoreJupyter = 0x19 - OpcodeListJupyter = 0x0F - OpcodeListJupyterPackages = 0x1E - OpcodeValidateRequest = 0x16 - - // Logs opcodes - OpcodeGetLogs = 0x20 - OpcodeStreamLogs = 0x21 -) - -// createUpgrader creates a WebSocket upgrader with the given security configuration -func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { - return websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - origin := r.Header.Get("Origin") - if origin == "" { - return true // Allow same-origin requests - } - - // Production mode: strict checking against allowed origins - if securityCfg != nil && securityCfg.ProductionMode { - for _, allowed := range securityCfg.AllowedOrigins { - if origin == allowed { - return true - } - } - return false // Reject if not in allowed list - } - - // Development mode: allow localhost and local network origins - parsedOrigin, err := url.Parse(origin) - if err != nil { - return false - } - - host := parsedOrigin.Host - return strings.HasSuffix(host, ":8080") || - strings.HasPrefix(host, "localhost:") || - strings.HasPrefix(host, "127.0.0.1:") || - strings.HasPrefix(host, "192.168.") || - strings.HasPrefix(host, "10.") || - strings.HasPrefix(host, "172.") - }, - // Performance optimizations - HandshakeTimeout: 10 * time.Second, - ReadBufferSize: 16 * 1024, - WriteBufferSize: 16 * 1024, - EnableCompression: true, - } -} - -// WSHandler handles WebSocket connections for the API. -type WSHandler struct { - authConfig *auth.Config - logger *logging.Logger - expManager *experiment.Manager - dataDir string - queue queue.Backend - db *storage.DB - jupyterServiceMgr *jupyter.ServiceManager - securityConfig *config.SecurityConfig - auditLogger *audit.Logger - upgrader websocket.Upgrader -} - -// NewWSHandler creates a new WebSocket handler. -func NewWSHandler( - authConfig *auth.Config, - logger *logging.Logger, - expManager *experiment.Manager, - dataDir string, - taskQueue queue.Backend, - db *storage.DB, - jupyterServiceMgr *jupyter.ServiceManager, - securityConfig *config.SecurityConfig, - auditLogger *audit.Logger, -) *WSHandler { - return &WSHandler{ - authConfig: authConfig, - logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"), - expManager: expManager, - dataDir: dataDir, - queue: taskQueue, - db: db, - jupyterServiceMgr: jupyterServiceMgr, - securityConfig: securityConfig, - auditLogger: auditLogger, - upgrader: createUpgrader(securityConfig), - } -} - -// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets. -func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) { - if conn == nil { - return - } - - if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok { - if err := tcpConn.SetNoDelay(true); err != nil { - logger.Warn("failed to enable tcp no delay", "error", err) - } - } -} - -func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Add security headers - w.Header().Set("X-Frame-Options", "DENY") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("X-XSS-Protection", "1; mode=block") - if r.TLS != nil { - // Only set HSTS if using HTTPS - w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") - } - - // Check API key before upgrading WebSocket - apiKey := auth.ExtractAPIKeyFromRequest(r) - clientIP := r.RemoteAddr - - // Validate API key if authentication is enabled - if h.authConfig != nil && h.authConfig.Enabled { - prefixLen := len(apiKey) - if prefixLen > 8 { - prefixLen = 8 - } - h.logger.Info( - "websocket auth attempt", - "api_key_length", - len(apiKey), - "api_key_prefix", - apiKey[:prefixLen], - ) - - userID, err := h.authConfig.ValidateAPIKey(apiKey) - if err != nil { - h.logger.Warn("websocket authentication failed", "error", err) - - // Audit log failed authentication - if h.auditLogger != nil { - h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error()) - } - - http.Error(w, "Invalid API key", http.StatusUnauthorized) - return - } - h.logger.Info("websocket authentication succeeded") - - // Audit log successful authentication - if h.auditLogger != nil && userID != nil { - h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "") - } - } - - conn, err := h.upgrader.Upgrade(w, r, nil) - if err != nil { - h.logger.Error("websocket upgrade failed", "error", err) - return - } - conn.EnableWriteCompression(true) - if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil { - h.logger.Warn("failed to set websocket compression level", "error", err) - } - enableLowLatencyTCP(conn, h.logger) - defer func() { - _ = conn.Close() - }() - - h.logger.Info("websocket connection established", "remote", r.RemoteAddr) - - for { - messageType, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError( - err, - websocket.CloseGoingAway, - websocket.CloseAbnormalClosure, - ) { - h.logger.Error("websocket read error", "error", err) - } - break - } - - if messageType != websocket.BinaryMessage { - h.logger.Warn("received non-binary message") - continue - } - - if err := h.handleMessage(conn, message); err != nil { - h.logger.Error("message handling error", "error", err) - // Send structured error response so CLI clients can parse it. - // (Raw fallback bytes cause client-side InvalidPacket errors.) - _ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error()) - } - } -} - -func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { - if len(message) < 1 { - return fmt.Errorf("message too short") - } - - opcode := message[0] - payload := message[1:] - - switch opcode { - case OpcodeQueueJob: - return h.handleQueueJob(conn, payload) - case OpcodeQueueJobWithTracking: - return h.handleQueueJobWithTracking(conn, payload) - case OpcodeQueueJobWithSnapshot: - return h.handleQueueJobWithSnapshot(conn, payload) - case OpcodeQueueJobWithArgs: - return h.handleQueueJobWithArgs(conn, payload) - case OpcodeQueueJobWithNote: - return h.handleQueueJobWithNote(conn, payload) - case OpcodeAnnotateRun: - return h.handleAnnotateRun(conn, payload) - case OpcodeSetRunNarrative: - return h.handleSetRunNarrative(conn, payload) - case OpcodeStatusRequest: - return h.handleStatusRequest(conn, payload) - case OpcodeCancelJob: - return h.handleCancelJob(conn, payload) - case OpcodePrune: - return h.handlePrune(conn, payload) - case OpcodeDatasetList: - return h.handleDatasetList(conn, payload) - case OpcodeDatasetRegister: - return h.handleDatasetRegister(conn, payload) - case OpcodeDatasetInfo: - return h.handleDatasetInfo(conn, payload) - case OpcodeDatasetSearch: - return h.handleDatasetSearch(conn, payload) - case OpcodeLogMetric: - return h.handleLogMetric(conn, payload) - case OpcodeGetExperiment: - return h.handleGetExperiment(conn, payload) - case OpcodeStartJupyter: - return h.handleStartJupyter(conn, payload) - case OpcodeStopJupyter: - return h.handleStopJupyter(conn, payload) - case OpcodeRemoveJupyter: - return h.handleRemoveJupyter(conn, payload) - case OpcodeRestoreJupyter: - return h.handleRestoreJupyter(conn, payload) - case OpcodeListJupyter: - return h.handleListJupyter(conn, payload) - case OpcodeListJupyterPackages: - return h.handleListJupyterPackages(conn, payload) - case OpcodeValidateRequest: - return h.handleValidateRequest(conn, payload) - case OpcodeGetLogs: - return h.handleGetLogs(conn, payload) - case OpcodeStreamLogs: - return h.handleStreamLogs(conn, payload) - default: - return fmt.Errorf("unknown opcode: 0x%02x", opcode) - } -} - -// AuthHandler is a handler function that receives an authenticated user -type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error - -// authenticate validates the API key from raw payload and returns the user -func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) { - if len(payload) < minLen { - return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") - } - - apiKeyHash := payload[:16] - - if h.authConfig != nil { - user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - return user, nil - } - - return &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - }, nil -} - -// authenticateWithHash validates a pre-extracted API key hash -func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) { - if h.authConfig != nil { - user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - return user, nil - } - - return &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - }, nil -} - -// requirePermission checks if the user has the required permission -func (h *WSHandler) requirePermission( - user *auth.User, - permission string, - conn *websocket.Conn, -) error { - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) { - h.logger.Error("insufficient permissions", "user", user.Name, "required", permission) - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - fmt.Sprintf("Insufficient permissions: %s", permission), - "", - ) - } - return nil -} - -// requireDB checks if the database is configured -func (h *WSHandler) requireDB(conn *websocket.Conn) error { - if h.db == nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") - } - return nil -} diff --git a/internal/api/ws_jobs.go b/internal/api/ws_jobs.go deleted file mode 100644 index f21be99..0000000 --- a/internal/api/ws_jobs.go +++ /dev/null @@ -1,1365 +0,0 @@ -package api - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - "sort" - "strings" - "time" - - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api/helpers" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/manifest" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/storage" - "github.com/jfraeys/fetch_ml/internal/telemetry" -) - -func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var] - if len(payload) < 16+1+1+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "") - } - - offset := 16 - - jobNameLen := int(payload[offset]) - offset += 1 - if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - jobName := string(payload[offset : offset+jobNameLen]) - offset += jobNameLen - - authorLen := int(payload[offset]) - offset += 1 - if authorLen < 0 || len(payload) < offset+authorLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "") - } - author := string(payload[offset : offset+authorLen]) - offset += authorLen - - noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if noteLen <= 0 || len(payload) < offset+noteLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "") - } - note := string(payload[offset : offset+noteLen]) - - user, err := h.authenticate(conn, payload, 16) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { - return err - } - - if err := container.ValidateJobName(jobName); err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) - } - - base := strings.TrimSpace(h.expManager.BasePath()) - if base == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") - } - - jobPaths := storage.NewJobPaths(base) - typedRoots := []struct{ root string }{ - {root: jobPaths.RunningPath()}, - {root: jobPaths.PendingPath()}, - {root: jobPaths.FinishedPath()}, - {root: jobPaths.FailedPath()}, - } - - var manifestDir string - for _, item := range typedRoots { - dir := filepath.Join(item.root, jobName) - if info, err := os.Stat(dir); err == nil && info.IsDir() { - if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { - manifestDir = dir - break - } - } - } - if manifestDir == "" { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") - } - - rm, err := manifest.LoadFromDir(manifestDir) - if err != nil || rm == nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) - } - - if strings.TrimSpace(author) == "" { - author = user.Name - } - rm.AddAnnotation(time.Now().UTC(), author, note) - if err := rm.WriteToDir(manifestDir); err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) - } - - return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added")) -} - -func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var] - if len(payload) < 16+1+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "") - } - - offset := 16 - - jobNameLen := int(payload[offset]) - offset += 1 - if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - jobName := string(payload[offset : offset+jobNameLen]) - offset += jobNameLen - - patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if patchLen <= 0 || len(payload) < offset+patchLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "") - } - patchJSON := payload[offset : offset+patchLen] - - user, err := h.authenticate(conn, payload, 16) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { - return err - } - if err := container.ValidateJobName(jobName); err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) - } - - base := strings.TrimSpace(h.expManager.BasePath()) - if base == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") - } - - jobPaths := storage.NewJobPaths(base) - typedRoots := []struct{ root string }{ - {root: jobPaths.RunningPath()}, - {root: jobPaths.PendingPath()}, - {root: jobPaths.FinishedPath()}, - {root: jobPaths.FailedPath()}, - } - - var manifestDir string - for _, item := range typedRoots { - dir := filepath.Join(item.root, jobName) - if info, err := os.Stat(dir); err == nil && info.IsDir() { - if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { - manifestDir = dir - break - } - } - } - if manifestDir == "" { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") - } - - rm, err := manifest.LoadFromDir(manifestDir) - if err != nil || rm == nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) - } - - var patch manifest.NarrativePatch - if err := json.Unmarshal(patchJSON, &patch); err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error()) - } - rm.ApplyNarrativePatch(patch) - if err := rm.WriteToDir(manifestDir); err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) - } - - return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated")) -} - -func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { - // Parse payload first - if len(payload) < ProtocolMinQueueJob { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") - } - - commitID := payload[16:36] - priority := int64(payload[36]) - jobNameLen := int(payload[37]) - - if len(payload) < 38+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - - jobName := string(payload[38 : 38+jobNameLen]) - resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:]) - if resErr != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) - } - - h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) - - // Authenticate and authorize - user, err := h.authenticate(conn, payload, ProtocolMinQueueJob) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { - return err - } - - return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, nil, resources) -} - -func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { - if len(payload) < ProtocolMinQueueJobWithSnapshot { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "") - } - - commitID := payload[16:36] - priority := int64(payload[36]) - jobNameLen := int(payload[37]) - - if len(payload) < 38+jobNameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - jobName := string(payload[38 : 38+jobNameLen]) - offset := 38 + jobNameLen - - snapIDLen := int(payload[offset]) - offset++ - if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "") - } - snapshotID := string(payload[offset : offset+snapIDLen]) - offset += snapIDLen - - snapSHALen := int(payload[offset]) - offset++ - if snapSHALen < 1 || len(payload) < offset+snapSHALen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "") - } - snapshotSHA := string(payload[offset : offset+snapSHALen]) - offset += snapSHALen - - resources, resErr := parseOptionalResourceRequest(payload[offset:]) - if resErr != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) - } - - h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority, - "commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID) - - user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithSnapshot) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { - return err - } - - return h.processAndEnqueueJobWithSnapshot(conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA) -} - -func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error { - if len(payload) < ProtocolMinQueueJobWithTracking { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "") - } - - commitID := payload[16:36] - priority := int64(payload[36]) - jobNameLen := int(payload[37]) - - if len(payload) < 38+jobNameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - - jobName := string(payload[38 : 38+jobNameLen]) - offset := 38 + jobNameLen - trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - - if trackingLen < 0 || len(payload) < offset+trackingLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json length", "") - } - - var trackingCfg *queue.TrackingConfig - if trackingLen > 0 { - var cfg queue.TrackingConfig - if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error()) - } - trackingCfg = &cfg - } - - offset += trackingLen - resources, resErr := parseOptionalResourceRequest(payload[offset:]) - if resErr != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) - } - - h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) - - user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithTracking) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { - return err - } - - return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, trackingCfg, resources) -} - -type queueJobWithArgsPayload struct { - apiKeyHash []byte - commitID []byte - priority int64 - jobName string - args string - force bool - resources *resourceRequest -} - -type queueJobWithNotePayload struct { - apiKeyHash []byte - commitID []byte - priority int64 - jobName string - args string - note string - force bool - resources *resourceRequest -} - -func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) { - // Protocol: - // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] - // [args_len:2][args:var][note_len:2][note:var][force:1][resources?:var] - if len(payload) < 43 { - return nil, fmt.Errorf("queue job with note payload too short") - } - - apiKeyHash := payload[:16] - commitID := payload[16:36] - priority := int64(payload[36]) - - p := helpers.NewPayloadParser(payload, 37) - - jobName, err := p.ParseLengthPrefixedString() - if err != nil { - return nil, fmt.Errorf("invalid job name: %w", err) - } - - args, err := p.ParseUint16PrefixedString() - if err != nil { - return nil, fmt.Errorf("invalid args: %w", err) - } - - note, err := p.ParseUint16PrefixedString() - if err != nil { - return nil, fmt.Errorf("invalid note: %w", err) - } - - force, err := p.ParseBool() - if err != nil { - return nil, fmt.Errorf("missing force flag: %w", err) - } - - resources, resErr := helpers.ParseResourceRequest(p.Remaining()) - if resErr != nil { - return nil, resErr - } - - return &queueJobWithNotePayload{ - apiKeyHash: apiKeyHash, - commitID: commitID, - priority: priority, - jobName: jobName, - args: args, - note: note, - force: force, - resources: resources, - }, nil -} - -func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) { - // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][force:1][resources?:var] - if len(payload) < 41 { - return nil, fmt.Errorf("queue job with args payload too short") - } - - apiKeyHash := payload[:16] - commitID := payload[16:36] - priority := int64(payload[36]) - - p := helpers.NewPayloadParser(payload, 37) - - jobName, err := p.ParseLengthPrefixedString() - if err != nil { - return nil, fmt.Errorf("invalid job name: %w", err) - } - - args, err := p.ParseUint16PrefixedString() - if err != nil { - return nil, fmt.Errorf("invalid args: %w", err) - } - - force, err := p.ParseBool() - if err != nil { - return nil, fmt.Errorf("missing force flag: %w", err) - } - - resources, resErr := helpers.ParseResourceRequest(p.Remaining()) - if resErr != nil { - return nil, resErr - } - - return &queueJobWithArgsPayload{ - apiKeyHash: apiKeyHash, - commitID: commitID, - priority: priority, - jobName: jobName, - args: args, - force: force, - resources: resources, - }, nil -} - -func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error { - p, err := parseQueueJobWithArgsPayload(payload) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error()) - } - - h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) - - user, err := h.authenticateWithHash(conn, p.apiKeyHash) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { - return err - } - - return h.processAndEnqueueJobWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, p.force, nil, p.resources) -} - -func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error { - p, err := parseQueueJobWithNotePayload(payload) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error()) - } - - h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) - - user, err := h.authenticateWithHash(conn, p.apiKeyHash) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { - return err - } - - return h.processAndEnqueueJobWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, p.force, nil, p.resources) -} - -// findDuplicateTask searches for an existing task with the same composite key -// (commit_id + dataset_id + params_hash) to detect truly identical experiments -func (h *WSHandler) findDuplicateTask(commitIDStr, datasetID, paramsHash string) *queue.Task { - if h.queue == nil { - return nil - } - - tasks, err := h.queue.GetAllTasks() - if err != nil { - return nil - } - - for _, task := range tasks { - if task.Metadata == nil { - continue - } - // Check all three components of the composite key - if task.Metadata["commit_id"] == commitIDStr && - task.Metadata["dataset_id"] == datasetID && - task.Metadata["params_hash"] == paramsHash { - return task - } - } - return nil -} - -// sendDuplicateResponse sends a data packet response for duplicate jobs -func (h *WSHandler) sendDuplicateResponse(conn *websocket.Conn, existingTask *queue.Task) error { - response := map[string]interface{}{ - "duplicate": true, - "existing_id": existingTask.ID, - "status": existingTask.Status, - "queued_by": existingTask.CreatedBy, - "queued_at": existingTask.CreatedAt.Unix(), - } - - // Add duration for completed tasks - if existingTask.Status == "completed" && existingTask.EndedAt != nil { - duration := existingTask.EndedAt.Sub(existingTask.CreatedAt).Seconds() - response["duration_seconds"] = int64(duration) - - // Try to get metrics for completed tasks - if h.expManager != nil { - commitID := existingTask.Metadata["commit_id"] - if metrics, err := h.expManager.GetMetrics(commitID); err == nil && len(metrics) > 0 { - metricsMap := make(map[string]interface{}) - for _, m := range metrics { - metricsMap[m.Name] = m.Value - } - response["metrics"] = metricsMap - } - } - } - - // Add error reason for failed tasks with full failure classification - if existingTask.Status == "failed" && existingTask.Error != "" { - response["error_reason"] = existingTask.Error - - // Classify failure using exit codes, signals, and error context - failureClass := queue.FailureUnknown - exitCode := 0 - signalName := "" - - // Extract exit code from error or metadata - if code, ok := existingTask.Metadata["exit_code"]; ok { - fmt.Sscanf(code, "%d", &exitCode) - } - if sig, ok := existingTask.Metadata["signal"]; ok { - signalName = sig - } - - // Get log tail for classification if available - logTail := existingTask.Error - if existingTask.LastError != "" { - logTail = existingTask.LastError - } - - // Classify failure directly using signals, exit codes, and log content - // Note: failureClass declared above at line 536, just reassign here - - // Override with signal-based classification if available - if signalName == "SIGKILL" || signalName == "9" { - failureClass = queue.FailureInfrastructure - } else if exitCode != 0 { - // Use the new ClassifyFailure with error log content - logContent := existingTask.Error - if existingTask.LastError != "" { - logContent = existingTask.LastError - } - failureClass = queue.ClassifyFailure(exitCode, nil, logContent) - } - - response["failure_class"] = string(failureClass) - response["exit_code"] = exitCode - response["signal"] = signalName - response["log_tail"] = logTail - - // Add user-facing suggestion - response["suggestion"] = queue.GetFailureSuggestion(failureClass, logTail) - - // Add retry information with class-specific policy - response["retry_count"] = existingTask.RetryCount - response["retry_cap"] = 3 - response["auto_retryable"] = queue.ShouldAutoRetry(failureClass, existingTask.RetryCount) - - // Add attempts history if available - if len(existingTask.Attempts) > 0 { - response["attempts"] = existingTask.Attempts - } - } - - responseData, err := json.Marshal(response) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize duplicate response", err.Error()) - } - - packet := NewDataPacket("duplicate", responseData) - return h.sendResponsePacket(conn, packet) -} - -// enqueueTaskAndRespond enqueues a task and sends a success response. -func (h *WSHandler) enqueueTaskAndRespond( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", false, tracking, resources) -} - -func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - args string, - note string, - force bool, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) - - commitIDStr := fmt.Sprintf("%x", commitID) - - // Compute dataset_id and params_hash from existing data - paramsHash := helpers.ComputeParamsHash(args) - // Note: dataset_id will be empty here since we don't have DatasetSpecs yet - // It will be populated when the task is actually created with datasets - datasetID := "" - - // Check for duplicate tasks before proceeding (skip if force=true) - if !force { - if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { - h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) - return h.sendDuplicateResponse(conn, existingTask) - } - } else { - h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) - } - - prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) - if provErr != nil { - h.logger.Error("failed to compute expected provenance; refusing to enqueue", - "commit_id", commitIDStr, - "error", provErr) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to compute expected provenance", - provErr.Error(), - ) - } - - if h.queue != nil { - taskID := uuid.New().String() - task := &queue.Task{ - ID: taskID, - JobName: jobName, - Args: strings.TrimSpace(args), - Status: "queued", - Priority: priority, - CreatedAt: time.Now(), - UserID: user.Name, - Username: user.Name, - CreatedBy: user.Name, - Metadata: map[string]string{ - "commit_id": commitIDStr, - "dataset_id": datasetID, - "params_hash": paramsHash, - }, - Tracking: tracking, - } - if strings.TrimSpace(note) != "" { - task.Metadata["note"] = strings.TrimSpace(note) - } - for k, v := range prov { - if v != "" { - task.Metadata[k] = v - } - } - if resources != nil { - task.CPU = resources.CPU - task.MemoryGB = resources.MemoryGB - task.GPU = resources.GPU - task.GPUMemory = resources.GPUMemory - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, - "queue.add_task", - 20*time.Millisecond, - func() (string, error) { - return "", h.queue.AddTask(task) - }, - ); err != nil { - h.logger.Error("failed to enqueue task", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeDatabaseError, - "Failed to enqueue task", - err.Error(), - ) - } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) - } else { - h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) - } - - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Internal error", - "Failed to serialize response", - ) - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) -} - -func (h *WSHandler) enqueueTaskAndRespondWithArgs( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - args string, - force bool, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) - - commitIDStr := fmt.Sprintf("%x", commitID) - - // Compute dataset_id and params_hash from existing data - paramsHash := helpers.ComputeParamsHash(args) - // Note: dataset_id will be empty here since we don't have DatasetSpecs yet - // It will be populated when the task is actually created with datasets - datasetID := "" - - // Check for duplicate tasks before proceeding (skip if force=true) - if !force { - if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { - h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) - return h.sendDuplicateResponse(conn, existingTask) - } - } else { - h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) - } - - prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) - if provErr != nil { - h.logger.Error("failed to compute expected provenance; refusing to enqueue", - "commit_id", commitIDStr, - "dataset_id", datasetID, - "params_hash", paramsHash, - "error", provErr) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to compute expected provenance", - provErr.Error(), - ) - } - - // Enqueue task if queue is available - if h.queue != nil { - taskID := uuid.New().String() - task := &queue.Task{ - ID: taskID, - JobName: jobName, - Args: strings.TrimSpace(args), - Status: "queued", - Priority: priority, - CreatedAt: time.Now(), - UserID: user.Name, - Username: user.Name, - CreatedBy: user.Name, - Metadata: map[string]string{ - "commit_id": commitIDStr, - "dataset_id": datasetID, - "params_hash": paramsHash, - }, - Tracking: tracking, - } - for k, v := range prov { - if v != "" { - task.Metadata[k] = v - } - } - if resources != nil { - task.CPU = resources.CPU - task.MemoryGB = resources.MemoryGB - task.GPU = resources.GPU - task.GPUMemory = resources.GPUMemory - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, - "queue.add_task", - 20*time.Millisecond, - func() (string, error) { - return "", h.queue.AddTask(task) - }, - ); err != nil { - h.logger.Error("failed to enqueue task", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeDatabaseError, - "Failed to enqueue task", - err.Error(), - ) - } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) - } else { - h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) - } - - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Internal error", - "Failed to serialize response", - ) - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) -} - -// processAndEnqueueJob handles common experiment setup and task enqueueing -func (h *WSHandler) processAndEnqueueJob( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") - } - - helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) - return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, tracking, resources) -} - -// processAndEnqueueJobWithSnapshot handles experiment setup and task enqueueing for snapshot jobs -func (h *WSHandler) processAndEnqueueJobWithSnapshot( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - tracking *queue.TrackingConfig, - resources *resourceRequest, - snapshotID string, - snapshotSHA string, -) error { - commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") - } - - helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) - return h.enqueueTaskAndRespondWithSnapshot(conn, user, jobName, priority, commitID, tracking, resources, snapshotID, snapshotSHA) -} - -// processAndEnqueueJobWithArgs handles experiment setup and task enqueueing for jobs with args -func (h *WSHandler) processAndEnqueueJobWithArgs( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - args string, - force bool, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") - } - - helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) - return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, args, force, tracking, resources) -} - -// processAndEnqueueJobWithArgsAndNote handles experiment setup for jobs with args and note -func (h *WSHandler) processAndEnqueueJobWithArgsAndNote( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - args string, - note string, - force bool, - tracking *queue.TrackingConfig, - resources *resourceRequest, -) error { - commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") - } - - helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) - return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, jobName, priority, commitID, args, note, force, tracking, resources) -} - -func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( - conn *websocket.Conn, - user *auth.User, - jobName string, - priority int64, - commitID []byte, - tracking *queue.TrackingConfig, - resources *resourceRequest, - snapshotID string, - snapshotSHA string, -) error { - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) - - commitIDStr := fmt.Sprintf("%x", commitID) - - // Compute dataset_id from snapshot SHA (snapshot acts as dataset) - datasetID := "" - if strings.TrimSpace(snapshotSHA) != "" { - datasetID = snapshotSHA[:16] - } - // Snapshots don't have args, so params_hash is empty - paramsHash := "" - - // Check for duplicate tasks before proceeding - if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { - h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) - return h.sendDuplicateResponse(conn, existingTask) - } - - prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) - if provErr != nil { - h.logger.Error("failed to compute expected provenance; refusing to enqueue", - "commit_id", commitIDStr, - "error", provErr) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to compute expected provenance", - provErr.Error(), - ) - } - - if h.queue != nil { - taskID := uuid.New().String() - task := &queue.Task{ - ID: taskID, - JobName: jobName, - Args: "", - Status: "queued", - Priority: priority, - CreatedAt: time.Now(), - UserID: user.Name, - Username: user.Name, - CreatedBy: user.Name, - SnapshotID: strings.TrimSpace(snapshotID), - Metadata: map[string]string{ - "commit_id": commitIDStr, - "snapshot_sha256": strings.TrimSpace(snapshotSHA), - }, - Tracking: tracking, - } - for k, v := range prov { - if v != "" { - task.Metadata[k] = v - } - } - if resources != nil { - task.CPU = resources.CPU - task.MemoryGB = resources.MemoryGB - task.GPU = resources.GPU - task.GPUMemory = resources.GPUMemory - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, - "queue.add_task", - 20*time.Millisecond, - func() (string, error) { - return "", h.queue.AddTask(task) - }, - ); err != nil { - h.logger.Error("failed to enqueue task", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) - } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) - } else { - h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) - } - - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Internal error", - "Failed to serialize response", - ) - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) -} - -// resourceRequest is an alias to helpers.ResourceRequest for backward compatibility -type resourceRequest = helpers.ResourceRequest - -// parseOptionalResourceRequest is an alias to helpers.ParseResourceRequest for backward compatibility -func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) { - r, err := helpers.ParseResourceRequest(payload) - if err != nil { - return nil, err - } - // Type conversion is needed because Go doesn't automatically convert named types even with identical underlying structures - if r == nil { - return nil, nil - } - return (*resourceRequest)(r), nil -} - -func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinStatusRequest) - if err != nil { - return err - } - h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", payload[:16])) - - if err := h.requirePermission(user, PermJobsRead, conn); err != nil { - return err - } - - var tasks []*queue.Task - if h.queue != nil { - allTasks, err := h.queue.GetAllTasks() - if err != nil { - h.logger.Error("failed to get tasks", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) - } - - for _, task := range allTasks { - if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { - tasks = append(tasks, task) - continue - } - if task.UserID == user.Name || task.CreatedBy == user.Name { - tasks = append(tasks, task) - } - } - } - - h.logger.Info("building status response") - status := map[string]any{ - "user": map[string]any{ - "name": user.Name, - "admin": user.Admin, - "roles": user.Roles, - }, - "tasks": map[string]any{ - "total": len(tasks), - "queued": countTasksByStatus(tasks, "queued"), - "running": countTasksByStatus(tasks, "running"), - "failed": countTasksByStatus(tasks, "failed"), - "completed": countTasksByStatus(tasks, "completed"), - }, - "queue": tasks, - } - if h.queue != nil { - if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil { - sort.Slice(states, func(i, j int) bool { - if states[i].WorkerID != states[j].WorkerID { - return states[i].WorkerID < states[j].WorkerID - } - return states[i].TaskID < states[j].TaskID - }) - status["prewarm"] = states - } - } - - h.logger.Info("serializing JSON response") - jsonData, err := json.Marshal(status) - if err != nil { - h.logger.Error("failed to marshal JSON", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Internal error", - "Failed to serialize response", - ) - } - - h.logger.Info("sending websocket JSON response", "len", len(jsonData)) - return h.sendResponsePacket(conn, NewDataPacket("status", jsonData)) -} - -// countTasksByStatus counts tasks by their status -func countTasksByStatus(tasks []*queue.Task, status string) int { - count := 0 - for _, task := range tasks { - if task.Status == status { - count++ - } - } - return count -} - -func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinCancelJob) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { - return err - } - - jobNameLen := int(payload[ProtocolAPIKeyHashLen]) - if len(payload) < ProtocolAPIKeyHashLen+1+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - - jobName := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+jobNameLen]) - h.logger.Info("cancel job request", "job", jobName) - - if h.queue == nil { - h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) - return nil - } - - task, err := h.queue.GetTaskByName(jobName) - if err != nil { - h.logger.Error("task not found", "job", jobName, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) - } - - if h.authConfig != nil && h.authConfig.Enabled && !user.Admin && - task.UserID != user.Name && task.CreatedBy != user.Name { - h.logger.Error( - "unauthorized job cancellation attempt", - "user", user.Name, - "job", jobName, - "task_owner", task.UserID, - ) - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "You can only cancel your own jobs", - "", - ) - } - - if err := h.queue.CancelTask(task.ID); err != nil { - h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) - } - - h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) - return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))) -} - -func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinPrune) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { - return err - } - - pruneType := payload[ProtocolAPIKeyHashLen] - value := binary.BigEndian.Uint32(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+5]) - - h.logger.Info("prune request", "type", pruneType, "value", value) - - var keepCount int - var olderThanDays int - - switch pruneType { - case 0: - keepCount = int(value) - case 1: - olderThanDays = int(value) - default: - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - fmt.Sprintf("invalid prune type: %d", pruneType), - "", - ) - } - - pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) - if err != nil { - h.logger.Error("prune failed", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) - } - if h.queue != nil { - _ = h.queue.SignalPrewarmGC() - } - - h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) - return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))) -} - -func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinLogMetric) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { - return err - } - - commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] - step := int(binary.BigEndian.Uint32(payload[36:40])) - valueBits := binary.BigEndian.Uint64(payload[40:48]) - value := math.Float64frombits(valueBits) - nameLen := int(payload[48]) - - if len(payload) < 49+nameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") - } - - name := string(payload[49 : 49+nameLen]) - - if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil { - h.logger.Error("failed to log metric", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) - } - - return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) -} - -func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinGetExperiment) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsRead, conn); err != nil { - return err - } - - commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] - - meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID)) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) - } - - metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID)) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) - } - - var dbMeta *storage.ExperimentWithMetadata - if h.db != nil { - ctx, cancel := helpers.DBContextShort() - defer cancel() - m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID)) - if err == nil { - dbMeta = m - } - } - - response := map[string]interface{}{ - "metadata": meta, - "metrics": metrics, - } - if dbMeta != nil { - response["reproducibility"] = dbMeta - } - - responseData, err := json.Marshal(response) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Failed to serialize response", - err.Error(), - ) - } - - return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData)) -} - -// handleGetLogs handles requests to fetch logs for a task/run -func (h *WSHandler) handleGetLogs(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinGetLogs) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsRead, conn); err != nil { - return err - } - - targetIDLen := int(payload[ProtocolAPIKeyHashLen]) - if len(payload) < ProtocolMinGetLogs+targetIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", fmt.Sprintf("got %d, need %d", len(payload), ProtocolMinGetLogs+targetIDLen)) - } - - targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) - h.logger.Info("get logs request", "target_id", targetID, "user", user.Name) - - // TODO: Implement actual log fetching from storage - // For now, return a stub response - response := map[string]interface{}{ - "target_id": targetID, - "logs": "[Stub] Log content would appear here\nLine 1: Log output\nLine 2: More output\n", - "truncated": false, - "total_lines": 3, - } - - responseData, err := json.Marshal(response) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) - } - - return h.sendResponsePacket(conn, NewDataPacket("logs", responseData)) -} - -// handleStreamLogs handles requests to stream logs in real-time -func (h *WSHandler) handleStreamLogs(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinStreamLogs) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJobsRead, conn); err != nil { - return err - } - - targetIDLen := int(payload[ProtocolAPIKeyHashLen]) - if len(payload) < ProtocolMinStreamLogs+targetIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", "") - } - - targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) - h.logger.Info("stream logs request", "target_id", targetID, "user", user.Name) - - // TODO: Implement actual log streaming - // For now, return a stub response indicating streaming started - response := map[string]interface{}{ - "target_id": targetID, - "streaming": true, - "message": "[Stub] Log streaming would start here. This feature is not yet fully implemented.", - } - - responseData, err := json.Marshal(response) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) - } - - return h.sendResponsePacket(conn, NewDataPacket("logs_stream", responseData)) -} diff --git a/internal/api/ws_jupyter.go b/internal/api/ws_jupyter.go deleted file mode 100644 index 578c089..0000000 --- a/internal/api/ws_jupyter.go +++ /dev/null @@ -1,512 +0,0 @@ -package api - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api/helpers" - "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/queue" -) - -// JupyterTaskErrorCode returns the error code for a Jupyter task. -// This is kept for backward compatibility and delegates to the helper. -func JupyterTaskErrorCode(t *queue.Task) byte { - mapper := helpers.NewTaskErrorMapper() - return byte(mapper.MapJupyterError(t)) -} - -type jupyterTaskOutput struct { - Type string `json:"type"` - Service json.RawMessage `json:"service,omitempty"` - Services json.RawMessage `json:"services,omitempty"` - Packages json.RawMessage `json:"packages,omitempty"` - RestorePath string `json:"restore_path,omitempty"` -} - -func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, 18) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJupyterManage, conn); err != nil { - return err - } - - offset := ProtocolAPIKeyHashLen - nameLen := int(payload[offset]) - offset++ - if len(payload) < offset+nameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") - } - name := string(payload[offset : offset+nameLen]) - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionRestore, - jupyterNameKey: strings.TrimSpace(name), - } - jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name)) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter restore", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "") - } - - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter restore", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error)) - } - - msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name)) - out := strings.TrimSpace(result.Output) - if out != "" { - var payloadOut jupyterTaskOutput - if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { - if strings.TrimSpace(payloadOut.RestorePath) != "" { - msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath)) - } - } - } - - return h.sendResponsePacket(conn, NewSuccessPacket(msg)) -} - -type jupyterServiceView struct { - Name string `json:"name"` - URL string `json:"url"` -} - -const ( - jupyterTaskTypeKey = "task_type" - jupyterTaskTypeValue = "jupyter" - - jupyterTaskActionKey = "jupyter_action" - jupyterActionStart = "start" - jupyterActionStop = "stop" - jupyterActionRemove = "remove" - jupyterActionRestore = "restore" - jupyterActionList = "list" - jupyterActionListPkgs = "list_packages" - - jupyterNameKey = "jupyter_name" - jupyterWorkspaceKey = "jupyter_workspace" - jupyterServiceIDKey = "jupyter_service_id" -) - -func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][name_len:1][name:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "") - } - - apiKeyHash := payload[:16] - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - if user != nil && !user.HasPermission("jupyter:read") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") - } - - p := helpers.NewPayloadParser(payload, 16) - name, err := p.ParseLengthPrefixedString() - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") - } - name = strings.TrimSpace(name) - if name == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "") - } - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionListPkgs, - jupyterNameKey: name, - } - jobName := fmt.Sprintf("jupyter-packages-%s", name) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter packages list", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "") - } - - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter packages list", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error)) - } - - out := strings.TrimSpace(result.Output) - if out == "" { - return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{}))) - } - var payloadOut jupyterTaskOutput - if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { - payload := payloadOut.Packages - if len(payload) == 0 { - payload = []byte("[]") - } - return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload)) - } - - return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{}))) -} - -func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) { - if h.queue == nil { - return "", fmt.Errorf("task queue not configured") - } - if err := container.ValidateJobName(jobName); err != nil { - return "", err - } - if strings.TrimSpace(userName) == "" { - return "", fmt.Errorf("missing user") - } - if meta == nil { - meta = make(map[string]string) - } - meta[jupyterTaskTypeKey] = jupyterTaskTypeValue - - taskID := uuid.New().String() - task := &queue.Task{ - ID: taskID, - JobName: jobName, - Args: "", - Status: "queued", - Priority: 100, // high priority; interactive request - CreatedAt: time.Now(), - UserID: userName, - Username: userName, - CreatedBy: userName, - Metadata: meta, - } - if err := h.queue.AddTask(task); err != nil { - return "", err - } - return taskID, nil -} - -func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) { - if h.queue == nil { - return nil, fmt.Errorf("task queue not configured") - } - deadline := time.Now().Add(timeout) - for { - if time.Now().After(deadline) { - return nil, fmt.Errorf("timed out waiting for worker") - } - t, err := h.queue.GetTask(taskID) - if err != nil { - return nil, err - } - if t == nil { - time.Sleep(200 * time.Millisecond) - continue - } - if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" { - return t, nil - } - time.Sleep(200 * time.Millisecond) - } -} - -func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error { - // Protocol: - // [api_key_hash:16][name][workspace][password] - if len(payload) < 21 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "") - } - - apiKeyHash := payload[:16] - - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - if user != nil && !user.HasPermission("jupyter:manage") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") - } - - offset := 16 - nameLen := int(payload[offset]) - offset++ - - if len(payload) < offset+nameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") - } - - name := string(payload[offset : offset+nameLen]) - offset += nameLen - - workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - - if len(payload) < offset+workspaceLen+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "") - } - - workspace := string(payload[offset : offset+workspaceLen]) - offset += workspaceLen - - passwordLen := int(payload[offset]) - offset++ - - if len(payload) < offset+passwordLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "") - } - - // Password is parsed but not used in StartRequest - // offset += passwordLen (already advanced during parsing) - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionStart, - jupyterNameKey: strings.TrimSpace(name), - jupyterWorkspaceKey: strings.TrimSpace(workspace), - } - jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name)) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter task", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "") - } - - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter start", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - h.logger.Error("jupyter task failed", "error", result.Error) - details := strings.TrimSpace(result.Error) - lower := strings.ToLower(details) - if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") { - return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details) - } - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details) - } - - msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name)) - out := strings.TrimSpace(result.Output) - if out != "" { - var payloadOut jupyterTaskOutput - if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 { - var svc jupyterServiceView - if err := json.Unmarshal(payloadOut.Service, &svc); err == nil { - if strings.TrimSpace(svc.URL) != "" { - msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL)) - } - } - } - } - return h.sendResponsePacket(conn, NewSuccessPacket(msg)) -} - -func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][service_id_len:1][service_id:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "") - } - - apiKeyHash := payload[:16] - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - if user != nil && !user.HasPermission("jupyter:manage") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") - } - - p := helpers.NewPayloadParser(payload, 16) - serviceID, err := p.ParseLengthPrefixedString() - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") - } - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionStop, - jupyterServiceIDKey: strings.TrimSpace(serviceID), - } - jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID)) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter stop", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "") - } - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter stop", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error)) - } - return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID))) -} - -func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][service_id_len:1][service_id:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "") - } - - apiKeyHash := payload[:16] - - p := helpers.NewPayloadParser(payload, 16) - serviceID, err := p.ParseLengthPrefixedString() - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") - } - - // Optional: purge flag (1 byte). Default false for trash-first behavior. - purge := false - if p.HasRemaining() { - purgeByte, _ := p.ParseByte() - purge = purgeByte == 0x01 - } - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - if user != nil && !user.HasPermission("jupyter:manage") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") - } - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionRemove, - jupyterServiceIDKey: strings.TrimSpace(serviceID), - "jupyter_purge": fmt.Sprintf("%t", purge), - } - jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID)) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter remove", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "") - } - - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter remove", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error)) - } - return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID))) -} - -func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error { - user, err := h.authenticate(conn, payload, ProtocolMinDatasetList) - if err != nil { - return err - } - if err := h.requirePermission(user, PermJupyterRead, conn); err != nil { - return err - } - - meta := map[string]string{ - jupyterTaskActionKey: jupyterActionList, - } - jobName := fmt.Sprintf("jupyter-list-%s", user.Name) - taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) - if err != nil { - h.logger.Error("failed to enqueue jupyter list", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "") - } - result, err := h.waitForTask(taskID, 2*time.Minute) - if err != nil { - h.logger.Error("failed waiting for jupyter list", "error", err) - return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") - } - if result.Status != "completed" { - return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error)) - } - - out := strings.TrimSpace(result.Output) - if out == "" { - return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{}))) - } - var payloadOut jupyterTaskOutput - if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { - payload := payloadOut.Services - if len(payload) == 0 { - payload = []byte("[]") - } - return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload)) - } - return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{}))) -} diff --git a/internal/api/ws_tls_auth.go b/internal/api/ws_tls_auth.go deleted file mode 100644 index 41730be..0000000 --- a/internal/api/ws_tls_auth.go +++ /dev/null @@ -1,100 +0,0 @@ -package api - -import ( - "crypto/tls" - "fmt" - "net/http" - "time" - - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/auth" - "golang.org/x/crypto/acme/autocert" -) - -// SetupTLSConfig creates TLS configuration for WebSocket server -func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) { - var server *http.Server - - if certFile != "" && keyFile != "" { - // Use provided certificates - server = &http.Server{ - ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - }, - }, - } - } else if host != "" { - // Use Let's Encrypt with autocert - certManager := &autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(host), - Cache: autocert.DirCache("/var/www/.cache"), - } - - server = &http.Server{ - ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks - TLSConfig: certManager.TLSConfig(), - } - } - - return server, nil -} - -// verifyAPIKeyHash verifies the provided binary hash against stored API keys -func (h *WSHandler) verifyAPIKeyHash(hash []byte) error { - if h.authConfig == nil || !h.authConfig.Enabled { - return nil // No auth required - } - _, err := h.authConfig.ValidateAPIKeyHash(hash) - if err != nil { - return fmt.Errorf("invalid api key") - } - return nil -} - -// sendErrorPacket sends an error response packet -func (h *WSHandler) sendErrorPacket( - conn *websocket.Conn, - errorCode byte, - message string, - details string, -) error { - packet := NewErrorPacket(errorCode, message, details) - return h.sendResponsePacket(conn, packet) -} - -// sendResponsePacket sends a structured response packet -func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error { - data, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize response packet", "error", err) - // Fallback to simple error response - return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) - } - - return conn.WriteMessage(websocket.BinaryMessage, data) -} - -func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) { - if h.authConfig != nil { - user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - return nil, err - } - return user, nil - } - - return &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - }, nil -} diff --git a/internal/api/ws_validate.go b/internal/api/ws_validate.go deleted file mode 100644 index 6d88d38..0000000 --- a/internal/api/ws_validate.go +++ /dev/null @@ -1,523 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/api/helpers" - "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/manifest" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/storage" - "github.com/jfraeys/fetch_ml/internal/worker" -) - -type validateCheck struct { - OK bool `json:"ok"` - Expected string `json:"expected,omitempty"` - Actual string `json:"actual,omitempty"` - Details string `json:"details,omitempty"` -} - -type validateReport struct { - OK bool `json:"ok"` - CommitID string `json:"commit_id,omitempty"` - TaskID string `json:"task_id,omitempty"` - Checks map[string]validateCheck `json:"checks"` - Errors []string `json:"errors,omitempty"` - Warnings []string `json:"warnings,omitempty"` - TS string `json:"ts"` -} - -func shouldRequireRunManifest(task *queue.Task) bool { - if task == nil { - return false - } - s := strings.ToLower(strings.TrimSpace(task.Status)) - switch s { - case "running", "completed", "failed": - return true - default: - return false - } -} - -func expectedRunManifestBucketForStatus(status string) (string, bool) { - s := strings.ToLower(strings.TrimSpace(status)) - switch s { - case "queued", "pending": - return "pending", true - case "running": - return "running", true - case "completed", "finished": - return "finished", true - case "failed": - return "failed", true - default: - return "", false - } -} - -func findRunManifestDir(basePath string, jobName string) (string, string, bool) { - if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" { - return "", "", false - } - jobPaths := storage.NewJobPaths(basePath) - typedRoots := []struct { - bucket string - root string - }{ - {bucket: "running", root: jobPaths.RunningPath()}, - {bucket: "pending", root: jobPaths.PendingPath()}, - {bucket: "finished", root: jobPaths.FinishedPath()}, - {bucket: "failed", root: jobPaths.FailedPath()}, - } - for _, item := range typedRoots { - root := item.root - dir := filepath.Join(root, jobName) - if info, err := os.Stat(dir); err == nil && info.IsDir() { - if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { - return dir, item.bucket, true - } - } - } - return "", "", false -} - -func validateResourcesForTask(task *queue.Task) (validateCheck, []string) { - if task == nil { - return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"} - } - - if task.CPU < 0 { - chk := validateCheck{OK: false, Details: "cpu must be >= 0"} - return chk, []string{"invalid cpu request"} - } - if task.MemoryGB < 0 { - chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"} - return chk, []string{"invalid memory request"} - } - if task.GPU < 0 { - chk := validateCheck{OK: false, Details: "gpu must be >= 0"} - return chk, []string{"invalid gpu request"} - } - - if strings.TrimSpace(task.GPUMemory) != "" { - s := strings.TrimSpace(task.GPUMemory) - if strings.HasSuffix(s, "%") { - v := strings.TrimSuffix(s, "%") - f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) - if err != nil || f <= 0 || f > 100 { - details := "gpu_memory must be a percentage in (0,100]" - chk := validateCheck{OK: false, Details: details} - return chk, []string{"invalid gpu_memory"} - } - } else { - f, err := strconv.ParseFloat(s, 64) - if err != nil || f <= 0 || f > 1 { - chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"} - return chk, []string{"invalid gpu_memory"} - } - } - } - - if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" { - chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"} - return chk, []string{"invalid gpu_memory"} - } - - return validateCheck{OK: true}, nil -} - -func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var] - // target_type: 0=commit_id (20 bytes), 1=task_id (string) - // TODO(context): Add a versioned validate protocol once we need more target types/fields. - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "") - } - apiKeyHash := payload[:16] - targetType := payload[16] - idLen := int(payload[17]) - if idLen < 1 || len(payload) < 18+idLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "") - } - idBytes := payload[18 : 18+idLen] - - // Validate API key and user - user, err := h.validateWSUser(apiKeyHash) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Invalid API key", - err.Error(), - ) - } - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to validate jobs", - "", - ) - } - if h.expManager == nil { - return h.sendErrorPacket( - conn, - ErrorCodeServiceUnavailable, - "Experiment manager not available", - "", - ) - } - - var task *queue.Task - commitID := "" - switch targetType { - case 0: - if len(idBytes) != 20 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "") - } - commitID = fmt.Sprintf("%x", idBytes) - case 1: - taskID := string(idBytes) - if h.queue == nil { - return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "") - } - t, err := h.queue.GetTask(taskID) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error()) - } - task = t - if h.authConfig != nil && - h.authConfig.Enabled && - !user.Admin && - task.UserID != user.Name && - task.CreatedBy != user.Name { - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "You can only validate your own jobs", - "", - ) - } - if task.Metadata == nil || task.Metadata["commit_id"] == "" { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "") - } - commitID = task.Metadata["commit_id"] - default: - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "") - } - - r := validateReport{ - OK: true, - TS: time.Now().UTC().Format(time.RFC3339Nano), - Checks: map[string]validateCheck{}, - } - if task != nil { - r.TaskID = task.ID - } - if commitID != "" { - r.CommitID = commitID - } - - // Validate commit id format - if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok { - r.OK = false - r.Errors = append(r.Errors, errMsg) - } - - // Experiment manifest integrity - // TODO(context): Extend report to include per-file diff list on mismatch (bounded output). - if r.OK { - if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok { - r.OK = false - r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details} - r.Errors = append(r.Errors, "experiment manifest validation failed") - } else { - r.Checks["experiment_manifest"] = validateCheck{OK: true} - } - } - - // Deps manifest presence + hash - // TODO(context): Allow client to declare which dependency manifest is authoritative. - filesPath := h.expManager.GetFilesPath(commitID) - depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID) - if depErrs != nil { - r.OK = false - r.Checks["deps_manifest"] = validateCheck(depCheck) - r.Errors = append(r.Errors, depErrs...) - } else { - r.Checks["deps_manifest"] = validateCheck(depCheck) - } - - // Compare against expected task metadata if available. - if task != nil { - resCheck, resErrs := validateResourcesForTask(task) - r.Checks["resources"] = resCheck - if !resCheck.OK { - r.OK = false - r.Errors = append(r.Errors, resErrs...) - } - - // Run manifest checks: best-effort for queued tasks, required for running/completed/failed. - if err := container.ValidateJobName(task.JobName); err != nil { - r.OK = false - r.Errors = append(r.Errors, "invalid job name") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"} - } else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" { - if shouldRequireRunManifest(task) { - r.OK = false - r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} - } else { - r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} - } - } else { - manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName) - if !found { - if shouldRequireRunManifest(task) { - r.OK = false - r.Errors = append(r.Errors, "run manifest not found") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} - } else { - r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} - } - } else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil { - r.OK = false - r.Errors = append(r.Errors, "unable to read run manifest") - r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"} - } else { - r.Checks["run_manifest"] = validateCheck{OK: true} - - expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status) - if ok { - if expectedBucket != manifestBucket { - msg := "run manifest location mismatch" - chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket} - if shouldRequireRunManifest(task) { - r.OK = false - r.Errors = append(r.Errors, msg) - r.Checks["run_manifest_location"] = chk - } else { - r.Warnings = append(r.Warnings, msg) - r.Checks["run_manifest_location"] = chk - } - } else { - r.Checks["run_manifest_location"] = validateCheck{ - OK: true, - Expected: expectedBucket, - Actual: manifestBucket, - } - } - } - - // Validate task ID using helper - taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID) - r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck) - if !taskIDCheck.OK { - r.OK = false - r.Errors = append(r.Errors, "run manifest task_id mismatch") - } - - // Validate commit ID using helper - commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"]) - r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck) - if !commitCheck.OK { - r.OK = false - r.Errors = append(r.Errors, "run manifest commit_id mismatch") - } - - // Validate deps provenance using helper - depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"]) - depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) - depGotName := strings.TrimSpace(rm.DepsManifestName) - depGotSHA := strings.TrimSpace(rm.DepsManifestSHA) - depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA) - r.Checks["run_manifest_deps"] = validateCheck(depsCheck) - if !depsCheck.OK { - r.OK = false - r.Errors = append(r.Errors, "run manifest deps provenance mismatch") - } - - // Validate snapshot using helpers - if strings.TrimSpace(task.SnapshotID) != "" { - snapWantID := strings.TrimSpace(task.SnapshotID) - snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"]) - snapGotID := strings.TrimSpace(rm.SnapshotID) - snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256) - - snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID) - r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck) - if !snapIDCheck.OK { - r.OK = false - r.Errors = append(r.Errors, "run manifest snapshot_id mismatch") - } - - snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA) - r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck) - if !snapSHACheck.OK { - r.OK = false - r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch") - } - } - - // Validate lifecycle using helper - lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status) - if lifecycleOK { - r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true} - } else { - chk := validateCheck{OK: false, Details: details} - if shouldRequireRunManifest(task) { - r.OK = false - r.Errors = append(r.Errors, "run manifest lifecycle invalid") - r.Checks["run_manifest_lifecycle"] = chk - } else { - r.Warnings = append(r.Warnings, "run manifest lifecycle invalid") - r.Checks["run_manifest_lifecycle"] = chk - } - } - } - } - - want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"]) - cur := "" - if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil { - cur = man.OverallSHA - } - if want == "" { - r.OK = false - r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha") - r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur} - } else if cur == "" { - r.OK = false - r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha") - r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want} - } else if want != cur { - r.OK = false - r.Errors = append(r.Errors, "experiment manifest overall sha mismatch") - r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur} - } else { - r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur} - } - - wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"]) - wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) - if wantDep == "" || wantDepSha == "" { - r.OK = false - r.Errors = append(r.Errors, "missing expected deps manifest provenance") - r.Checks["expected_deps_manifest"] = validateCheck{OK: false} - } else if depName != "" { - sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName)) - ok := (wantDep == depName && wantDepSha == sha) - if !ok { - r.OK = false - r.Errors = append(r.Errors, "deps manifest provenance mismatch") - r.Checks["expected_deps_manifest"] = validateCheck{ - OK: false, - Expected: wantDep + ":" + wantDepSha, - Actual: depName + ":" + sha, - } - } else { - r.Checks["expected_deps_manifest"] = validateCheck{ - OK: true, - Expected: wantDep + ":" + wantDepSha, - Actual: depName + ":" + sha, - } - } - } - - // Snapshot/dataset checks require dataDir. - // TODO(context): Support snapshot stores beyond local filesystem (e.g. S3). - // TODO(context): Validate snapshots by digest. - if task.SnapshotID != "" { - if h.dataDir == "" { - r.OK = false - r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot") - r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"} - } else { - wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) - if nerr != nil || wantSnap == "" { - r.OK = false - r.Errors = append(r.Errors, "missing/invalid snapshot_sha256") - r.Checks["snapshot"] = validateCheck{OK: false} - } else { - curSnap, err := worker.DirOverallSHA256Hex( - filepath.Join(h.dataDir, "snapshots", task.SnapshotID), - ) - if err != nil { - r.OK = false - r.Errors = append(r.Errors, "snapshot hash computation failed") - r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()} - } else if curSnap != wantSnap { - r.OK = false - r.Errors = append(r.Errors, "snapshot checksum mismatch") - r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap} - } else { - r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap} - } - } - } - } - - if len(task.DatasetSpecs) > 0 { - // TODO(context): Add dataset URI fetch/verification. - // TODO(context): Currently only validates local materialized datasets. - for _, ds := range task.DatasetSpecs { - if ds.Checksum == "" { - continue - } - key := "dataset:" + ds.Name - if h.dataDir == "" { - r.OK = false - r.Errors = append( - r.Errors, - "api server data_dir not configured; cannot validate dataset checksums", - ) - r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"} - continue - } - wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum) - if nerr != nil || wantDS == "" { - r.OK = false - r.Errors = append(r.Errors, "invalid dataset checksum format") - r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"} - continue - } - curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name)) - if err != nil { - r.OK = false - r.Errors = append(r.Errors, "dataset checksum computation failed") - r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()} - continue - } - if curDS != wantDS { - r.OK = false - r.Errors = append(r.Errors, "dataset checksum mismatch") - r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS} - continue - } - r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS} - } - } - } - - payloadBytes, err := json.Marshal(r) - if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeUnknownError, - "failed to serialize validate report", - err.Error(), - ) - } - return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes)) -}