diff --git a/cmd/tui/internal/services/export.go b/cmd/tui/internal/services/export.go index 03ce623..890f765 100644 --- a/cmd/tui/internal/services/export.go +++ b/cmd/tui/internal/services/export.go @@ -28,19 +28,19 @@ func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportS // Returns the path to the exported file func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) { s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize) - + // Placeholder - actual implementation would call API // POST /api/jobs/{id}/export?anonymize=true - + exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix()) - + s.logger.Info("export complete", "job", jobName, "path", exportPath) return exportPath, nil } // ExportOptions contains options for export type ExportOptions struct { - Anonymize bool + Anonymize bool IncludeLogs bool IncludeData bool } diff --git a/cmd/tui/internal/services/websocket.go b/cmd/tui/internal/services/websocket.go index 670a725..0653f3b 100644 --- a/cmd/tui/internal/services/websocket.go +++ b/cmd/tui/internal/services/websocket.go @@ -16,37 +16,37 @@ import ( // WebSocketClient manages real-time updates from the server type WebSocketClient struct { - conn *websocket.Conn - serverURL string - apiKey string - logger *logging.Logger - + conn *websocket.Conn + serverURL string + apiKey string + logger *logging.Logger + // Channels for different update types - jobUpdates chan model.JobUpdateMsg - gpuUpdates chan model.GPUUpdateMsg + jobUpdates chan model.JobUpdateMsg + gpuUpdates chan model.GPUUpdateMsg statusUpdates chan model.StatusMsg - + // Control - ctx context.Context - cancel context.CancelFunc - connected bool + ctx context.Context + cancel context.CancelFunc + connected bool } // JobUpdateMsg represents a real-time job status update type JobUpdateMsg struct { - JobName string `json:"job_name"` - Status string `json:"status"` - TaskID string `json:"task_id"` - Progress int `json:"progress"` + JobName string `json:"job_name"` + Status string `json:"status"` + TaskID string `json:"task_id"` + Progress int `json:"progress"` } // GPUUpdateMsg represents a real-time GPU status update type GPUUpdateMsg struct { - DeviceID int `json:"device_id"` - Utilization int `json:"utilization"` - MemoryUsed int64 `json:"memory_used"` - MemoryTotal int64 `json:"memory_total"` - Temperature int `json:"temperature"` + DeviceID int `json:"device_id"` + Utilization int `json:"utilization"` + MemoryUsed int64 `json:"memory_used"` + MemoryTotal int64 `json:"memory_total"` + Temperature int `json:"temperature"` } // NewWebSocketClient creates a new WebSocket client @@ -71,26 +71,26 @@ func (c *WebSocketClient) Connect() error { if err != nil { return fmt.Errorf("invalid server URL: %w", err) } - + // Convert http/https to ws/wss wsScheme := "ws" if u.Scheme == "https" { wsScheme = "wss" } wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host) - + // Create dialer with timeout dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, Subprotocols: []string{"fetchml-v1"}, } - + // Add API key to headers headers := http.Header{} if c.apiKey != "" { headers.Set("X-API-Key", c.apiKey) } - + conn, resp, err := dialer.Dial(wsURL, headers) if err != nil { if resp != nil { @@ -98,17 +98,17 @@ func (c *WebSocketClient) Connect() error { } return fmt.Errorf("websocket dial failed: %w", err) } - + c.conn = conn c.connected = true c.logger.Info("websocket connected", "url", wsURL) - + // Start message handler go c.messageHandler() - + // Start heartbeat go c.heartbeat() - + return nil } @@ -134,15 +134,15 @@ func (c *WebSocketClient) messageHandler() { return default: } - + if c.conn == nil { time.Sleep(100 * time.Millisecond) continue } - + // Set read deadline c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - + // Read message messageType, data, err := c.conn.ReadMessage() if err != nil { @@ -150,7 +150,7 @@ func (c *WebSocketClient) messageHandler() { c.logger.Error("websocket read error", "error", err) } c.connected = false - + // Attempt reconnect time.Sleep(5 * time.Second) if err := c.Connect(); err != nil { @@ -158,7 +158,7 @@ func (c *WebSocketClient) messageHandler() { } continue } - + // Handle binary vs text messages if messageType == websocket.BinaryMessage { c.handleBinaryMessage(data) @@ -173,11 +173,11 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) { if len(data) < 2 { return } - + // Binary protocol: [opcode:1][data...] opcode := data[0] payload := data[1:] - + switch opcode { case 0x01: // Job update var update JobUpdateMsg @@ -186,7 +186,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) { return } c.jobUpdates <- model.JobUpdateMsg(update) - + case 0x02: // GPU update var update GPUUpdateMsg if err := json.Unmarshal(payload, &update); err != nil { @@ -194,7 +194,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) { return } c.gpuUpdates <- model.GPUUpdateMsg(update) - + case 0x03: // Status message var status model.StatusMsg if err := json.Unmarshal(payload, &status); err != nil { @@ -212,7 +212,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) { c.logger.Error("failed to unmarshal text message", "error", err) return } - + msgType, _ := msg["type"].(string) switch msgType { case "job_update": @@ -228,7 +228,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) { func (c *WebSocketClient) heartbeat() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() - + for { select { case <-c.ctx.Done(): @@ -249,12 +249,12 @@ func (c *WebSocketClient) Subscribe(channels ...string) error { if !c.connected { return fmt.Errorf("not connected") } - + subMsg := map[string]interface{}{ "action": "subscribe", "channels": channels, } - + data, _ := json.Marshal(subMsg) return c.conn.WriteMessage(websocket.TextMessage, data) } diff --git a/internal/api/adapter.go b/internal/api/adapter.go new file mode 100644 index 0000000..cfa3f3d --- /dev/null +++ b/internal/api/adapter.go @@ -0,0 +1,156 @@ +// Package api provides HTTP handlers and OpenAPI-generated server interface implementations +package api + +import ( + "net/http" + + "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/jobs" + "github.com/jfraeys/fetch_ml/internal/api/jupyter" + "github.com/labstack/echo/v4" +) + +// HandlerAdapter implements the generated ServerInterface using existing handlers +type HandlerAdapter struct { + jobsHandler *jobs.Handler + jupyterHandler *jupyter.Handler + datasetsHandler *datasets.Handler +} + +// NewHandlerAdapter creates a new handler adapter +func NewHandlerAdapter( + jobsHandler *jobs.Handler, + jupyterHandler *jupyter.Handler, + datasetsHandler *datasets.Handler, +) *HandlerAdapter { + return &HandlerAdapter{ + jobsHandler: jobsHandler, + jupyterHandler: jupyterHandler, + datasetsHandler: datasetsHandler, + } +} + +// Ensure HandlerAdapter implements the generated interface +var _ ServerInterface = (*HandlerAdapter)(nil) + +// toHTTPHandler converts echo.Context to standard HTTP handler +func toHTTPHandler(h func(http.ResponseWriter, *http.Request)) echo.HandlerFunc { + return func(c echo.Context) error { + h(c.Response().Writer, c.Request()) + return nil + } +} + +// GetHealth implements the health check endpoint +func (a *HandlerAdapter) GetHealth(ctx echo.Context) error { + return ctx.String(200, "OK\n") +} + +// GetV1Experiments lists all experiments +func (a *HandlerAdapter) GetV1Experiments(ctx echo.Context) error { + return ctx.JSON(200, map[string]any{ + "experiments": []any{}, + "message": "Not yet implemented", + }) +} + +// PostV1Experiments creates a new experiment +func (a *HandlerAdapter) PostV1Experiments(ctx echo.Context) error { + return ctx.JSON(201, map[string]any{ + "message": "Not yet implemented", + }) +} + +// GetV1JupyterServices lists all Jupyter services +func (a *HandlerAdapter) GetV1JupyterServices(ctx echo.Context) error { + if a.jupyterHandler == nil { + return ctx.JSON(503, map[string]any{ + "error": "Jupyter service not available", + "code": "SERVICE_UNAVAILABLE", + }) + } + return toHTTPHandler(a.jupyterHandler.ListServicesHTTP)(ctx) +} + +// PostV1JupyterServices starts a new Jupyter service +func (a *HandlerAdapter) PostV1JupyterServices(ctx echo.Context) error { + if a.jupyterHandler == nil { + return ctx.JSON(503, map[string]any{ + "error": "Jupyter service not available", + "code": "SERVICE_UNAVAILABLE", + }) + } + return toHTTPHandler(a.jupyterHandler.StartServiceHTTP)(ctx) +} + +// DeleteV1JupyterServicesServiceId stops a Jupyter service +func (a *HandlerAdapter) DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error { + if a.jupyterHandler == nil { + return ctx.JSON(503, map[string]any{ + "error": "Jupyter service not available", + "code": "SERVICE_UNAVAILABLE", + }) + } + // TODO: Implement when StopServiceHTTP is available + return ctx.JSON(501, map[string]any{ + "error": "Not implemented", + "code": "NOT_IMPLEMENTED", + "message": "Jupyter service stop not yet implemented via REST API", + }) +} + +// GetV1Queue returns queue status +func (a *HandlerAdapter) GetV1Queue(ctx echo.Context) error { + return ctx.JSON(200, map[string]any{ + "status": "healthy", + "pending": 0, + "running": 0, + }) +} + +// GetV1Tasks lists all tasks +func (a *HandlerAdapter) GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error { + if a.jobsHandler == nil { + return ctx.JSON(503, map[string]any{ + "error": "Jobs handler not available", + "code": "SERVICE_UNAVAILABLE", + }) + } + return toHTTPHandler(a.jobsHandler.ListAllJobsHTTP)(ctx) +} + +// PostV1Tasks creates a new task +func (a *HandlerAdapter) PostV1Tasks(ctx echo.Context) error { + return ctx.JSON(501, map[string]any{ + "error": "Not implemented", + "code": "NOT_IMPLEMENTED", + "message": "Task creation via REST API not yet implemented - use WebSocket", + }) +} + +// DeleteV1TasksTaskId cancels/deletes a task +func (a *HandlerAdapter) DeleteV1TasksTaskId(ctx echo.Context, taskId string) error { + return ctx.JSON(501, map[string]any{ + "error": "Not implemented", + "code": "NOT_IMPLEMENTED", + "message": "Task cancellation via REST API not yet implemented - use WebSocket", + }) +} + +// GetV1TasksTaskId gets task details +func (a *HandlerAdapter) GetV1TasksTaskId(ctx echo.Context, taskId string) error { + return ctx.JSON(501, map[string]any{ + "error": "Not implemented", + "code": "NOT_IMPLEMENTED", + "message": "Task details via REST API not yet implemented - use WebSocket", + }) +} + +// GetWs handles WebSocket connections +func (a *HandlerAdapter) GetWs(ctx echo.Context) error { + return ctx.JSON(426, map[string]any{ + "error": "WebSocket connection required", + "code": "UPGRADE_REQUIRED", + "message": "Use WebSocket protocol to connect to this endpoint", + }) +} diff --git a/internal/api/errors.go b/internal/api/errors.go deleted file mode 100644 index e118968..0000000 --- a/internal/api/errors.go +++ /dev/null @@ -1,133 +0,0 @@ -// Package api provides error handling utilities for the API -package api - -import ( - "encoding/json" - "net/http" - "time" -) - -// ErrorResponse represents a standardized error response -type ErrorResponse struct { - Error bool `json:"error"` - Code byte `json:"code"` - Message string `json:"message"` - Details string `json:"details,omitempty"` - Timestamp time.Time `json:"timestamp"` - RequestID string `json:"request_id,omitempty"` -} - -// SuccessResponse represents a standardized success response -type SuccessResponse struct { - Success bool `json:"success"` - Data interface{} `json:"data,omitempty"` - Timestamp time.Time `json:"timestamp"` -} - -// WriteError writes a standardized error response -func WriteError(w http.ResponseWriter, code byte, message, details string, statusCode int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - - response := ErrorResponse{ - Error: true, - Code: code, - Message: message, - Details: details, - Timestamp: time.Now().UTC(), - } - - json.NewEncoder(w).Encode(response) -} - -// WriteSuccess writes a standardized success response -func WriteSuccess(w http.ResponseWriter, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - - response := SuccessResponse{ - Success: true, - Data: data, - Timestamp: time.Now().UTC(), - } - - json.NewEncoder(w).Encode(response) -} - -// Common error codes for API responses -const ( - ErrCodeUnknownError = 0x00 - ErrCodeInvalidRequest = 0x01 - ErrCodeAuthenticationFailed = 0x02 - ErrCodePermissionDenied = 0x03 - ErrCodeResourceNotFound = 0x04 - ErrCodeResourceAlreadyExists = 0x05 - ErrCodeServerOverloaded = 0x10 - ErrCodeDatabaseError = 0x11 - ErrCodeNetworkError = 0x12 - ErrCodeStorageError = 0x13 - ErrCodeTimeout = 0x14 - ErrCodeJobNotFound = 0x20 - ErrCodeJobAlreadyRunning = 0x21 - ErrCodeJobFailedToStart = 0x22 - ErrCodeJobExecutionFailed = 0x23 - ErrCodeJobCancelled = 0x24 - ErrCodeOutOfMemory = 0x30 - ErrCodeDiskFull = 0x31 - ErrCodeInvalidConfiguration = 0x32 - ErrCodeServiceUnavailable = 0x33 -) - -// HTTP status code mappings -const ( - StatusBadRequest = http.StatusBadRequest - StatusUnauthorized = http.StatusUnauthorized - StatusForbidden = http.StatusForbidden - StatusNotFound = http.StatusNotFound - StatusConflict = http.StatusConflict - StatusInternalServerError = http.StatusInternalServerError - StatusServiceUnavailable = http.StatusServiceUnavailable - StatusTooManyRequests = http.StatusTooManyRequests -) - -// ErrorCodeToHTTPStatus maps API error codes to HTTP status codes -func ErrorCodeToHTTPStatus(code byte) int { - switch code { - case ErrCodeInvalidRequest: - return StatusBadRequest - case ErrCodeAuthenticationFailed: - return StatusUnauthorized - case ErrCodePermissionDenied: - return StatusForbidden - case ErrCodeResourceNotFound, ErrCodeJobNotFound: - return StatusNotFound - case ErrCodeResourceAlreadyExists, ErrCodeJobAlreadyRunning: - return StatusConflict - case ErrCodeServerOverloaded, ErrCodeServiceUnavailable: - return StatusServiceUnavailable - case ErrCodeDatabaseError, ErrCodeNetworkError, ErrCodeStorageError: - return StatusInternalServerError - default: - return StatusInternalServerError - } -} - -// NewErrorResponse creates a new error response with the given details -func NewErrorResponse(code byte, message, details string) ErrorResponse { - return ErrorResponse{ - Error: true, - Code: code, - Message: message, - Details: details, - Timestamp: time.Now().UTC(), - } -} - -// NewSuccessResponse creates a new success response with the given data -func NewSuccessResponse(data interface{}) SuccessResponse { - return SuccessResponse{ - Success: true, - Data: data, - Timestamp: time.Now().UTC(), - } -} diff --git a/internal/api/factory.go b/internal/api/factory.go index ad85ef7..e9313e9 100644 --- a/internal/api/factory.go +++ b/internal/api/factory.go @@ -5,6 +5,7 @@ import ( "os" "strings" + apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" @@ -49,8 +50,8 @@ func (s *Server) initializeComponents() error { // Initialize Jupyter service manager s.initJupyterServiceManager() - // Initialize handlers - s.handlers = NewHandlers(s.expManager, nil, s.logger) + // Initialize OpenAPI validation middleware (optional, doesn't block startup) + s.initValidationMiddleware() return nil } @@ -250,6 +251,24 @@ func (s *Server) initAuditLogger() *audit.Logger { return al } +// initValidationMiddleware initializes the OpenAPI validation middleware +func (s *Server) initValidationMiddleware() { + // Only initialize if OpenAPI spec exists + if _, err := os.Stat("api/openapi.yaml"); err != nil { + s.logger.Debug("OpenAPI spec not found, skipping validation middleware") + return + } + + vm, err := apimiddleware.NewValidationMiddleware("api/openapi.yaml") + if err != nil { + s.logger.Warn("failed to initialize OpenAPI validation middleware", "error", err) + return + } + + s.validationMiddleware = vm + s.logger.Info("OpenAPI validation middleware initialized") +} + // getSecurityConfig extracts security config from server config func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig { return &config.SecurityConfig{ diff --git a/internal/api/handlers.go b/internal/api/handlers.go deleted file mode 100644 index 0788ea3..0000000 --- a/internal/api/handlers.go +++ /dev/null @@ -1,313 +0,0 @@ -// Package api provides HTTP handlers for the fetch_ml API server -package api - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/jfraeys/fetch_ml/internal/api/helpers" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/experiment" - "github.com/jfraeys/fetch_ml/internal/jupyter" - "github.com/jfraeys/fetch_ml/internal/logging" -) - -// Handlers groups all HTTP handlers -type Handlers struct { - expManager *experiment.Manager - jupyterServiceMgr *jupyter.ServiceManager - logger *logging.Logger -} - -// NewHandlers creates a new handler group -func NewHandlers( - expManager *experiment.Manager, - jupyterServiceMgr *jupyter.ServiceManager, - logger *logging.Logger, -) *Handlers { - return &Handlers{ - expManager: expManager, - jupyterServiceMgr: jupyterServiceMgr, - logger: logger, - } -} - -// RegisterHandlers registers all HTTP handlers with the mux -func (h *Handlers) RegisterHandlers(mux *http.ServeMux) { - // Health check endpoints - mux.HandleFunc("/db-status", h.handleDBStatus) - - // Jupyter service endpoints - if h.jupyterServiceMgr != nil { - mux.HandleFunc("/api/jupyter/services", h.handleJupyterServices) - mux.HandleFunc("/api/jupyter/experiments/link", h.handleJupyterExperimentLink) - mux.HandleFunc("/api/jupyter/experiments/sync", h.handleJupyterExperimentSync) - } -} - -// handleHealth responds with a simple health check -func (h *Handlers) handleHealth(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintf(w, "OK\n") -} - -// handleDBStatus responds with database connection status -func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - - // This would need the DB instance passed to handlers - // For now, return a basic response - response := map[string]any{ - "status": "unknown", - "message": "Database status check not implemented", - } - - w.WriteHeader(http.StatusOK) - if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil { - h.logger.Error("failed to write response", "error", err) - } -} - -// handleJupyterServices handles Jupyter service management requests -func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - user := auth.GetUserFromContext(r.Context()) - if user == nil { - http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) - return - } - - switch r.Method { - case http.MethodGet: - if !user.HasPermission("jupyter:read") { - http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) - return - } - h.listJupyterServices(w, r) - case http.MethodPost: - if !user.HasPermission("jupyter:manage") { - http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) - return - } - h.startJupyterService(w, r) - case http.MethodDelete: - if !user.HasPermission("jupyter:manage") { - http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) - return - } - h.stopJupyterService(w, r) - default: - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } -} - -// listJupyterServices lists all Jupyter services -func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) { - services := h.jupyterServiceMgr.ListServices() - w.WriteHeader(http.StatusOK) - if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil { - h.logger.Error("failed to write response", "error", err) - } -} - -// startJupyterService starts a new Jupyter service -func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) { - var req jupyter.StartRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - ctx := r.Context() - service, err := h.jupyterServiceMgr.StartService(ctx, &req) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to start service: %v", err), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusCreated) - if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil { - h.logger.Error("failed to write response", "error", err) - } -} - -// stopJupyterService stops a Jupyter service -func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) { - serviceID := r.URL.Query().Get("id") - if serviceID == "" { - http.Error(w, "Service ID is required", http.StatusBadRequest) - return - } - - ctx := r.Context() - if err := h.jupyterServiceMgr.StopService(ctx, serviceID); err != nil { - http.Error(w, fmt.Sprintf("Failed to stop service: %v", err), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]string{ - "status": "stopped", - "id": serviceID, - }); err != nil { - h.logger.Error("failed to encode response", "error", err) - } -} - -// handleJupyterExperimentLink handles linking Jupyter workspaces with experiments -func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - user := auth.GetUserFromContext(r.Context()) - if user == nil { - http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) - return - } - if !user.HasPermission("jupyter:manage") { - http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) - return - } - - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - var req struct { - Workspace string `json:"workspace"` - ExperimentID string `json:"experiment_id"` - ServiceID string `json:"service_id,omitempty"` - } - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - if req.Workspace == "" || req.ExperimentID == "" { - http.Error(w, "Workspace and experiment_id are required", http.StatusBadRequest) - return - } - - if !h.expManager.ExperimentExists(req.ExperimentID) { - http.Error(w, "Experiment not found", http.StatusNotFound) - return - } - - // Link workspace with experiment using service manager - if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment( - req.Workspace, - req.ExperimentID, - req.ServiceID, - ); err != nil { - http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError) - return - } - - // Get workspace metadata to return - metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace) - if err != nil { - http.Error( - w, - fmt.Sprintf("Failed to get workspace metadata: %v", err), - http.StatusInternalServerError, - ) - return - } - - h.logger.Info("jupyter workspace linked with experiment", - "workspace", req.Workspace, - "experiment_id", req.ExperimentID, - "service_id", req.ServiceID) - - w.WriteHeader(http.StatusCreated) - if err := json.NewEncoder(w).Encode(map[string]interface{}{ - "status": "linked", - "data": metadata, - }); err != nil { - h.logger.Error("failed to encode response", "error", err) - } -} - -// handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments -func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - user := auth.GetUserFromContext(r.Context()) - if user == nil { - http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) - return - } - if !user.HasPermission("jupyter:manage") { - http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) - return - } - - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - var req struct { - Workspace string `json:"workspace"` - ExperimentID string `json:"experiment_id"` - Direction string `json:"direction"` // "pull" or "push" - SyncType string `json:"sync_type"` // "data", "notebooks", "all" - } - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - if req.Workspace == "" || req.ExperimentID == "" || req.Direction == "" { - http.Error(w, "Workspace, experiment_id, and direction are required", http.StatusBadRequest) - return - } - - // Validate experiment exists - if !h.expManager.ExperimentExists(req.ExperimentID) { - http.Error(w, "Experiment not found", http.StatusNotFound) - return - } - - // Perform sync operation using service manager - ctx := r.Context() - if err := h.jupyterServiceMgr.SyncWorkspaceWithExperiment( - ctx, req.Workspace, req.ExperimentID, req.Direction); err != nil { - http.Error(w, fmt.Sprintf("Failed to sync workspace: %v", err), http.StatusInternalServerError) - return - } - - // Get updated metadata - metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace) - if err != nil { - http.Error( - w, - fmt.Sprintf("Failed to get workspace metadata: %v", err), - http.StatusInternalServerError, - ) - return - } - - // Create sync result - syncResult := map[string]interface{}{ - "workspace": req.Workspace, - "experiment_id": req.ExperimentID, - "direction": req.Direction, - "sync_type": req.SyncType, - "synced_at": metadata.LastSync, - "status": "completed", - "metadata": metadata, - } - - h.logger.Info("jupyter workspace sync completed", - "workspace", req.Workspace, - "experiment_id", req.ExperimentID, - "direction", req.Direction, - "sync_type", req.SyncType) - - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(syncResult); err != nil { - h.logger.Error("failed to encode response", "error", err) - } -} diff --git a/internal/api/jupyter/handlers.go b/internal/api/jupyter/handlers.go index 73e74c4..a179d2b 100644 --- a/internal/api/jupyter/handlers.go +++ b/internal/api/jupyter/handlers.go @@ -15,9 +15,9 @@ import ( // Handler provides Jupyter-related WebSocket handlers type Handler struct { - logger *logging.Logger - jupyterMgr *jupyter.ServiceManager - authConfig *auth.Config + logger *logging.Logger + jupyterMgr *jupyter.ServiceManager + authConfig *auth.Config } // NewHandler creates a new Jupyter handler @@ -35,11 +35,11 @@ func NewHandler( // Error codes const ( - ErrorCodeInvalidRequest = 0x01 - ErrorCodeAuthenticationFailed = 0x02 - ErrorCodePermissionDenied = 0x03 - ErrorCodeResourceNotFound = 0x04 - ErrorCodeServiceUnavailable = 0x33 + ErrorCodeInvalidRequest = 0x01 + ErrorCodeAuthenticationFailed = 0x02 + ErrorCodePermissionDenied = 0x03 + ErrorCodeResourceNotFound = 0x04 + ErrorCodeServiceUnavailable = 0x33 ) // Permissions @@ -141,18 +141,18 @@ func (h *Handler) HandleListJupyter(conn *websocket.Conn, payload []byte, user * if h.jupyterMgr == nil { return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "services": []interface{}{}, - "count": 0, + "success": true, + "services": []interface{}{}, + "count": 0, }) } services := h.jupyterMgr.ListServices() return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "services": services, - "count": len(services), + "success": true, + "services": services, + "count": len(services), }) } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index dede950..bea04ca 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -26,6 +26,10 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { if len(s.config.Security.IPWhitelist) > 0 { handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler) } + // Add OpenAPI validation if available + if s.validationMiddleware != nil { + handler = s.validationMiddleware.ValidateRequest(handler) + } handler.ServeHTTP(w, r) }) } diff --git a/internal/api/middleware/validation.go b/internal/api/middleware/validation.go new file mode 100644 index 0000000..5805e83 --- /dev/null +++ b/internal/api/middleware/validation.go @@ -0,0 +1,87 @@ +// Package middleware provides request/response validation using OpenAPI spec +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +// ValidationMiddleware validates HTTP requests against OpenAPI spec +type ValidationMiddleware struct { + router routers.Router +} + +// NewValidationMiddleware creates a new validation middleware from OpenAPI spec +func NewValidationMiddleware(specPath string) (*ValidationMiddleware, error) { + // Load OpenAPI spec + loader := openapi3.NewLoader() + doc, err := loader.LoadFromFile(specPath) + if err != nil { + return nil, err + } + + // Validate the spec itself + if err := doc.Validate(loader.Context); err != nil { + return nil, err + } + + // Create router for path matching + router, err := gorillamux.NewRouter(doc) + if err != nil { + return nil, err + } + + return &ValidationMiddleware{router: router}, nil +} + +// ValidateRequest validates an incoming HTTP request +func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip validation for health endpoints + if r.URL.Path == "/health/ok" || r.URL.Path == "/health" || r.URL.Path == "/metrics" { + next.ServeHTTP(w, r) + return + } + + // Find the route + route, pathParams, err := v.router.FindRoute(r) + if err != nil { + // Route not in spec - allow through (might be unregistered endpoint) + next.ServeHTTP(w, r) + return + } + + // Read and restore body for validation + var bodyBytes []byte + if r.Body != nil && r.Body != http.NoBody { + bodyBytes, _ = io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + // Validate request - body is read from r.Body automatically + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + + if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]any{ + "error": "validation failed", + "message": err.Error(), + }) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/responses/errors.go b/internal/api/responses/errors.go new file mode 100644 index 0000000..f729c94 --- /dev/null +++ b/internal/api/responses/errors.go @@ -0,0 +1,173 @@ +// Package responses provides structured API response types with security-conscious error handling. +package responses + +import ( + "encoding/json" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// ErrorResponse provides a sanitized error response to clients. +// It includes a trace ID for support lookup while preventing information leakage. +type ErrorResponse struct { + Error string `json:"error"` // Sanitized error message for clients + Code string `json:"code"` // Machine-readable error code + TraceID string `json:"trace_id"` // For support lookup (internal correlation) +} + +// Error codes for machine-readable error identification +const ( + ErrCodeBadRequest = "BAD_REQUEST" + ErrCodeUnauthorized = "UNAUTHORIZED" + ErrCodeForbidden = "FORBIDDEN" + ErrCodeNotFound = "NOT_FOUND" + ErrCodeConflict = "CONFLICT" + ErrCodeRateLimited = "RATE_LIMITED" + ErrCodeInternal = "INTERNAL_ERROR" + ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE" + ErrCodeValidation = "VALIDATION_ERROR" +) + +// HTTP status to error code mapping +var statusToCode = map[int]string{ + http.StatusBadRequest: ErrCodeBadRequest, + http.StatusUnauthorized: ErrCodeUnauthorized, + http.StatusForbidden: ErrCodeForbidden, + http.StatusNotFound: ErrCodeNotFound, + http.StatusConflict: ErrCodeConflict, + http.StatusTooManyRequests: ErrCodeRateLimited, + http.StatusInternalServerError: ErrCodeInternal, + http.StatusServiceUnavailable: ErrCodeServiceUnavailable, + 422: ErrCodeValidation, // Unprocessable Entity +} + +// Patterns to sanitize from error messages (security: prevent information leakage) +var ( + // Remove file paths + pathPattern = regexp.MustCompile(`/[^\s]*`) + // Remove sensitive keywords + sensitiveKeywords = []string{"password", "secret", "token", "key", "credential", "auth"} +) + +// WriteError writes a sanitized error response to the client. +// It extracts the trace ID from the context, logs the full error internally, +// and returns a sanitized message to the client. +func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, logger *logging.Logger) { + traceID := logging.TraceIDFromContext(r.Context()) + if traceID == "" { + traceID = generateTraceID() + } + + // Log the full error internally with all details + if logger != nil { + logger.Error("request failed", + "trace_id", traceID, + "method", r.Method, + "path", r.URL.Path, + "status", status, + "error", err.Error(), + "client_ip", getClientIP(r), + ) + } + + // Build sanitized response + resp := ErrorResponse{ + Error: sanitizeError(err.Error()), + Code: errorCodeFromStatus(status), + TraceID: traceID, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(resp) +} + +// WriteErrorMessage writes a sanitized error response with a custom message. +func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, message string, logger *logging.Logger) { + traceID := logging.TraceIDFromContext(r.Context()) + if traceID == "" { + traceID = generateTraceID() + } + + if logger != nil { + logger.Error("request failed", + "trace_id", traceID, + "method", r.Method, + "path", r.URL.Path, + "status", status, + "error", message, + ) + } + + resp := ErrorResponse{ + Error: sanitizeError(message), + Code: errorCodeFromStatus(status), + TraceID: traceID, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(resp) +} + +// sanitizeError removes potentially sensitive information from error messages. +// It prevents information leakage to clients while preserving useful context. +func sanitizeError(msg string) string { + if msg == "" { + return "An error occurred" + } + + // Remove file paths + msg = pathPattern.ReplaceAllString(msg, "[path]") + + // Remove sensitive keywords and their values + lowerMsg := strings.ToLower(msg) + for _, keyword := range sensitiveKeywords { + if strings.Contains(lowerMsg, keyword) { + return "An error occurred" + } + } + + // Remove internal error details + msg = strings.ReplaceAll(msg, "internal error", "an error occurred") + msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred") + + // Truncate if too long + if len(msg) > 200 { + msg = msg[:200] + "..." + } + + return msg +} + +// errorCodeFromStatus returns the appropriate error code for an HTTP status. +func errorCodeFromStatus(status int) string { + if code, ok := statusToCode[status]; ok { + return code + } + return ErrCodeInternal +} + +// getClientIP extracts the client IP from the request. +func getClientIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + return r.RemoteAddr +} + +// generateTraceID generates a new trace ID when one isn't in context. +func generateTraceID() string { + return fmt.Sprintf("%d", time.Now().UnixNano()) +} diff --git a/internal/api/routes.go b/internal/api/routes.go index d8df8ed..d6dc88d 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -2,12 +2,14 @@ package api import ( "net/http" + "os" "github.com/jfraeys/fetch_ml/internal/api/datasets" "github.com/jfraeys/fetch_ml/internal/api/jobs" "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/prommetrics" + "github.com/labstack/echo/v4" ) // registerRoutes sets up all HTTP routes and handlers @@ -34,9 +36,6 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { // Register WebSocket endpoint s.registerWebSocketRoutes(mux) - // Register HTTP API handlers - s.handlers.RegisterHandlers(mux) - // Register new REST API endpoints for TUI jobsHandler := jobs.NewHandler( s.expManager, @@ -52,13 +51,73 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { // Team jobs endpoint: GET /api/jobs?all_users=true mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP) + + // Register OpenAPI-generated routes with Echo router + s.registerOpenAPIRoutes(mux, jobsHandler) + + // Register API documentation endpoint + s.registerDocsRoutes(mux) +} + +// registerDocsRoutes sets up API documentation serving +func (s *Server) registerDocsRoutes(mux *http.ServeMux) { + // Check if docs directory exists + if _, err := os.Stat("docs/api"); err == nil { + docsFS := http.FileServer(http.Dir("docs/api")) + mux.Handle("/docs/", http.StripPrefix("/docs/", docsFS)) + s.logger.Info("API documentation endpoint registered", "path", "/docs/") + } else { + s.logger.Debug("API documentation not available", "path", "docs/api") + } +} + +// registerOpenAPIRoutes sets up Echo router with generated OpenAPI handlers +func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Handler) { + // Create Echo instance for OpenAPI-generated routes + e := echo.New() + + // Create jupyter handler + jupyterHandler := jupyter.NewHandler( + s.logger, + s.jupyterServiceMgr, + s.config.BuildAuthConfig(), + ) + + // Create datasets handler + datasetsHandler := datasets.NewHandler( + s.logger, + s.db, + s.config.DataDir, + ) + + // Create adapter implementing ServerInterface + handlerAdapter := NewHandlerAdapter( + jobsHandler, + jupyterHandler, + datasetsHandler, + ) + + // Register generated OpenAPI routes + RegisterHandlers(e, handlerAdapter) + + // Wrap Echo router to work with existing middleware chain + echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + e.ServeHTTP(w, r) + }) + + // Register Echo router at /v1/ prefix (and other generated paths) + // These paths take precedence over legacy routes + mux.Handle("/health", echoHandler) + mux.Handle("/v1/", echoHandler) + mux.Handle("/ws", echoHandler) + + s.logger.Info("OpenAPI-generated routes registered with Echo router") } // registerHealthRoutes sets up health check endpoints func (s *Server) registerHealthRoutes(mux *http.ServeMux) { healthHandler := NewHealthHandler(s) healthHandler.RegisterRoutes(mux) - mux.HandleFunc("/health/ok", s.handlers.handleHealth) s.logger.Info("health check endpoints registered") } diff --git a/internal/api/server.go b/internal/api/server.go index 57889db..e95f90c 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -8,6 +8,7 @@ import ( "syscall" "time" + apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" @@ -20,18 +21,18 @@ import ( // Server represents the API server type Server struct { - config *ServerConfig - httpServer *http.Server - logger *logging.Logger - expManager *experiment.Manager - taskQueue queue.Backend - db *storage.DB - handlers *Handlers - sec *middleware.SecurityMiddleware - cleanupFuncs []func() - jupyterServiceMgr *jupyter.ServiceManager - auditLogger *audit.Logger - promMetrics *prommetrics.Metrics // Prometheus metrics + config *ServerConfig + httpServer *http.Server + logger *logging.Logger + expManager *experiment.Manager + taskQueue queue.Backend + db *storage.DB + sec *middleware.SecurityMiddleware + cleanupFuncs []func() + jupyterServiceMgr *jupyter.ServiceManager + auditLogger *audit.Logger + promMetrics *prommetrics.Metrics // Prometheus metrics + validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation } // NewServer creates a new API server diff --git a/internal/api/server_gen.go b/internal/api/server_gen.go new file mode 100644 index 0000000..a559bcc --- /dev/null +++ b/internal/api/server_gen.go @@ -0,0 +1,642 @@ +// Package api provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT. +package api + +import ( + "bytes" + "compress/gzip" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "time" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/labstack/echo/v4" + "github.com/oapi-codegen/runtime" +) + +const ( + ApiKeyAuthScopes = "ApiKeyAuth.Scopes" +) + +// Defines values for ErrorResponseCode. +const ( + BADREQUEST ErrorResponseCode = "BAD_REQUEST" + CONFLICT ErrorResponseCode = "CONFLICT" + FORBIDDEN ErrorResponseCode = "FORBIDDEN" + INTERNALERROR ErrorResponseCode = "INTERNAL_ERROR" + NOTFOUND ErrorResponseCode = "NOT_FOUND" + RATELIMITED ErrorResponseCode = "RATE_LIMITED" + SERVICEUNAVAILABLE ErrorResponseCode = "SERVICE_UNAVAILABLE" + UNAUTHORIZED ErrorResponseCode = "UNAUTHORIZED" + VALIDATIONERROR ErrorResponseCode = "VALIDATION_ERROR" +) + +// Defines values for ExperimentStatus. +const ( + Active ExperimentStatus = "active" + Archived ExperimentStatus = "archived" + Deleted ExperimentStatus = "deleted" +) + +// Defines values for HealthResponseStatus. +const ( + Degraded HealthResponseStatus = "degraded" + Healthy HealthResponseStatus = "healthy" + Unhealthy HealthResponseStatus = "unhealthy" +) + +// Defines values for JupyterServiceStatus. +const ( + JupyterServiceStatusError JupyterServiceStatus = "error" + JupyterServiceStatusRunning JupyterServiceStatus = "running" + JupyterServiceStatusStarting JupyterServiceStatus = "starting" + JupyterServiceStatusStopped JupyterServiceStatus = "stopped" + JupyterServiceStatusStopping JupyterServiceStatus = "stopping" +) + +// Defines values for TaskStatus. +const ( + TaskStatusCollecting TaskStatus = "collecting" + TaskStatusCompleted TaskStatus = "completed" + TaskStatusFailed TaskStatus = "failed" + TaskStatusPreparing TaskStatus = "preparing" + TaskStatusQueued TaskStatus = "queued" + TaskStatusRunning TaskStatus = "running" +) + +// Defines values for GetV1TasksParamsStatus. +const ( + Completed GetV1TasksParamsStatus = "completed" + Failed GetV1TasksParamsStatus = "failed" + Queued GetV1TasksParamsStatus = "queued" + Running GetV1TasksParamsStatus = "running" +) + +// CreateExperimentRequest defines model for CreateExperimentRequest. +type CreateExperimentRequest struct { + Description *string `json:"description,omitempty"` + Name string `json:"name"` +} + +// CreateTaskRequest defines model for CreateTaskRequest. +type CreateTaskRequest struct { + // Args Command-line arguments for the training script + Args *string `json:"args,omitempty"` + + // Cpu CPU cores requested + Cpu *int `json:"cpu,omitempty"` + DatasetSpecs *[]DatasetSpec `json:"dataset_specs,omitempty"` + Datasets *[]string `json:"datasets,omitempty"` + + // Gpu GPUs requested + Gpu *int `json:"gpu,omitempty"` + + // JobName Unique identifier for the job + JobName string `json:"job_name"` + + // MemoryGb Memory (GB) requested + MemoryGb *int `json:"memory_gb,omitempty"` + Metadata *map[string]string `json:"metadata,omitempty"` + Priority *int `json:"priority,omitempty"` + + // SnapshotId Reference to experiment snapshot + SnapshotId *string `json:"snapshot_id,omitempty"` +} + +// DatasetSpec defines model for DatasetSpec. +type DatasetSpec struct { + MountPath *string `json:"mount_path,omitempty"` + Name *string `json:"name,omitempty"` + Sha256 *string `json:"sha256,omitempty"` + Source *string `json:"source,omitempty"` +} + +// ErrorResponse defines model for ErrorResponse. +type ErrorResponse struct { + Code ErrorResponseCode `json:"code"` + + // Error Sanitized error message + Error string `json:"error"` + + // TraceId Support correlation ID + TraceId string `json:"trace_id"` +} + +// ErrorResponseCode defines model for ErrorResponse.Code. +type ErrorResponseCode string + +// Experiment defines model for Experiment. +type Experiment struct { + CommitId *string `json:"commit_id,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *string `json:"id,omitempty"` + Name *string `json:"name,omitempty"` + Status *ExperimentStatus `json:"status,omitempty"` +} + +// ExperimentStatus defines model for Experiment.Status. +type ExperimentStatus string + +// HealthResponse defines model for HealthResponse. +type HealthResponse struct { + Status *HealthResponseStatus `json:"status,omitempty"` + Timestamp *time.Time `json:"timestamp,omitempty"` + Version *string `json:"version,omitempty"` +} + +// HealthResponseStatus defines model for HealthResponse.Status. +type HealthResponseStatus string + +// JupyterService defines model for JupyterService. +type JupyterService struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *string `json:"id,omitempty"` + Name *string `json:"name,omitempty"` + Status *JupyterServiceStatus `json:"status,omitempty"` + Token *string `json:"token,omitempty"` + Url *string `json:"url,omitempty"` +} + +// JupyterServiceStatus defines model for JupyterService.Status. +type JupyterServiceStatus string + +// QueueStats defines model for QueueStats. +type QueueStats struct { + // Completed Tasks completed today + Completed *int `json:"completed,omitempty"` + + // Failed Tasks failed today + Failed *int `json:"failed,omitempty"` + + // Queued Tasks waiting to run + Queued *int `json:"queued,omitempty"` + + // Running Tasks currently executing + Running *int `json:"running,omitempty"` + + // Workers Active workers + Workers *int `json:"workers,omitempty"` +} + +// StartJupyterRequest defines model for StartJupyterRequest. +type StartJupyterRequest struct { + Image *string `json:"image,omitempty"` + Name string `json:"name"` + Workspace *string `json:"workspace,omitempty"` +} + +// Task defines model for Task. +type Task struct { + Cpu *int `json:"cpu,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` + Datasets *[]string `json:"datasets,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` + Error *string `json:"error,omitempty"` + Gpu *int `json:"gpu,omitempty"` + + // Id Unique task identifier + Id *string `json:"id,omitempty"` + JobName *string `json:"job_name,omitempty"` + MaxRetries *int `json:"max_retries,omitempty"` + MemoryGb *int `json:"memory_gb,omitempty"` + Output *string `json:"output,omitempty"` + Priority *int `json:"priority,omitempty"` + RetryCount *int `json:"retry_count,omitempty"` + SnapshotId *string `json:"snapshot_id,omitempty"` + StartedAt *time.Time `json:"started_at,omitempty"` + Status *TaskStatus `json:"status,omitempty"` + UserId *string `json:"user_id,omitempty"` + WorkerId *string `json:"worker_id,omitempty"` +} + +// TaskStatus defines model for Task.Status. +type TaskStatus string + +// TaskList defines model for TaskList. +type TaskList struct { + Limit *int `json:"limit,omitempty"` + Offset *int `json:"offset,omitempty"` + Tasks *[]Task `json:"tasks,omitempty"` + Total *int `json:"total,omitempty"` +} + +// BadRequest defines model for BadRequest. +type BadRequest = ErrorResponse + +// NotFound defines model for NotFound. +type NotFound = ErrorResponse + +// RateLimited defines model for RateLimited. +type RateLimited = ErrorResponse + +// Unauthorized defines model for Unauthorized. +type Unauthorized = ErrorResponse + +// ValidationError defines model for ValidationError. +type ValidationError = ErrorResponse + +// GetV1TasksParams defines parameters for GetV1Tasks. +type GetV1TasksParams struct { + Status *GetV1TasksParamsStatus `form:"status,omitempty" json:"status,omitempty"` + Limit *int `form:"limit,omitempty" json:"limit,omitempty"` + Offset *int `form:"offset,omitempty" json:"offset,omitempty"` +} + +// GetV1TasksParamsStatus defines parameters for GetV1Tasks. +type GetV1TasksParamsStatus string + +// PostV1ExperimentsJSONRequestBody defines body for PostV1Experiments for application/json ContentType. +type PostV1ExperimentsJSONRequestBody = CreateExperimentRequest + +// PostV1JupyterServicesJSONRequestBody defines body for PostV1JupyterServices for application/json ContentType. +type PostV1JupyterServicesJSONRequestBody = StartJupyterRequest + +// PostV1TasksJSONRequestBody defines body for PostV1Tasks for application/json ContentType. +type PostV1TasksJSONRequestBody = CreateTaskRequest + +// ServerInterface represents all server handlers. +type ServerInterface interface { + // Health check + // (GET /health) + GetHealth(ctx echo.Context) error + // List experiments + // (GET /v1/experiments) + GetV1Experiments(ctx echo.Context) error + // Create experiment + // (POST /v1/experiments) + PostV1Experiments(ctx echo.Context) error + // List Jupyter services + // (GET /v1/jupyter/services) + GetV1JupyterServices(ctx echo.Context) error + // Start Jupyter service + // (POST /v1/jupyter/services) + PostV1JupyterServices(ctx echo.Context) error + // Stop Jupyter service + // (DELETE /v1/jupyter/services/{serviceId}) + DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error + // Queue status + // (GET /v1/queue) + GetV1Queue(ctx echo.Context) error + // List tasks + // (GET /v1/tasks) + GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error + // Create task + // (POST /v1/tasks) + PostV1Tasks(ctx echo.Context) error + // Cancel/delete task + // (DELETE /v1/tasks/{taskId}) + DeleteV1TasksTaskId(ctx echo.Context, taskId string) error + // Get task details + // (GET /v1/tasks/{taskId}) + GetV1TasksTaskId(ctx echo.Context, taskId string) error + // WebSocket connection + // (GET /ws) + GetWs(ctx echo.Context) error +} + +// ServerInterfaceWrapper converts echo contexts to parameters. +type ServerInterfaceWrapper struct { + Handler ServerInterface +} + +// GetHealth converts echo context to params. +func (w *ServerInterfaceWrapper) GetHealth(ctx echo.Context) error { + var err error + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetHealth(ctx) + return err +} + +// GetV1Experiments converts echo context to params. +func (w *ServerInterfaceWrapper) GetV1Experiments(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetV1Experiments(ctx) + return err +} + +// PostV1Experiments converts echo context to params. +func (w *ServerInterfaceWrapper) PostV1Experiments(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostV1Experiments(ctx) + return err +} + +// GetV1JupyterServices converts echo context to params. +func (w *ServerInterfaceWrapper) GetV1JupyterServices(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetV1JupyterServices(ctx) + return err +} + +// PostV1JupyterServices converts echo context to params. +func (w *ServerInterfaceWrapper) PostV1JupyterServices(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostV1JupyterServices(ctx) + return err +} + +// DeleteV1JupyterServicesServiceId converts echo context to params. +func (w *ServerInterfaceWrapper) DeleteV1JupyterServicesServiceId(ctx echo.Context) error { + var err error + // ------------- Path parameter "serviceId" ------------- + var serviceId string + + err = runtime.BindStyledParameterWithOptions("simple", "serviceId", ctx.Param("serviceId"), &serviceId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter serviceId: %s", err)) + } + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.DeleteV1JupyterServicesServiceId(ctx, serviceId) + return err +} + +// GetV1Queue converts echo context to params. +func (w *ServerInterfaceWrapper) GetV1Queue(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetV1Queue(ctx) + return err +} + +// GetV1Tasks converts echo context to params. +func (w *ServerInterfaceWrapper) GetV1Tasks(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Parameter object where we will unmarshal all parameters from the context + var params GetV1TasksParams + // ------------- Optional query parameter "status" ------------- + + err = runtime.BindQueryParameter("form", true, false, "status", ctx.QueryParams(), ¶ms.Status) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter status: %s", err)) + } + + // ------------- Optional query parameter "limit" ------------- + + err = runtime.BindQueryParameter("form", true, false, "limit", ctx.QueryParams(), ¶ms.Limit) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter limit: %s", err)) + } + + // ------------- Optional query parameter "offset" ------------- + + err = runtime.BindQueryParameter("form", true, false, "offset", ctx.QueryParams(), ¶ms.Offset) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter offset: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetV1Tasks(ctx, params) + return err +} + +// PostV1Tasks converts echo context to params. +func (w *ServerInterfaceWrapper) PostV1Tasks(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostV1Tasks(ctx) + return err +} + +// DeleteV1TasksTaskId converts echo context to params. +func (w *ServerInterfaceWrapper) DeleteV1TasksTaskId(ctx echo.Context) error { + var err error + // ------------- Path parameter "taskId" ------------- + var taskId string + + err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err)) + } + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.DeleteV1TasksTaskId(ctx, taskId) + return err +} + +// GetV1TasksTaskId converts echo context to params. +func (w *ServerInterfaceWrapper) GetV1TasksTaskId(ctx echo.Context) error { + var err error + // ------------- Path parameter "taskId" ------------- + var taskId string + + err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err)) + } + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetV1TasksTaskId(ctx, taskId) + return err +} + +// GetWs converts echo context to params. +func (w *ServerInterfaceWrapper) GetWs(ctx echo.Context) error { + var err error + + ctx.Set(ApiKeyAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetWs(ctx) + return err +} + +// This is a simple interface which specifies echo.Route addition functions which +// are present on both echo.Echo and echo.Group, since we want to allow using +// either of them for path registration +type EchoRouter interface { + CONNECT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + HEAD(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + OPTIONS(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + PATCH(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + TRACE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route +} + +// RegisterHandlers adds each server route to the EchoRouter. +func RegisterHandlers(router EchoRouter, si ServerInterface) { + RegisterHandlersWithBaseURL(router, si, "") +} + +// Registers handlers, and prepends BaseURL to the paths, so that the paths +// can be served under a prefix. +func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) { + + wrapper := ServerInterfaceWrapper{ + Handler: si, + } + + router.GET(baseURL+"/health", wrapper.GetHealth) + router.GET(baseURL+"/v1/experiments", wrapper.GetV1Experiments) + router.POST(baseURL+"/v1/experiments", wrapper.PostV1Experiments) + router.GET(baseURL+"/v1/jupyter/services", wrapper.GetV1JupyterServices) + router.POST(baseURL+"/v1/jupyter/services", wrapper.PostV1JupyterServices) + router.DELETE(baseURL+"/v1/jupyter/services/:serviceId", wrapper.DeleteV1JupyterServicesServiceId) + router.GET(baseURL+"/v1/queue", wrapper.GetV1Queue) + router.GET(baseURL+"/v1/tasks", wrapper.GetV1Tasks) + router.POST(baseURL+"/v1/tasks", wrapper.PostV1Tasks) + router.DELETE(baseURL+"/v1/tasks/:taskId", wrapper.DeleteV1TasksTaskId) + router.GET(baseURL+"/v1/tasks/:taskId", wrapper.GetV1TasksTaskId) + router.GET(baseURL+"/ws", wrapper.GetWs) + +} + +// Base64 encoded, gzipped, json marshaled Swagger object +var swaggerSpec = []string{ + + "H4sIAAAAAAAC/8RaW1fbuhL+K1ra+6FdxyFAL2c3bwHSNqfh0gTaszbhBGFPEoEtqZIMZLPy38+SZCe+", + "yEDphSdIJI1nPn2a+TTOHQ55IjgDphXu3GEJSnCmwH7YIdEQvqWgtPkUcqaB2X+JEDENiaactS8VZ+Y7", + "uCWJiMHNjAB38E53bzLsfT7pjY5xgEFKLnEH99k1iWmEpLOMplwmROMAa0lCmNAIdzDZutgOX0WvW/Bm", + "+rb177/ebbbIRRi1YLq1/er1m7fmG7wMsArnkBDzyD8lTHEH/9Feh9N2o6rdM08eZoHh5XIZ4AhUKKkw", + "AdRdMpYPuH7PUxY9KfCDw+PJ+8OTg71C2ENQPJUhIMZNzMb0c4bscWcZ4CHRMKAJ1fC0wIfd495k0N/v", + "H/dKsRMNKDZ2EdyGABE8b/DHnKOEsEW+4QoHeA4kAmlpPwQtF63uVIM0H8trRxByFimUMk1jJNeRSVBg", + "La2d1AthUKFMwwyk8WQZ4BNGUj3nkv7zRJBPDronxx8Ph/2/SyDnJOYSJVQpymaoe9RHV7B4Vqy7qZ4D", + "01lYFnEqwbLti/HXft1zMTwBjC/dQX+ve9w/PJj0hsPDYQGQtXk0JTR+Zs7VvVmujFvW7UogGnq3AiRN", + "gOlC5hWSC5CauqxcsrsimdKSsplxmJHEIpSQ2wGwmZ7jztb2X0F14jLAq83onLpVZ6tZ/OISQpsJnV/H", + "RF01ekTkrO4Z3uVJQljUiikDROQsNVEpk++RngPSklBmWOrW4KAeSShSj9mjExRyCSo/vW5jK0ctwBHR", + "RIGeKAGh9Y5qSNRD+7jnVo0EhMZIZpZISRYFo2V7Nb+rq2a+OD4cnTwUwiW/mOS7WV58wui3FBCNzNGa", + "UpArWC/5BQ6Ke//2dYAF0RqkWfi/U9L6p9v6e7P1btI6+9efPtgTSLhcTGYX9efu2yH04sPOywd8T0AT", + "g5blRxRRY4DERyXeNAG3Jp+QlEuqF86TKUljjTtvbHw0SRPc2doMcEJZ9sHniGJEqDnX9tTf1crgFCSw", + "EJDmCFZnD+WL8EPHZrVFvqNTJFPt0CQ8ZXoiiNmje05xbUDNyfabt/4hW9I9Qz5oyxmr5p7Lr3cYmEH2", + "tCLmKjXo/eFwp7+31zvAQUn87B4evB/0d82KijToHxz3hgfdwSprj3rDL/3d3uTkoPul2x90dwY9HNTT", + "+5mHr5BXj0qlJoxqU2WRnYASUIrMwMf4dVmoGUmF4FKbjCMhdvm7v/cgLZxLgUOxYN/HknXK9+1CktCc", + "uvX8aFNzNCF2ZSakOyZJQUvTxBtqg6lmummiU1VkAgk1vTa2iQzn9Nqe/whiMJngLHgM9T4CifW8mXv1", + "Z87tioV90kwSJyBTln/tY4UBQGmSiMdjcw1S+auqL4r/pGKhQY5AXtPQd4KeZXeUJlKb0QDLlDH3n9Jc", + "iMK/Fj7HUS90/Ar82iKVcSmYVFL8qC3/nEIKI01c6ayRXDj61I6fkR0KrSYgzSOy8NabTFQ1WHCj9yz/", + "ZvxrXH5DqMHUFAmZMq+BHOymEFIpgel4geAWwjTbobqZGy6vsmtIRUfbU4fy8fpaH+ojQ4aMp43qjSYm", + "KxYLLL50S9piobkM552YaLM0+A5mGkeVIE3l6FHS00DnoYtTU3XsnnLgnibogEXf+ZxVjaqNzJrC8ZWj", + "TPdpoq4K4s/3wKJ2/DE1SG4nErQsa7aS0ivIxfowT7VItTf0n6DujGeLSWi0lP/xFfnny6Dye0lTz7pZ", + "+jAhgSCymn9DHscQ6vxDnu5WWcuXhFMFsslplwX8o8uGgzSgvsNvexcNGzedKmgYMwR8/KXKnmPPMdJc", + "k9jbKKnFYFCHMDVsGRmrzvuuoJ9g0U2dhq4kTNf8sBcjUmpC4ABTM8O1fHCexPB/W92jfusTFCoEsQ9w", + "93nKpjzvUZDQApMtfA86nO8PUCYWcb0JctS3fiSEkZkpJPuD4m3DookIi1CWq5FyokJtjNmY/fEHGmWx", + "j1k3jhGwSHBqLtMv4DYEoZETQiicQ3ilXuZdlrwBVIkfXVNi7opjdr4K+Rw5NDbQullnHKUKAZtyGUKE", + "BMjcYu6XvUSgj4RFMWWzsWvkmDt+HPMbRFDImaJKmyDd0UI3VM9RQsI5ZdCSQCJyEQMyOtkhYKUy6u+p", + "zpidn59fKs7G7G7MEBq7JDrGHTRuEvdjHLipxqCbae8Nk93Dvd5qMJfjbkKa0qg15bKl3PaN8Zgt7cPH", + "ttRTHZtN3h+gr/bYGQxwQSzirY3NDdsq4gIYERR38KuNzY1X2KbauWVq2+2Q+XfmDlX1IqpTyZTdeJD5", + "frpEs4EOeHUL8/q5ge1jpf22H+EO/gDa6WuTgYrd/O3NzUe02B7X7qooeE+/a+QCoQrlIr14hnHn9CzA", + "Kk0SIhe4k90IHH/txPb1Vnt9QlQjbiarIWLORGGyB5IvW73ShB9C5lFJr3Ctq6W+Olw2Dj4thWEBW0Fk", + "Z5SGAyy48mDiOnaIIAY3hRU1WI648uBiVeIOjxY/jSxNnc1lWQhqmcKytjNbP82N4obUN2A9ijIdWcE/", + "QxWKVhxPc7WcJ+0CWT08LF8bfw8XK1fV7+BjtSD5SFmfs2amj3E+BH4+63x3n9/MuCrsdZgr0KFMjVZQ", + "tqFUYW6kX/su+68fLV1yMGqzvhl79vvadozyxbZ8SZKAthfS0zunmmzHcqWZVGF2GVjPi7CVRD2rgf7a", + "96Ytx8R1LKqYcNEIiZXjjTXjA+j8Po7sTFtnqdI0bKgdtnvxK+tpoT3iYcnnqpNlKNbDqVpBsFLo95dN", + "Jz2tKOPC9ejRlMYasjuMBwzb0Gggx7cU5KLADudTkQq1G1PxlvSIi9Ey8D/KXWSKT1rfKDdLV8rNTV/z", + "xG81uwN5zfrMnP1Cjqwucfeka7frywC/dg/22Vs52C78vsMu2Xp4SenttVm0/e7hRcWfFXjqx8ppv5wZ", + "pRcJ1ZmcqV+cGmRNztJfJ2iKr0R/c2FxF2rPDxuIusrlC1JpGIJS0zSOF7+XEtsPL6q++/9xKmX6TDto", + "CkmwfWf+PLIYWtoc2/mPqn86n/qTi5/bScJCiOMMVjftfnhWP1uqYGMNtV3wGUTBPRr1WVDY/D3HIwJN", + "aKx+FFIjInTJnuHcTXPJ/QoXIx5egV61b2xLSAKJbaPRWUtFRPS677Pv2hroeCFAjVkLnZtZEzfrvINs", + "RK7KonBO2Kw4K6+n+bwpZVTNIbIzBGWz8w76BCBaJKbXgF64mCOnBs4FZ7Pzl7YFUmPI19q1ZctliqaQ", + "Q84YhLZzAUqTi9g6Um0JlBt6p2fLUo/AZ83t8kMmbCvCkbdSNnlIYhTBNcRcuBf/di7O3nThudai027H", + "Zt6cK915ZwI1aqFs6EjyKHXxeSyoTrtNBN2Ygg7nSbyR/YxpI+QJXp4t/x8AAP//HXQa3YQpAAA=", +} + +// GetSwagger returns the content of the embedded swagger specification file +// or error if failed to decode +func decodeSpec() ([]byte, error) { + zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, "")) + if err != nil { + return nil, fmt.Errorf("error base64 decoding spec: %w", err) + } + zr, err := gzip.NewReader(bytes.NewReader(zipped)) + if err != nil { + return nil, fmt.Errorf("error decompressing spec: %w", err) + } + var buf bytes.Buffer + _, err = buf.ReadFrom(zr) + if err != nil { + return nil, fmt.Errorf("error decompressing spec: %w", err) + } + + return buf.Bytes(), nil +} + +var rawSpec = decodeSpecCached() + +// a naive cached of a decoded swagger spec +func decodeSpecCached() func() ([]byte, error) { + data, err := decodeSpec() + return func() ([]byte, error) { + return data, err + } +} + +// Constructs a synthetic filesystem for resolving external references when loading openapi specifications. +func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) { + res := make(map[string]func() ([]byte, error)) + if len(pathToFile) > 0 { + res[pathToFile] = rawSpec + } + + return res +} + +// GetSwagger returns the Swagger specification corresponding to the generated code +// in this file. The external references of Swagger specification are resolved. +// The logic of resolving external references is tightly connected to "import-mapping" feature. +// Externally referenced files must be embedded in the corresponding golang packages. +// Urls can be supported but this task was out of the scope. +func GetSwagger() (swagger *openapi3.T, err error) { + resolvePath := PathToRawSpec("") + + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) { + pathToFile := url.String() + pathToFile = path.Clean(pathToFile) + getSpec, ok := resolvePath[pathToFile] + if !ok { + err1 := fmt.Errorf("path not found: %s", pathToFile) + return nil, err1 + } + return getSpec() + } + var specData []byte + specData, err = rawSpec() + if err != nil { + return + } + swagger, err = loader.LoadFromData(specData) + if err != nil { + return + } + return +} diff --git a/internal/auth/crypto.go b/internal/auth/crypto.go index 991e421..1762ca3 100644 --- a/internal/auth/crypto.go +++ b/internal/auth/crypto.go @@ -23,11 +23,11 @@ const ( // Argon2id parameters (OWASP recommended minimum) const ( - argon2idTime = 1 // Iterations - argon2idMemory = 64 * 1024 // 64 MB - argon2idThreads = 4 // Parallelism - argon2idKeyLen = 32 // 256-bit output - argon2idSaltLen = 16 // 128-bit salt + argon2idTime = 1 // Iterations + argon2idMemory = 64 * 1024 // 64 MB + argon2idThreads = 4 // Parallelism + argon2idKeyLen = 32 // 256-bit output + argon2idSaltLen = 16 // 128-bit salt ) // HashedKey represents a hashed API key with its algorithm and salt diff --git a/internal/config/paths.go b/internal/config/paths.go index 8c4c324..91adeec 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -21,35 +21,39 @@ func NewPathRegistry(root string) *PathRegistry { } // Binary paths -func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") } -func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") } -func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") } -func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") } -func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") } +func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") } +func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") } +func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") } +func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") } +func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") } // Data paths -func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") } -func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") } -func (p *PathRegistry) JupyterStateDir() string { return filepath.Join(p.DataDir(), "active", "jupyter") } -func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") } -func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") } +func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") } +func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") } +func (p *PathRegistry) JupyterStateDir() string { + return filepath.Join(p.DataDir(), "active", "jupyter") +} +func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") } +func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") } // Database paths -func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") } -func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") } +func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") } +func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") } // Log paths -func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") } -func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") } +func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") } +func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") } // Config paths -func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") } -func (p *PathRegistry) APIServerConfig() string { return filepath.Join(p.ConfigDir(), "api", "dev.yaml") } -func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") } +func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") } +func (p *PathRegistry) APIServerConfig() string { + return filepath.Join(p.ConfigDir(), "api", "dev.yaml") +} +func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") } // Test paths -func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") } -func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") } +func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") } +func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") } // State file paths (for service persistence) func (p *PathRegistry) JupyterServicesFile() string { @@ -83,13 +87,13 @@ func detectRepoRoot() string { if err != nil { return "." } - + // Walk up directory tree looking for go.mod for { if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { return dir } - + parent := filepath.Dir(dir) if parent == dir { // Reached root @@ -97,7 +101,7 @@ func detectRepoRoot() string { } dir = parent } - + return "." } diff --git a/internal/config/secrets.go b/internal/config/secrets.go index 23ed7ce..7755715 100644 --- a/internal/config/secrets.go +++ b/internal/config/secrets.go @@ -23,7 +23,9 @@ func NewEnvSecretsManager() *EnvSecretsManager { return &EnvSecretsManager{} } func (e *EnvSecretsManager) Get(ctx context.Context, key string) (string, error) { value := os.Getenv(key) - if value == "" { return "", fmt.Errorf("secret %s not found", key) } + if value == "" { + return "", fmt.Errorf("secret %s not found", key) + } return value, nil } @@ -47,6 +49,8 @@ func (e *EnvSecretsManager) List(ctx context.Context, prefix string) ([]string, // RedactSecret masks a secret for safe logging func RedactSecret(secret string) string { - if len(secret) <= 8 { return "***" } + if len(secret) <= 8 { + return "***" + } return secret[:4] + "..." + secret[len(secret)-4:] } diff --git a/internal/domain/status.go b/internal/domain/status.go index 58f83de..1e0a00f 100644 --- a/internal/domain/status.go +++ b/internal/domain/status.go @@ -4,11 +4,11 @@ package domain type JobStatus string const ( - StatusPending JobStatus = "pending" - StatusQueued JobStatus = "queued" - StatusRunning JobStatus = "running" - StatusCompleted JobStatus = "completed" - StatusFailed JobStatus = "failed" + StatusPending JobStatus = "pending" + StatusQueued JobStatus = "queued" + StatusRunning JobStatus = "running" + StatusCompleted JobStatus = "completed" + StatusFailed JobStatus = "failed" ) // String returns the string representation of the status diff --git a/internal/logging/logging.go b/internal/logging/logging.go index b46bb5c..f5abae2 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -164,6 +164,15 @@ func (l *Logger) Job(ctx context.Context, job string, task string) *Logger { return l.WithContext(ctx, "job_name", job, "task_id", task) } +// TraceIDFromContext extracts the trace ID from a context. +// Returns empty string if no trace ID is present. +func TraceIDFromContext(ctx context.Context) string { + if id, ok := ctx.Value(CtxTraceID).(string); ok { + return id + } + return "" +} + // Fatal logs an error message and exits with status 1. func (l *Logger) Fatal(msg string, args ...any) { l.Error(msg, args...) diff --git a/internal/queue/event_store.go b/internal/queue/event_store.go index 7ed2ea0..4bb0819 100644 --- a/internal/queue/event_store.go +++ b/internal/queue/event_store.go @@ -14,19 +14,19 @@ import ( // EventStore provides append-only event storage using Redis Streams. // Events are stored chronologically and can be queried for audit trails. type EventStore struct { - client *redis.Client - ctx context.Context - retentionDays int - maxStreamLen int64 + client *redis.Client + ctx context.Context + retentionDays int + maxStreamLen int64 } // EventStoreConfig holds configuration for the event store. type EventStoreConfig struct { - RedisAddr string - RedisPassword string - RedisDB int - RetentionDays int // How long to keep events (default: 7) - MaxStreamLen int64 // Max events per task stream (default: 1000) + RedisAddr string + RedisPassword string + RedisDB int + RetentionDays int // How long to keep events (default: 7) + MaxStreamLen int64 // Max events per task stream (default: 1000) } // NewEventStore creates a new event store instance. diff --git a/internal/queue/redis/queue.go b/internal/queue/redis/queue.go index d1cb00c..5b913d2 100644 --- a/internal/queue/redis/queue.go +++ b/internal/queue/redis/queue.go @@ -29,7 +29,7 @@ type metricEvent struct { } // Config holds configuration for Queue - type Config struct { +type Config struct { RedisAddr string RedisPassword string RedisDB int @@ -48,9 +48,9 @@ func NewQueue(cfg Config) (*Queue, error) { } } else { opts = &redis.Options{ - Addr: cfg.RedisAddr, - Password: cfg.RedisPassword, - DB: cfg.RedisDB, + Addr: cfg.RedisAddr, + Password: cfg.RedisPassword, + DB: cfg.RedisDB, PoolSize: 50, MinIdleConns: 10, MaxRetries: 3, diff --git a/internal/storage/db_metrics.go b/internal/storage/db_metrics.go index 56148ae..7ef3f68 100644 --- a/internal/storage/db_metrics.go +++ b/internal/storage/db_metrics.go @@ -17,14 +17,14 @@ type Metric struct { // MetricSummary represents aggregated metric statistics type MetricSummary struct { - Name string `json:"name"` - Count int64 `json:"count"` - Avg float64 `json:"avg"` - Min float64 `json:"min"` - Max float64 `json:"max"` - Sum float64 `json:"sum"` - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"end_time"` + Name string `json:"name"` + Count int64 `json:"count"` + Avg float64 `json:"avg"` + Min float64 `json:"min"` + Max float64 `json:"max"` + Sum float64 `json:"sum"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` } // RecordMetric records a metric to the database diff --git a/internal/worker/errors/execution.go b/internal/worker/errors/execution.go new file mode 100644 index 0000000..2c0e885 --- /dev/null +++ b/internal/worker/errors/execution.go @@ -0,0 +1,106 @@ +// Package errors provides structured error types for the worker. +package errors + +import ( + "fmt" + "time" +) + +// ExecutionError provides structured error context for task execution failures. +// It captures the task ID, execution phase, specific operation, root cause, +// and additional context to make debugging easier. +type ExecutionError struct { + TaskID string // The task that failed + Phase string // Current TaskState (queued, preparing, running, collecting) + Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset") + Cause error // The underlying error + Context map[string]string // Additional context (paths, IDs, etc.) + Timestamp time.Time // When the error occurred +} + +// Error implements the error interface with a formatted message. +func (e ExecutionError) Error() string { + return fmt.Sprintf("[%s/%s] task=%s: %v", e.Phase, e.Operation, e.TaskID, e.Cause) +} + +// Unwrap returns the underlying error for error chain inspection. +func (e ExecutionError) Unwrap() error { + return e.Cause +} + +// WithContext adds a key-value pair to the error context. +func (e ExecutionError) WithContext(key, value string) ExecutionError { + if e.Context == nil { + e.Context = make(map[string]string) + } + e.Context[key] = value + return e +} + +// ContextString returns a formatted string of all context values. +func (e ExecutionError) ContextString() string { + if len(e.Context) == 0 { + return "" + } + result := "" + for k, v := range e.Context { + if result != "" { + result += ", " + } + result += fmt.Sprintf("%s=%s", k, v) + } + return result +} + +// NewExecutionError creates a new ExecutionError with the given parameters. +func NewExecutionError(taskID, phase, operation string, cause error) ExecutionError { + return ExecutionError{ + TaskID: taskID, + Phase: phase, + Operation: operation, + Cause: cause, + Context: make(map[string]string), + Timestamp: time.Now(), + } +} + +// IsExecutionError checks if an error is an ExecutionError. +func IsExecutionError(err error) bool { + _, ok := err.(ExecutionError) + return ok +} + +// Common error operations for the worker lifecycle. +const ( + // Preparation operations + OpCreateWorkspace = "create_workspace" + OpFetchDataset = "fetch_dataset" + OpMountVolume = "mount_volume" + OpSetupEnvironment = "setup_environment" + OpStageSnapshot = "stage_snapshot" + + // Execution operations + OpStartContainer = "start_container" + OpExecuteCommand = "execute_command" + OpMonitorExecution = "monitor_execution" + + // Collection operations + OpCollectResults = "collect_results" + OpUploadArtifacts = "upload_artifacts" + OpCleanupWorkspace = "cleanup_workspace" + + // Validation operations + OpValidateManifest = "validate_manifest" + OpCheckResources = "check_resources" + OpVerifyProvenance = "verify_provenance" +) + +// Common phase names (should match TaskState values). +const ( + PhaseQueued = "queued" + PhasePreparing = "preparing" + PhaseRunning = "running" + PhaseCollecting = "collecting" + PhaseCompleted = "completed" + PhaseFailed = "failed" +) diff --git a/internal/worker/factory.go b/internal/worker/factory.go index 3e9e312..0535510 100644 --- a/internal/worker/factory.go +++ b/internal/worker/factory.go @@ -126,6 +126,9 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { cfg.LocalMode, ) + // Create state manager for task lifecycle management + stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available + // Create job runner jobRunner := executor.NewJobRunner( localExecutor, @@ -140,6 +143,7 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { exec, metricsObj, logger, + stateMgr, ) // Create resource manager diff --git a/internal/worker/lifecycle/runloop.go b/internal/worker/lifecycle/runloop.go index bfb66ab..ba6d687 100644 --- a/internal/worker/lifecycle/runloop.go +++ b/internal/worker/lifecycle/runloop.go @@ -33,6 +33,7 @@ type RunLoop struct { executor TaskExecutor metrics MetricsRecorder logger Logger + stateMgr *StateManager // State management running map[string]context.CancelFunc @@ -64,6 +65,7 @@ func NewRunLoop( executor TaskExecutor, metrics MetricsRecorder, logger Logger, + stateMgr *StateManager, ) *RunLoop { ctx, cancel := context.WithCancel(context.Background()) return &RunLoop{ @@ -72,6 +74,7 @@ func NewRunLoop( executor: executor, metrics: metrics, logger: logger, + stateMgr: stateMgr, running: make(map[string]context.CancelFunc), ctx: ctx, cancel: cancel, @@ -185,17 +188,45 @@ func (r *RunLoop) waitForTasks() { func (r *RunLoop) executeTask(task *queue.Task) { defer r.releaseRunningSlot(task.ID) + // Transition to preparing state + if r.stateMgr != nil { + if err := r.stateMgr.Transition(task, StatePreparing); err != nil { + r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err) + } + } + r.metrics.RecordTaskStart() defer r.metrics.RecordTaskCompletion() r.logger.Info("starting task", "task_id", task.ID, "job_name", task.JobName) + // Transition to running state + if r.stateMgr != nil { + if err := r.stateMgr.Transition(task, StateRunning); err != nil { + r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err) + } + } + taskCtx, cancel := context.WithTimeout(r.ctx, 24*time.Hour) defer cancel() if err := r.executor.Execute(taskCtx, task); err != nil { r.logger.Error("task execution failed", "task_id", task.ID, "error", err) r.metrics.RecordTaskFailure() + + // Transition to failed state + if r.stateMgr != nil { + if stateErr := r.stateMgr.Transition(task, StateFailed); stateErr != nil { + r.logger.Error("failed to transition task to failed state", "task_id", task.ID, "error", stateErr) + } + } + } else { + // Transition to completed state + if r.stateMgr != nil { + if err := r.stateMgr.Transition(task, StateCompleted); err != nil { + r.logger.Error("failed to transition task to completed state", "task_id", task.ID, "error", err) + } + } } _ = r.queue.ReleaseLease(task.ID, r.config.WorkerID) diff --git a/internal/worker/lifecycle/states.go b/internal/worker/lifecycle/states.go new file mode 100644 index 0000000..607d660 --- /dev/null +++ b/internal/worker/lifecycle/states.go @@ -0,0 +1,130 @@ +// Package lifecycle provides task lifecycle management with explicit state transitions. +package lifecycle + +import ( + "fmt" + "slices" + "time" + + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/domain" +) + +// TaskState represents the current state of a task in its lifecycle. +// These states form a finite state machine with valid transitions defined below. +type TaskState string + +const ( + // StateQueued indicates the task is waiting to be picked up by a worker. + StateQueued TaskState = "queued" + // StatePreparing indicates the task is being prepared (workspace setup, data staging). + StatePreparing TaskState = "preparing" + // StateRunning indicates the task is currently executing in a container. + StateRunning TaskState = "running" + // StateCollecting indicates the task has finished execution and results are being collected. + StateCollecting TaskState = "collecting" + // StateCompleted indicates the task finished successfully. + StateCompleted TaskState = "completed" + // StateFailed indicates the task failed during execution. + StateFailed TaskState = "failed" +) + +// ValidTransitions defines the allowed state transitions. +// The key is the "from" state, the value is a list of valid "to" states. +// This enforces that state transitions follow the expected lifecycle. +var ValidTransitions = map[TaskState][]TaskState{ + StateQueued: {StatePreparing, StateFailed}, + StatePreparing: {StateRunning, StateFailed}, + StateRunning: {StateCollecting, StateFailed}, + StateCollecting: {StateCompleted, StateFailed}, + StateCompleted: {}, + StateFailed: {}, +} + +// StateTransitionError is returned when an invalid state transition is attempted. +type StateTransitionError struct { + From TaskState + To TaskState +} + +func (e StateTransitionError) Error() string { + return fmt.Sprintf("invalid state transition: %s -> %s", e.From, e.To) +} + +// StateManager manages task state transitions with audit logging. +type StateManager struct { + enabled bool + auditor *audit.Logger +} + +// NewStateManager creates a new state manager with the given audit logger. +func NewStateManager(auditor *audit.Logger) *StateManager { + return &StateManager{ + enabled: auditor != nil, + auditor: auditor, + } +} + +// Transition attempts to transition a task from its current state to a new state. +// It validates the transition, updates the task status, and logs the event. +// Returns StateTransitionError if the transition is not valid. +func (sm *StateManager) Transition(task *domain.Task, to TaskState) error { + from := TaskState(task.Status) + + // Validate the transition + if err := sm.validateTransition(from, to); err != nil { + return err + } + + // Audit the transition before updating + if sm.enabled && sm.auditor != nil { + sm.auditor.Log(audit.Event{ + EventType: audit.EventJobStarted, + Timestamp: time.Now(), + Resource: task.ID, + Action: "task_state_change", + Success: true, + Metadata: map[string]interface{}{ + "job_name": task.JobName, + "old_state": string(from), + "new_state": string(to), + }, + }) + } + + // Update task state + task.Status = string(to) + + return nil +} + +// validateTransition checks if a transition from one state to another is valid. +func (sm *StateManager) validateTransition(from, to TaskState) error { + // Check if "from" state is valid + allowed, ok := ValidTransitions[from] + if !ok { + return StateTransitionError{From: from, To: to} + } + + // Check if "to" state is in the allowed list + if slices.Contains(allowed, to) { + return nil + } + + return StateTransitionError{From: from, To: to} +} + +// IsTerminalState returns true if the state is terminal (no further transitions allowed). +func IsTerminalState(state TaskState) bool { + return state == StateCompleted || state == StateFailed +} + +// CanTransition returns true if a transition from -> to is valid. +func CanTransition(from, to TaskState) bool { + allowed, ok := ValidTransitions[from] + if !ok { + return false + } + + return slices.Contains(allowed, to) +} diff --git a/tests/benchmarks/dataset_size_comparison_test.go b/tests/benchmarks/dataset_size_comparison_test.go index f1b8c18..f65d2bd 100644 --- a/tests/benchmarks/dataset_size_comparison_test.go +++ b/tests/benchmarks/dataset_size_comparison_test.go @@ -18,11 +18,11 @@ func BenchmarkDatasetSizeComparison(b *testing.B) { numFiles int totalMB int }{ - {"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB - {"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB - {"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB - {"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB - {"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB + {"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB + {"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB + {"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB + {"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB + {"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB } for _, tc := range sizes { diff --git a/tests/benchmarks/log_sanitize_bench_test.go b/tests/benchmarks/log_sanitize_bench_test.go index 1a6e24d..7635d04 100644 --- a/tests/benchmarks/log_sanitize_bench_test.go +++ b/tests/benchmarks/log_sanitize_bench_test.go @@ -11,6 +11,7 @@ import ( // - Regex matching is CPU-intensive // - High volume log pipelines process thousands of messages/sec // - C++ can use Hyperscan/RE2 for parallel regex matching +// // Expected speedup: 3-5x for high-volume logging func BenchmarkLogSanitizeMessage(b *testing.B) { // Test messages with various sensitive data patterns diff --git a/tests/unit/worker/artifacts_test.go b/tests/unit/worker/artifacts_test.go index 19791f3..dd527f4 100644 --- a/tests/unit/worker/artifacts_test.go +++ b/tests/unit/worker/artifacts_test.go @@ -1,68 +1,68 @@ package worker_test import ( -"os" -"path/filepath" -"testing" + "os" + "path/filepath" + "testing" -"github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/worker" ) func TestScanArtifacts_SkipsKnownPathsAndLogs(t *testing.T) { -runDir := t.TempDir() + runDir := t.TempDir() -mustWrite := func(rel string, data []byte) { -p := filepath.Join(runDir, rel) -if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil { -t.Fatalf("mkdir: %v", err) -} -if err := os.WriteFile(p, data, 0600); err != nil { -t.Fatalf("write file: %v", err) -} -} + mustWrite := func(rel string, data []byte) { + p := filepath.Join(runDir, rel) + if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(p, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + } -mustWrite("run_manifest.json", []byte("{}")) -mustWrite("output.log", []byte("log")) -mustWrite("code/ignored.txt", []byte("ignore")) -mustWrite("snapshot/ignored.bin", []byte("ignore")) + mustWrite("run_manifest.json", []byte("{}")) + mustWrite("output.log", []byte("log")) + mustWrite("code/ignored.txt", []byte("ignore")) + mustWrite("snapshot/ignored.bin", []byte("ignore")) -mustWrite("results/metrics.jsonl", []byte("m")) -mustWrite("checkpoints/best.pt", []byte("checkpoint")) -mustWrite("plots/loss.png", []byte("png")) + mustWrite("results/metrics.jsonl", []byte("m")) + mustWrite("checkpoints/best.pt", []byte("checkpoint")) + mustWrite("plots/loss.png", []byte("png")) -art, err := worker.ScanArtifacts(runDir) -if err != nil { -t.Fatalf("scanArtifacts: %v", err) -} -if art == nil { -t.Fatalf("expected artifacts") -} + art, err := worker.ScanArtifacts(runDir) + if err != nil { + t.Fatalf("scanArtifacts: %v", err) + } + if art == nil { + t.Fatalf("expected artifacts") + } -paths := make([]string, 0, len(art.Files)) -var total int64 -for _, f := range art.Files { -paths = append(paths, f.Path) -total += f.SizeBytes -} + paths := make([]string, 0, len(art.Files)) + var total int64 + for _, f := range art.Files { + paths = append(paths, f.Path) + total += f.SizeBytes + } -want := []string{ -"checkpoints/best.pt", -"plots/loss.png", -"results/metrics.jsonl", -} -if len(paths) != len(want) { -t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths) -} -for i := range want { -if paths[i] != want[i] { -t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i]) -} -} + want := []string{ + "checkpoints/best.pt", + "plots/loss.png", + "results/metrics.jsonl", + } + if len(paths) != len(want) { + t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths) + } + for i := range want { + if paths[i] != want[i] { + t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i]) + } + } -if art.TotalSizeBytes != total { -t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes) -} -if art.DiscoveryTime.IsZero() { -t.Fatalf("expected discovery_time") -} + if art.TotalSizeBytes != total { + t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes) + } + if art.DiscoveryTime.IsZero() { + t.Fatalf("expected discovery_time") + } }