refactor(api): internal refactoring for TUI and worker modules

- Refactor internal/worker and internal/queue packages
- Update cmd/tui for monitoring interface
- Update test configurations
This commit is contained in:
Jeremie Fraeys 2026-02-20 15:51:23 -05:00
parent 7583932897
commit 23e5f3d1dc
No known key found for this signature in database
28 changed files with 1620 additions and 636 deletions

View file

@ -28,19 +28,19 @@ func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportS
// Returns the path to the exported file
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
// Placeholder - actual implementation would call API
// POST /api/jobs/{id}/export?anonymize=true
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
s.logger.Info("export complete", "job", jobName, "path", exportPath)
return exportPath, nil
}
// ExportOptions contains options for export
type ExportOptions struct {
Anonymize bool
Anonymize bool
IncludeLogs bool
IncludeData bool
}

View file

@ -16,37 +16,37 @@ import (
// WebSocketClient manages real-time updates from the server
type WebSocketClient struct {
conn *websocket.Conn
serverURL string
apiKey string
logger *logging.Logger
conn *websocket.Conn
serverURL string
apiKey string
logger *logging.Logger
// Channels for different update types
jobUpdates chan model.JobUpdateMsg
gpuUpdates chan model.GPUUpdateMsg
jobUpdates chan model.JobUpdateMsg
gpuUpdates chan model.GPUUpdateMsg
statusUpdates chan model.StatusMsg
// Control
ctx context.Context
cancel context.CancelFunc
connected bool
ctx context.Context
cancel context.CancelFunc
connected bool
}
// JobUpdateMsg represents a real-time job status update
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}
// NewWebSocketClient creates a new WebSocket client
@ -71,26 +71,26 @@ func (c *WebSocketClient) Connect() error {
if err != nil {
return fmt.Errorf("invalid server URL: %w", err)
}
// Convert http/https to ws/wss
wsScheme := "ws"
if u.Scheme == "https" {
wsScheme = "wss"
}
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
// Create dialer with timeout
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
Subprotocols: []string{"fetchml-v1"},
}
// Add API key to headers
headers := http.Header{}
if c.apiKey != "" {
headers.Set("X-API-Key", c.apiKey)
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
if resp != nil {
@ -98,17 +98,17 @@ func (c *WebSocketClient) Connect() error {
}
return fmt.Errorf("websocket dial failed: %w", err)
}
c.conn = conn
c.connected = true
c.logger.Info("websocket connected", "url", wsURL)
// Start message handler
go c.messageHandler()
// Start heartbeat
go c.heartbeat()
return nil
}
@ -134,15 +134,15 @@ func (c *WebSocketClient) messageHandler() {
return
default:
}
if c.conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Read message
messageType, data, err := c.conn.ReadMessage()
if err != nil {
@ -150,7 +150,7 @@ func (c *WebSocketClient) messageHandler() {
c.logger.Error("websocket read error", "error", err)
}
c.connected = false
// Attempt reconnect
time.Sleep(5 * time.Second)
if err := c.Connect(); err != nil {
@ -158,7 +158,7 @@ func (c *WebSocketClient) messageHandler() {
}
continue
}
// Handle binary vs text messages
if messageType == websocket.BinaryMessage {
c.handleBinaryMessage(data)
@ -173,11 +173,11 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
if len(data) < 2 {
return
}
// Binary protocol: [opcode:1][data...]
opcode := data[0]
payload := data[1:]
switch opcode {
case 0x01: // Job update
var update JobUpdateMsg
@ -186,7 +186,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
return
}
c.jobUpdates <- model.JobUpdateMsg(update)
case 0x02: // GPU update
var update GPUUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
@ -194,7 +194,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
return
}
c.gpuUpdates <- model.GPUUpdateMsg(update)
case 0x03: // Status message
var status model.StatusMsg
if err := json.Unmarshal(payload, &status); err != nil {
@ -212,7 +212,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) {
c.logger.Error("failed to unmarshal text message", "error", err)
return
}
msgType, _ := msg["type"].(string)
switch msgType {
case "job_update":
@ -228,7 +228,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) {
func (c *WebSocketClient) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
@ -249,12 +249,12 @@ func (c *WebSocketClient) Subscribe(channels ...string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
subMsg := map[string]interface{}{
"action": "subscribe",
"channels": channels,
}
data, _ := json.Marshal(subMsg)
return c.conn.WriteMessage(websocket.TextMessage, data)
}

156
internal/api/adapter.go Normal file
View file

@ -0,0 +1,156 @@
// Package api provides HTTP handlers and OpenAPI-generated server interface implementations
package api
import (
"net/http"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
"github.com/jfraeys/fetch_ml/internal/api/jupyter"
"github.com/labstack/echo/v4"
)
// HandlerAdapter implements the generated ServerInterface using existing handlers
type HandlerAdapter struct {
jobsHandler *jobs.Handler
jupyterHandler *jupyter.Handler
datasetsHandler *datasets.Handler
}
// NewHandlerAdapter creates a new handler adapter
func NewHandlerAdapter(
jobsHandler *jobs.Handler,
jupyterHandler *jupyter.Handler,
datasetsHandler *datasets.Handler,
) *HandlerAdapter {
return &HandlerAdapter{
jobsHandler: jobsHandler,
jupyterHandler: jupyterHandler,
datasetsHandler: datasetsHandler,
}
}
// Ensure HandlerAdapter implements the generated interface
var _ ServerInterface = (*HandlerAdapter)(nil)
// toHTTPHandler converts echo.Context to standard HTTP handler
func toHTTPHandler(h func(http.ResponseWriter, *http.Request)) echo.HandlerFunc {
return func(c echo.Context) error {
h(c.Response().Writer, c.Request())
return nil
}
}
// GetHealth implements the health check endpoint
func (a *HandlerAdapter) GetHealth(ctx echo.Context) error {
return ctx.String(200, "OK\n")
}
// GetV1Experiments lists all experiments
func (a *HandlerAdapter) GetV1Experiments(ctx echo.Context) error {
return ctx.JSON(200, map[string]any{
"experiments": []any{},
"message": "Not yet implemented",
})
}
// PostV1Experiments creates a new experiment
func (a *HandlerAdapter) PostV1Experiments(ctx echo.Context) error {
return ctx.JSON(201, map[string]any{
"message": "Not yet implemented",
})
}
// GetV1JupyterServices lists all Jupyter services
func (a *HandlerAdapter) GetV1JupyterServices(ctx echo.Context) error {
if a.jupyterHandler == nil {
return ctx.JSON(503, map[string]any{
"error": "Jupyter service not available",
"code": "SERVICE_UNAVAILABLE",
})
}
return toHTTPHandler(a.jupyterHandler.ListServicesHTTP)(ctx)
}
// PostV1JupyterServices starts a new Jupyter service
func (a *HandlerAdapter) PostV1JupyterServices(ctx echo.Context) error {
if a.jupyterHandler == nil {
return ctx.JSON(503, map[string]any{
"error": "Jupyter service not available",
"code": "SERVICE_UNAVAILABLE",
})
}
return toHTTPHandler(a.jupyterHandler.StartServiceHTTP)(ctx)
}
// DeleteV1JupyterServicesServiceId stops a Jupyter service
func (a *HandlerAdapter) DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error {
if a.jupyterHandler == nil {
return ctx.JSON(503, map[string]any{
"error": "Jupyter service not available",
"code": "SERVICE_UNAVAILABLE",
})
}
// TODO: Implement when StopServiceHTTP is available
return ctx.JSON(501, map[string]any{
"error": "Not implemented",
"code": "NOT_IMPLEMENTED",
"message": "Jupyter service stop not yet implemented via REST API",
})
}
// GetV1Queue returns queue status
func (a *HandlerAdapter) GetV1Queue(ctx echo.Context) error {
return ctx.JSON(200, map[string]any{
"status": "healthy",
"pending": 0,
"running": 0,
})
}
// GetV1Tasks lists all tasks
func (a *HandlerAdapter) GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error {
if a.jobsHandler == nil {
return ctx.JSON(503, map[string]any{
"error": "Jobs handler not available",
"code": "SERVICE_UNAVAILABLE",
})
}
return toHTTPHandler(a.jobsHandler.ListAllJobsHTTP)(ctx)
}
// PostV1Tasks creates a new task
func (a *HandlerAdapter) PostV1Tasks(ctx echo.Context) error {
return ctx.JSON(501, map[string]any{
"error": "Not implemented",
"code": "NOT_IMPLEMENTED",
"message": "Task creation via REST API not yet implemented - use WebSocket",
})
}
// DeleteV1TasksTaskId cancels/deletes a task
func (a *HandlerAdapter) DeleteV1TasksTaskId(ctx echo.Context, taskId string) error {
return ctx.JSON(501, map[string]any{
"error": "Not implemented",
"code": "NOT_IMPLEMENTED",
"message": "Task cancellation via REST API not yet implemented - use WebSocket",
})
}
// GetV1TasksTaskId gets task details
func (a *HandlerAdapter) GetV1TasksTaskId(ctx echo.Context, taskId string) error {
return ctx.JSON(501, map[string]any{
"error": "Not implemented",
"code": "NOT_IMPLEMENTED",
"message": "Task details via REST API not yet implemented - use WebSocket",
})
}
// GetWs handles WebSocket connections
func (a *HandlerAdapter) GetWs(ctx echo.Context) error {
return ctx.JSON(426, map[string]any{
"error": "WebSocket connection required",
"code": "UPGRADE_REQUIRED",
"message": "Use WebSocket protocol to connect to this endpoint",
})
}

View file

@ -1,133 +0,0 @@
// Package api provides error handling utilities for the API
package api
import (
"encoding/json"
"net/http"
"time"
)
// ErrorResponse represents a standardized error response
type ErrorResponse struct {
Error bool `json:"error"`
Code byte `json:"code"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
Timestamp time.Time `json:"timestamp"`
RequestID string `json:"request_id,omitempty"`
}
// SuccessResponse represents a standardized success response
type SuccessResponse struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// WriteError writes a standardized error response
func WriteError(w http.ResponseWriter, code byte, message, details string, statusCode int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
response := ErrorResponse{
Error: true,
Code: code,
Message: message,
Details: details,
Timestamp: time.Now().UTC(),
}
json.NewEncoder(w).Encode(response)
}
// WriteSuccess writes a standardized success response
func WriteSuccess(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := SuccessResponse{
Success: true,
Data: data,
Timestamp: time.Now().UTC(),
}
json.NewEncoder(w).Encode(response)
}
// Common error codes for API responses
const (
ErrCodeUnknownError = 0x00
ErrCodeInvalidRequest = 0x01
ErrCodeAuthenticationFailed = 0x02
ErrCodePermissionDenied = 0x03
ErrCodeResourceNotFound = 0x04
ErrCodeResourceAlreadyExists = 0x05
ErrCodeServerOverloaded = 0x10
ErrCodeDatabaseError = 0x11
ErrCodeNetworkError = 0x12
ErrCodeStorageError = 0x13
ErrCodeTimeout = 0x14
ErrCodeJobNotFound = 0x20
ErrCodeJobAlreadyRunning = 0x21
ErrCodeJobFailedToStart = 0x22
ErrCodeJobExecutionFailed = 0x23
ErrCodeJobCancelled = 0x24
ErrCodeOutOfMemory = 0x30
ErrCodeDiskFull = 0x31
ErrCodeInvalidConfiguration = 0x32
ErrCodeServiceUnavailable = 0x33
)
// HTTP status code mappings
const (
StatusBadRequest = http.StatusBadRequest
StatusUnauthorized = http.StatusUnauthorized
StatusForbidden = http.StatusForbidden
StatusNotFound = http.StatusNotFound
StatusConflict = http.StatusConflict
StatusInternalServerError = http.StatusInternalServerError
StatusServiceUnavailable = http.StatusServiceUnavailable
StatusTooManyRequests = http.StatusTooManyRequests
)
// ErrorCodeToHTTPStatus maps API error codes to HTTP status codes
func ErrorCodeToHTTPStatus(code byte) int {
switch code {
case ErrCodeInvalidRequest:
return StatusBadRequest
case ErrCodeAuthenticationFailed:
return StatusUnauthorized
case ErrCodePermissionDenied:
return StatusForbidden
case ErrCodeResourceNotFound, ErrCodeJobNotFound:
return StatusNotFound
case ErrCodeResourceAlreadyExists, ErrCodeJobAlreadyRunning:
return StatusConflict
case ErrCodeServerOverloaded, ErrCodeServiceUnavailable:
return StatusServiceUnavailable
case ErrCodeDatabaseError, ErrCodeNetworkError, ErrCodeStorageError:
return StatusInternalServerError
default:
return StatusInternalServerError
}
}
// NewErrorResponse creates a new error response with the given details
func NewErrorResponse(code byte, message, details string) ErrorResponse {
return ErrorResponse{
Error: true,
Code: code,
Message: message,
Details: details,
Timestamp: time.Now().UTC(),
}
}
// NewSuccessResponse creates a new success response with the given data
func NewSuccessResponse(data interface{}) SuccessResponse {
return SuccessResponse{
Success: true,
Data: data,
Timestamp: time.Now().UTC(),
}
}

View file

@ -5,6 +5,7 @@ import (
"os"
"strings"
apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
@ -49,8 +50,8 @@ func (s *Server) initializeComponents() error {
// Initialize Jupyter service manager
s.initJupyterServiceManager()
// Initialize handlers
s.handlers = NewHandlers(s.expManager, nil, s.logger)
// Initialize OpenAPI validation middleware (optional, doesn't block startup)
s.initValidationMiddleware()
return nil
}
@ -250,6 +251,24 @@ func (s *Server) initAuditLogger() *audit.Logger {
return al
}
// initValidationMiddleware initializes the OpenAPI validation middleware
func (s *Server) initValidationMiddleware() {
// Only initialize if OpenAPI spec exists
if _, err := os.Stat("api/openapi.yaml"); err != nil {
s.logger.Debug("OpenAPI spec not found, skipping validation middleware")
return
}
vm, err := apimiddleware.NewValidationMiddleware("api/openapi.yaml")
if err != nil {
s.logger.Warn("failed to initialize OpenAPI validation middleware", "error", err)
return
}
s.validationMiddleware = vm
s.logger.Info("OpenAPI validation middleware initialized")
}
// getSecurityConfig extracts security config from server config
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
return &config.SecurityConfig{

View file

@ -1,313 +0,0 @@
// Package api provides HTTP handlers for the fetch_ml API server
package api
import (
"encoding/json"
"fmt"
"net/http"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// Handlers groups all HTTP handlers
type Handlers struct {
expManager *experiment.Manager
jupyterServiceMgr *jupyter.ServiceManager
logger *logging.Logger
}
// NewHandlers creates a new handler group
func NewHandlers(
expManager *experiment.Manager,
jupyterServiceMgr *jupyter.ServiceManager,
logger *logging.Logger,
) *Handlers {
return &Handlers{
expManager: expManager,
jupyterServiceMgr: jupyterServiceMgr,
logger: logger,
}
}
// RegisterHandlers registers all HTTP handlers with the mux
func (h *Handlers) RegisterHandlers(mux *http.ServeMux) {
// Health check endpoints
mux.HandleFunc("/db-status", h.handleDBStatus)
// Jupyter service endpoints
if h.jupyterServiceMgr != nil {
mux.HandleFunc("/api/jupyter/services", h.handleJupyterServices)
mux.HandleFunc("/api/jupyter/experiments/link", h.handleJupyterExperimentLink)
mux.HandleFunc("/api/jupyter/experiments/sync", h.handleJupyterExperimentSync)
}
}
// handleHealth responds with a simple health check
func (h *Handlers) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprintf(w, "OK\n")
}
// handleDBStatus responds with database connection status
func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
// This would need the DB instance passed to handlers
// For now, return a basic response
response := map[string]any{
"status": "unknown",
"message": "Database status check not implemented",
}
w.WriteHeader(http.StatusOK)
if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}
// handleJupyterServices handles Jupyter service management requests
func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
switch r.Method {
case http.MethodGet:
if !user.HasPermission("jupyter:read") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.listJupyterServices(w, r)
case http.MethodPost:
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.startJupyterService(w, r)
case http.MethodDelete:
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.stopJupyterService(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// listJupyterServices lists all Jupyter services
func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) {
services := h.jupyterServiceMgr.ListServices()
w.WriteHeader(http.StatusOK)
if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}
// startJupyterService starts a new Jupyter service
func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) {
var req jupyter.StartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
ctx := r.Context()
service, err := h.jupyterServiceMgr.StartService(ctx, &req)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to start service: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}
// stopJupyterService stops a Jupyter service
func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
serviceID := r.URL.Query().Get("id")
if serviceID == "" {
http.Error(w, "Service ID is required", http.StatusBadRequest)
return
}
ctx := r.Context()
if err := h.jupyterServiceMgr.StopService(ctx, serviceID); err != nil {
http.Error(w, fmt.Sprintf("Failed to stop service: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(map[string]string{
"status": "stopped",
"id": serviceID,
}); err != nil {
h.logger.Error("failed to encode response", "error", err)
}
}
// handleJupyterExperimentLink handles linking Jupyter workspaces with experiments
func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Workspace string `json:"workspace"`
ExperimentID string `json:"experiment_id"`
ServiceID string `json:"service_id,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Workspace == "" || req.ExperimentID == "" {
http.Error(w, "Workspace and experiment_id are required", http.StatusBadRequest)
return
}
if !h.expManager.ExperimentExists(req.ExperimentID) {
http.Error(w, "Experiment not found", http.StatusNotFound)
return
}
// Link workspace with experiment using service manager
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(
req.Workspace,
req.ExperimentID,
req.ServiceID,
); err != nil {
http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError)
return
}
// Get workspace metadata to return
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
if err != nil {
http.Error(
w,
fmt.Sprintf("Failed to get workspace metadata: %v", err),
http.StatusInternalServerError,
)
return
}
h.logger.Info("jupyter workspace linked with experiment",
"workspace", req.Workspace,
"experiment_id", req.ExperimentID,
"service_id", req.ServiceID)
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(map[string]interface{}{
"status": "linked",
"data": metadata,
}); err != nil {
h.logger.Error("failed to encode response", "error", err)
}
}
// handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments
func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Workspace string `json:"workspace"`
ExperimentID string `json:"experiment_id"`
Direction string `json:"direction"` // "pull" or "push"
SyncType string `json:"sync_type"` // "data", "notebooks", "all"
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Workspace == "" || req.ExperimentID == "" || req.Direction == "" {
http.Error(w, "Workspace, experiment_id, and direction are required", http.StatusBadRequest)
return
}
// Validate experiment exists
if !h.expManager.ExperimentExists(req.ExperimentID) {
http.Error(w, "Experiment not found", http.StatusNotFound)
return
}
// Perform sync operation using service manager
ctx := r.Context()
if err := h.jupyterServiceMgr.SyncWorkspaceWithExperiment(
ctx, req.Workspace, req.ExperimentID, req.Direction); err != nil {
http.Error(w, fmt.Sprintf("Failed to sync workspace: %v", err), http.StatusInternalServerError)
return
}
// Get updated metadata
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
if err != nil {
http.Error(
w,
fmt.Sprintf("Failed to get workspace metadata: %v", err),
http.StatusInternalServerError,
)
return
}
// Create sync result
syncResult := map[string]interface{}{
"workspace": req.Workspace,
"experiment_id": req.ExperimentID,
"direction": req.Direction,
"sync_type": req.SyncType,
"synced_at": metadata.LastSync,
"status": "completed",
"metadata": metadata,
}
h.logger.Info("jupyter workspace sync completed",
"workspace", req.Workspace,
"experiment_id", req.ExperimentID,
"direction", req.Direction,
"sync_type", req.SyncType)
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(syncResult); err != nil {
h.logger.Error("failed to encode response", "error", err)
}
}

