feat(api): integrate scheduler protocol and WebSocket enhancements

Update API layer for scheduler integration:
- WebSocket handlers with scheduler protocol support
- Jobs WebSocket endpoint with priority queue integration
- Validation middleware for scheduler messages
- Server configuration with security hardening
- Protocol definitions for worker-scheduler communication
- Dataset handlers with tenant isolation checks
- Response helpers with audit context
- OpenAPI spec updates for new endpoints
This commit is contained in:
Jeremie Fraeys 2026-02-26 12:05:57 -05:00
parent 9b2d5986a3
commit 420de879ff
No known key found for this signature in database
19 changed files with 259 additions and 244 deletions

View file

@ -1,13 +1,14 @@
---
openapi: 3.0.3
info:
title: ML Worker API
description: |
API for managing ML experiment tasks and Jupyter services.
## Security
All endpoints (except health checks) require API key authentication via the
`X-API-Key` header. Rate limiting is enforced per API key.
## Error Handling
Errors follow a consistent format with machine-readable codes and trace IDs:
```json
@ -20,16 +21,13 @@ info:
version: 1.0.0
contact:
name: FetchML Support
servers:
- url: http://localhost:9101
description: Local development server
- url: https://api.fetchml.example.com
description: Production server
security:
- ApiKeyAuth: []
paths:
/health:
get:
@ -43,7 +41,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HealthResponse'
/v1/tasks:
get:
summary: List tasks
@ -78,7 +75,6 @@ paths:
$ref: '#/components/responses/Unauthorized'
'429':
$ref: '#/components/responses/RateLimited'
post:
summary: Create task
description: Submit a new ML experiment task
@ -103,7 +99,6 @@ paths:
$ref: '#/components/responses/ValidationError'
'429':
$ref: '#/components/responses/RateLimited'
/v1/tasks/{taskId}:
get:
summary: Get task details
@ -122,7 +117,6 @@ paths:
$ref: '#/components/schemas/Task'
'404':
$ref: '#/components/responses/NotFound'
delete:
summary: Cancel/delete task
parameters:
@ -136,7 +130,6 @@ paths:
description: Task cancelled
'404':
$ref: '#/components/responses/NotFound'
/v1/queue:
get:
summary: Queue status
@ -148,7 +141,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/QueueStats'
/v1/experiments:
get:
summary: List experiments
@ -162,7 +154,6 @@ paths:
type: array
items:
$ref: '#/components/schemas/Experiment'
post:
summary: Create experiment
description: Create a new experiment
@ -179,7 +170,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/Experiment'
/v1/jupyter/services:
get:
summary: List Jupyter services
@ -192,7 +182,6 @@ paths:
type: array
items:
$ref: '#/components/schemas/JupyterService'
post:
summary: Start Jupyter service
requestBody:
@ -208,7 +197,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/JupyterService'
/v1/jupyter/services/{serviceId}:
delete:
summary: Stop Jupyter service
@ -221,13 +209,12 @@ paths:
responses:
'204':
description: Service stopped
/ws:
get:
summary: WebSocket connection
description: |
WebSocket endpoint for real-time task updates.
## Message Types
- `task_update`: Task status changes
- `task_complete`: Task finished
@ -237,7 +224,6 @@ paths:
responses:
'101':
description: WebSocket connection established
components:
securitySchemes:
ApiKeyAuth:
@ -245,7 +231,6 @@ components:
in: header
name: X-API-Key
description: API key for authentication
schemas:
HealthResponse:
type: object
@ -258,7 +243,6 @@ components:
timestamp:
type: string
format: date-time
Task:
type: object
properties:
@ -310,7 +294,6 @@ components:
type: integer
max_retries:
type: integer
CreateTaskRequest:
type: object
required:
@ -353,7 +336,6 @@ components:
type: object
additionalProperties:
type: string
DatasetSpec:
type: object
properties:
@ -365,7 +347,6 @@ components:
type: string
mount_path:
type: string
TaskList:
type: object
properties:
@ -379,7 +360,6 @@ components:
type: integer
offset:
type: integer
QueueStats:
type: object
properties:
@ -398,7 +378,6 @@ components:
workers:
type: integer
description: Active workers
Experiment:
type: object
properties:
@ -414,7 +393,6 @@ components:
status:
type: string
enum: [active, archived, deleted]
CreateExperimentRequest:
type: object
required:
@ -425,7 +403,6 @@ components:
maxLength: 128
description:
type: string
JupyterService:
type: object
properties:
@ -444,7 +421,6 @@ components:
created_at:
type: string
format: date-time
StartJupyterRequest:
type: object
required:
@ -457,7 +433,6 @@ components:
image:
type: string
default: jupyter/pytorch:latest
ErrorResponse:
type: object
required:
@ -474,7 +449,6 @@ components:
trace_id:
type: string
description: Support correlation ID
responses:
BadRequest:
description: Invalid request
@ -486,7 +460,6 @@ components:
error: Invalid request format
code: BAD_REQUEST
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
Unauthorized:
description: Authentication required
content:
@ -497,7 +470,6 @@ components:
error: Invalid or missing API key
code: UNAUTHORIZED
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
Forbidden:
description: Insufficient permissions
content:
@ -508,7 +480,6 @@ components:
error: Insufficient permissions
code: FORBIDDEN
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
NotFound:
description: Resource not found
content:
@ -519,7 +490,6 @@ components:
error: Resource not found
code: NOT_FOUND
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
ValidationError:
description: Validation failed
content:
@ -530,7 +500,6 @@ components:
error: Validation failed
code: VALIDATION_ERROR
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
RateLimited:
description: Too many requests
content:
@ -546,7 +515,6 @@ components:
schema:
type: integer
description: Seconds until rate limit resets
InternalError:
description: Internal server error
content:

View file

@ -42,24 +42,23 @@ const (
)
// sendErrorPacket sends an error response packet to the client
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
err := map[string]interface{}{
func sendErrorPacket(conn *websocket.Conn, message string) error {
err := map[string]any{
"error": true,
"code": code,
"code": ErrorCodeInvalidRequest,
"message": message,
"details": details,
}
return conn.WriteJSON(err)
}
// sendSuccessPacket sends a success response packet
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
return conn.WriteJSON(data)
}
// sendDataPacket sends a data response packet
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
return conn.WriteJSON(map[string]interface{}{
return conn.WriteJSON(map[string]any{
"type": dataType,
"payload": string(payload),
})
@ -86,9 +85,11 @@ func (h *Handler) HandleDatasetList(conn *websocket.Conn, payload []byte, user *
// HandleDatasetRegister handles registering a new dataset
// Protocol: [api_key_hash:16][name_len:1][name:var][path_len:2][path:var]
func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, user *auth.User) error {
func (h *Handler) HandleDatasetRegister(
conn *websocket.Conn, payload []byte, user *auth.User,
) error {
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "register dataset payload too short", "")
return sendErrorPacket(conn, "register dataset payload too short")
}
offset := 16
@ -96,7 +97,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
return sendErrorPacket(conn, "invalid name length")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
@ -104,7 +105,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
pathLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if pathLen < 0 || len(payload) < offset+pathLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid path length", "")
return sendErrorPacket(conn, "invalid path length")
}
path := string(payload[offset : offset+pathLen])
@ -121,7 +122,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"name": name,
"path": path,
@ -134,7 +135,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
// Protocol: [api_key_hash:16][dataset_id_len:1][dataset_id:var]
func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
return sendErrorPacket(conn, "dataset info payload too short")
}
offset := 16
@ -142,7 +143,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *
datasetIDLen := int(payload[offset])
offset++
if datasetIDLen <= 0 || len(payload) < offset+datasetIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset ID length", "")
return sendErrorPacket(conn, "invalid dataset ID length")
}
datasetID := string(payload[offset : offset+datasetIDLen])
@ -167,7 +168,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *
// Protocol: [api_key_hash:16][query_len:2][query:var]
func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
return sendErrorPacket(conn, "dataset search payload too short")
}
offset := 16
@ -175,7 +176,7 @@ func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user
queryLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if queryLen < 0 || len(payload) < offset+queryLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query length", "")
return sendErrorPacket(conn, "invalid query length")
}
query := string(payload[offset : offset+queryLen])

View file

@ -8,9 +8,9 @@ import (
// HealthStatus represents the health status of the service
type HealthStatus struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Checks map[string]string `json:"checks,omitempty"`
Status string `json:"status"`
}
// HealthHandler handles /health requests

View file

@ -3,6 +3,7 @@ package helpers
import (
"context"
"slices"
"time"
)
@ -29,12 +30,7 @@ func DBContextLong() (context.Context, context.CancelFunc) {
// StringSliceContains checks if a string slice contains a specific string.
func StringSliceContains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
return slices.Contains(slice, item)
}
// StringSliceFilter filters a string slice based on a predicate.

View file

@ -15,9 +15,9 @@ import (
// ExperimentSetupResult contains the result of experiment setup operations
type ExperimentSetupResult struct {
CommitIDStr string
Manifest *experiment.Manifest
Err error
Manifest *experiment.Manifest
CommitIDStr string
}
// RunExperimentSetup performs the common experiment setup operations:
@ -149,12 +149,14 @@ func UpsertExperimentDBAsync(
// TaskEnqueueResult contains the result of task enqueueing
type TaskEnqueueResult struct {
TaskID string
Err error
TaskID string
}
// BuildTaskMetadata creates the standard task metadata map.
func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string {
func BuildTaskMetadata(
commitIDStr, datasetID, paramsHash string, prov map[string]string,
) map[string]string {
meta := map[string]string{
"commit_id": commitIDStr,
"dataset_id": datasetID,
@ -169,7 +171,9 @@ func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[strin
}
// BuildSnapshotTaskMetadata creates task metadata for snapshot jobs.
func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string {
func BuildSnapshotTaskMetadata(
commitIDStr, snapshotSHA string, prov map[string]string,
) map[string]string {
meta := map[string]string{
"commit_id": commitIDStr,
"snapshot_sha256": snapshotSHA,

View file

@ -99,20 +99,20 @@ func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) e
return fmt.Errorf("missing commit id")
}
filesPath := expMgr.GetFilesPath(commitID)
if err := os.MkdirAll(filesPath, 0750); err != nil {
if err := os.MkdirAll(filesPath, 0o750); err != nil {
return err
}
trainPath := filepath.Join(filesPath, "train.py")
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0o640); err != nil {
return err
}
}
reqPath := filepath.Join(filesPath, "requirements.txt")
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0o640); err != nil {
return err
}
}

View file

@ -96,10 +96,10 @@ func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode {
// ResourceRequest represents resource requirements
type ResourceRequest struct {
GPUMemory string
CPU int
MemoryGB int
GPU int
GPUMemory string
}
// ParseResourceRequest parses an optional resource request from bytes.
@ -128,11 +128,11 @@ func ParseResourceRequest(payload []byte) (*ResourceRequest, error) {
// JSONResponseBuilder helps build JSON data responses
type JSONResponseBuilder struct {
data interface{}
data any
}
// NewJSONResponseBuilder creates a new JSON response builder
func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder {
func NewJSONResponseBuilder(data any) *JSONResponseBuilder {
return &JSONResponseBuilder{data: data}
}
@ -161,7 +161,7 @@ func IntPtr(i int) *int {
}
// MarshalJSONOrEmpty marshals data to JSON or returns empty array on error
func MarshalJSONOrEmpty(data interface{}) []byte {
func MarshalJSONOrEmpty(data any) []byte {
b, err := json.Marshal(data)
if err != nil {
return []byte("[]")
@ -170,7 +170,7 @@ func MarshalJSONOrEmpty(data interface{}) []byte {
}
// MarshalJSONBytes marshals data to JSON bytes with error handling
func MarshalJSONBytes(data interface{}) ([]byte, error) {
func MarshalJSONBytes(data any) ([]byte, error) {
return json.Marshal(data)
}

View file

@ -53,21 +53,21 @@ func ValidateDepsManifest(
// ValidateCheck represents a validation check result
type ValidateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
OK bool `json:"ok"`
}
// ValidateReport represents a validation report
type ValidateReport struct {
OK bool `json:"ok"`
Checks map[string]ValidateCheck `json:"checks"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]ValidateCheck `json:"checks"`
TS string `json:"ts"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
OK bool `json:"ok"`
}
// NewValidateReport creates a new validation report

View file

@ -2,19 +2,19 @@ package api
// MonitoringConfig holds monitoring-related configuration
type MonitoringConfig struct {
Prometheus PrometheusConfig `yaml:"prometheus"`
HealthChecks HealthChecksConfig `yaml:"health_checks"`
Prometheus PrometheusConfig `yaml:"prometheus"`
}
// PrometheusConfig holds Prometheus metrics configuration
type PrometheusConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
Port int `yaml:"port"`
Enabled bool `yaml:"enabled"`
}
// HealthChecksConfig holds health check configuration
type HealthChecksConfig struct {
Enabled bool `yaml:"enabled"`
Interval string `yaml:"interval"`
Enabled bool `yaml:"enabled"`
}

View file

@ -70,33 +70,21 @@ const (
// ResponsePacket represents a structured response packet
type ResponsePacket struct {
PacketType byte
Timestamp uint64
// Success fields
SuccessMessage string
// Error fields
ErrorCode byte
ErrorMessage string
ErrorDetails string
// Progress fields
ProgressType byte
DataType string
SuccessMessage string
LogMessage string
ErrorMessage string
ErrorDetails string
ProgressMessage string
StatusData string
DataPayload []byte
Timestamp uint64
ProgressValue uint32
ProgressTotal uint32
ProgressMessage string
// Status fields
StatusData string
// Data fields
DataType string
DataPayload []byte
// Log fields
LogLevel byte
LogMessage string
ErrorCode byte
ProgressType byte
LogLevel byte
PacketType byte
}
// NewSuccessPacket creates a success response packet

View file

@ -105,11 +105,9 @@ func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Han
e.ServeHTTP(w, r)
})
// Register Echo router at /v1/ prefix (and other generated paths)
// Register Echo router at /v1/ prefix
// These paths take precedence over legacy routes
mux.Handle("/health", echoHandler)
mux.Handle("/v1/", echoHandler)
mux.Handle("/ws", echoHandler)
s.logger.Info("OpenAPI-generated routes registered with Echo router")
}

View file

@ -21,18 +21,18 @@ import (
// Server represents the API server
type Server struct {
taskQueue queue.Backend
config *ServerConfig
httpServer *http.Server
logger *logging.Logger
expManager *experiment.Manager
taskQueue queue.Backend
db *storage.DB
sec *middleware.SecurityMiddleware
cleanupFuncs []func()
jupyterServiceMgr *jupyter.ServiceManager
auditLogger *audit.Logger
promMetrics *prommetrics.Metrics // Prometheus metrics
validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation
promMetrics *prommetrics.Metrics
validationMiddleware *apimiddleware.ValidationMiddleware
cleanupFuncs []func()
}
// NewServer creates a new API server

View file

@ -23,17 +23,17 @@ type QueueConfig struct {
// ServerConfig holds all server configuration
type ServerConfig struct {
Logging logging.Config `yaml:"logging"`
BasePath string `yaml:"base_path"`
DataDir string `yaml:"data_dir"`
Auth auth.Config `yaml:"auth"`
Database DatabaseConfig `yaml:"database"`
Server ServerSection `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Monitoring MonitoringConfig `yaml:"monitoring"`
Queue QueueConfig `yaml:"queue"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
Resources config.ResourceConfig `yaml:"resources"`
Security SecurityConfig `yaml:"security"`
}
// ServerSection holds server-specific configuration
@ -44,26 +44,26 @@ type ServerSection struct {
// TLSConfig holds TLS configuration
type TLSConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
Enabled bool `yaml:"enabled"`
}
// SecurityConfig holds security-related configuration
type SecurityConfig struct {
ProductionMode bool `yaml:"production_mode"`
AllowedOrigins []string `yaml:"allowed_origins"`
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
AuditLogging AuditLog `yaml:"audit_logging"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
AllowedOrigins []string `yaml:"allowed_origins"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
ProductionMode bool `yaml:"production_mode"`
}
// AuditLog holds audit logging configuration
type AuditLog struct {
Enabled bool `yaml:"enabled"`
LogPath string `yaml:"log_path"`
Enabled bool `yaml:"enabled"`
}
// RateLimitConfig holds rate limiting configuration
@ -75,17 +75,17 @@ type RateLimitConfig struct {
// LockoutConfig holds failed login lockout configuration
type LockoutConfig struct {
Enabled bool `yaml:"enabled"`
MaxAttempts int `yaml:"max_attempts"`
LockoutDuration string `yaml:"lockout_duration"`
MaxAttempts int `yaml:"max_attempts"`
Enabled bool `yaml:"enabled"`
}
// RedisConfig holds Redis connection configuration
type RedisConfig struct {
Addr string `yaml:"addr"`
Password string `yaml:"password"`
DB int `yaml:"db"`
URL string `yaml:"url"`
DB int `yaml:"db"`
}
// DatabaseConfig holds database connection configuration
@ -93,10 +93,10 @@ type DatabaseConfig struct {
Type string `yaml:"type"`
Connection string `yaml:"connection"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Database string `yaml:"database"`
Port int `yaml:"port"`
}
// LoadServerConfig loads and validates server configuration

