fetch_ml/internal/api/ws/jobs.go
Jeremie Fraeys 10e6416e11
refactor: update WebSocket handlers and database schemas
- Update datasets handlers with improved error handling
- Refactor WebSocket handler for better organization
- Clean up jobs.go handler implementation
- Add websocket_metrics table to Postgres and SQLite schemas
2026-02-18 14:36:30 -05:00

225 lines
6.4 KiB
Go

// Package ws provides WebSocket handling for the API
package ws
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
// handleQueueJob handles the QueueJob opcode (0x01)
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
if len(payload) < 39 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
commitIDBytes := payload[17:37]
commitIDHex := hex.EncodeToString(commitIDBytes)
priority := payload[37]
jobNameLen := int(payload[38])
if len(payload) < 39+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[39 : 39+jobNameLen])
// Parse optional resource fields
cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, ""
pos := 39 + jobNameLen
if len(payload) > pos {
cpu = int(payload[pos])
pos++
if len(payload) > pos {
memoryGB = int(payload[pos])
pos++
if len(payload) > pos {
gpu = int(payload[pos])
pos++
if len(payload) > pos {
gpuMemLen := int(payload[pos])
pos++
if len(payload) >= pos+gpuMemLen {
gpuMemory = string(payload[pos : pos+gpuMemLen])
}
}
}
}
}
task := &queue.Task{
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
JobName: jobName,
Status: "queued",
Priority: int64(priority),
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
CPU: cpu,
MemoryGB: memoryGB,
GPU: gpu,
GPUMemory: gpuMemory,
Metadata: map[string]string{"commit_id": commitIDHex},
}
// Auto-detect deps manifest and compute manifest SHA
if h.expManager != nil {
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, _ := selectDependencyManifest(filesPath)
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
}
if h.taskQueue != nil {
if err := h.taskQueue.AddTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"task_id": task.ID,
})
}
// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17)
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
if len(payload) < 41 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
commitIDBytes := payload[17:37]
commitIDHex := hex.EncodeToString(commitIDBytes)
priority := payload[37]
jobNameLen := int(payload[38])
if len(payload) < 39+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[39 : 39+jobNameLen])
pos := 39 + jobNameLen
snapshotIDLen := int(payload[pos])
pos++
snapshotID := string(payload[pos : pos+snapshotIDLen])
pos += snapshotIDLen
snapshotSHALen := int(payload[pos])
pos++
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
task := &queue.Task{
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
JobName: jobName,
Status: "queued",
Priority: int64(priority),
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
SnapshotID: snapshotID,
Metadata: map[string]string{
"commit_id": commitIDHex,
"snapshot_sha256": snapshotSHA,
},
}
if h.expManager != nil {
filesPath := h.expManager.GetFilesPath(commitIDHex)
depsName, _ := selectDependencyManifest(filesPath)
if depsName != "" {
task.Metadata["deps_manifest_name"] = depsName
depsPath := filepath.Join(filesPath, depsName)
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
task.Metadata["deps_manifest_sha256"] = sha
}
}
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
if data, err := os.ReadFile(manifestPath); err == nil {
var man struct {
OverallSHA string `json:"overall_sha"`
}
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
}
}
}
if h.taskQueue != nil {
if err := h.taskQueue.AddTask(task); err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"task_id": task.ID,
})
}
// handleCancelJob handles the CancelJob opcode (0x03)
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
jobNameLen := int(payload[17])
if len(payload) < 18+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
}
jobName := string(payload[18 : 18+jobNameLen])
if h.taskQueue != nil {
task, err := h.taskQueue.GetTaskByName(jobName)
if err == nil && task != nil {
task.Status = "cancelled"
h.taskQueue.UpdateTask(task)
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"message": "Job cancelled",
})
}
// handlePrune handles the Prune opcode (0x04)
func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][prune_type:1][value:4]
if len(payload) < 16+1+4 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Authenticate user
// Skip 16-byte API key hash for now (authentication would use it)
// offset := 16
// pruneType := payload[offset]
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"message": "Prune completed",
"pruned": 0,
})
}