View file

@ -15,9 +15,9 @@ import (
// Handler provides Jupyter-related WebSocket handlers
type Handler struct {
logger *logging.Logger
jupyterMgr *jupyter.ServiceManager
authConfig *auth.Config
logger *logging.Logger
jupyterMgr *jupyter.ServiceManager
authConfig *auth.Config
}
// NewHandler creates a new Jupyter handler
@ -35,11 +35,11 @@ func NewHandler(
// Error codes
const (
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeServiceUnavailable = 0x33
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeServiceUnavailable = 0x33
)
// Permissions
@ -141,18 +141,18 @@ func (h *Handler) HandleListJupyter(conn *websocket.Conn, payload []byte, user *
if h.jupyterMgr == nil {
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"services": []interface{}{},
"count": 0,
"success": true,
"services": []interface{}{},
"count": 0,
})
}
services := h.jupyterMgr.ListServices()
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"services": services,
"count": len(services),
"success": true,
"services": services,
"count": len(services),
})
}

View file

@ -26,6 +26,10 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
if len(s.config.Security.IPWhitelist) > 0 {
handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler)
}
// Add OpenAPI validation if available
if s.validationMiddleware != nil {
handler = s.validationMiddleware.ValidateRequest(handler)
}
handler.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,87 @@
// Package middleware provides request/response validation using OpenAPI spec
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers"
"github.com/getkin/kin-openapi/routers/gorillamux"
)
// ValidationMiddleware validates HTTP requests against OpenAPI spec
type ValidationMiddleware struct {
router routers.Router
}
// NewValidationMiddleware creates a new validation middleware from OpenAPI spec
func NewValidationMiddleware(specPath string) (*ValidationMiddleware, error) {
// Load OpenAPI spec
loader := openapi3.NewLoader()
doc, err := loader.LoadFromFile(specPath)
if err != nil {
return nil, err
}
// Validate the spec itself
if err := doc.Validate(loader.Context); err != nil {
return nil, err
}
// Create router for path matching
router, err := gorillamux.NewRouter(doc)
if err != nil {
return nil, err
}
return &ValidationMiddleware{router: router}, nil
}
// ValidateRequest validates an incoming HTTP request
func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip validation for health endpoints
if r.URL.Path == "/health/ok" || r.URL.Path == "/health" || r.URL.Path == "/metrics" {
next.ServeHTTP(w, r)
return
}
// Find the route
route, pathParams, err := v.router.FindRoute(r)
if err != nil {
// Route not in spec - allow through (might be unregistered endpoint)
next.ServeHTTP(w, r)
return
}
// Read and restore body for validation
var bodyBytes []byte
if r.Body != nil && r.Body != http.NoBody {
bodyBytes, _ = io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
// Validate request - body is read from r.Body automatically
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: r,
PathParams: pathParams,
Route: route,
}
if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"error": "validation failed",
"message": err.Error(),
})
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,173 @@
// Package responses provides structured API response types with security-conscious error handling.
package responses
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// ErrorResponse provides a sanitized error response to clients.
// It includes a trace ID for support lookup while preventing information leakage.
type ErrorResponse struct {
Error string `json:"error"` // Sanitized error message for clients
Code string `json:"code"` // Machine-readable error code
TraceID string `json:"trace_id"` // For support lookup (internal correlation)
}
// Error codes for machine-readable error identification
const (
ErrCodeBadRequest = "BAD_REQUEST"
ErrCodeUnauthorized = "UNAUTHORIZED"
ErrCodeForbidden = "FORBIDDEN"
ErrCodeNotFound = "NOT_FOUND"
ErrCodeConflict = "CONFLICT"
ErrCodeRateLimited = "RATE_LIMITED"
ErrCodeInternal = "INTERNAL_ERROR"
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
ErrCodeValidation = "VALIDATION_ERROR"
)
// HTTP status to error code mapping
var statusToCode = map[int]string{
http.StatusBadRequest: ErrCodeBadRequest,
http.StatusUnauthorized: ErrCodeUnauthorized,
http.StatusForbidden: ErrCodeForbidden,
http.StatusNotFound: ErrCodeNotFound,
http.StatusConflict: ErrCodeConflict,
http.StatusTooManyRequests: ErrCodeRateLimited,
http.StatusInternalServerError: ErrCodeInternal,
http.StatusServiceUnavailable: ErrCodeServiceUnavailable,
422: ErrCodeValidation, // Unprocessable Entity
}
// Patterns to sanitize from error messages (security: prevent information leakage)
var (
// Remove file paths
pathPattern = regexp.MustCompile(`/[^\s]*`)
// Remove sensitive keywords
sensitiveKeywords = []string{"password", "secret", "token", "key", "credential", "auth"}
)
// WriteError writes a sanitized error response to the client.
// It extracts the trace ID from the context, logs the full error internally,
// and returns a sanitized message to the client.
func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, logger *logging.Logger) {
traceID := logging.TraceIDFromContext(r.Context())
if traceID == "" {
traceID = generateTraceID()
}
// Log the full error internally with all details
if logger != nil {
logger.Error("request failed",
"trace_id", traceID,
"method", r.Method,
"path", r.URL.Path,
"status", status,
"error", err.Error(),
"client_ip", getClientIP(r),
)
}
// Build sanitized response
resp := ErrorResponse{
Error: sanitizeError(err.Error()),
Code: errorCodeFromStatus(status),
TraceID: traceID,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp)
}
// WriteErrorMessage writes a sanitized error response with a custom message.
func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, message string, logger *logging.Logger) {
traceID := logging.TraceIDFromContext(r.Context())
if traceID == "" {
traceID = generateTraceID()
}
if logger != nil {
logger.Error("request failed",
"trace_id", traceID,
"method", r.Method,
"path", r.URL.Path,
"status", status,
"error", message,
)
}
resp := ErrorResponse{
Error: sanitizeError(message),
Code: errorCodeFromStatus(status),
TraceID: traceID,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp)
}
// sanitizeError removes potentially sensitive information from error messages.
// It prevents information leakage to clients while preserving useful context.
func sanitizeError(msg string) string {
if msg == "" {
return "An error occurred"
}
// Remove file paths
msg = pathPattern.ReplaceAllString(msg, "[path]")
// Remove sensitive keywords and their values
lowerMsg := strings.ToLower(msg)
for _, keyword := range sensitiveKeywords {
if strings.Contains(lowerMsg, keyword) {
return "An error occurred"
}
}
// Remove internal error details
msg = strings.ReplaceAll(msg, "internal error", "an error occurred")
msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred")
// Truncate if too long
if len(msg) > 200 {
msg = msg[:200] + "..."
}
return msg
}
// errorCodeFromStatus returns the appropriate error code for an HTTP status.
func errorCodeFromStatus(status int) string {
if code, ok := statusToCode[status]; ok {
return code
}
return ErrCodeInternal
}
// getClientIP extracts the client IP from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
return r.RemoteAddr
}
// generateTraceID generates a new trace ID when one isn't in context.
func generateTraceID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}

