package api import ( "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" "strconv" "strings" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" ) type validateCheck struct { OK bool `json:"ok"` Expected string `json:"expected,omitempty"` Actual string `json:"actual,omitempty"` Details string `json:"details,omitempty"` } type validateReport struct { OK bool `json:"ok"` CommitID string `json:"commit_id,omitempty"` TaskID string `json:"task_id,omitempty"` Checks map[string]validateCheck `json:"checks"` Errors []string `json:"errors,omitempty"` Warnings []string `json:"warnings,omitempty"` TS string `json:"ts"` } func shouldRequireRunManifest(task *queue.Task) bool { if task == nil { return false } s := strings.ToLower(strings.TrimSpace(task.Status)) switch s { case "running", "completed", "failed": return true default: return false } } func expectedRunManifestBucketForStatus(status string) (string, bool) { s := strings.ToLower(strings.TrimSpace(status)) switch s { case "queued", "pending": return "pending", true case "running": return "running", true case "completed", "finished": return "finished", true case "failed": return "failed", true default: return "", false } } func findRunManifestDir(basePath string, jobName string) (string, string, bool) { if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" { return "", "", false } jobPaths := config.NewJobPaths(basePath) typedRoots := []struct { bucket string root string }{ {bucket: "running", root: jobPaths.RunningPath()}, {bucket: "pending", root: jobPaths.PendingPath()}, {bucket: "finished", root: jobPaths.FinishedPath()}, {bucket: "failed", root: jobPaths.FailedPath()}, } for _, item := range typedRoots { root := item.root dir := filepath.Join(root, jobName) if info, err := os.Stat(dir); err == nil && info.IsDir() { if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { return dir, item.bucket, true } } } return "", "", false } func validateResourcesForTask(task *queue.Task) (validateCheck, []string) { if task == nil { return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"} } if task.CPU < 0 { chk := validateCheck{OK: false, Details: "cpu must be >= 0"} return chk, []string{"invalid cpu request"} } if task.MemoryGB < 0 { chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"} return chk, []string{"invalid memory request"} } if task.GPU < 0 { chk := validateCheck{OK: false, Details: "gpu must be >= 0"} return chk, []string{"invalid gpu request"} } if strings.TrimSpace(task.GPUMemory) != "" { s := strings.TrimSpace(task.GPUMemory) if strings.HasSuffix(s, "%") { v := strings.TrimSuffix(s, "%") f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) if err != nil || f <= 0 || f > 100 { details := "gpu_memory must be a percentage in (0,100]" chk := validateCheck{OK: false, Details: details} return chk, []string{"invalid gpu_memory"} } } else { f, err := strconv.ParseFloat(s, 64) if err != nil || f <= 0 || f > 1 { chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"} return chk, []string{"invalid gpu_memory"} } } } if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" { chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"} return chk, []string{"invalid gpu_memory"} } return validateCheck{OK: true}, nil } func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var] // target_type: 0=commit_id (20 bytes), 1=task_id (string) // TODO(context): Add a versioned validate protocol once we need more target types/fields. if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "") } apiKeyHash := payload[:16] targetType := payload[16] idLen := int(payload[17]) if idLen < 1 || len(payload) < 18+idLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "") } idBytes := payload[18 : 18+idLen] // Validate API key and user user, err := h.validateWSUser(apiKeyHash) if err != nil { return h.sendErrorPacket( conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error(), ) } if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to validate jobs", "", ) } if h.expManager == nil { return h.sendErrorPacket( conn, ErrorCodeServiceUnavailable, "Experiment manager not available", "", ) } var task *queue.Task commitID := "" switch targetType { case 0: if len(idBytes) != 20 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "") } commitID = fmt.Sprintf("%x", idBytes) case 1: taskID := string(idBytes) if h.queue == nil { return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "") } t, err := h.queue.GetTask(taskID) if err != nil { return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error()) } task = t if h.authConfig != nil && h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name { return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "You can only validate your own jobs", "", ) } if task.Metadata == nil || task.Metadata["commit_id"] == "" { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "") } commitID = task.Metadata["commit_id"] default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "") } r := validateReport{ OK: true, TS: time.Now().UTC().Format(time.RFC3339Nano), Checks: map[string]validateCheck{}, } if task != nil { r.TaskID = task.ID } if commitID != "" { r.CommitID = commitID } // Validate commit id format if len(commitID) != 40 { r.OK = false r.Errors = append(r.Errors, "invalid commit_id length") } else if _, err := hex.DecodeString(commitID); err != nil { r.OK = false r.Errors = append(r.Errors, "invalid commit_id hex") } // Experiment manifest integrity // TODO(context): Extend report to include per-file diff list on mismatch (bounded output). if r.OK { if err := h.expManager.ValidateManifest(commitID); err != nil { r.OK = false r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()} r.Errors = append(r.Errors, "experiment manifest validation failed") } else { r.Checks["experiment_manifest"] = validateCheck{OK: true} } } // Deps manifest presence + hash // TODO(context): Allow client to declare which dependency manifest is authoritative. filesPath := h.expManager.GetFilesPath(commitID) depName, depErr := worker.SelectDependencyManifest(filesPath) if depErr != nil { r.OK = false r.Checks["deps_manifest"] = validateCheck{ OK: false, Details: depErr.Error(), } r.Errors = append(r.Errors, "deps manifest missing") } else { sha, err := fileSHA256Hex(filepath.Join(filesPath, depName)) if err != nil { r.OK = false r.Checks["deps_manifest"] = validateCheck{ OK: false, Details: err.Error(), } r.Errors = append(r.Errors, "deps manifest hash failed") } else { r.Checks["deps_manifest"] = validateCheck{ OK: true, Actual: depName + ":" + sha, } } } // Compare against expected task metadata if available. if task != nil { resCheck, resErrs := validateResourcesForTask(task) r.Checks["resources"] = resCheck if !resCheck.OK { r.OK = false r.Errors = append(r.Errors, resErrs...) } // Run manifest checks: best-effort for queued tasks, required for running/completed/failed. if err := container.ValidateJobName(task.JobName); err != nil { r.OK = false r.Errors = append(r.Errors, "invalid job name") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"} } else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" { if shouldRequireRunManifest(task) { r.OK = false r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} } else { r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} } } else { manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName) if !found { if shouldRequireRunManifest(task) { r.OK = false r.Errors = append(r.Errors, "run manifest not found") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} } else { r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} } } else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil { r.OK = false r.Errors = append(r.Errors, "unable to read run manifest") r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"} } else { r.Checks["run_manifest"] = validateCheck{OK: true} expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status) if ok { if expectedBucket != manifestBucket { msg := "run manifest location mismatch" chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket} if shouldRequireRunManifest(task) { r.OK = false r.Errors = append(r.Errors, msg) r.Checks["run_manifest_location"] = chk } else { r.Warnings = append(r.Warnings, msg) r.Checks["run_manifest_location"] = chk } } else { r.Checks["run_manifest_location"] = validateCheck{ OK: true, Expected: expectedBucket, Actual: manifestBucket, } } } if strings.TrimSpace(rm.TaskID) == "" { r.OK = false r.Errors = append(r.Errors, "run manifest missing task_id") r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID} } else if rm.TaskID != task.ID { r.OK = false r.Errors = append(r.Errors, "run manifest task_id mismatch") r.Checks["run_manifest_task_id"] = validateCheck{ OK: false, Expected: task.ID, Actual: rm.TaskID, } } else { r.Checks["run_manifest_task_id"] = validateCheck{ OK: true, Expected: task.ID, Actual: rm.TaskID, } } commitWant := strings.TrimSpace(task.Metadata["commit_id"]) commitGot := strings.TrimSpace(rm.CommitID) if commitWant != "" && commitGot != "" && commitWant != commitGot { r.OK = false r.Errors = append(r.Errors, "run manifest commit_id mismatch") r.Checks["run_manifest_commit_id"] = validateCheck{ OK: false, Expected: commitWant, Actual: commitGot, } } else if commitWant != "" { r.Checks["run_manifest_commit_id"] = validateCheck{ OK: true, Expected: commitWant, Actual: commitGot, } } depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"]) depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) depGotName := strings.TrimSpace(rm.DepsManifestName) depGotSHA := strings.TrimSpace(rm.DepsManifestSHA) if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" { expectedDep := depWantName + ":" + depWantSHA actualDep := depGotName + ":" + depGotSHA if depWantName != depGotName || depWantSHA != depGotSHA { r.OK = false r.Errors = append(r.Errors, "run manifest deps provenance mismatch") r.Checks["run_manifest_deps"] = validateCheck{ OK: false, Expected: expectedDep, Actual: actualDep, } } else { r.Checks["run_manifest_deps"] = validateCheck{ OK: true, Expected: expectedDep, Actual: actualDep, } } } if strings.TrimSpace(task.SnapshotID) != "" { snapWantID := strings.TrimSpace(task.SnapshotID) snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"]) snapGotID := strings.TrimSpace(rm.SnapshotID) snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256) if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID { r.OK = false r.Errors = append(r.Errors, "run manifest snapshot_id mismatch") r.Checks["run_manifest_snapshot_id"] = validateCheck{ OK: false, Expected: snapWantID, Actual: snapGotID, } } else { r.Checks["run_manifest_snapshot_id"] = validateCheck{ OK: true, Expected: snapWantID, Actual: snapGotID, } } if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA { r.OK = false r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch") r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ OK: false, Expected: snapWantSHA, Actual: snapGotSHA, } } else if snapWantSHA != "" { r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ OK: true, Expected: snapWantSHA, Actual: snapGotSHA, } } } statusLower := strings.ToLower(strings.TrimSpace(task.Status)) lifecycleOK := true details := "" switch statusLower { case "running": if rm.StartedAt.IsZero() { lifecycleOK = false details = "missing started_at for running task" } if !rm.EndedAt.IsZero() { lifecycleOK = false if details == "" { details = "ended_at must be empty for running task" } } if rm.ExitCode != nil { lifecycleOK = false if details == "" { details = "exit_code must be empty for running task" } } case "completed", "failed": if rm.StartedAt.IsZero() { lifecycleOK = false details = "missing started_at for completed/failed task" } if rm.EndedAt.IsZero() { lifecycleOK = false if details == "" { details = "missing ended_at for completed/failed task" } } if rm.ExitCode == nil { lifecycleOK = false if details == "" { details = "missing exit_code for completed/failed task" } } if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) { lifecycleOK = false if details == "" { details = "ended_at is before started_at" } } case "queued", "pending": // queued/pending tasks may not have started yet. if !rm.EndedAt.IsZero() || rm.ExitCode != nil { lifecycleOK = false details = "queued/pending task should not have ended_at/exit_code" } } if lifecycleOK { r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true} } else { chk := validateCheck{OK: false, Details: details} if shouldRequireRunManifest(task) { r.OK = false r.Errors = append(r.Errors, "run manifest lifecycle invalid") r.Checks["run_manifest_lifecycle"] = chk } else { r.Warnings = append(r.Warnings, "run manifest lifecycle invalid") r.Checks["run_manifest_lifecycle"] = chk } } } } want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"]) cur := "" if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil { cur = man.OverallSHA } if want == "" { r.OK = false r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha") r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur} } else if cur == "" { r.OK = false r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha") r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want} } else if want != cur { r.OK = false r.Errors = append(r.Errors, "experiment manifest overall sha mismatch") r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur} } else { r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur} } wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"]) wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) if wantDep == "" || wantDepSha == "" { r.OK = false r.Errors = append(r.Errors, "missing expected deps manifest provenance") r.Checks["expected_deps_manifest"] = validateCheck{OK: false} } else if depName != "" { sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName)) ok := (wantDep == depName && wantDepSha == sha) if !ok { r.OK = false r.Errors = append(r.Errors, "deps manifest provenance mismatch") r.Checks["expected_deps_manifest"] = validateCheck{ OK: false, Expected: wantDep + ":" + wantDepSha, Actual: depName + ":" + sha, } } else { r.Checks["expected_deps_manifest"] = validateCheck{ OK: true, Expected: wantDep + ":" + wantDepSha, Actual: depName + ":" + sha, } } } // Snapshot/dataset checks require dataDir. // TODO(context): Support snapshot stores beyond local filesystem (e.g. S3). // TODO(context): Validate snapshots by digest. if task.SnapshotID != "" { if h.dataDir == "" { r.OK = false r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot") r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"} } else { wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) if nerr != nil || wantSnap == "" { r.OK = false r.Errors = append(r.Errors, "missing/invalid snapshot_sha256") r.Checks["snapshot"] = validateCheck{OK: false} } else { curSnap, err := worker.DirOverallSHA256Hex( filepath.Join(h.dataDir, "snapshots", task.SnapshotID), ) if err != nil { r.OK = false r.Errors = append(r.Errors, "snapshot hash computation failed") r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()} } else if curSnap != wantSnap { r.OK = false r.Errors = append(r.Errors, "snapshot checksum mismatch") r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap} } else { r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap} } } } } if len(task.DatasetSpecs) > 0 { // TODO(context): Add dataset URI fetch/verification. // TODO(context): Currently only validates local materialized datasets. for _, ds := range task.DatasetSpecs { if ds.Checksum == "" { continue } key := "dataset:" + ds.Name if h.dataDir == "" { r.OK = false r.Errors = append( r.Errors, "api server data_dir not configured; cannot validate dataset checksums", ) r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"} continue } wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum) if nerr != nil || wantDS == "" { r.OK = false r.Errors = append(r.Errors, "invalid dataset checksum format") r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"} continue } curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name)) if err != nil { r.OK = false r.Errors = append(r.Errors, "dataset checksum computation failed") r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()} continue } if curDS != wantDS { r.OK = false r.Errors = append(r.Errors, "dataset checksum mismatch") r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS} continue } r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS} } } } payloadBytes, err := json.Marshal(r) if err != nil { return h.sendErrorPacket( conn, ErrorCodeUnknownError, "failed to serialize validate report", err.Error(), ) } return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes)) }