fetch_ml/internal/api/ws_validate.go
Jeremie Fraeys b05470b30a
refactor: improve API structure and WebSocket protocol
- Extract WebSocket protocol handling to dedicated module
- Add helper functions for DB operations, validation, and responses
- Improve WebSocket frame handling and opcodes
- Refactor dataset, job, and Jupyter handlers
- Add duplicate detection processing
2026-02-16 20:38:12 -05:00

523 lines
17 KiB
Go

package api
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
)
type validateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
}
type validateReport struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]validateCheck `json:"checks"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
}
func shouldRequireRunManifest(task *queue.Task) bool {
if task == nil {
return false
}
s := strings.ToLower(strings.TrimSpace(task.Status))
switch s {
case "running", "completed", "failed":
return true
default:
return false
}
}
func expectedRunManifestBucketForStatus(status string) (string, bool) {
s := strings.ToLower(strings.TrimSpace(status))
switch s {
case "queued", "pending":
return "pending", true
case "running":
return "running", true
case "completed", "finished":
return "finished", true
case "failed":
return "failed", true
default:
return "", false
}
}
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
return "", "", false
}
jobPaths := config.NewJobPaths(basePath)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
for _, item := range typedRoots {
root := item.root
dir := filepath.Join(root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
return dir, item.bucket, true
}
}
}
return "", "", false
}
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
if task == nil {
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
}
if task.CPU < 0 {
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
return chk, []string{"invalid cpu request"}
}
if task.MemoryGB < 0 {
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
return chk, []string{"invalid memory request"}
}
if task.GPU < 0 {
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
return chk, []string{"invalid gpu request"}
}
if strings.TrimSpace(task.GPUMemory) != "" {
s := strings.TrimSpace(task.GPUMemory)
if strings.HasSuffix(s, "%") {
v := strings.TrimSuffix(s, "%")
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil || f <= 0 || f > 100 {
details := "gpu_memory must be a percentage in (0,100]"
chk := validateCheck{OK: false, Details: details}
return chk, []string{"invalid gpu_memory"}
}
} else {
f, err := strconv.ParseFloat(s, 64)
if err != nil || f <= 0 || f > 1 {
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
return chk, []string{"invalid gpu_memory"}
}
}
}
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
return chk, []string{"invalid gpu_memory"}
}
return validateCheck{OK: true}, nil
}
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
}
apiKeyHash := payload[:16]
targetType := payload[16]
idLen := int(payload[17])
if idLen < 1 || len(payload) < 18+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
}
idBytes := payload[18 : 18+idLen]
// Validate API key and user
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Invalid API key",
err.Error(),
)
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to validate jobs",
"",
)
}
if h.expManager == nil {
return h.sendErrorPacket(
conn,
ErrorCodeServiceUnavailable,
"Experiment manager not available",
"",
)
}
var task *queue.Task
commitID := ""
switch targetType {
case 0:
if len(idBytes) != 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
}
commitID = fmt.Sprintf("%x", idBytes)
case 1:
taskID := string(idBytes)
if h.queue == nil {
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
}
task = t
if h.authConfig != nil &&
h.authConfig.Enabled &&
!user.Admin &&
task.UserID != user.Name &&
task.CreatedBy != user.Name {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only validate your own jobs",
"",
)
}
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
}
commitID = task.Metadata["commit_id"]
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
}
r := validateReport{
OK: true,
TS: time.Now().UTC().Format(time.RFC3339Nano),
Checks: map[string]validateCheck{},
}
if task != nil {
r.TaskID = task.ID
}
if commitID != "" {
r.CommitID = commitID
}
// Validate commit id format
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
r.OK = false
r.Errors = append(r.Errors, errMsg)
}
// Experiment manifest integrity
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
if r.OK {
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
r.OK = false
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
r.Errors = append(r.Errors, "experiment manifest validation failed")
} else {
r.Checks["experiment_manifest"] = validateCheck{OK: true}
}
}
// Deps manifest presence + hash
// TODO(context): Allow client to declare which dependency manifest is authoritative.
filesPath := h.expManager.GetFilesPath(commitID)
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
if depErrs != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck(depCheck)
r.Errors = append(r.Errors, depErrs...)
} else {
r.Checks["deps_manifest"] = validateCheck(depCheck)
}
// Compare against expected task metadata if available.
if task != nil {
resCheck, resErrs := validateResourcesForTask(task)
r.Checks["resources"] = resCheck
if !resCheck.OK {
r.OK = false
r.Errors = append(r.Errors, resErrs...)
}
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
if err := container.ValidateJobName(task.JobName); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid job name")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
} else {
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
}
} else {
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
if !found {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest not found")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
} else {
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
}
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
r.OK = false
r.Errors = append(r.Errors, "unable to read run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
} else {
r.Checks["run_manifest"] = validateCheck{OK: true}
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
if ok {
if expectedBucket != manifestBucket {
msg := "run manifest location mismatch"
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, msg)
r.Checks["run_manifest_location"] = chk
} else {
r.Warnings = append(r.Warnings, msg)
r.Checks["run_manifest_location"] = chk
}
} else {
r.Checks["run_manifest_location"] = validateCheck{
OK: true,
Expected: expectedBucket,
Actual: manifestBucket,
}
}
}
// Validate task ID using helper
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
if !taskIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest task_id mismatch")
}
// Validate commit ID using helper
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
if !commitCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
}
// Validate deps provenance using helper
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
depGotName := strings.TrimSpace(rm.DepsManifestName)
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
if !depsCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
}
// Validate snapshot using helpers
if strings.TrimSpace(task.SnapshotID) != "" {
snapWantID := strings.TrimSpace(task.SnapshotID)
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
snapGotID := strings.TrimSpace(rm.SnapshotID)
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
if !snapIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
}
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
if !snapSHACheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
}
}
// Validate lifecycle using helper
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
if lifecycleOK {
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
} else {
chk := validateCheck{OK: false, Details: details}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
} else {
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
}
}
}
}
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
cur := ""
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
cur = man.OverallSHA
}
if want == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
} else if cur == "" {
r.OK = false
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
} else if want != cur {
r.OK = false
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
} else {
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
}
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if wantDep == "" || wantDepSha == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
} else if depName != "" {
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
ok := (wantDep == depName && wantDepSha == sha)
if !ok {
r.OK = false
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
r.Checks["expected_deps_manifest"] = validateCheck{
OK: false,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
} else {
r.Checks["expected_deps_manifest"] = validateCheck{
OK: true,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
}
}
// Snapshot/dataset checks require dataDir.
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
// TODO(context): Validate snapshots by digest.
if task.SnapshotID != "" {
if h.dataDir == "" {
r.OK = false
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
} else {
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if nerr != nil || wantSnap == "" {
r.OK = false
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
r.Checks["snapshot"] = validateCheck{OK: false}
} else {
curSnap, err := worker.DirOverallSHA256Hex(
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
)
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "snapshot hash computation failed")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
} else if curSnap != wantSnap {
r.OK = false
r.Errors = append(r.Errors, "snapshot checksum mismatch")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
} else {
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
}
}
}
}
if len(task.DatasetSpecs) > 0 {
// TODO(context): Add dataset URI fetch/verification.
// TODO(context): Currently only validates local materialized datasets.
for _, ds := range task.DatasetSpecs {
if ds.Checksum == "" {
continue
}
key := "dataset:" + ds.Name
if h.dataDir == "" {
r.OK = false
r.Errors = append(
r.Errors,
"api server data_dir not configured; cannot validate dataset checksums",
)
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
continue
}
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
if nerr != nil || wantDS == "" {
r.OK = false
r.Errors = append(r.Errors, "invalid dataset checksum format")
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
continue
}
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum computation failed")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
continue
}
if curDS != wantDS {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum mismatch")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
continue
}
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
}
}
}
payloadBytes, err := json.Marshal(r)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeUnknownError,
"failed to serialize validate report",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
}