View file

@ -2,12 +2,14 @@ package api
import (
"net/http"
"os"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
"github.com/jfraeys/fetch_ml/internal/api/jupyter"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/prommetrics"
"github.com/labstack/echo/v4"
)
// registerRoutes sets up all HTTP routes and handlers
@ -34,9 +36,6 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
// Register WebSocket endpoint
s.registerWebSocketRoutes(mux)
// Register HTTP API handlers
s.handlers.RegisterHandlers(mux)
// Register new REST API endpoints for TUI
jobsHandler := jobs.NewHandler(
s.expManager,
@ -52,13 +51,73 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
// Team jobs endpoint: GET /api/jobs?all_users=true
mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP)
// Register OpenAPI-generated routes with Echo router
s.registerOpenAPIRoutes(mux, jobsHandler)
// Register API documentation endpoint
s.registerDocsRoutes(mux)
}
// registerDocsRoutes sets up API documentation serving
func (s *Server) registerDocsRoutes(mux *http.ServeMux) {
// Check if docs directory exists
if _, err := os.Stat("docs/api"); err == nil {
docsFS := http.FileServer(http.Dir("docs/api"))
mux.Handle("/docs/", http.StripPrefix("/docs/", docsFS))
s.logger.Info("API documentation endpoint registered", "path", "/docs/")
} else {
s.logger.Debug("API documentation not available", "path", "docs/api")
}
}
// registerOpenAPIRoutes sets up Echo router with generated OpenAPI handlers
func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Handler) {
// Create Echo instance for OpenAPI-generated routes
e := echo.New()
// Create jupyter handler
jupyterHandler := jupyter.NewHandler(
s.logger,
s.jupyterServiceMgr,
s.config.BuildAuthConfig(),
)
// Create datasets handler
datasetsHandler := datasets.NewHandler(
s.logger,
s.db,
s.config.DataDir,
)
// Create adapter implementing ServerInterface
handlerAdapter := NewHandlerAdapter(
jobsHandler,
jupyterHandler,
datasetsHandler,
)
// Register generated OpenAPI routes
RegisterHandlers(e, handlerAdapter)
// Wrap Echo router to work with existing middleware chain
echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e.ServeHTTP(w, r)
})
// Register Echo router at /v1/ prefix (and other generated paths)
// These paths take precedence over legacy routes
mux.Handle("/health", echoHandler)
mux.Handle("/v1/", echoHandler)
mux.Handle("/ws", echoHandler)
s.logger.Info("OpenAPI-generated routes registered with Echo router")
}
// registerHealthRoutes sets up health check endpoints
func (s *Server) registerHealthRoutes(mux *http.ServeMux) {
healthHandler := NewHealthHandler(s)
healthHandler.RegisterRoutes(mux)
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
s.logger.Info("health check endpoints registered")
}

