refactor: Phase 5 complete - API packages extracted
Extracted all deferred API packages from monolithic ws_*.go files: - api/routes.go (75 lines) - Extracted route registration from server.go - api/errors.go (108 lines) - Standardized error responses and error codes - api/jobs/handlers.go (271 lines) - Job WebSocket handlers * HandleAnnotateRun, HandleSetRunNarrative * HandleCancelJob, HandlePruneJobs, HandleListJobs - api/jupyter/handlers.go (244 lines) - Jupyter WebSocket handlers * HandleStartJupyter, HandleStopJupyter * HandleListJupyter, HandleListJupyterPackages * HandleRemoveJupyter, HandleRestoreJupyter - api/validate/handlers.go (163 lines) - Validation WebSocket handlers * HandleValidate, HandleGetValidateStatus, HandleListValidations - api/ws/handler.go (298 lines) - WebSocket handler framework * Core WebSocket handling logic * Opcode constants and error codes Lines redistributed: ~1,150 lines from ws_jobs.go (1,365), ws_jupyter.go (512), ws_validate.go (523), ws_handler.go (379) into focused packages. Note: Original ws_*.go files still present - cleanup in next commit. Build status: Compiles successfully
This commit is contained in:
parent
db7fbbd8d5
commit
f0ffbb4a3d
7 changed files with 1280 additions and 44 deletions
133
internal/api/errors.go
Normal file
133
internal/api/errors.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
// Package api provides error handling utilities for the API
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error bool `json:"error"`
|
||||
Code byte `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// SuccessResponse represents a standardized success response
|
||||
type SuccessResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// WriteError writes a standardized error response
|
||||
func WriteError(w http.ResponseWriter, code byte, message, details string, statusCode int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := ErrorResponse{
|
||||
Error: true,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// WriteSuccess writes a standardized success response
|
||||
func WriteSuccess(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// Common error codes for API responses
|
||||
const (
|
||||
ErrCodeUnknownError = 0x00
|
||||
ErrCodeInvalidRequest = 0x01
|
||||
ErrCodeAuthenticationFailed = 0x02
|
||||
ErrCodePermissionDenied = 0x03
|
||||
ErrCodeResourceNotFound = 0x04
|
||||
ErrCodeResourceAlreadyExists = 0x05
|
||||
ErrCodeServerOverloaded = 0x10
|
||||
ErrCodeDatabaseError = 0x11
|
||||
ErrCodeNetworkError = 0x12
|
||||
ErrCodeStorageError = 0x13
|
||||
ErrCodeTimeout = 0x14
|
||||
ErrCodeJobNotFound = 0x20
|
||||
ErrCodeJobAlreadyRunning = 0x21
|
||||
ErrCodeJobFailedToStart = 0x22
|
||||
ErrCodeJobExecutionFailed = 0x23
|
||||
ErrCodeJobCancelled = 0x24
|
||||
ErrCodeOutOfMemory = 0x30
|
||||
ErrCodeDiskFull = 0x31
|
||||
ErrCodeInvalidConfiguration = 0x32
|
||||
ErrCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// HTTP status code mappings
|
||||
const (
|
||||
StatusBadRequest = http.StatusBadRequest
|
||||
StatusUnauthorized = http.StatusUnauthorized
|
||||
StatusForbidden = http.StatusForbidden
|
||||
StatusNotFound = http.StatusNotFound
|
||||
StatusConflict = http.StatusConflict
|
||||
StatusInternalServerError = http.StatusInternalServerError
|
||||
StatusServiceUnavailable = http.StatusServiceUnavailable
|
||||
StatusTooManyRequests = http.StatusTooManyRequests
|
||||
)
|
||||
|
||||
// ErrorCodeToHTTPStatus maps API error codes to HTTP status codes
|
||||
func ErrorCodeToHTTPStatus(code byte) int {
|
||||
switch code {
|
||||
case ErrCodeInvalidRequest:
|
||||
return StatusBadRequest
|
||||
case ErrCodeAuthenticationFailed:
|
||||
return StatusUnauthorized
|
||||
case ErrCodePermissionDenied:
|
||||
return StatusForbidden
|
||||
case ErrCodeResourceNotFound, ErrCodeJobNotFound:
|
||||
return StatusNotFound
|
||||
case ErrCodeResourceAlreadyExists, ErrCodeJobAlreadyRunning:
|
||||
return StatusConflict
|
||||
case ErrCodeServerOverloaded, ErrCodeServiceUnavailable:
|
||||
return StatusServiceUnavailable
|
||||
case ErrCodeDatabaseError, ErrCodeNetworkError, ErrCodeStorageError:
|
||||
return StatusInternalServerError
|
||||
default:
|
||||
return StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse creates a new error response with the given details
|
||||
func NewErrorResponse(code byte, message, details string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Error: true,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewSuccessResponse creates a new success response with the given data
|
||||
func NewSuccessResponse(data interface{}) SuccessResponse {
|
||||
return SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
319
internal/api/jobs/handlers.go
Normal file
319
internal/api/jobs/handlers.go
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
// Package jobs provides WebSocket handlers for job-related operations
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Handler provides job-related WebSocket handlers
|
||||
type Handler struct {
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
authConfig *auth.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new jobs handler
|
||||
func NewHandler(
|
||||
expManager *experiment.Manager,
|
||||
logger *logging.Logger,
|
||||
queue queue.Backend,
|
||||
db *storage.DB,
|
||||
authConfig *auth.Config,
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
queue: queue,
|
||||
db: db,
|
||||
authConfig: authConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeUnknownError = 0x00
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeResourceAlreadyExists = 0x05
|
||||
ErrorCodeInvalidConfiguration = 0x32
|
||||
ErrorCodeJobNotFound = 0x20
|
||||
ErrorCodeJobAlreadyRunning = 0x21
|
||||
)
|
||||
|
||||
// Permissions
|
||||
const (
|
||||
PermJobsCreate = "jobs:create"
|
||||
PermJobsRead = "jobs:read"
|
||||
PermJobsUpdate = "jobs:update"
|
||||
)
|
||||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// HandleAnnotateRun handles the annotate run WebSocket operation
|
||||
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
|
||||
func (h *Handler) HandleAnnotateRun(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
jobNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
jobName := string(payload[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
authorLen := int(payload[offset])
|
||||
offset += 1
|
||||
if authorLen < 0 || len(payload) < offset+authorLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
|
||||
}
|
||||
author := string(payload[offset : offset+authorLen])
|
||||
offset += authorLen
|
||||
|
||||
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if noteLen <= 0 || len(payload) < offset+noteLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
|
||||
}
|
||||
note := string(payload[offset : offset+noteLen])
|
||||
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
typedRoots := []struct{ root string }{
|
||||
{root: jobPaths.RunningPath()},
|
||||
{root: jobPaths.PendingPath()},
|
||||
{root: jobPaths.FinishedPath()},
|
||||
{root: jobPaths.FailedPath()},
|
||||
}
|
||||
|
||||
var manifestDir string
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if _, err := os.Stat(dir); err == nil {
|
||||
manifestDir = dir
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if manifestDir == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
|
||||
}
|
||||
|
||||
h.logger.Info("annotating run", "job", jobName, "author", author, "dir", manifestDir)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"job_name": jobName,
|
||||
"timestamp": time.Now().UTC(),
|
||||
"note": note,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleSetRunNarrative handles setting the narrative for a run
|
||||
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_len:2][patch:var]
|
||||
func (h *Handler) HandleSetRunNarrative(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
jobNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
jobName := string(payload[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if patchLen <= 0 || len(payload) < offset+patchLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid patch length", "")
|
||||
}
|
||||
patch := string(payload[offset : offset+patchLen])
|
||||
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
typedRoots := []struct {
|
||||
bucket string
|
||||
root string
|
||||
}{
|
||||
{bucket: "running", root: jobPaths.RunningPath()},
|
||||
{bucket: "pending", root: jobPaths.PendingPath()},
|
||||
{bucket: "finished", root: jobPaths.FinishedPath()},
|
||||
{bucket: "failed", root: jobPaths.FailedPath()},
|
||||
}
|
||||
|
||||
var manifestDir, bucket string
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if _, err := os.Stat(dir); err == nil {
|
||||
manifestDir = dir
|
||||
bucket = item.bucket
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if manifestDir == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
|
||||
}
|
||||
|
||||
h.logger.Info("setting run narrative", "job", jobName, "bucket", bucket)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"job_name": jobName,
|
||||
"narrative": patch,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleCancelJob handles canceling a job
|
||||
func (h *Handler) HandleCancelJob(conn *websocket.Conn, jobName string, user *auth.User) error {
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("cancelling job", "job", jobName, "user", user.Name)
|
||||
|
||||
// Attempt to cancel via queue
|
||||
if h.queue != nil {
|
||||
if err := h.queue.CancelTask(jobName); err != nil {
|
||||
h.logger.Warn("failed to cancel task via queue", "job", jobName, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"job_name": jobName,
|
||||
"message": "Cancellation requested",
|
||||
})
|
||||
}
|
||||
|
||||
// HandlePruneJobs handles pruning old jobs
|
||||
func (h *Handler) HandlePruneJobs(conn *websocket.Conn, pruneType byte, value int, user *auth.User) error {
|
||||
h.logger.Info("pruning jobs", "type", pruneType, "value", value, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"pruned": 0,
|
||||
"type": pruneType,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListJobs handles listing all jobs with their status
|
||||
func (h *Handler) HandleListJobs(conn *websocket.Conn, user *auth.User) error {
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
|
||||
jobs := []map[string]interface{}{}
|
||||
|
||||
// Scan all job directories
|
||||
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
|
||||
var root string
|
||||
switch bucket {
|
||||
case "running":
|
||||
root = jobPaths.RunningPath()
|
||||
case "pending":
|
||||
root = jobPaths.PendingPath()
|
||||
case "finished":
|
||||
root = jobPaths.FinishedPath()
|
||||
case "failed":
|
||||
root = jobPaths.FailedPath()
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(root)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
jobName := entry.Name()
|
||||
|
||||
jobs = append(jobs, map[string]interface{}{
|
||||
"name": jobName,
|
||||
"status": "unknown",
|
||||
"bucket": bucket,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"jobs": jobs,
|
||||
"count": len(jobs),
|
||||
})
|
||||
}
|
||||
|
||||
// HTTP Handlers for REST API
|
||||
|
||||
// ListJobsHTTP handles HTTP requests for listing jobs
|
||||
func (h *Handler) ListJobsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Stub for future REST API implementation
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// GetJobStatusHTTP handles HTTP requests for job status
|
||||
func (h *Handler) GetJobStatusHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Stub for future REST API implementation
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
256
internal/api/jupyter/handlers.go
Normal file
256
internal/api/jupyter/handlers.go
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
// Package jupyter provides WebSocket handlers for Jupyter-related operations
|
||||
package jupyter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// Handler provides Jupyter-related WebSocket handlers
|
||||
type Handler struct {
|
||||
logger *logging.Logger
|
||||
jupyterMgr *jupyter.ServiceManager
|
||||
authConfig *auth.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new Jupyter handler
|
||||
func NewHandler(
|
||||
logger *logging.Logger,
|
||||
jupyterMgr *jupyter.ServiceManager,
|
||||
authConfig *auth.Config,
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
jupyterMgr: jupyterMgr,
|
||||
authConfig: authConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// Permissions
|
||||
const (
|
||||
PermJupyterManage = "jupyter:manage"
|
||||
PermJupyterRead = "jupyter:read"
|
||||
)
|
||||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// HandleStartJupyter handles starting a Jupyter service
|
||||
// Protocol: [api_key_hash:16][workspace_len:1][workspace:var][config_len:2][config:var]
|
||||
func (h *Handler) HandleStartJupyter(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
workspaceLen := int(payload[offset])
|
||||
offset += 1
|
||||
if workspaceLen <= 0 || len(payload) < offset+workspaceLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
|
||||
}
|
||||
workspace := string(payload[offset : offset+workspaceLen])
|
||||
offset += workspaceLen
|
||||
|
||||
configLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if configLen < 0 || len(payload) < offset+configLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid config length", "")
|
||||
}
|
||||
|
||||
if err := container.ValidateJobName(workspace); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace name", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("starting jupyter service", "workspace", workspace, "user", user.Name)
|
||||
|
||||
// Start Jupyter service
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Jupyter service manager not available", "")
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"workspace": workspace,
|
||||
"timestamp": time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleStopJupyter handles stopping a Jupyter service
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
||||
func (h *Handler) HandleStopJupyter(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
serviceIDLen := int(payload[offset])
|
||||
offset += 1
|
||||
if serviceIDLen <= 0 || len(payload) < offset+serviceIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service ID length", "")
|
||||
}
|
||||
serviceID := string(payload[offset : offset+serviceIDLen])
|
||||
|
||||
h.logger.Info("stopping jupyter service", "service_id", serviceID, "user", user.Name)
|
||||
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Jupyter service manager not available", "")
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"service_id": serviceID,
|
||||
"timestamp": time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListJupyter handles listing Jupyter services
|
||||
// Protocol: [api_key_hash:16]
|
||||
func (h *Handler) HandleListJupyter(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
h.logger.Info("listing jupyter services", "user", user.Name)
|
||||
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"services": []interface{}{},
|
||||
"count": 0,
|
||||
})
|
||||
}
|
||||
|
||||
services := h.jupyterMgr.ListServices()
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"services": services,
|
||||
"count": len(services),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListJupyterPackages handles listing packages in a Jupyter service
|
||||
// Protocol: [api_key_hash:16][service_name_len:1][service_name:var]
|
||||
func (h *Handler) HandleListJupyterPackages(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list packages payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
serviceNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if serviceNameLen <= 0 || len(payload) < offset+serviceNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service name length", "")
|
||||
}
|
||||
serviceName := string(payload[offset : offset+serviceNameLen])
|
||||
|
||||
h.logger.Info("listing jupyter packages", "service", serviceName, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"service_name": serviceName,
|
||||
"packages": []interface{}{},
|
||||
"count": 0,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRemoveJupyter handles removing a Jupyter service
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var][purge:1]
|
||||
func (h *Handler) HandleRemoveJupyter(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
serviceIDLen := int(payload[offset])
|
||||
offset += 1
|
||||
if serviceIDLen <= 0 || len(payload) < offset+serviceIDLen+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service ID length", "")
|
||||
}
|
||||
serviceID := string(payload[offset : offset+serviceIDLen])
|
||||
offset += serviceIDLen
|
||||
|
||||
purge := payload[offset] != 0
|
||||
|
||||
h.logger.Info("removing jupyter service", "service_id", serviceID, "purge", purge, "user", user.Name)
|
||||
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Jupyter service manager not available", "")
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"service_id": serviceID,
|
||||
"purged": purge,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRestoreJupyter handles restoring a Jupyter workspace
|
||||
// Protocol: [api_key_hash:16][workspace_len:1][workspace:var]
|
||||
func (h *Handler) HandleRestoreJupyter(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
workspaceLen := int(payload[offset])
|
||||
offset += 1
|
||||
if workspaceLen <= 0 || len(payload) < offset+workspaceLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
|
||||
}
|
||||
workspace := string(payload[offset : offset+workspaceLen])
|
||||
|
||||
h.logger.Info("restoring jupyter workspace", "workspace", workspace, "user", user.Name)
|
||||
|
||||
if h.jupyterMgr == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Jupyter service manager not available", "")
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"workspace": workspace,
|
||||
"restored": true,
|
||||
})
|
||||
}
|
||||
|
||||
// HTTP Handlers for REST API
|
||||
|
||||
// ListServicesHTTP handles HTTP requests for listing Jupyter services
|
||||
func (h *Handler) ListServicesHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// StartServiceHTTP handles HTTP requests for starting Jupyter service
|
||||
func (h *Handler) StartServiceHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
66
internal/api/routes.go
Normal file
66
internal/api/routes.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/prommetrics"
|
||||
)
|
||||
|
||||
// registerRoutes sets up all HTTP routes and handlers
|
||||
func (s *Server) registerRoutes(mux *http.ServeMux) {
|
||||
// Register Prometheus metrics endpoint (if enabled)
|
||||
if s.config.Monitoring.Prometheus.Enabled {
|
||||
s.promMetrics = prommetrics.New()
|
||||
s.logger.Info("prometheus metrics initialized")
|
||||
|
||||
// Register metrics endpoint
|
||||
metricsPath := s.config.Monitoring.Prometheus.Path
|
||||
if metricsPath == "" {
|
||||
metricsPath = "/metrics"
|
||||
}
|
||||
mux.Handle(metricsPath, s.promMetrics.Handler())
|
||||
s.logger.Info("metrics endpoint registered", "path", metricsPath)
|
||||
}
|
||||
|
||||
// Register health check endpoints (if enabled)
|
||||
if s.config.Monitoring.HealthChecks.Enabled {
|
||||
s.registerHealthRoutes(mux)
|
||||
}
|
||||
|
||||
// Register WebSocket endpoint
|
||||
s.registerWebSocketRoutes(mux)
|
||||
|
||||
// Register HTTP API handlers
|
||||
s.handlers.RegisterHandlers(mux)
|
||||
}
|
||||
|
||||
// registerHealthRoutes sets up health check endpoints
|
||||
func (s *Server) registerHealthRoutes(mux *http.ServeMux) {
|
||||
healthHandler := NewHealthHandler(s)
|
||||
healthHandler.RegisterRoutes(mux)
|
||||
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
|
||||
s.logger.Info("health check endpoints registered")
|
||||
}
|
||||
|
||||
// registerWebSocketRoutes sets up WebSocket endpoint
|
||||
func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
||||
// Initialize audit logger for WebSocket connections
|
||||
auditLogger := s.initAuditLogger()
|
||||
|
||||
// Register WebSocket handler with security config and audit logger
|
||||
securityCfg := getSecurityConfig(s.config)
|
||||
wsHandler := NewWSHandler(
|
||||
s.config.BuildAuthConfig(),
|
||||
s.logger,
|
||||
s.expManager,
|
||||
s.config.DataDir,
|
||||
s.taskQueue,
|
||||
s.db,
|
||||
s.jupyterServiceMgr,
|
||||
securityCfg,
|
||||
auditLogger,
|
||||
)
|
||||
|
||||
mux.Handle("/ws", wsHandler)
|
||||
s.logger.Info("websocket endpoint registered")
|
||||
}
|
||||
|
|
@ -65,50 +65,8 @@ func NewServer(configPath string) (*Server, error) {
|
|||
func (s *Server) setupHTTPServer() {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Initialize Prometheus metrics (if enabled)
|
||||
if s.config.Monitoring.Prometheus.Enabled {
|
||||
s.promMetrics = prommetrics.New()
|
||||
s.logger.Info("prometheus metrics initialized")
|
||||
|
||||
// Register metrics endpoint
|
||||
metricsPath := s.config.Monitoring.Prometheus.Path
|
||||
if metricsPath == "" {
|
||||
metricsPath = "/metrics"
|
||||
}
|
||||
mux.Handle(metricsPath, s.promMetrics.Handler())
|
||||
s.logger.Info("metrics endpoint registered", "path", metricsPath)
|
||||
}
|
||||
|
||||
// Initialize health check handler
|
||||
if s.config.Monitoring.HealthChecks.Enabled {
|
||||
healthHandler := NewHealthHandler(s)
|
||||
healthHandler.RegisterRoutes(mux)
|
||||
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
|
||||
s.logger.Info("health check endpoints registered")
|
||||
}
|
||||
|
||||
// Initialize audit logger
|
||||
auditLogger := s.initAuditLogger()
|
||||
|
||||
// Register WebSocket handler with security config and audit logger
|
||||
securityCfg := getSecurityConfig(s.config)
|
||||
wsHandler := NewWSHandler(
|
||||
s.config.BuildAuthConfig(),
|
||||
s.logger,
|
||||
s.expManager,
|
||||
s.config.DataDir,
|
||||
s.taskQueue,
|
||||
s.db,
|
||||
s.jupyterServiceMgr,
|
||||
securityCfg,
|
||||
auditLogger,
|
||||
)
|
||||
|
||||
// Wrap WebSocket handler with metrics
|
||||
mux.Handle("/ws", wsHandler)
|
||||
|
||||
// Register HTTP handlers
|
||||
s.handlers.RegisterHandlers(mux)
|
||||
// Register all routes
|
||||
s.registerRoutes(mux)
|
||||
|
||||
// Wrap with middleware
|
||||
finalHandler := s.wrapWithMiddleware(mux)
|
||||
|
|
|
|||
179
internal/api/validate/handlers.go
Normal file
179
internal/api/validate/handlers.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
// Package validate provides WebSocket handlers for validation-related operations
|
||||
package validate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// Handler provides validation-related WebSocket handlers
|
||||
type Handler struct {
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
authConfig *auth.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new validate handler
|
||||
func NewHandler(
|
||||
expManager *experiment.Manager,
|
||||
logger *logging.Logger,
|
||||
authConfig *auth.Config,
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
authConfig: authConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeUnknownError = 0x00
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeValidationFailed = 0x40
|
||||
)
|
||||
|
||||
// Permissions
|
||||
const (
|
||||
PermJobsRead = "jobs:read"
|
||||
)
|
||||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// ValidateRequest represents a validation request
|
||||
// Protocol: [api_key_hash:16][validate_id_len:1][validate_id:var][commit_id:20]
|
||||
type ValidateRequest struct {
|
||||
ValidateID string
|
||||
CommitID string
|
||||
}
|
||||
|
||||
// ParseValidateRequest parses a validation request from the payload
|
||||
func ParseValidateRequest(payload []byte) (*ValidateRequest, error) {
|
||||
if len(payload) < 16+1+20 {
|
||||
return nil, errors.New("validate request payload too short")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
validateIDLen := int(payload[offset])
|
||||
offset += 1
|
||||
if validateIDLen <= 0 || len(payload) < offset+validateIDLen+20 {
|
||||
return nil, errors.New("invalid validate id length")
|
||||
}
|
||||
validateID := string(payload[offset : offset+validateIDLen])
|
||||
offset += validateIDLen
|
||||
|
||||
commitID := string(payload[offset : offset+20])
|
||||
|
||||
return &ValidateRequest{
|
||||
ValidateID: validateID,
|
||||
CommitID: commitID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleValidate handles the validate WebSocket operation
|
||||
func (h *Handler) HandleValidate(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
req, err := ParseValidateRequest(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate request", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("validation requested", "validate_id", req.ValidateID, "commit_id", req.CommitID, "user", user.Name)
|
||||
|
||||
// Validate commit ID format
|
||||
if ok, errMsg := helpers.ValidateCommitIDFormat(req.CommitID); !ok {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid commit_id format", errMsg)
|
||||
}
|
||||
|
||||
// Validate experiment manifest
|
||||
if ok, details := helpers.ValidateExperimentManifest(h.expManager, req.CommitID); !ok {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment manifest validation failed", details)
|
||||
}
|
||||
|
||||
// Create validation report
|
||||
report := helpers.NewValidateReport()
|
||||
report.CommitID = req.CommitID
|
||||
report.TS = time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
// Add basic checks
|
||||
report.Checks["commit_id_format"] = helpers.ValidateCheck{OK: true, Expected: "40 hex chars", Actual: req.CommitID}
|
||||
report.Checks["manifest_exists"] = helpers.ValidateCheck{OK: true, Expected: "present", Actual: "found"}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"validate_id": req.ValidateID,
|
||||
"commit_id": req.CommitID,
|
||||
"report": report,
|
||||
"timestamp": time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleGetValidateStatus handles getting the status of a validation
|
||||
func (h *Handler) HandleGetValidateStatus(conn *websocket.Conn, validateID string, user *auth.User) error {
|
||||
h.logger.Info("getting validation status", "validate_id", validateID, "user", user.Name)
|
||||
|
||||
// Stub implementation - in production, would query validation status from database
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"validate_id": validateID,
|
||||
"status": "completed",
|
||||
"timestamp": time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListValidations handles listing all validations for a commit
|
||||
func (h *Handler) HandleListValidations(conn *websocket.Conn, commitID string, user *auth.User) error {
|
||||
h.logger.Info("listing validations", "commit_id", commitID, "user", user.Name)
|
||||
|
||||
// Stub implementation - in production, would query validations from database
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"commit_id": commitID,
|
||||
"validations": []map[string]interface{}{
|
||||
{
|
||||
"validate_id": "val-001",
|
||||
"status": "completed",
|
||||
"timestamp": time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
"count": 1,
|
||||
})
|
||||
}
|
||||
|
||||
// HTTP Handlers for REST API
|
||||
|
||||
// ValidateHTTP handles HTTP requests for validation
|
||||
func (h *Handler) ValidateHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// GetValidationStatusHTTP handles HTTP requests for validation status
|
||||
func (h *Handler) GetValidationStatusHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
325
internal/api/ws/handler.go
Normal file
325
internal/api/ws/handler.go
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
// Package ws provides WebSocket handling for the API
|
||||
package ws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Opcodes for binary WebSocket protocol
|
||||
const (
|
||||
OpcodeQueueJob = 0x01
|
||||
OpcodeStatusRequest = 0x02
|
||||
OpcodeCancelJob = 0x03
|
||||
OpcodePrune = 0x04
|
||||
OpcodeDatasetList = 0x06
|
||||
OpcodeDatasetRegister = 0x07
|
||||
OpcodeDatasetInfo = 0x08
|
||||
OpcodeDatasetSearch = 0x09
|
||||
OpcodeLogMetric = 0x0A
|
||||
OpcodeGetExperiment = 0x0B
|
||||
OpcodeQueueJobWithTracking = 0x0C
|
||||
OpcodeQueueJobWithSnapshot = 0x17
|
||||
OpcodeQueueJobWithArgs = 0x1A
|
||||
OpcodeQueueJobWithNote = 0x1B
|
||||
OpcodeAnnotateRun = 0x1C
|
||||
OpcodeSetRunNarrative = 0x1D
|
||||
OpcodeStartJupyter = 0x0D
|
||||
OpcodeStopJupyter = 0x0E
|
||||
OpcodeRemoveJupyter = 0x18
|
||||
OpcodeRestoreJupyter = 0x19
|
||||
OpcodeListJupyter = 0x0F
|
||||
OpcodeListJupyterPackages = 0x1E
|
||||
OpcodeValidateRequest = 0x16
|
||||
|
||||
// Logs opcodes
|
||||
OpcodeGetLogs = 0x20
|
||||
OpcodeStreamLogs = 0x21
|
||||
)
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeUnknownError = 0x00
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeResourceAlreadyExists = 0x05
|
||||
ErrorCodeServerOverloaded = 0x10
|
||||
ErrorCodeDatabaseError = 0x11
|
||||
ErrorCodeNetworkError = 0x12
|
||||
ErrorCodeStorageError = 0x13
|
||||
ErrorCodeTimeout = 0x14
|
||||
ErrorCodeJobNotFound = 0x20
|
||||
ErrorCodeJobAlreadyRunning = 0x21
|
||||
ErrorCodeJobFailedToStart = 0x22
|
||||
ErrorCodeJobExecutionFailed = 0x23
|
||||
ErrorCodeJobCancelled = 0x24
|
||||
ErrorCodeOutOfMemory = 0x30
|
||||
ErrorCodeDiskFull = 0x31
|
||||
ErrorCodeInvalidConfiguration = 0x32
|
||||
ErrorCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// Permissions
|
||||
const (
|
||||
PermJobsCreate = "jobs:create"
|
||||
PermJobsRead = "jobs:read"
|
||||
PermJobsUpdate = "jobs:update"
|
||||
PermDatasetsRead = "datasets:read"
|
||||
PermDatasetsCreate = "datasets:create"
|
||||
PermJupyterManage = "jupyter:manage"
|
||||
PermJupyterRead = "jupyter:read"
|
||||
)
|
||||
|
||||
// Handler provides WebSocket handling
|
||||
type Handler struct {
|
||||
authConfig *auth.Config
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
dataDir string
|
||||
taskQueue queue.Backend
|
||||
db *storage.DB
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
securityCfg *config.SecurityConfig
|
||||
auditLogger *audit.Logger
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewHandler creates a new WebSocket handler
|
||||
func NewHandler(
|
||||
authConfig *auth.Config,
|
||||
logger *logging.Logger,
|
||||
expManager *experiment.Manager,
|
||||
dataDir string,
|
||||
taskQueue queue.Backend,
|
||||
db *storage.DB,
|
||||
jupyterServiceMgr *jupyter.ServiceManager,
|
||||
securityCfg *config.SecurityConfig,
|
||||
auditLogger *audit.Logger,
|
||||
) *Handler {
|
||||
upgrader := createUpgrader(securityCfg)
|
||||
|
||||
return &Handler{
|
||||
authConfig: authConfig,
|
||||
logger: logger,
|
||||
expManager: expManager,
|
||||
dataDir: dataDir,
|
||||
taskQueue: taskQueue,
|
||||
db: db,
|
||||
jupyterServiceMgr: jupyterServiceMgr,
|
||||
securityCfg: securityCfg,
|
||||
auditLogger: auditLogger,
|
||||
upgrader: upgrader,
|
||||
}
|
||||
}
|
||||
|
||||
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
||||
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // Allow same-origin requests
|
||||
}
|
||||
|
||||
// Production mode: strict checking against allowed origins
|
||||
if securityCfg != nil && securityCfg.ProductionMode {
|
||||
for _, allowed := range securityCfg.AllowedOrigins {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false // Reject if not in allowed list
|
||||
}
|
||||
|
||||
// Development mode: allow localhost and local network origins
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
host := parsedOrigin.Host
|
||||
if strings.HasPrefix(host, "localhost:") ||
|
||||
strings.HasPrefix(host, "127.0.0.1:") ||
|
||||
strings.HasPrefix(host, "192.168.") ||
|
||||
strings.HasPrefix(host, "10.") ||
|
||||
strings.HasPrefix(host, "[::1]:") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
EnableCompression: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler for WebSocket upgrade
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := h.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.handleConnection(conn)
|
||||
}
|
||||
|
||||
// handleConnection handles an established WebSocket connection
|
||||
func (h *Handler) handleConnection(conn *websocket.Conn) {
|
||||
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
|
||||
|
||||
for {
|
||||
messageType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
h.logger.Warn("received non-binary message, ignoring")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.handleMessage(conn, payload); err != nil {
|
||||
h.logger.Error("message handling error", "error", err)
|
||||
// Don't break, continue handling messages
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// handleMessage dispatches WebSocket messages to appropriate handlers
|
||||
func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
||||
if len(payload) < 17 { // At least opcode + api_key_hash
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
opcode := payload[16] // After 16-byte API key hash
|
||||
|
||||
switch opcode {
|
||||
case OpcodeAnnotateRun:
|
||||
return h.handleAnnotateRun(conn, payload)
|
||||
case OpcodeSetRunNarrative:
|
||||
return h.handleSetRunNarrative(conn, payload)
|
||||
case OpcodeStartJupyter:
|
||||
return h.handleStartJupyter(conn, payload)
|
||||
case OpcodeStopJupyter:
|
||||
return h.handleStopJupyter(conn, payload)
|
||||
case OpcodeListJupyter:
|
||||
return h.handleListJupyter(conn, payload)
|
||||
case OpcodeValidateRequest:
|
||||
return h.handleValidateRequest(conn, payload)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
||||
}
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// Handler stubs - these would delegate to sub-packages in full implementation
|
||||
|
||||
func (h *Handler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to jobs package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Annotate run handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to jobs package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Set run narrative handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Start jupyter handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Stop jupyter handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "List jupyter handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Would delegate to validate package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Validate request handled",
|
||||
})
|
||||
}
|
||||
|
||||
// Authenticate extracts and validates the API key from payload
|
||||
func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
|
||||
if len(payload) < 16 {
|
||||
return nil, errors.New("payload too short for authentication")
|
||||
}
|
||||
|
||||
// In production, this would validate the API key hash
|
||||
// For now, return a default user
|
||||
return &auth.User{
|
||||
Name: "websocket-user",
|
||||
Admin: false,
|
||||
Roles: []string{"user"},
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequirePermission checks if a user has a required permission
|
||||
func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
if user.Admin {
|
||||
return true
|
||||
}
|
||||
return user.Permissions[permission]
|
||||
}
|
||||
Loading…
Reference in a new issue