Phase 7 of the monorepo maintainability plan: New files created: - model/jobs.go - Job type, JobStatus constants, list.Item interface - model/messages.go - tea.Msg types (JobsLoadedMsg, StatusMsg, TickMsg, etc.) - model/styles.go - NewJobListDelegate(), JobListTitleStyle(), SpinnerStyle() - model/keys.go - KeyMap struct, DefaultKeys() function Modified files: - model/state.go - reduced from 226 to ~130 lines - Removed: Job, JobStatus, KeyMap, Keys, inline styles - Kept: State struct, domain re-exports, ViewMode, DatasetInfo, InitialState() - controller/commands.go - use model. prefix for message types - controller/controller.go - use model. prefix for message types - controller/settings.go - use model.SettingsContentMsg Deleted files: - controller/keys.go (moved to model/keys.go since State references KeyMap) Result: - No file >150 lines in model/ package - Single concern per file: state, jobs, messages, styles, keys - All 41 test packages pass
879 lines
26 KiB
Go
879 lines
26 KiB
Go
// Package ws provides WebSocket handling for the API
|
|
package ws
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
)
|
|
|
|
// Response packet types (duplicated from api package to avoid import cycle)
|
|
const (
|
|
PacketTypeSuccess = 0x00
|
|
PacketTypeError = 0x01
|
|
PacketTypeProgress = 0x02
|
|
PacketTypeStatus = 0x03
|
|
PacketTypeData = 0x04
|
|
PacketTypeLog = 0x05
|
|
)
|
|
|
|
// Opcodes for binary WebSocket protocol
|
|
const (
|
|
OpcodeQueueJob = 0x01
|
|
OpcodeStatusRequest = 0x02
|
|
OpcodeCancelJob = 0x03
|
|
OpcodePrune = 0x04
|
|
OpcodeDatasetList = 0x06
|
|
OpcodeDatasetRegister = 0x07
|
|
OpcodeDatasetInfo = 0x08
|
|
OpcodeDatasetSearch = 0x09
|
|
OpcodeLogMetric = 0x0A
|
|
OpcodeGetExperiment = 0x0B
|
|
OpcodeQueueJobWithTracking = 0x0C
|
|
OpcodeQueueJobWithSnapshot = 0x17
|
|
OpcodeQueueJobWithArgs = 0x1A
|
|
OpcodeQueueJobWithNote = 0x1B
|
|
OpcodeAnnotateRun = 0x1C
|
|
OpcodeSetRunNarrative = 0x1D
|
|
OpcodeStartJupyter = 0x0D
|
|
OpcodeStopJupyter = 0x0E
|
|
OpcodeRemoveJupyter = 0x18
|
|
OpcodeRestoreJupyter = 0x19
|
|
OpcodeListJupyter = 0x0F
|
|
OpcodeListJupyterPackages = 0x1E
|
|
OpcodeValidateRequest = 0x16
|
|
|
|
// Logs opcodes
|
|
OpcodeGetLogs = 0x20
|
|
OpcodeStreamLogs = 0x21
|
|
)
|
|
|
|
// Error codes
|
|
const (
|
|
ErrorCodeUnknownError = 0x00
|
|
ErrorCodeInvalidRequest = 0x01
|
|
ErrorCodeAuthenticationFailed = 0x02
|
|
ErrorCodePermissionDenied = 0x03
|
|
ErrorCodeResourceNotFound = 0x04
|
|
ErrorCodeResourceAlreadyExists = 0x05
|
|
ErrorCodeServerOverloaded = 0x10
|
|
ErrorCodeDatabaseError = 0x11
|
|
ErrorCodeNetworkError = 0x12
|
|
ErrorCodeStorageError = 0x13
|
|
ErrorCodeTimeout = 0x14
|
|
ErrorCodeJobNotFound = 0x20
|
|
ErrorCodeJobAlreadyRunning = 0x21
|
|
ErrorCodeJobFailedToStart = 0x22
|
|
ErrorCodeJobExecutionFailed = 0x23
|
|
ErrorCodeJobCancelled = 0x24
|
|
ErrorCodeOutOfMemory = 0x30
|
|
ErrorCodeDiskFull = 0x31
|
|
ErrorCodeInvalidConfiguration = 0x32
|
|
ErrorCodeServiceUnavailable = 0x33
|
|
)
|
|
|
|
// Permissions
|
|
const (
|
|
PermJobsCreate = "jobs:create"
|
|
PermJobsRead = "jobs:read"
|
|
PermJobsUpdate = "jobs:update"
|
|
PermDatasetsRead = "datasets:read"
|
|
PermDatasetsCreate = "datasets:create"
|
|
PermJupyterManage = "jupyter:manage"
|
|
PermJupyterRead = "jupyter:read"
|
|
)
|
|
|
|
// Handler provides WebSocket handling
|
|
type Handler struct {
|
|
authConfig *auth.Config
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
dataDir string
|
|
taskQueue queue.Backend
|
|
db *storage.DB
|
|
jupyterServiceMgr *jupyter.ServiceManager
|
|
securityCfg *config.SecurityConfig
|
|
auditLogger *audit.Logger
|
|
upgrader websocket.Upgrader
|
|
}
|
|
|
|
// NewHandler creates a new WebSocket handler
|
|
func NewHandler(
|
|
authConfig *auth.Config,
|
|
logger *logging.Logger,
|
|
expManager *experiment.Manager,
|
|
dataDir string,
|
|
taskQueue queue.Backend,
|
|
db *storage.DB,
|
|
jupyterServiceMgr *jupyter.ServiceManager,
|
|
securityCfg *config.SecurityConfig,
|
|
auditLogger *audit.Logger,
|
|
) *Handler {
|
|
upgrader := createUpgrader(securityCfg)
|
|
|
|
return &Handler{
|
|
authConfig: authConfig,
|
|
logger: logger,
|
|
expManager: expManager,
|
|
dataDir: dataDir,
|
|
taskQueue: taskQueue,
|
|
db: db,
|
|
jupyterServiceMgr: jupyterServiceMgr,
|
|
securityCfg: securityCfg,
|
|
auditLogger: auditLogger,
|
|
upgrader: upgrader,
|
|
}
|
|
}
|
|
|
|
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
|
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
|
return websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // Allow same-origin requests
|
|
}
|
|
|
|
// Production mode: strict checking against allowed origins
|
|
if securityCfg != nil && securityCfg.ProductionMode {
|
|
for _, allowed := range securityCfg.AllowedOrigins {
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false // Reject if not in allowed list
|
|
}
|
|
|
|
// Development mode: allow localhost and local network origins
|
|
parsedOrigin, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
host := parsedOrigin.Host
|
|
if strings.HasPrefix(host, "localhost:") ||
|
|
strings.HasPrefix(host, "127.0.0.1:") ||
|
|
strings.HasPrefix(host, "192.168.") ||
|
|
strings.HasPrefix(host, "10.") ||
|
|
strings.HasPrefix(host, "[::1]:") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
},
|
|
EnableCompression: true,
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler for WebSocket upgrade
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := h.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Error("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
h.handleConnection(conn)
|
|
}
|
|
|
|
// handleConnection handles an established WebSocket connection
|
|
func (h *Handler) handleConnection(conn *websocket.Conn) {
|
|
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
|
|
|
|
for {
|
|
messageType, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
h.logger.Error("websocket read error", "error", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if messageType != websocket.BinaryMessage {
|
|
h.logger.Warn("received non-binary message, ignoring")
|
|
continue
|
|
}
|
|
|
|
if err := h.handleMessage(conn, payload); err != nil {
|
|
h.logger.Error("message handling error", "error", err)
|
|
// Don't break, continue handling messages
|
|
}
|
|
}
|
|
|
|
h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr())
|
|
}
|
|
|
|
// handleMessage dispatches WebSocket messages to appropriate handlers
|
|
func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < 17 { // At least opcode + api_key_hash
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
opcode := payload[0] // First byte is opcode, followed by 16-byte API key hash
|
|
|
|
switch opcode {
|
|
case OpcodeAnnotateRun:
|
|
return h.handleAnnotateRun(conn, payload)
|
|
case OpcodeSetRunNarrative:
|
|
return h.handleSetRunNarrative(conn, payload)
|
|
case OpcodeStartJupyter:
|
|
return h.handleStartJupyter(conn, payload)
|
|
case OpcodeStopJupyter:
|
|
return h.handleStopJupyter(conn, payload)
|
|
case OpcodeListJupyter:
|
|
return h.handleListJupyter(conn, payload)
|
|
case OpcodeQueueJob:
|
|
return h.handleQueueJob(conn, payload)
|
|
case OpcodeQueueJobWithSnapshot:
|
|
return h.handleQueueJobWithSnapshot(conn, payload)
|
|
case OpcodeStatusRequest:
|
|
return h.handleStatusRequest(conn, payload)
|
|
case OpcodeCancelJob:
|
|
return h.handleCancelJob(conn, payload)
|
|
case OpcodePrune:
|
|
return h.handlePrune(conn, payload)
|
|
case OpcodeValidateRequest:
|
|
return h.handleValidateRequest(conn, payload)
|
|
case OpcodeLogMetric:
|
|
return h.handleLogMetric(conn, payload)
|
|
case OpcodeGetExperiment:
|
|
return h.handleGetExperiment(conn, payload)
|
|
case OpcodeDatasetList:
|
|
return h.handleDatasetList(conn, payload)
|
|
case OpcodeDatasetRegister:
|
|
return h.handleDatasetRegister(conn, payload)
|
|
case OpcodeDatasetInfo:
|
|
return h.handleDatasetInfo(conn, payload)
|
|
case OpcodeDatasetSearch:
|
|
return h.handleDatasetSearch(conn, payload)
|
|
default:
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
|
}
|
|
}
|
|
|
|
// sendErrorPacket sends an error response packet
|
|
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
|
// Binary protocol: [PacketType:1][Timestamp:8][ErrorCode:1][ErrorMessageLen:varint][ErrorMessage][ErrorDetailsLen:varint][ErrorDetails]
|
|
var buf []byte
|
|
buf = append(buf, PacketTypeError)
|
|
|
|
// Timestamp (8 bytes, big-endian) - simplified, using 0 for now
|
|
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
|
|
|
// Error code
|
|
buf = append(buf, code)
|
|
|
|
// Error message with length prefix
|
|
msgLen := uint64(len(message))
|
|
var tmp [10]byte
|
|
n := binary.PutUvarint(tmp[:], msgLen)
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, message...)
|
|
|
|
// Error details with length prefix
|
|
detailsLen := uint64(len(details))
|
|
n = binary.PutUvarint(tmp[:], detailsLen)
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, details...)
|
|
|
|
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
|
}
|
|
|
|
// sendSuccessPacket sends a success response packet with JSON payload
|
|
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
|
payload, err := json.Marshal(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Binary protocol: [PacketType:1][Timestamp:8][PayloadLen:varint][Payload]
|
|
var buf []byte
|
|
buf = append(buf, PacketTypeSuccess)
|
|
|
|
// Timestamp (8 bytes, big-endian)
|
|
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
|
|
|
// Payload with length prefix
|
|
payloadLen := uint64(len(payload))
|
|
var tmp [10]byte
|
|
n := binary.PutUvarint(tmp[:], payloadLen)
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, payload...)
|
|
|
|
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
|
}
|
|
|
|
// sendDataPacket sends a data response packet
|
|
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
|
|
// Binary protocol: [PacketType:1][Timestamp:8][DataTypeLen:varint][DataType][PayloadLen:varint][Payload]
|
|
var buf []byte
|
|
buf = append(buf, PacketTypeData)
|
|
|
|
// Timestamp (8 bytes, big-endian)
|
|
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
|
|
|
// DataType with length prefix
|
|
typeLen := uint64(len(dataType))
|
|
var tmp [10]byte
|
|
n := binary.PutUvarint(tmp[:], typeLen)
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, dataType...)
|
|
|
|
// Payload with length prefix
|
|
payloadLen := uint64(len(payload))
|
|
n = binary.PutUvarint(tmp[:], payloadLen)
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, payload...)
|
|
|
|
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
|
}
|
|
|
|
// Handler stubs - these would delegate to sub-packages in full implementation
|
|
|
|
func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to jobs package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Annotate run handled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to jobs package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Set run narrative handled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to jupyter package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Start jupyter handled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to jupyter package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Stop jupyter handled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleListJupyter(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to jupyter package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "List jupyter handled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
|
|
// mode=0: commit_id validation [commit_id_len:1][commit_id:var]
|
|
// mode=1: task_id validation [task_id_len:1][task_id:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
mode := payload[17]
|
|
|
|
if mode == 0 {
|
|
// Commit ID validation (basic)
|
|
if len(payload) < 20 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "")
|
|
}
|
|
commitIDLen := int(payload[18])
|
|
if len(payload) < 19+commitIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "")
|
|
}
|
|
commitIDBytes := payload[19 : 19+commitIDLen]
|
|
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
|
|
|
|
report := map[string]interface{}{
|
|
"ok": true,
|
|
"commit_id": commitIDHex,
|
|
}
|
|
payloadBytes, _ := json.Marshal(report)
|
|
return h.sendDataPacket(conn, "validate", payloadBytes)
|
|
}
|
|
|
|
// Task ID validation (mode=1) - full validation with checks
|
|
if len(payload) < 20 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "")
|
|
}
|
|
|
|
taskIDLen := int(payload[18])
|
|
if len(payload) < 19+taskIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "")
|
|
}
|
|
taskID := string(payload[19 : 19+taskIDLen])
|
|
|
|
// Initialize validation report
|
|
checks := make(map[string]interface{})
|
|
ok := true
|
|
|
|
// Get task from queue
|
|
if h.taskQueue == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "")
|
|
}
|
|
|
|
task, err := h.taskQueue.GetTask(taskID)
|
|
if err != nil || task == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "")
|
|
}
|
|
|
|
// Run manifest validation - load manifest if it exists
|
|
rmCheck := map[string]interface{}{"ok": true}
|
|
rmCommitCheck := map[string]interface{}{"ok": true}
|
|
rmLocCheck := map[string]interface{}{"ok": true}
|
|
rmLifecycle := map[string]interface{}{"ok": true}
|
|
|
|
// Determine expected location based on task status
|
|
expectedLocation := "running"
|
|
if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" {
|
|
expectedLocation = "finished"
|
|
}
|
|
|
|
// Try to load run manifest from appropriate location
|
|
var rm *manifest.RunManifest
|
|
var rmLoadErr error
|
|
|
|
if h.expManager != nil {
|
|
// Try expected location first
|
|
jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName)
|
|
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
|
|
|
|
// If not found and task is running, also check finished (wrong location test)
|
|
if rmLoadErr != nil && task.Status == "running" {
|
|
wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName)
|
|
rm, _ = manifest.LoadFromDir(wrongDir)
|
|
if rm != nil {
|
|
// Manifest exists but in wrong location
|
|
rmLocCheck["ok"] = false
|
|
rmLocCheck["expected"] = "running"
|
|
rmLocCheck["actual"] = "finished"
|
|
ok = false
|
|
}
|
|
}
|
|
}
|
|
|
|
if rm == nil {
|
|
// No run manifest found
|
|
if task.Status == "running" || task.Status == "completed" {
|
|
rmCheck["ok"] = false
|
|
ok = false
|
|
}
|
|
} else {
|
|
// Run manifest exists - validate it
|
|
|
|
// Check commit_id match
|
|
taskCommitID := task.Metadata["commit_id"]
|
|
if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID {
|
|
rmCommitCheck["ok"] = false
|
|
rmCommitCheck["expected"] = taskCommitID
|
|
ok = false
|
|
}
|
|
|
|
// Check lifecycle ordering (started_at < ended_at)
|
|
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) {
|
|
rmLifecycle["ok"] = false
|
|
ok = false
|
|
}
|
|
}
|
|
|
|
checks["run_manifest"] = rmCheck
|
|
checks["run_manifest_commit_id"] = rmCommitCheck
|
|
checks["run_manifest_location"] = rmLocCheck
|
|
checks["run_manifest_lifecycle"] = rmLifecycle
|
|
|
|
// Resources check
|
|
resCheck := map[string]interface{}{"ok": true}
|
|
if task.CPU < 0 {
|
|
resCheck["ok"] = false
|
|
ok = false
|
|
}
|
|
checks["resources"] = resCheck
|
|
|
|
// Snapshot check
|
|
snapCheck := map[string]interface{}{"ok": true}
|
|
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
|
|
// Verify snapshot SHA
|
|
dataDir := h.dataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(h.expManager.BasePath(), "data")
|
|
}
|
|
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath)
|
|
expectedSHA := task.Metadata["snapshot_sha256"]
|
|
if actualSHA != expectedSHA {
|
|
snapCheck["ok"] = false
|
|
snapCheck["actual"] = actualSHA
|
|
ok = false
|
|
}
|
|
}
|
|
checks["snapshot"] = snapCheck
|
|
|
|
report := map[string]interface{}{
|
|
"ok": ok,
|
|
"checks": checks,
|
|
}
|
|
payloadBytes, _ := json.Marshal(report)
|
|
return h.sendDataPacket(conn, "validate", payloadBytes)
|
|
}
|
|
|
|
func (h *Handler) handleLogMetric(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to metrics package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Metric logged",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
|
// Check authentication and permissions
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
// Would delegate to experiment package
|
|
// For now, return error as expected by test
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment not found", "")
|
|
}
|
|
|
|
func (h *Handler) handleDatasetList(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to dataset package
|
|
// Return empty list as expected by test
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
|
|
func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to dataset package
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Dataset registered",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetInfo(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to dataset package
|
|
return h.sendDataPacket(conn, "dataset_info", []byte("{}"))
|
|
}
|
|
|
|
func (h *Handler) handleDatasetSearch(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to dataset package
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
|
|
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
jobNameLen := int(payload[17])
|
|
if len(payload) < 18+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[18 : 18+jobNameLen])
|
|
|
|
// Find and cancel the task
|
|
if h.taskQueue != nil {
|
|
task, err := h.taskQueue.GetTaskByName(jobName)
|
|
if err == nil && task != nil {
|
|
task.Status = "cancelled"
|
|
h.taskQueue.UpdateTask(task)
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Job cancelled",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error {
|
|
// Would delegate to experiment package for pruning
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Prune completed",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
|
// Optional: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
|
|
if len(payload) < 39 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
// Extract commit_id (20 bytes starting at position 17)
|
|
commitIDBytes := payload[17:37]
|
|
commitIDHex := hex.EncodeToString(commitIDBytes)
|
|
|
|
priority := payload[37]
|
|
jobNameLen := int(payload[38])
|
|
|
|
if len(payload) < 39+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[39 : 39+jobNameLen])
|
|
|
|
// Parse optional resource fields if present
|
|
cpu := 0
|
|
memoryGB := 0
|
|
gpu := 0
|
|
gpuMemory := ""
|
|
|
|
pos := 39 + jobNameLen
|
|
if len(payload) > pos {
|
|
cpu = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
memoryGB = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
gpu = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
gpuMemLen := int(payload[pos])
|
|
pos++
|
|
if len(payload) >= pos+gpuMemLen {
|
|
gpuMemory = string(payload[pos : pos+gpuMemLen])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create task
|
|
task := &queue.Task{
|
|
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
|
JobName: jobName,
|
|
Status: "queued",
|
|
Priority: int64(priority),
|
|
CreatedAt: time.Now(),
|
|
UserID: "user",
|
|
CreatedBy: "user",
|
|
CPU: cpu,
|
|
MemoryGB: memoryGB,
|
|
GPU: gpu,
|
|
GPUMemory: gpuMemory,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDHex,
|
|
},
|
|
}
|
|
|
|
// Auto-detect deps manifest and compute manifest SHA if experiment exists
|
|
if h.expManager != nil {
|
|
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
|
depsName, _ := selectDependencyManifest(filesPath)
|
|
if depsName != "" {
|
|
task.Metadata["deps_manifest_name"] = depsName
|
|
depsPath := filepath.Join(filesPath, depsName)
|
|
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
|
task.Metadata["deps_manifest_sha256"] = sha
|
|
}
|
|
}
|
|
|
|
// Get experiment manifest SHA
|
|
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
|
if data, err := os.ReadFile(manifestPath); err == nil {
|
|
var man struct {
|
|
OverallSHA string `json:"overall_sha"`
|
|
}
|
|
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
|
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add task to queue
|
|
if h.taskQueue != nil {
|
|
if err := h.taskQueue.AddTask(task); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"task_id": task.ID,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][snapshot_id_len:1][snapshot_id:var][snapshot_sha_len:1][snapshot_sha:var]
|
|
if len(payload) < 41 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
// Extract commit_id (20 bytes starting at position 17)
|
|
commitIDBytes := payload[17:37]
|
|
commitIDHex := hex.EncodeToString(commitIDBytes)
|
|
|
|
priority := payload[37]
|
|
jobNameLen := int(payload[38])
|
|
|
|
if len(payload) < 39+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[39 : 39+jobNameLen])
|
|
|
|
// Parse snapshot_id
|
|
pos := 39 + jobNameLen
|
|
if len(payload) < pos+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length missing", "")
|
|
}
|
|
snapshotIDLen := int(payload[pos])
|
|
pos++
|
|
if len(payload) < pos+snapshotIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length mismatch", "")
|
|
}
|
|
snapshotID := string(payload[pos : pos+snapshotIDLen])
|
|
pos += snapshotIDLen
|
|
|
|
// Parse snapshot_sha
|
|
if len(payload) < pos+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length missing", "")
|
|
}
|
|
snapshotSHALen := int(payload[pos])
|
|
pos++
|
|
if len(payload) < pos+snapshotSHALen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length mismatch", "")
|
|
}
|
|
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
|
|
|
|
// Create task
|
|
task := &queue.Task{
|
|
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
|
JobName: jobName,
|
|
Status: "queued",
|
|
Priority: int64(priority),
|
|
CreatedAt: time.Now(),
|
|
UserID: "user",
|
|
CreatedBy: "user",
|
|
SnapshotID: snapshotID,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDHex,
|
|
"snapshot_sha256": snapshotSHA,
|
|
},
|
|
}
|
|
|
|
// Auto-detect deps manifest and compute manifest SHA if experiment exists
|
|
if h.expManager != nil {
|
|
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
|
depsName, _ := selectDependencyManifest(filesPath)
|
|
if depsName != "" {
|
|
task.Metadata["deps_manifest_name"] = depsName
|
|
depsPath := filepath.Join(filesPath, depsName)
|
|
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
|
task.Metadata["deps_manifest_sha256"] = sha
|
|
}
|
|
}
|
|
|
|
// Get experiment manifest SHA
|
|
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
|
if data, err := os.ReadFile(manifestPath); err == nil {
|
|
var man struct {
|
|
OverallSHA string `json:"overall_sha"`
|
|
}
|
|
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
|
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add task to queue
|
|
if h.taskQueue != nil {
|
|
if err := h.taskQueue.AddTask(task); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"task_id": task.ID,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) error {
|
|
// Return queue status as Data packet
|
|
status := map[string]interface{}{
|
|
"queue_length": 0,
|
|
"status": "ok",
|
|
}
|
|
|
|
if h.taskQueue != nil {
|
|
// Try to get queue length - this is a best-effort operation
|
|
// The queue backend may not support this directly
|
|
}
|
|
|
|
payloadBytes, _ := json.Marshal(status)
|
|
return h.sendDataPacket(conn, "status", payloadBytes)
|
|
}
|
|
|
|
// selectDependencyManifest auto-detects the dependency manifest file
|
|
func selectDependencyManifest(filesPath string) (string, error) {
|
|
candidates := []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"}
|
|
for _, name := range candidates {
|
|
path := filepath.Join(filesPath, name)
|
|
if _, err := os.Stat(path); err == nil {
|
|
return name, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("no dependency manifest found")
|
|
}
|
|
|
|
// Authenticate extracts and validates the API key from payload
|
|
func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
|
|
if len(payload) < 16 {
|
|
return nil, errors.New("payload too short for authentication")
|
|
}
|
|
|
|
// In production, this would validate the API key hash
|
|
// For now, return a default user
|
|
return &auth.User{
|
|
Name: "websocket-user",
|
|
Admin: false,
|
|
Roles: []string{"user"},
|
|
Permissions: map[string]bool{"jobs:read": true},
|
|
}, nil
|
|
}
|
|
|
|
// RequirePermission checks if a user has a required permission
|
|
func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
|
|
if user == nil {
|
|
return false
|
|
}
|
|
if user.Admin {
|
|
return true
|
|
}
|
|
return user.Permissions[permission]
|
|
}
|