fetch_ml/internal/api/ws_validate.go

642 lines
20 KiB
Go

package api
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
)
type validateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
}
type validateReport struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]validateCheck `json:"checks"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
}
func shouldRequireRunManifest(task *queue.Task) bool {
if task == nil {
return false
}
s := strings.ToLower(strings.TrimSpace(task.Status))
switch s {
case "running", "completed", "failed":
return true
default:
return false
}
}
func expectedRunManifestBucketForStatus(status string) (string, bool) {
s := strings.ToLower(strings.TrimSpace(status))
switch s {
case "queued", "pending":
return "pending", true
case "running":
return "running", true
case "completed", "finished":
return "finished", true
case "failed":
return "failed", true
default:
return "", false
}
}
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
return "", "", false
}
jobPaths := config.NewJobPaths(basePath)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
for _, item := range typedRoots {
root := item.root
dir := filepath.Join(root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
return dir, item.bucket, true
}
}
}
return "", "", false
}
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
if task == nil {
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
}
if task.CPU < 0 {
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
return chk, []string{"invalid cpu request"}
}
if task.MemoryGB < 0 {
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
return chk, []string{"invalid memory request"}
}
if task.GPU < 0 {
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
return chk, []string{"invalid gpu request"}
}
if strings.TrimSpace(task.GPUMemory) != "" {
s := strings.TrimSpace(task.GPUMemory)
if strings.HasSuffix(s, "%") {
v := strings.TrimSuffix(s, "%")
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil || f <= 0 || f > 100 {
details := "gpu_memory must be a percentage in (0,100]"
chk := validateCheck{OK: false, Details: details}
return chk, []string{"invalid gpu_memory"}
}
} else {
f, err := strconv.ParseFloat(s, 64)
if err != nil || f <= 0 || f > 1 {
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
return chk, []string{"invalid gpu_memory"}
}
}
}
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
return chk, []string{"invalid gpu_memory"}
}
return validateCheck{OK: true}, nil
}
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
}
apiKeyHash := payload[:16]
targetType := payload[16]
idLen := int(payload[17])
if idLen < 1 || len(payload) < 18+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
}
idBytes := payload[18 : 18+idLen]
// Validate API key and user
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Invalid API key",
err.Error(),
)
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to validate jobs",
"",
)
}
if h.expManager == nil {
return h.sendErrorPacket(
conn,
ErrorCodeServiceUnavailable,
"Experiment manager not available",
"",
)
}
var task *queue.Task
commitID := ""
switch targetType {
case 0:
if len(idBytes) != 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
}
commitID = fmt.Sprintf("%x", idBytes)
case 1:
taskID := string(idBytes)
if h.queue == nil {
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
}
task = t
if h.authConfig != nil &&
h.authConfig.Enabled &&
!user.Admin &&
task.UserID != user.Name &&
task.CreatedBy != user.Name {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only validate your own jobs",
"",
)
}
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
}
commitID = task.Metadata["commit_id"]
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
}
r := validateReport{
OK: true,
TS: time.Now().UTC().Format(time.RFC3339Nano),
Checks: map[string]validateCheck{},
}
if task != nil {
r.TaskID = task.ID
}
if commitID != "" {
r.CommitID = commitID
}
// Validate commit id format
if len(commitID) != 40 {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id length")
} else if _, err := hex.DecodeString(commitID); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id hex")
}
// Experiment manifest integrity
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
if r.OK {
if err := h.expManager.ValidateManifest(commitID); err != nil {
r.OK = false
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()}
r.Errors = append(r.Errors, "experiment manifest validation failed")
} else {
r.Checks["experiment_manifest"] = validateCheck{OK: true}
}
}
// Deps manifest presence + hash
// TODO(context): Allow client to declare which dependency manifest is authoritative.
filesPath := h.expManager.GetFilesPath(commitID)
depName, depErr := worker.SelectDependencyManifest(filesPath)
if depErr != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: depErr.Error(),
}
r.Errors = append(r.Errors, "deps manifest missing")
} else {
sha, err := fileSHA256Hex(filepath.Join(filesPath, depName))
if err != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: err.Error(),
}
r.Errors = append(r.Errors, "deps manifest hash failed")
} else {
r.Checks["deps_manifest"] = validateCheck{
OK: true,
Actual: depName + ":" + sha,
}
}
}
// Compare against expected task metadata if available.
if task != nil {
resCheck, resErrs := validateResourcesForTask(task)
r.Checks["resources"] = resCheck
if !resCheck.OK {
r.OK = false
r.Errors = append(r.Errors, resErrs...)
}
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
if err := container.ValidateJobName(task.JobName); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid job name")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
} else {
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
}
} else {
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
if !found {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest not found")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
} else {
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
}
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
r.OK = false
r.Errors = append(r.Errors, "unable to read run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
} else {
r.Checks["run_manifest"] = validateCheck{OK: true}
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
if ok {
if expectedBucket != manifestBucket {
msg := "run manifest location mismatch"
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, msg)
r.Checks["run_manifest_location"] = chk
} else {
r.Warnings = append(r.Warnings, msg)
r.Checks["run_manifest_location"] = chk
}
} else {
r.Checks["run_manifest_location"] = validateCheck{
OK: true,
Expected: expectedBucket,
Actual: manifestBucket,
}
}
}
if strings.TrimSpace(rm.TaskID) == "" {
r.OK = false
r.Errors = append(r.Errors, "run manifest missing task_id")
r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID}
} else if rm.TaskID != task.ID {
r.OK = false
r.Errors = append(r.Errors, "run manifest task_id mismatch")
r.Checks["run_manifest_task_id"] = validateCheck{
OK: false,
Expected: task.ID,
Actual: rm.TaskID,
}
} else {
r.Checks["run_manifest_task_id"] = validateCheck{
OK: true,
Expected: task.ID,
Actual: rm.TaskID,
}
}
commitWant := strings.TrimSpace(task.Metadata["commit_id"])
commitGot := strings.TrimSpace(rm.CommitID)
if commitWant != "" && commitGot != "" && commitWant != commitGot {
r.OK = false
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: false,
Expected: commitWant,
Actual: commitGot,
}
} else if commitWant != "" {
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: true,
Expected: commitWant,
Actual: commitGot,
}
}
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
depGotName := strings.TrimSpace(rm.DepsManifestName)
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" {
expectedDep := depWantName + ":" + depWantSHA
actualDep := depGotName + ":" + depGotSHA
if depWantName != depGotName || depWantSHA != depGotSHA {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
r.Checks["run_manifest_deps"] = validateCheck{
OK: false,
Expected: expectedDep,
Actual: actualDep,
}
} else {
r.Checks["run_manifest_deps"] = validateCheck{
OK: true,
Expected: expectedDep,
Actual: actualDep,
}
}
}
if strings.TrimSpace(task.SnapshotID) != "" {
snapWantID := strings.TrimSpace(task.SnapshotID)
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
snapGotID := strings.TrimSpace(rm.SnapshotID)
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: false,
Expected: snapWantID,
Actual: snapGotID,
}
} else {
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: true,
Expected: snapWantID,
Actual: snapGotID,
}
}
if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: false,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
} else if snapWantSHA != "" {
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: true,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
}
}
statusLower := strings.ToLower(strings.TrimSpace(task.Status))
lifecycleOK := true
details := ""
switch statusLower {
case "running":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for running task"
}
if !rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "ended_at must be empty for running task"
}
}
if rm.ExitCode != nil {
lifecycleOK = false
if details == "" {
details = "exit_code must be empty for running task"
}
}
case "completed", "failed":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for completed/failed task"
}
if rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "missing ended_at for completed/failed task"
}
}
if rm.ExitCode == nil {
lifecycleOK = false
if details == "" {
details = "missing exit_code for completed/failed task"
}
}
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
lifecycleOK = false
if details == "" {
details = "ended_at is before started_at"
}
}
case "queued", "pending":
// queued/pending tasks may not have started yet.
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
lifecycleOK = false
details = "queued/pending task should not have ended_at/exit_code"
}
}
if lifecycleOK {
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
} else {
chk := validateCheck{OK: false, Details: details}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
} else {
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
}
}
}
}
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
cur := ""
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
cur = man.OverallSHA
}
if want == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
} else if cur == "" {
r.OK = false
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
} else if want != cur {
r.OK = false
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
} else {
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
}
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if wantDep == "" || wantDepSha == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
} else if depName != "" {
sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName))
ok := (wantDep == depName && wantDepSha == sha)
if !ok {
r.OK = false
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
r.Checks["expected_deps_manifest"] = validateCheck{
OK: false,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
} else {
r.Checks["expected_deps_manifest"] = validateCheck{
OK: true,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
}
}
// Snapshot/dataset checks require dataDir.
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
// TODO(context): Validate snapshots by digest.
if task.SnapshotID != "" {
if h.dataDir == "" {
r.OK = false
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
} else {
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if nerr != nil || wantSnap == "" {
r.OK = false
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
r.Checks["snapshot"] = validateCheck{OK: false}
} else {
curSnap, err := worker.DirOverallSHA256Hex(
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
)
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "snapshot hash computation failed")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
} else if curSnap != wantSnap {
r.OK = false
r.Errors = append(r.Errors, "snapshot checksum mismatch")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
} else {
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
}
}
}
}
if len(task.DatasetSpecs) > 0 {
// TODO(context): Add dataset URI fetch/verification.
// TODO(context): Currently only validates local materialized datasets.
for _, ds := range task.DatasetSpecs {
if ds.Checksum == "" {
continue
}
key := "dataset:" + ds.Name
if h.dataDir == "" {
r.OK = false
r.Errors = append(
r.Errors,
"api server data_dir not configured; cannot validate dataset checksums",
)
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
continue
}
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
if nerr != nil || wantDS == "" {
r.OK = false
r.Errors = append(r.Errors, "invalid dataset checksum format")
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
continue
}
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum computation failed")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
continue
}
if curDS != wantDS {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum mismatch")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
continue
}
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
}
}
}
payloadBytes, err := json.Marshal(r)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeUnknownError,
"failed to serialize validate report",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
}