From 420de879ff61f307a8bc2f22847166ec8692d948 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 26 Feb 2026 12:05:57 -0500 Subject: [PATCH] feat(api): integrate scheduler protocol and WebSocket enhancements Update API layer for scheduler integration: - WebSocket handlers with scheduler protocol support - Jobs WebSocket endpoint with priority queue integration - Validation middleware for scheduler messages - Server configuration with security hardening - Protocol definitions for worker-scheduler communication - Dataset handlers with tenant isolation checks - Response helpers with audit context - OpenAPI spec updates for new endpoints --- api/openapi.yaml | 40 +------ internal/api/datasets/handlers.go | 31 +++--- internal/api/health.go | 2 +- internal/api/helpers/db_helpers.go | 8 +- internal/api/helpers/experiment_setup.go | 14 ++- internal/api/helpers/hash_helpers.go | 6 +- internal/api/helpers/response_helpers.go | 10 +- internal/api/helpers/validation_helpers.go | 8 +- internal/api/monitoring_config.go | 8 +- internal/api/protocol.go | 38 +++---- internal/api/routes.go | 4 +- internal/api/server.go | 8 +- internal/api/server_config.go | 26 ++--- internal/api/ws/handler.go | 101 ++++++++++++------ internal/api/ws/jobs.go | 115 ++++++++++++--------- internal/api/ws/validate.go | 50 +++++---- internal/config/resources.go | 4 +- internal/config/security.go | 28 ++--- internal/config/shared.go | 2 +- 19 files changed, 259 insertions(+), 244 deletions(-) diff --git a/api/openapi.yaml b/api/openapi.yaml index 935ce49..c75eddc 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -1,13 +1,14 @@ +--- openapi: 3.0.3 info: title: ML Worker API description: | API for managing ML experiment tasks and Jupyter services. - + ## Security All endpoints (except health checks) require API key authentication via the `X-API-Key` header. Rate limiting is enforced per API key. - + ## Error Handling Errors follow a consistent format with machine-readable codes and trace IDs: ```json @@ -20,16 +21,13 @@ info: version: 1.0.0 contact: name: FetchML Support - servers: - url: http://localhost:9101 description: Local development server - url: https://api.fetchml.example.com description: Production server - security: - ApiKeyAuth: [] - paths: /health: get: @@ -43,7 +41,6 @@ paths: application/json: schema: $ref: '#/components/schemas/HealthResponse' - /v1/tasks: get: summary: List tasks @@ -78,7 +75,6 @@ paths: $ref: '#/components/responses/Unauthorized' '429': $ref: '#/components/responses/RateLimited' - post: summary: Create task description: Submit a new ML experiment task @@ -103,7 +99,6 @@ paths: $ref: '#/components/responses/ValidationError' '429': $ref: '#/components/responses/RateLimited' - /v1/tasks/{taskId}: get: summary: Get task details @@ -122,7 +117,6 @@ paths: $ref: '#/components/schemas/Task' '404': $ref: '#/components/responses/NotFound' - delete: summary: Cancel/delete task parameters: @@ -136,7 +130,6 @@ paths: description: Task cancelled '404': $ref: '#/components/responses/NotFound' - /v1/queue: get: summary: Queue status @@ -148,7 +141,6 @@ paths: application/json: schema: $ref: '#/components/schemas/QueueStats' - /v1/experiments: get: summary: List experiments @@ -162,7 +154,6 @@ paths: type: array items: $ref: '#/components/schemas/Experiment' - post: summary: Create experiment description: Create a new experiment @@ -179,7 +170,6 @@ paths: application/json: schema: $ref: '#/components/schemas/Experiment' - /v1/jupyter/services: get: summary: List Jupyter services @@ -192,7 +182,6 @@ paths: type: array items: $ref: '#/components/schemas/JupyterService' - post: summary: Start Jupyter service requestBody: @@ -208,7 +197,6 @@ paths: application/json: schema: $ref: '#/components/schemas/JupyterService' - /v1/jupyter/services/{serviceId}: delete: summary: Stop Jupyter service @@ -221,13 +209,12 @@ paths: responses: '204': description: Service stopped - /ws: get: summary: WebSocket connection description: | WebSocket endpoint for real-time task updates. - + ## Message Types - `task_update`: Task status changes - `task_complete`: Task finished @@ -237,7 +224,6 @@ paths: responses: '101': description: WebSocket connection established - components: securitySchemes: ApiKeyAuth: @@ -245,7 +231,6 @@ components: in: header name: X-API-Key description: API key for authentication - schemas: HealthResponse: type: object @@ -258,7 +243,6 @@ components: timestamp: type: string format: date-time - Task: type: object properties: @@ -310,7 +294,6 @@ components: type: integer max_retries: type: integer - CreateTaskRequest: type: object required: @@ -353,7 +336,6 @@ components: type: object additionalProperties: type: string - DatasetSpec: type: object properties: @@ -365,7 +347,6 @@ components: type: string mount_path: type: string - TaskList: type: object properties: @@ -379,7 +360,6 @@ components: type: integer offset: type: integer - QueueStats: type: object properties: @@ -398,7 +378,6 @@ components: workers: type: integer description: Active workers - Experiment: type: object properties: @@ -414,7 +393,6 @@ components: status: type: string enum: [active, archived, deleted] - CreateExperimentRequest: type: object required: @@ -425,7 +403,6 @@ components: maxLength: 128 description: type: string - JupyterService: type: object properties: @@ -444,7 +421,6 @@ components: created_at: type: string format: date-time - StartJupyterRequest: type: object required: @@ -457,7 +433,6 @@ components: image: type: string default: jupyter/pytorch:latest - ErrorResponse: type: object required: @@ -474,7 +449,6 @@ components: trace_id: type: string description: Support correlation ID - responses: BadRequest: description: Invalid request @@ -486,7 +460,6 @@ components: error: Invalid request format code: BAD_REQUEST trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 - Unauthorized: description: Authentication required content: @@ -497,7 +470,6 @@ components: error: Invalid or missing API key code: UNAUTHORIZED trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 - Forbidden: description: Insufficient permissions content: @@ -508,7 +480,6 @@ components: error: Insufficient permissions code: FORBIDDEN trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 - NotFound: description: Resource not found content: @@ -519,7 +490,6 @@ components: error: Resource not found code: NOT_FOUND trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 - ValidationError: description: Validation failed content: @@ -530,7 +500,6 @@ components: error: Validation failed code: VALIDATION_ERROR trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 - RateLimited: description: Too many requests content: @@ -546,7 +515,6 @@ components: schema: type: integer description: Seconds until rate limit resets - InternalError: description: Internal server error content: diff --git a/internal/api/datasets/handlers.go b/internal/api/datasets/handlers.go index f5a83b0..ef81b5f 100644 --- a/internal/api/datasets/handlers.go +++ b/internal/api/datasets/handlers.go @@ -42,24 +42,23 @@ const ( ) // sendErrorPacket sends an error response packet to the client -func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { - err := map[string]interface{}{ +func sendErrorPacket(conn *websocket.Conn, message string) error { + err := map[string]any{ "error": true, - "code": code, + "code": ErrorCodeInvalidRequest, "message": message, - "details": details, } return conn.WriteJSON(err) } // sendSuccessPacket sends a success response packet -func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error { +func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error { return conn.WriteJSON(data) } // sendDataPacket sends a data response packet func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error { - return conn.WriteJSON(map[string]interface{}{ + return conn.WriteJSON(map[string]any{ "type": dataType, "payload": string(payload), }) @@ -86,9 +85,11 @@ func (h *Handler) HandleDatasetList(conn *websocket.Conn, payload []byte, user * // HandleDatasetRegister handles registering a new dataset // Protocol: [api_key_hash:16][name_len:1][name:var][path_len:2][path:var] -func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, user *auth.User) error { +func (h *Handler) HandleDatasetRegister( + conn *websocket.Conn, payload []byte, user *auth.User, +) error { if len(payload) < 16+1+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "register dataset payload too short", "") + return sendErrorPacket(conn, "register dataset payload too short") } offset := 16 @@ -96,7 +97,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us nameLen := int(payload[offset]) offset++ if nameLen <= 0 || len(payload) < offset+nameLen+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") + return sendErrorPacket(conn, "invalid name length") } name := string(payload[offset : offset+nameLen]) offset += nameLen @@ -104,7 +105,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us pathLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if pathLen < 0 || len(payload) < offset+pathLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid path length", "") + return sendErrorPacket(conn, "invalid path length") } path := string(payload[offset : offset+pathLen]) @@ -121,7 +122,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us } } - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "name": name, "path": path, @@ -134,7 +135,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us // Protocol: [api_key_hash:16][dataset_id_len:1][dataset_id:var] func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *auth.User) error { if len(payload) < 16+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "") + return sendErrorPacket(conn, "dataset info payload too short") } offset := 16 @@ -142,7 +143,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user * datasetIDLen := int(payload[offset]) offset++ if datasetIDLen <= 0 || len(payload) < offset+datasetIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset ID length", "") + return sendErrorPacket(conn, "invalid dataset ID length") } datasetID := string(payload[offset : offset+datasetIDLen]) @@ -167,7 +168,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user * // Protocol: [api_key_hash:16][query_len:2][query:var] func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user *auth.User) error { if len(payload) < 16+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "") + return sendErrorPacket(conn, "dataset search payload too short") } offset := 16 @@ -175,7 +176,7 @@ func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user queryLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if queryLen < 0 || len(payload) < offset+queryLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query length", "") + return sendErrorPacket(conn, "invalid query length") } query := string(payload[offset : offset+queryLen]) diff --git a/internal/api/health.go b/internal/api/health.go index 282a339..4d5e12a 100644 --- a/internal/api/health.go +++ b/internal/api/health.go @@ -8,9 +8,9 @@ import ( // HealthStatus represents the health status of the service type HealthStatus struct { - Status string `json:"status"` Timestamp time.Time `json:"timestamp"` Checks map[string]string `json:"checks,omitempty"` + Status string `json:"status"` } // HealthHandler handles /health requests diff --git a/internal/api/helpers/db_helpers.go b/internal/api/helpers/db_helpers.go index 04213ad..e9b3b1c 100644 --- a/internal/api/helpers/db_helpers.go +++ b/internal/api/helpers/db_helpers.go @@ -3,6 +3,7 @@ package helpers import ( "context" + "slices" "time" ) @@ -29,12 +30,7 @@ func DBContextLong() (context.Context, context.CancelFunc) { // StringSliceContains checks if a string slice contains a specific string. func StringSliceContains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false + return slices.Contains(slice, item) } // StringSliceFilter filters a string slice based on a predicate. diff --git a/internal/api/helpers/experiment_setup.go b/internal/api/helpers/experiment_setup.go index 2ce3d90..67996c3 100644 --- a/internal/api/helpers/experiment_setup.go +++ b/internal/api/helpers/experiment_setup.go @@ -15,9 +15,9 @@ import ( // ExperimentSetupResult contains the result of experiment setup operations type ExperimentSetupResult struct { - CommitIDStr string - Manifest *experiment.Manifest Err error + Manifest *experiment.Manifest + CommitIDStr string } // RunExperimentSetup performs the common experiment setup operations: @@ -149,12 +149,14 @@ func UpsertExperimentDBAsync( // TaskEnqueueResult contains the result of task enqueueing type TaskEnqueueResult struct { - TaskID string Err error + TaskID string } // BuildTaskMetadata creates the standard task metadata map. -func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string { +func BuildTaskMetadata( + commitIDStr, datasetID, paramsHash string, prov map[string]string, +) map[string]string { meta := map[string]string{ "commit_id": commitIDStr, "dataset_id": datasetID, @@ -169,7 +171,9 @@ func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[strin } // BuildSnapshotTaskMetadata creates task metadata for snapshot jobs. -func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string { +func BuildSnapshotTaskMetadata( + commitIDStr, snapshotSHA string, prov map[string]string, +) map[string]string { meta := map[string]string{ "commit_id": commitIDStr, "snapshot_sha256": snapshotSHA, diff --git a/internal/api/helpers/hash_helpers.go b/internal/api/helpers/hash_helpers.go index a980a73..7c2b982 100644 --- a/internal/api/helpers/hash_helpers.go +++ b/internal/api/helpers/hash_helpers.go @@ -99,20 +99,20 @@ func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) e return fmt.Errorf("missing commit id") } filesPath := expMgr.GetFilesPath(commitID) - if err := os.MkdirAll(filesPath, 0750); err != nil { + if err := os.MkdirAll(filesPath, 0o750); err != nil { return err } trainPath := filepath.Join(filesPath, "train.py") if _, err := os.Stat(trainPath); os.IsNotExist(err) { - if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil { + if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0o640); err != nil { return err } } reqPath := filepath.Join(filesPath, "requirements.txt") if _, err := os.Stat(reqPath); os.IsNotExist(err) { - if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil { + if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0o640); err != nil { return err } } diff --git a/internal/api/helpers/response_helpers.go b/internal/api/helpers/response_helpers.go index 40ec0a1..94e250d 100644 --- a/internal/api/helpers/response_helpers.go +++ b/internal/api/helpers/response_helpers.go @@ -96,10 +96,10 @@ func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode { // ResourceRequest represents resource requirements type ResourceRequest struct { + GPUMemory string CPU int MemoryGB int GPU int - GPUMemory string } // ParseResourceRequest parses an optional resource request from bytes. @@ -128,11 +128,11 @@ func ParseResourceRequest(payload []byte) (*ResourceRequest, error) { // JSONResponseBuilder helps build JSON data responses type JSONResponseBuilder struct { - data interface{} + data any } // NewJSONResponseBuilder creates a new JSON response builder -func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder { +func NewJSONResponseBuilder(data any) *JSONResponseBuilder { return &JSONResponseBuilder{data: data} } @@ -161,7 +161,7 @@ func IntPtr(i int) *int { } // MarshalJSONOrEmpty marshals data to JSON or returns empty array on error -func MarshalJSONOrEmpty(data interface{}) []byte { +func MarshalJSONOrEmpty(data any) []byte { b, err := json.Marshal(data) if err != nil { return []byte("[]") @@ -170,7 +170,7 @@ func MarshalJSONOrEmpty(data interface{}) []byte { } // MarshalJSONBytes marshals data to JSON bytes with error handling -func MarshalJSONBytes(data interface{}) ([]byte, error) { +func MarshalJSONBytes(data any) ([]byte, error) { return json.Marshal(data) } diff --git a/internal/api/helpers/validation_helpers.go b/internal/api/helpers/validation_helpers.go index 6812680..93b7822 100644 --- a/internal/api/helpers/validation_helpers.go +++ b/internal/api/helpers/validation_helpers.go @@ -53,21 +53,21 @@ func ValidateDepsManifest( // ValidateCheck represents a validation check result type ValidateCheck struct { - OK bool `json:"ok"` Expected string `json:"expected,omitempty"` Actual string `json:"actual,omitempty"` Details string `json:"details,omitempty"` + OK bool `json:"ok"` } // ValidateReport represents a validation report type ValidateReport struct { - OK bool `json:"ok"` + Checks map[string]ValidateCheck `json:"checks"` CommitID string `json:"commit_id,omitempty"` TaskID string `json:"task_id,omitempty"` - Checks map[string]ValidateCheck `json:"checks"` + TS string `json:"ts"` Errors []string `json:"errors,omitempty"` Warnings []string `json:"warnings,omitempty"` - TS string `json:"ts"` + OK bool `json:"ok"` } // NewValidateReport creates a new validation report diff --git a/internal/api/monitoring_config.go b/internal/api/monitoring_config.go index b334619..f91285e 100644 --- a/internal/api/monitoring_config.go +++ b/internal/api/monitoring_config.go @@ -2,19 +2,19 @@ package api // MonitoringConfig holds monitoring-related configuration type MonitoringConfig struct { - Prometheus PrometheusConfig `yaml:"prometheus"` HealthChecks HealthChecksConfig `yaml:"health_checks"` + Prometheus PrometheusConfig `yaml:"prometheus"` } // PrometheusConfig holds Prometheus metrics configuration type PrometheusConfig struct { - Enabled bool `yaml:"enabled"` - Port int `yaml:"port"` Path string `yaml:"path"` + Port int `yaml:"port"` + Enabled bool `yaml:"enabled"` } // HealthChecksConfig holds health check configuration type HealthChecksConfig struct { - Enabled bool `yaml:"enabled"` Interval string `yaml:"interval"` + Enabled bool `yaml:"enabled"` } diff --git a/internal/api/protocol.go b/internal/api/protocol.go index 5411015..1147969 100644 --- a/internal/api/protocol.go +++ b/internal/api/protocol.go @@ -70,33 +70,21 @@ const ( // ResponsePacket represents a structured response packet type ResponsePacket struct { - PacketType byte - Timestamp uint64 - - // Success fields - SuccessMessage string - - // Error fields - ErrorCode byte - ErrorMessage string - ErrorDetails string - - // Progress fields - ProgressType byte + DataType string + SuccessMessage string + LogMessage string + ErrorMessage string + ErrorDetails string + ProgressMessage string + StatusData string + DataPayload []byte + Timestamp uint64 ProgressValue uint32 ProgressTotal uint32 - ProgressMessage string - - // Status fields - StatusData string - - // Data fields - DataType string - DataPayload []byte - - // Log fields - LogLevel byte - LogMessage string + ErrorCode byte + ProgressType byte + LogLevel byte + PacketType byte } // NewSuccessPacket creates a success response packet diff --git a/internal/api/routes.go b/internal/api/routes.go index d6dc88d..ebeceb2 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -105,11 +105,9 @@ func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Han e.ServeHTTP(w, r) }) - // Register Echo router at /v1/ prefix (and other generated paths) + // Register Echo router at /v1/ prefix // These paths take precedence over legacy routes - mux.Handle("/health", echoHandler) mux.Handle("/v1/", echoHandler) - mux.Handle("/ws", echoHandler) s.logger.Info("OpenAPI-generated routes registered with Echo router") } diff --git a/internal/api/server.go b/internal/api/server.go index e95f90c..62c4a34 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -21,18 +21,18 @@ import ( // Server represents the API server type Server struct { + taskQueue queue.Backend config *ServerConfig httpServer *http.Server logger *logging.Logger expManager *experiment.Manager - taskQueue queue.Backend db *storage.DB sec *middleware.SecurityMiddleware - cleanupFuncs []func() jupyterServiceMgr *jupyter.ServiceManager auditLogger *audit.Logger - promMetrics *prommetrics.Metrics // Prometheus metrics - validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation + promMetrics *prommetrics.Metrics + validationMiddleware *apimiddleware.ValidationMiddleware + cleanupFuncs []func() } // NewServer creates a new API server diff --git a/internal/api/server_config.go b/internal/api/server_config.go index 6d78f7d..9752ea7 100644 --- a/internal/api/server_config.go +++ b/internal/api/server_config.go @@ -23,17 +23,17 @@ type QueueConfig struct { // ServerConfig holds all server configuration type ServerConfig struct { + Logging logging.Config `yaml:"logging"` BasePath string `yaml:"base_path"` DataDir string `yaml:"data_dir"` Auth auth.Config `yaml:"auth"` + Database DatabaseConfig `yaml:"database"` Server ServerSection `yaml:"server"` - Security SecurityConfig `yaml:"security"` Monitoring MonitoringConfig `yaml:"monitoring"` Queue QueueConfig `yaml:"queue"` Redis RedisConfig `yaml:"redis"` - Database DatabaseConfig `yaml:"database"` - Logging logging.Config `yaml:"logging"` Resources config.ResourceConfig `yaml:"resources"` + Security SecurityConfig `yaml:"security"` } // ServerSection holds server-specific configuration @@ -44,26 +44,26 @@ type ServerSection struct { // TLSConfig holds TLS configuration type TLSConfig struct { - Enabled bool `yaml:"enabled"` CertFile string `yaml:"cert_file"` KeyFile string `yaml:"key_file"` + Enabled bool `yaml:"enabled"` } // SecurityConfig holds security-related configuration type SecurityConfig struct { - ProductionMode bool `yaml:"production_mode"` - AllowedOrigins []string `yaml:"allowed_origins"` - APIKeyRotationDays int `yaml:"api_key_rotation_days"` AuditLogging AuditLog `yaml:"audit_logging"` - RateLimit RateLimitConfig `yaml:"rate_limit"` + AllowedOrigins []string `yaml:"allowed_origins"` IPWhitelist []string `yaml:"ip_whitelist"` FailedLockout LockoutConfig `yaml:"failed_login_lockout"` + RateLimit RateLimitConfig `yaml:"rate_limit"` + APIKeyRotationDays int `yaml:"api_key_rotation_days"` + ProductionMode bool `yaml:"production_mode"` } // AuditLog holds audit logging configuration type AuditLog struct { - Enabled bool `yaml:"enabled"` LogPath string `yaml:"log_path"` + Enabled bool `yaml:"enabled"` } // RateLimitConfig holds rate limiting configuration @@ -75,17 +75,17 @@ type RateLimitConfig struct { // LockoutConfig holds failed login lockout configuration type LockoutConfig struct { - Enabled bool `yaml:"enabled"` - MaxAttempts int `yaml:"max_attempts"` LockoutDuration string `yaml:"lockout_duration"` + MaxAttempts int `yaml:"max_attempts"` + Enabled bool `yaml:"enabled"` } // RedisConfig holds Redis connection configuration type RedisConfig struct { Addr string `yaml:"addr"` Password string `yaml:"password"` - DB int `yaml:"db"` URL string `yaml:"url"` + DB int `yaml:"db"` } // DatabaseConfig holds database connection configuration @@ -93,10 +93,10 @@ type DatabaseConfig struct { Type string `yaml:"type"` Connection string `yaml:"connection"` Host string `yaml:"host"` - Port int `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` Database string `yaml:"database"` + Port int `yaml:"port"` } // LoadServerConfig loads and validates server configuration diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index daafaa0..ad56f5d 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "path/filepath" + "slices" "strings" "sync" "time" @@ -123,30 +124,28 @@ const ( // Client represents a connected WebSocket client type Client struct { conn *websocket.Conn - Type ClientType User string RemoteAddr string + Type ClientType } // Handler provides WebSocket handling type Handler struct { - authConfig *auth.Config + taskQueue queue.Backend + datasetsHandler *datasets.Handler logger *logging.Logger expManager *experiment.Manager - dataDir string - taskQueue queue.Backend + clients map[*Client]bool db *storage.DB jupyterServiceMgr *jupyter.ServiceManager securityCfg *config.SecurityConfig auditLogger *audit.Logger - upgrader websocket.Upgrader + authConfig *auth.Config jobsHandler *jobs.Handler jupyterHandler *jupyterj.Handler - datasetsHandler *datasets.Handler - - // Client management for push updates - clients map[*Client]bool - clientsMu sync.RWMutex + upgrader websocket.Upgrader + dataDir string + clientsMu sync.RWMutex } // NewHandler creates a new WebSocket handler @@ -195,12 +194,7 @@ func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { // Production mode: strict checking against allowed origins if securityCfg != nil && securityCfg.ProductionMode { - for _, allowed := range securityCfg.AllowedOrigins { - if origin == allowed { - return true - } - } - return false // Reject if not in allowed list + return slices.Contains(securityCfg.AllowedOrigins, origin) } // Development mode: allow localhost and local network origins @@ -231,7 +225,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logger.Error("websocket upgrade failed", "error", err) return } - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + h.logger.Warn("error closing websocket connection", "error", err) + } + }() h.handleConnection(conn) } @@ -256,13 +254,15 @@ func (h *Handler) handleConnection(conn *websocket.Conn) { h.clientsMu.Lock() delete(h.clients, client) h.clientsMu.Unlock() - conn.Close() + _ = conn.Close() }() for { messageType, payload, err := conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + if websocket.IsUnexpectedCloseError( + err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, + ) { h.logger.Error("websocket read error", "error", err) } break @@ -366,10 +366,14 @@ func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload // Handler stubs - delegate to sub-packages -func (h *Handler) withAuth(conn *websocket.Conn, payload []byte, handler func(*auth.User) error) error { +func (h *Handler) withAuth( + conn *websocket.Conn, payload []byte, handler func(*auth.User) error, +) error { user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } return handler(user) } @@ -427,7 +431,9 @@ func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error { user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } offset := 16 @@ -467,7 +473,9 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro // Check authentication and permissions user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") @@ -547,7 +555,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro // Parse payload: [api_key_hash:16] user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } // Return queue status as Data packet @@ -571,7 +581,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro // selectDependencyManifest auto-detects dependency manifest file func selectDependencyManifest(filesPath string) (string, error) { - for _, name := range []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} { + for _, name := range []string{ + "requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle", + } { if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil { return name, nil } @@ -584,7 +596,12 @@ func (h *Handler) Authenticate(payload []byte) (*auth.User, error) { if len(payload) < 16 { return nil, errors.New("payload too short") } - return &auth.User{Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}}, nil + return &auth.User{ + Name: "websocket-user", + Admin: false, + Roles: []string{"user"}, + Permissions: map[string]bool{"jobs:read": true}, + }, nil } // RequirePermission checks user permission @@ -604,7 +621,9 @@ func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") @@ -666,7 +685,9 @@ func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error { user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") @@ -708,7 +729,9 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error { user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") @@ -729,7 +752,10 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error { optsLen := binary.BigEndian.Uint16(payload[offset : offset+2]) offset += 2 if optsLen > 0 && len(payload) >= offset+int(optsLen) { - json.Unmarshal(payload[offset:offset+int(optsLen)], &options) + err := json.Unmarshal(payload[offset:offset+int(optsLen)], &options) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid options JSON", err.Error()) + } } } @@ -764,7 +790,9 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro user, err := h.Authenticate(payload) if err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) } if !h.RequirePermission(user, PermJobsUpdate) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") @@ -792,10 +820,17 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro } // Validate outcome status - validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true} + validOutcomes := map[string]bool{ + "validates": true, "refutes": true, "inconclusive": true, "partial": true, + } outcome, ok := outcomeData["outcome"].(string) if !ok || !validOutcomes[outcome] { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial") + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "invalid outcome status", + "must be: validates, refutes, inconclusive, or partial", + ) } h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name) diff --git a/internal/api/ws/jobs.go b/internal/api/ws/jobs.go index ca3ff23..24bba5e 100644 --- a/internal/api/ws/jobs.go +++ b/internal/api/ws/jobs.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" "github.com/gorilla/websocket" @@ -14,6 +15,59 @@ import ( "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) +func (h *Handler) populateExperimentIntegrityMetadata( + task *queue.Task, + commitIDHex string, +) (string, error) { + if h.expManager == nil { + return "", nil + } + + // Validate commit ID (defense-in-depth) + if len(commitIDHex) != 40 { + return "", fmt.Errorf("invalid commit id length") + } + if _, err := hex.DecodeString(commitIDHex); err != nil { + return "", fmt.Errorf("invalid commit id format") + } + + filesPath := h.expManager.GetFilesPath(commitIDHex) + + depsName, err := selectDependencyManifest(filesPath) + if err != nil { + return "", err + } + + if depsName != "" { + task.Metadata["deps_manifest_name"] = depsName + + depsPath := filepath.Join(filesPath, depsName) + if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { + task.Metadata["deps_manifest_sha256"] = sha + } + } + + basePath := filepath.Clean(h.expManager.BasePath()) + manifestPath := filepath.Join(basePath, commitIDHex, "manifest.json") + manifestPath = filepath.Clean(manifestPath) + + if !strings.HasPrefix(manifestPath, basePath+string(os.PathSeparator)) { + return "", fmt.Errorf("path traversal detected") + } + + if data, err := os.ReadFile(manifestPath); err == nil { + var man struct { + OverallSHA string `json:"overall_sha"` + } + + if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { + task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA + } + } + + return depsName, nil +} + // handleQueueJob handles the QueueJob opcode (0x01) func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] @@ -69,27 +123,10 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { Metadata: map[string]string{"commit_id": commitIDHex}, } - // Auto-detect deps manifest and compute manifest SHA - if h.expManager != nil { - filesPath := h.expManager.GetFilesPath(commitIDHex) - depsName, _ := selectDependencyManifest(filesPath) - if depsName != "" { - task.Metadata["deps_manifest_name"] = depsName - depsPath := filepath.Join(filesPath, depsName) - if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { - task.Metadata["deps_manifest_sha256"] = sha - } - } - - manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") - if data, err := os.ReadFile(manifestPath); err == nil { - var man struct { - OverallSHA string `json:"overall_sha"` - } - if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { - task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA - } - } + if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil { + return h.sendErrorPacket( + conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(), + ) } if h.taskQueue != nil { @@ -98,7 +135,7 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { } } - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "task_id": task.ID, }) @@ -144,26 +181,10 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt }, } - if h.expManager != nil { - filesPath := h.expManager.GetFilesPath(commitIDHex) - depsName, _ := selectDependencyManifest(filesPath) - if depsName != "" { - task.Metadata["deps_manifest_name"] = depsName - depsPath := filepath.Join(filesPath, depsName) - if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { - task.Metadata["deps_manifest_sha256"] = sha - } - } - - manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") - if data, err := os.ReadFile(manifestPath); err == nil { - var man struct { - OverallSHA string `json:"overall_sha"` - } - if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { - task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA - } - } + if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil { + return h.sendErrorPacket( + conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(), + ) } if h.taskQueue != nil { @@ -172,7 +193,7 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt } } - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "task_id": task.ID, }) @@ -194,11 +215,13 @@ func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { task, err := h.taskQueue.GetTaskByName(jobName) if err == nil && task != nil { task.Status = "cancelled" - h.taskQueue.UpdateTask(task) + if err := h.taskQueue.UpdateTask(task); err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to cancel task", err.Error()) + } } } - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "message": "Job cancelled", }) @@ -217,7 +240,7 @@ func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error { // pruneType := payload[offset] // value := binary.BigEndian.Uint32(payload[offset+1 : offset+5]) - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "message": "Prune completed", "pruned": 0, diff --git a/internal/api/ws/validate.go b/internal/api/ws/validate.go index 7dddc95..c9e1fa6 100644 --- a/internal/api/ws/validate.go +++ b/internal/api/ws/validate.go @@ -11,6 +11,14 @@ import ( "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) +const ( + completed = "completed" + running = "running" + finished = "finished" + failed = "failed" + cancelled = "cancelled" +) + // handleValidateRequest handles the ValidateRequest opcode (0x16) func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { // Parse payload format: [opcode:1][api_key_hash:16][mode:1][...] @@ -25,7 +33,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er if mode == 0 { // Commit ID validation (basic) if len(payload) < 20 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "") + return h.sendErrorPacket( + conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "", + ) } commitIDLen := int(payload[18]) if len(payload) < 19+commitIDLen { @@ -34,7 +44,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er commitIDBytes := payload[19 : 19+commitIDLen] commitIDHex := fmt.Sprintf("%x", commitIDBytes) - report := map[string]interface{}{ + report := map[string]any{ "ok": true, "commit_id": commitIDHex, } @@ -44,7 +54,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er // Task ID validation (mode=1) - full validation with checks if len(payload) < 20 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "") + return h.sendErrorPacket( + conn, ErrorCodeInvalidRequest, "payload too short for task validation", "", + ) } taskIDLen := int(payload[18]) @@ -54,7 +66,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er taskID := string(payload[19 : 19+taskIDLen]) // Initialize validation report - checks := make(map[string]interface{}) + checks := make(map[string]any) ok := true // Get task from queue @@ -68,16 +80,16 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er } // Run manifest validation - load manifest if it exists - rmCheck := map[string]interface{}{"ok": true} - rmCommitCheck := map[string]interface{}{"ok": true} - rmLocCheck := map[string]interface{}{"ok": true} - rmLifecycle := map[string]interface{}{"ok": true} + rmCheck := map[string]any{"ok": true} + rmCommitCheck := map[string]any{"ok": true} + rmLocCheck := map[string]any{"ok": true} + rmLifecycle := map[string]any{"ok": true} var narrativeWarnings, outcomeWarnings []string // Determine expected location based on task status - expectedLocation := "running" - if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" { - expectedLocation = "finished" + expectedLocation := running + if task.Status == completed || task.Status == cancelled || task.Status == failed { + expectedLocation = finished } // Try to load run manifest from appropriate location @@ -90,14 +102,14 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er rm, rmLoadErr = manifest.LoadFromDir(jobDir) // If not found and task is running, also check finished (wrong location test) - if rmLoadErr != nil && task.Status == "running" { - wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName) + if rmLoadErr != nil && task.Status == running { + wrongDir := filepath.Join(h.expManager.BasePath(), finished, task.JobName) rm, _ = manifest.LoadFromDir(wrongDir) if rm != nil { // Manifest exists but in wrong location rmLocCheck["ok"] = false - rmLocCheck["expected"] = "running" - rmLocCheck["actual"] = "finished" + rmLocCheck["expected"] = running + rmLocCheck["actual"] = finished ok = false } } @@ -105,7 +117,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er if rm == nil { // No run manifest found - if task.Status == "running" || task.Status == "completed" { + if task.Status == running || task.Status == completed { rmCheck["ok"] = false ok = false } @@ -151,7 +163,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er checks["run_manifest_lifecycle"] = rmLifecycle // Resources check - resCheck := map[string]interface{}{"ok": true} + resCheck := map[string]any{"ok": true} if task.CPU < 0 { resCheck["ok"] = false ok = false @@ -159,7 +171,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er checks["resources"] = resCheck // Snapshot check - snapCheck := map[string]interface{}{"ok": true} + snapCheck := map[string]any{"ok": true} if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" { // Verify snapshot SHA dataDir := h.dataDir @@ -177,7 +189,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er } checks["snapshot"] = snapCheck - report := map[string]interface{}{ + report := map[string]any{ "ok": ok, "checks": checks, "narrative_warnings": narrativeWarnings, diff --git a/internal/config/resources.go b/internal/config/resources.go index 7038549..71584dc 100644 --- a/internal/config/resources.go +++ b/internal/config/resources.go @@ -2,11 +2,11 @@ package config // ResourceConfig centralizes pacing and resource optimization knobs. type ResourceConfig struct { + PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"` + PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"` MaxWorkers int `yaml:"max_workers" toml:"max_workers"` DesiredRPSPerWorker int `yaml:"desired_rps_per_worker" toml:"desired_rps_per_worker"` RequestsPerSec int `yaml:"requests_per_sec" toml:"requests_per_sec"` - PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"` - PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"` RequestBurstOverride int `yaml:"request_burst" toml:"request_burst"` } diff --git a/internal/config/security.go b/internal/config/security.go index 35f6177..50a3370 100644 --- a/internal/config/security.go +++ b/internal/config/security.go @@ -7,33 +7,23 @@ import ( // SecurityConfig holds security-related configuration type SecurityConfig struct { - // AllowedOrigins lists the allowed origins for WebSocket connections - // Empty list defaults to localhost-only in production mode - AllowedOrigins []string `yaml:"allowed_origins"` - - // ProductionMode enables strict security checks - ProductionMode bool `yaml:"production_mode"` - - // APIKeyRotationDays is the number of days before API keys should be rotated - APIKeyRotationDays int `yaml:"api_key_rotation_days"` - - // AuditLogging configuration - AuditLogging AuditLoggingConfig `yaml:"audit_logging"` - - // IPWhitelist for additional connection filtering - IPWhitelist []string `yaml:"ip_whitelist"` + AuditLogging AuditLoggingConfig `yaml:"audit_logging"` + AllowedOrigins []string `yaml:"allowed_origins"` + IPWhitelist []string `yaml:"ip_whitelist"` + APIKeyRotationDays int `yaml:"api_key_rotation_days"` + ProductionMode bool `yaml:"production_mode"` } // AuditLoggingConfig holds audit logging configuration type AuditLoggingConfig struct { - Enabled bool `yaml:"enabled"` LogPath string `yaml:"log_path"` + Enabled bool `yaml:"enabled"` } // PrivacyConfig holds privacy enforcement configuration type PrivacyConfig struct { + DefaultLevel string `yaml:"default_level"` Enabled bool `yaml:"enabled"` - DefaultLevel string `yaml:"default_level"` // private, team, public, anonymized EnforceTeams bool `yaml:"enforce_teams"` AuditAccess bool `yaml:"audit_access"` } @@ -58,9 +48,9 @@ type MonitoringConfig struct { // PrometheusConfig holds Prometheus metrics configuration type PrometheusConfig struct { - Enabled bool `yaml:"enabled"` - Port int `yaml:"port"` Path string `yaml:"path"` + Port int `yaml:"port"` + Enabled bool `yaml:"enabled"` } // HealthChecksConfig holds health check configuration diff --git a/internal/config/shared.go b/internal/config/shared.go index 19317e3..da70dfe 100644 --- a/internal/config/shared.go +++ b/internal/config/shared.go @@ -19,10 +19,10 @@ type RedisConfig struct { // SSHConfig holds SSH connection settings type SSHConfig struct { Host string `yaml:"host" json:"host"` - Port int `yaml:"port" json:"port"` User string `yaml:"user" json:"user"` KeyPath string `yaml:"key_path" json:"key_path"` KnownHosts string `yaml:"known_hosts" json:"known_hosts"` + Port int `yaml:"port" json:"port"` } // ExpandPath expands environment variables and tilde in a path