refactor(api): internal refactoring for TUI and worker modules
- Refactor internal/worker and internal/queue packages - Update cmd/tui for monitoring interface - Update test configurations
This commit is contained in:
parent
7583932897
commit
23e5f3d1dc
28 changed files with 1620 additions and 636 deletions
|
|
@ -28,19 +28,19 @@ func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportS
|
|||
// Returns the path to the exported file
|
||||
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
|
||||
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
|
||||
|
||||
|
||||
// Placeholder - actual implementation would call API
|
||||
// POST /api/jobs/{id}/export?anonymize=true
|
||||
|
||||
|
||||
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
|
||||
|
||||
|
||||
s.logger.Info("export complete", "job", jobName, "path", exportPath)
|
||||
return exportPath, nil
|
||||
}
|
||||
|
||||
// ExportOptions contains options for export
|
||||
type ExportOptions struct {
|
||||
Anonymize bool
|
||||
Anonymize bool
|
||||
IncludeLogs bool
|
||||
IncludeData bool
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,37 +16,37 @@ import (
|
|||
|
||||
// WebSocketClient manages real-time updates from the server
|
||||
type WebSocketClient struct {
|
||||
conn *websocket.Conn
|
||||
serverURL string
|
||||
apiKey string
|
||||
logger *logging.Logger
|
||||
|
||||
conn *websocket.Conn
|
||||
serverURL string
|
||||
apiKey string
|
||||
logger *logging.Logger
|
||||
|
||||
// Channels for different update types
|
||||
jobUpdates chan model.JobUpdateMsg
|
||||
gpuUpdates chan model.GPUUpdateMsg
|
||||
jobUpdates chan model.JobUpdateMsg
|
||||
gpuUpdates chan model.GPUUpdateMsg
|
||||
statusUpdates chan model.StatusMsg
|
||||
|
||||
|
||||
// Control
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
connected bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
connected bool
|
||||
}
|
||||
|
||||
// JobUpdateMsg represents a real-time job status update
|
||||
type JobUpdateMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id"`
|
||||
Progress int `json:"progress"`
|
||||
JobName string `json:"job_name"`
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id"`
|
||||
Progress int `json:"progress"`
|
||||
}
|
||||
|
||||
// GPUUpdateMsg represents a real-time GPU status update
|
||||
type GPUUpdateMsg struct {
|
||||
DeviceID int `json:"device_id"`
|
||||
Utilization int `json:"utilization"`
|
||||
MemoryUsed int64 `json:"memory_used"`
|
||||
MemoryTotal int64 `json:"memory_total"`
|
||||
Temperature int `json:"temperature"`
|
||||
DeviceID int `json:"device_id"`
|
||||
Utilization int `json:"utilization"`
|
||||
MemoryUsed int64 `json:"memory_used"`
|
||||
MemoryTotal int64 `json:"memory_total"`
|
||||
Temperature int `json:"temperature"`
|
||||
}
|
||||
|
||||
// NewWebSocketClient creates a new WebSocket client
|
||||
|
|
@ -71,26 +71,26 @@ func (c *WebSocketClient) Connect() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("invalid server URL: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Convert http/https to ws/wss
|
||||
wsScheme := "ws"
|
||||
if u.Scheme == "https" {
|
||||
wsScheme = "wss"
|
||||
}
|
||||
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
|
||||
|
||||
|
||||
// Create dialer with timeout
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
Subprotocols: []string{"fetchml-v1"},
|
||||
}
|
||||
|
||||
|
||||
// Add API key to headers
|
||||
headers := http.Header{}
|
||||
if c.apiKey != "" {
|
||||
headers.Set("X-API-Key", c.apiKey)
|
||||
}
|
||||
|
||||
|
||||
conn, resp, err := dialer.Dial(wsURL, headers)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
|
|
@ -98,17 +98,17 @@ func (c *WebSocketClient) Connect() error {
|
|||
}
|
||||
return fmt.Errorf("websocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.logger.Info("websocket connected", "url", wsURL)
|
||||
|
||||
|
||||
// Start message handler
|
||||
go c.messageHandler()
|
||||
|
||||
|
||||
// Start heartbeat
|
||||
go c.heartbeat()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -134,15 +134,15 @@ func (c *WebSocketClient) messageHandler() {
|
|||
return
|
||||
default:
|
||||
}
|
||||
|
||||
|
||||
if c.conn == nil {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
|
||||
// Read message
|
||||
messageType, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
|
|
@ -150,7 +150,7 @@ func (c *WebSocketClient) messageHandler() {
|
|||
c.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
c.connected = false
|
||||
|
||||
|
||||
// Attempt reconnect
|
||||
time.Sleep(5 * time.Second)
|
||||
if err := c.Connect(); err != nil {
|
||||
|
|
@ -158,7 +158,7 @@ func (c *WebSocketClient) messageHandler() {
|
|||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Handle binary vs text messages
|
||||
if messageType == websocket.BinaryMessage {
|
||||
c.handleBinaryMessage(data)
|
||||
|
|
@ -173,11 +173,11 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
|||
if len(data) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Binary protocol: [opcode:1][data...]
|
||||
opcode := data[0]
|
||||
payload := data[1:]
|
||||
|
||||
|
||||
switch opcode {
|
||||
case 0x01: // Job update
|
||||
var update JobUpdateMsg
|
||||
|
|
@ -186,7 +186,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
|||
return
|
||||
}
|
||||
c.jobUpdates <- model.JobUpdateMsg(update)
|
||||
|
||||
|
||||
case 0x02: // GPU update
|
||||
var update GPUUpdateMsg
|
||||
if err := json.Unmarshal(payload, &update); err != nil {
|
||||
|
|
@ -194,7 +194,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
|||
return
|
||||
}
|
||||
c.gpuUpdates <- model.GPUUpdateMsg(update)
|
||||
|
||||
|
||||
case 0x03: // Status message
|
||||
var status model.StatusMsg
|
||||
if err := json.Unmarshal(payload, &status); err != nil {
|
||||
|
|
@ -212,7 +212,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) {
|
|||
c.logger.Error("failed to unmarshal text message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
msgType, _ := msg["type"].(string)
|
||||
switch msgType {
|
||||
case "job_update":
|
||||
|
|
@ -228,7 +228,7 @@ func (c *WebSocketClient) handleTextMessage(data []byte) {
|
|||
func (c *WebSocketClient) heartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
|
|
@ -249,12 +249,12 @@ func (c *WebSocketClient) Subscribe(channels ...string) error {
|
|||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
|
||||
subMsg := map[string]interface{}{
|
||||
"action": "subscribe",
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
|
||||
data, _ := json.Marshal(subMsg)
|
||||
return c.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
|
|
|||
156
internal/api/adapter.go
Normal file
156
internal/api/adapter.go
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
// Package api provides HTTP handlers and OpenAPI-generated server interface implementations
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// HandlerAdapter implements the generated ServerInterface using existing handlers
|
||||
type HandlerAdapter struct {
|
||||
jobsHandler *jobs.Handler
|
||||
jupyterHandler *jupyter.Handler
|
||||
datasetsHandler *datasets.Handler
|
||||
}
|
||||
|
||||
// NewHandlerAdapter creates a new handler adapter
|
||||
func NewHandlerAdapter(
|
||||
jobsHandler *jobs.Handler,
|
||||
jupyterHandler *jupyter.Handler,
|
||||
datasetsHandler *datasets.Handler,
|
||||
) *HandlerAdapter {
|
||||
return &HandlerAdapter{
|
||||
jobsHandler: jobsHandler,
|
||||
jupyterHandler: jupyterHandler,
|
||||
datasetsHandler: datasetsHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure HandlerAdapter implements the generated interface
|
||||
var _ ServerInterface = (*HandlerAdapter)(nil)
|
||||
|
||||
// toHTTPHandler converts echo.Context to standard HTTP handler
|
||||
func toHTTPHandler(h func(http.ResponseWriter, *http.Request)) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
h(c.Response().Writer, c.Request())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetHealth implements the health check endpoint
|
||||
func (a *HandlerAdapter) GetHealth(ctx echo.Context) error {
|
||||
return ctx.String(200, "OK\n")
|
||||
}
|
||||
|
||||
// GetV1Experiments lists all experiments
|
||||
func (a *HandlerAdapter) GetV1Experiments(ctx echo.Context) error {
|
||||
return ctx.JSON(200, map[string]any{
|
||||
"experiments": []any{},
|
||||
"message": "Not yet implemented",
|
||||
})
|
||||
}
|
||||
|
||||
// PostV1Experiments creates a new experiment
|
||||
func (a *HandlerAdapter) PostV1Experiments(ctx echo.Context) error {
|
||||
return ctx.JSON(201, map[string]any{
|
||||
"message": "Not yet implemented",
|
||||
})
|
||||
}
|
||||
|
||||
// GetV1JupyterServices lists all Jupyter services
|
||||
func (a *HandlerAdapter) GetV1JupyterServices(ctx echo.Context) error {
|
||||
if a.jupyterHandler == nil {
|
||||
return ctx.JSON(503, map[string]any{
|
||||
"error": "Jupyter service not available",
|
||||
"code": "SERVICE_UNAVAILABLE",
|
||||
})
|
||||
}
|
||||
return toHTTPHandler(a.jupyterHandler.ListServicesHTTP)(ctx)
|
||||
}
|
||||
|
||||
// PostV1JupyterServices starts a new Jupyter service
|
||||
func (a *HandlerAdapter) PostV1JupyterServices(ctx echo.Context) error {
|
||||
if a.jupyterHandler == nil {
|
||||
return ctx.JSON(503, map[string]any{
|
||||
"error": "Jupyter service not available",
|
||||
"code": "SERVICE_UNAVAILABLE",
|
||||
})
|
||||
}
|
||||
return toHTTPHandler(a.jupyterHandler.StartServiceHTTP)(ctx)
|
||||
}
|
||||
|
||||
// DeleteV1JupyterServicesServiceId stops a Jupyter service
|
||||
func (a *HandlerAdapter) DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error {
|
||||
if a.jupyterHandler == nil {
|
||||
return ctx.JSON(503, map[string]any{
|
||||
"error": "Jupyter service not available",
|
||||
"code": "SERVICE_UNAVAILABLE",
|
||||
})
|
||||
}
|
||||
// TODO: Implement when StopServiceHTTP is available
|
||||
return ctx.JSON(501, map[string]any{
|
||||
"error": "Not implemented",
|
||||
"code": "NOT_IMPLEMENTED",
|
||||
"message": "Jupyter service stop not yet implemented via REST API",
|
||||
})
|
||||
}
|
||||
|
||||
// GetV1Queue returns queue status
|
||||
func (a *HandlerAdapter) GetV1Queue(ctx echo.Context) error {
|
||||
return ctx.JSON(200, map[string]any{
|
||||
"status": "healthy",
|
||||
"pending": 0,
|
||||
"running": 0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetV1Tasks lists all tasks
|
||||
func (a *HandlerAdapter) GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error {
|
||||
if a.jobsHandler == nil {
|
||||
return ctx.JSON(503, map[string]any{
|
||||
"error": "Jobs handler not available",
|
||||
"code": "SERVICE_UNAVAILABLE",
|
||||
})
|
||||
}
|
||||
return toHTTPHandler(a.jobsHandler.ListAllJobsHTTP)(ctx)
|
||||
}
|
||||
|
||||
// PostV1Tasks creates a new task
|
||||
func (a *HandlerAdapter) PostV1Tasks(ctx echo.Context) error {
|
||||
return ctx.JSON(501, map[string]any{
|
||||
"error": "Not implemented",
|
||||
"code": "NOT_IMPLEMENTED",
|
||||
"message": "Task creation via REST API not yet implemented - use WebSocket",
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteV1TasksTaskId cancels/deletes a task
|
||||
func (a *HandlerAdapter) DeleteV1TasksTaskId(ctx echo.Context, taskId string) error {
|
||||
return ctx.JSON(501, map[string]any{
|
||||
"error": "Not implemented",
|
||||
"code": "NOT_IMPLEMENTED",
|
||||
"message": "Task cancellation via REST API not yet implemented - use WebSocket",
|
||||
})
|
||||
}
|
||||
|
||||
// GetV1TasksTaskId gets task details
|
||||
func (a *HandlerAdapter) GetV1TasksTaskId(ctx echo.Context, taskId string) error {
|
||||
return ctx.JSON(501, map[string]any{
|
||||
"error": "Not implemented",
|
||||
"code": "NOT_IMPLEMENTED",
|
||||
"message": "Task details via REST API not yet implemented - use WebSocket",
|
||||
})
|
||||
}
|
||||
|
||||
// GetWs handles WebSocket connections
|
||||
func (a *HandlerAdapter) GetWs(ctx echo.Context) error {
|
||||
return ctx.JSON(426, map[string]any{
|
||||
"error": "WebSocket connection required",
|
||||
"code": "UPGRADE_REQUIRED",
|
||||
"message": "Use WebSocket protocol to connect to this endpoint",
|
||||
})
|
||||
}
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
// Package api provides error handling utilities for the API
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error bool `json:"error"`
|
||||
Code byte `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// SuccessResponse represents a standardized success response
|
||||
type SuccessResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// WriteError writes a standardized error response
|
||||
func WriteError(w http.ResponseWriter, code byte, message, details string, statusCode int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := ErrorResponse{
|
||||
Error: true,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// WriteSuccess writes a standardized success response
|
||||
func WriteSuccess(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// Common error codes for API responses
|
||||
const (
|
||||
ErrCodeUnknownError = 0x00
|
||||
ErrCodeInvalidRequest = 0x01
|
||||
ErrCodeAuthenticationFailed = 0x02
|
||||
ErrCodePermissionDenied = 0x03
|
||||
ErrCodeResourceNotFound = 0x04
|
||||
ErrCodeResourceAlreadyExists = 0x05
|
||||
ErrCodeServerOverloaded = 0x10
|
||||
ErrCodeDatabaseError = 0x11
|
||||
ErrCodeNetworkError = 0x12
|
||||
ErrCodeStorageError = 0x13
|
||||
ErrCodeTimeout = 0x14
|
||||
ErrCodeJobNotFound = 0x20
|
||||
ErrCodeJobAlreadyRunning = 0x21
|
||||
ErrCodeJobFailedToStart = 0x22
|
||||
ErrCodeJobExecutionFailed = 0x23
|
||||
ErrCodeJobCancelled = 0x24
|
||||
ErrCodeOutOfMemory = 0x30
|
||||
ErrCodeDiskFull = 0x31
|
||||
ErrCodeInvalidConfiguration = 0x32
|
||||
ErrCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// HTTP status code mappings
|
||||
const (
|
||||
StatusBadRequest = http.StatusBadRequest
|
||||
StatusUnauthorized = http.StatusUnauthorized
|
||||
StatusForbidden = http.StatusForbidden
|
||||
StatusNotFound = http.StatusNotFound
|
||||
StatusConflict = http.StatusConflict
|
||||
StatusInternalServerError = http.StatusInternalServerError
|
||||
StatusServiceUnavailable = http.StatusServiceUnavailable
|
||||
StatusTooManyRequests = http.StatusTooManyRequests
|
||||
)
|
||||
|
||||
// ErrorCodeToHTTPStatus maps API error codes to HTTP status codes
|
||||
func ErrorCodeToHTTPStatus(code byte) int {
|
||||
switch code {
|
||||
case ErrCodeInvalidRequest:
|
||||
return StatusBadRequest
|
||||
case ErrCodeAuthenticationFailed:
|
||||
return StatusUnauthorized
|
||||
case ErrCodePermissionDenied:
|
||||
return StatusForbidden
|
||||
case ErrCodeResourceNotFound, ErrCodeJobNotFound:
|
||||
return StatusNotFound
|
||||
case ErrCodeResourceAlreadyExists, ErrCodeJobAlreadyRunning:
|
||||
return StatusConflict
|
||||
case ErrCodeServerOverloaded, ErrCodeServiceUnavailable:
|
||||
return StatusServiceUnavailable
|
||||
case ErrCodeDatabaseError, ErrCodeNetworkError, ErrCodeStorageError:
|
||||
return StatusInternalServerError
|
||||
default:
|
||||
return StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse creates a new error response with the given details
|
||||
func NewErrorResponse(code byte, message, details string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Error: true,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewSuccessResponse creates a new success response with the given data
|
||||
func NewSuccessResponse(data interface{}) SuccessResponse {
|
||||
return SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
|
||||
apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
|
|
@ -49,8 +50,8 @@ func (s *Server) initializeComponents() error {
|
|||
// Initialize Jupyter service manager
|
||||
s.initJupyterServiceManager()
|
||||
|
||||
// Initialize handlers
|
||||
s.handlers = NewHandlers(s.expManager, nil, s.logger)
|
||||
// Initialize OpenAPI validation middleware (optional, doesn't block startup)
|
||||
s.initValidationMiddleware()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -250,6 +251,24 @@ func (s *Server) initAuditLogger() *audit.Logger {
|
|||
return al
|
||||
}
|
||||
|
||||
// initValidationMiddleware initializes the OpenAPI validation middleware
|
||||
func (s *Server) initValidationMiddleware() {
|
||||
// Only initialize if OpenAPI spec exists
|
||||
if _, err := os.Stat("api/openapi.yaml"); err != nil {
|
||||
s.logger.Debug("OpenAPI spec not found, skipping validation middleware")
|
||||
return
|
||||
}
|
||||
|
||||
vm, err := apimiddleware.NewValidationMiddleware("api/openapi.yaml")
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to initialize OpenAPI validation middleware", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.validationMiddleware = vm
|
||||
s.logger.Info("OpenAPI validation middleware initialized")
|
||||
}
|
||||
|
||||
// getSecurityConfig extracts security config from server config
|
||||
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
|
||||
return &config.SecurityConfig{
|
||||
|
|
|
|||
|
|
@ -1,313 +0,0 @@
|
|||
// Package api provides HTTP handlers for the fetch_ml API server
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// Handlers groups all HTTP handlers
|
||||
type Handlers struct {
|
||||
expManager *experiment.Manager
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewHandlers creates a new handler group
|
||||
func NewHandlers(
|
||||
expManager *experiment.Manager,
|
||||
jupyterServiceMgr *jupyter.ServiceManager,
|
||||
logger *logging.Logger,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
expManager: expManager,
|
||||
jupyterServiceMgr: jupyterServiceMgr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandlers registers all HTTP handlers with the mux
|
||||
func (h *Handlers) RegisterHandlers(mux *http.ServeMux) {
|
||||
// Health check endpoints
|
||||
mux.HandleFunc("/db-status", h.handleDBStatus)
|
||||
|
||||
// Jupyter service endpoints
|
||||
if h.jupyterServiceMgr != nil {
|
||||
mux.HandleFunc("/api/jupyter/services", h.handleJupyterServices)
|
||||
mux.HandleFunc("/api/jupyter/experiments/link", h.handleJupyterExperimentLink)
|
||||
mux.HandleFunc("/api/jupyter/experiments/sync", h.handleJupyterExperimentSync)
|
||||
}
|
||||
}
|
||||
|
||||
// handleHealth responds with a simple health check
|
||||
func (h *Handlers) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, "OK\n")
|
||||
}
|
||||
|
||||
// handleDBStatus responds with database connection status
|
||||
func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// This would need the DB instance passed to handlers
|
||||
// For now, return a basic response
|
||||
response := map[string]any{
|
||||
"status": "unknown",
|
||||
"message": "Database status check not implemented",
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleJupyterServices handles Jupyter service management requests
|
||||
func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if !user.HasPermission("jupyter:read") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.listJupyterServices(w, r)
|
||||
case http.MethodPost:
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.startJupyterService(w, r)
|
||||
case http.MethodDelete:
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.stopJupyterService(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// listJupyterServices lists all Jupyter services
|
||||
func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) {
|
||||
services := h.jupyterServiceMgr.ListServices()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startJupyterService starts a new Jupyter service
|
||||
func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) {
|
||||
var req jupyter.StartRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
service, err := h.jupyterServiceMgr.StartService(ctx, &req)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to start service: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stopJupyterService stops a Jupyter service
|
||||
func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
|
||||
serviceID := r.URL.Query().Get("id")
|
||||
if serviceID == "" {
|
||||
http.Error(w, "Service ID is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
if err := h.jupyterServiceMgr.StopService(ctx, serviceID); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to stop service: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "stopped",
|
||||
"id": serviceID,
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleJupyterExperimentLink handles linking Jupyter workspaces with experiments
|
||||
func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Workspace string `json:"workspace"`
|
||||
ExperimentID string `json:"experiment_id"`
|
||||
ServiceID string `json:"service_id,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Workspace == "" || req.ExperimentID == "" {
|
||||
http.Error(w, "Workspace and experiment_id are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !h.expManager.ExperimentExists(req.ExperimentID) {
|
||||
http.Error(w, "Experiment not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Link workspace with experiment using service manager
|
||||
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(
|
||||
req.Workspace,
|
||||
req.ExperimentID,
|
||||
req.ServiceID,
|
||||
); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get workspace metadata to return
|
||||
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
|
||||
if err != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to get workspace metadata: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("jupyter workspace linked with experiment",
|
||||
"workspace", req.Workspace,
|
||||
"experiment_id", req.ExperimentID,
|
||||
"service_id", req.ServiceID)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "linked",
|
||||
"data": metadata,
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments
|
||||
func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Workspace string `json:"workspace"`
|
||||
ExperimentID string `json:"experiment_id"`
|
||||
Direction string `json:"direction"` // "pull" or "push"
|
||||
SyncType string `json:"sync_type"` // "data", "notebooks", "all"
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Workspace == "" || req.ExperimentID == "" || req.Direction == "" {
|
||||
http.Error(w, "Workspace, experiment_id, and direction are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate experiment exists
|
||||
if !h.expManager.ExperimentExists(req.ExperimentID) {
|
||||
http.Error(w, "Experiment not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Perform sync operation using service manager
|
||||
ctx := r.Context()
|
||||
if err := h.jupyterServiceMgr.SyncWorkspaceWithExperiment(
|
||||
ctx, req.Workspace, req.ExperimentID, req.Direction); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to sync workspace: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get updated metadata
|
||||
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
|
||||
if err != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to get workspace metadata: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Create sync result
|
||||
syncResult := map[string]interface{}{
|
||||
"workspace": req.Workspace,
|
||||
"experiment_id": req.ExperimentID,
|
||||
"direction": req.Direction,
|
||||
"sync_type": req.SyncType,
|
||||
"synced_at": metadata.LastSync,
|
||||
"status": "completed",
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
h.logger.Info("jupyter workspace sync completed",
|
||||
"workspace", req.Workspace,
|
||||
"experiment_id", req.ExperimentID,
|
||||
"direction", req.Direction,
|
||||
"sync_type", req.SyncType)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(syncResult); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -15,9 +15,9 @@ import (
|
|||
|
||||
// Handler provides Jupyter-related WebSocket handlers
|
||||
type Handler struct {
|
||||
logger *logging.Logger
|
||||
jupyterMgr *jupyter.ServiceManager
|
||||
authConfig *auth.Config
|
||||
logger *logging.Logger
|
||||
jupyterMgr *jupyter.ServiceManager
|
||||
authConfig *auth.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new Jupyter handler
|
||||
|
|
@ -35,11 +35,11 @@ func NewHandler(
|
|||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeServiceUnavailable = 0x33
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// Permissions
|
||||
|
|
@ -141,18 +141,18 @@ func (h *Handler) HandleListJupyter(conn *websocket.Conn, payload []byte, user *
|
|||
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"services": []interface{}{},
|
||||
"count": 0,
|
||||
"success": true,
|
||||
"services": []interface{}{},
|
||||
"count": 0,
|
||||
})
|
||||
}
|
||||
|
||||
services := h.jupyterMgr.ListServices()
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"services": services,
|
||||
"count": len(services),
|
||||
"success": true,
|
||||
"services": services,
|
||||
"count": len(services),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
|
|||
if len(s.config.Security.IPWhitelist) > 0 {
|
||||
handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler)
|
||||
}
|
||||
// Add OpenAPI validation if available
|
||||
if s.validationMiddleware != nil {
|
||||
handler = s.validationMiddleware.ValidateRequest(handler)
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
87
internal/api/middleware/validation.go
Normal file
87
internal/api/middleware/validation.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Package middleware provides request/response validation using OpenAPI spec
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/getkin/kin-openapi/openapi3filter"
|
||||
"github.com/getkin/kin-openapi/routers"
|
||||
"github.com/getkin/kin-openapi/routers/gorillamux"
|
||||
)
|
||||
|
||||
// ValidationMiddleware validates HTTP requests against OpenAPI spec
|
||||
type ValidationMiddleware struct {
|
||||
router routers.Router
|
||||
}
|
||||
|
||||
// NewValidationMiddleware creates a new validation middleware from OpenAPI spec
|
||||
func NewValidationMiddleware(specPath string) (*ValidationMiddleware, error) {
|
||||
// Load OpenAPI spec
|
||||
loader := openapi3.NewLoader()
|
||||
doc, err := loader.LoadFromFile(specPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate the spec itself
|
||||
if err := doc.Validate(loader.Context); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create router for path matching
|
||||
router, err := gorillamux.NewRouter(doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ValidationMiddleware{router: router}, nil
|
||||
}
|
||||
|
||||
// ValidateRequest validates an incoming HTTP request
|
||||
func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip validation for health endpoints
|
||||
if r.URL.Path == "/health/ok" || r.URL.Path == "/health" || r.URL.Path == "/metrics" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Find the route
|
||||
route, pathParams, err := v.router.FindRoute(r)
|
||||
if err != nil {
|
||||
// Route not in spec - allow through (might be unregistered endpoint)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Read and restore body for validation
|
||||
var bodyBytes []byte
|
||||
if r.Body != nil && r.Body != http.NoBody {
|
||||
bodyBytes, _ = io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
// Validate request - body is read from r.Body automatically
|
||||
requestValidationInput := &openapi3filter.RequestValidationInput{
|
||||
Request: r,
|
||||
PathParams: pathParams,
|
||||
Route: route,
|
||||
}
|
||||
|
||||
if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "validation failed",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
173
internal/api/responses/errors.go
Normal file
173
internal/api/responses/errors.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
// Package responses provides structured API response types with security-conscious error handling.
|
||||
package responses
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// ErrorResponse provides a sanitized error response to clients.
|
||||
// It includes a trace ID for support lookup while preventing information leakage.
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"` // Sanitized error message for clients
|
||||
Code string `json:"code"` // Machine-readable error code
|
||||
TraceID string `json:"trace_id"` // For support lookup (internal correlation)
|
||||
}
|
||||
|
||||
// Error codes for machine-readable error identification
|
||||
const (
|
||||
ErrCodeBadRequest = "BAD_REQUEST"
|
||||
ErrCodeUnauthorized = "UNAUTHORIZED"
|
||||
ErrCodeForbidden = "FORBIDDEN"
|
||||
ErrCodeNotFound = "NOT_FOUND"
|
||||
ErrCodeConflict = "CONFLICT"
|
||||
ErrCodeRateLimited = "RATE_LIMITED"
|
||||
ErrCodeInternal = "INTERNAL_ERROR"
|
||||
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
|
||||
ErrCodeValidation = "VALIDATION_ERROR"
|
||||
)
|
||||
|
||||
// HTTP status to error code mapping
|
||||
var statusToCode = map[int]string{
|
||||
http.StatusBadRequest: ErrCodeBadRequest,
|
||||
http.StatusUnauthorized: ErrCodeUnauthorized,
|
||||
http.StatusForbidden: ErrCodeForbidden,
|
||||
http.StatusNotFound: ErrCodeNotFound,
|
||||
http.StatusConflict: ErrCodeConflict,
|
||||
http.StatusTooManyRequests: ErrCodeRateLimited,
|
||||
http.StatusInternalServerError: ErrCodeInternal,
|
||||
http.StatusServiceUnavailable: ErrCodeServiceUnavailable,
|
||||
422: ErrCodeValidation, // Unprocessable Entity
|
||||
}
|
||||
|
||||
// Patterns to sanitize from error messages (security: prevent information leakage)
|
||||
var (
|
||||
// Remove file paths
|
||||
pathPattern = regexp.MustCompile(`/[^\s]*`)
|
||||
// Remove sensitive keywords
|
||||
sensitiveKeywords = []string{"password", "secret", "token", "key", "credential", "auth"}
|
||||
)
|
||||
|
||||
// WriteError writes a sanitized error response to the client.
|
||||
// It extracts the trace ID from the context, logs the full error internally,
|
||||
// and returns a sanitized message to the client.
|
||||
func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, logger *logging.Logger) {
|
||||
traceID := logging.TraceIDFromContext(r.Context())
|
||||
if traceID == "" {
|
||||
traceID = generateTraceID()
|
||||
}
|
||||
|
||||
// Log the full error internally with all details
|
||||
if logger != nil {
|
||||
logger.Error("request failed",
|
||||
"trace_id", traceID,
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", status,
|
||||
"error", err.Error(),
|
||||
"client_ip", getClientIP(r),
|
||||
)
|
||||
}
|
||||
|
||||
// Build sanitized response
|
||||
resp := ErrorResponse{
|
||||
Error: sanitizeError(err.Error()),
|
||||
Code: errorCodeFromStatus(status),
|
||||
TraceID: traceID,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// WriteErrorMessage writes a sanitized error response with a custom message.
|
||||
func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, message string, logger *logging.Logger) {
|
||||
traceID := logging.TraceIDFromContext(r.Context())
|
||||
if traceID == "" {
|
||||
traceID = generateTraceID()
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Error("request failed",
|
||||
"trace_id", traceID,
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", status,
|
||||
"error", message,
|
||||
)
|
||||
}
|
||||
|
||||
resp := ErrorResponse{
|
||||
Error: sanitizeError(message),
|
||||
Code: errorCodeFromStatus(status),
|
||||
TraceID: traceID,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// sanitizeError removes potentially sensitive information from error messages.
|
||||
// It prevents information leakage to clients while preserving useful context.
|
||||
func sanitizeError(msg string) string {
|
||||
if msg == "" {
|
||||
return "An error occurred"
|
||||
}
|
||||
|
||||
// Remove file paths
|
||||
msg = pathPattern.ReplaceAllString(msg, "[path]")
|
||||
|
||||
// Remove sensitive keywords and their values
|
||||
lowerMsg := strings.ToLower(msg)
|
||||
for _, keyword := range sensitiveKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
return "An error occurred"
|
||||
}
|
||||
}
|
||||
|
||||
// Remove internal error details
|
||||
msg = strings.ReplaceAll(msg, "internal error", "an error occurred")
|
||||
msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred")
|
||||
|
||||
// Truncate if too long
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200] + "..."
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// errorCodeFromStatus returns the appropriate error code for an HTTP status.
|
||||
func errorCodeFromStatus(status int) string {
|
||||
if code, ok := statusToCode[status]; ok {
|
||||
return code
|
||||
}
|
||||
return ErrCodeInternal
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// generateTraceID generates a new trace ID when one isn't in context.
|
||||
func generateTraceID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
|
|
@ -2,12 +2,14 @@ package api
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/prommetrics"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// registerRoutes sets up all HTTP routes and handlers
|
||||
|
|
@ -34,9 +36,6 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
|
|||
// Register WebSocket endpoint
|
||||
s.registerWebSocketRoutes(mux)
|
||||
|
||||
// Register HTTP API handlers
|
||||
s.handlers.RegisterHandlers(mux)
|
||||
|
||||
// Register new REST API endpoints for TUI
|
||||
jobsHandler := jobs.NewHandler(
|
||||
s.expManager,
|
||||
|
|
@ -52,13 +51,73 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
|
|||
|
||||
// Team jobs endpoint: GET /api/jobs?all_users=true
|
||||
mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP)
|
||||
|
||||
// Register OpenAPI-generated routes with Echo router
|
||||
s.registerOpenAPIRoutes(mux, jobsHandler)
|
||||
|
||||
// Register API documentation endpoint
|
||||
s.registerDocsRoutes(mux)
|
||||
}
|
||||
|
||||
// registerDocsRoutes sets up API documentation serving
|
||||
func (s *Server) registerDocsRoutes(mux *http.ServeMux) {
|
||||
// Check if docs directory exists
|
||||
if _, err := os.Stat("docs/api"); err == nil {
|
||||
docsFS := http.FileServer(http.Dir("docs/api"))
|
||||
mux.Handle("/docs/", http.StripPrefix("/docs/", docsFS))
|
||||
s.logger.Info("API documentation endpoint registered", "path", "/docs/")
|
||||
} else {
|
||||
s.logger.Debug("API documentation not available", "path", "docs/api")
|
||||
}
|
||||
}
|
||||
|
||||
// registerOpenAPIRoutes sets up Echo router with generated OpenAPI handlers
|
||||
func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Handler) {
|
||||
// Create Echo instance for OpenAPI-generated routes
|
||||
e := echo.New()
|
||||
|
||||
// Create jupyter handler
|
||||
jupyterHandler := jupyter.NewHandler(
|
||||
s.logger,
|
||||
s.jupyterServiceMgr,
|
||||
s.config.BuildAuthConfig(),
|
||||
)
|
||||
|
||||
// Create datasets handler
|
||||
datasetsHandler := datasets.NewHandler(
|
||||
s.logger,
|
||||
s.db,
|
||||
s.config.DataDir,
|
||||
)
|
||||
|
||||
// Create adapter implementing ServerInterface
|
||||
handlerAdapter := NewHandlerAdapter(
|
||||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
)
|
||||
|
||||
// Register generated OpenAPI routes
|
||||
RegisterHandlers(e, handlerAdapter)
|
||||
|
||||
// Wrap Echo router to work with existing middleware chain
|
||||
echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
e.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Register Echo router at /v1/ prefix (and other generated paths)
|
||||
// These paths take precedence over legacy routes
|
||||
mux.Handle("/health", echoHandler)
|
||||
mux.Handle("/v1/", echoHandler)
|
||||
mux.Handle("/ws", echoHandler)
|
||||
|
||||
s.logger.Info("OpenAPI-generated routes registered with Echo router")
|
||||
}
|
||||
|
||||
// registerHealthRoutes sets up health check endpoints
|
||||
func (s *Server) registerHealthRoutes(mux *http.ServeMux) {
|
||||
healthHandler := NewHealthHandler(s)
|
||||
healthHandler.RegisterRoutes(mux)
|
||||
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
|
||||
s.logger.Info("health check endpoints registered")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
apimiddleware "github.com/jfraeys/fetch_ml/internal/api/middleware"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
|
|
@ -20,18 +21,18 @@ import (
|
|||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
config *ServerConfig
|
||||
httpServer *http.Server
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
taskQueue queue.Backend
|
||||
db *storage.DB
|
||||
handlers *Handlers
|
||||
sec *middleware.SecurityMiddleware
|
||||
cleanupFuncs []func()
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
auditLogger *audit.Logger
|
||||
promMetrics *prommetrics.Metrics // Prometheus metrics
|
||||
config *ServerConfig
|
||||
httpServer *http.Server
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
taskQueue queue.Backend
|
||||
db *storage.DB
|
||||
sec *middleware.SecurityMiddleware
|
||||
cleanupFuncs []func()
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
auditLogger *audit.Logger
|
||||
promMetrics *prommetrics.Metrics // Prometheus metrics
|
||||
validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
|
|
|
|||
642
internal/api/server_gen.go
Normal file
642
internal/api/server_gen.go
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
// Package api provides primitives to interact with the openapi HTTP API.
|
||||
//
|
||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/oapi-codegen/runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiKeyAuthScopes = "ApiKeyAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for ErrorResponseCode.
|
||||
const (
|
||||
BADREQUEST ErrorResponseCode = "BAD_REQUEST"
|
||||
CONFLICT ErrorResponseCode = "CONFLICT"
|
||||
FORBIDDEN ErrorResponseCode = "FORBIDDEN"
|
||||
INTERNALERROR ErrorResponseCode = "INTERNAL_ERROR"
|
||||
NOTFOUND ErrorResponseCode = "NOT_FOUND"
|
||||
RATELIMITED ErrorResponseCode = "RATE_LIMITED"
|
||||
SERVICEUNAVAILABLE ErrorResponseCode = "SERVICE_UNAVAILABLE"
|
||||
UNAUTHORIZED ErrorResponseCode = "UNAUTHORIZED"
|
||||
VALIDATIONERROR ErrorResponseCode = "VALIDATION_ERROR"
|
||||
)
|
||||
|
||||
// Defines values for ExperimentStatus.
|
||||
const (
|
||||
Active ExperimentStatus = "active"
|
||||
Archived ExperimentStatus = "archived"
|
||||
Deleted ExperimentStatus = "deleted"
|
||||
)
|
||||
|
||||
// Defines values for HealthResponseStatus.
|
||||
const (
|
||||
Degraded HealthResponseStatus = "degraded"
|
||||
Healthy HealthResponseStatus = "healthy"
|
||||
Unhealthy HealthResponseStatus = "unhealthy"
|
||||
)
|
||||
|
||||
// Defines values for JupyterServiceStatus.
|
||||
const (
|
||||
JupyterServiceStatusError JupyterServiceStatus = "error"
|
||||
JupyterServiceStatusRunning JupyterServiceStatus = "running"
|
||||
JupyterServiceStatusStarting JupyterServiceStatus = "starting"
|
||||
JupyterServiceStatusStopped JupyterServiceStatus = "stopped"
|
||||
JupyterServiceStatusStopping JupyterServiceStatus = "stopping"
|
||||
)
|
||||
|
||||
// Defines values for TaskStatus.
|
||||
const (
|
||||
TaskStatusCollecting TaskStatus = "collecting"
|
||||
TaskStatusCompleted TaskStatus = "completed"
|
||||
TaskStatusFailed TaskStatus = "failed"
|
||||
TaskStatusPreparing TaskStatus = "preparing"
|
||||
TaskStatusQueued TaskStatus = "queued"
|
||||
TaskStatusRunning TaskStatus = "running"
|
||||
)
|
||||
|
||||
// Defines values for GetV1TasksParamsStatus.
|
||||
const (
|
||||
Completed GetV1TasksParamsStatus = "completed"
|
||||
Failed GetV1TasksParamsStatus = "failed"
|
||||
Queued GetV1TasksParamsStatus = "queued"
|
||||
Running GetV1TasksParamsStatus = "running"
|
||||
)
|
||||
|
||||
// CreateExperimentRequest defines model for CreateExperimentRequest.
|
||||
type CreateExperimentRequest struct {
|
||||
Description *string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// CreateTaskRequest defines model for CreateTaskRequest.
|
||||
type CreateTaskRequest struct {
|
||||
// Args Command-line arguments for the training script
|
||||
Args *string `json:"args,omitempty"`
|
||||
|
||||
// Cpu CPU cores requested
|
||||
Cpu *int `json:"cpu,omitempty"`
|
||||
DatasetSpecs *[]DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
Datasets *[]string `json:"datasets,omitempty"`
|
||||
|
||||
// Gpu GPUs requested
|
||||
Gpu *int `json:"gpu,omitempty"`
|
||||
|
||||
// JobName Unique identifier for the job
|
||||
JobName string `json:"job_name"`
|
||||
|
||||
// MemoryGb Memory (GB) requested
|
||||
MemoryGb *int `json:"memory_gb,omitempty"`
|
||||
Metadata *map[string]string `json:"metadata,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
|
||||
// SnapshotId Reference to experiment snapshot
|
||||
SnapshotId *string `json:"snapshot_id,omitempty"`
|
||||
}
|
||||
|
||||
// DatasetSpec defines model for DatasetSpec.
|
||||
type DatasetSpec struct {
|
||||
MountPath *string `json:"mount_path,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Sha256 *string `json:"sha256,omitempty"`
|
||||
Source *string `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse defines model for ErrorResponse.
|
||||
type ErrorResponse struct {
|
||||
Code ErrorResponseCode `json:"code"`
|
||||
|
||||
// Error Sanitized error message
|
||||
Error string `json:"error"`
|
||||
|
||||
// TraceId Support correlation ID
|
||||
TraceId string `json:"trace_id"`
|
||||
}
|
||||
|
||||
// ErrorResponseCode defines model for ErrorResponse.Code.
|
||||
type ErrorResponseCode string
|
||||
|
||||
// Experiment defines model for Experiment.
|
||||
type Experiment struct {
|
||||
CommitId *string `json:"commit_id,omitempty"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
Id *string `json:"id,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Status *ExperimentStatus `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// ExperimentStatus defines model for Experiment.Status.
|
||||
type ExperimentStatus string
|
||||
|
||||
// HealthResponse defines model for HealthResponse.
|
||||
type HealthResponse struct {
|
||||
Status *HealthResponseStatus `json:"status,omitempty"`
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
Version *string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// HealthResponseStatus defines model for HealthResponse.Status.
|
||||
type HealthResponseStatus string
|
||||
|
||||
// JupyterService defines model for JupyterService.
|
||||
type JupyterService struct {
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
Id *string `json:"id,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Status *JupyterServiceStatus `json:"status,omitempty"`
|
||||
Token *string `json:"token,omitempty"`
|
||||
Url *string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// JupyterServiceStatus defines model for JupyterService.Status.
|
||||
type JupyterServiceStatus string
|
||||
|
||||
// QueueStats defines model for QueueStats.
|
||||
type QueueStats struct {
|
||||
// Completed Tasks completed today
|
||||
Completed *int `json:"completed,omitempty"`
|
||||
|
||||
// Failed Tasks failed today
|
||||
Failed *int `json:"failed,omitempty"`
|
||||
|
||||
// Queued Tasks waiting to run
|
||||
Queued *int `json:"queued,omitempty"`
|
||||
|
||||
// Running Tasks currently executing
|
||||
Running *int `json:"running,omitempty"`
|
||||
|
||||
// Workers Active workers
|
||||
Workers *int `json:"workers,omitempty"`
|
||||
}
|
||||
|
||||
// StartJupyterRequest defines model for StartJupyterRequest.
|
||||
type StartJupyterRequest struct {
|
||||
Image *string `json:"image,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Workspace *string `json:"workspace,omitempty"`
|
||||
}
|
||||
|
||||
// Task defines model for Task.
|
||||
type Task struct {
|
||||
Cpu *int `json:"cpu,omitempty"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
Datasets *[]string `json:"datasets,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
Gpu *int `json:"gpu,omitempty"`
|
||||
|
||||
// Id Unique task identifier
|
||||
Id *string `json:"id,omitempty"`
|
||||
JobName *string `json:"job_name,omitempty"`
|
||||
MaxRetries *int `json:"max_retries,omitempty"`
|
||||
MemoryGb *int `json:"memory_gb,omitempty"`
|
||||
Output *string `json:"output,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
RetryCount *int `json:"retry_count,omitempty"`
|
||||
SnapshotId *string `json:"snapshot_id,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
Status *TaskStatus `json:"status,omitempty"`
|
||||
UserId *string `json:"user_id,omitempty"`
|
||||
WorkerId *string `json:"worker_id,omitempty"`
|
||||
}
|
||||
|
||||
// TaskStatus defines model for Task.Status.
|
||||
type TaskStatus string
|
||||
|
||||
// TaskList defines model for TaskList.
|
||||
type TaskList struct {
|
||||
Limit *int `json:"limit,omitempty"`
|
||||
Offset *int `json:"offset,omitempty"`
|
||||
Tasks *[]Task `json:"tasks,omitempty"`
|
||||
Total *int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// BadRequest defines model for BadRequest.
|
||||
type BadRequest = ErrorResponse
|
||||
|
||||
// NotFound defines model for NotFound.
|
||||
type NotFound = ErrorResponse
|
||||
|
||||
// RateLimited defines model for RateLimited.
|
||||
type RateLimited = ErrorResponse
|
||||
|
||||
// Unauthorized defines model for Unauthorized.
|
||||
type Unauthorized = ErrorResponse
|
||||
|
||||
// ValidationError defines model for ValidationError.
|
||||
type ValidationError = ErrorResponse
|
||||
|
||||
// GetV1TasksParams defines parameters for GetV1Tasks.
|
||||
type GetV1TasksParams struct {
|
||||
Status *GetV1TasksParamsStatus `form:"status,omitempty" json:"status,omitempty"`
|
||||
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
|
||||
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
|
||||
}
|
||||
|
||||
// GetV1TasksParamsStatus defines parameters for GetV1Tasks.
|
||||
type GetV1TasksParamsStatus string
|
||||
|
||||
// PostV1ExperimentsJSONRequestBody defines body for PostV1Experiments for application/json ContentType.
|
||||
type PostV1ExperimentsJSONRequestBody = CreateExperimentRequest
|
||||
|
||||
// PostV1JupyterServicesJSONRequestBody defines body for PostV1JupyterServices for application/json ContentType.
|
||||
type PostV1JupyterServicesJSONRequestBody = StartJupyterRequest
|
||||
|
||||
// PostV1TasksJSONRequestBody defines body for PostV1Tasks for application/json ContentType.
|
||||
type PostV1TasksJSONRequestBody = CreateTaskRequest
|
||||
|
||||
// ServerInterface represents all server handlers.
|
||||
type ServerInterface interface {
|
||||
// Health check
|
||||
// (GET /health)
|
||||
GetHealth(ctx echo.Context) error
|
||||
// List experiments
|
||||
// (GET /v1/experiments)
|
||||
GetV1Experiments(ctx echo.Context) error
|
||||
// Create experiment
|
||||
// (POST /v1/experiments)
|
||||
PostV1Experiments(ctx echo.Context) error
|
||||
// List Jupyter services
|
||||
// (GET /v1/jupyter/services)
|
||||
GetV1JupyterServices(ctx echo.Context) error
|
||||
// Start Jupyter service
|
||||
// (POST /v1/jupyter/services)
|
||||
PostV1JupyterServices(ctx echo.Context) error
|
||||
// Stop Jupyter service
|
||||
// (DELETE /v1/jupyter/services/{serviceId})
|
||||
DeleteV1JupyterServicesServiceId(ctx echo.Context, serviceId string) error
|
||||
// Queue status
|
||||
// (GET /v1/queue)
|
||||
GetV1Queue(ctx echo.Context) error
|
||||
// List tasks
|
||||
// (GET /v1/tasks)
|
||||
GetV1Tasks(ctx echo.Context, params GetV1TasksParams) error
|
||||
// Create task
|
||||
// (POST /v1/tasks)
|
||||
PostV1Tasks(ctx echo.Context) error
|
||||
// Cancel/delete task
|
||||
// (DELETE /v1/tasks/{taskId})
|
||||
DeleteV1TasksTaskId(ctx echo.Context, taskId string) error
|
||||
// Get task details
|
||||
// (GET /v1/tasks/{taskId})
|
||||
GetV1TasksTaskId(ctx echo.Context, taskId string) error
|
||||
// WebSocket connection
|
||||
// (GET /ws)
|
||||
GetWs(ctx echo.Context) error
|
||||
}
|
||||
|
||||
// ServerInterfaceWrapper converts echo contexts to parameters.
|
||||
type ServerInterfaceWrapper struct {
|
||||
Handler ServerInterface
|
||||
}
|
||||
|
||||
// GetHealth converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetHealth(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetHealth(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetV1Experiments converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetV1Experiments(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetV1Experiments(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// PostV1Experiments converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) PostV1Experiments(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.PostV1Experiments(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetV1JupyterServices converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetV1JupyterServices(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetV1JupyterServices(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// PostV1JupyterServices converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) PostV1JupyterServices(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.PostV1JupyterServices(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteV1JupyterServicesServiceId converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) DeleteV1JupyterServicesServiceId(ctx echo.Context) error {
|
||||
var err error
|
||||
// ------------- Path parameter "serviceId" -------------
|
||||
var serviceId string
|
||||
|
||||
err = runtime.BindStyledParameterWithOptions("simple", "serviceId", ctx.Param("serviceId"), &serviceId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter serviceId: %s", err))
|
||||
}
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.DeleteV1JupyterServicesServiceId(ctx, serviceId)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetV1Queue converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetV1Queue(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetV1Queue(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetV1Tasks converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetV1Tasks(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params GetV1TasksParams
|
||||
// ------------- Optional query parameter "status" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "status", ctx.QueryParams(), ¶ms.Status)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter status: %s", err))
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "limit" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "limit", ctx.QueryParams(), ¶ms.Limit)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter limit: %s", err))
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "offset" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "offset", ctx.QueryParams(), ¶ms.Offset)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter offset: %s", err))
|
||||
}
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetV1Tasks(ctx, params)
|
||||
return err
|
||||
}
|
||||
|
||||
// PostV1Tasks converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) PostV1Tasks(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.PostV1Tasks(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteV1TasksTaskId converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) DeleteV1TasksTaskId(ctx echo.Context) error {
|
||||
var err error
|
||||
// ------------- Path parameter "taskId" -------------
|
||||
var taskId string
|
||||
|
||||
err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err))
|
||||
}
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.DeleteV1TasksTaskId(ctx, taskId)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetV1TasksTaskId converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetV1TasksTaskId(ctx echo.Context) error {
|
||||
var err error
|
||||
// ------------- Path parameter "taskId" -------------
|
||||
var taskId string
|
||||
|
||||
err = runtime.BindStyledParameterWithOptions("simple", "taskId", ctx.Param("taskId"), &taskId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter taskId: %s", err))
|
||||
}
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetV1TasksTaskId(ctx, taskId)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetWs converts echo context to params.
|
||||
func (w *ServerInterfaceWrapper) GetWs(ctx echo.Context) error {
|
||||
var err error
|
||||
|
||||
ctx.Set(ApiKeyAuthScopes, []string{})
|
||||
|
||||
// Invoke the callback with all the unmarshaled arguments
|
||||
err = w.Handler.GetWs(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// This is a simple interface which specifies echo.Route addition functions which
|
||||
// are present on both echo.Echo and echo.Group, since we want to allow using
|
||||
// either of them for path registration
|
||||
type EchoRouter interface {
|
||||
CONNECT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
HEAD(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
OPTIONS(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
PATCH(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
TRACE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route
|
||||
}
|
||||
|
||||
// RegisterHandlers adds each server route to the EchoRouter.
|
||||
func RegisterHandlers(router EchoRouter, si ServerInterface) {
|
||||
RegisterHandlersWithBaseURL(router, si, "")
|
||||
}
|
||||
|
||||
// Registers handlers, and prepends BaseURL to the paths, so that the paths
|
||||
// can be served under a prefix.
|
||||
func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) {
|
||||
|
||||
wrapper := ServerInterfaceWrapper{
|
||||
Handler: si,
|
||||
}
|
||||
|
||||
router.GET(baseURL+"/health", wrapper.GetHealth)
|
||||
router.GET(baseURL+"/v1/experiments", wrapper.GetV1Experiments)
|
||||
router.POST(baseURL+"/v1/experiments", wrapper.PostV1Experiments)
|
||||
router.GET(baseURL+"/v1/jupyter/services", wrapper.GetV1JupyterServices)
|
||||
router.POST(baseURL+"/v1/jupyter/services", wrapper.PostV1JupyterServices)
|
||||
router.DELETE(baseURL+"/v1/jupyter/services/:serviceId", wrapper.DeleteV1JupyterServicesServiceId)
|
||||
router.GET(baseURL+"/v1/queue", wrapper.GetV1Queue)
|
||||
router.GET(baseURL+"/v1/tasks", wrapper.GetV1Tasks)
|
||||
router.POST(baseURL+"/v1/tasks", wrapper.PostV1Tasks)
|
||||
router.DELETE(baseURL+"/v1/tasks/:taskId", wrapper.DeleteV1TasksTaskId)
|
||||
router.GET(baseURL+"/v1/tasks/:taskId", wrapper.GetV1TasksTaskId)
|
||||
router.GET(baseURL+"/ws", wrapper.GetWs)
|
||||
|
||||
}
|
||||
|
||||
// Base64 encoded, gzipped, json marshaled Swagger object
|
||||
var swaggerSpec = []string{
|
||||
|
||||
"H4sIAAAAAAAC/8RaW1fbuhL+K1ra+6FdxyFAL2c3bwHSNqfh0gTaszbhBGFPEoEtqZIMZLPy38+SZCe+",
|
||||
"yEDphSdIJI1nPn2a+TTOHQ55IjgDphXu3GEJSnCmwH7YIdEQvqWgtPkUcqaB2X+JEDENiaactS8VZ+Y7",
|
||||
"uCWJiMHNjAB38E53bzLsfT7pjY5xgEFKLnEH99k1iWmEpLOMplwmROMAa0lCmNAIdzDZutgOX0WvW/Bm",
|
||||
"+rb177/ebbbIRRi1YLq1/er1m7fmG7wMsArnkBDzyD8lTHEH/9Feh9N2o6rdM08eZoHh5XIZ4AhUKKkw",
|
||||
"AdRdMpYPuH7PUxY9KfCDw+PJ+8OTg71C2ENQPJUhIMZNzMb0c4bscWcZ4CHRMKAJ1fC0wIfd495k0N/v",
|
||||
"H/dKsRMNKDZ2EdyGABE8b/DHnKOEsEW+4QoHeA4kAmlpPwQtF63uVIM0H8trRxByFimUMk1jJNeRSVBg",
|
||||
"La2d1AthUKFMwwyk8WQZ4BNGUj3nkv7zRJBPDronxx8Ph/2/SyDnJOYSJVQpymaoe9RHV7B4Vqy7qZ4D",
|
||||
"01lYFnEqwbLti/HXft1zMTwBjC/dQX+ve9w/PJj0hsPDYQGQtXk0JTR+Zs7VvVmujFvW7UogGnq3AiRN",
|
||||
"gOlC5hWSC5CauqxcsrsimdKSsplxmJHEIpSQ2wGwmZ7jztb2X0F14jLAq83onLpVZ6tZ/OISQpsJnV/H",
|
||||
"RF01ekTkrO4Z3uVJQljUiikDROQsNVEpk++RngPSklBmWOrW4KAeSShSj9mjExRyCSo/vW5jK0ctwBHR",
|
||||
"RIGeKAGh9Y5qSNRD+7jnVo0EhMZIZpZISRYFo2V7Nb+rq2a+OD4cnTwUwiW/mOS7WV58wui3FBCNzNGa",
|
||||
"UpArWC/5BQ6Ke//2dYAF0RqkWfi/U9L6p9v6e7P1btI6+9efPtgTSLhcTGYX9efu2yH04sPOywd8T0AT",
|
||||
"g5blRxRRY4DERyXeNAG3Jp+QlEuqF86TKUljjTtvbHw0SRPc2doMcEJZ9sHniGJEqDnX9tTf1crgFCSw",
|
||||
"EJDmCFZnD+WL8EPHZrVFvqNTJFPt0CQ8ZXoiiNmje05xbUDNyfabt/4hW9I9Qz5oyxmr5p7Lr3cYmEH2",
|
||||
"tCLmKjXo/eFwp7+31zvAQUn87B4evB/0d82KijToHxz3hgfdwSprj3rDL/3d3uTkoPul2x90dwY9HNTT",
|
||||
"+5mHr5BXj0qlJoxqU2WRnYASUIrMwMf4dVmoGUmF4FKbjCMhdvm7v/cgLZxLgUOxYN/HknXK9+1CktCc",
|
||||
"uvX8aFNzNCF2ZSakOyZJQUvTxBtqg6lmummiU1VkAgk1vTa2iQzn9Nqe/whiMJngLHgM9T4CifW8mXv1",
|
||||
"Z87tioV90kwSJyBTln/tY4UBQGmSiMdjcw1S+auqL4r/pGKhQY5AXtPQd4KeZXeUJlKb0QDLlDH3n9Jc",
|
||||
"iMK/Fj7HUS90/Ar82iKVcSmYVFL8qC3/nEIKI01c6ayRXDj61I6fkR0KrSYgzSOy8NabTFQ1WHCj9yz/",
|
||||
"ZvxrXH5DqMHUFAmZMq+BHOymEFIpgel4geAWwjTbobqZGy6vsmtIRUfbU4fy8fpaH+ojQ4aMp43qjSYm",
|
||||
"KxYLLL50S9piobkM552YaLM0+A5mGkeVIE3l6FHS00DnoYtTU3XsnnLgnibogEXf+ZxVjaqNzJrC8ZWj",
|
||||
"TPdpoq4K4s/3wKJ2/DE1SG4nErQsa7aS0ivIxfowT7VItTf0n6DujGeLSWi0lP/xFfnny6Dye0lTz7pZ",
|
||||
"+jAhgSCymn9DHscQ6vxDnu5WWcuXhFMFsslplwX8o8uGgzSgvsNvexcNGzedKmgYMwR8/KXKnmPPMdJc",
|
||||
"k9jbKKnFYFCHMDVsGRmrzvuuoJ9g0U2dhq4kTNf8sBcjUmpC4ABTM8O1fHCexPB/W92jfusTFCoEsQ9w",
|
||||
"93nKpjzvUZDQApMtfA86nO8PUCYWcb0JctS3fiSEkZkpJPuD4m3DookIi1CWq5FyokJtjNmY/fEHGmWx",
|
||||
"j1k3jhGwSHBqLtMv4DYEoZETQiicQ3ilXuZdlrwBVIkfXVNi7opjdr4K+Rw5NDbQullnHKUKAZtyGUKE",
|
||||
"BMjcYu6XvUSgj4RFMWWzsWvkmDt+HPMbRFDImaJKmyDd0UI3VM9RQsI5ZdCSQCJyEQMyOtkhYKUy6u+p",
|
||||
"zpidn59fKs7G7G7MEBq7JDrGHTRuEvdjHLipxqCbae8Nk93Dvd5qMJfjbkKa0qg15bKl3PaN8Zgt7cPH",
|
||||
"ttRTHZtN3h+gr/bYGQxwQSzirY3NDdsq4gIYERR38KuNzY1X2KbauWVq2+2Q+XfmDlX1IqpTyZTdeJD5",
|
||||
"frpEs4EOeHUL8/q5ge1jpf22H+EO/gDa6WuTgYrd/O3NzUe02B7X7qooeE+/a+QCoQrlIr14hnHn9CzA",
|
||||
"Kk0SIhe4k90IHH/txPb1Vnt9QlQjbiarIWLORGGyB5IvW73ShB9C5lFJr3Ctq6W+Olw2Dj4thWEBW0Fk",
|
||||
"Z5SGAyy48mDiOnaIIAY3hRU1WI648uBiVeIOjxY/jSxNnc1lWQhqmcKytjNbP82N4obUN2A9ijIdWcE/",
|
||||
"QxWKVhxPc7WcJ+0CWT08LF8bfw8XK1fV7+BjtSD5SFmfs2amj3E+BH4+63x3n9/MuCrsdZgr0KFMjVZQ",
|
||||
"tqFUYW6kX/su+68fLV1yMGqzvhl79vvadozyxbZ8SZKAthfS0zunmmzHcqWZVGF2GVjPi7CVRD2rgf7a",
|
||||
"96Ytx8R1LKqYcNEIiZXjjTXjA+j8Po7sTFtnqdI0bKgdtnvxK+tpoT3iYcnnqpNlKNbDqVpBsFLo95dN",
|
||||
"Jz2tKOPC9ejRlMYasjuMBwzb0Gggx7cU5KLADudTkQq1G1PxlvSIi9Ey8D/KXWSKT1rfKDdLV8rNTV/z",
|
||||
"xG81uwN5zfrMnP1Cjqwucfeka7frywC/dg/22Vs52C78vsMu2Xp4SenttVm0/e7hRcWfFXjqx8ppv5wZ",
|
||||
"pRcJ1ZmcqV+cGmRNztJfJ2iKr0R/c2FxF2rPDxuIusrlC1JpGIJS0zSOF7+XEtsPL6q++/9xKmX6TDto",
|
||||
"CkmwfWf+PLIYWtoc2/mPqn86n/qTi5/bScJCiOMMVjftfnhWP1uqYGMNtV3wGUTBPRr1WVDY/D3HIwJN",
|
||||
"aKx+FFIjInTJnuHcTXPJ/QoXIx5egV61b2xLSAKJbaPRWUtFRPS677Pv2hroeCFAjVkLnZtZEzfrvINs",
|
||||
"RK7KonBO2Kw4K6+n+bwpZVTNIbIzBGWz8w76BCBaJKbXgF64mCOnBs4FZ7Pzl7YFUmPI19q1ZctliqaQ",
|
||||
"Q84YhLZzAUqTi9g6Um0JlBt6p2fLUo/AZ83t8kMmbCvCkbdSNnlIYhTBNcRcuBf/di7O3nThudai027H",
|
||||
"Zt6cK915ZwI1aqFs6EjyKHXxeSyoTrtNBN2Ygg7nSbyR/YxpI+QJXp4t/x8AAP//HXQa3YQpAAA=",
|
||||
}
|
||||
|
||||
// GetSwagger returns the content of the embedded swagger specification file
|
||||
// or error if failed to decode
|
||||
func decodeSpec() ([]byte, error) {
|
||||
zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, ""))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error base64 decoding spec: %w", err)
|
||||
}
|
||||
zr, err := gzip.NewReader(bytes.NewReader(zipped))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decompressing spec: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
_, err = buf.ReadFrom(zr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decompressing spec: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
var rawSpec = decodeSpecCached()
|
||||
|
||||
// a naive cached of a decoded swagger spec
|
||||
func decodeSpecCached() func() ([]byte, error) {
|
||||
data, err := decodeSpec()
|
||||
return func() ([]byte, error) {
|
||||
return data, err
|
||||
}
|
||||
}
|
||||
|
||||
// Constructs a synthetic filesystem for resolving external references when loading openapi specifications.
|
||||
func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) {
|
||||
res := make(map[string]func() ([]byte, error))
|
||||
if len(pathToFile) > 0 {
|
||||
res[pathToFile] = rawSpec
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// GetSwagger returns the Swagger specification corresponding to the generated code
|
||||
// in this file. The external references of Swagger specification are resolved.
|
||||
// The logic of resolving external references is tightly connected to "import-mapping" feature.
|
||||
// Externally referenced files must be embedded in the corresponding golang packages.
|
||||
// Urls can be supported but this task was out of the scope.
|
||||
func GetSwagger() (swagger *openapi3.T, err error) {
|
||||
resolvePath := PathToRawSpec("")
|
||||
|
||||
loader := openapi3.NewLoader()
|
||||
loader.IsExternalRefsAllowed = true
|
||||
loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) {
|
||||
pathToFile := url.String()
|
||||
pathToFile = path.Clean(pathToFile)
|
||||
getSpec, ok := resolvePath[pathToFile]
|
||||
if !ok {
|
||||
err1 := fmt.Errorf("path not found: %s", pathToFile)
|
||||
return nil, err1
|
||||
}
|
||||
return getSpec()
|
||||
}
|
||||
var specData []byte
|
||||
specData, err = rawSpec()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
swagger, err = loader.LoadFromData(specData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -23,11 +23,11 @@ const (
|
|||
|
||||
// Argon2id parameters (OWASP recommended minimum)
|
||||
const (
|
||||
argon2idTime = 1 // Iterations
|
||||
argon2idMemory = 64 * 1024 // 64 MB
|
||||
argon2idThreads = 4 // Parallelism
|
||||
argon2idKeyLen = 32 // 256-bit output
|
||||
argon2idSaltLen = 16 // 128-bit salt
|
||||
argon2idTime = 1 // Iterations
|
||||
argon2idMemory = 64 * 1024 // 64 MB
|
||||
argon2idThreads = 4 // Parallelism
|
||||
argon2idKeyLen = 32 // 256-bit output
|
||||
argon2idSaltLen = 16 // 128-bit salt
|
||||
)
|
||||
|
||||
// HashedKey represents a hashed API key with its algorithm and salt
|
||||
|
|
|
|||
|
|
@ -21,35 +21,39 @@ func NewPathRegistry(root string) *PathRegistry {
|
|||
}
|
||||
|
||||
// Binary paths
|
||||
func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") }
|
||||
func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") }
|
||||
func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") }
|
||||
func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") }
|
||||
func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") }
|
||||
func (p *PathRegistry) BinDir() string { return filepath.Join(p.RootDir, "bin") }
|
||||
func (p *PathRegistry) APIServerBinary() string { return filepath.Join(p.BinDir(), "api-server") }
|
||||
func (p *PathRegistry) WorkerBinary() string { return filepath.Join(p.BinDir(), "worker") }
|
||||
func (p *PathRegistry) TUIBinary() string { return filepath.Join(p.BinDir(), "tui") }
|
||||
func (p *PathRegistry) DataManagerBinary() string { return filepath.Join(p.BinDir(), "data_manager") }
|
||||
|
||||
// Data paths
|
||||
func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") }
|
||||
func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") }
|
||||
func (p *PathRegistry) JupyterStateDir() string { return filepath.Join(p.DataDir(), "active", "jupyter") }
|
||||
func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") }
|
||||
func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") }
|
||||
func (p *PathRegistry) DataDir() string { return filepath.Join(p.RootDir, "data") }
|
||||
func (p *PathRegistry) ActiveDataDir() string { return filepath.Join(p.DataDir(), "active") }
|
||||
func (p *PathRegistry) JupyterStateDir() string {
|
||||
return filepath.Join(p.DataDir(), "active", "jupyter")
|
||||
}
|
||||
func (p *PathRegistry) ExperimentsDir() string { return filepath.Join(p.DataDir(), "experiments") }
|
||||
func (p *PathRegistry) ProdSmokeDir() string { return filepath.Join(p.DataDir(), "prod-smoke") }
|
||||
|
||||
// Database paths
|
||||
func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") }
|
||||
func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") }
|
||||
func (p *PathRegistry) DBDir() string { return filepath.Join(p.RootDir, "db") }
|
||||
func (p *PathRegistry) SQLitePath() string { return filepath.Join(p.DBDir(), "fetch_ml.db") }
|
||||
|
||||
// Log paths
|
||||
func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") }
|
||||
func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") }
|
||||
func (p *PathRegistry) LogDir() string { return filepath.Join(p.RootDir, "logs") }
|
||||
func (p *PathRegistry) AuditLogPath() string { return filepath.Join(p.LogDir(), "fetchml-audit.log") }
|
||||
|
||||
// Config paths
|
||||
func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") }
|
||||
func (p *PathRegistry) APIServerConfig() string { return filepath.Join(p.ConfigDir(), "api", "dev.yaml") }
|
||||
func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") }
|
||||
func (p *PathRegistry) ConfigDir() string { return filepath.Join(p.RootDir, "configs") }
|
||||
func (p *PathRegistry) APIServerConfig() string {
|
||||
return filepath.Join(p.ConfigDir(), "api", "dev.yaml")
|
||||
}
|
||||
func (p *PathRegistry) WorkerConfigDir() string { return filepath.Join(p.ConfigDir(), "workers") }
|
||||
|
||||
// Test paths
|
||||
func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") }
|
||||
func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") }
|
||||
func (p *PathRegistry) TestResultsDir() string { return filepath.Join(p.RootDir, "test_results") }
|
||||
func (p *PathRegistry) TempDir() string { return filepath.Join(p.RootDir, "tmp") }
|
||||
|
||||
// State file paths (for service persistence)
|
||||
func (p *PathRegistry) JupyterServicesFile() string {
|
||||
|
|
@ -83,13 +87,13 @@ func detectRepoRoot() string {
|
|||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
|
||||
|
||||
// Walk up directory tree looking for go.mod
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
return dir
|
||||
}
|
||||
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
// Reached root
|
||||
|
|
@ -97,7 +101,7 @@ func detectRepoRoot() string {
|
|||
}
|
||||
dir = parent
|
||||
}
|
||||
|
||||
|
||||
return "."
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ func NewEnvSecretsManager() *EnvSecretsManager { return &EnvSecretsManager{} }
|
|||
|
||||
func (e *EnvSecretsManager) Get(ctx context.Context, key string) (string, error) {
|
||||
value := os.Getenv(key)
|
||||
if value == "" { return "", fmt.Errorf("secret %s not found", key) }
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("secret %s not found", key)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
|
|
@ -47,6 +49,8 @@ func (e *EnvSecretsManager) List(ctx context.Context, prefix string) ([]string,
|
|||
|
||||
// RedactSecret masks a secret for safe logging
|
||||
func RedactSecret(secret string) string {
|
||||
if len(secret) <= 8 { return "***" }
|
||||
if len(secret) <= 8 {
|
||||
return "***"
|
||||
}
|
||||
return secret[:4] + "..." + secret[len(secret)-4:]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ package domain
|
|||
type JobStatus string
|
||||
|
||||
const (
|
||||
StatusPending JobStatus = "pending"
|
||||
StatusQueued JobStatus = "queued"
|
||||
StatusRunning JobStatus = "running"
|
||||
StatusCompleted JobStatus = "completed"
|
||||
StatusFailed JobStatus = "failed"
|
||||
StatusPending JobStatus = "pending"
|
||||
StatusQueued JobStatus = "queued"
|
||||
StatusRunning JobStatus = "running"
|
||||
StatusCompleted JobStatus = "completed"
|
||||
StatusFailed JobStatus = "failed"
|
||||
)
|
||||
|
||||
// String returns the string representation of the status
|
||||
|
|
|
|||
|
|
@ -164,6 +164,15 @@ func (l *Logger) Job(ctx context.Context, job string, task string) *Logger {
|
|||
return l.WithContext(ctx, "job_name", job, "task_id", task)
|
||||
}
|
||||
|
||||
// TraceIDFromContext extracts the trace ID from a context.
|
||||
// Returns empty string if no trace ID is present.
|
||||
func TraceIDFromContext(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(CtxTraceID).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Fatal logs an error message and exits with status 1.
|
||||
func (l *Logger) Fatal(msg string, args ...any) {
|
||||
l.Error(msg, args...)
|
||||
|
|
|
|||
|
|
@ -14,19 +14,19 @@ import (
|
|||
// EventStore provides append-only event storage using Redis Streams.
|
||||
// Events are stored chronologically and can be queried for audit trails.
|
||||
type EventStore struct {
|
||||
client *redis.Client
|
||||
ctx context.Context
|
||||
retentionDays int
|
||||
maxStreamLen int64
|
||||
client *redis.Client
|
||||
ctx context.Context
|
||||
retentionDays int
|
||||
maxStreamLen int64
|
||||
}
|
||||
|
||||
// EventStoreConfig holds configuration for the event store.
|
||||
type EventStoreConfig struct {
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RetentionDays int // How long to keep events (default: 7)
|
||||
MaxStreamLen int64 // Max events per task stream (default: 1000)
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RetentionDays int // How long to keep events (default: 7)
|
||||
MaxStreamLen int64 // Max events per task stream (default: 1000)
|
||||
}
|
||||
|
||||
// NewEventStore creates a new event store instance.
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ type metricEvent struct {
|
|||
}
|
||||
|
||||
// Config holds configuration for Queue
|
||||
type Config struct {
|
||||
type Config struct {
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
|
|
@ -48,9 +48,9 @@ func NewQueue(cfg Config) (*Queue, error) {
|
|||
}
|
||||
} else {
|
||||
opts = &redis.Options{
|
||||
Addr: cfg.RedisAddr,
|
||||
Password: cfg.RedisPassword,
|
||||
DB: cfg.RedisDB,
|
||||
Addr: cfg.RedisAddr,
|
||||
Password: cfg.RedisPassword,
|
||||
DB: cfg.RedisDB,
|
||||
PoolSize: 50,
|
||||
MinIdleConns: 10,
|
||||
MaxRetries: 3,
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ type Metric struct {
|
|||
|
||||
// MetricSummary represents aggregated metric statistics
|
||||
type MetricSummary struct {
|
||||
Name string `json:"name"`
|
||||
Count int64 `json:"count"`
|
||||
Avg float64 `json:"avg"`
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
Sum float64 `json:"sum"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Name string `json:"name"`
|
||||
Count int64 `json:"count"`
|
||||
Avg float64 `json:"avg"`
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
Sum float64 `json:"sum"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
}
|
||||
|
||||
// RecordMetric records a metric to the database
|
||||
|
|
|
|||
106
internal/worker/errors/execution.go
Normal file
106
internal/worker/errors/execution.go
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
// Package errors provides structured error types for the worker.
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExecutionError provides structured error context for task execution failures.
|
||||
// It captures the task ID, execution phase, specific operation, root cause,
|
||||
// and additional context to make debugging easier.
|
||||
type ExecutionError struct {
|
||||
TaskID string // The task that failed
|
||||
Phase string // Current TaskState (queued, preparing, running, collecting)
|
||||
Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset")
|
||||
Cause error // The underlying error
|
||||
Context map[string]string // Additional context (paths, IDs, etc.)
|
||||
Timestamp time.Time // When the error occurred
|
||||
}
|
||||
|
||||
// Error implements the error interface with a formatted message.
|
||||
func (e ExecutionError) Error() string {
|
||||
return fmt.Sprintf("[%s/%s] task=%s: %v", e.Phase, e.Operation, e.TaskID, e.Cause)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error for error chain inspection.
|
||||
func (e ExecutionError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// WithContext adds a key-value pair to the error context.
|
||||
func (e ExecutionError) WithContext(key, value string) ExecutionError {
|
||||
if e.Context == nil {
|
||||
e.Context = make(map[string]string)
|
||||
}
|
||||
e.Context[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// ContextString returns a formatted string of all context values.
|
||||
func (e ExecutionError) ContextString() string {
|
||||
if len(e.Context) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := ""
|
||||
for k, v := range e.Context {
|
||||
if result != "" {
|
||||
result += ", "
|
||||
}
|
||||
result += fmt.Sprintf("%s=%s", k, v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NewExecutionError creates a new ExecutionError with the given parameters.
|
||||
func NewExecutionError(taskID, phase, operation string, cause error) ExecutionError {
|
||||
return ExecutionError{
|
||||
TaskID: taskID,
|
||||
Phase: phase,
|
||||
Operation: operation,
|
||||
Cause: cause,
|
||||
Context: make(map[string]string),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// IsExecutionError checks if an error is an ExecutionError.
|
||||
func IsExecutionError(err error) bool {
|
||||
_, ok := err.(ExecutionError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Common error operations for the worker lifecycle.
|
||||
const (
|
||||
// Preparation operations
|
||||
OpCreateWorkspace = "create_workspace"
|
||||
OpFetchDataset = "fetch_dataset"
|
||||
OpMountVolume = "mount_volume"
|
||||
OpSetupEnvironment = "setup_environment"
|
||||
OpStageSnapshot = "stage_snapshot"
|
||||
|
||||
// Execution operations
|
||||
OpStartContainer = "start_container"
|
||||
OpExecuteCommand = "execute_command"
|
||||
OpMonitorExecution = "monitor_execution"
|
||||
|
||||
// Collection operations
|
||||
OpCollectResults = "collect_results"
|
||||
OpUploadArtifacts = "upload_artifacts"
|
||||
OpCleanupWorkspace = "cleanup_workspace"
|
||||
|
||||
// Validation operations
|
||||
OpValidateManifest = "validate_manifest"
|
||||
OpCheckResources = "check_resources"
|
||||
OpVerifyProvenance = "verify_provenance"
|
||||
)
|
||||
|
||||
// Common phase names (should match TaskState values).
|
||||
const (
|
||||
PhaseQueued = "queued"
|
||||
PhasePreparing = "preparing"
|
||||
PhaseRunning = "running"
|
||||
PhaseCollecting = "collecting"
|
||||
PhaseCompleted = "completed"
|
||||
PhaseFailed = "failed"
|
||||
)
|
||||
|
|
@ -126,6 +126,9 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
cfg.LocalMode,
|
||||
)
|
||||
|
||||
// Create state manager for task lifecycle management
|
||||
stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available
|
||||
|
||||
// Create job runner
|
||||
jobRunner := executor.NewJobRunner(
|
||||
localExecutor,
|
||||
|
|
@ -140,6 +143,7 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
exec,
|
||||
metricsObj,
|
||||
logger,
|
||||
stateMgr,
|
||||
)
|
||||
|
||||
// Create resource manager
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ type RunLoop struct {
|
|||
executor TaskExecutor
|
||||
metrics MetricsRecorder
|
||||
logger Logger
|
||||
stateMgr *StateManager
|
||||
|
||||
// State management
|
||||
running map[string]context.CancelFunc
|
||||
|
|
@ -64,6 +65,7 @@ func NewRunLoop(
|
|||
executor TaskExecutor,
|
||||
metrics MetricsRecorder,
|
||||
logger Logger,
|
||||
stateMgr *StateManager,
|
||||
) *RunLoop {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &RunLoop{
|
||||
|
|
@ -72,6 +74,7 @@ func NewRunLoop(
|
|||
executor: executor,
|
||||
metrics: metrics,
|
||||
logger: logger,
|
||||
stateMgr: stateMgr,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
|
@ -185,17 +188,45 @@ func (r *RunLoop) waitForTasks() {
|
|||
func (r *RunLoop) executeTask(task *queue.Task) {
|
||||
defer r.releaseRunningSlot(task.ID)
|
||||
|
||||
// Transition to preparing state
|
||||
if r.stateMgr != nil {
|
||||
if err := r.stateMgr.Transition(task, StatePreparing); err != nil {
|
||||
r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.metrics.RecordTaskStart()
|
||||
defer r.metrics.RecordTaskCompletion()
|
||||
|
||||
r.logger.Info("starting task", "task_id", task.ID, "job_name", task.JobName)
|
||||
|
||||
// Transition to running state
|
||||
if r.stateMgr != nil {
|
||||
if err := r.stateMgr.Transition(task, StateRunning); err != nil {
|
||||
r.logger.Error("failed to transition task state", "task_id", task.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
taskCtx, cancel := context.WithTimeout(r.ctx, 24*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
if err := r.executor.Execute(taskCtx, task); err != nil {
|
||||
r.logger.Error("task execution failed", "task_id", task.ID, "error", err)
|
||||
r.metrics.RecordTaskFailure()
|
||||
|
||||
// Transition to failed state
|
||||
if r.stateMgr != nil {
|
||||
if stateErr := r.stateMgr.Transition(task, StateFailed); stateErr != nil {
|
||||
r.logger.Error("failed to transition task to failed state", "task_id", task.ID, "error", stateErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Transition to completed state
|
||||
if r.stateMgr != nil {
|
||||
if err := r.stateMgr.Transition(task, StateCompleted); err != nil {
|
||||
r.logger.Error("failed to transition task to completed state", "task_id", task.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = r.queue.ReleaseLease(task.ID, r.config.WorkerID)
|
||||
|
|
|
|||
130
internal/worker/lifecycle/states.go
Normal file
130
internal/worker/lifecycle/states.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
// Package lifecycle provides task lifecycle management with explicit state transitions.
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
// TaskState represents the current state of a task in its lifecycle.
|
||||
// These states form a finite state machine with valid transitions defined below.
|
||||
type TaskState string
|
||||
|
||||
const (
|
||||
// StateQueued indicates the task is waiting to be picked up by a worker.
|
||||
StateQueued TaskState = "queued"
|
||||
// StatePreparing indicates the task is being prepared (workspace setup, data staging).
|
||||
StatePreparing TaskState = "preparing"
|
||||
// StateRunning indicates the task is currently executing in a container.
|
||||
StateRunning TaskState = "running"
|
||||
// StateCollecting indicates the task has finished execution and results are being collected.
|
||||
StateCollecting TaskState = "collecting"
|
||||
// StateCompleted indicates the task finished successfully.
|
||||
StateCompleted TaskState = "completed"
|
||||
// StateFailed indicates the task failed during execution.
|
||||
StateFailed TaskState = "failed"
|
||||
)
|
||||
|
||||
// ValidTransitions defines the allowed state transitions.
|
||||
// The key is the "from" state, the value is a list of valid "to" states.
|
||||
// This enforces that state transitions follow the expected lifecycle.
|
||||
var ValidTransitions = map[TaskState][]TaskState{
|
||||
StateQueued: {StatePreparing, StateFailed},
|
||||
StatePreparing: {StateRunning, StateFailed},
|
||||
StateRunning: {StateCollecting, StateFailed},
|
||||
StateCollecting: {StateCompleted, StateFailed},
|
||||
StateCompleted: {},
|
||||
StateFailed: {},
|
||||
}
|
||||
|
||||
// StateTransitionError is returned when an invalid state transition is attempted.
|
||||
type StateTransitionError struct {
|
||||
From TaskState
|
||||
To TaskState
|
||||
}
|
||||
|
||||
func (e StateTransitionError) Error() string {
|
||||
return fmt.Sprintf("invalid state transition: %s -> %s", e.From, e.To)
|
||||
}
|
||||
|
||||
// StateManager manages task state transitions with audit logging.
|
||||
type StateManager struct {
|
||||
enabled bool
|
||||
auditor *audit.Logger
|
||||
}
|
||||
|
||||
// NewStateManager creates a new state manager with the given audit logger.
|
||||
func NewStateManager(auditor *audit.Logger) *StateManager {
|
||||
return &StateManager{
|
||||
enabled: auditor != nil,
|
||||
auditor: auditor,
|
||||
}
|
||||
}
|
||||
|
||||
// Transition attempts to transition a task from its current state to a new state.
|
||||
// It validates the transition, updates the task status, and logs the event.
|
||||
// Returns StateTransitionError if the transition is not valid.
|
||||
func (sm *StateManager) Transition(task *domain.Task, to TaskState) error {
|
||||
from := TaskState(task.Status)
|
||||
|
||||
// Validate the transition
|
||||
if err := sm.validateTransition(from, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Audit the transition before updating
|
||||
if sm.enabled && sm.auditor != nil {
|
||||
sm.auditor.Log(audit.Event{
|
||||
EventType: audit.EventJobStarted,
|
||||
Timestamp: time.Now(),
|
||||
Resource: task.ID,
|
||||
Action: "task_state_change",
|
||||
Success: true,
|
||||
Metadata: map[string]interface{}{
|
||||
"job_name": task.JobName,
|
||||
"old_state": string(from),
|
||||
"new_state": string(to),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Update task state
|
||||
task.Status = string(to)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTransition checks if a transition from one state to another is valid.
|
||||
func (sm *StateManager) validateTransition(from, to TaskState) error {
|
||||
// Check if "from" state is valid
|
||||
allowed, ok := ValidTransitions[from]
|
||||
if !ok {
|
||||
return StateTransitionError{From: from, To: to}
|
||||
}
|
||||
|
||||
// Check if "to" state is in the allowed list
|
||||
if slices.Contains(allowed, to) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return StateTransitionError{From: from, To: to}
|
||||
}
|
||||
|
||||
// IsTerminalState returns true if the state is terminal (no further transitions allowed).
|
||||
func IsTerminalState(state TaskState) bool {
|
||||
return state == StateCompleted || state == StateFailed
|
||||
}
|
||||
|
||||
// CanTransition returns true if a transition from -> to is valid.
|
||||
func CanTransition(from, to TaskState) bool {
|
||||
allowed, ok := ValidTransitions[from]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return slices.Contains(allowed, to)
|
||||
}
|
||||
|
|
@ -18,11 +18,11 @@ func BenchmarkDatasetSizeComparison(b *testing.B) {
|
|||
numFiles int
|
||||
totalMB int
|
||||
}{
|
||||
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
|
||||
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
|
||||
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
|
||||
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
|
||||
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
|
||||
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
|
||||
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
|
||||
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
|
||||
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
|
||||
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
|
||||
}
|
||||
|
||||
for _, tc := range sizes {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
// - Regex matching is CPU-intensive
|
||||
// - High volume log pipelines process thousands of messages/sec
|
||||
// - C++ can use Hyperscan/RE2 for parallel regex matching
|
||||
//
|
||||
// Expected speedup: 3-5x for high-volume logging
|
||||
func BenchmarkLogSanitizeMessage(b *testing.B) {
|
||||
// Test messages with various sensitive data patterns
|
||||
|
|
|
|||
|
|
@ -1,68 +1,68 @@
|
|||
package worker_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
func TestScanArtifacts_SkipsKnownPathsAndLogs(t *testing.T) {
|
||||
runDir := t.TempDir()
|
||||
runDir := t.TempDir()
|
||||
|
||||
mustWrite := func(rel string, data []byte) {
|
||||
p := filepath.Join(runDir, rel)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil {
|
||||
t.Fatalf("mkdir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(p, data, 0600); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
}
|
||||
mustWrite := func(rel string, data []byte) {
|
||||
p := filepath.Join(runDir, rel)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0750); err != nil {
|
||||
t.Fatalf("mkdir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(p, data, 0600); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
mustWrite("run_manifest.json", []byte("{}"))
|
||||
mustWrite("output.log", []byte("log"))
|
||||
mustWrite("code/ignored.txt", []byte("ignore"))
|
||||
mustWrite("snapshot/ignored.bin", []byte("ignore"))
|
||||
mustWrite("run_manifest.json", []byte("{}"))
|
||||
mustWrite("output.log", []byte("log"))
|
||||
mustWrite("code/ignored.txt", []byte("ignore"))
|
||||
mustWrite("snapshot/ignored.bin", []byte("ignore"))
|
||||
|
||||
mustWrite("results/metrics.jsonl", []byte("m"))
|
||||
mustWrite("checkpoints/best.pt", []byte("checkpoint"))
|
||||
mustWrite("plots/loss.png", []byte("png"))
|
||||
mustWrite("results/metrics.jsonl", []byte("m"))
|
||||
mustWrite("checkpoints/best.pt", []byte("checkpoint"))
|
||||
mustWrite("plots/loss.png", []byte("png"))
|
||||
|
||||
art, err := worker.ScanArtifacts(runDir)
|
||||
if err != nil {
|
||||
t.Fatalf("scanArtifacts: %v", err)
|
||||
}
|
||||
if art == nil {
|
||||
t.Fatalf("expected artifacts")
|
||||
}
|
||||
art, err := worker.ScanArtifacts(runDir)
|
||||
if err != nil {
|
||||
t.Fatalf("scanArtifacts: %v", err)
|
||||
}
|
||||
if art == nil {
|
||||
t.Fatalf("expected artifacts")
|
||||
}
|
||||
|
||||
paths := make([]string, 0, len(art.Files))
|
||||
var total int64
|
||||
for _, f := range art.Files {
|
||||
paths = append(paths, f.Path)
|
||||
total += f.SizeBytes
|
||||
}
|
||||
paths := make([]string, 0, len(art.Files))
|
||||
var total int64
|
||||
for _, f := range art.Files {
|
||||
paths = append(paths, f.Path)
|
||||
total += f.SizeBytes
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"checkpoints/best.pt",
|
||||
"plots/loss.png",
|
||||
"results/metrics.jsonl",
|
||||
}
|
||||
if len(paths) != len(want) {
|
||||
t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths)
|
||||
}
|
||||
for i := range want {
|
||||
if paths[i] != want[i] {
|
||||
t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i])
|
||||
}
|
||||
}
|
||||
want := []string{
|
||||
"checkpoints/best.pt",
|
||||
"plots/loss.png",
|
||||
"results/metrics.jsonl",
|
||||
}
|
||||
if len(paths) != len(want) {
|
||||
t.Fatalf("expected %d files, got %d: %v", len(want), len(paths), paths)
|
||||
}
|
||||
for i := range want {
|
||||
if paths[i] != want[i] {
|
||||
t.Fatalf("expected paths[%d]=%q, got %q", i, want[i], paths[i])
|
||||
}
|
||||
}
|
||||
|
||||
if art.TotalSizeBytes != total {
|
||||
t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes)
|
||||
}
|
||||
if art.DiscoveryTime.IsZero() {
|
||||
t.Fatalf("expected discovery_time")
|
||||
}
|
||||
if art.TotalSizeBytes != total {
|
||||
t.Fatalf("expected total_size_bytes=%d, got %d", total, art.TotalSizeBytes)
|
||||
}
|
||||
if art.DiscoveryTime.IsZero() {
|
||||
t.Fatalf("expected discovery_time")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue