fetch_ml/internal/api/ws/validate.go
Jeremie Fraeys 420de879ff
feat(api): integrate scheduler protocol and WebSocket enhancements
Update API layer for scheduler integration:
- WebSocket handlers with scheduler protocol support
- Jobs WebSocket endpoint with priority queue integration
- Validation middleware for scheduler messages
- Server configuration with security hardening
- Protocol definitions for worker-scheduler communication
- Dataset handlers with tenant isolation checks
- Response helpers with audit context
- OpenAPI spec updates for new endpoints
2026-02-26 12:05:57 -05:00

200 lines
5.6 KiB
Go

// Package ws provides WebSocket handling for the API
package ws
import (
"encoding/json"
"fmt"
"path/filepath"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
const (
completed = "completed"
running = "running"
finished = "finished"
failed = "failed"
cancelled = "cancelled"
)
// handleValidateRequest handles the ValidateRequest opcode (0x16)
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
// mode=0: commit_id validation [commit_id_len:1][commit_id:var]
// mode=1: task_id validation [task_id_len:1][task_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
mode := payload[17]
if mode == 0 {
// Commit ID validation (basic)
if len(payload) < 20 {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "",
)
}
commitIDLen := int(payload[18])
if len(payload) < 19+commitIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "")
}
commitIDBytes := payload[19 : 19+commitIDLen]
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
report := map[string]any{
"ok": true,
"commit_id": commitIDHex,
}
payloadBytes, _ := json.Marshal(report)
return h.sendDataPacket(conn, "validate", payloadBytes)
}
// Task ID validation (mode=1) - full validation with checks
if len(payload) < 20 {
return h.sendErrorPacket(
conn, ErrorCodeInvalidRequest, "payload too short for task validation", "",
)
}
taskIDLen := int(payload[18])
if len(payload) < 19+taskIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "")
}
taskID := string(payload[19 : 19+taskIDLen])
// Initialize validation report
checks := make(map[string]any)
ok := true
// Get task from queue
if h.taskQueue == nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "")
}
task, err := h.taskQueue.GetTask(taskID)
if err != nil || task == nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "")
}
// Run manifest validation - load manifest if it exists
rmCheck := map[string]any{"ok": true}
rmCommitCheck := map[string]any{"ok": true}
rmLocCheck := map[string]any{"ok": true}
rmLifecycle := map[string]any{"ok": true}
var narrativeWarnings, outcomeWarnings []string
// Determine expected location based on task status
expectedLocation := running
if task.Status == completed || task.Status == cancelled || task.Status == failed {
expectedLocation = finished
}
// Try to load run manifest from appropriate location
var rm *manifest.RunManifest
var rmLoadErr error
if h.expManager != nil {
// Try expected location first
jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName)
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
// If not found and task is running, also check finished (wrong location test)
if rmLoadErr != nil && task.Status == running {
wrongDir := filepath.Join(h.expManager.BasePath(), finished, task.JobName)
rm, _ = manifest.LoadFromDir(wrongDir)
if rm != nil {
// Manifest exists but in wrong location
rmLocCheck["ok"] = false
rmLocCheck["expected"] = running
rmLocCheck["actual"] = finished
ok = false
}
}
}
if rm == nil {
// No run manifest found
if task.Status == running || task.Status == completed {
rmCheck["ok"] = false
ok = false
}
} else {
// Run manifest exists - validate it
// Check commit_id match
taskCommitID := task.Metadata["commit_id"]
if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID {
rmCommitCheck["ok"] = false
rmCommitCheck["expected"] = taskCommitID
ok = false
}
// Check lifecycle ordering (started_at < ended_at)
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) {
rmLifecycle["ok"] = false
ok = false
}
// Validate narrative if present
if rm.Narrative != nil {
nv := manifest.ValidateNarrative(rm.Narrative)
if len(nv.Errors) > 0 {
ok = false
}
narrativeWarnings = nv.Warnings
}
// Validate outcome if present
if rm.Outcome != nil {
ov := manifest.ValidateOutcome(rm.Outcome)
if len(ov.Errors) > 0 {
ok = false
}
outcomeWarnings = ov.Warnings
}
}
checks["run_manifest"] = rmCheck
checks["run_manifest_commit_id"] = rmCommitCheck
checks["run_manifest_location"] = rmLocCheck
checks["run_manifest_lifecycle"] = rmLifecycle
// Resources check
resCheck := map[string]any{"ok": true}
if task.CPU < 0 {
resCheck["ok"] = false
ok = false
}
checks["resources"] = resCheck
// Snapshot check
snapCheck := map[string]any{"ok": true}
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
// Verify snapshot SHA
dataDir := h.dataDir
if dataDir == "" {
dataDir = filepath.Join(h.expManager.BasePath(), "data")
}
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath)
expectedSHA := task.Metadata["snapshot_sha256"]
if actualSHA != expectedSHA {
snapCheck["ok"] = false
snapCheck["actual"] = actualSHA
ok = false
}
}
checks["snapshot"] = snapCheck
report := map[string]any{
"ok": ok,
"checks": checks,
"narrative_warnings": narrativeWarnings,
"outcome_warnings": outcomeWarnings,
}
payloadBytes, _ := json.Marshal(report)
return h.sendDataPacket(conn, "validate", payloadBytes)
}