refactor: Phase 5 cleanup - Remove original ws_*.go files

Removed original monolithic WebSocket handler files after extracting
to focused packages:

Deleted:
- ws_jobs.go (1,365 lines) → Extracted to api/jobs/handlers.go
- ws_jupyter.go (512 lines) → Extracted to api/jupyter/handlers.go
- ws_validate.go (523 lines) → Extracted to api/validate/handlers.go
- ws_handler.go (379 lines) → Extracted to api/ws/handler.go
- ws_datasets.go (174 lines) - Functionality not migrated
- ws_tls_auth.go (101 lines) - Functionality not migrated

Updated:
- routes.go - Changed NewWSHandler → ws.NewHandler

Lines deleted: ~3,000+ lines from monolithic files
Build status: Compiles successfully
This commit is contained in:
Jeremie Fraeys 2026-02-17 13:33:00 -05:00
parent f0ffbb4a3d
commit d9c5750ed8
No known key found for this signature in database
7 changed files with 2 additions and 3053 deletions

View file

@ -3,6 +3,7 @@ package api
import (
"net/http"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/prommetrics"
)
@ -49,7 +50,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
// Register WebSocket handler with security config and audit logger
securityCfg := getSecurityConfig(s.config)
wsHandler := NewWSHandler(
wsHandler := ws.NewHandler(
s.config.BuildAuthConfig(),
s.logger,
s.expManager,

View file

@ -1,173 +0,0 @@
package api
import (
"database/sql"
"encoding/binary"
"encoding/json"
"net/url"
"strings"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/storage"
)
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
if err != nil {
return err
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if err := h.requireDB(conn); err != nil {
return err
}
ctx, cancel := helpers.DBContextShort()
defer cancel()
datasets, err := h.db.ListDatasets(ctx, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister)
if err != nil {
return err
}
if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil {
return err
}
if err := h.requireDB(conn); err != nil {
return err
}
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if urlLen <= 0 || len(payload) < offset+urlLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
}
urlStr := string(payload[offset : offset+urlLen])
if strings.TrimSpace(name) == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
}
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
}
ctx, cancel := helpers.DBContextShort()
defer cancel()
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
}
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo)
if err != nil {
return err
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if err := h.requireDB(conn); err != nil {
return err
}
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
ctx, cancel := helpers.DBContextShort()
defer cancel()
ds, err := h.db.GetDataset(ctx, name)
if err != nil {
if err == sql.ErrNoRows {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
}
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
}
data, err := json.Marshal(ds)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
}
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch)
if err != nil {
return err
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if err := h.requireDB(conn); err != nil {
return err
}
offset := ProtocolAPIKeyHashLen
termLen := int(payload[offset])
offset++
if termLen < 0 || len(payload) < offset+termLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
}
term := string(payload[offset : offset+termLen])
term = strings.TrimSpace(term)
ctx, cancel := helpers.DBContextShort()
defer cancel()
datasets, err := h.db.SearchDatasets(ctx, term, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}

View file