View file

@ -11,6 +11,7 @@ import (
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
@ -123,30 +124,28 @@ const (
// Client represents a connected WebSocket client
type Client struct {
conn *websocket.Conn
Type ClientType
User string
RemoteAddr string
Type ClientType
}
// Handler provides WebSocket handling
type Handler struct {
authConfig *auth.Config
taskQueue queue.Backend
datasetsHandler *datasets.Handler
logger *logging.Logger
expManager *experiment.Manager
dataDir string
taskQueue queue.Backend
clients map[*Client]bool
db *storage.DB
jupyterServiceMgr *jupyter.ServiceManager
securityCfg *config.SecurityConfig
auditLogger *audit.Logger
upgrader websocket.Upgrader
authConfig *auth.Config
jobsHandler *jobs.Handler
jupyterHandler *jupyterj.Handler
datasetsHandler *datasets.Handler
// Client management for push updates
clients map[*Client]bool
clientsMu sync.RWMutex
upgrader websocket.Upgrader
dataDir string
clientsMu sync.RWMutex
}
// NewHandler creates a new WebSocket handler
@ -195,12 +194,7 @@ func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
// Production mode: strict checking against allowed origins
if securityCfg != nil && securityCfg.ProductionMode {
for _, allowed := range securityCfg.AllowedOrigins {
if origin == allowed {
return true
}
}
return false // Reject if not in allowed list
return slices.Contains(securityCfg.AllowedOrigins, origin)
}
// Development mode: allow localhost and local network origins
@ -231,7 +225,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
h.logger.Warn("error closing websocket connection", "error", err)
}
}()
h.handleConnection(conn)
}
@ -256,13 +254,15 @@ func (h *Handler) handleConnection(conn *websocket.Conn) {
h.clientsMu.Lock()
delete(h.clients, client)
h.clientsMu.Unlock()
conn.Close()
_ = conn.Close()
}()
for {
messageType, payload, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
if websocket.IsUnexpectedCloseError(
err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure,
) {
h.logger.Error("websocket read error", "error", err)
}
break
@ -366,10 +366,14 @@ func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload
// Handler stubs - delegate to sub-packages
func (h *Handler) withAuth(conn *websocket.Conn, payload []byte, handler func(*auth.User) error) error {
func (h *Handler) withAuth(
conn *websocket.Conn, payload []byte, handler func(*auth.User) error,
) error {
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
return handler(user)
}
@ -427,7 +431,9 @@ func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
offset := 16
@ -467,7 +473,9 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro
// Check authentication and permissions
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
@ -547,7 +555,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro
// Parse payload: [api_key_hash:16]
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
// Return queue status as Data packet
@ -571,7 +581,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro
// selectDependencyManifest auto-detects dependency manifest file
func selectDependencyManifest(filesPath string) (string, error) {
for _, name := range []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} {
for _, name := range []string{
"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle",
} {
if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil {
return name, nil
}
@ -584,7 +596,12 @@ func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
if len(payload) < 16 {
return nil, errors.New("payload too short")
}
return &auth.User{Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}}, nil
return &auth.User{
Name: "websocket-user",
Admin: false,
Roles: []string{"user"},
Permissions: map[string]bool{"jobs:read": true},
}, nil
}
// RequirePermission checks user permission
@ -604,7 +621,9 @@ func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
@ -666,7 +685,9 @@ func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
@ -708,7 +729,9 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
@ -729,7 +752,10 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
offset += 2
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
err := json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid options JSON", err.Error())
}
}
}
@ -764,7 +790,9 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsUpdate) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
@ -792,10 +820,17 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
}
// Validate outcome status
validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true}
validOutcomes := map[string]bool{
"validates": true, "refutes": true, "inconclusive": true, "partial": true,
}
outcome, ok := outcomeData["outcome"].(string)
if !ok || !validOutcomes[outcome] {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial")
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"invalid outcome status",
"must be: validates, refutes, inconclusive, or partial",
)
}
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)

