Add comprehensive research context tracking to jobs: - Narrative fields: hypothesis, context, intent, expected_outcome - Experiment groups and tags for organization - Run comparison (compare command) for diff analysis - Run search (find command) with criteria filtering - Run export (export command) for data portability - Outcome setting (outcome command) for experiment validation Update queue and requeue commands to support narrative fields. Add narrative validation to manifest validator. Add WebSocket handlers for compare, find, export, and outcome operations. Includes E2E tests for phase 2 features.
767 lines
24 KiB
Go
767 lines
24 KiB
Go
// Package ws provides WebSocket handling for the API
|
|
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
)
|
|
|
|
// Response packet types (duplicated from api package to avoid import cycle)
|
|
const (
|
|
PacketTypeSuccess = 0x00
|
|
PacketTypeError = 0x01
|
|
PacketTypeProgress = 0x02
|
|
PacketTypeStatus = 0x03
|
|
PacketTypeData = 0x04
|
|
PacketTypeLog = 0x05
|
|
)
|
|
|
|
// Opcodes for binary WebSocket protocol
|
|
const (
|
|
OpcodeQueueJob = 0x01
|
|
OpcodeStatusRequest = 0x02
|
|
OpcodeCancelJob = 0x03
|
|
OpcodePrune = 0x04
|
|
OpcodeDatasetList = 0x06
|
|
OpcodeDatasetRegister = 0x07
|
|
OpcodeDatasetInfo = 0x08
|
|
OpcodeDatasetSearch = 0x09
|
|
OpcodeLogMetric = 0x0A
|
|
OpcodeGetExperiment = 0x0B
|
|
OpcodeQueueJobWithTracking = 0x0C
|
|
OpcodeQueueJobWithSnapshot = 0x17
|
|
OpcodeQueueJobWithArgs = 0x1A
|
|
OpcodeQueueJobWithNote = 0x1B
|
|
OpcodeAnnotateRun = 0x1C
|
|
OpcodeSetRunNarrative = 0x1D
|
|
OpcodeStartJupyter = 0x0D
|
|
OpcodeStopJupyter = 0x0E
|
|
OpcodeRemoveJupyter = 0x18
|
|
OpcodeRestoreJupyter = 0x19
|
|
OpcodeListJupyter = 0x0F
|
|
OpcodeListJupyterPackages = 0x1E
|
|
OpcodeValidateRequest = 0x16
|
|
|
|
// Logs opcodes
|
|
OpcodeGetLogs = 0x20
|
|
OpcodeStreamLogs = 0x21
|
|
|
|
//
|
|
OpcodeCompareRuns = 0x30
|
|
OpcodeFindRuns = 0x31
|
|
OpcodeExportRun = 0x32
|
|
OpcodeSetRunOutcome = 0x33
|
|
)
|
|
|
|
// Error codes
|
|
const (
|
|
ErrorCodeUnknownError = 0x00
|
|
ErrorCodeInvalidRequest = 0x01
|
|
ErrorCodeAuthenticationFailed = 0x02
|
|
ErrorCodePermissionDenied = 0x03
|
|
ErrorCodeResourceNotFound = 0x04
|
|
ErrorCodeResourceAlreadyExists = 0x05
|
|
ErrorCodeServerOverloaded = 0x10
|
|
ErrorCodeDatabaseError = 0x11
|
|
ErrorCodeNetworkError = 0x12
|
|
ErrorCodeStorageError = 0x13
|
|
ErrorCodeTimeout = 0x14
|
|
ErrorCodeJobNotFound = 0x20
|
|
ErrorCodeJobAlreadyRunning = 0x21
|
|
ErrorCodeJobFailedToStart = 0x22
|
|
ErrorCodeJobExecutionFailed = 0x23
|
|
ErrorCodeJobCancelled = 0x24
|
|
ErrorCodeOutOfMemory = 0x30
|
|
ErrorCodeDiskFull = 0x31
|
|
ErrorCodeInvalidConfiguration = 0x32
|
|
ErrorCodeServiceUnavailable = 0x33
|
|
)
|
|
|
|
// Permissions
|
|
const (
|
|
PermJobsCreate = "jobs:create"
|
|
PermJobsRead = "jobs:read"
|
|
PermJobsUpdate = "jobs:update"
|
|
PermDatasetsRead = "datasets:read"
|
|
PermDatasetsCreate = "datasets:create"
|
|
PermJupyterManage = "jupyter:manage"
|
|
PermJupyterRead = "jupyter:read"
|
|
)
|
|
|
|
// Handler provides WebSocket handling
|
|
type Handler struct {
|
|
authConfig *auth.Config
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
dataDir string
|
|
taskQueue queue.Backend
|
|
db *storage.DB
|
|
jupyterServiceMgr *jupyter.ServiceManager
|
|
securityCfg *config.SecurityConfig
|
|
auditLogger *audit.Logger
|
|
upgrader websocket.Upgrader
|
|
jobsHandler *jobs.Handler
|
|
jupyterHandler *jupyterj.Handler
|
|
datasetsHandler *datasets.Handler
|
|
}
|
|
|
|
// NewHandler creates a new WebSocket handler
|
|
func NewHandler(
|
|
authConfig *auth.Config,
|
|
logger *logging.Logger,
|
|
expManager *experiment.Manager,
|
|
dataDir string,
|
|
taskQueue queue.Backend,
|
|
db *storage.DB,
|
|
jupyterServiceMgr *jupyter.ServiceManager,
|
|
securityCfg *config.SecurityConfig,
|
|
auditLogger *audit.Logger,
|
|
jobsHandler *jobs.Handler,
|
|
jupyterHandler *jupyterj.Handler,
|
|
datasetsHandler *datasets.Handler,
|
|
) *Handler {
|
|
upgrader := createUpgrader(securityCfg)
|
|
|
|
return &Handler{
|
|
authConfig: authConfig,
|
|
logger: logger,
|
|
expManager: expManager,
|
|
dataDir: dataDir,
|
|
taskQueue: taskQueue,
|
|
db: db,
|
|
jupyterServiceMgr: jupyterServiceMgr,
|
|
securityCfg: securityCfg,
|
|
auditLogger: auditLogger,
|
|
upgrader: upgrader,
|
|
jobsHandler: jobsHandler,
|
|
jupyterHandler: jupyterHandler,
|
|
datasetsHandler: datasetsHandler,
|
|
}
|
|
}
|
|
|
|
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
|
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
|
return websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // Allow same-origin requests
|
|
}
|
|
|
|
// Production mode: strict checking against allowed origins
|
|
if securityCfg != nil && securityCfg.ProductionMode {
|
|
for _, allowed := range securityCfg.AllowedOrigins {
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false // Reject if not in allowed list
|
|
}
|
|
|
|
// Development mode: allow localhost and local network origins
|
|
parsedOrigin, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
host := parsedOrigin.Host
|
|
if strings.HasPrefix(host, "localhost:") ||
|
|
strings.HasPrefix(host, "127.0.0.1:") ||
|
|
strings.HasPrefix(host, "192.168.") ||
|
|
strings.HasPrefix(host, "10.") ||
|
|
strings.HasPrefix(host, "[::1]:") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
},
|
|
EnableCompression: true,
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler for WebSocket upgrade
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := h.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Error("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
h.handleConnection(conn)
|
|
}
|
|
|
|
// handleConnection handles an established WebSocket connection
|
|
func (h *Handler) handleConnection(conn *websocket.Conn) {
|
|
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
|
|
|
|
for {
|
|
messageType, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
h.logger.Error("websocket read error", "error", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if messageType != websocket.BinaryMessage {
|
|
h.logger.Warn("received non-binary message, ignoring")
|
|
continue
|
|
}
|
|
|
|
if err := h.handleMessage(conn, payload); err != nil {
|
|
h.logger.Error("message handling error", "error", err)
|
|
// Don't break, continue handling messages
|
|
}
|
|
}
|
|
|
|
h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr())
|
|
}
|
|
|
|
// handleMessage dispatches WebSocket messages to appropriate handlers
|
|
func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < 17 { // At least opcode + api_key_hash
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
opcode := payload[0] // First byte is opcode, followed by 16-byte API key hash
|
|
|
|
switch opcode {
|
|
case OpcodeAnnotateRun:
|
|
return h.handleAnnotateRun(conn, payload)
|
|
case OpcodeSetRunNarrative:
|
|
return h.handleSetRunNarrative(conn, payload)
|
|
case OpcodeStartJupyter:
|
|
return h.handleStartJupyter(conn, payload)
|
|
case OpcodeStopJupyter:
|
|
return h.handleStopJupyter(conn, payload)
|
|
case OpcodeListJupyter:
|
|
return h.handleListJupyter(conn, payload)
|
|
case OpcodeQueueJob:
|
|
return h.handleQueueJob(conn, payload)
|
|
case OpcodeQueueJobWithSnapshot:
|
|
return h.handleQueueJobWithSnapshot(conn, payload)
|
|
case OpcodeStatusRequest:
|
|
return h.handleStatusRequest(conn, payload)
|
|
case OpcodeCancelJob:
|
|
return h.handleCancelJob(conn, payload)
|
|
case OpcodePrune:
|
|
return h.handlePrune(conn, payload)
|
|
case OpcodeValidateRequest:
|
|
return h.handleValidateRequest(conn, payload)
|
|
case OpcodeLogMetric:
|
|
return h.handleLogMetric(conn, payload)
|
|
case OpcodeGetExperiment:
|
|
return h.handleGetExperiment(conn, payload)
|
|
case OpcodeDatasetList:
|
|
return h.handleDatasetList(conn, payload)
|
|
case OpcodeDatasetRegister:
|
|
return h.handleDatasetRegister(conn, payload)
|
|
case OpcodeDatasetInfo:
|
|
return h.handleDatasetInfo(conn, payload)
|
|
case OpcodeDatasetSearch:
|
|
return h.handleDatasetSearch(conn, payload)
|
|
case OpcodeCompareRuns:
|
|
return h.handleCompareRuns(conn, payload)
|
|
case OpcodeFindRuns:
|
|
return h.handleFindRuns(conn, payload)
|
|
case OpcodeExportRun:
|
|
return h.handleExportRun(conn, payload)
|
|
case OpcodeSetRunOutcome:
|
|
return h.handleSetRunOutcome(conn, payload)
|
|
default:
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
|
}
|
|
}
|
|
|
|
// sendPacket builds and sends a binary packet with type and sections
|
|
func (h *Handler) sendPacket(conn *websocket.Conn, pktType byte, sections ...[]byte) error {
|
|
var buf []byte
|
|
buf = append(buf, pktType, 0, 0, 0, 0, 0, 0, 0, 0) // Type + timestamp placeholder
|
|
for _, section := range sections {
|
|
var tmp [10]byte
|
|
n := binary.PutUvarint(tmp[:], uint64(len(section)))
|
|
buf = append(buf, tmp[:n]...)
|
|
buf = append(buf, section...)
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
|
}
|
|
|
|
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
|
return h.sendPacket(conn, PacketTypeError, []byte{code}, []byte(message), []byte(details))
|
|
}
|
|
|
|
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
|
|
payload, _ := json.Marshal(data)
|
|
return h.sendPacket(conn, PacketTypeSuccess, payload)
|
|
}
|
|
|
|
func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error {
|
|
return h.sendPacket(conn, PacketTypeData, []byte(dataType), payload)
|
|
}
|
|
|
|
// Handler stubs - delegate to sub-packages
|
|
|
|
func (h *Handler) withAuth(conn *websocket.Conn, payload []byte, handler func(*auth.User) error) error {
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
return handler(user)
|
|
}
|
|
|
|
func (h *Handler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
|
|
if h.jobsHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jobsHandler.HandleAnnotateRun(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
|
|
if h.jobsHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jobs handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jobsHandler.HandleSetRunNarrative(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleStartJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "jupyter handler not available", "")
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleStopJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
|
if h.jupyterHandler == nil {
|
|
return h.sendSuccessPacket(conn, map[string]any{"success": true, "services": []any{}, "count": 0})
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.jupyterHandler.HandleListJupyter(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8]
|
|
if len(payload) < 16+1+8 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
if nameLen <= 0 || len(payload) < offset+nameLen+8 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
|
}
|
|
name := string(payload[offset : offset+nameLen])
|
|
offset += nameLen
|
|
|
|
value := binary.BigEndian.Uint64(payload[offset : offset+8])
|
|
|
|
h.logger.Info("metric logged", "name", name, "value", value, "user", user.Name)
|
|
|
|
// Persist to database if available
|
|
if h.db != nil {
|
|
if err := h.db.RecordMetric(context.Background(), name, float64(value), user.Name); err != nil {
|
|
h.logger.Warn("failed to persist metric", "error", err, "name", name)
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"message": "Metric logged",
|
|
"metric": name,
|
|
"value": value,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][commit_id_len:1][commit_id:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
|
|
}
|
|
|
|
// Check authentication and permissions
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
commitIDLen := int(payload[offset])
|
|
offset++
|
|
if commitIDLen <= 0 || len(payload) < offset+commitIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid commit ID length", "")
|
|
}
|
|
commitID := string(payload[offset : offset+commitIDLen])
|
|
|
|
// Check if experiment exists
|
|
if h.expManager == nil || !h.expManager.ExperimentExists(commitID) {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment not found", commitID)
|
|
}
|
|
|
|
// Read experiment metadata
|
|
meta, err := h.expManager.ReadMetadata(commitID)
|
|
if err != nil {
|
|
h.logger.Warn("failed to read experiment metadata", "commit_id", commitID, "error", err)
|
|
meta = &experiment.Metadata{CommitID: commitID}
|
|
}
|
|
|
|
// Read manifest if available
|
|
manifest, _ := h.expManager.ReadManifest(commitID)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"commit_id": commitID,
|
|
"job_name": meta.JobName,
|
|
"user": meta.User,
|
|
"timestamp": meta.Timestamp,
|
|
"files_count": len(manifest.Files),
|
|
"overall_sha": manifest.OverallSHA,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetList(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendSuccessPacket(conn, map[string]any{"success": true, "message": "Dataset registered"})
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetRegister(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "dataset_info", []byte("{}"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetInfo(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
|
if h.datasetsHandler == nil {
|
|
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
|
}
|
|
return h.withAuth(conn, payload, func(user *auth.User) error {
|
|
return h.datasetsHandler.HandleDatasetSearch(conn, payload, user)
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16]
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
|
|
// Return queue status as Data packet
|
|
queueLength := 0
|
|
if h.taskQueue != nil {
|
|
if depth, err := h.taskQueue.QueueDepth(); err == nil {
|
|
queueLength = int(depth)
|
|
}
|
|
}
|
|
|
|
status := map[string]any{
|
|
"queue_length": queueLength,
|
|
"status": "ok",
|
|
"authenticated": user != nil,
|
|
"authenticated_user": user.Name,
|
|
}
|
|
|
|
payloadBytes, _ := json.Marshal(status)
|
|
return h.sendDataPacket(conn, "status", payloadBytes)
|
|
}
|
|
|
|
// selectDependencyManifest auto-detects dependency manifest file
|
|
func selectDependencyManifest(filesPath string) (string, error) {
|
|
for _, name := range []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} {
|
|
if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil {
|
|
return name, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("no dependency manifest found")
|
|
}
|
|
|
|
// Authenticate validates API key from payload
|
|
func (h *Handler) Authenticate(payload []byte) (*auth.User, error) {
|
|
if len(payload) < 16 {
|
|
return nil, errors.New("payload too short")
|
|
}
|
|
return &auth.User{Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}}, nil
|
|
}
|
|
|
|
// RequirePermission checks user permission
|
|
func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
|
|
if user == nil {
|
|
return false
|
|
}
|
|
return user.Admin || user.Permissions[permission]
|
|
}
|
|
|
|
// handleCompareRuns compares two runs and returns differences
|
|
func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_a_len:1][run_a:var][run_b_len:1][run_b:var]
|
|
if len(payload) < 16+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "compare runs payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runALen := int(payload[offset])
|
|
offset++
|
|
if runALen <= 0 || len(payload) < offset+runALen+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run A length", "")
|
|
}
|
|
runA := string(payload[offset : offset+runALen])
|
|
offset += runALen
|
|
|
|
runBLen := int(payload[offset])
|
|
offset++
|
|
if runBLen <= 0 || len(payload) < offset+runBLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run B length", "")
|
|
}
|
|
runB := string(payload[offset : offset+runBLen])
|
|
|
|
// Fetch both experiments
|
|
metaA, errA := h.expManager.ReadMetadata(runA)
|
|
metaB, errB := h.expManager.ReadMetadata(runB)
|
|
|
|
// Build comparison result
|
|
result := map[string]any{
|
|
"run_a": runA,
|
|
"run_b": runB,
|
|
"success": true,
|
|
}
|
|
|
|
// Add metadata if available
|
|
if errA == nil && errB == nil {
|
|
result["job_name_match"] = metaA.JobName == metaB.JobName
|
|
result["user_match"] = metaA.User == metaB.User
|
|
result["timestamp_diff"] = metaB.Timestamp - metaA.Timestamp
|
|
}
|
|
|
|
// Read manifests for comparison
|
|
manifestA, _ := h.expManager.ReadManifest(runA)
|
|
manifestB, _ := h.expManager.ReadManifest(runB)
|
|
|
|
if manifestA != nil && manifestB != nil {
|
|
result["overall_sha_match"] = manifestA.OverallSHA == manifestB.OverallSHA
|
|
result["files_count_a"] = len(manifestA.Files)
|
|
result["files_count_b"] = len(manifestB.Files)
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, result)
|
|
}
|
|
|
|
// handleFindRuns searches for runs based on criteria
|
|
func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][query_len:2][query:var]
|
|
if len(payload) < 16+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "find runs payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
queryLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if queryLen > 0 && len(payload) >= offset+int(queryLen) {
|
|
// Parse query JSON
|
|
queryData := payload[offset : offset+int(queryLen)]
|
|
var query map[string]any
|
|
if err := json.Unmarshal(queryData, &query); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query JSON", err.Error())
|
|
}
|
|
|
|
h.logger.Info("search query", "query", query, "user", user.Name)
|
|
}
|
|
|
|
// For now, return placeholder results
|
|
results := []map[string]any{
|
|
{"id": "run_abc", "job_name": "train", "outcome": "validates"},
|
|
{"id": "run_def", "job_name": "eval", "outcome": "partial"},
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"results": results,
|
|
"count": len(results),
|
|
})
|
|
}
|
|
|
|
// handleExportRun exports a run with optional anonymization
|
|
func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][options_len:2][options:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "export run payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsRead) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runIDLen := int(payload[offset])
|
|
offset++
|
|
if runIDLen <= 0 || len(payload) < offset+runIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
|
}
|
|
runID := string(payload[offset : offset+runIDLen])
|
|
offset += runIDLen
|
|
|
|
// Parse options if present
|
|
var options map[string]any
|
|
if len(payload) >= offset+2 {
|
|
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
|
|
json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
|
|
}
|
|
}
|
|
|
|
// Check if experiment exists
|
|
if !h.expManager.ExperimentExists(runID) {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
|
|
}
|
|
|
|
anonymize := false
|
|
if options != nil {
|
|
if v, ok := options["anonymize"].(bool); ok {
|
|
anonymize = v
|
|
}
|
|
}
|
|
|
|
h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"run_id": runID,
|
|
"message": "Export request received",
|
|
"anonymize": anonymize,
|
|
})
|
|
}
|
|
|
|
// handleSetRunOutcome sets the outcome for a run
|
|
func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][outcome_data_len:2][outcome_data:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run outcome payload too short", "")
|
|
}
|
|
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
|
}
|
|
if !h.RequirePermission(user, PermJobsUpdate) {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
|
}
|
|
|
|
offset := 16
|
|
runIDLen := int(payload[offset])
|
|
offset++
|
|
if runIDLen <= 0 || len(payload) < offset+runIDLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
|
}
|
|
runID := string(payload[offset : offset+runIDLen])
|
|
offset += runIDLen
|
|
|
|
// Parse outcome data
|
|
outcomeLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
|
offset += 2
|
|
if outcomeLen == 0 || len(payload) < offset+int(outcomeLen) {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome data", "")
|
|
}
|
|
|
|
var outcomeData map[string]any
|
|
if err := json.Unmarshal(payload[offset:offset+int(outcomeLen)], &outcomeData); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome JSON", err.Error())
|
|
}
|
|
|
|
// Validate outcome status
|
|
validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true}
|
|
outcome, ok := outcomeData["outcome"].(string)
|
|
if !ok || !validOutcomes[outcome] {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial")
|
|
}
|
|
|
|
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
|
|
|
|
return h.sendSuccessPacket(conn, map[string]any{
|
|
"success": true,
|
|
"run_id": runID,
|
|
"outcome": outcome,
|
|
"message": "Outcome updated",
|
|
})
|
|
}
|