feat(api): integrate scheduler protocol and WebSocket enhancements
Update API layer for scheduler integration: - WebSocket handlers with scheduler protocol support - Jobs WebSocket endpoint with priority queue integration - Validation middleware for scheduler messages - Server configuration with security hardening - Protocol definitions for worker-scheduler communication - Dataset handlers with tenant isolation checks - Response helpers with audit context - OpenAPI spec updates for new endpoints
This commit is contained in:
parent
9b2d5986a3
commit
420de879ff
19 changed files with 259 additions and 244 deletions
|
|
@ -1,13 +1,14 @@
|
|||
---
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: ML Worker API
|
||||
description: |
|
||||
API for managing ML experiment tasks and Jupyter services.
|
||||
|
||||
|
||||
## Security
|
||||
All endpoints (except health checks) require API key authentication via the
|
||||
`X-API-Key` header. Rate limiting is enforced per API key.
|
||||
|
||||
|
||||
## Error Handling
|
||||
Errors follow a consistent format with machine-readable codes and trace IDs:
|
||||
```json
|
||||
|
|
@ -20,16 +21,13 @@ info:
|
|||
version: 1.0.0
|
||||
contact:
|
||||
name: FetchML Support
|
||||
|
||||
servers:
|
||||
- url: http://localhost:9101
|
||||
description: Local development server
|
||||
- url: https://api.fetchml.example.com
|
||||
description: Production server
|
||||
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
|
||||
paths:
|
||||
/health:
|
||||
get:
|
||||
|
|
@ -43,7 +41,6 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HealthResponse'
|
||||
|
||||
/v1/tasks:
|
||||
get:
|
||||
summary: List tasks
|
||||
|
|
@ -78,7 +75,6 @@ paths:
|
|||
$ref: '#/components/responses/Unauthorized'
|
||||
'429':
|
||||
$ref: '#/components/responses/RateLimited'
|
||||
|
||||
post:
|
||||
summary: Create task
|
||||
description: Submit a new ML experiment task
|
||||
|
|
@ -103,7 +99,6 @@ paths:
|
|||
$ref: '#/components/responses/ValidationError'
|
||||
'429':
|
||||
$ref: '#/components/responses/RateLimited'
|
||||
|
||||
/v1/tasks/{taskId}:
|
||||
get:
|
||||
summary: Get task details
|
||||
|
|
@ -122,7 +117,6 @@ paths:
|
|||
$ref: '#/components/schemas/Task'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
delete:
|
||||
summary: Cancel/delete task
|
||||
parameters:
|
||||
|
|
@ -136,7 +130,6 @@ paths:
|
|||
description: Task cancelled
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
/v1/queue:
|
||||
get:
|
||||
summary: Queue status
|
||||
|
|
@ -148,7 +141,6 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueueStats'
|
||||
|
||||
/v1/experiments:
|
||||
get:
|
||||
summary: List experiments
|
||||
|
|
@ -162,7 +154,6 @@ paths:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Experiment'
|
||||
|
||||
post:
|
||||
summary: Create experiment
|
||||
description: Create a new experiment
|
||||
|
|
@ -179,7 +170,6 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Experiment'
|
||||
|
||||
/v1/jupyter/services:
|
||||
get:
|
||||
summary: List Jupyter services
|
||||
|
|
@ -192,7 +182,6 @@ paths:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/JupyterService'
|
||||
|
||||
post:
|
||||
summary: Start Jupyter service
|
||||
requestBody:
|
||||
|
|
@ -208,7 +197,6 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JupyterService'
|
||||
|
||||
/v1/jupyter/services/{serviceId}:
|
||||
delete:
|
||||
summary: Stop Jupyter service
|
||||
|
|
@ -221,13 +209,12 @@ paths:
|
|||
responses:
|
||||
'204':
|
||||
description: Service stopped
|
||||
|
||||
/ws:
|
||||
get:
|
||||
summary: WebSocket connection
|
||||
description: |
|
||||
WebSocket endpoint for real-time task updates.
|
||||
|
||||
|
||||
## Message Types
|
||||
- `task_update`: Task status changes
|
||||
- `task_complete`: Task finished
|
||||
|
|
@ -237,7 +224,6 @@ paths:
|
|||
responses:
|
||||
'101':
|
||||
description: WebSocket connection established
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
ApiKeyAuth:
|
||||
|
|
@ -245,7 +231,6 @@ components:
|
|||
in: header
|
||||
name: X-API-Key
|
||||
description: API key for authentication
|
||||
|
||||
schemas:
|
||||
HealthResponse:
|
||||
type: object
|
||||
|
|
@ -258,7 +243,6 @@ components:
|
|||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
Task:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -310,7 +294,6 @@ components:
|
|||
type: integer
|
||||
max_retries:
|
||||
type: integer
|
||||
|
||||
CreateTaskRequest:
|
||||
type: object
|
||||
required:
|
||||
|
|
@ -353,7 +336,6 @@ components:
|
|||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
|
||||
DatasetSpec:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -365,7 +347,6 @@ components:
|
|||
type: string
|
||||
mount_path:
|
||||
type: string
|
||||
|
||||
TaskList:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -379,7 +360,6 @@ components:
|
|||
type: integer
|
||||
offset:
|
||||
type: integer
|
||||
|
||||
QueueStats:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -398,7 +378,6 @@ components:
|
|||
workers:
|
||||
type: integer
|
||||
description: Active workers
|
||||
|
||||
Experiment:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -414,7 +393,6 @@ components:
|
|||
status:
|
||||
type: string
|
||||
enum: [active, archived, deleted]
|
||||
|
||||
CreateExperimentRequest:
|
||||
type: object
|
||||
required:
|
||||
|
|
@ -425,7 +403,6 @@ components:
|
|||
maxLength: 128
|
||||
description:
|
||||
type: string
|
||||
|
||||
JupyterService:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -444,7 +421,6 @@ components:
|
|||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
StartJupyterRequest:
|
||||
type: object
|
||||
required:
|
||||
|
|
@ -457,7 +433,6 @@ components:
|
|||
image:
|
||||
type: string
|
||||
default: jupyter/pytorch:latest
|
||||
|
||||
ErrorResponse:
|
||||
type: object
|
||||
required:
|
||||
|
|
@ -474,7 +449,6 @@ components:
|
|||
trace_id:
|
||||
type: string
|
||||
description: Support correlation ID
|
||||
|
||||
responses:
|
||||
BadRequest:
|
||||
description: Invalid request
|
||||
|
|
@ -486,7 +460,6 @@ components:
|
|||
error: Invalid request format
|
||||
code: BAD_REQUEST
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
Unauthorized:
|
||||
description: Authentication required
|
||||
content:
|
||||
|
|
@ -497,7 +470,6 @@ components:
|
|||
error: Invalid or missing API key
|
||||
code: UNAUTHORIZED
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
Forbidden:
|
||||
description: Insufficient permissions
|
||||
content:
|
||||
|
|
@ -508,7 +480,6 @@ components:
|
|||
error: Insufficient permissions
|
||||
code: FORBIDDEN
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
|
|
@ -519,7 +490,6 @@ components:
|
|||
error: Resource not found
|
||||
code: NOT_FOUND
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
ValidationError:
|
||||
description: Validation failed
|
||||
content:
|
||||
|
|
@ -530,7 +500,6 @@ components:
|
|||
error: Validation failed
|
||||
code: VALIDATION_ERROR
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
RateLimited:
|
||||
description: Too many requests
|
||||
content:
|
||||
|
|
@ -546,7 +515,6 @@ components:
|
|||
schema:
|
||||
type: integer
|
||||
description: Seconds until rate limit resets
|
||||
|
||||
InternalError:
|
||||
description: Internal server error
|
||||
content:
|
||||
|
|
|
|||
|
|
@ -42,24 +42,23 @@ const (
|
|||
)
|
||||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
func sendErrorPacket(conn *websocket.Conn, message string) error {
|
||||
err := map[string]any{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"code": ErrorCodeInvalidRequest,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// sendDataPacket sends a data response packet
|
||||
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
|
||||
return conn.WriteJSON(map[string]interface{}{
|
||||
return conn.WriteJSON(map[string]any{
|
||||
"type": dataType,
|
||||
"payload": string(payload),
|
||||
})
|
||||
|
|
@ -86,9 +85,11 @@ func (h *Handler) HandleDatasetList(conn *websocket.Conn, payload []byte, user *
|
|||
|
||||
// HandleDatasetRegister handles registering a new dataset
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var][path_len:2][path:var]
|
||||
func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
func (h *Handler) HandleDatasetRegister(
|
||||
conn *websocket.Conn, payload []byte, user *auth.User,
|
||||
) error {
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "register dataset payload too short", "")
|
||||
return sendErrorPacket(conn, "register dataset payload too short")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
|
@ -96,7 +97,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
|
|||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
return sendErrorPacket(conn, "invalid name length")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
|
@ -104,7 +105,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
|
|||
pathLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if pathLen < 0 || len(payload) < offset+pathLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid path length", "")
|
||||
return sendErrorPacket(conn, "invalid path length")
|
||||
}
|
||||
path := string(payload[offset : offset+pathLen])
|
||||
|
||||
|
|
@ -121,7 +122,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
|
|||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"name": name,
|
||||
"path": path,
|
||||
|
|
@ -134,7 +135,7 @@ func (h *Handler) HandleDatasetRegister(conn *websocket.Conn, payload []byte, us
|
|||
// Protocol: [api_key_hash:16][dataset_id_len:1][dataset_id:var]
|
||||
func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
|
||||
return sendErrorPacket(conn, "dataset info payload too short")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
|
@ -142,7 +143,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *
|
|||
datasetIDLen := int(payload[offset])
|
||||
offset++
|
||||
if datasetIDLen <= 0 || len(payload) < offset+datasetIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset ID length", "")
|
||||
return sendErrorPacket(conn, "invalid dataset ID length")
|
||||
}
|
||||
datasetID := string(payload[offset : offset+datasetIDLen])
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ func (h *Handler) HandleDatasetInfo(conn *websocket.Conn, payload []byte, user *
|
|||
// Protocol: [api_key_hash:16][query_len:2][query:var]
|
||||
func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
|
||||
return sendErrorPacket(conn, "dataset search payload too short")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
|
@ -175,7 +176,7 @@ func (h *Handler) HandleDatasetSearch(conn *websocket.Conn, payload []byte, user
|
|||
queryLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if queryLen < 0 || len(payload) < offset+queryLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query length", "")
|
||||
return sendErrorPacket(conn, "invalid query length")
|
||||
}
|
||||
query := string(payload[offset : offset+queryLen])
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ import (
|
|||
|
||||
// HealthStatus represents the health status of the service
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Checks map[string]string `json:"checks,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// HealthHandler handles /health requests
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package helpers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -29,12 +30,7 @@ func DBContextLong() (context.Context, context.CancelFunc) {
|
|||
|
||||
// StringSliceContains checks if a string slice contains a specific string.
|
||||
func StringSliceContains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
// StringSliceFilter filters a string slice based on a predicate.
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ import (
|
|||
|
||||
// ExperimentSetupResult contains the result of experiment setup operations
|
||||
type ExperimentSetupResult struct {
|
||||
CommitIDStr string
|
||||
Manifest *experiment.Manifest
|
||||
Err error
|
||||
Manifest *experiment.Manifest
|
||||
CommitIDStr string
|
||||
}
|
||||
|
||||
// RunExperimentSetup performs the common experiment setup operations:
|
||||
|
|
@ -149,12 +149,14 @@ func UpsertExperimentDBAsync(
|
|||
|
||||
// TaskEnqueueResult contains the result of task enqueueing
|
||||
type TaskEnqueueResult struct {
|
||||
TaskID string
|
||||
Err error
|
||||
TaskID string
|
||||
}
|
||||
|
||||
// BuildTaskMetadata creates the standard task metadata map.
|
||||
func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string {
|
||||
func BuildTaskMetadata(
|
||||
commitIDStr, datasetID, paramsHash string, prov map[string]string,
|
||||
) map[string]string {
|
||||
meta := map[string]string{
|
||||
"commit_id": commitIDStr,
|
||||
"dataset_id": datasetID,
|
||||
|
|
@ -169,7 +171,9 @@ func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[strin
|
|||
}
|
||||
|
||||
// BuildSnapshotTaskMetadata creates task metadata for snapshot jobs.
|
||||
func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string {
|
||||
func BuildSnapshotTaskMetadata(
|
||||
commitIDStr, snapshotSHA string, prov map[string]string,
|
||||
) map[string]string {
|
||||
meta := map[string]string{
|
||||
"commit_id": commitIDStr,
|
||||
"snapshot_sha256": snapshotSHA,
|
||||
|
|
|
|||
|
|
@ -99,20 +99,20 @@ func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) e
|
|||
return fmt.Errorf("missing commit id")
|
||||
}
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
if err := os.MkdirAll(filesPath, 0750); err != nil {
|
||||
if err := os.MkdirAll(filesPath, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trainPath := filepath.Join(filesPath, "train.py")
|
||||
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
|
||||
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
|
||||
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0o640); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reqPath := filepath.Join(filesPath, "requirements.txt")
|
||||
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
|
||||
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
|
||||
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0o640); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,10 +96,10 @@ func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode {
|
|||
|
||||
// ResourceRequest represents resource requirements
|
||||
type ResourceRequest struct {
|
||||
GPUMemory string
|
||||
CPU int
|
||||
MemoryGB int
|
||||
GPU int
|
||||
GPUMemory string
|
||||
}
|
||||
|
||||
// ParseResourceRequest parses an optional resource request from bytes.
|
||||
|
|
@ -128,11 +128,11 @@ func ParseResourceRequest(payload []byte) (*ResourceRequest, error) {
|
|||
|
||||
// JSONResponseBuilder helps build JSON data responses
|
||||
type JSONResponseBuilder struct {
|
||||
data interface{}
|
||||
data any
|
||||
}
|
||||
|
||||
// NewJSONResponseBuilder creates a new JSON response builder
|
||||
func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder {
|
||||
func NewJSONResponseBuilder(data any) *JSONResponseBuilder {
|
||||
return &JSONResponseBuilder{data: data}
|
||||
}
|
||||
|
||||
|
|
@ -161,7 +161,7 @@ func IntPtr(i int) *int {
|
|||
}
|
||||
|
||||
// MarshalJSONOrEmpty marshals data to JSON or returns empty array on error
|
||||
func MarshalJSONOrEmpty(data interface{}) []byte {
|
||||
func MarshalJSONOrEmpty(data any) []byte {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return []byte("[]")
|
||||
|
|
@ -170,7 +170,7 @@ func MarshalJSONOrEmpty(data interface{}) []byte {
|
|||
}
|
||||
|
||||
// MarshalJSONBytes marshals data to JSON bytes with error handling
|
||||
func MarshalJSONBytes(data interface{}) ([]byte, error) {
|
||||
func MarshalJSONBytes(data any) ([]byte, error) {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -53,21 +53,21 @@ func ValidateDepsManifest(
|
|||
|
||||
// ValidateCheck represents a validation check result
|
||||
type ValidateCheck struct {
|
||||
OK bool `json:"ok"`
|
||||
Expected string `json:"expected,omitempty"`
|
||||
Actual string `json:"actual,omitempty"`
|
||||
Details string `json:"details,omitempty"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// ValidateReport represents a validation report
|
||||
type ValidateReport struct {
|
||||
OK bool `json:"ok"`
|
||||
Checks map[string]ValidateCheck `json:"checks"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Checks map[string]ValidateCheck `json:"checks"`
|
||||
TS string `json:"ts"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
TS string `json:"ts"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// NewValidateReport creates a new validation report
|
||||
|
|
|
|||
|
|
@ -2,19 +2,19 @@ package api
|
|||
|
||||
// MonitoringConfig holds monitoring-related configuration
|
||||
type MonitoringConfig struct {
|
||||
Prometheus PrometheusConfig `yaml:"prometheus"`
|
||||
HealthChecks HealthChecksConfig `yaml:"health_checks"`
|
||||
Prometheus PrometheusConfig `yaml:"prometheus"`
|
||||
}
|
||||
|
||||
// PrometheusConfig holds Prometheus metrics configuration
|
||||
type PrometheusConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
Path string `yaml:"path"`
|
||||
Port int `yaml:"port"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// HealthChecksConfig holds health check configuration
|
||||
type HealthChecksConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Interval string `yaml:"interval"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,33 +70,21 @@ const (
|
|||
|
||||
// ResponsePacket represents a structured response packet
|
||||
type ResponsePacket struct {
|
||||
PacketType byte
|
||||
Timestamp uint64
|
||||
|
||||
// Success fields
|
||||
SuccessMessage string
|
||||
|
||||
// Error fields
|
||||
ErrorCode byte
|
||||
ErrorMessage string
|
||||
ErrorDetails string
|
||||
|
||||
// Progress fields
|
||||
ProgressType byte
|
||||
DataType string
|
||||
SuccessMessage string
|
||||
LogMessage string
|
||||
ErrorMessage string
|
||||
ErrorDetails string
|
||||
ProgressMessage string
|
||||
StatusData string
|
||||
DataPayload []byte
|
||||
Timestamp uint64
|
||||
ProgressValue uint32
|
||||
ProgressTotal uint32
|
||||
ProgressMessage string
|
||||
|
||||
// Status fields
|
||||
StatusData string
|
||||
|
||||
// Data fields
|
||||
DataType string
|
||||
DataPayload []byte
|
||||
|
||||
// Log fields
|
||||
LogLevel byte
|
||||
LogMessage string
|
||||
ErrorCode byte
|
||||
ProgressType byte
|
||||
LogLevel byte
|
||||
PacketType byte
|
||||
}
|
||||
|
||||
// NewSuccessPacket creates a success response packet
|
||||
|
|
|
|||
|
|
@ -105,11 +105,9 @@ func (s *Server) registerOpenAPIRoutes(mux *http.ServeMux, jobsHandler *jobs.Han
|
|||
e.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Register Echo router at /v1/ prefix (and other generated paths)
|
||||
// Register Echo router at /v1/ prefix
|
||||
// These paths take precedence over legacy routes
|
||||
mux.Handle("/health", echoHandler)
|
||||
mux.Handle("/v1/", echoHandler)
|
||||
mux.Handle("/ws", echoHandler)
|
||||
|
||||
s.logger.Info("OpenAPI-generated routes registered with Echo router")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,18 +21,18 @@ import (
|
|||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
taskQueue queue.Backend
|
||||
config *ServerConfig
|
||||
httpServer *http.Server
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
taskQueue queue.Backend
|
||||
db *storage.DB
|
||||
sec *middleware.SecurityMiddleware
|
||||
cleanupFuncs []func()
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
auditLogger *audit.Logger
|
||||
promMetrics *prommetrics.Metrics // Prometheus metrics
|
||||
validationMiddleware *apimiddleware.ValidationMiddleware // OpenAPI validation
|
||||
promMetrics *prommetrics.Metrics
|
||||
validationMiddleware *apimiddleware.ValidationMiddleware
|
||||
cleanupFuncs []func()
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
|
|
|
|||
|
|
@ -23,17 +23,17 @@ type QueueConfig struct {
|
|||
|
||||
// ServerConfig holds all server configuration
|
||||
type ServerConfig struct {
|
||||
Logging logging.Config `yaml:"logging"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
Auth auth.Config `yaml:"auth"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Server ServerSection `yaml:"server"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Monitoring MonitoringConfig `yaml:"monitoring"`
|
||||
Queue QueueConfig `yaml:"queue"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Logging logging.Config `yaml:"logging"`
|
||||
Resources config.ResourceConfig `yaml:"resources"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
}
|
||||
|
||||
// ServerSection holds server-specific configuration
|
||||
|
|
@ -44,26 +44,26 @@ type ServerSection struct {
|
|||
|
||||
// TLSConfig holds TLS configuration
|
||||
type TLSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// SecurityConfig holds security-related configuration
|
||||
type SecurityConfig struct {
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
AuditLogging AuditLog `yaml:"audit_logging"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
}
|
||||
|
||||
// AuditLog holds audit logging configuration
|
||||
type AuditLog struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogPath string `yaml:"log_path"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// RateLimitConfig holds rate limiting configuration
|
||||
|
|
@ -75,17 +75,17 @@ type RateLimitConfig struct {
|
|||
|
||||
// LockoutConfig holds failed login lockout configuration
|
||||
type LockoutConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MaxAttempts int `yaml:"max_attempts"`
|
||||
LockoutDuration string `yaml:"lockout_duration"`
|
||||
MaxAttempts int `yaml:"max_attempts"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// RedisConfig holds Redis connection configuration
|
||||
type RedisConfig struct {
|
||||
Addr string `yaml:"addr"`
|
||||
Password string `yaml:"password"`
|
||||
DB int `yaml:"db"`
|
||||
URL string `yaml:"url"`
|
||||
DB int `yaml:"db"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds database connection configuration
|
||||
|
|
@ -93,10 +93,10 @@ type DatabaseConfig struct {
|
|||
Type string `yaml:"type"`
|
||||
Connection string `yaml:"connection"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
Port int `yaml:"port"`
|
||||
}
|
||||
|
||||
// LoadServerConfig loads and validates server configuration
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -123,30 +124,28 @@ const (
|
|||
// Client represents a connected WebSocket client
|
||||
type Client struct {
|
||||
conn *websocket.Conn
|
||||
Type ClientType
|
||||
User string
|
||||
RemoteAddr string
|
||||
Type ClientType
|
||||
}
|
||||
|
||||
// Handler provides WebSocket handling
|
||||
type Handler struct {
|
||||
authConfig *auth.Config
|
||||
taskQueue queue.Backend
|
||||
datasetsHandler *datasets.Handler
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
dataDir string
|
||||
taskQueue queue.Backend
|
||||
clients map[*Client]bool
|
||||
db *storage.DB
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
securityCfg *config.SecurityConfig
|
||||
auditLogger *audit.Logger
|
||||
upgrader websocket.Upgrader
|
||||
authConfig *auth.Config
|
||||
jobsHandler *jobs.Handler
|
||||
jupyterHandler *jupyterj.Handler
|
||||
datasetsHandler *datasets.Handler
|
||||
|
||||
// Client management for push updates
|
||||
clients map[*Client]bool
|
||||
clientsMu sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
dataDir string
|
||||
clientsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new WebSocket handler
|
||||
|
|
@ -195,12 +194,7 @@ func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
|||
|
||||
// Production mode: strict checking against allowed origins
|
||||
if securityCfg != nil && securityCfg.ProductionMode {
|
||||
for _, allowed := range securityCfg.AllowedOrigins {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false // Reject if not in allowed list
|
||||
return slices.Contains(securityCfg.AllowedOrigins, origin)
|
||||
}
|
||||
|
||||
// Development mode: allow localhost and local network origins
|
||||
|
|
@ -231,7 +225,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
h.logger.Warn("error closing websocket connection", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h.handleConnection(conn)
|
||||
}
|
||||
|
|
@ -256,13 +254,15 @@ func (h *Handler) handleConnection(conn *websocket.Conn) {
|
|||
h.clientsMu.Lock()
|
||||
delete(h.clients, client)
|
||||
h.clientsMu.Unlock()
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
messageType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
if websocket.IsUnexpectedCloseError(
|
||||
err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure,
|
||||
) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
|
|
@ -366,10 +366,14 @@ func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload
|
|||
|
||||
// Handler stubs - delegate to sub-packages
|
||||
|
||||
func (h *Handler) withAuth(conn *websocket.Conn, payload []byte, handler func(*auth.User) error) error {
|
||||
func (h *Handler) withAuth(
|
||||
conn *websocket.Conn, payload []byte, handler func(*auth.User) error,
|
||||
) error {
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
return handler(user)
|
||||
}
|
||||
|
|
@ -427,7 +431,9 @@ func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
|||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
|
@ -467,7 +473,9 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro
|
|||
// Check authentication and permissions
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
|
|
@ -547,7 +555,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro
|
|||
// Parse payload: [api_key_hash:16]
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Return queue status as Data packet
|
||||
|
|
@ -571,7 +581,9 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) erro
|
|||
|
||||
// selectDependencyManifest auto-detects dependency manifest file
|
||||
func selectDependencyManifest(filesPath string) (string, error) {
|
||||
for _, name := range []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} {
|
||||
for _, name := range []string{
|
||||
"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle",
|
||||
} {
|
||||
if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil {
|
||||
return name, nil
|
||||
}
|
||||
|
|
@ -584,7 +596,12 @@ func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
|
|||
if len(payload) < 16 {
|
||||
return nil, errors.New("payload too short")
|
||||
}
|
||||
return &auth.User{Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}}, nil
|
||||
return &auth.User{
|
||||
Name: "websocket-user",
|
||||
Admin: false,
|
||||
Roles: []string{"user"},
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequirePermission checks user permission
|
||||
|
|
@ -604,7 +621,9 @@ func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error
|
|||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
|
|
@ -666,7 +685,9 @@ func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
|
|||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
|
|
@ -708,7 +729,9 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
|
|||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
|
|
@ -729,7 +752,10 @@ func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
|
|||
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
||||
offset += 2
|
||||
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
|
||||
json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
|
||||
err := json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid options JSON", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -764,7 +790,9 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
|
|||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsUpdate) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
|
|
@ -792,10 +820,17 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
|
|||
}
|
||||
|
||||
// Validate outcome status
|
||||
validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true}
|
||||
validOutcomes := map[string]bool{
|
||||
"validates": true, "refutes": true, "inconclusive": true, "partial": true,
|
||||
}
|
||||
outcome, ok := outcomeData["outcome"].(string)
|
||||
if !ok || !validOutcomes[outcome] {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial")
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeInvalidRequest,
|
||||
"invalid outcome status",
|
||||
"must be: validates, refutes, inconclusive, or partial",
|
||||
)
|
||||
}
|
||||
|
||||
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
|
@ -14,6 +15,59 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
func (h *Handler) populateExperimentIntegrityMetadata(
|
||||
task *queue.Task,
|
||||
commitIDHex string,
|
||||
) (string, error) {
|
||||
if h.expManager == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Validate commit ID (defense-in-depth)
|
||||
if len(commitIDHex) != 40 {
|
||||
return "", fmt.Errorf("invalid commit id length")
|
||||
}
|
||||
if _, err := hex.DecodeString(commitIDHex); err != nil {
|
||||
return "", fmt.Errorf("invalid commit id format")
|
||||
}
|
||||
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
|
||||
depsName, err := selectDependencyManifest(filesPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
basePath := filepath.Clean(h.expManager.BasePath())
|
||||
manifestPath := filepath.Join(basePath, commitIDHex, "manifest.json")
|
||||
manifestPath = filepath.Clean(manifestPath)
|
||||
|
||||
if !strings.HasPrefix(manifestPath, basePath+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("path traversal detected")
|
||||
}
|
||||
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct {
|
||||
OverallSHA string `json:"overall_sha"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
|
||||
return depsName, nil
|
||||
}
|
||||
|
||||
// handleQueueJob handles the QueueJob opcode (0x01)
|
||||
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
||||
|
|
@ -69,27 +123,10 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|||
Metadata: map[string]string{"commit_id": commitIDHex},
|
||||
}
|
||||
|
||||
// Auto-detect deps manifest and compute manifest SHA
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct {
|
||||
OverallSHA string `json:"overall_sha"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
if h.taskQueue != nil {
|
||||
|
|
@ -98,7 +135,7 @@ func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
|
|
@ -144,26 +181,10 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt
|
|||
},
|
||||
}
|
||||
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct {
|
||||
OverallSHA string `json:"overall_sha"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
if h.taskQueue != nil {
|
||||
|
|
@ -172,7 +193,7 @@ func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byt
|
|||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
|
|
@ -194,11 +215,13 @@ func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
|||
task, err := h.taskQueue.GetTaskByName(jobName)
|
||||
if err == nil && task != nil {
|
||||
task.Status = "cancelled"
|
||||
h.taskQueue.UpdateTask(task)
|
||||
if err := h.taskQueue.UpdateTask(task); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to cancel task", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"message": "Job cancelled",
|
||||
})
|
||||
|
|
@ -217,7 +240,7 @@ func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
|||
// pruneType := payload[offset]
|
||||
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"message": "Prune completed",
|
||||
"pruned": 0,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,14 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
const (
|
||||
completed = "completed"
|
||||
running = "running"
|
||||
finished = "finished"
|
||||
failed = "failed"
|
||||
cancelled = "cancelled"
|
||||
)
|
||||
|
||||
// handleValidateRequest handles the ValidateRequest opcode (0x16)
|
||||
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
|
||||
|
|
@ -25,7 +33,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
if mode == 0 {
|
||||
// Commit ID validation (basic)
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "")
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "",
|
||||
)
|
||||
}
|
||||
commitIDLen := int(payload[18])
|
||||
if len(payload) < 19+commitIDLen {
|
||||
|
|
@ -34,7 +44,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
commitIDBytes := payload[19 : 19+commitIDLen]
|
||||
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
|
||||
|
||||
report := map[string]interface{}{
|
||||
report := map[string]any{
|
||||
"ok": true,
|
||||
"commit_id": commitIDHex,
|
||||
}
|
||||
|
|
@ -44,7 +54,9 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
|
||||
// Task ID validation (mode=1) - full validation with checks
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "")
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeInvalidRequest, "payload too short for task validation", "",
|
||||
)
|
||||
}
|
||||
|
||||
taskIDLen := int(payload[18])
|
||||
|
|
@ -54,7 +66,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
taskID := string(payload[19 : 19+taskIDLen])
|
||||
|
||||
// Initialize validation report
|
||||
checks := make(map[string]interface{})
|
||||
checks := make(map[string]any)
|
||||
ok := true
|
||||
|
||||
// Get task from queue
|
||||
|
|
@ -68,16 +80,16 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
}
|
||||
|
||||
// Run manifest validation - load manifest if it exists
|
||||
rmCheck := map[string]interface{}{"ok": true}
|
||||
rmCommitCheck := map[string]interface{}{"ok": true}
|
||||
rmLocCheck := map[string]interface{}{"ok": true}
|
||||
rmLifecycle := map[string]interface{}{"ok": true}
|
||||
rmCheck := map[string]any{"ok": true}
|
||||
rmCommitCheck := map[string]any{"ok": true}
|
||||
rmLocCheck := map[string]any{"ok": true}
|
||||
rmLifecycle := map[string]any{"ok": true}
|
||||
var narrativeWarnings, outcomeWarnings []string
|
||||
|
||||
// Determine expected location based on task status
|
||||
expectedLocation := "running"
|
||||
if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" {
|
||||
expectedLocation = "finished"
|
||||
expectedLocation := running
|
||||
if task.Status == completed || task.Status == cancelled || task.Status == failed {
|
||||
expectedLocation = finished
|
||||
}
|
||||
|
||||
// Try to load run manifest from appropriate location
|
||||
|
|
@ -90,14 +102,14 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
|
||||
|
||||
// If not found and task is running, also check finished (wrong location test)
|
||||
if rmLoadErr != nil && task.Status == "running" {
|
||||
wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName)
|
||||
if rmLoadErr != nil && task.Status == running {
|
||||
wrongDir := filepath.Join(h.expManager.BasePath(), finished, task.JobName)
|
||||
rm, _ = manifest.LoadFromDir(wrongDir)
|
||||
if rm != nil {
|
||||
// Manifest exists but in wrong location
|
||||
rmLocCheck["ok"] = false
|
||||
rmLocCheck["expected"] = "running"
|
||||
rmLocCheck["actual"] = "finished"
|
||||
rmLocCheck["expected"] = running
|
||||
rmLocCheck["actual"] = finished
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
|
@ -105,7 +117,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
|
||||
if rm == nil {
|
||||
// No run manifest found
|
||||
if task.Status == "running" || task.Status == "completed" {
|
||||
if task.Status == running || task.Status == completed {
|
||||
rmCheck["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
|
|
@ -151,7 +163,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
checks["run_manifest_lifecycle"] = rmLifecycle
|
||||
|
||||
// Resources check
|
||||
resCheck := map[string]interface{}{"ok": true}
|
||||
resCheck := map[string]any{"ok": true}
|
||||
if task.CPU < 0 {
|
||||
resCheck["ok"] = false
|
||||
ok = false
|
||||
|
|
@ -159,7 +171,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
checks["resources"] = resCheck
|
||||
|
||||
// Snapshot check
|
||||
snapCheck := map[string]interface{}{"ok": true}
|
||||
snapCheck := map[string]any{"ok": true}
|
||||
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
|
||||
// Verify snapshot SHA
|
||||
dataDir := h.dataDir
|
||||
|
|
@ -177,7 +189,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
}
|
||||
checks["snapshot"] = snapCheck
|
||||
|
||||
report := map[string]interface{}{
|
||||
report := map[string]any{
|
||||
"ok": ok,
|
||||
"checks": checks,
|
||||
"narrative_warnings": narrativeWarnings,
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ package config
|
|||
|
||||
// ResourceConfig centralizes pacing and resource optimization knobs.
|
||||
type ResourceConfig struct {
|
||||
PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"`
|
||||
PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"`
|
||||
MaxWorkers int `yaml:"max_workers" toml:"max_workers"`
|
||||
DesiredRPSPerWorker int `yaml:"desired_rps_per_worker" toml:"desired_rps_per_worker"`
|
||||
RequestsPerSec int `yaml:"requests_per_sec" toml:"requests_per_sec"`
|
||||
PodmanCPUs string `yaml:"podman_cpus" toml:"podman_cpus"`
|
||||
PodmanMemory string `yaml:"podman_memory" toml:"podman_memory"`
|
||||
RequestBurstOverride int `yaml:"request_burst" toml:"request_burst"`
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,33 +7,23 @@ import (
|
|||
|
||||
// SecurityConfig holds security-related configuration
|
||||
type SecurityConfig struct {
|
||||
// AllowedOrigins lists the allowed origins for WebSocket connections
|
||||
// Empty list defaults to localhost-only in production mode
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
|
||||
// ProductionMode enables strict security checks
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
|
||||
// APIKeyRotationDays is the number of days before API keys should be rotated
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
|
||||
// AuditLogging configuration
|
||||
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
|
||||
|
||||
// IPWhitelist for additional connection filtering
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
}
|
||||
|
||||
// AuditLoggingConfig holds audit logging configuration
|
||||
type AuditLoggingConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogPath string `yaml:"log_path"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// PrivacyConfig holds privacy enforcement configuration
|
||||
type PrivacyConfig struct {
|
||||
DefaultLevel string `yaml:"default_level"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
DefaultLevel string `yaml:"default_level"` // private, team, public, anonymized
|
||||
EnforceTeams bool `yaml:"enforce_teams"`
|
||||
AuditAccess bool `yaml:"audit_access"`
|
||||
}
|
||||
|
|
@ -58,9 +48,9 @@ type MonitoringConfig struct {
|
|||
|
||||
// PrometheusConfig holds Prometheus metrics configuration
|
||||
type PrometheusConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
Path string `yaml:"path"`
|
||||
Port int `yaml:"port"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// HealthChecksConfig holds health check configuration
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ type RedisConfig struct {
|
|||
// SSHConfig holds SSH connection settings
|
||||
type SSHConfig struct {
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port int `yaml:"port" json:"port"`
|
||||
User string `yaml:"user" json:"user"`
|
||||
KeyPath string `yaml:"key_path" json:"key_path"`
|
||||
KnownHosts string `yaml:"known_hosts" json:"known_hosts"`
|
||||
Port int `yaml:"port" json:"port"`
|
||||
}
|
||||
|
||||
// ExpandPath expands environment variables and tilde in a path
|
||||
|
|
|
|||
Loading…
Reference in a new issue