View file

@ -7,6 +7,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/gorilla/websocket"
@ -14,6 +15,59 @@ import (
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
func (h *Handler) populateExperimentIntegrityMetadata(
task *queue.Task,
commitIDHex string,
) (string, error) {
if h.expManager == nil {
return "", nil
}
// Validate commit ID (defense-in-depth)
if len(commitIDHex) != 40 {
return "", fmt.Errorf("invalid commit id length")
}
if _, err := hex.DecodeString(commitIDHex); err != nil {
return "", fmt.Errorf("invalid commit id format")
}
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, err := selectDependencyManifest(filesPath)
if err != nil {
return "", err
}
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
basePath := filepath.Clean(h.expManager.BasePath())
manifestPath := filepath.Join(basePath, commitIDHex, "manifest.json")
manifestPath = filepath.Clean(manifestPath)
if !strings.HasPrefix(manifestPath, basePath+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected")
}
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
return depsName, nil
}
// handleQueueJob handles the QueueJob opcode (0x01)
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
@ -69,27 +123,10 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
Metadata: map[string]string{"commit_id": commitIDHex},
}
// Auto-detect deps manifest and compute manifest SHA
if h.expManager != nil {
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, _ := selectDependencyManifest(filesPath)
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
)
}
if h.taskQueue != nil {
@ -98,7 +135,7 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"task_id": task.ID,
})
@ -144,26 +181,10 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt
},
}
if h.expManager != nil {
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, _ := selectDependencyManifest(filesPath)
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
)
}
if h.taskQueue != nil {
@ -172,7 +193,7 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"task_id": task.ID,
})
@ -194,11 +215,13 @@ func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
task, err := h.taskQueue.GetTaskByName(jobName)
if err == nil && task != nil {
task.Status = "cancelled"
h.taskQueue.UpdateTask(task)
if err := h.taskQueue.UpdateTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to cancel task", err.Error())
}
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"message": "Job cancelled",
})
@ -217,7 +240,7 @@ func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
// pruneType := payload[offset]
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"message": "Prune completed",
"pruned": 0,