View file

@ -8,6 +8,7 @@ import (
"syscall"
"time"
apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
@ -20,18 +21,18 @@ import (
// Server represents the API server
type Server struct {
config *ServerConfig
httpServer *http.Server
logger *logging.Logger
expManager *experiment.Manager
taskQueue queue.Backend
db *storage.DB
handlers *Handlers
sec *middleware.SecurityMiddleware
cleanupFuncs []func()
jupyterServiceMgr *jupyter.ServiceManager
auditLogger *audit.Logger
promMetrics *prommetrics.Metrics // Prometheus metrics
config *ServerConfig
httpServer *http.Server
logger *logging.Logger
expManager *experiment.Manager
taskQueue queue.Backend
db *storage.DB
sec *middleware.SecurityMiddleware
cleanupFuncs []func()
jupyterServiceMgr *jupyter.ServiceManager
auditLogger *audit.Logger
promMetrics *prommetrics.Metrics // Prometheus metrics
validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation
}
// NewServer creates a new API server

642
internal/api/server_gen.go Normal file
View file

@ -0,0 +1,642 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
package api
import (
"bytes"
"compress/gzip"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/getkin/kin-openapi/openapi3"
"github.com/labstack/echo/v4"
"github.com/oapi-codegen/runtime"
)
const (
ApiKeyAuthScopes = "ApiKeyAuth.Scopes"
)
// Defines values for ErrorResponseCode.
const (
BADREQUEST ErrorResponseCode = "BAD_REQUEST"
CONFLICT ErrorResponseCode = "CONFLICT"
FORBIDDEN ErrorResponseCode = "FORBIDDEN"
INTERNALERROR ErrorResponseCode = "INTERNAL_ERROR"
NOTFOUND ErrorResponseCode = "NOT_FOUND"
RATELIMITED ErrorResponseCode = "RATE_LIMITED"
SERVICEUNAVAILABLE ErrorResponseCode = "SERVICE_UNAVAILABLE"
UNAUTHORIZED ErrorResponseCode = "UNAUTHORIZED"
VALIDATIONERROR ErrorResponseCode = "VALIDATION_ERROR"
)
// Defines values for ExperimentStatus.
const (
Active ExperimentStatus = "active"
Archived ExperimentStatus = "archived"
Deleted ExperimentStatus = "deleted"
)
// Defines values for HealthResponseStatus.
const (
Degraded HealthResponseStatus = "degraded"
Healthy HealthResponseStatus = "healthy"
Unhealthy HealthResponseStatus = "unhealthy"
)
// Defines values for JupyterServiceStatus.
const (
JupyterServiceStatusError JupyterServiceStatus = "error"
JupyterServiceStatusRunning JupyterServiceStatus = "running"
JupyterServiceStatusStarting JupyterServiceStatus = "starting"
JupyterServiceStatusStopped JupyterServiceStatus = "stopped"
JupyterServiceStatusStopping JupyterServiceStatus = "stopping"
)
// Defines values for TaskStatus.
const (
TaskStatusCollecting TaskStatus = "collecting"
TaskStatusCompleted TaskStatus = "completed"
TaskStatusFailed TaskStatus = "failed"
TaskStatusPreparing TaskStatus = "preparing"
TaskStatusQueued TaskStatus = "queued"
TaskStatusRunning TaskStatus = "running"
)
// Defines values for GetV1TasksParamsStatus.
const (
Completed GetV1TasksParamsStatus = "completed"
Failed GetV1TasksParamsStatus = "failed"
Queued GetV1TasksParamsStatus = "queued"
Running GetV1TasksParamsStatus = "running"
)
// CreateExperimentRequest defines model for CreateExperimentRequest.
type CreateExperimentRequest struct {
Description *string `json:"description,omitempty"`
Name string `json:"name"`
}
// CreateTaskRequest defines model for CreateTaskRequest.
type CreateTaskRequest struct {
// Args Command-line arguments for the training script
Args *string `json:"args,omitempty"`
// Cpu CPU cores requested
Cpu *int `json:"cpu,omitempty"`
DatasetSpecs *[]DatasetSpec `json:"dataset_specs,omitempty"`
Datasets *[]string `json:"datasets,omitempty"`
// Gpu GPUs requested
Gpu *int `json:"gpu,omitempty"`
// JobName Unique identifier for the job
JobName string `json:"job_name"`
// MemoryGb Memory (GB) requested
MemoryGb *int `json:"memory_gb,omitempty"`
Metadata *map[string]string `json:"metadata,omitempty"`
Priority *int `json:"priority,omitempty"`
// SnapshotId Reference to experiment snapshot
SnapshotId *string `json:"snapshot_id,omitempty"`
}
// DatasetSpec defines model for DatasetSpec.
type DatasetSpec struct {
MountPath *string `json:"mount_path,omitempty"`
Name *string `json:"name,omitempty"`
Sha256 *string `json:"sha256,omitempty"`
Source *string `json:"source,omitempty"`
}
// ErrorResponse defines model for ErrorResponse.
type ErrorResponse struct {
Code ErrorResponseCode `json:"code"`
// Error Sanitized error message
Error string `json:"error"`
// TraceId Support correlation ID
TraceId string `json:"trace_id"`
}
// ErrorResponseCode defines model for ErrorResponse.Code.
type ErrorResponseCode string
// Experiment defines model for Experiment.
type Experiment struct {
CommitId *string `json:"commit_id,omitempty"`
CreatedAt *time.Time `json:"created_at,omitempty"`
Id *string `json:"id,omitempty"`
Name *string `json:"name,omitempty"`
Status *ExperimentStatus `json:"status,omitempty"`
}
// ExperimentStatus defines model for Experiment.Status.
type ExperimentStatus string
// HealthResponse defines model for HealthResponse.
type HealthResponse struct {
Status *HealthResponseStatus `json:"status,omitempty"`
Timestamp *time.Time `json:"timestamp,omitempty"`
Version *string `json:"version,omitempty"`
}
// HealthResponseStatus defines model for HealthResponse.Status.
type HealthResponseStatus string
// JupyterService defines model for JupyterService.
type JupyterService struct {
CreatedAt *time.Time `json:"created_at,omitempty"`
Id *string `json:"id,omitempty"`
Name *string `json:"name,omitempty"`
Status *JupyterServiceStatus `json:"status,omitempty"`
Token *string `json:"token,omitempty"`
Url *string `json:"url,omitempty"`
}
// JupyterServiceStatus defines model for JupyterService.Status.
type JupyterServiceStatus string
// QueueStats defines model for QueueStats.
type QueueStats struct {
// Completed Tasks completed today
Completed *int `json:"completed,omitempty"`
// Failed Tasks failed today
Failed *int `json:"failed,omitempty"`
// Queued Tasks waiting to run
Queued *int `json:"queued,omitempty"`
// Running Tasks currently executing
Running *int `json:"running,omitempty"`
// Workers Active workers
Workers *int `json:"workers,omitempty"`
}
// StartJupyterRequest defines model for StartJupyterRequest.
type StartJupyterRequest struct {
Image *string `json:"image,omitempty"`
Name string `json:"name"`
Workspace *string `json:"workspace,omitempty"`
}
// Task defines model for Task.
type Task struct {
Cpu *int `json:"cpu,omitempty"`
CreatedAt *time.Time `json:"created_at,omitempty"`
Datasets *[]string `json:"datasets,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
Error *string `json:"error,omitempty"`
Gpu *int `json:"gpu,omitempty"`
// Id Unique task identifier
Id *string `json:"id,omitempty"`
JobName *string `json:"job_name,omitempty"`
MaxRetries *int `json:"max_retries,omitempty"`
MemoryGb *int `json:"memory_gb,omitempty"`
Output *string `json:"output,omitempty"`
Priority *int `json:"priority,omitempty"`
RetryCount *int `json:"retry_count,omitempty"`
SnapshotId *string `json:"snapshot_id,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
Status *TaskStatus `json:"status,omitempty"`
UserId *string `json:"user_id,omitempty"`
WorkerId *string `json:"worker_id,omitempty"`
}
// TaskStatus defines model for Task.Status.
type TaskStatus string
// TaskList defines model for TaskList.
type TaskList struct {
Limit *int `json:"limit,omitempty"`
Offset *int `json:"offset,omitempty"`
Tasks *[]Task `json:"tasks,omitempty"`
Total *int `json:"total,omitempty"`
}
// BadRequest defines model for BadRequest.
type BadRequest = ErrorResponse
// NotFound defines model for NotFound.
type NotFound = ErrorResponse
// RateLimited defines model for RateLimited.
type RateLimited = ErrorResponse
// Unauthorized defines model for Unauthorized.
type Unauthorized = ErrorResponse
// ValidationError defines model for ValidationError.
type ValidationError = ErrorResponse
// GetV1TasksParams defines parameters for GetV1Tasks.
type GetV1TasksParams struct {
Status *GetV1TasksParamsStatus `form:"status,omitempty" json:"status,omitempty"`
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// GetV1TasksParamsStatus defines parameters for GetV1Tasks.
type GetV1TasksParamsStatus string
// PostV1ExperimentsJSONRequestBody defines body for PostV1Experiments for application/json ContentType.
type PostV1ExperimentsJSONRequestBody = CreateExperimentRequest
// PostV1JupyterServicesJSONRequestBody defines body for PostV1JupyterServices for application/json ContentType.
type PostV1JupyterServicesJSONRequestBody = StartJupyterRequest
// PostV1TasksJSONRequestBody defines body for PostV1Tasks for application/json ContentType.
type PostV1TasksJSONRequestBody = CreateTaskRequest
// ServerInterface represents all server handlers.
type ServerInterface interface {
// Health check
// (GET /health)
GetHealth(ctx echo.Context) error
// List experiments
// (GET /v1/experiments)
GetV1Experiments(ctx echo.Context) error
// Create experiment
// (POST /v1/experiments)
PostV1Experiments(ctx echo.Context) error
// List Jupyter services
// (GET /v1/jupyter/services)
GetV1JupyterServices(ctx echo.Context) error
// Start Jupyter service
// (POST /v1/jupyter/services)
PostV1JupyterServices(ctx echo.Context) error
// Stop Jupyter service
// (DELETE /v1/jupyter/services/{serviceId})
DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error
// Queue status
// (GET /v1/queue)
GetV1Queue(ctx echo.Context) error
// List tasks
// (GET /v1/tasks)
GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error
// Create task
// (POST /v1/tasks)
PostV1Tasks(ctx echo.Context) error
// Cancel/delete task
// (DELETE /v1/tasks/{taskId})
DeleteV1TasksTaskId(ctx echo.Context, taskId string) error
// Get task details
// (GET /v1/tasks/{taskId})
GetV1TasksTaskId(ctx echo.Context, taskId string) error
// WebSocket connection
// (GET /ws)
GetWs(ctx echo.Context) error
}
// ServerInterfaceWrapper converts echo contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
}
// GetHealth converts echo context to params.
func (w *ServerInterfaceWrapper) GetHealth(ctx echo.Context) error {
var err error
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetHealth(ctx)
return err
}
// GetV1Experiments converts echo context to params.
func (w *ServerInterfaceWrapper) GetV1Experiments(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetV1Experiments(ctx)
return err
}
// PostV1Experiments converts echo context to params.
func (w *ServerInterfaceWrapper) PostV1Experiments(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.PostV1Experiments(ctx)
return err
}
// GetV1JupyterServices converts echo context to params.
func (w *ServerInterfaceWrapper) GetV1JupyterServices(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetV1JupyterServices(ctx)
return err
}
// PostV1JupyterServices converts echo context to params.
func (w *ServerInterfaceWrapper) PostV1JupyterServices(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.PostV1JupyterServices(ctx)
return err
}
// DeleteV1JupyterServicesServiceId converts echo context to params.
func (w *ServerInterfaceWrapper) DeleteV1JupyterServicesServiceId(ctx echo.Context) error {
var err error
// ------------- Path parameter "serviceId" -------------
var serviceId string
err = runtime.BindStyledParameterWithOptions("simple", "serviceId", ctx.Param("serviceId"), &serviceId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter serviceId: %s", err))
}
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.DeleteV1JupyterServicesServiceId(ctx, serviceId)
return err
}
// GetV1Queue converts echo context to params.
func (w *ServerInterfaceWrapper) GetV1Queue(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetV1Queue(ctx)
return err
}
// GetV1Tasks converts echo context to params.
func (w *ServerInterfaceWrapper) GetV1Tasks(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Parameter object where we will unmarshal all parameters from the context
var params GetV1TasksParams
// ------------- Optional query parameter "status" -------------
err = runtime.BindQueryParameter("form", true, false, "status", ctx.QueryParams(), &params.Status)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter status: %s", err))
}
// ------------- Optional query parameter "limit" -------------
err = runtime.BindQueryParameter("form", true, false, "limit", ctx.QueryParams(), &params.Limit)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter limit: %s", err))
}
// ------------- Optional query parameter "offset" -------------
err = runtime.BindQueryParameter("form", true, false, "offset", ctx.QueryParams(), &params.Offset)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter offset: %s", err))
}
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetV1Tasks(ctx, params)
return err
}
// PostV1Tasks converts echo context to params.
func (w *ServerInterfaceWrapper) PostV1Tasks(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.PostV1Tasks(ctx)
return err
}
// DeleteV1TasksTaskId converts echo context to params.
func (w *ServerInterfaceWrapper) DeleteV1TasksTaskId(ctx echo.Context) error {
var err error
// ------------- Path parameter "taskId" -------------
var taskId string
err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err))
}
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.DeleteV1TasksTaskId(ctx, taskId)
return err
}
// GetV1TasksTaskId converts echo context to params.
func (w *ServerInterfaceWrapper) GetV1TasksTaskId(ctx echo.Context) error {
var err error
// ------------- Path parameter "taskId" -------------
var taskId string
err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err))
}
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetV1TasksTaskId(ctx, taskId)
return err
}
// GetWs converts echo context to params.
func (w *ServerInterfaceWrapper) GetWs(ctx echo.Context) error {
var err error
ctx.Set(ApiKeyAuthScopes, []string{})
// Invoke the callback with all the unmarshaled arguments
err = w.Handler.GetWs(ctx)
return err
}
// This is a simple interface which specifies echo.Route addition functions which
// are present on both echo.Echo and echo.Group, since we want to allow using
// either of them for path registration
type EchoRouter interface {
CONNECT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
HEAD(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
OPTIONS(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
PATCH(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
TRACE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
}
// RegisterHandlers adds each server route to the EchoRouter.
func RegisterHandlers(router EchoRouter, si ServerInterface) {
RegisterHandlersWithBaseURL(router, si, "")
}
// Registers handlers, and prepends BaseURL to the paths, so that the paths
// can be served under a prefix.
func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) {
wrapper := ServerInterfaceWrapper{
Handler: si,
}
router.GET(baseURL+"/health", wrapper.GetHealth)
router.GET(baseURL+"/v1/experiments", wrapper.GetV1Experiments)
router.POST(baseURL+"/v1/experiments", wrapper.PostV1Experiments)
router.GET(baseURL+"/v1/jupyter/services", wrapper.GetV1JupyterServices)
router.POST(baseURL+"/v1/jupyter/services", wrapper.PostV1JupyterServices)
router.DELETE(baseURL+"/v1/jupyter/services/:serviceId", wrapper.DeleteV1JupyterServicesServiceId)
router.GET(baseURL+"/v1/queue", wrapper.GetV1Queue)
router.GET(baseURL+"/v1/tasks", wrapper.GetV1Tasks)
router.POST(baseURL+"/v1/tasks", wrapper.PostV1Tasks)
router.DELETE(baseURL+"/v1/tasks/:taskId", wrapper.DeleteV1TasksTaskId)
router.GET(baseURL+"/v1/tasks/:taskId", wrapper.GetV1TasksTaskId)
router.GET(baseURL+"/ws", wrapper.GetWs)
}
// Base64 encoded, gzipped, json marshaled Swagger object
var swaggerSpec = []string{
"H4sIAAAAAAAC/8RaW1fbuhL+K1ra+6FdxyFAL2c3bwHSNqfh0gTaszbhBGFPEoEtqZIMZLPy38+SZCe+",
"yEDphSdIJI1nPn2a+TTOHQ55IjgDphXu3GEJSnCmwH7YIdEQvqWgtPkUcqaB2X+JEDENiaactS8VZ+Y7",
"uCWJiMHNjAB38E53bzLsfT7pjY5xgEFKLnEH99k1iWmEpLOMplwmROMAa0lCmNAIdzDZutgOX0WvW/Bm",
"+rb177/ebbbIRRi1YLq1/er1m7fmG7wMsArnkBDzyD8lTHEH/9Feh9N2o6rdM08eZoHh5XIZ4AhUKKkw",
"AdRdMpYPuH7PUxY9KfCDw+PJ+8OTg71C2ENQPJUhIMZNzMb0c4bscWcZ4CHRMKAJ1fC0wIfd495k0N/v",
"H/dKsRMNKDZ2EdyGABE8b/DHnKOEsEW+4QoHeA4kAmlpPwQtF63uVIM0H8trRxByFimUMk1jJNeRSVBg",
"La2d1AthUKFMwwyk8WQZ4BNGUj3nkv7zRJBPDronxx8Ph/2/SyDnJOYSJVQpymaoe9RHV7B4Vqy7qZ4D",
"01lYFnEqwbLti/HXft1zMTwBjC/dQX+ve9w/PJj0hsPDYQGQtXk0JTR+Zs7VvVmujFvW7UogGnq3AiRN",
"gOlC5hWSC5CauqxcsrsimdKSsplxmJHEIpSQ2wGwmZ7jztb2X0F14jLAq83onLpVZ6tZ/OISQpsJnV/H",
"RF01ekTkrO4Z3uVJQljUiikDROQsNVEpk++RngPSklBmWOrW4KAeSShSj9mjExRyCSo/vW5jK0ctwBHR",
"RIGeKAGh9Y5qSNRD+7jnVo0EhMZIZpZISRYFo2V7Nb+rq2a+OD4cnTwUwiW/mOS7WV58wui3FBCNzNGa",
"UpArWC/5BQ6Ke//2dYAF0RqkWfi/U9L6p9v6e7P1btI6+9efPtgTSLhcTGYX9efu2yH04sPOywd8T0AT",
"g5blRxRRY4DERyXeNAG3Jp+QlEuqF86TKUljjTtvbHw0SRPc2doMcEJZ9sHniGJEqDnX9tTf1crgFCSw",
"EJDmCFZnD+WL8EPHZrVFvqNTJFPt0CQ8ZXoiiNmje05xbUDNyfabt/4hW9I9Qz5oyxmr5p7Lr3cYmEH2",
"tCLmKjXo/eFwp7+31zvAQUn87B4evB/0d82KijToHxz3hgfdwSprj3rDL/3d3uTkoPul2x90dwY9HNTT",
"+5mHr5BXj0qlJoxqU2WRnYASUIrMwMf4dVmoGUmF4FKbjCMhdvm7v/cgLZxLgUOxYN/HknXK9+1CktCc",
"uvX8aFNzNCF2ZSakOyZJQUvTxBtqg6lmummiU1VkAgk1vTa2iQzn9Nqe/whiMJngLHgM9T4CifW8mXv1",
"Z87tioV90kwSJyBTln/tY4UBQGmSiMdjcw1S+auqL4r/pGKhQY5AXtPQd4KeZXeUJlKb0QDLlDH3n9Jc",
"iMK/Fj7HUS90/Ar82iKVcSmYVFL8qC3/nEIKI01c6ayRXDj61I6fkR0KrSYgzSOy8NabTFQ1WHCj9yz/",
"ZvxrXH5DqMHUFAmZMq+BHOymEFIpgel4geAWwjTbobqZGy6vsmtIRUfbU4fy8fpaH+ojQ4aMp43qjSYm",
"KxYLLL50S9piobkM552YaLM0+A5mGkeVIE3l6FHS00DnoYtTU3XsnnLgnibogEXf+ZxVjaqNzJrC8ZWj",
"TPdpoq4K4s/3wKJ2/DE1SG4nErQsa7aS0ivIxfowT7VItTf0n6DujGeLSWi0lP/xFfnny6Dye0lTz7pZ",
"+jAhgSCymn9DHscQ6vxDnu5WWcuXhFMFsslplwX8o8uGgzSgvsNvexcNGzedKmgYMwR8/KXKnmPPMdJc",
"k9jbKKnFYFCHMDVsGRmrzvuuoJ9g0U2dhq4kTNf8sBcjUmpC4ABTM8O1fHCexPB/W92jfusTFCoEsQ9w",
"93nKpjzvUZDQApMtfA86nO8PUCYWcb0JctS3fiSEkZkpJPuD4m3DookIi1CWq5FyokJtjNmY/fEHGmWx",
"j1k3jhGwSHBqLtMv4DYEoZETQiicQ3ilXuZdlrwBVIkfXVNi7opjdr4K+Rw5NDbQullnHKUKAZtyGUKE",
"BMjcYu6XvUSgj4RFMWWzsWvkmDt+HPMbRFDImaJKmyDd0UI3VM9RQsI5ZdCSQCJyEQMyOtkhYKUy6u+p",
"zpidn59fKs7G7G7MEBq7JDrGHTRuEvdjHLipxqCbae8Nk93Dvd5qMJfjbkKa0qg15bKl3PaN8Zgt7cPH",
"ttRTHZtN3h+gr/bYGQxwQSzirY3NDdsq4gIYERR38KuNzY1X2KbauWVq2+2Q+XfmDlX1IqpTyZTdeJD5",
"frpEs4EOeHUL8/q5ge1jpf22H+EO/gDa6WuTgYrd/O3NzUe02B7X7qooeE+/a+QCoQrlIr14hnHn9CzA",
"Kk0SIhe4k90IHH/txPb1Vnt9QlQjbiarIWLORGGyB5IvW73ShB9C5lFJr3Ctq6W+Olw2Dj4thWEBW0Fk",
"Z5SGAyy48mDiOnaIIAY3hRU1WI648uBiVeIOjxY/jSxNnc1lWQhqmcKytjNbP82N4obUN2A9ijIdWcE/",
"QxWKVhxPc7WcJ+0CWT08LF8bfw8XK1fV7+BjtSD5SFmfs2amj3E+BH4+63x3n9/MuCrsdZgr0KFMjVZQ",
"tqFUYW6kX/su+68fLV1yMGqzvhl79vvadozyxbZ8SZKAthfS0zunmmzHcqWZVGF2GVjPi7CVRD2rgf7a",
"96Ytx8R1LKqYcNEIiZXjjTXjA+j8Po7sTFtnqdI0bKgdtnvxK+tpoT3iYcnnqpNlKNbDqVpBsFLo95dN",
"Jz2tKOPC9ejRlMYasjuMBwzb0Gggx7cU5KLADudTkQq1G1PxlvSIi9Ey8D/KXWSKT1rfKDdLV8rNTV/z",
"xG81uwN5zfrMnP1Cjqwucfeka7frywC/dg/22Vs52C78vsMu2Xp4SenttVm0/e7hRcWfFXjqx8ppv5wZ",
"pRcJ1ZmcqV+cGmRNztJfJ2iKr0R/c2FxF2rPDxuIusrlC1JpGIJS0zSOF7+XEtsPL6q++/9xKmX6TDto",
"CkmwfWf+PLIYWtoc2/mPqn86n/qTi5/bScJCiOMMVjftfnhWP1uqYGMNtV3wGUTBPRr1WVDY/D3HIwJN",
"aKx+FFIjInTJnuHcTXPJ/QoXIx5egV61b2xLSAKJbaPRWUtFRPS677Pv2hroeCFAjVkLnZtZEzfrvINs",
"RK7KonBO2Kw4K6+n+bwpZVTNIbIzBGWz8w76BCBaJKbXgF64mCOnBs4FZ7Pzl7YFUmPI19q1ZctliqaQ",
"Q84YhLZzAUqTi9g6Um0JlBt6p2fLUo/AZ83t8kMmbCvCkbdSNnlIYhTBNcRcuBf/di7O3nThudai027H",
"Zt6cK915ZwI1aqFs6EjyKHXxeSyoTrtNBN2Ygg7nSbyR/YxpI+QJXp4t/x8AAP//HXQa3YQpAAA=",
}
// GetSwagger returns the content of the embedded swagger specification file
// or error if failed to decode
func decodeSpec() ([]byte, error) {
zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, ""))
if err != nil {
return nil, fmt.Errorf("error base64 decoding spec: %w", err)
}
zr, err := gzip.NewReader(bytes.NewReader(zipped))
if err != nil {
return nil, fmt.Errorf("error decompressing spec: %w", err)
}
var buf bytes.Buffer
_, err = buf.ReadFrom(zr)
if err != nil {
return nil, fmt.Errorf("error decompressing spec: %w", err)
}
return buf.Bytes(), nil
}
var rawSpec = decodeSpecCached()
// a naive cached of a decoded swagger spec
func decodeSpecCached() func() ([]byte, error) {
data, err := decodeSpec()
return func() ([]byte, error) {
return data, err
}
}
// Constructs a synthetic filesystem for resolving external references when loading openapi specifications.
func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) {
res := make(map[string]func() ([]byte, error))
if len(pathToFile) > 0 {
res[pathToFile] = rawSpec
}
return res
}
// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file. The external references of Swagger specification are resolved.
// The logic of resolving external references is tightly connected to "import-mapping" feature.
// Externally referenced files must be embedded in the corresponding golang packages.
// Urls can be supported but this task was out of the scope.
func GetSwagger() (swagger *openapi3.T, err error) {
resolvePath := PathToRawSpec("")
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) {
pathToFile := url.String()
pathToFile = path.Clean(pathToFile)
getSpec, ok := resolvePath[pathToFile]
if !ok {
err1 := fmt.Errorf("path not found: %s", pathToFile)
return nil, err1
}
return getSpec()
}
var specData []byte
specData, err = rawSpec()
if err != nil {
return
}
swagger, err = loader.LoadFromData(specData)
if err != nil {
return
}
return
}

