- Update datasets handlers with improved error handling - Refactor WebSocket handler for better organization - Clean up jobs.go handler implementation - Add websocket_metrics table to Postgres and SQLite schemas
225 lines
6.4 KiB
Go
225 lines
6.4 KiB
Go
// Package ws provides WebSocket handling for the API
|
|
package ws
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
)
|
|
|
|
// handleQueueJob handles the QueueJob opcode (0x01)
|
|
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
|
if len(payload) < 39 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
commitIDBytes := payload[17:37]
|
|
commitIDHex := hex.EncodeToString(commitIDBytes)
|
|
priority := payload[37]
|
|
jobNameLen := int(payload[38])
|
|
|
|
if len(payload) < 39+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[39 : 39+jobNameLen])
|
|
|
|
// Parse optional resource fields
|
|
cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, ""
|
|
pos := 39 + jobNameLen
|
|
if len(payload) > pos {
|
|
cpu = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
memoryGB = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
gpu = int(payload[pos])
|
|
pos++
|
|
if len(payload) > pos {
|
|
gpuMemLen := int(payload[pos])
|
|
pos++
|
|
if len(payload) >= pos+gpuMemLen {
|
|
gpuMemory = string(payload[pos : pos+gpuMemLen])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
task := &queue.Task{
|
|
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
|
JobName: jobName,
|
|
Status: "queued",
|
|
Priority: int64(priority),
|
|
CreatedAt: time.Now(),
|
|
UserID: "user",
|
|
CreatedBy: "user",
|
|
CPU: cpu,
|
|
MemoryGB: memoryGB,
|
|
GPU: gpu,
|
|
GPUMemory: gpuMemory,
|
|
Metadata: map[string]string{"commit_id": commitIDHex},
|
|
}
|
|
|
|
// Auto-detect deps manifest and compute manifest SHA
|
|
if h.expManager != nil {
|
|
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
|
depsName, _ := selectDependencyManifest(filesPath)
|
|
if depsName != "" {
|
|
task.Metadata["deps_manifest_name"] = depsName
|
|
depsPath := filepath.Join(filesPath, depsName)
|
|
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
|
task.Metadata["deps_manifest_sha256"] = sha
|
|
}
|
|
}
|
|
|
|
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
|
if data, err := os.ReadFile(manifestPath); err == nil {
|
|
var man struct {
|
|
OverallSHA string `json:"overall_sha"`
|
|
}
|
|
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
|
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
|
}
|
|
}
|
|
}
|
|
|
|
if h.taskQueue != nil {
|
|
if err := h.taskQueue.AddTask(task); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"task_id": task.ID,
|
|
})
|
|
}
|
|
|
|
// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17)
|
|
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < 41 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
commitIDBytes := payload[17:37]
|
|
commitIDHex := hex.EncodeToString(commitIDBytes)
|
|
priority := payload[37]
|
|
jobNameLen := int(payload[38])
|
|
|
|
if len(payload) < 39+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[39 : 39+jobNameLen])
|
|
|
|
pos := 39 + jobNameLen
|
|
snapshotIDLen := int(payload[pos])
|
|
pos++
|
|
snapshotID := string(payload[pos : pos+snapshotIDLen])
|
|
pos += snapshotIDLen
|
|
snapshotSHALen := int(payload[pos])
|
|
pos++
|
|
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
|
|
|
|
task := &queue.Task{
|
|
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
|
JobName: jobName,
|
|
Status: "queued",
|
|
Priority: int64(priority),
|
|
CreatedAt: time.Now(),
|
|
UserID: "user",
|
|
CreatedBy: "user",
|
|
SnapshotID: snapshotID,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDHex,
|
|
"snapshot_sha256": snapshotSHA,
|
|
},
|
|
}
|
|
|
|
if h.expManager != nil {
|
|
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
|
depsName, _ := selectDependencyManifest(filesPath)
|
|
if depsName != "" {
|
|
task.Metadata["deps_manifest_name"] = depsName
|
|
depsPath := filepath.Join(filesPath, depsName)
|
|
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
|
task.Metadata["deps_manifest_sha256"] = sha
|
|
}
|
|
}
|
|
|
|
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
|
if data, err := os.ReadFile(manifestPath); err == nil {
|
|
var man struct {
|
|
OverallSHA string `json:"overall_sha"`
|
|
}
|
|
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
|
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
|
}
|
|
}
|
|
}
|
|
|
|
if h.taskQueue != nil {
|
|
if err := h.taskQueue.AddTask(task); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"task_id": task.ID,
|
|
})
|
|
}
|
|
|
|
// handleCancelJob handles the CancelJob opcode (0x03)
|
|
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
jobNameLen := int(payload[17])
|
|
if len(payload) < 18+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
|
}
|
|
jobName := string(payload[18 : 18+jobNameLen])
|
|
|
|
if h.taskQueue != nil {
|
|
task, err := h.taskQueue.GetTaskByName(jobName)
|
|
if err == nil && task != nil {
|
|
task.Status = "cancelled"
|
|
h.taskQueue.UpdateTask(task)
|
|
}
|
|
}
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Job cancelled",
|
|
})
|
|
}
|
|
|
|
// handlePrune handles the Prune opcode (0x04)
|
|
func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload: [api_key_hash:16][prune_type:1][value:4]
|
|
if len(payload) < 16+1+4 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
|
|
}
|
|
|
|
// Authenticate user
|
|
// Skip 16-byte API key hash for now (authentication would use it)
|
|
// offset := 16
|
|
// pruneType := payload[offset]
|
|
// value := binary.BigEndian.Uint32(payload[offset+1 : offset+5])
|
|
|
|
return h.sendSuccessPacket(conn, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Prune completed",
|
|
"pruned": 0,
|
|
})
|
|
}
|