Move schema ownership to infrastructure layer: - Redis keys: config/constants.go -> queue/keys.go (TaskQueueKey, TaskPrefix, etc.) - Filesystem paths: config/paths.go -> storage/paths.go (JobPaths) - Create config/shared.go with RedisConfig, SSHConfig - Update all imports: worker/, api/helpers, api/ws_jobs, api/ws_validate - Clean up: remove duplicates from queue/task.go, queue/queue.go, config/paths.go Build status: Compiles successfully
523 lines
17 KiB
Go
523 lines
17 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
type validateCheck struct {
|
|
OK bool `json:"ok"`
|
|
Expected string `json:"expected,omitempty"`
|
|
Actual string `json:"actual,omitempty"`
|
|
Details string `json:"details,omitempty"`
|
|
}
|
|
|
|
type validateReport struct {
|
|
OK bool `json:"ok"`
|
|
CommitID string `json:"commit_id,omitempty"`
|
|
TaskID string `json:"task_id,omitempty"`
|
|
Checks map[string]validateCheck `json:"checks"`
|
|
Errors []string `json:"errors,omitempty"`
|
|
Warnings []string `json:"warnings,omitempty"`
|
|
TS string `json:"ts"`
|
|
}
|
|
|
|
func shouldRequireRunManifest(task *queue.Task) bool {
|
|
if task == nil {
|
|
return false
|
|
}
|
|
s := strings.ToLower(strings.TrimSpace(task.Status))
|
|
switch s {
|
|
case "running", "completed", "failed":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func expectedRunManifestBucketForStatus(status string) (string, bool) {
|
|
s := strings.ToLower(strings.TrimSpace(status))
|
|
switch s {
|
|
case "queued", "pending":
|
|
return "pending", true
|
|
case "running":
|
|
return "running", true
|
|
case "completed", "finished":
|
|
return "finished", true
|
|
case "failed":
|
|
return "failed", true
|
|
default:
|
|
return "", false
|
|
}
|
|
}
|
|
|
|
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
|
|
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
|
|
return "", "", false
|
|
}
|
|
jobPaths := storage.NewJobPaths(basePath)
|
|
typedRoots := []struct {
|
|
bucket string
|
|
root string
|
|
}{
|
|
{bucket: "running", root: jobPaths.RunningPath()},
|
|
{bucket: "pending", root: jobPaths.PendingPath()},
|
|
{bucket: "finished", root: jobPaths.FinishedPath()},
|
|
{bucket: "failed", root: jobPaths.FailedPath()},
|
|
}
|
|
for _, item := range typedRoots {
|
|
root := item.root
|
|
dir := filepath.Join(root, jobName)
|
|
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
|
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
|
return dir, item.bucket, true
|
|
}
|
|
}
|
|
}
|
|
return "", "", false
|
|
}
|
|
|
|
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
|
|
if task == nil {
|
|
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
|
|
}
|
|
|
|
if task.CPU < 0 {
|
|
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
|
|
return chk, []string{"invalid cpu request"}
|
|
}
|
|
if task.MemoryGB < 0 {
|
|
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
|
|
return chk, []string{"invalid memory request"}
|
|
}
|
|
if task.GPU < 0 {
|
|
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
|
|
return chk, []string{"invalid gpu request"}
|
|
}
|
|
|
|
if strings.TrimSpace(task.GPUMemory) != "" {
|
|
s := strings.TrimSpace(task.GPUMemory)
|
|
if strings.HasSuffix(s, "%") {
|
|
v := strings.TrimSuffix(s, "%")
|
|
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
|
if err != nil || f <= 0 || f > 100 {
|
|
details := "gpu_memory must be a percentage in (0,100]"
|
|
chk := validateCheck{OK: false, Details: details}
|
|
return chk, []string{"invalid gpu_memory"}
|
|
}
|
|
} else {
|
|
f, err := strconv.ParseFloat(s, 64)
|
|
if err != nil || f <= 0 || f > 1 {
|
|
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
|
|
return chk, []string{"invalid gpu_memory"}
|
|
}
|
|
}
|
|
}
|
|
|
|
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
|
|
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
|
|
return chk, []string{"invalid gpu_memory"}
|
|
}
|
|
|
|
return validateCheck{OK: true}, nil
|
|
}
|
|
|
|
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
|
|
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
|
|
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
|
|
}
|
|
apiKeyHash := payload[:16]
|
|
targetType := payload[16]
|
|
idLen := int(payload[17])
|
|
if idLen < 1 || len(payload) < 18+idLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
|
|
}
|
|
idBytes := payload[18 : 18+idLen]
|
|
|
|
// Validate API key and user
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Invalid API key",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodePermissionDenied,
|
|
"Insufficient permissions to validate jobs",
|
|
"",
|
|
)
|
|
}
|
|
if h.expManager == nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServiceUnavailable,
|
|
"Experiment manager not available",
|
|
"",
|
|
)
|
|
}
|
|
|
|
var task *queue.Task
|
|
commitID := ""
|
|
switch targetType {
|
|
case 0:
|
|
if len(idBytes) != 20 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
|
|
}
|
|
commitID = fmt.Sprintf("%x", idBytes)
|
|
case 1:
|
|
taskID := string(idBytes)
|
|
if h.queue == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
|
|
}
|
|
t, err := h.queue.GetTask(taskID)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
|
|
}
|
|
task = t
|
|
if h.authConfig != nil &&
|
|
h.authConfig.Enabled &&
|
|
!user.Admin &&
|
|
task.UserID != user.Name &&
|
|
task.CreatedBy != user.Name {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodePermissionDenied,
|
|
"You can only validate your own jobs",
|
|
"",
|
|
)
|
|
}
|
|
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
|
|
}
|
|
commitID = task.Metadata["commit_id"]
|
|
default:
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
|
|
}
|
|
|
|
r := validateReport{
|
|
OK: true,
|
|
TS: time.Now().UTC().Format(time.RFC3339Nano),
|
|
Checks: map[string]validateCheck{},
|
|
}
|
|
if task != nil {
|
|
r.TaskID = task.ID
|
|
}
|
|
if commitID != "" {
|
|
r.CommitID = commitID
|
|
}
|
|
|
|
// Validate commit id format
|
|
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, errMsg)
|
|
}
|
|
|
|
// Experiment manifest integrity
|
|
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
|
|
if r.OK {
|
|
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
|
|
r.OK = false
|
|
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
|
|
r.Errors = append(r.Errors, "experiment manifest validation failed")
|
|
} else {
|
|
r.Checks["experiment_manifest"] = validateCheck{OK: true}
|
|
}
|
|
}
|
|
|
|
// Deps manifest presence + hash
|
|
// TODO(context): Allow client to declare which dependency manifest is authoritative.
|
|
filesPath := h.expManager.GetFilesPath(commitID)
|
|
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
|
|
if depErrs != nil {
|
|
r.OK = false
|
|
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
|
r.Errors = append(r.Errors, depErrs...)
|
|
} else {
|
|
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
|
}
|
|
|
|
// Compare against expected task metadata if available.
|
|
if task != nil {
|
|
resCheck, resErrs := validateResourcesForTask(task)
|
|
r.Checks["resources"] = resCheck
|
|
if !resCheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, resErrs...)
|
|
}
|
|
|
|
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
|
|
if err := container.ValidateJobName(task.JobName); err != nil {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "invalid job name")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
|
|
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
|
|
if shouldRequireRunManifest(task) {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
|
} else {
|
|
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
|
}
|
|
} else {
|
|
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
|
|
if !found {
|
|
if shouldRequireRunManifest(task) {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest not found")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
|
} else {
|
|
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
|
}
|
|
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "unable to read run manifest")
|
|
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
|
|
} else {
|
|
r.Checks["run_manifest"] = validateCheck{OK: true}
|
|
|
|
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
|
|
if ok {
|
|
if expectedBucket != manifestBucket {
|
|
msg := "run manifest location mismatch"
|
|
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
|
|
if shouldRequireRunManifest(task) {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, msg)
|
|
r.Checks["run_manifest_location"] = chk
|
|
} else {
|
|
r.Warnings = append(r.Warnings, msg)
|
|
r.Checks["run_manifest_location"] = chk
|
|
}
|
|
} else {
|
|
r.Checks["run_manifest_location"] = validateCheck{
|
|
OK: true,
|
|
Expected: expectedBucket,
|
|
Actual: manifestBucket,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate task ID using helper
|
|
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
|
|
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
|
|
if !taskIDCheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest task_id mismatch")
|
|
}
|
|
|
|
// Validate commit ID using helper
|
|
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
|
|
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
|
|
if !commitCheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
|
|
}
|
|
|
|
// Validate deps provenance using helper
|
|
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
|
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
|
depGotName := strings.TrimSpace(rm.DepsManifestName)
|
|
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
|
|
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
|
|
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
|
|
if !depsCheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
|
|
}
|
|
|
|
// Validate snapshot using helpers
|
|
if strings.TrimSpace(task.SnapshotID) != "" {
|
|
snapWantID := strings.TrimSpace(task.SnapshotID)
|
|
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
|
|
snapGotID := strings.TrimSpace(rm.SnapshotID)
|
|
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
|
|
|
|
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
|
|
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
|
|
if !snapIDCheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
|
|
}
|
|
|
|
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
|
|
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
|
|
if !snapSHACheck.OK {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
|
|
}
|
|
}
|
|
|
|
// Validate lifecycle using helper
|
|
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
|
|
if lifecycleOK {
|
|
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
|
|
} else {
|
|
chk := validateCheck{OK: false, Details: details}
|
|
if shouldRequireRunManifest(task) {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
|
|
r.Checks["run_manifest_lifecycle"] = chk
|
|
} else {
|
|
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
|
|
r.Checks["run_manifest_lifecycle"] = chk
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
|
|
cur := ""
|
|
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
|
|
cur = man.OverallSHA
|
|
}
|
|
if want == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
|
|
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
|
|
} else if cur == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
|
|
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
|
|
} else if want != cur {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
|
|
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
|
|
} else {
|
|
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
|
|
}
|
|
|
|
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
|
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
|
if wantDep == "" || wantDepSha == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
|
|
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
|
|
} else if depName != "" {
|
|
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
|
|
ok := (wantDep == depName && wantDepSha == sha)
|
|
if !ok {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
|
|
r.Checks["expected_deps_manifest"] = validateCheck{
|
|
OK: false,
|
|
Expected: wantDep + ":" + wantDepSha,
|
|
Actual: depName + ":" + sha,
|
|
}
|
|
} else {
|
|
r.Checks["expected_deps_manifest"] = validateCheck{
|
|
OK: true,
|
|
Expected: wantDep + ":" + wantDepSha,
|
|
Actual: depName + ":" + sha,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Snapshot/dataset checks require dataDir.
|
|
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
|
|
// TODO(context): Validate snapshots by digest.
|
|
if task.SnapshotID != "" {
|
|
if h.dataDir == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
|
|
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
|
|
} else {
|
|
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
|
if nerr != nil || wantSnap == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
|
|
r.Checks["snapshot"] = validateCheck{OK: false}
|
|
} else {
|
|
curSnap, err := worker.DirOverallSHA256Hex(
|
|
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
|
|
)
|
|
if err != nil {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "snapshot hash computation failed")
|
|
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
|
|
} else if curSnap != wantSnap {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "snapshot checksum mismatch")
|
|
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
|
|
} else {
|
|
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(task.DatasetSpecs) > 0 {
|
|
// TODO(context): Add dataset URI fetch/verification.
|
|
// TODO(context): Currently only validates local materialized datasets.
|
|
for _, ds := range task.DatasetSpecs {
|
|
if ds.Checksum == "" {
|
|
continue
|
|
}
|
|
key := "dataset:" + ds.Name
|
|
if h.dataDir == "" {
|
|
r.OK = false
|
|
r.Errors = append(
|
|
r.Errors,
|
|
"api server data_dir not configured; cannot validate dataset checksums",
|
|
)
|
|
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
|
|
continue
|
|
}
|
|
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
|
|
if nerr != nil || wantDS == "" {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "invalid dataset checksum format")
|
|
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
|
|
continue
|
|
}
|
|
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
|
|
if err != nil {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "dataset checksum computation failed")
|
|
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
|
|
continue
|
|
}
|
|
if curDS != wantDS {
|
|
r.OK = false
|
|
r.Errors = append(r.Errors, "dataset checksum mismatch")
|
|
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
|
|
continue
|
|
}
|
|
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
|
|
}
|
|
}
|
|
}
|
|
|
|
payloadBytes, err := json.Marshal(r)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeUnknownError,
|
|
"failed to serialize validate report",
|
|
err.Error(),
|
|
)
|
|
}
|
|
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
|
|
}
|