Update API layer for scheduler integration: - WebSocket handlers with scheduler protocol support - Jobs WebSocket endpoint with priority queue integration - Validation middleware for scheduler messages - Server configuration with security hardening - Protocol definitions for worker-scheduler communication - Dataset handlers with tenant isolation checks - Response helpers with audit context - OpenAPI spec updates for new endpoints
907 lines
27 KiB
Go
907 lines
27 KiB
Go
// Package ws provides WebSocket handling for the API
|
|
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
)
|
|
|
|
// Response packet types (duplicated from api package to avoid import cycle)
|
|
const (
|
|
PacketTypeSuccess = 0x00
|
|
PacketTypeError = 0x01
|
|
PacketTypeProgress = 0x02
|
|
PacketTypeStatus = 0x03
|
|
PacketTypeData = 0x04
|
|
PacketTypeLog = 0x05
|
|
)
|
|
|
|
// Opcodes for binary WebSocket protocol
|
|
const (
|
|
OpcodeQueueJob = 0x01
|
|
OpcodeStatusRequest = 0x02
|
|
OpcodeCancelJob = 0x03
|
|
OpcodePrune = 0x04
|
|
OpcodeDatasetList = 0x06
|
|
OpcodeDatasetRegister = 0x07
|
|
OpcodeDatasetInfo = 0x08
|
|
OpcodeDatasetSearch = 0x09
|
|
OpcodeLogMetric = 0x0A
|
|
OpcodeGetExperiment = 0x0B
|
|
OpcodeQueueJobWithTracking = 0x0C
|
|
OpcodeQueueJobWithSnapshot = 0x17
|
|
OpcodeQueueJobWithArgs = 0x1A
|
|
OpcodeQueueJobWithNote = 0x1B
|
|
OpcodeAnnotateRun = 0x1C
|
|
OpcodeSetRunNarrative = 0x1D
|
|
OpcodeStartJupyter = 0x0D
|
|
OpcodeStopJupyter = 0x0E
|
|
OpcodeRemoveJupyter = 0x18
|
|
OpcodeRestoreJupyter = 0x19
|
|
OpcodeListJupyter = 0x0F
|
|
OpcodeListJupyterPackages = 0x1E
|
|
OpcodeValidateRequest = 0x16
|
|
|
|
// Logs opcodes
|
|
OpcodeGetLogs = 0x20
|
|
OpcodeStreamLogs = 0x21
|
|
|
|
//
|
|
OpcodeCompareRuns = 0x30
|
|
OpcodeFindRuns = 0x31
|
|
OpcodeExportRun = 0x32
|
|
OpcodeSetRunOutcome = 0x33
|
|
)
|
|
|
|
// Error codes
|
|
const (
|
|
ErrorCodeUnknownError = 0x00
|
|
ErrorCodeInvalidRequest = 0x01
|
|
ErrorCodeAuthenticationFailed = 0x02
|
|
ErrorCodePermissionDenied = 0x03
|
|
ErrorCodeResourceNotFound = 0x04
|
|
ErrorCodeResourceAlreadyExists = 0x05
|
|
ErrorCodeServerOverloaded = 0x10
|
|
ErrorCodeDatabaseError = 0x11
|
|
ErrorCodeNetworkError = 0x12
|
|
ErrorCodeStorageError = 0x13
|
|
ErrorCodeTimeout = 0x14
|
|
ErrorCodeJobNotFound = 0x20
|
|
ErrorCodeJobAlreadyRunning = 0x21
|
|
ErrorCodeJobFailedToStart = 0x22
|
|
ErrorCodeJobExecutionFailed = 0x23
|
|
ErrorCodeJobCancelled = 0x24
|
|
ErrorCodeOutOfMemory = 0x30
|
|
ErrorCodeDiskFull = 0x31
|
|
ErrorCodeInvalidConfiguration = 0x32
|
|
ErrorCodeServiceUnavailable = 0x33
|
|
)
|
|
|
|
// Permissions
|
|
const (
|
|
PermJobsCreate = "jobs:create"
|
|
PermJobsRead = "jobs:read"
|
|
PermJobsUpdate = "jobs:update"
|
|
PermDatasetsRead = "datasets:read"
|
|
PermDatasetsCreate = "datasets:create"
|
|
PermJupyterManage = "jupyter:manage"
|
|
PermJupyterRead = "jupyter:read"
|
|
)
|
|
|
|
// ClientType represents the type of WebSocket client
|
|
type ClientType int
|
|
|
|
const (
|
|
ClientTypeCLI ClientType = iota
|
|
ClientTypeTUI
|
|
)
|
|
|
|
// Client represents a connected WebSocket client
|
|
type Client struct {
|
|
conn *websocket.Conn
|
|
User string
|
|
RemoteAddr string
|
|
Type ClientType
|
|
}
|
|
|
|
// Handler provides WebSocket handling
|
|
type Handler struct {
|
|
taskQueue queue.Backend
|
|
datasetsHandler *datasets.Handler
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
clients map[*Client]bool
|
|
db *storage.DB
|
|
jupyterServiceMgr *jupyter.ServiceManager
|
|
securityCfg *config.SecurityConfig
|
|
auditLogger *audit.Logger
|
|
authConfig *auth.Config
|
|
jobsHandler *jobs.Handler
|
|
jupyterHandler *jupyterj.Handler
|
|
upgrader websocket.Upgrader
|
|
dataDir string
|
|
clientsMu sync.RWMutex
|
|
}
|
|
|
|
// NewHandler creates a new WebSocket handler
|
|
func NewHandler(
|
|
authConfig *auth.Config,
|
|
logger *logging.Logger,
|
|
expManager *experiment.Manager,
|
|
dataDir string,
|
|
taskQueue queue.Backend,
|
|
db *storage.DB,
|
|
jupyterServiceMgr *jupyter.ServiceManager,
|
|
securityCfg *config.SecurityConfig,
|
|
auditLogger *audit.Logger,
|
|
jobsHandler *jobs.Handler,
|
|
jupyterHandler *jupyterj.Handler,
|
|
datasetsHandler *datasets.Handler,
|
|
) *Handler {
|
|
upgrader := createUpgrader(securityCfg)
|
|
|
|
return &Handler{
|
|
authConfig: authConfig,
|
|
logger: logger,
|
|
expManager: expManager,
|
|
dataDir: dataDir,
|
|
taskQueue: taskQueue,
|
|
db: db,
|
|
jupyterServiceMgr: jupyterServiceMgr,
|
|
securityCfg: securityCfg,
|
|
auditLogger: auditLogger,
|
|
upgrader: upgrader,
|
|
jobsHandler: jobsHandler,
|
|
jupyterHandler: jupyterHandler,
|
|
datasetsHandler: datasetsHandler,
|
|
clients: make(map[*Client]bool),
|
|
}
|
|
}
|
|
|
|
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
|
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
|
return websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // Allow same-origin requests
|
|
}
|
|
|
|
// Production mode: strict checking against allowed origins
|
|
if securityCfg != nil && securityCfg.ProductionMode {
|
|
return slices.Contains(securityCfg.AllowedOrigins, origin)
|
|
}
|
|
|
|
// Development mode: allow localhost and local network origins
|
|
parsedOrigin, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
host := parsedOrigin.Host
|
|
if strings.HasPrefix(host, "localhost:") ||
|
|
strings.HasPrefix(host, "127.0.0.1:") ||
|
|
strings.HasPrefix(host, "192.168.") ||
|
|
strings.HasPrefix(host, "10.") ||
|
|
strings.HasPrefix(host, "[::1]:") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
},
|
|
EnableCompression: true,
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler for WebSocket upgrade
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := h.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Error("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
h.logger.Warn("error closing websocket connection", "error", err)
|
|
}
|
|
}()
|
|
|
|
h.handleConnection(conn)
|
|
}
|
|
|
|
// handleConnection handles an established WebSocket connection
|
|
func (h *Handler) handleConnection(conn *websocket.Conn) {
|
|
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
|
|
|
|
// Register client
|
|
client := &Client{
|
|
conn: conn,
|
|
Type: ClientTypeTUI, // Assume TUI for now, could detect from handshake
|
|
User: "tui-user",
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
}
|
|
|
|
h.clientsMu.Lock()
|
|
h.clients[client] = true
|
|
h.clientsMu.Unlock()
|
|
|
|
defer func() {
|
|
h.clientsMu.Lock()
|
|
delete(h.clients, client)
|
|
h.clientsMu.Unlock()
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
for {
|
|
messageType, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(
|
|
err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure,
|
|
) {
|
|
h.logger.Error("websocket read error", "error", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if messageType != websocket.BinaryMessage {
|
|
h.logger.Warn("received non-binary message, ignoring")
|
|
continue
|
|
}
|
|
|
|
if err := h.handleMessage(conn, payload); err != nil {
|
|
h.logger.Error("message handling error", "error", err)
|
|
// Don't break, continue handling messages
|
|
}
|
|
}
|
|
|
|
h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr())
|
|
}
|
|
|
|
// handleMessage dispatches WebSocket messages to appropriate handlers
|
|
func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < 17 { // At least opcode + api_key_hash
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
opcode := payload[0] // First byte is opcode, followed by 16-byte API key hash
|
|
|
|
switch opcode {
|
|
case OpcodeAnnotateRun:
|
|
return h.handleAnnotateRun(conn, payload)
|
|
case OpcodeSetRunNarrative:
|
|
return h.handleSetRunNarrative(conn, payload)
|
|
case OpcodeStartJupyter:
|
|
return h.handleStartJupyter(conn, payload)
|
|
case OpcodeStopJupyter:
|
|
return h.handleStopJupyter(conn, payload)
|
|
case OpcodeListJupyter:
|
|
return h.handleListJupyter(conn, payload)
|
|
case OpcodeQueueJob:
|
|
return h.handleQueueJob(conn, payload)
|
|
case OpcodeQueueJobWithSnapshot:
|
|
return h.handleQueueJobWithSnapshot(conn, payload)
|
|
case OpcodeStatusRequest:
|
|
return h.handleStatusRequest(conn, payload)
|
|
case OpcodeCancelJob:
|
|
return h.handleCancelJob(conn, payload)
|
|
case OpcodePrune:
|
|
return h.handlePrune(conn, payload)
|
|
case OpcodeValidateRequest:
|
|
return h.handleValidateRequest(conn, payload)
|
|
case OpcodeLogMetric:
|
|
return h.handleLogMetric(conn, payload)
|
|
case OpcodeGetExperiment:
|
|
return h.handleGetExperiment(conn, payload)
|
|
case OpcodeDatasetList:
|
|
return h.handleDatasetList(conn, payload)
|
|
case OpcodeDatasetRegister:
|
|
return h.handleDatasetRegister(conn, payload)
|
|
case OpcodeDatasetInfo:
|
|
return h.handleDatasetInfo(conn, payload)
|
|
case OpcodeDatasetSearch:
|
|
return h.handleDatasetSearch(conn, payload)
|
|
case OpcodeCompareRuns:
|
|
return h.handleCompareRuns(conn, payload)
|
|
case OpcodeFindRuns:
|
|
return h.handleFindRuns(conn, payload)
|
|
case OpcodeExportRun:
|
|
return h.handleExportRun(conn, payload)
|
|
case OpcodeSetRunOutcome:
|
|
return h.handleSetRunOutcome(conn, payload)
|
|
default:
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
|
}
|
|
}
|
|
|
|
// sendPacket builds and sends a binary packet with type and sections
|
|
func (h *Handler) sendPacket(conn *websocket.Conn, pktType byte, sections ...[]byte) error {
|
|
var buf []byte
|
|
buf = append(buf, pktType, 0, 0, 0, 0, 0, 0, 0, 0) // Type + timestamp placeholder
|
|
for _, section := range sections {
|
|
var tmp [10]byte
|
|
n := binary.PutUvarint(tmp[:], uint64(len(section)))
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, section...)
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
|
}
|
|
|
|
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
|
return h.sendPacket(conn, PacketTypeError, []byte{code}, []byte(message), []byte(details))
|
|
}
|
|
|
|
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
|
|
payload, _ := json.Marshal(data)
|
|
return h.sendPacket(conn, PacketTypeSuccess, payload)
|
|
}
|
|
|
|
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
|
|
return h.sendPacket(conn, PacketTypeData, []byte(dataType), payload)
|
|
}
|
|
|
|
// Handler stubs - delegate to sub-packages
|
|
|
|
func (h *Handler) withAuth(
|
|
conn *websocket.Conn, payload []byte, handler func(*auth.User) error,
|
|
) error {
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
return handler(user)
|
|
}
|
|
|
|
func (h *Handler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
|
|
if h.jobsHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jobsHandler.HandleAnnotateRun(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
|
|
if h.jobsHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jobsHandler.HandleSetRunNarrative(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleStartJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleStopJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendSuccessPacket(conn, map[string]any{"success": true, "services": []any{}, "count": 0})
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleListJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8]
|
|
if len(payload) < 16+1+8 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
if nameLen <= 0 || len(payload) < offset+nameLen+8 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
|
}
|
|
name := string(payload[offset : offset+nameLen])
|
|
offset += nameLen
|
|
|
|
value := binary.BigEndian.Uint64(payload[offset : offset+8])
|
|
|
|
h.logger.Info("metric logged", "name", name, "value", value, "user", user.Name)
|
|
|
|
// Persist to database if available
|
|
if h.db != nil {
|
|
if err := h.db.RecordMetric(context.Background(), name, float64(value), user.Name); err != nil {
|
|
h.logger.Warn("failed to persist metric", "error", err, "name", name)
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"message": "Metric logged",
|
|
"metric": name,
|
|
"value": value,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][commit_id_len:1][commit_id:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
|
|
}
|
|
|
|
// Check authentication and permissions
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
commitIDLen := int(payload[offset])
|
|
offset++
|
|
if commitIDLen <= 0 || len(payload) < offset+commitIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid commit ID length", "")
|
|
}
|
|
commitID := string(payload[offset : offset+commitIDLen])
|
|
|
|
// Check if experiment exists
|
|
if h.expManager == nil || !h.expManager.ExperimentExists(commitID) {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment not found", commitID)
|
|
}
|
|
|
|
// Read experiment metadata
|
|
meta, err := h.expManager.ReadMetadata(commitID)
|
|
if err != nil {
|
|
h.logger.Warn("failed to read experiment metadata", "commit_id", commitID, "error", err)
|
|
meta = &experiment.Metadata{CommitID: commitID}
|
|
}
|
|
|
|
// Read manifest if available
|
|
manifest, _ := h.expManager.ReadManifest(commitID)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"commit_id": commitID,
|
|
"job_name": meta.JobName,
|
|
"user": meta.User,
|
|
"timestamp": meta.Timestamp,
|
|
"files_count": len(manifest.Files),
|
|
"overall_sha": manifest.OverallSHA,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetList(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendSuccessPacket(conn, map[string]any{"success": true, "message": "Dataset registered"})
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetRegister(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "dataset_info", []byte("{}"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetInfo(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetSearch(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16]
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
|
|
// Return queue status as Data packet
|
|
queueLength := 0
|
|
if h.taskQueue != nil {
|
|
if depth, err := h.taskQueue.QueueDepth(); err == nil {
|
|
queueLength = int(depth)
|
|
}
|
|
}
|
|
|
|
status := map[string]any{
|
|
"queue_length": queueLength,
|
|
"status": "ok",
|
|
"authenticated": user != nil,
|
|
"authenticated_user": user.Name,
|
|
}
|
|
|
|
payloadBytes, _ := json.Marshal(status)
|
|
return h.sendDataPacket(conn, "status", payloadBytes)
|
|
}
|
|
|
|
// selectDependencyManifest auto-detects dependency manifest file
|
|
func selectDependencyManifest(filesPath string) (string, error) {
|
|
for _, name := range []string{
|
|
"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle",
|
|
} {
|
|
if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil {
|
|
return name, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("no dependency manifest found")
|
|
}
|
|
|
|
// Authenticate validates API key from payload
|
|
func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
|
|
if len(payload) < 16 {
|
|
return nil, errors.New("payload too short")
|
|
}
|
|
return &auth.User{
|
|
Name: "websocket-user",
|
|
Admin: false,
|
|
Roles: []string{"user"},
|
|
Permissions: map[string]bool{"jobs:read": true},
|
|
}, nil
|
|
}
|
|
|
|
// RequirePermission checks user permission
|
|
func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
|
|
if user == nil {
|
|
return false
|
|
}
|
|
return user.Admin || user.Permissions[permission]
|
|
}
|
|
|
|
// handleCompareRuns compares two runs and returns differences
|
|
func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_a_len:1][run_a:var][run_b_len:1][run_b:var]
|
|
if len(payload) < 16+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "compare runs payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runALen := int(payload[offset])
|
|
offset++
|
|
if runALen <= 0 || len(payload) < offset+runALen+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run A length", "")
|
|
}
|
|
runA := string(payload[offset : offset+runALen])
|
|
offset += runALen
|
|
|
|
runBLen := int(payload[offset])
|
|
offset++
|
|
if runBLen <= 0 || len(payload) < offset+runBLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run B length", "")
|
|
}
|
|
runB := string(payload[offset : offset+runBLen])
|
|
|
|
// Fetch both experiments
|
|
metaA, errA := h.expManager.ReadMetadata(runA)
|
|
metaB, errB := h.expManager.ReadMetadata(runB)
|
|
|
|
// Build comparison result
|
|
result := map[string]any{
|
|
"run_a": runA,
|
|
"run_b": runB,
|
|
"success": true,
|
|
}
|
|
|
|
// Add metadata if available
|
|
if errA == nil && errB == nil {
|
|
result["job_name_match"] = metaA.JobName == metaB.JobName
|
|
result["user_match"] = metaA.User == metaB.User
|
|
result["timestamp_diff"] = metaB.Timestamp - metaA.Timestamp
|
|
}
|
|
|
|
// Read manifests for comparison
|
|
manifestA, _ := h.expManager.ReadManifest(runA)
|
|
manifestB, _ := h.expManager.ReadManifest(runB)
|
|
|
|
if manifestA != nil && manifestB != nil {
|
|
result["overall_sha_match"] = manifestA.OverallSHA == manifestB.OverallSHA
|
|
result["files_count_a"] = len(manifestA.Files)
|
|
result["files_count_b"] = len(manifestB.Files)
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, result)
|
|
}
|
|
|
|
// handleFindRuns searches for runs based on criteria
|
|
func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][query_len:2][query:var]
|
|
if len(payload) < 16+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "find runs payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
queryLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if queryLen > 0 && len(payload) >= offset+int(queryLen) {
|
|
// Parse query JSON
|
|
queryData := payload[offset : offset+int(queryLen)]
|
|
var query map[string]any
|
|
if err := json.Unmarshal(queryData, &query); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query JSON", err.Error())
|
|
}
|
|
|
|
h.logger.Info("search query", "query", query, "user", user.Name)
|
|
}
|
|
|
|
// For now, return placeholder results
|
|
results := []map[string]any{
|
|
{"id": "run_abc", "job_name": "train", "outcome": "validates"},
|
|
{"id": "run_def", "job_name": "eval", "outcome": "partial"},
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"results": results,
|
|
"count": len(results),
|
|
})
|
|
}
|
|
|
|
// handleExportRun exports a run with optional anonymization
|
|
func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][options_len:2][options:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "export run payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runIDLen := int(payload[offset])
|
|
offset++
|
|
if runIDLen <= 0 || len(payload) < offset+runIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
|
}
|
|
runID := string(payload[offset : offset+runIDLen])
|
|
offset += runIDLen
|
|
|
|
// Parse options if present
|
|
var options map[string]any
|
|
if len(payload) >= offset+2 {
|
|
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
|
|
err := json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid options JSON", err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if experiment exists
|
|
if !h.expManager.ExperimentExists(runID) {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
|
|
}
|
|
|
|
anonymize := false
|
|
if options != nil {
|
|
if v, ok := options["anonymize"].(bool); ok {
|
|
anonymize = v
|
|
}
|
|
}
|
|
|
|
h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"run_id": runID,
|
|
"message": "Export request received",
|
|
"anonymize": anonymize,
|
|
})
|
|
}
|
|
|
|
// handleSetRunOutcome sets the outcome for a run
|
|
func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][outcome_data_len:2][outcome_data:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run outcome payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
|
)
|
|
}
|
|
if !h.RequirePermission(user, PermJobsUpdate) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runIDLen := int(payload[offset])
|
|
offset++
|
|
if runIDLen <= 0 || len(payload) < offset+runIDLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
|
}
|
|
runID := string(payload[offset : offset+runIDLen])
|
|
offset += runIDLen
|
|
|
|
// Parse outcome data
|
|
outcomeLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if outcomeLen == 0 || len(payload) < offset+int(outcomeLen) {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome data", "")
|
|
}
|
|
|
|
var outcomeData map[string]any
|
|
if err := json.Unmarshal(payload[offset:offset+int(outcomeLen)], &outcomeData); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome JSON", err.Error())
|
|
}
|
|
|
|
// Validate outcome status
|
|
validOutcomes := map[string]bool{
|
|
"validates": true, "refutes": true, "inconclusive": true, "partial": true,
|
|
}
|
|
outcome, ok := outcomeData["outcome"].(string)
|
|
if !ok || !validOutcomes[outcome] {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeInvalidRequest,
|
|
"invalid outcome status",
|
|
"must be: validates, refutes, inconclusive, or partial",
|
|
)
|
|
}
|
|
|
|
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"run_id": runID,
|
|
"outcome": outcome,
|
|
"message": "Outcome updated",
|
|
})
|
|
}
|
|
|
|
// BroadcastJobUpdate sends job status update to all connected TUI clients
|
|
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
|
|
h.clientsMu.RLock()
|
|
defer h.clientsMu.RUnlock()
|
|
|
|
msg := map[string]any{
|
|
"type": "job_update",
|
|
"job_name": jobName,
|
|
"status": status,
|
|
"progress": progress,
|
|
"time": time.Now().Unix(),
|
|
}
|
|
|
|
payload, _ := json.Marshal(msg)
|
|
|
|
for client := range h.clients {
|
|
if client.Type == ClientTypeTUI {
|
|
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
|
h.logger.Warn("failed to broadcast to client", "error", err, "client", client.RemoteAddr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastGPUUpdate sends GPU status update to all connected TUI clients
|
|
func (h *Handler) BroadcastGPUUpdate(deviceID, utilization int, memoryUsed, memoryTotal int64) {
|
|
h.clientsMu.RLock()
|
|
defer h.clientsMu.RUnlock()
|
|
|
|
msg := map[string]any{
|
|
"type": "gpu_update",
|
|
"device_id": deviceID,
|
|
"utilization": utilization,
|
|
"memory_used": memoryUsed,
|
|
"memory_total": memoryTotal,
|
|
"time": time.Now().Unix(),
|
|
}
|
|
|
|
payload, _ := json.Marshal(msg)
|
|
|
|
for client := range h.clients {
|
|
if client.Type == ClientTypeTUI {
|
|
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
|
h.logger.Warn("failed to broadcast GPU update", "error", err, "client", client.RemoteAddr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetConnectedClientCount returns the number of connected TUI clients
|
|
func (h *Handler) GetConnectedClientCount() int {
|
|
h.clientsMu.RLock()
|
|
defer h.clientsMu.RUnlock()
|
|
|
|
count := 0
|
|
for client := range h.clients {
|
|
if client.Type == ClientTypeTUI {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|