View file

@ -23,11 +23,11 @@ const (
// Argon2id parameters (OWASP recommended minimum)
const (
argon2idTime = 1 // Iterations
argon2idMemory = 64 * 1024 // 64 MB
argon2idThreads = 4 // Parallelism
argon2idKeyLen = 32 // 256-bit output
argon2idSaltLen = 16 // 128-bit salt
argon2idTime = 1 // Iterations
argon2idMemory = 64 * 1024 // 64 MB
argon2idThreads = 4 // Parallelism
argon2idKeyLen = 32 // 256-bit output
argon2idSaltLen = 16 // 128-bit salt
)
// HashedKey represents a hashed API key with its algorithm and salt

View file

@ -21,35 +21,39 @@ func NewPathRegistry(root string) *PathRegistry {
}
// Binary paths
func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") }
func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") }
func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") }
func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") }
func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") }
func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") }
func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") }
func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") }
func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") }
func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") }
// Data paths
func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") }
func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") }
func (p *PathRegistry) JupyterStateDir() string { return filepath.Join(p.DataDir(), "active", "jupyter") }
func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") }
func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") }
func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") }
func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") }
func (p *PathRegistry) JupyterStateDir() string {
return filepath.Join(p.DataDir(), "active", "jupyter")
}
func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") }
func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") }
// Database paths
func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") }
func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") }
func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") }
func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") }
// Log paths
func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") }
func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") }
func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") }
func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") }
// Config paths
func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") }
func (p *PathRegistry) APIServerConfig() string { return filepath.Join(p.ConfigDir(), "api", "dev.yaml") }
func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") }
func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") }
func (p *PathRegistry) APIServerConfig() string {
return filepath.Join(p.ConfigDir(), "api", "dev.yaml")
}
func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") }
// Test paths
func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") }
func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") }
func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") }
func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") }
// State file paths (for service persistence)
func (p *PathRegistry) JupyterServicesFile() string {
@ -83,13 +87,13 @@ func detectRepoRoot() string {
if err != nil {
return "."
}
// Walk up directory tree looking for go.mod
for {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
parent := filepath.Dir(dir)
if parent == dir {
// Reached root
@ -97,7 +101,7 @@ func detectRepoRoot() string {
}
dir = parent
}
return "."
}

View file

@ -23,7 +23,9 @@ func NewEnvSecretsManager() *EnvSecretsManager { return &EnvSecretsManager{} }
func (e *EnvSecretsManager) Get(ctx context.Context, key string) (string, error) {
value := os.Getenv(key)
if value == "" { return "", fmt.Errorf("secret %s not found", key) }
if value == "" {
return "", fmt.Errorf("secret %s not found", key)
}
return value, nil
}
@ -47,6 +49,8 @@ func (e *EnvSecretsManager) List(ctx context.Context, prefix string) ([]string,
// RedactSecret masks a secret for safe logging
func RedactSecret(secret string) string {
if len(secret) <= 8 { return "***" }
if len(secret) <= 8 {
return "***"
}
return secret[:4] + "..." + secret[len(secret)-4:]
}

View file

@ -4,11 +4,11 @@ package domain
type JobStatus string
const (
StatusPending JobStatus = "pending"
StatusQueued JobStatus = "queued"
StatusRunning JobStatus = "running"
StatusCompleted JobStatus = "completed"
StatusFailed JobStatus = "failed"
StatusPending JobStatus = "pending"
StatusQueued JobStatus = "queued"
StatusRunning JobStatus = "running"
StatusCompleted JobStatus = "completed"
StatusFailed JobStatus = "failed"
)
// String returns the string representation of the status

View file

@ -164,6 +164,15 @@ func (l *Logger) Job(ctx context.Context, job string, task string) *Logger {
return l.WithContext(ctx, "job_name", job, "task_id", task)
}
// TraceIDFromContext extracts the trace ID from a context.
// Returns empty string if no trace ID is present.
func TraceIDFromContext(ctx context.Context) string {
if id, ok := ctx.Value(CtxTraceID).(string); ok {
return id
}
return ""
}
// Fatal logs an error message and exits with status 1.
func (l *Logger) Fatal(msg string, args ...any) {
l.Error(msg, args...)

View file

@ -14,19 +14,19 @@ import (
// EventStore provides append-only event storage using Redis Streams.
// Events are stored chronologically and can be queried for audit trails.
type EventStore struct {
client *redis.Client
ctx context.Context
retentionDays int
maxStreamLen int64
client *redis.Client
ctx context.Context
retentionDays int
maxStreamLen int64
}
// EventStoreConfig holds configuration for the event store.
type EventStoreConfig struct {
RedisAddr string
RedisPassword string
RedisDB int
RetentionDays int // How long to keep events (default: 7)
MaxStreamLen int64 // Max events per task stream (default: 1000)
RedisAddr string
RedisPassword string
RedisDB int
RetentionDays int // How long to keep events (default: 7)
MaxStreamLen int64 // Max events per task stream (default: 1000)
}
// NewEventStore creates a new event store instance.

View file

@ -29,7 +29,7 @@ type metricEvent struct {
}
// Config holds configuration for Queue
type Config struct {
type Config struct {
RedisAddr string
RedisPassword string
RedisDB int
@ -48,9 +48,9 @@ func NewQueue(cfg Config) (*Queue, error) {
}
} else {
opts = &redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
PoolSize: 50,
MinIdleConns: 10,
MaxRetries: 3,

View file

@ -17,14 +17,14 @@ type Metric struct {
// MetricSummary represents aggregated metric statistics
type MetricSummary struct {
Name string `json:"name"`
Count int64 `json:"count"`
Avg float64 `json:"avg"`
Min float64 `json:"min"`
Max float64 `json:"max"`
Sum float64 `json:"sum"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Name string `json:"name"`
Count int64 `json:"count"`
Avg float64 `json:"avg"`
Min float64 `json:"min"`
Max float64 `json:"max"`
Sum float64 `json:"sum"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
// RecordMetric records a metric to the database

View file

@ -0,0 +1,106 @@
// Package errors provides structured error types for the worker.
package errors
import (
"fmt"
"time"
)
// ExecutionError provides structured error context for task execution failures.
// It captures the task ID, execution phase, specific operation, root cause,
// and additional context to make debugging easier.
type ExecutionError struct {
TaskID string // The task that failed
Phase string // Current TaskState (queued, preparing, running, collecting)
Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset")
Cause error // The underlying error
Context map[string]string // Additional context (paths, IDs, etc.)
Timestamp time.Time // When the error occurred
}
// Error implements the error interface with a formatted message.
func (e ExecutionError) Error() string {
return fmt.Sprintf("[%s/%s] task=%s: %v", e.Phase, e.Operation, e.TaskID, e.Cause)
}
// Unwrap returns the underlying error for error chain inspection.
func (e ExecutionError) Unwrap() error {
return e.Cause
}
// WithContext adds a key-value pair to the error context.
func (e ExecutionError) WithContext(key, value string) ExecutionError {
if e.Context == nil {
e.Context = make(map[string]string)
}
e.Context[key] = value
return e
}
// ContextString returns a formatted string of all context values.
func (e ExecutionError) ContextString() string {
if len(e.Context) == 0 {
return ""
}
result := ""
for k, v := range e.Context {
if result != "" {
result += ", "
}
result += fmt.Sprintf("%s=%s", k, v)
}
return result
}
// NewExecutionError creates a new ExecutionError with the given parameters.
func NewExecutionError(taskID, phase, operation string, cause error) ExecutionError {
return ExecutionError{
TaskID: taskID,
Phase: phase,
Operation: operation,
Cause: cause,
Context: make(map[string]string),
Timestamp: time.Now(),
}
}
// IsExecutionError checks if an error is an ExecutionError.
func IsExecutionError(err error) bool {
_, ok := err.(ExecutionError)
return ok
}
// Common error operations for the worker lifecycle.
const (
// Preparation operations
OpCreateWorkspace = "create_workspace"
OpFetchDataset = "fetch_dataset"
OpMountVolume = "mount_volume"
OpSetupEnvironment = "setup_environment"
OpStageSnapshot = "stage_snapshot"
// Execution operations
OpStartContainer = "start_container"
OpExecuteCommand = "execute_command"
OpMonitorExecution = "monitor_execution"
// Collection operations
OpCollectResults = "collect_results"
OpUploadArtifacts = "upload_artifacts"
OpCleanupWorkspace = "cleanup_workspace"
// Validation operations
OpValidateManifest = "validate_manifest"
OpCheckResources = "check_resources"
OpVerifyProvenance = "verify_provenance"
)
// Common phase names (should match TaskState values).
const (
PhaseQueued = "queued"
PhasePreparing = "preparing"
PhaseRunning = "running"
PhaseCollecting = "collecting"
PhaseCompleted = "completed"
PhaseFailed = "failed"
)

View file

@ -126,6 +126,9 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
cfg.LocalMode,
)
// Create state manager for task lifecycle management
stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available
// Create job runner
jobRunner := executor.NewJobRunner(
localExecutor,
@ -140,6 +143,7 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
exec,
metricsObj,
logger,
stateMgr,
)
// Create resource manager

View file

@ -33,6 +33,7 @@ type RunLoop struct {
executor TaskExecutor
metrics MetricsRecorder
logger Logger
stateMgr *StateManager
// State management
running map[string]context.CancelFunc
@ -64,6 +65,7 @@ func NewRunLoop(
executor TaskExecutor,
metrics MetricsRecorder,
logger Logger,
stateMgr *StateManager,
) *RunLoop {
ctx, cancel := context.WithCancel(context.Background())
return &RunLoop{
@ -72,6 +74,7 @@ func NewRunLoop(
executor: executor,
metrics: metrics,
logger: logger,
stateMgr: stateMgr,
running: make(map[string]context.CancelFunc),
ctx: ctx,
cancel: cancel,
@ -185,17 +188,45 @@ func (r *RunLoop) waitForTasks() {
func (r *RunLoop) executeTask(task *queue.Task) {
defer r.releaseRunningSlot(task.ID)
// Transition to preparing state
if r.stateMgr != nil {
if err := r.stateMgr.Transition(task, StatePreparing); err != nil {
r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err)
}
}
r.metrics.RecordTaskStart()
defer r.metrics.RecordTaskCompletion()
r.logger.Info("starting task", "task_id", task.ID, "job_name", task.JobName)
// Transition to running state
if r.stateMgr != nil {
if err := r.stateMgr.Transition(task, StateRunning); err != nil {
r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err)
}
}
taskCtx, cancel := context.WithTimeout(r.ctx, 24*time.Hour)
defer cancel()
if err := r.executor.Execute(taskCtx, task); err != nil {
r.logger.Error("task execution failed", "task_id", task.ID, "error", err)
r.metrics.RecordTaskFailure()
// Transition to failed state
if r.stateMgr != nil {
if stateErr := r.stateMgr.Transition(task, StateFailed); stateErr != nil {
r.logger.Error("failed to transition task to failed state", "task_id", task.ID, "error", stateErr)
}
}
} else {
// Transition to completed state
if r.stateMgr != nil {
if err := r.stateMgr.Transition(task, StateCompleted); err != nil {
r.logger.Error("failed to transition task to completed state", "task_id", task.ID, "error", err)
}
}
}
_ = r.queue.ReleaseLease(task.ID, r.config.WorkerID)

View file

@ -0,0 +1,130 @@
// Package lifecycle provides task lifecycle management with explicit state transitions.
package lifecycle
import (
"fmt"
"slices"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/domain"
)
// TaskState represents the current state of a task in its lifecycle.
// These states form a finite state machine with valid transitions defined below.
type TaskState string
const (
// StateQueued indicates the task is waiting to be picked up by a worker.
StateQueued TaskState = "queued"
// StatePreparing indicates the task is being prepared (workspace setup, data staging).
StatePreparing TaskState = "preparing"
// StateRunning indicates the task is currently executing in a container.
StateRunning TaskState = "running"
// StateCollecting indicates the task has finished execution and results are being collected.
StateCollecting TaskState = "collecting"
// StateCompleted indicates the task finished successfully.
StateCompleted TaskState = "completed"
// StateFailed indicates the task failed during execution.
StateFailed TaskState = "failed"
)
// ValidTransitions defines the allowed state transitions.
// The key is the "from" state, the value is a list of valid "to" states.
// This enforces that state transitions follow the expected lifecycle.
var ValidTransitions = map[TaskState][]TaskState{
StateQueued: {StatePreparing, StateFailed},
StatePreparing: {StateRunning, StateFailed},
StateRunning: {StateCollecting, StateFailed},
StateCollecting: {StateCompleted, StateFailed},
StateCompleted: {},
StateFailed: {},
}
// StateTransitionError is returned when an invalid state transition is attempted.
type StateTransitionError struct {
From TaskState
To TaskState
}
func (e StateTransitionError) Error() string {
return fmt.Sprintf("invalid state transition: %s -> %s", e.From, e.To)
}
// StateManager manages task state transitions with audit logging.
type StateManager struct {
enabled bool
auditor *audit.Logger
}
// NewStateManager creates a new state manager with the given audit logger.
func NewStateManager(auditor *audit.Logger) *StateManager {
return &StateManager{
enabled: auditor != nil,
auditor: auditor,
}
}
// Transition attempts to transition a task from its current state to a new state.
// It validates the transition, updates the task status, and logs the event.
// Returns StateTransitionError if the transition is not valid.
func (sm *StateManager) Transition(task *domain.Task, to TaskState) error {
from := TaskState(task.Status)
// Validate the transition
if err := sm.validateTransition(from, to); err != nil {
return err
}
// Audit the transition before updating
if sm.enabled && sm.auditor != nil {
sm.auditor.Log(audit.Event{
EventType: audit.EventJobStarted,
Timestamp: time.Now(),
Resource: task.ID,
Action: "task_state_change",
Success: true,
Metadata: map[string]interface{}{
"job_name": task.JobName,
"old_state": string(from),
"new_state": string(to),
},
})
}
// Update task state
task.Status = string(to)
return nil
}
// validateTransition checks if a transition from one state to another is valid.
func (sm *StateManager) validateTransition(from, to TaskState) error {
// Check if "from" state is valid
allowed, ok := ValidTransitions[from]
if !ok {
return StateTransitionError{From: from, To: to}
}
// Check if "to" state is in the allowed list
if slices.Contains(allowed, to) {
return nil
}
return StateTransitionError{From: from, To: to}
}
// IsTerminalState returns true if the state is terminal (no further transitions allowed).
func IsTerminalState(state TaskState) bool {
return state == StateCompleted || state == StateFailed
}
// CanTransition returns true if a transition from -> to is valid.
func CanTransition(from, to TaskState) bool {
allowed, ok := ValidTransitions[from]
if !ok {
return false
}
return slices.Contains(allowed, to)
}

View file

@ -18,11 +18,11 @@ func BenchmarkDatasetSizeComparison(b *testing.B) {
numFiles int
totalMB int
}{
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
}
for _, tc := range sizes {

View file

@ -11,6 +11,7 @@ import (
// - Regex matching is CPU-intensive
// - High volume log pipelines process thousands of messages/sec
// - C++ can use Hyperscan/RE2 for parallel regex matching
//
// Expected speedup: 3-5x for high-volume logging
func BenchmarkLogSanitizeMessage(b *testing.B) {
// Test messages with various sensitive data patterns

View file

@ -1,68 +1,68 @@
package worker_test
import (
"os"
"path/filepath"
"testing"
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker"
)
func TestScanArtifacts_SkipsKnownPathsAndLogs(t *testing.T) {
runDir := t.TempDir()
runDir := t.TempDir()
mustWrite := func(rel string, data []byte) {
p := filepath.Join(runDir, rel)
if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil {
t.Fatalf("mkdir: %v", err)
}
if err := os.WriteFile(p, data, 0600); err != nil {
t.Fatalf("write file: %v", err)
}
}
mustWrite := func(rel string, data []byte) {
p := filepath.Join(runDir, rel)
if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil {
t.Fatalf("mkdir: %v", err)
}
if err := os.WriteFile(p, data, 0600); err != nil {
t.Fatalf("write file: %v", err)
}
}
mustWrite("run_manifest.json", []byte("{}"))
mustWrite("output.log", []byte("log"))
mustWrite("code/ignored.txt", []byte("ignore"))
mustWrite("snapshot/ignored.bin", []byte("ignore"))
mustWrite("run_manifest.json", []byte("{}"))
mustWrite("output.log", []byte("log"))
mustWrite("code/ignored.txt", []byte("ignore"))
mustWrite("snapshot/ignored.bin", []byte("ignore"))
mustWrite("results/metrics.jsonl", []byte("m"))
mustWrite("checkpoints/best.pt", []byte("checkpoint"))
mustWrite("plots/loss.png", []byte("png"))
mustWrite("results/metrics.jsonl", []byte("m"))
mustWrite("checkpoints/best.pt", []byte("checkpoint"))
mustWrite("plots/loss.png", []byte("png"))
art, err := worker.ScanArtifacts(runDir)
if err != nil {
t.Fatalf("scanArtifacts: %v", err)
}
if art == nil {
t.Fatalf("expected artifacts")
}
art, err := worker.ScanArtifacts(runDir)
if err != nil {
t.Fatalf("scanArtifacts: %v", err)
}
if art == nil {
t.Fatalf("expected artifacts")
}
paths := make([]string, 0, len(art.Files))
var total int64
for _, f := range art.Files {
paths = append(paths, f.Path)
total += f.SizeBytes
}
paths := make([]string, 0, len(art.Files))
var total int64
for _, f := range art.Files {
paths = append(paths, f.Path)
total += f.SizeBytes
}
want := []string{
"checkpoints/best.pt",
"plots/loss.png",
"results/metrics.jsonl",
}
if len(paths) != len(want) {
t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths)
}
for i := range want {
if paths[i] != want[i] {
t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i])
}
}
want := []string{
"checkpoints/best.pt",
"plots/loss.png",
"results/metrics.jsonl",
}
if len(paths) != len(want) {
t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths)
}
for i := range want {
if paths[i] != want[i] {
t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i])
}
}
if art.TotalSizeBytes != total {
t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes)
}
if art.DiscoveryTime.IsZero() {
t.Fatalf("expected discovery_time")
}
if art.TotalSizeBytes != total {
t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes)
}
if art.DiscoveryTime.IsZero() {
t.Fatalf("expected discovery_time")
}
}