fetch_ml/internal/api/ws/jobs.go
Jeremie Fraeys 188cf55939
refactor(api): overhaul WebSocket handler and protocol layer
Major WebSocket handler refactor:
- Rewrite ws/handler.go with structured message routing and backpressure
- Add connection lifecycle management with heartbeats and timeouts
- Implement graceful connection draining for zero-downtime restarts

Protocol improvements:
- Define structured protocol types in protocol.go for hub communication
- Add versioned message envelopes for backward compatibility
- Standardize error codes and response formats across WebSocket API

Job streaming via WebSocket:
- Simplify ws/jobs.go with async job status streaming
- Add compression for high-volume job updates

Testing:
- Update websocket_e2e_test.go for new protocol semantics
- Add connection resilience tests
2026-03-12 12:01:21 -04:00

235 lines
6.6 KiB
Go

// Package ws provides WebSocket handling for the API
package ws
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
func (h *Handler) populateExperimentIntegrityMetadata(
task *queue.Task,
commitIDHex string,
) (string, error) {
if h.expManager == nil {
return "", nil
}
// Validate commit ID (defense-in-depth)
if len(commitIDHex) != 40 {
return "", fmt.Errorf("invalid commit id length")
}
if _, err := hex.DecodeString(commitIDHex); err != nil {
return "", fmt.Errorf("invalid commit id format")
}
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, err := selectDependencyManifest(filesPath)
if err != nil {
return "", err
}
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
basePath := filepath.Clean(h.expManager.BasePath())
manifestPath := filepath.Join(basePath, commitIDHex, "manifest.json")
manifestPath = filepath.Clean(manifestPath)
if !strings.HasPrefix(manifestPath, basePath+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected")
}
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
return depsName, nil
}
// handleQueueJob handles the QueueJob opcode (0x01)
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
if len(payload) < 39 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
commitIDBytes := payload[17:37]
commitIDHex := hex.EncodeToString(commitIDBytes)
priority := payload[37]
jobNameLen := int(payload[38])
if len(payload) < 39+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[39 : 39+jobNameLen])
// Parse optional resource fields
cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, ""
pos := 39 + jobNameLen
if len(payload) > pos {
cpu = int(payload[pos])
pos++
if len(payload) > pos {
memoryGB = int(payload[pos])
pos++
if len(payload) > pos {
gpu = int(payload[pos])
pos++
if len(payload) > pos {
gpuMemLen := int(payload[pos])
pos++
if len(payload) >= pos+gpuMemLen {
gpuMemory = string(payload[pos : pos+gpuMemLen])
}
}
}
}
}
task := &queue.Task{
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
JobName: jobName,
Status: "queued",
Priority: int64(priority),
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
CPU: cpu,
MemoryGB: memoryGB,
GPU: gpu,
GPUMemory: gpuMemory,
Metadata: map[string]string{"commit_id": commitIDHex},
}
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
)
}
if h.taskQueue != nil {
if err := h.taskQueue.AddTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
}
}
return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID})
}
// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17)
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
if len(payload) < 41 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
commitIDBytes := payload[17:37]
commitIDHex := hex.EncodeToString(commitIDBytes)
priority := payload[37]
jobNameLen := int(payload[38])
if len(payload) < 39+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[39 : 39+jobNameLen])
pos := 39 + jobNameLen
snapshotIDLen := int(payload[pos])
pos++
snapshotID := string(payload[pos : pos+snapshotIDLen])
pos += snapshotIDLen
snapshotSHALen := int(payload[pos])
pos++
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
task := &queue.Task{
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
JobName: jobName,
Status: "queued",
Priority: int64(priority),
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
SnapshotID: snapshotID,
Metadata: map[string]string{
"commit_id": commitIDHex,
"snapshot_sha256": snapshotSHA,
},
}
if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(),
)
}
if h.taskQueue != nil {
if err := h.taskQueue.AddTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
}
}
return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID})
}
// handleCancelJob handles the CancelJob opcode (0x03)
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
jobNameLen := int(payload[17])
if len(payload) < 18+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[18 : 18+jobNameLen])
if h.taskQueue != nil {
task, err := h.taskQueue.GetTaskByName(jobName)
if err == nil && task != nil {
task.Status = "cancelled"
if err := h.taskQueue.UpdateTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to cancel task", err.Error())
}
}
}
return h.sendSuccessPacket(conn, map[string]any{"message": "Job cancelled"})
}
// handlePrune handles the Prune opcode (0x04)
func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][prune_type:1][value:4]
if len(payload) < 16+1+4 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Authenticate user
// Skip 16-byte API key hash for now (authentication would use it)
// offset := 16
// pruneType := payload[offset]
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
return h.sendSuccessPacket(conn, map[string]any{"message": "Prune completed", "pruned": 0})
}