Update API layer for scheduler integration: - WebSocket handlers with scheduler protocol support - Jobs WebSocket endpoint with priority queue integration - Validation middleware for scheduler messages - Server configuration with security hardening - Protocol definitions for worker-scheduler communication - Dataset handlers with tenant isolation checks - Response helpers with audit context - OpenAPI spec updates for new endpoints
200 lines
5.6 KiB
Go
200 lines
5.6 KiB
Go
// Package ws provides WebSocket handling for the API
|
|
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"path/filepath"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
|
)
|
|
|
|
const (
|
|
completed = "completed"
|
|
running = "running"
|
|
finished = "finished"
|
|
failed = "failed"
|
|
cancelled = "cancelled"
|
|
)
|
|
|
|
// handleValidateRequest handles the ValidateRequest opcode (0x16)
|
|
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
|
|
// mode=0: commit_id validation [commit_id_len:1][commit_id:var]
|
|
// mode=1: task_id validation [task_id_len:1][task_id:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
|
}
|
|
|
|
mode := payload[17]
|
|
|
|
if mode == 0 {
|
|
// Commit ID validation (basic)
|
|
if len(payload) < 20 {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "",
|
|
)
|
|
}
|
|
commitIDLen := int(payload[18])
|
|
if len(payload) < 19+commitIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "")
|
|
}
|
|
commitIDBytes := payload[19 : 19+commitIDLen]
|
|
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
|
|
|
|
report := map[string]any{
|
|
"ok": true,
|
|
"commit_id": commitIDHex,
|
|
}
|
|
payloadBytes, _ := json.Marshal(report)
|
|
return h.sendDataPacket(conn, "validate", payloadBytes)
|
|
}
|
|
|
|
// Task ID validation (mode=1) - full validation with checks
|
|
if len(payload) < 20 {
|
|
return h.sendErrorPacket(
|
|
conn, ErrorCodeInvalidRequest, "payload too short for task validation", "",
|
|
)
|
|
}
|
|
|
|
taskIDLen := int(payload[18])
|
|
if len(payload) < 19+taskIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "")
|
|
}
|
|
taskID := string(payload[19 : 19+taskIDLen])
|
|
|
|
// Initialize validation report
|
|
checks := make(map[string]any)
|
|
ok := true
|
|
|
|
// Get task from queue
|
|
if h.taskQueue == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "")
|
|
}
|
|
|
|
task, err := h.taskQueue.GetTask(taskID)
|
|
if err != nil || task == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "")
|
|
}
|
|
|
|
// Run manifest validation - load manifest if it exists
|
|
rmCheck := map[string]any{"ok": true}
|
|
rmCommitCheck := map[string]any{"ok": true}
|
|
rmLocCheck := map[string]any{"ok": true}
|
|
rmLifecycle := map[string]any{"ok": true}
|
|
var narrativeWarnings, outcomeWarnings []string
|
|
|
|
// Determine expected location based on task status
|
|
expectedLocation := running
|
|
if task.Status == completed || task.Status == cancelled || task.Status == failed {
|
|
expectedLocation = finished
|
|
}
|
|
|
|
// Try to load run manifest from appropriate location
|
|
var rm *manifest.RunManifest
|
|
var rmLoadErr error
|
|
|
|
if h.expManager != nil {
|
|
// Try expected location first
|
|
jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName)
|
|
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
|
|
|
|
// If not found and task is running, also check finished (wrong location test)
|
|
if rmLoadErr != nil && task.Status == running {
|
|
wrongDir := filepath.Join(h.expManager.BasePath(), finished, task.JobName)
|
|
rm, _ = manifest.LoadFromDir(wrongDir)
|
|
if rm != nil {
|
|
// Manifest exists but in wrong location
|
|
rmLocCheck["ok"] = false
|
|
rmLocCheck["expected"] = running
|
|
rmLocCheck["actual"] = finished
|
|
ok = false
|
|
}
|
|
}
|
|
}
|
|
|
|
if rm == nil {
|
|
// No run manifest found
|
|
if task.Status == running || task.Status == completed {
|
|
rmCheck["ok"] = false
|
|
ok = false
|
|
}
|
|
} else {
|
|
// Run manifest exists - validate it
|
|
|
|
// Check commit_id match
|
|
taskCommitID := task.Metadata["commit_id"]
|
|
if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID {
|
|
rmCommitCheck["ok"] = false
|
|
rmCommitCheck["expected"] = taskCommitID
|
|
ok = false
|
|
}
|
|
|
|
// Check lifecycle ordering (started_at < ended_at)
|
|
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) {
|
|
rmLifecycle["ok"] = false
|
|
ok = false
|
|
}
|
|
|
|
// Validate narrative if present
|
|
if rm.Narrative != nil {
|
|
nv := manifest.ValidateNarrative(rm.Narrative)
|
|
if len(nv.Errors) > 0 {
|
|
ok = false
|
|
}
|
|
narrativeWarnings = nv.Warnings
|
|
}
|
|
|
|
// Validate outcome if present
|
|
if rm.Outcome != nil {
|
|
ov := manifest.ValidateOutcome(rm.Outcome)
|
|
if len(ov.Errors) > 0 {
|
|
ok = false
|
|
}
|
|
outcomeWarnings = ov.Warnings
|
|
}
|
|
}
|
|
|
|
checks["run_manifest"] = rmCheck
|
|
checks["run_manifest_commit_id"] = rmCommitCheck
|
|
checks["run_manifest_location"] = rmLocCheck
|
|
checks["run_manifest_lifecycle"] = rmLifecycle
|
|
|
|
// Resources check
|
|
resCheck := map[string]any{"ok": true}
|
|
if task.CPU < 0 {
|
|
resCheck["ok"] = false
|
|
ok = false
|
|
}
|
|
checks["resources"] = resCheck
|
|
|
|
// Snapshot check
|
|
snapCheck := map[string]any{"ok": true}
|
|
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
|
|
// Verify snapshot SHA
|
|
dataDir := h.dataDir
|
|
if dataDir == "" {
|
|
dataDir = filepath.Join(h.expManager.BasePath(), "data")
|
|
}
|
|
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
|
actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath)
|
|
expectedSHA := task.Metadata["snapshot_sha256"]
|
|
if actualSHA != expectedSHA {
|
|
snapCheck["ok"] = false
|
|
snapCheck["actual"] = actualSHA
|
|
ok = false
|
|
}
|
|
}
|
|
checks["snapshot"] = snapCheck
|
|
|
|
report := map[string]any{
|
|
"ok": ok,
|
|
"checks": checks,
|
|
"narrative_warnings": narrativeWarnings,
|
|
"outcome_warnings": outcomeWarnings,
|
|
}
|
|
payloadBytes, _ := json.Marshal(report)
|
|
return h.sendDataPacket(conn, "validate", payloadBytes)
|
|
}
|