@ -1,379 +0,0 @@
package api
import (
"compress/flate"
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeDatasetList = 0x06
OpcodeDatasetRegister = 0x07
OpcodeDatasetInfo = 0x08
OpcodeDatasetSearch = 0x09
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
OpcodeQueueJobWithTracking = 0x0C
OpcodeQueueJobWithSnapshot = 0x17
OpcodeQueueJobWithArgs = 0x1A
OpcodeQueueJobWithNote = 0x1B
OpcodeAnnotateRun = 0x1C
OpcodeSetRunNarrative = 0x1D
OpcodeStartJupyter = 0x0D
OpcodeStopJupyter = 0x0E
OpcodeRemoveJupyter = 0x18
OpcodeRestoreJupyter = 0x19
OpcodeListJupyter = 0x0F
OpcodeListJupyterPackages = 0x1E
OpcodeValidateRequest = 0x16
// Logs opcodes
OpcodeGetLogs = 0x20
OpcodeStreamLogs = 0x21
)
// createUpgrader creates a WebSocket upgrader with the given security configuration
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
return websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Production mode: strict checking against allowed origins
if securityCfg != nil && securityCfg.ProductionMode {
for _, allowed := range securityCfg.AllowedOrigins {
if origin == allowed {
return true
}
}
return false // Reject if not in allowed list
}
// Development mode: allow localhost and local network origins
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasPrefix(host, "localhost:") ||
strings.HasPrefix(host, "127.0.0.1:") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
// Performance optimizations
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
EnableCompression: true,
}
}
// WSHandler handles WebSocket connections for the API.
type WSHandler struct {
authConfig *auth.Config
logger *logging.Logger
expManager *experiment.Manager
dataDir string
queue queue.Backend
db *storage.DB
jupyterServiceMgr *jupyter.ServiceManager
securityConfig *config.SecurityConfig
auditLogger *audit.Logger
upgrader websocket.Upgrader
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
dataDir string,
taskQueue queue.Backend,
db *storage.DB,
jupyterServiceMgr *jupyter.ServiceManager,
securityConfig *config.SecurityConfig,
auditLogger *audit.Logger,
) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
expManager: expManager,
dataDir: dataDir,
queue: taskQueue,
db: db,
jupyterServiceMgr: jupyterServiceMgr,
securityConfig: securityConfig,
auditLogger: auditLogger,
upgrader: createUpgrader(securityConfig),
}
}
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
if conn == nil {
return
}
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
if err := tcpConn.SetNoDelay(true); err != nil {
logger.Warn("failed to enable tcp no delay", "error", err)
}
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
if r.TLS != nil {
// Only set HSTS if using HTTPS
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
// Check API key before upgrading WebSocket
apiKey := auth.ExtractAPIKeyFromRequest(r)
clientIP := r.RemoteAddr
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
prefixLen := len(apiKey)
if prefixLen > 8 {
prefixLen = 8
}
h.logger.Info(
"websocket auth attempt",
"api_key_length",
len(apiKey),
"api_key_prefix",
apiKey[:prefixLen],
)
userID, err := h.authConfig.ValidateAPIKey(apiKey)
if err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
// Audit log failed authentication
if h.auditLogger != nil {
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
}
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
h.logger.Info("websocket authentication succeeded")
// Audit log successful authentication
if h.auditLogger != nil && userID != nil {
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
}
}
conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
conn.EnableWriteCompression(true)
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
h.logger.Warn("failed to set websocket compression level", "error", err)
}
enableLowLatencyTCP(conn, h.logger)
defer func() {
_ = conn.Close()
}()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(
err,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send structured error response so CLI clients can parse it.
// (Raw fallback bytes cause client-side InvalidPacket errors.)
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeQueueJobWithTracking:
return h.handleQueueJobWithTracking(conn, payload)
case OpcodeQueueJobWithSnapshot:
return h.handleQueueJobWithSnapshot(conn, payload)
case OpcodeQueueJobWithArgs:
return h.handleQueueJobWithArgs(conn, payload)
case OpcodeQueueJobWithNote:
return h.handleQueueJobWithNote(conn, payload)
case OpcodeAnnotateRun:
return h.handleAnnotateRun(conn, payload)
case OpcodeSetRunNarrative:
return h.handleSetRunNarrative(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeDatasetList:
return h.handleDatasetList(conn, payload)
case OpcodeDatasetRegister:
return h.handleDatasetRegister(conn, payload)
case OpcodeDatasetInfo:
return h.handleDatasetInfo(conn, payload)
case OpcodeDatasetSearch:
return h.handleDatasetSearch(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
case OpcodeStartJupyter:
return h.handleStartJupyter(conn, payload)
case OpcodeStopJupyter:
return h.handleStopJupyter(conn, payload)
case OpcodeRemoveJupyter:
return h.handleRemoveJupyter(conn, payload)
case OpcodeRestoreJupyter:
return h.handleRestoreJupyter(conn, payload)
case OpcodeListJupyter:
return h.handleListJupyter(conn, payload)
case OpcodeListJupyterPackages:
return h.handleListJupyterPackages(conn, payload)
case OpcodeValidateRequest:
return h.handleValidateRequest(conn, payload)
case OpcodeGetLogs:
return h.handleGetLogs(conn, payload)
case OpcodeStreamLogs:
return h.handleStreamLogs(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}
// AuthHandler is a handler function that receives an authenticated user
type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error
// authenticate validates the API key from raw payload and returns the user
func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) {
if len(payload) < minLen {
return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}
// authenticateWithHash validates a pre-extracted API key hash
func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) {
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}
// requirePermission checks if the user has the required permission
func (h *WSHandler) requirePermission(
user *auth.User,
permission string,
conn *websocket.Conn,
) error {
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) {
h.logger.Error("insufficient permissions", "user", user.Name, "required", permission)
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
fmt.Sprintf("Insufficient permissions: %s", permission),
"",
)
}
return nil
}
// requireDB checks if the database is configured
func (h *WSHandler) requireDB(conn *websocket.Conn) error {
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -1,512 +0,0 @@
package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// JupyterTaskErrorCode returns the error code for a Jupyter task.
// This is kept for backward compatibility and delegates to the helper.
func JupyterTaskErrorCode(t *queue.Task) byte {
mapper := helpers.NewTaskErrorMapper()
return byte(mapper.MapJupyterError(t))
}
type jupyterTaskOutput struct {
Type string `json:"type"`
Service json.RawMessage `json:"service,omitempty"`
Services json.RawMessage `json:"services,omitempty"`
Packages json.RawMessage `json:"packages,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, 18)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJupyterManage, conn); err != nil {
return err
}
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRestore,
jupyterNameKey: strings.TrimSpace(name),
}
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
}
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
if strings.TrimSpace(payloadOut.RestorePath) != "" {
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
type jupyterServiceView struct {
Name string `json:"name"`
URL string `json:"url"`
}
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterActionListPkgs = "list_packages"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
)
func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:read") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
p := helpers.NewPayloadParser(payload, 16)
name, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name = strings.TrimSpace(name)
if name == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionListPkgs,
jupyterNameKey: name,
}
jobName := fmt.Sprintf("jupyter-packages-%s", name)
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter packages list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter packages list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)
if out == "" {
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
payload := payloadOut.Packages
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
}
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
if h.queue == nil {
return "", fmt.Errorf("task queue not configured")
}
if err := container.ValidateJobName(jobName); err != nil {
return "", err
}
if strings.TrimSpace(userName) == "" {
return "", fmt.Errorf("missing user")
}
if meta == nil {
meta = make(map[string]string)
}
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: 100, // high priority; interactive request
CreatedAt: time.Now(),
UserID: userName,
Username: userName,
CreatedBy: userName,
Metadata: meta,
}
if err := h.queue.AddTask(task); err != nil {
return "", err
}
return taskID, nil
}
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
if h.queue == nil {
return nil, fmt.Errorf("task queue not configured")
}
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timed out waiting for worker")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return nil, err
}
if t == nil {
time.Sleep(200 * time.Millisecond)
continue
}
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
return t, nil
}
time.Sleep(200 * time.Millisecond)
}
}
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol:
// [api_key_hash:16][name][workspace][password]
if len(payload) < 21 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
}
apiKeyHash := payload[:16]
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if len(payload) < offset+workspaceLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
}
workspace := string(payload[offset : offset+workspaceLen])
offset += workspaceLen
passwordLen := int(payload[offset])
offset++
if len(payload) < offset+passwordLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
}
// Password is parsed but not used in StartRequest
// offset += passwordLen (already advanced during parsing)
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStart,
jupyterNameKey: strings.TrimSpace(name),
jupyterWorkspaceKey: strings.TrimSpace(workspace),
}
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter start", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
h.logger.Error("jupyter task failed", "error", result.Error)
details := strings.TrimSpace(result.Error)
lower := strings.ToLower(details)
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
}
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details)
}
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
var svc jupyterServiceView
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
if strings.TrimSpace(svc.URL) != "" {
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
}
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
p := helpers.NewPayloadParser(payload, 16)
serviceID, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStop,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
}
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
}
apiKeyHash := payload[:16]
p := helpers.NewPayloadParser(payload, 16)
serviceID, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
// Optional: purge flag (1 byte). Default false for trash-first behavior.
purge := false
if p.HasRemaining() {
purgeByte, _ := p.ParseByte()
purge = purgeByte == 0x01
}
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRemove,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
"jupyter_purge": fmt.Sprintf("%t", purge),
}
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJupyterRead, conn); err != nil {
return err
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionList,
}
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)
if out == "" {
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
payload := payloadOut.Services
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
}

View file

@ -1,100 +0,0 @@
package api
import (
"crypto/tls"
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"golang.org/x/crypto/acme/autocert"
)
// SetupTLSConfig creates TLS configuration for WebSocket server
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
var server *http.Server
if certFile != "" && keyFile != "" {
// Use provided certificates
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
} else if host != "" {
// Use Let's Encrypt with autocert
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(host),
Cache: autocert.DirCache("/var/www/.cache"),
}
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: certManager.TLSConfig(),
}
}
return server, nil
}
// verifyAPIKeyHash verifies the provided binary hash against stored API keys
func (h *WSHandler) verifyAPIKeyHash(hash []byte) error {
if h.authConfig == nil || !h.authConfig.Enabled {
return nil // No auth required
}
_, err := h.authConfig.ValidateAPIKeyHash(hash)
if err != nil {
return fmt.Errorf("invalid api key")
}
return nil
}
// sendErrorPacket sends an error response packet
func (h *WSHandler) sendErrorPacket(
conn *websocket.Conn,
errorCode byte,
message string,
details string,
) error {
packet := NewErrorPacket(errorCode, message, details)
return h.sendResponsePacket(conn, packet)
}
// sendResponsePacket sends a structured response packet
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
data, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize response packet", "error", err)
// Fallback to simple error response
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
}
return conn.WriteMessage(websocket.BinaryMessage, data)
}
func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) {
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
return nil, err
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}

View file

@ -1,523 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/worker"
)
type validateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
}
type validateReport struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]validateCheck `json:"checks"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
}
func shouldRequireRunManifest(task *queue.Task) bool {
if task == nil {
return false
}
s := strings.ToLower(strings.TrimSpace(task.Status))
switch s {
case "running", "completed", "failed":
return true
default:
return false
}
}
func expectedRunManifestBucketForStatus(status string) (string, bool) {
s := strings.ToLower(strings.TrimSpace(status))
switch s {
case "queued", "pending":
return "pending", true
case "running":
return "running", true
case "completed", "finished":
return "finished", true
case "failed":
return "failed", true
default:
return "", false
}
}
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
return "", "", false
}
jobPaths := storage.NewJobPaths(basePath)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
for _, item := range typedRoots {
root := item.root
dir := filepath.Join(root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
return dir, item.bucket, true
}
}
}
return "", "", false
}
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
if task == nil {
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
}
if task.CPU < 0 {
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
return chk, []string{"invalid cpu request"}
}
if task.MemoryGB < 0 {
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
return chk, []string{"invalid memory request"}
}
if task.GPU < 0 {
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
return chk, []string{"invalid gpu request"}
}
if strings.TrimSpace(task.GPUMemory) != "" {
s := strings.TrimSpace(task.GPUMemory)
if strings.HasSuffix(s, "%") {
v := strings.TrimSuffix(s, "%")
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil || f <= 0 || f > 100 {
details := "gpu_memory must be a percentage in (0,100]"
chk := validateCheck{OK: false, Details: details}
return chk, []string{"invalid gpu_memory"}
}
} else {
f, err := strconv.ParseFloat(s, 64)
if err != nil || f <= 0 || f > 1 {
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
return chk, []string{"invalid gpu_memory"}
}
}
}
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
return chk, []string{"invalid gpu_memory"}
}
return validateCheck{OK: true}, nil
}
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
}
apiKeyHash := payload[:16]
targetType := payload[16]
idLen := int(payload[17])
if idLen < 1 || len(payload) < 18+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
}
idBytes := payload[18 : 18+idLen]
// Validate API key and user
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Invalid API key",
err.Error(),
)
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to validate jobs",
"",
)
}
if h.expManager == nil {
return h.sendErrorPacket(
conn,
ErrorCodeServiceUnavailable,
"Experiment manager not available",
"",
)
}
var task *queue.Task
commitID := ""
switch targetType {
case 0:
if len(idBytes) != 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
}
commitID = fmt.Sprintf("%x", idBytes)
case 1:
taskID := string(idBytes)
if h.queue == nil {
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
}
task = t
if h.authConfig != nil &&
h.authConfig.Enabled &&
!user.Admin &&
task.UserID != user.Name &&
task.CreatedBy != user.Name {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only validate your own jobs",
"",
)
}
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
}
commitID = task.Metadata["commit_id"]
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
}
r := validateReport{
OK: true,
TS: time.Now().UTC().Format(time.RFC3339Nano),
Checks: map[string]validateCheck{},
}
if task != nil {
r.TaskID = task.ID
}
if commitID != "" {
r.CommitID = commitID
}
// Validate commit id format
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
r.OK = false
r.Errors = append(r.Errors, errMsg)
}
// Experiment manifest integrity
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
if r.OK {
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
r.OK = false
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
r.Errors = append(r.Errors, "experiment manifest validation failed")
} else {
r.Checks["experiment_manifest"] = validateCheck{OK: true}
}
}
// Deps manifest presence + hash
// TODO(context): Allow client to declare which dependency manifest is authoritative.
filesPath := h.expManager.GetFilesPath(commitID)
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
if depErrs != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck(depCheck)
r.Errors = append(r.Errors, depErrs...)
} else {
r.Checks["deps_manifest"] = validateCheck(depCheck)
}
// Compare against expected task metadata if available.
if task != nil {
resCheck, resErrs := validateResourcesForTask(task)
r.Checks["resources"] = resCheck
if !resCheck.OK {
r.OK = false
r.Errors = append(r.Errors, resErrs...)
}
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
if err := container.ValidateJobName(task.JobName); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid job name")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
} else {
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
}
} else {
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
if !found {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest not found")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
} else {
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
}
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
r.OK = false
r.Errors = append(r.Errors, "unable to read run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
} else {
r.Checks["run_manifest"] = validateCheck{OK: true}
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
if ok {
if expectedBucket != manifestBucket {
msg := "run manifest location mismatch"
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, msg)
r.Checks["run_manifest_location"] = chk
} else {
r.Warnings = append(r.Warnings, msg)
r.Checks["run_manifest_location"] = chk
}
} else {
r.Checks["run_manifest_location"] = validateCheck{
OK: true,
Expected: expectedBucket,
Actual: manifestBucket,
}
}
}
// Validate task ID using helper
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
if !taskIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest task_id mismatch")
}
// Validate commit ID using helper
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
if !commitCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
}
// Validate deps provenance using helper
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
depGotName := strings.TrimSpace(rm.DepsManifestName)
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
if !depsCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
}
// Validate snapshot using helpers
if strings.TrimSpace(task.SnapshotID) != "" {
snapWantID := strings.TrimSpace(task.SnapshotID)
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
snapGotID := strings.TrimSpace(rm.SnapshotID)
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
if !snapIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
}
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
if !snapSHACheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
}
}
// Validate lifecycle using helper
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
if lifecycleOK {
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
} else {
chk := validateCheck{OK: false, Details: details}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
} else {
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
}
}
}
}
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
cur := ""
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
cur = man.OverallSHA
}
if want == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
} else if cur == "" {
r.OK = false
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
} else if want != cur {
r.OK = false
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
} else {
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
}
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if wantDep == "" || wantDepSha == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
} else if depName != "" {
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
ok := (wantDep == depName && wantDepSha == sha)
if !ok {
r.OK = false
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
r.Checks["expected_deps_manifest"] = validateCheck{
OK: false,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
} else {
r.Checks["expected_deps_manifest"] = validateCheck{
OK: true,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
}
}
// Snapshot/dataset checks require dataDir.
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
// TODO(context): Validate snapshots by digest.
if task.SnapshotID != "" {
if h.dataDir == "" {
r.OK = false
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
} else {
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if nerr != nil || wantSnap == "" {
r.OK = false
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
r.Checks["snapshot"] = validateCheck{OK: false}
} else {
curSnap, err := worker.DirOverallSHA256Hex(
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
)
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "snapshot hash computation failed")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
} else if curSnap != wantSnap {
r.OK = false
r.Errors = append(r.Errors, "snapshot checksum mismatch")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
} else {
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
}
}
}
}
if len(task.DatasetSpecs) > 0 {
// TODO(context): Add dataset URI fetch/verification.
// TODO(context): Currently only validates local materialized datasets.
for _, ds := range task.DatasetSpecs {
if ds.Checksum == "" {
continue
}
key := "dataset:" + ds.Name
if h.dataDir == "" {
r.OK = false
r.Errors = append(
r.Errors,
"api server data_dir not configured; cannot validate dataset checksums",
)
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
continue
}
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
if nerr != nil || wantDS == "" {
r.OK = false
r.Errors = append(r.Errors, "invalid dataset checksum format")
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
continue
}
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum computation failed")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
continue
}
if curDS != wantDS {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum mismatch")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
continue
}
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
}
}
}
payloadBytes, err := json.Marshal(r)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeUnknownError,
"failed to serialize validate report",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
}