View file

@ -11,6 +11,14 @@ import (
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
const (
completed = "completed"
running = "running"
finished = "finished"
failed = "failed"
cancelled = "cancelled"
)
// handleValidateRequest handles the ValidateRequest opcode (0x16)
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
@ -25,7 +33,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
if mode == 0 {
// Commit ID validation (basic)
if len(payload) < 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "")
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "",
)
}
commitIDLen := int(payload[18])
if len(payload) < 19+commitIDLen {
@ -34,7 +44,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
commitIDBytes := payload[19 : 19+commitIDLen]
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
report := map[string]interface{}{
report := map[string]any{
"ok": true,
"commit_id": commitIDHex,
}
@ -44,7 +54,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
// Task ID validation (mode=1) - full validation with checks
if len(payload) < 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "")
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "payload too short for task validation", "",
)
}
taskIDLen := int(payload[18])
@ -54,7 +66,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
taskID := string(payload[19 : 19+taskIDLen])
// Initialize validation report
checks := make(map[string]interface{})
checks := make(map[string]any)
ok := true
// Get task from queue
@ -68,16 +80,16 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
}
// Run manifest validation - load manifest if it exists
rmCheck := map[string]interface{}{"ok": true}
rmCommitCheck := map[string]interface{}{"ok": true}
rmLocCheck := map[string]interface{}{"ok": true}
rmLifecycle := map[string]interface{}{"ok": true}
rmCheck := map[string]any{"ok": true}
rmCommitCheck := map[string]any{"ok": true}
rmLocCheck := map[string]any{"ok": true}
rmLifecycle := map[string]any{"ok": true}
var narrativeWarnings, outcomeWarnings []string
// Determine expected location based on task status
expectedLocation := "running"
if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" {
expectedLocation = "finished"
expectedLocation := running
if task.Status == completed || task.Status == cancelled || task.Status == failed {
expectedLocation = finished
}
// Try to load run manifest from appropriate location
@ -90,14 +102,14 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
// If not found and task is running, also check finished (wrong location test)
if rmLoadErr != nil && task.Status == "running" {
wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName)
if rmLoadErr != nil && task.Status == running {
wrongDir := filepath.Join(h.expManager.BasePath(), finished, task.JobName)
rm, _ = manifest.LoadFromDir(wrongDir)
if rm != nil {
// Manifest exists but in wrong location
rmLocCheck["ok"] = false
rmLocCheck["expected"] = "running"
rmLocCheck["actual"] = "finished"
rmLocCheck["expected"] = running
rmLocCheck["actual"] = finished
ok = false
}
}
@ -105,7 +117,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
if rm == nil {
// No run manifest found
if task.Status == "running" || task.Status == "completed" {
if task.Status == running || task.Status == completed {
rmCheck["ok"] = false
ok = false
}
@ -151,7 +163,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
checks["run_manifest_lifecycle"] = rmLifecycle
// Resources check
resCheck := map[string]interface{}{"ok": true}
resCheck := map[string]any{"ok": true}
if task.CPU < 0 {
resCheck["ok"] = false
ok = false
@ -159,7 +171,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
checks["resources"] = resCheck
// Snapshot check
snapCheck := map[string]interface{}{"ok": true}
snapCheck := map[string]any{"ok": true}
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
// Verify snapshot SHA
dataDir := h.dataDir
@ -177,7 +189,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
}
checks["snapshot"] = snapCheck
report := map[string]interface{}{
report := map[string]any{
"ok": ok,
"checks": checks,
"narrative_warnings": narrativeWarnings,

View file

@ -2,11 +2,11 @@ package config
// ResourceConfig centralizes pacing and resource optimization knobs.
type ResourceConfig struct {
PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"`
PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"`
MaxWorkers int `yaml:"max_workers" toml:"max_workers"`
DesiredRPSPerWorker int `yaml:"desired_rps_per_worker" toml:"desired_rps_per_worker"`
RequestsPerSec int `yaml:"requests_per_sec" toml:"requests_per_sec"`
PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"`
PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"`
RequestBurstOverride int `yaml:"request_burst" toml:"request_burst"`
}

View file

@ -7,33 +7,23 @@ import (
// SecurityConfig holds security-related configuration
type SecurityConfig struct {
// AllowedOrigins lists the allowed origins for WebSocket connections
// Empty list defaults to localhost-only in production mode
AllowedOrigins []string `yaml:"allowed_origins"`
// ProductionMode enables strict security checks
ProductionMode bool `yaml:"production_mode"`
// APIKeyRotationDays is the number of days before API keys should be rotated
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
// AuditLogging configuration
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
// IPWhitelist for additional connection filtering
IPWhitelist []string `yaml:"ip_whitelist"`
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
AllowedOrigins []string `yaml:"allowed_origins"`
IPWhitelist []string `yaml:"ip_whitelist"`
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
ProductionMode bool `yaml:"production_mode"`
}
// AuditLoggingConfig holds audit logging configuration
type AuditLoggingConfig struct {
Enabled bool `yaml:"enabled"`
LogPath string `yaml:"log_path"`
Enabled bool `yaml:"enabled"`
}
// PrivacyConfig holds privacy enforcement configuration
type PrivacyConfig struct {
DefaultLevel string `yaml:"default_level"`
Enabled bool `yaml:"enabled"`
DefaultLevel string `yaml:"default_level"` // private, team, public, anonymized
EnforceTeams bool `yaml:"enforce_teams"`
AuditAccess bool `yaml:"audit_access"`
}
@ -58,9 +48,9 @@ type MonitoringConfig struct {
// PrometheusConfig holds Prometheus metrics configuration
type PrometheusConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
Port int `yaml:"port"`
Enabled bool `yaml:"enabled"`
}
// HealthChecksConfig holds health check configuration

View file

@ -19,10 +19,10 @@ type RedisConfig struct {
// SSHConfig holds SSH connection settings
type SSHConfig struct {
Host string `yaml:"host" json:"host"`
Port int `yaml:"port" json:"port"`
User string `yaml:"user" json:"user"`
KeyPath string `yaml:"key_path" json:"key_path"`
KnownHosts string `yaml:"known_hosts" json:"known_hosts"`
Port int `yaml:"port" json:"port"`
}
// ExpandPath expands environment variables and tilde in a path