From b05470b30af6dce899bbac7093eacfe98e9f9c26 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 16 Feb 2026 20:38:12 -0500 Subject: [PATCH] refactor: improve API structure and WebSocket protocol - Extract WebSocket protocol handling to dedicated module - Add helper functions for DB operations, validation, and responses - Improve WebSocket frame handling and opcodes - Refactor dataset, job, and Jupyter handlers - Add duplicate detection processing --- internal/api/handlers.go | 18 +- internal/api/helpers/db_helpers.go | 49 + internal/api/helpers/experiment_setup.go | 193 +++ internal/api/helpers/hash_helpers.go | 129 ++ internal/api/helpers/payload_parser.go | 121 ++ internal/api/helpers/response_helpers.go | 185 +++ internal/api/helpers/validation_helpers.go | 237 +++ internal/api/ws_datasets.go | 107 +- internal/api/ws_handler.go | 85 ++ internal/api/ws_jobs.go | 1532 +++++++------------- internal/api/ws_jupyter.go | 143 +- internal/api/ws_protocol.go | 43 + internal/api/ws_validate.go | 191 +-- 13 files changed, 1663 insertions(+), 1370 deletions(-) create mode 100644 internal/api/helpers/db_helpers.go create mode 100644 internal/api/helpers/experiment_setup.go create mode 100644 internal/api/helpers/hash_helpers.go create mode 100644 internal/api/helpers/payload_parser.go create mode 100644 internal/api/helpers/response_helpers.go create mode 100644 internal/api/helpers/validation_helpers.go create mode 100644 internal/api/ws_protocol.go diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 55bef74..0788ea3 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" @@ -62,9 +63,8 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) { "message": "Database status check not implemented", } - jsonBytes, _ := json.Marshal(response) w.WriteHeader(http.StatusOK) - if _, err := w.Write(jsonBytes); err != nil { + if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil { h.logger.Error("failed to write response", "error", err) } } @@ -105,13 +105,8 @@ func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) // listJupyterServices lists all Jupyter services func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) { services := h.jupyterServiceMgr.ListServices() - jsonBytes, err := json.Marshal(services) - if err != nil { - http.Error(w, "Failed to marshal services", http.StatusInternalServerError) - return - } w.WriteHeader(http.StatusOK) - if _, err := w.Write(jsonBytes); err != nil { + if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil { h.logger.Error("failed to write response", "error", err) } } @@ -131,13 +126,8 @@ func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) { return } - jsonBytes, err := json.Marshal(service) - if err != nil { - http.Error(w, "Failed to marshal service", http.StatusInternalServerError) - return - } w.WriteHeader(http.StatusCreated) - if _, err := w.Write(jsonBytes); err != nil { + if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil { h.logger.Error("failed to write response", "error", err) } } diff --git a/internal/api/helpers/db_helpers.go b/internal/api/helpers/db_helpers.go new file mode 100644 index 0000000..04213ad --- /dev/null +++ b/internal/api/helpers/db_helpers.go @@ -0,0 +1,49 @@ +// Package helpers provides shared utilities for WebSocket handlers. +package helpers + +import ( + "context" + "time" +) + +// DBContext provides a standard database operation context. +// It creates a context with the specified timeout and returns the context and cancel function. +func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), timeout) +} + +// DBContextShort returns a short-lived context for quick DB operations (3 seconds). +func DBContextShort() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 3*time.Second) +} + +// DBContextMedium returns a medium-lived context for standard DB operations (5 seconds). +func DBContextMedium() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 5*time.Second) +} + +// DBContextLong returns a long-lived context for complex DB operations (10 seconds). +func DBContextLong() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 10*time.Second) +} + +// StringSliceContains checks if a string slice contains a specific string. +func StringSliceContains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// StringSliceFilter filters a string slice based on a predicate. +func StringSliceFilter(slice []string, predicate func(string) bool) []string { + result := make([]string, 0) + for _, s := range slice { + if predicate(s) { + result = append(result, s) + } + } + return result +} diff --git a/internal/api/helpers/experiment_setup.go b/internal/api/helpers/experiment_setup.go new file mode 100644 index 0000000..2ce3d90 --- /dev/null +++ b/internal/api/helpers/experiment_setup.go @@ -0,0 +1,193 @@ +// Package helpers provides shared utilities for WebSocket handlers. +package helpers + +import ( + "context" + "fmt" + "time" + + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" + "github.com/jfraeys/fetch_ml/internal/telemetry" +) + +// ExperimentSetupResult contains the result of experiment setup operations +type ExperimentSetupResult struct { + CommitIDStr string + Manifest *experiment.Manifest + Err error +} + +// RunExperimentSetup performs the common experiment setup operations: +// create experiment dir, write metadata, ensure minimal files, generate manifest. +// Returns the commitID string and any error that occurred. +func RunExperimentSetup( + logger *logging.Logger, + expMgr *experiment.Manager, + commitID []byte, + jobName string, + userName string, +) (string, error) { + commitIDStr := fmt.Sprintf("%x", commitID) + + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.create", 50*time.Millisecond, + func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) }, + ); err != nil { + logger.Error("failed to create experiment directory", "error", err) + return "", fmt.Errorf("failed to create experiment directory: %w", err) + } + + meta := &experiment.Metadata{ + CommitID: commitIDStr, + JobName: jobName, + User: userName, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.write_metadata", 50*time.Millisecond, + func() (string, error) { return "", expMgr.WriteMetadata(meta) }, + ); err != nil { + logger.Error("failed to save experiment metadata", "error", err) + return "", fmt.Errorf("failed to save experiment metadata: %w", err) + } + + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.ensure_minimal_files", 50*time.Millisecond, + func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) }, + ); err != nil { + logger.Error("failed to ensure minimal experiment files", "error", err) + return "", fmt.Errorf("failed to initialize experiment files: %w", err) + } + + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.generate_manifest", 100*time.Millisecond, + func() (string, error) { + manifest, err := expMgr.GenerateManifest(commitIDStr) + if err != nil { + return "", fmt.Errorf("failed to generate manifest: %w", err) + } + return "", expMgr.WriteManifest(manifest) + }, + ); err != nil { + logger.Error("failed to generate/write manifest", "error", err) + return "", fmt.Errorf("failed to generate content integrity manifest: %w", err) + } + + return commitIDStr, nil +} + +// RunExperimentSetupWithoutManifest performs experiment setup without manifest generation. +// Used for jobs with args/note where manifest generation is deferred. +func RunExperimentSetupWithoutManifest( + logger *logging.Logger, + expMgr *experiment.Manager, + commitID []byte, + jobName string, + userName string, +) (string, error) { + commitIDStr := fmt.Sprintf("%x", commitID) + + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.create", 50*time.Millisecond, + func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) }, + ); err != nil { + logger.Error("failed to create experiment directory", "error", err) + return "", fmt.Errorf("failed to create experiment directory: %w", err) + } + + meta := &experiment.Metadata{ + CommitID: commitIDStr, + JobName: jobName, + User: userName, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.write_metadata", 50*time.Millisecond, + func() (string, error) { return "", expMgr.WriteMetadata(meta) }, + ); err != nil { + logger.Error("failed to save experiment metadata", "error", err) + return "", fmt.Errorf("failed to save experiment metadata: %w", err) + } + + if _, err := telemetry.ExecWithMetrics( + logger, "experiment.ensure_minimal_files", 50*time.Millisecond, + func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) }, + ); err != nil { + logger.Error("failed to ensure minimal experiment files", "error", err) + return "", fmt.Errorf("failed to initialize experiment files: %w", err) + } + + return commitIDStr, nil +} + +// UpsertExperimentDBAsync upserts experiment data to the database asynchronously. +// This is a fire-and-forget operation that runs in a goroutine. +func UpsertExperimentDBAsync( + logger *logging.Logger, + db *storage.DB, + commitIDStr string, + jobName string, + userName string, +) { + if db == nil { + return + } + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + exp := &storage.Experiment{ID: commitIDStr, Name: jobName, Status: "pending", UserID: userName} + if _, err := telemetry.ExecWithMetrics(logger, "db.experiments.upsert", 50*time.Millisecond, + func() (string, error) { return "", db.UpsertExperiment(ctx, exp) }); err != nil { + logger.Error("failed to upsert experiment row", "error", err) + } + }() +} + +// TaskEnqueueResult contains the result of task enqueueing +type TaskEnqueueResult struct { + TaskID string + Err error +} + +// BuildTaskMetadata creates the standard task metadata map. +func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string { + meta := map[string]string{ + "commit_id": commitIDStr, + "dataset_id": datasetID, + "params_hash": paramsHash, + } + for k, v := range prov { + if v != "" { + meta[k] = v + } + } + return meta +} + +// BuildSnapshotTaskMetadata creates task metadata for snapshot jobs. +func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string { + meta := map[string]string{ + "commit_id": commitIDStr, + "snapshot_sha256": snapshotSHA, + } + for k, v := range prov { + if v != "" { + meta[k] = v + } + } + return meta +} + +// ApplyResourceRequest applies resource request to a task. +func ApplyResourceRequest(task *queue.Task, resources *ResourceRequest) { + if resources != nil { + task.CPU = resources.CPU + task.MemoryGB = resources.MemoryGB + task.GPU = resources.GPU + task.GPUMemory = resources.GPUMemory + } +} diff --git a/internal/api/helpers/hash_helpers.go b/internal/api/helpers/hash_helpers.go new file mode 100644 index 0000000..0fad0c4 --- /dev/null +++ b/internal/api/helpers/hash_helpers.go @@ -0,0 +1,129 @@ +// Package helpers provides shared utilities for WebSocket handlers. +package helpers + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker" +) + +// ComputeDatasetID computes a dataset ID from dataset specs or dataset names. +func ComputeDatasetID(datasetSpecs []queue.DatasetSpec, datasets []string) string { + if len(datasetSpecs) > 0 { + var checksums []string + for _, ds := range datasetSpecs { + if ds.Checksum != "" { + checksums = append(checksums, ds.Checksum) + } else if ds.Name != "" { + checksums = append(checksums, ds.Name) + } + } + if len(checksums) > 0 { + h := sha256.New() + for _, cs := range checksums { + h.Write([]byte(cs)) + } + return hex.EncodeToString(h.Sum(nil))[:16] + } + } + if len(datasets) > 0 { + h := sha256.New() + for _, ds := range datasets { + h.Write([]byte(ds)) + } + return hex.EncodeToString(h.Sum(nil))[:16] + } + return "" +} + +// ComputeParamsHash computes a hash of the args string. +func ComputeParamsHash(args string) string { + if strings.TrimSpace(args) == "" { + return "" + } + h := sha256.New() + h.Write([]byte(strings.TrimSpace(args))) + return hex.EncodeToString(h.Sum(nil))[:16] +} + +// FileSHA256Hex computes the SHA256 hash of a file. +func FileSHA256Hex(path string) (string, error) { + f, err := os.Open(filepath.Clean(path)) + if err != nil { + return "", err + } + defer func() { _ = f.Close() }() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// ExpectedProvenanceForCommit computes expected provenance metadata for a commit. +func ExpectedProvenanceForCommit( + expMgr *experiment.Manager, + commitID string, +) (map[string]string, error) { + out := map[string]string{} + manifest, err := expMgr.ReadManifest(commitID) + if err != nil { + return nil, err + } + if manifest == nil || manifest.OverallSHA == "" { + return nil, fmt.Errorf("missing manifest overall_sha") + } + out["experiment_manifest_overall_sha"] = manifest.OverallSHA + + filesPath := expMgr.GetFilesPath(commitID) + depName, err := worker.SelectDependencyManifest(filesPath) + if err == nil && strings.TrimSpace(depName) != "" { + depPath := filepath.Join(filesPath, depName) + sha, err := FileSHA256Hex(depPath) + if err == nil && strings.TrimSpace(sha) != "" { + out["deps_manifest_name"] = depName + out["deps_manifest_sha256"] = sha + } + } + return out, nil +} + +// EnsureMinimalExperimentFiles ensures minimal experiment files exist. +func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error { + if expMgr == nil { + return fmt.Errorf("missing experiment manager") + } + commitID = strings.TrimSpace(commitID) + if commitID == "" { + return fmt.Errorf("missing commit id") + } + filesPath := expMgr.GetFilesPath(commitID) + if err := os.MkdirAll(filesPath, 0750); err != nil { + return err + } + + trainPath := filepath.Join(filesPath, "train.py") + if _, err := os.Stat(trainPath); os.IsNotExist(err) { + if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil { + return err + } + } + + reqPath := filepath.Join(filesPath, "requirements.txt") + if _, err := os.Stat(reqPath); os.IsNotExist(err) { + if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil { + return err + } + } + + return nil +} diff --git a/internal/api/helpers/payload_parser.go b/internal/api/helpers/payload_parser.go new file mode 100644 index 0000000..3bc0af9 --- /dev/null +++ b/internal/api/helpers/payload_parser.go @@ -0,0 +1,121 @@ +// Package helpers provides shared utilities for WebSocket handlers. +package helpers + +import ( + "encoding/binary" + "fmt" +) + +// PayloadParser provides helpers for parsing binary WebSocket payloads. +type PayloadParser struct { + payload []byte + offset int +} + +// NewPayloadParser creates a new payload parser starting after the API key hash. +func NewPayloadParser(payload []byte, apiKeyHashLen int) *PayloadParser { + return &PayloadParser{ + payload: payload, + offset: apiKeyHashLen, + } +} + +// ParseByte parses a single byte and advances the offset. +func (p *PayloadParser) ParseByte() (byte, error) { + if p.offset >= len(p.payload) { + return 0, fmt.Errorf("payload too short at offset %d", p.offset) + } + b := p.payload[p.offset] + p.offset++ + return b, nil +} + +// ParseUint16 parses a 2-byte big-endian uint16 and advances the offset. +func (p *PayloadParser) ParseUint16() (uint16, error) { + if p.offset+2 > len(p.payload) { + return 0, fmt.Errorf("payload too short for uint16 at offset %d", p.offset) + } + v := binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2]) + p.offset += 2 + return v, nil +} + +// ParseLengthPrefixedString parses a length-prefixed string. +// Format: [length:1][string:var] +func (p *PayloadParser) ParseLengthPrefixedString() (string, error) { + if p.offset >= len(p.payload) { + return "", fmt.Errorf("payload too short for length at offset %d", p.offset) + } + length := int(p.payload[p.offset]) + p.offset++ + if length < 0 { + return "", fmt.Errorf("invalid negative length at offset %d", p.offset-1) + } + if p.offset+length > len(p.payload) { + return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset) + } + str := string(p.payload[p.offset : p.offset+length]) + p.offset += length + return str, nil +} + +// ParseUint16PrefixedString parses a string prefixed by a 2-byte length. +// Format: [length:2][string:var] +func (p *PayloadParser) ParseUint16PrefixedString() (string, error) { + if p.offset+2 > len(p.payload) { + return "", fmt.Errorf("payload too short for uint16 length at offset %d", p.offset) + } + length := int(binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2])) + p.offset += 2 + if length < 0 { + return "", fmt.Errorf("invalid negative length at offset %d", p.offset-2) + } + if p.offset+length > len(p.payload) { + return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset) + } + str := string(p.payload[p.offset : p.offset+length]) + p.offset += length + return str, nil +} + +// Payload returns the underlying payload bytes. +func (p *PayloadParser) Payload() []byte { + return p.payload +} + +// Offset returns the current offset into the payload. +func (p *PayloadParser) Offset() int { + return p.offset +} + +// HasRemaining returns true if there are remaining bytes. +func (p *PayloadParser) HasRemaining() bool { + return p.offset < len(p.payload) +} + +// Remaining returns the remaining bytes in the payload from current offset. +func (p *PayloadParser) Remaining() []byte { + if p.offset >= len(p.payload) { + return nil + } + return p.payload[p.offset:] +} + +// ParseBool parses a byte as a boolean (0 = false, non-zero = true). +func (p *PayloadParser) ParseBool() (bool, error) { + b, err := p.ParseByte() + if err != nil { + return false, err + } + return b != 0, nil +} + +// ParseFixedBytes parses a fixed-length byte slice. +func (p *PayloadParser) ParseFixedBytes(length int) ([]byte, error) { + if p.offset+length > len(p.payload) { + return nil, fmt.Errorf("payload too short for %d bytes at offset %d", length, p.offset) + } + bytes := p.payload[p.offset : p.offset+length] + p.offset += length + return bytes, nil +} diff --git a/internal/api/helpers/response_helpers.go b/internal/api/helpers/response_helpers.go new file mode 100644 index 0000000..40ec0a1 --- /dev/null +++ b/internal/api/helpers/response_helpers.go @@ -0,0 +1,185 @@ +// Package helpers provides shared utilities for WebSocket handlers. +package helpers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// ErrorCode represents WebSocket error codes +type ErrorCode byte + +// TaskErrorMapper maps task errors to error codes +type TaskErrorMapper struct{} + +// NewTaskErrorMapper creates a new task error mapper +func NewTaskErrorMapper() *TaskErrorMapper { + return &TaskErrorMapper{} +} + +// MapError maps a task error to an error code based on status and error message +func (m *TaskErrorMapper) MapError(t *queue.Task, defaultCode ErrorCode) ErrorCode { + if t == nil { + return defaultCode + } + status := strings.ToLower(strings.TrimSpace(t.Status)) + errStr := strings.ToLower(strings.TrimSpace(t.Error)) + + if status == "cancelled" { + return 0x24 // ErrorCodeJobCancelled + } + if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") { + return 0x30 // ErrorCodeOutOfMemory + } + if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") { + return 0x31 // ErrorCodeDiskFull + } + if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") { + return 0x33 // ErrorCodeServiceUnavailable + } + if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") { + return 0x14 // ErrorCodeTimeout + } + if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") { + return 0x12 // ErrorCodeNetworkError + } + if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") { + return 0x32 // ErrorCodeInvalidConfiguration + } + + // Default for worker-side execution failures + if status == "failed" { + return 0x23 // ErrorCodeJobExecutionFailed + } + return defaultCode +} + +// MapJupyterError maps Jupyter task errors to error codes +func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode { + if t == nil { + return 0x00 // ErrorCodeUnknownError + } + status := strings.ToLower(strings.TrimSpace(t.Status)) + errStr := strings.ToLower(strings.TrimSpace(t.Error)) + + if status == "cancelled" { + return 0x24 // ErrorCodeJobCancelled + } + if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") { + return 0x30 // ErrorCodeOutOfMemory + } + if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") { + return 0x31 // ErrorCodeDiskFull + } + if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") { + return 0x33 // ErrorCodeServiceUnavailable + } + if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") { + return 0x14 // ErrorCodeTimeout + } + if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") { + return 0x12 // ErrorCodeNetworkError + } + if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") { + return 0x32 // ErrorCodeInvalidConfiguration + } + + // Default for worker-side execution failures + if status == "failed" { + return 0x23 // ErrorCodeJobExecutionFailed + } + return 0x00 // ErrorCodeUnknownError +} + +// ResourceRequest represents resource requirements +type ResourceRequest struct { + CPU int + MemoryGB int + GPU int + GPUMemory string +} + +// ParseResourceRequest parses an optional resource request from bytes. +// Format: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] +// If payload is empty, returns nil. +func ParseResourceRequest(payload []byte) (*ResourceRequest, error) { + if len(payload) == 0 { + return nil, nil + } + if len(payload) < 4 { + return nil, fmt.Errorf("resource payload too short") + } + cpu := int(payload[0]) + mem := int(payload[1]) + gpu := int(payload[2]) + gpuMemLen := int(payload[3]) + if gpuMemLen < 0 || len(payload) < 4+gpuMemLen { + return nil, fmt.Errorf("invalid gpu memory length") + } + gpuMem := "" + if gpuMemLen > 0 { + gpuMem = string(payload[4 : 4+gpuMemLen]) + } + return &ResourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil +} + +// JSONResponseBuilder helps build JSON data responses +type JSONResponseBuilder struct { + data interface{} +} + +// NewJSONResponseBuilder creates a new JSON response builder +func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder { + return &JSONResponseBuilder{data: data} +} + +// Build marshals the data to JSON +func (b *JSONResponseBuilder) Build() ([]byte, error) { + return json.Marshal(b.data) +} + +// BuildOrEmpty marshals the data to JSON or returns empty array on error +func (b *JSONResponseBuilder) BuildOrEmpty() []byte { + data, err := json.Marshal(b.data) + if err != nil { + return []byte("[]") + } + return data +} + +// StringPtr returns a pointer to a string +func StringPtr(s string) *string { + return &s +} + +// IntPtr returns a pointer to an int +func IntPtr(i int) *int { + return &i +} + +// MarshalJSONOrEmpty marshals data to JSON or returns empty array on error +func MarshalJSONOrEmpty(data interface{}) []byte { + b, err := json.Marshal(data) + if err != nil { + return []byte("[]") + } + return b +} + +// MarshalJSONBytes marshals data to JSON bytes with error handling +func MarshalJSONBytes(data interface{}) ([]byte, error) { + return json.Marshal(data) +} + +// IsEmptyJSON checks if JSON data is empty or "null" +func IsEmptyJSON(data []byte) bool { + if len(data) == 0 { + return true + } + // Check for "null", "[]", "{}" or empty after trimming + s := strings.TrimSpace(string(data)) + return s == "" || s == "null" || s == "[]" || s == "{}" +} diff --git a/internal/api/helpers/validation_helpers.go b/internal/api/helpers/validation_helpers.go new file mode 100644 index 0000000..78404ec --- /dev/null +++ b/internal/api/helpers/validation_helpers.go @@ -0,0 +1,237 @@ +// Package helpers provides validation utilities for WebSocket handlers. +package helpers + +import ( + "encoding/hex" + "os" + "path/filepath" + "strings" + + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker" +) + +// ValidateCommitIDFormat validates the commit ID format (40 hex chars) +func ValidateCommitIDFormat(commitID string) (ok bool, errMsg string) { + if len(commitID) != 40 { + return false, "invalid commit_id length" + } + if _, err := hex.DecodeString(commitID); err != nil { + return false, "invalid commit_id hex" + } + return true, "" +} + +// ValidateExperimentManifest validates the experiment manifest integrity +func ValidateExperimentManifest(expMgr *experiment.Manager, commitID string) (ok bool, details string) { + if err := expMgr.ValidateManifest(commitID); err != nil { + return false, err.Error() + } + return true, "" +} + +// ValidateDepsManifest validates the dependency manifest presence and hash +func ValidateDepsManifest( + expMgr *experiment.Manager, + commitID string, +) (depName string, check ValidateCheck, errMsgs []string) { + filesPath := expMgr.GetFilesPath(commitID) + depName, depErr := worker.SelectDependencyManifest(filesPath) + if depErr != nil { + return "", ValidateCheck{OK: false, Details: depErr.Error()}, []string{"deps manifest missing"} + } + + sha, err := FileSHA256Hex(filepath.Join(filesPath, depName)) + if err != nil { + return depName, ValidateCheck{OK: false, Details: err.Error()}, []string{"deps manifest hash failed"} + } + return depName, ValidateCheck{OK: true, Actual: depName + ":" + sha}, nil +} + +// ValidateCheck represents a validation check result +type ValidateCheck struct { + OK bool `json:"ok"` + Expected string `json:"expected,omitempty"` + Actual string `json:"actual,omitempty"` + Details string `json:"details,omitempty"` +} + +// ValidateReport represents a validation report +type ValidateReport struct { + OK bool `json:"ok"` + CommitID string `json:"commit_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + Checks map[string]ValidateCheck `json:"checks"` + Errors []string `json:"errors,omitempty"` + Warnings []string `json:"warnings,omitempty"` + TS string `json:"ts"` +} + +// NewValidateReport creates a new validation report +func NewValidateReport() ValidateReport { + return ValidateReport{ + OK: true, + Checks: map[string]ValidateCheck{}, + } +} + +// ShouldRequireRunManifest returns true if run manifest should be required for the given status +func ShouldRequireRunManifest(task *queue.Task) bool { + if task == nil { + return false + } + s := strings.ToLower(strings.TrimSpace(task.Status)) + switch s { + case "running", "completed", "failed": + return true + default: + return false + } +} + +// ExpectedRunManifestBucketForStatus returns the expected bucket for a given status +func ExpectedRunManifestBucketForStatus(status string) (string, bool) { + s := strings.ToLower(strings.TrimSpace(status)) + switch s { + case "queued", "pending": + return "pending", true + case "running": + return "running", true + case "completed", "finished": + return "finished", true + case "failed": + return "failed", true + default: + return "", false + } +} + +// FindRunManifestDir finds the run manifest directory for a job +func FindRunManifestDir(basePath string, jobName string) (dir string, bucket string, found bool) { + if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" { + return "", "", false + } + jobPaths := config.NewJobPaths(basePath) + typedRoots := []struct { + bucket string + root string + }{ + {bucket: "running", root: jobPaths.RunningPath()}, + {bucket: "pending", root: jobPaths.PendingPath()}, + {bucket: "finished", root: jobPaths.FinishedPath()}, + {bucket: "failed", root: jobPaths.FailedPath()}, + } + for _, item := range typedRoots { + dir := filepath.Join(item.root, jobName) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { + return dir, item.bucket, true + } + } + } + return "", "", false +} + +// ValidateRunManifestLifecycle validates the run manifest lifecycle fields +func ValidateRunManifestLifecycle(rm *manifest.RunManifest, status string) (ok bool, details string) { + statusLower := strings.ToLower(strings.TrimSpace(status)) + + switch statusLower { + case "running": + if rm.StartedAt.IsZero() { + return false, "missing started_at for running task" + } + if !rm.EndedAt.IsZero() { + return false, "ended_at must be empty for running task" + } + if rm.ExitCode != nil { + return false, "exit_code must be empty for running task" + } + case "completed", "failed": + if rm.StartedAt.IsZero() { + return false, "missing started_at for completed/failed task" + } + if rm.EndedAt.IsZero() { + return false, "missing ended_at for completed/failed task" + } + if rm.ExitCode == nil { + return false, "missing exit_code for completed/failed task" + } + if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) { + return false, "ended_at is before started_at" + } + case "queued", "pending": + // queued/pending tasks may not have started yet. + if !rm.EndedAt.IsZero() || rm.ExitCode != nil { + return false, "queued/pending task should not have ended_at/exit_code" + } + } + return true, "" +} + +// ValidateTaskIDMatch validates the task ID in the run manifest matches the expected task +func ValidateTaskIDMatch(rm *manifest.RunManifest, expectedTaskID string) ValidateCheck { + if strings.TrimSpace(rm.TaskID) == "" { + return ValidateCheck{OK: false, Expected: expectedTaskID} + } + if rm.TaskID != expectedTaskID { + return ValidateCheck{OK: false, Expected: expectedTaskID, Actual: rm.TaskID} + } + return ValidateCheck{OK: true, Expected: expectedTaskID, Actual: rm.TaskID} +} + +// ValidateCommitIDMatch validates the commit ID in the run manifest matches the expected commit +func ValidateCommitIDMatch(rmCommitID, expectedCommitID string) ValidateCheck { + want := strings.TrimSpace(expectedCommitID) + got := strings.TrimSpace(rmCommitID) + if want != "" && got != "" && want != got { + return ValidateCheck{OK: false, Expected: want, Actual: got} + } + if want != "" { + return ValidateCheck{OK: true, Expected: want, Actual: got} + } + return ValidateCheck{OK: true} +} + +// ValidateDepsProvenance validates the dependency manifest provenance +func ValidateDepsProvenance(wantName, wantSHA, gotName, gotSHA string) ValidateCheck { + if wantName == "" || wantSHA == "" || gotName == "" || gotSHA == "" { + return ValidateCheck{OK: true} + } + expected := wantName + ":" + wantSHA + actual := gotName + ":" + gotSHA + if wantName != gotName || wantSHA != gotSHA { + return ValidateCheck{OK: false, Expected: expected, Actual: actual} + } + return ValidateCheck{OK: true, Expected: expected, Actual: actual} +} + +// ValidateSnapshotID validates the snapshot ID in the run manifest +func ValidateSnapshotID(wantID, gotID string) ValidateCheck { + if wantID == "" || gotID == "" { + return ValidateCheck{OK: true, Expected: wantID, Actual: gotID} + } + if wantID != gotID { + return ValidateCheck{OK: false, Expected: wantID, Actual: gotID} + } + return ValidateCheck{OK: true, Expected: wantID, Actual: gotID} +} + +// ValidateSnapshotSHA validates the snapshot SHA in the run manifest +func ValidateSnapshotSHA(wantSHA, gotSHA string) ValidateCheck { + if wantSHA == "" || gotSHA == "" { + return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA} + } + if wantSHA != gotSHA { + return ValidateCheck{OK: false, Expected: wantSHA, Actual: gotSHA} + } + return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA} +} + +// ContainerStat is a function type for stat operations (for mocking in tests) +var ContainerStat = func(path string) (os.FileInfo, error) { + return os.Stat(path) +} diff --git a/internal/api/ws_datasets.go b/internal/api/ws_datasets.go index db2e50e..e347cdb 100644 --- a/internal/api/ws_datasets.go +++ b/internal/api/ws_datasets.go @@ -1,40 +1,30 @@ package api import ( - "context" "database/sql" "encoding/binary" "encoding/json" "net/url" "strings" - "time" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/storage" ) func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16] - if len(payload) < 16 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinDatasetList) + if err != nil { + return err } - apiKeyHash := payload[:16] - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } + if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { + return err } - if h.db == nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + if err := h.requireDB(conn); err != nil { + return err } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := helpers.DBContextShort() defer cancel() datasets, err := h.db.ListDatasets(ctx, 0) @@ -55,26 +45,18 @@ func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) erro } func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var] - if len(payload) < 16+1+2 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister) + if err != nil { + return err } - apiKeyHash := payload[:16] - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } + if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil { + return err } - if h.db == nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + if err := h.requireDB(conn); err != nil { + return err } - offset := 16 + offset := ProtocolAPIKeyHashLen nameLen := int(payload[offset]) offset++ if nameLen <= 0 || len(payload) < offset+nameLen+2 { @@ -90,7 +72,6 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) } urlStr := string(payload[offset : offset+urlLen]) - // Minimal validation (server-side authoritative): name non-empty and url parseable. if strings.TrimSpace(name) == "" { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "") } @@ -98,7 +79,7 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "") } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := helpers.DBContextShort() defer cancel() if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil { @@ -108,26 +89,18 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) } func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][name_len:1][name:var] - if len(payload) < 16+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo) + if err != nil { + return err } - apiKeyHash := payload[:16] - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } + if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { + return err } - if h.db == nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + if err := h.requireDB(conn); err != nil { + return err } - offset := 16 + offset := ProtocolAPIKeyHashLen nameLen := int(payload[offset]) offset++ if nameLen <= 0 || len(payload) < offset+nameLen { @@ -135,7 +108,7 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro } name := string(payload[offset : offset+nameLen]) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := helpers.DBContextShort() defer cancel() ds, err := h.db.GetDataset(ctx, name) @@ -159,26 +132,18 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro } func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][term_len:1][term:var] - if len(payload) < 16+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch) + if err != nil { + return err } - apiKeyHash := payload[:16] - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } + if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil { + return err } - if h.db == nil { - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + if err := h.requireDB(conn); err != nil { + return err } - offset := 16 + offset := ProtocolAPIKeyHashLen termLen := int(payload[offset]) offset++ if termLen < 0 || len(payload) < offset+termLen { @@ -187,7 +152,7 @@ func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) er term := string(payload[offset : offset+termLen]) term = strings.TrimSpace(term) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := helpers.DBContextShort() defer cancel() datasets, err := h.db.SearchDatasets(ctx, term, 0) diff --git a/internal/api/ws_handler.go b/internal/api/ws_handler.go index e2b5612..3365be4 100644 --- a/internal/api/ws_handler.go +++ b/internal/api/ws_handler.go @@ -46,6 +46,10 @@ const ( OpcodeListJupyter = 0x0F OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 + + // Logs opcodes + OpcodeGetLogs = 0x20 + OpcodeStreamLogs = 0x21 ) // createUpgrader creates a WebSocket upgrader with the given security configuration @@ -288,7 +292,88 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { return h.handleListJupyterPackages(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) + case OpcodeGetLogs: + return h.handleGetLogs(conn, payload) + case OpcodeStreamLogs: + return h.handleStreamLogs(conn, payload) default: return fmt.Errorf("unknown opcode: 0x%02x", opcode) } } + +// AuthHandler is a handler function that receives an authenticated user +type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error + +// authenticate validates the API key from raw payload and returns the user +func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) { + if len(payload) < minLen { + return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") + } + + apiKeyHash := payload[:16] + + if h.authConfig != nil { + user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + return user, nil + } + + return &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + }, nil +} + +// authenticateWithHash validates a pre-extracted API key hash +func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) { + if h.authConfig != nil { + user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + return user, nil + } + + return &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + }, nil +} + +// requirePermission checks if the user has the required permission +func (h *WSHandler) requirePermission( + user *auth.User, + permission string, + conn *websocket.Conn, +) error { + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) { + h.logger.Error("insufficient permissions", "user", user.Name, "required", permission) + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + fmt.Sprintf("Insufficient permissions: %s", permission), + "", + ) + } + return nil +} + +// requireDB checks if the database is configured +func (h *WSHandler) requireDB(conn *websocket.Conn) error { + if h.db == nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + } + return nil +} diff --git a/internal/api/ws_jobs.go b/internal/api/ws_jobs.go index 140529f..919c84c 100644 --- a/internal/api/ws_jobs.go +++ b/internal/api/ws_jobs.go @@ -1,13 +1,9 @@ package api import ( - "context" - "crypto/sha256" "encoding/binary" - "encoding/hex" "encoding/json" "fmt" - "io" "math" "os" "path/filepath" @@ -17,16 +13,14 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/experiment" - "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" - "github.com/jfraeys/fetch_ml/internal/worker" ) func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { @@ -35,7 +29,6 @@ func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) erro return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "") } - apiKeyHash := payload[:16] offset := 16 jobNameLen := int(payload[offset]) @@ -61,30 +54,12 @@ func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) erro } note := string(payload[offset : offset+noteLen]) - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + user, err := h.authenticate(conn, payload, 16) + if err != nil { + return err } - - // Permission model: if auth is enabled, require jobs:update to mutate shared run artifacts. - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to annotate runs", "") + if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { + return err } if err := container.ValidateJobName(jobName); err != nil { @@ -97,9 +72,7 @@ func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) erro } jobPaths := config.NewJobPaths(base) - typedRoots := []struct { - root string - }{ + typedRoots := []struct{ root string }{ {root: jobPaths.RunningPath()}, {root: jobPaths.PendingPath()}, {root: jobPaths.FinishedPath()}, @@ -125,7 +98,6 @@ func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) erro return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) } - // Default author to authenticated user if empty. if strings.TrimSpace(author) == "" { author = user.Name } @@ -143,7 +115,6 @@ func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "") } - apiKeyHash := payload[:16] offset := 16 jobNameLen := int(payload[offset]) @@ -161,28 +132,12 @@ func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) } patchJSON := payload[offset : offset+patchLen] - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + user, err := h.authenticate(conn, payload, 16) + if err != nil { + return err } - - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to update run narrative", "") + if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { + return err } if err := container.ValidateJobName(jobName); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) @@ -232,83 +187,12 @@ func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated")) } -func fileSHA256Hex(path string) (string, error) { - f, err := os.Open(filepath.Clean(path)) - if err != nil { - return "", err - } - defer func() { _ = f.Close() }() - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return "", err - } - return hex.EncodeToString(h.Sum(nil)), nil -} - -func expectedProvenanceForCommit( - expMgr *experiment.Manager, - commitID string, -) (map[string]string, error) { - out := map[string]string{} - manifest, err := expMgr.ReadManifest(commitID) - if err != nil { - return nil, err - } - if manifest == nil || manifest.OverallSHA == "" { - return nil, fmt.Errorf("missing manifest overall_sha") - } - out["experiment_manifest_overall_sha"] = manifest.OverallSHA - - filesPath := expMgr.GetFilesPath(commitID) - depName, err := worker.SelectDependencyManifest(filesPath) - if err == nil && strings.TrimSpace(depName) != "" { - depPath := filepath.Join(filesPath, depName) - sha, err := fileSHA256Hex(depPath) - if err == nil && strings.TrimSpace(sha) != "" { - out["deps_manifest_name"] = depName - out["deps_manifest_sha256"] = sha - } - } - return out, nil -} - -func ensureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error { - if expMgr == nil { - return fmt.Errorf("missing experiment manager") - } - commitID = strings.TrimSpace(commitID) - if commitID == "" { - return fmt.Errorf("missing commit id") - } - filesPath := expMgr.GetFilesPath(commitID) - if err := os.MkdirAll(filesPath, 0750); err != nil { - return err - } - - trainPath := filepath.Join(filesPath, "train.py") - if _, err := os.Stat(trainPath); os.IsNotExist(err) { - if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil { - return err - } - } - - reqPath := filepath.Join(filesPath, "requirements.txt") - if _, err := os.Stat(reqPath); os.IsNotExist(err) { - if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil { - return err - } - } - - return nil -} - func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] - if len(payload) < 38 { + // Parse payload first + if len(payload) < ProtocolMinQueueJob { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") } - apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) @@ -318,180 +202,30 @@ func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { } jobName := string(payload[38 : 38+jobNameLen]) - resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:]) if resErr != nil { - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "invalid resource request", - resErr.Error(), - ) + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } - h.logger.Info("queue job request", - "job", jobName, - "priority", priority, - "commit_id", fmt.Sprintf("%x", commitID), - ) + h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + // Authenticate and authorize + user, err := h.authenticate(conn, payload, ProtocolMinQueueJob) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { + return err } - // Check user permissions - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info( - "job queued", - "job", jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), - "user", user.Name, - ) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to create jobs", - "", - ) - } - - // Create experiment directory and metadata (optimized) - if _, err := telemetry.ExecWithMetrics( - h.logger, - "experiment.create", - 50*time.Millisecond, - func() (string, error) { - return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) - }, - ); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to create experiment directory", - err.Error(), - ) - } - - meta := &experiment.Metadata{ - CommitID: fmt.Sprintf("%x", commitID), - JobName: jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to save experiment metadata", - err.Error(), - ) - } - - // Generate and write content integrity manifest - commitIDStr := fmt.Sprintf("%x", commitID) - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { - return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) - }); err != nil { - h.logger.Error("failed to ensure minimal experiment files", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to initialize experiment files", - err.Error(), - ) - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { - manifest, err := h.expManager.GenerateManifest(commitIDStr) - if err != nil { - return "", fmt.Errorf("failed to generate manifest: %w", err) - } - if err := h.expManager.WriteManifest(manifest); err != nil { - return "", fmt.Errorf("failed to write manifest: %w", err) - } - return "", nil - }); err != nil { - h.logger.Error("failed to generate/write manifest", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to generate content integrity manifest", - err.Error(), - ) - } - - // Add user info to experiment metadata (deferred for performance) - go func() { - if h.db != nil { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - exp := &storage.Experiment{ - ID: fmt.Sprintf("%x", commitID), - Name: jobName, - Status: "pending", - UserID: user.Name, - } - if _, err := telemetry.ExecWithMetrics( - h.logger, - "db.experiments.upsert", - 50*time.Millisecond, - func() (string, error) { - return "", h.db.UpsertExperiment(ctx, exp) - }, - ); err != nil { - h.logger.Error("failed to upsert experiment row", "error", err) - } - } - - }() - - h.logger.Info( - "job queued", - "job", jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), - "user", user.Name, - ) - - return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, nil, resources) + return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, nil, resources) } func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { - if len(payload) < 40 { - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "queue job with snapshot payload too short", - "", - ) + if len(payload) < ProtocolMinQueueJobWithSnapshot { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "") } - apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) @@ -520,189 +254,32 @@ func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []b resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "invalid resource request", - resErr.Error(), - ) + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } - h.logger.Info("queue job with snapshot request", - "job", jobName, - "priority", priority, - "commit_id", fmt.Sprintf("%x", commitID), - "snapshot_id", snapshotID, - ) + h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority, + "commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID) - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithSnapshot) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { + return err } - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info( - "job queued", - "job", jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), - "user", user.Name, - ) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to create jobs", - "", - ) - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, - "experiment.create", - 50*time.Millisecond, - func() (string, error) { - return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) - }, - ); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to create experiment directory", - err.Error(), - ) - } - - meta := &experiment.Metadata{ - CommitID: fmt.Sprintf("%x", commitID), - JobName: jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to save experiment metadata", - err.Error(), - ) - } - - commitIDStr := fmt.Sprintf("%x", commitID) - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { - return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) - }); err != nil { - h.logger.Error("failed to ensure minimal experiment files", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to initialize experiment files", - err.Error(), - ) - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { - manifest, err := h.expManager.GenerateManifest(commitIDStr) - if err != nil { - return "", fmt.Errorf("failed to generate manifest: %w", err) - } - if err := h.expManager.WriteManifest(manifest); err != nil { - return "", fmt.Errorf("failed to write manifest: %w", err) - } - return "", nil - }); err != nil { - h.logger.Error("failed to generate/write manifest", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to generate content integrity manifest", - err.Error(), - ) - } - - go func() { - if h.db != nil { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - exp := &storage.Experiment{ - ID: fmt.Sprintf("%x", commitID), - Name: jobName, - Status: "pending", - UserID: user.Name, - } - if _, err := telemetry.ExecWithMetrics( - h.logger, - "db.experiments.upsert", - 50*time.Millisecond, - func() (string, error) { - return "", h.db.UpsertExperiment(ctx, exp) - }, - ); err != nil { - h.logger.Error("failed to upsert experiment row", "error", err) - } - } - }() - - h.logger.Info( - "job queued", - "job", jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), - "user", user.Name, - ) - - return h.enqueueTaskAndRespondWithSnapshot( - conn, - user, - jobName, - priority, - commitID, - nil, - resources, - snapshotID, - snapshotSHA, - ) + return h.processAndEnqueueJobWithSnapshot(conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA) } -// handleQueueJobWithTracking queues a job with optional tracking configuration. -// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] -// [tracking_json_len:2][tracking_json:var] func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error { - if len(payload) < 38+2 { // minimum with zero-length tracking JSON - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "queue job with tracking payload too short", - "", - ) + if len(payload) < ProtocolMinQueueJobWithTracking { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "") } - apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) - // Ensure we have job name and two bytes for tracking length if len(payload) < 38+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } @@ -713,12 +290,7 @@ func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []b offset += 2 if trackingLen < 0 || len(payload) < offset+trackingLen { - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "invalid tracking json length", - "", - ) + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json length", "") } var trackingCfg *queue.TrackingConfig @@ -733,145 +305,20 @@ func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []b offset += trackingLen resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { - return h.sendErrorPacket( - conn, - ErrorCodeInvalidRequest, - "invalid resource request", - resErr.Error(), - ) + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } - h.logger.Info("queue job with tracking request", - "job", jobName, - "priority", priority, - "commit_id", fmt.Sprintf("%x", commitID), - ) + h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithTracking) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { + return err } - // Check user permissions - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info( - "job queued (with tracking)", - "job", jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), - "user", user.Name, - ) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to create jobs", - "", - ) - } - - // Create experiment directory and metadata (optimized) - if _, err := telemetry.ExecWithMetrics( - h.logger, - "experiment.create", - 50*time.Millisecond, - func() (string, error) { - return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) - }, - ); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to create experiment directory", - err.Error(), - ) - } - - meta := &experiment.Metadata{ - CommitID: fmt.Sprintf("%x", commitID), - JobName: jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to save experiment metadata", - err.Error(), - ) - } - - // Generate and write content integrity manifest - commitIDStr := fmt.Sprintf("%x", commitID) - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { - manifest, err := h.expManager.GenerateManifest(commitIDStr) - if err != nil { - return "", fmt.Errorf("failed to generate manifest: %w", err) - } - if err := h.expManager.WriteManifest(manifest); err != nil { - return "", fmt.Errorf("failed to write manifest: %w", err) - } - return "", nil - }); err != nil { - h.logger.Error("failed to generate/write manifest", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeStorageError, - "Failed to generate content integrity manifest", - err.Error(), - ) - } - - // Add user info to experiment metadata (deferred for performance) - go func() { - if h.db != nil { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - exp := &storage.Experiment{ - ID: fmt.Sprintf("%x", commitID), - Name: jobName, - Status: "pending", - UserID: user.Name, - } - if _, err := telemetry.ExecWithMetrics( - h.logger, - "db.experiments.upsert", - 50*time.Millisecond, - func() (string, error) { - return "", h.db.UpsertExperiment(ctx, exp) - }, - ); err != nil { - h.logger.Error("failed to upsert experiment row", "error", err) - } - } - - }() - - return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources) + return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, trackingCfg, resources) } type queueJobWithArgsPayload struct { @@ -880,6 +327,7 @@ type queueJobWithArgsPayload struct { priority int64 jobName string args string + force bool resources *resourceRequest } @@ -890,50 +338,45 @@ type queueJobWithNotePayload struct { jobName string args string note string + force bool resources *resourceRequest } func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) { // Protocol: // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] - // [args_len:2][args:var][note_len:2][note:var][resources?:var] - if len(payload) < 42 { + // [args_len:2][args:var][note_len:2][note:var][force:1][resources?:var] + if len(payload) < 43 { return nil, fmt.Errorf("queue job with note payload too short") } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) - jobNameLen := int(payload[37]) - if len(payload) < 38+jobNameLen+2 { - return nil, fmt.Errorf("invalid job name length") - } - jobName := string(payload[38 : 38+jobNameLen]) - offset := 38 + jobNameLen - argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if argsLen < 0 || len(payload) < offset+argsLen+2 { - return nil, fmt.Errorf("invalid args length") - } - args := "" - if argsLen > 0 { - args = string(payload[offset : offset+argsLen]) - } - offset += argsLen + p := helpers.NewPayloadParser(payload, 37) - noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if noteLen < 0 || len(payload) < offset+noteLen { - return nil, fmt.Errorf("invalid note length") + jobName, err := p.ParseLengthPrefixedString() + if err != nil { + return nil, fmt.Errorf("invalid job name: %w", err) } - note := "" - if noteLen > 0 { - note = string(payload[offset : offset+noteLen]) - } - offset += noteLen - resources, resErr := parseOptionalResourceRequest(payload[offset:]) + args, err := p.ParseUint16PrefixedString() + if err != nil { + return nil, fmt.Errorf("invalid args: %w", err) + } + + note, err := p.ParseUint16PrefixedString() + if err != nil { + return nil, fmt.Errorf("invalid note: %w", err) + } + + force, err := p.ParseBool() + if err != nil { + return nil, fmt.Errorf("missing force flag: %w", err) + } + + resources, resErr := helpers.ParseResourceRequest(p.Remaining()) if resErr != nil { return nil, resErr } @@ -945,38 +388,39 @@ func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, err jobName: jobName, args: args, note: note, + force: force, resources: resources, }, nil } func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) { - // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][resources?:var] - if len(payload) < 40 { + // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][force:1][resources?:var] + if len(payload) < 41 { return nil, fmt.Errorf("queue job with args payload too short") } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) - jobNameLen := int(payload[37]) - if len(payload) < 38+jobNameLen+2 { - return nil, fmt.Errorf("invalid job name length") - } - jobName := string(payload[38 : 38+jobNameLen]) - offset := 38 + jobNameLen - argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) - offset += 2 - if argsLen < 0 || len(payload) < offset+argsLen { - return nil, fmt.Errorf("invalid args length") - } - args := "" - if argsLen > 0 { - args = string(payload[offset : offset+argsLen]) - } - offset += argsLen + p := helpers.NewPayloadParser(payload, 37) - resources, resErr := parseOptionalResourceRequest(payload[offset:]) + jobName, err := p.ParseLengthPrefixedString() + if err != nil { + return nil, fmt.Errorf("invalid job name: %w", err) + } + + args, err := p.ParseUint16PrefixedString() + if err != nil { + return nil, fmt.Errorf("invalid args: %w", err) + } + + force, err := p.ParseBool() + if err != nil { + return nil, fmt.Errorf("missing force flag: %w", err) + } + + resources, resErr := helpers.ParseResourceRequest(p.Remaining()) if resErr != nil { return nil, resErr } @@ -987,6 +431,7 @@ func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, err priority: priority, jobName: jobName, args: args, + force: force, resources: resources, }, nil } @@ -997,84 +442,17 @@ func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error()) } - h.logger.Info("queue job request", - "job", p.jobName, - "priority", p.priority, - "commit_id", fmt.Sprintf("%x", p.commitID), - ) + h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) - // Validate API key and get user information - var user *auth.User - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + user, err := h.authenticateWithHash(conn, p.apiKeyHash) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { + return err } - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info( - "job queued", - "job", p.jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)), - "user", user.Name, - ) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to create jobs", - "", - ) - } - - commitIDStr := fmt.Sprintf("%x", p.commitID) - if _, err := telemetry.ExecWithMetrics( - h.logger, - "experiment.create", - 50*time.Millisecond, - func() (string, error) { - return "", h.expManager.CreateExperiment(commitIDStr) - }, - ); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) - } - - meta := &experiment.Metadata{ - CommitID: commitIDStr, - JobName: p.jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { - return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) - }); err != nil { - h.logger.Error("failed to ensure minimal experiment files", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error()) - } - - return h.enqueueTaskAndRespondWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, nil, p.resources) + return h.processAndEnqueueJobWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, p.force, nil, p.resources) } func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error { @@ -1083,83 +461,137 @@ func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error()) } - h.logger.Info("queue job request", - "job", p.jobName, - "priority", p.priority, - "commit_id", fmt.Sprintf("%x", p.commitID), - ) + h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) - var user *auth.User - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + user, err := h.authenticateWithHash(conn, p.apiKeyHash) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { + return err + } + + return h.processAndEnqueueJobWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, p.force, nil, p.resources) +} + +// findDuplicateTask searches for an existing task with the same composite key +// (commit_id + dataset_id + params_hash) to detect truly identical experiments +func (h *WSHandler) findDuplicateTask(commitIDStr, datasetID, paramsHash string) *queue.Task { + if h.queue == nil { + return nil + } + + tasks, err := h.queue.GetAllTasks() + if err != nil { + return nil + } + + for _, task := range tasks { + if task.Metadata == nil { + continue } - } else { - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, + // Check all three components of the composite key + if task.Metadata["commit_id"] == commitIDStr && + task.Metadata["dataset_id"] == datasetID && + task.Metadata["params_hash"] == paramsHash { + return task + } + } + return nil +} + +// sendDuplicateResponse sends a data packet response for duplicate jobs +func (h *WSHandler) sendDuplicateResponse(conn *websocket.Conn, existingTask *queue.Task) error { + response := map[string]interface{}{ + "duplicate": true, + "existing_id": existingTask.ID, + "status": existingTask.Status, + "queued_by": existingTask.CreatedBy, + "queued_at": existingTask.CreatedAt.Unix(), + } + + // Add duration for completed tasks + if existingTask.Status == "completed" && existingTask.EndedAt != nil { + duration := existingTask.EndedAt.Sub(existingTask.CreatedAt).Seconds() + response["duration_seconds"] = int64(duration) + + // Try to get metrics for completed tasks + if h.expManager != nil { + commitID := existingTask.Metadata["commit_id"] + if metrics, err := h.expManager.GetMetrics(commitID); err == nil && len(metrics) > 0 { + metricsMap := make(map[string]interface{}) + for _, m := range metrics { + metricsMap[m.Name] = m.Value + } + response["metrics"] = metricsMap + } } } - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info( - "job queued", - "job", p.jobName, - "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)), - "user", user.Name, - ) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to create jobs", - "", - ) + // Add error reason for failed tasks with full failure classification + if existingTask.Status == "failed" && existingTask.Error != "" { + response["error_reason"] = existingTask.Error + + // Classify failure using exit codes, signals, and error context + failureClass := queue.FailureUnknown + exitCode := 0 + signalName := "" + + // Extract exit code from error or metadata + if code, ok := existingTask.Metadata["exit_code"]; ok { + fmt.Sscanf(code, "%d", &exitCode) + } + if sig, ok := existingTask.Metadata["signal"]; ok { + signalName = sig + } + + // Get log tail for classification if available + logTail := existingTask.Error + if existingTask.LastError != "" { + logTail = existingTask.LastError + } + + // Classify failure directly using signals, exit codes, and log content + // Note: failureClass declared above at line 536, just reassign here + + // Override with signal-based classification if available + if signalName == "SIGKILL" || signalName == "9" { + failureClass = queue.FailureInfrastructure + } else if exitCode != 0 { + // Use the new ClassifyFailure with error log content + logContent := existingTask.Error + if existingTask.LastError != "" { + logContent = existingTask.LastError + } + failureClass = queue.ClassifyFailure(exitCode, nil, logContent) + } + + response["failure_class"] = string(failureClass) + response["exit_code"] = exitCode + response["signal"] = signalName + response["log_tail"] = logTail + + // Add user-facing suggestion + response["suggestion"] = queue.GetFailureSuggestion(failureClass, logTail) + + // Add retry information with class-specific policy + response["retry_count"] = existingTask.RetryCount + response["retry_cap"] = 3 + response["auto_retryable"] = queue.ShouldAutoRetry(failureClass, existingTask.RetryCount) + + // Add attempts history if available + if len(existingTask.Attempts) > 0 { + response["attempts"] = existingTask.Attempts + } } - commitIDStr := fmt.Sprintf("%x", p.commitID) - if _, err := telemetry.ExecWithMetrics( - h.logger, - "experiment.create", - 50*time.Millisecond, - func() (string, error) { - return "", h.expManager.CreateExperiment(commitIDStr) - }, - ); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) + responseData, err := json.Marshal(response) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize duplicate response", err.Error()) } - meta := &experiment.Metadata{ - CommitID: commitIDStr, - JobName: p.jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) - } - - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { - return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) - }); err != nil { - h.logger.Error("failed to ensure minimal experiment files", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error()) - } - - return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, nil, p.resources) + packet := NewDataPacket("duplicate", responseData) + return h.sendResponsePacket(conn, packet) } // enqueueTaskAndRespond enqueues a task and sends a success response. @@ -1172,7 +604,7 @@ func (h *WSHandler) enqueueTaskAndRespond( tracking *queue.TrackingConfig, resources *resourceRequest, ) error { - return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", tracking, resources) + return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", false, tracking, resources) } func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( @@ -1183,13 +615,31 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( commitID []byte, args string, note string, + force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) - prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + + // Compute dataset_id and params_hash from existing data + paramsHash := helpers.ComputeParamsHash(args) + // Note: dataset_id will be empty here since we don't have DatasetSpecs yet + // It will be populated when the task is actually created with datasets + datasetID := "" + + // Check for duplicate tasks before proceeding (skip if force=true) + if !force { + if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { + h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) + return h.sendDuplicateResponse(conn, existingTask) + } + } else { + h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) + } + + prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, @@ -1215,7 +665,9 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ - "commit_id": commitIDStr, + "commit_id": commitIDStr, + "dataset_id": datasetID, + "params_hash": paramsHash, }, Tracking: tracking, } @@ -1250,7 +702,7 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( err.Error(), ) } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } @@ -1275,16 +727,36 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgs( priority int64, commitID []byte, args string, + force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) - prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + + // Compute dataset_id and params_hash from existing data + paramsHash := helpers.ComputeParamsHash(args) + // Note: dataset_id will be empty here since we don't have DatasetSpecs yet + // It will be populated when the task is actually created with datasets + datasetID := "" + + // Check for duplicate tasks before proceeding (skip if force=true) + if !force { + if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { + h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) + return h.sendDuplicateResponse(conn, existingTask) + } + } else { + h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) + } + + prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, + "dataset_id", datasetID, + "params_hash", paramsHash, "error", provErr) return h.sendErrorPacket( conn, @@ -1308,7 +780,9 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgs( Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ - "commit_id": commitIDStr, + "commit_id": commitIDStr, + "dataset_id": datasetID, + "params_hash": paramsHash, }, Tracking: tracking, } @@ -1340,7 +814,7 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgs( err.Error(), ) } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } @@ -1358,6 +832,89 @@ func (h *WSHandler) enqueueTaskAndRespondWithArgs( return conn.WriteMessage(websocket.BinaryMessage, packetData) } +// processAndEnqueueJob handles common experiment setup and task enqueueing +func (h *WSHandler) processAndEnqueueJob( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + tracking *queue.TrackingConfig, + resources *resourceRequest, +) error { + commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") + } + + helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) + return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, tracking, resources) +} + +// processAndEnqueueJobWithSnapshot handles experiment setup and task enqueueing for snapshot jobs +func (h *WSHandler) processAndEnqueueJobWithSnapshot( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + tracking *queue.TrackingConfig, + resources *resourceRequest, + snapshotID string, + snapshotSHA string, +) error { + commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") + } + + helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) + return h.enqueueTaskAndRespondWithSnapshot(conn, user, jobName, priority, commitID, tracking, resources, snapshotID, snapshotSHA) +} + +// processAndEnqueueJobWithArgs handles experiment setup and task enqueueing for jobs with args +func (h *WSHandler) processAndEnqueueJobWithArgs( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + args string, + force bool, + tracking *queue.TrackingConfig, + resources *resourceRequest, +) error { + commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") + } + + helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) + return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, args, force, tracking, resources) +} + +// processAndEnqueueJobWithArgsAndNote handles experiment setup for jobs with args and note +func (h *WSHandler) processAndEnqueueJobWithArgsAndNote( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + args string, + note string, + force bool, + tracking *queue.TrackingConfig, + resources *resourceRequest, +) error { + commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") + } + + helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) + return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, jobName, priority, commitID, args, note, force, tracking, resources) +} + func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( conn *websocket.Conn, user *auth.User, @@ -1372,7 +929,22 @@ func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) - prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + + // Compute dataset_id from snapshot SHA (snapshot acts as dataset) + datasetID := "" + if strings.TrimSpace(snapshotSHA) != "" { + datasetID = snapshotSHA[:16] + } + // Snapshots don't have args, so params_hash is empty + paramsHash := "" + + // Check for duplicate tasks before proceeding + if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { + h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) + return h.sendDuplicateResponse(conn, existingTask) + } + + prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, @@ -1445,78 +1017,33 @@ func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( return conn.WriteMessage(websocket.BinaryMessage, packetData) } -type resourceRequest struct { - CPU int - MemoryGB int - GPU int - GPUMemory string -} +// resourceRequest is an alias to helpers.ResourceRequest for backward compatibility +type resourceRequest = helpers.ResourceRequest -// parseOptionalResourceRequest parses an optional tail encoding: -// [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] -// If payload is empty, returns nil. +// parseOptionalResourceRequest is an alias to helpers.ParseResourceRequest for backward compatibility func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) { - if len(payload) == 0 { + r, err := helpers.ParseResourceRequest(payload) + if err != nil { + return nil, err + } + // Type conversion is needed because Go doesn't automatically convert named types even with identical underlying structures + if r == nil { return nil, nil } - if len(payload) < 4 { - return nil, fmt.Errorf("resource payload too short") - } - cpu := int(payload[0]) - mem := int(payload[1]) - gpu := int(payload[2]) - gpuMemLen := int(payload[3]) - if gpuMemLen < 0 || len(payload) < 4+gpuMemLen { - return nil, fmt.Errorf("invalid gpu memory length") - } - gpuMem := "" - if gpuMemLen > 0 { - gpuMem = string(payload[4 : 4+gpuMemLen]) - } - return &resourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil + return (*resourceRequest)(r), nil } func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16] - if len(payload) < 16 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinStatusRequest) + if err != nil { + return err + } + h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", payload[:16])) + + if err := h.requirePermission(user, PermJobsRead, conn); err != nil { + return err } - apiKeyHash := payload[:16] - h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", apiKeyHash)) - - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } - } - - // Check user permissions for viewing jobs - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "Insufficient permissions to view jobs", - "", - ) - } - // Get tasks with user filtering var tasks []*queue.Task if h.queue != nil { allTasks, err := h.queue.GetAllTasks() @@ -1525,22 +1052,17 @@ func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) er return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) } - // Filter tasks based on user permissions for _, task := range allTasks { - // If auth is disabled or admin can see all tasks if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { tasks = append(tasks, task) continue } - - // Users can only see their own tasks if task.UserID == user.Name || task.CreatedBy == user.Name { tasks = append(tasks, task) } } } - // Build status response as raw JSON for CLI compatibility h.logger.Info("building status response") status := map[string]any{ "user": map[string]any{ @@ -1582,10 +1104,7 @@ func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) er } h.logger.Info("sending websocket JSON response", "len", len(jsonData)) - - // Send as binary protocol packet - packet := NewDataPacket("status", jsonData) - return h.sendResponsePacket(conn, packet) + return h.sendResponsePacket(conn, NewDataPacket("status", jsonData)) } // countTasksByStatus counts tasks by their status @@ -1600,144 +1119,79 @@ func countTasksByStatus(tasks []*queue.Task, status string) int { } func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][job_name_len:1][job_name:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinCancelJob) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { + return err } - // Parse 16-byte binary API key hash - apiKeyHash := payload[:16] - jobNameLen := int(payload[16]) - - if len(payload) < 17+jobNameLen { + jobNameLen := int(payload[ProtocolAPIKeyHashLen]) + if len(payload) < ProtocolAPIKeyHashLen+1+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } - jobName := string(payload[17 : 17+jobNameLen]) - + jobName := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+jobNameLen]) h.logger.Info("cancel job request", "job", jobName) - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } + if h.queue == nil { + h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) + return nil } - // Check user permissions for canceling jobs - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") + task, err := h.queue.GetTaskByName(jobName) + if err != nil { + h.logger.Error("task not found", "job", jobName, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) + } + + if h.authConfig != nil && h.authConfig.Enabled && !user.Admin && + task.UserID != user.Name && task.CreatedBy != user.Name { + h.logger.Error( + "unauthorized job cancellation attempt", + "user", user.Name, + "job", jobName, + "task_owner", task.UserID, + ) return h.sendErrorPacket( conn, ErrorCodePermissionDenied, - "Insufficient permissions to cancel jobs", + "You can only cancel your own jobs", "", ) } - // Find the task and verify ownership - if h.queue != nil { - task, err := h.queue.GetTaskByName(jobName) - if err != nil { - h.logger.Error("task not found", "job", jobName, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) - } - // Check if user can cancel this task (admin or owner) - if h.authConfig != nil && - h.authConfig.Enabled && - !user.Admin && - task.UserID != user.Name && - task.CreatedBy != user.Name { - h.logger.Error( - "unauthorized job cancellation attempt", - "user", user.Name, - "job", jobName, - "task_owner", task.UserID, - ) - return h.sendErrorPacket( - conn, - ErrorCodePermissionDenied, - "You can only cancel your own jobs", - "", - ) - } - // Cancel the task - if err := h.queue.CancelTask(task.ID); err != nil { - h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) - } - - h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) - } else { - h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) + if err := h.queue.CancelTask(task.ID); err != nil { + h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) } - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeServerOverloaded, - "Internal error", - "Failed to serialize response", - ) - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) + h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) + return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))) } func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][prune_type:1][value:4] - if len(payload) < 21 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinPrune) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { + return err } - // Parse 16-byte binary API key hash - apiKeyHash := payload[:16] - pruneType := payload[16] - value := binary.BigEndian.Uint32(payload[17:21]) + pruneType := payload[ProtocolAPIKeyHashLen] + value := binary.BigEndian.Uint32(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+5]) h.logger.Info("prune request", "type", pruneType, "value", value) - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - h.logger.Error("api key verification failed", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - - // Convert prune parameters var keepCount int var olderThanDays int switch pruneType { case 0: - // keep N keepCount = int(value) - olderThanDays = 0 case 1: - // older than days - keepCount = 0 olderThanDays = int(value) default: return h.sendErrorPacket( @@ -1748,7 +1202,6 @@ func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { ) } - // Perform pruning pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) if err != nil { h.logger.Error("prune failed", "error", err) @@ -1759,20 +1212,19 @@ func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { } h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) - - // Send structured success response - packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) - return h.sendResponsePacket(conn, packet) + return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))) } func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var] - if len(payload) < 51 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinLogMetric) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { + return err } - apiKeyHash := payload[:16] - commitID := payload[16:36] + commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] step := int(binary.BigEndian.Uint32(payload[36:40])) valueBits := binary.BigEndian.Uint64(payload[40:48]) value := math.Float64frombits(valueBits) @@ -1784,19 +1236,6 @@ func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error name := string(payload[49 : 49+nameLen]) - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - h.logger.Error("api key verification failed", "error", err) - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil { h.logger.Error("failed to log metric", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) @@ -1806,25 +1245,15 @@ func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error } func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][commit_id:20] - if len(payload) < 36 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") + user, err := h.authenticate(conn, payload, ProtocolMinGetExperiment) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsRead, conn); err != nil { + return err } - apiKeyHash := payload[:16] - commitID := payload[16:36] - - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } + commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID)) if err != nil { @@ -1838,7 +1267,7 @@ func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) er var dbMeta *storage.ExperimentWithMetadata if h.db != nil { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := helpers.DBContextShort() defer cancel() m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID)) if err == nil { @@ -1866,3 +1295,72 @@ func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) er return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData)) } + +// handleGetLogs handles requests to fetch logs for a task/run +func (h *WSHandler) handleGetLogs(conn *websocket.Conn, payload []byte) error { + user, err := h.authenticate(conn, payload, ProtocolMinGetLogs) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsRead, conn); err != nil { + return err + } + + targetIDLen := int(payload[ProtocolAPIKeyHashLen]) + if len(payload) < ProtocolMinGetLogs+targetIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", fmt.Sprintf("got %d, need %d", len(payload), ProtocolMinGetLogs+targetIDLen)) + } + + targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) + h.logger.Info("get logs request", "target_id", targetID, "user", user.Name) + + // TODO: Implement actual log fetching from storage + // For now, return a stub response + response := map[string]interface{}{ + "target_id": targetID, + "logs": "[Stub] Log content would appear here\nLine 1: Log output\nLine 2: More output\n", + "truncated": false, + "total_lines": 3, + } + + responseData, err := json.Marshal(response) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) + } + + return h.sendResponsePacket(conn, NewDataPacket("logs", responseData)) +} + +// handleStreamLogs handles requests to stream logs in real-time +func (h *WSHandler) handleStreamLogs(conn *websocket.Conn, payload []byte) error { + user, err := h.authenticate(conn, payload, ProtocolMinStreamLogs) + if err != nil { + return err + } + if err := h.requirePermission(user, PermJobsRead, conn); err != nil { + return err + } + + targetIDLen := int(payload[ProtocolAPIKeyHashLen]) + if len(payload) < ProtocolMinStreamLogs+targetIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", "") + } + + targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) + h.logger.Info("stream logs request", "target_id", targetID, "user", user.Name) + + // TODO: Implement actual log streaming + // For now, return a stub response indicating streaming started + response := map[string]interface{}{ + "target_id": targetID, + "streaming": true, + "message": "[Stub] Log streaming would start here. This feature is not yet fully implemented.", + } + + responseData, err := json.Marshal(response) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) + } + + return h.sendResponsePacket(conn, NewDataPacket("logs_stream", responseData)) +} diff --git a/internal/api/ws_jupyter.go b/internal/api/ws_jupyter.go index 7870513..578c089 100644 --- a/internal/api/ws_jupyter.go +++ b/internal/api/ws_jupyter.go @@ -9,44 +9,16 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/queue" ) +// JupyterTaskErrorCode returns the error code for a Jupyter task. +// This is kept for backward compatibility and delegates to the helper. func JupyterTaskErrorCode(t *queue.Task) byte { - if t == nil { - return ErrorCodeUnknownError - } - status := strings.ToLower(strings.TrimSpace(t.Status)) - errStr := strings.ToLower(strings.TrimSpace(t.Error)) - - if status == "cancelled" { - return ErrorCodeJobCancelled - } - if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") { - return ErrorCodeOutOfMemory - } - if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") { - return ErrorCodeDiskFull - } - if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") { - return ErrorCodeServiceUnavailable - } - if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") { - return ErrorCodeTimeout - } - if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") { - return ErrorCodeNetworkError - } - if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") { - return ErrorCodeInvalidConfiguration - } - - // Default for worker-side execution failures. - if status == "failed" { - return ErrorCodeJobExecutionFailed - } - return ErrorCodeUnknownError + mapper := helpers.NewTaskErrorMapper() + return byte(mapper.MapJupyterError(t)) } type jupyterTaskOutput struct { @@ -58,37 +30,15 @@ type jupyterTaskOutput struct { } func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16][name_len:1][name:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "") - } - - apiKeyHash := payload[:16] - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) + user, err := h.authenticate(conn, payload, 18) if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) + return err } - if user != nil && !user.HasPermission("jupyter:manage") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + if err := h.requirePermission(user, PermJupyterManage, conn); err != nil { + return err } - offset := 16 + offset := ProtocolAPIKeyHashLen nameLen := int(payload[offset]) offset++ if len(payload) < offset+nameLen { @@ -183,13 +133,11 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") } - offset := 16 - nameLen := int(payload[offset]) - offset++ - if len(payload) < offset+nameLen { + p := helpers.NewPayloadParser(payload, 16) + name, err := p.ParseLengthPrefixedString() + if err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") } - name := string(payload[offset : offset+nameLen]) name = strings.TrimSpace(name) if name == "" { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "") @@ -217,7 +165,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by out := strings.TrimSpace(result.Output) if out == "" { - return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]"))) + return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{}))) } var payloadOut jupyterTaskOutput if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { @@ -228,7 +176,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload)) } - return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]"))) + return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{}))) } func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) { @@ -427,16 +375,12 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") } - offset := 16 - idLen := int(payload[offset]) - offset++ - - if len(payload) < offset+idLen { + p := helpers.NewPayloadParser(payload, 16) + serviceID, err := p.ParseLengthPrefixedString() + if err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") } - serviceID := string(payload[offset : offset+idLen]) - meta := map[string]string{ jupyterTaskActionKey: jupyterActionStop, jupyterServiceIDKey: strings.TrimSpace(serviceID), @@ -466,19 +410,17 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er apiKeyHash := payload[:16] - offset := 16 - idLen := int(payload[offset]) - offset++ - if len(payload) < offset+idLen { + p := helpers.NewPayloadParser(payload, 16) + serviceID, err := p.ParseLengthPrefixedString() + if err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") } - serviceID := string(payload[offset : offset+idLen]) - offset += idLen // Optional: purge flag (1 byte). Default false for trash-first behavior. purge := false - if len(payload) > offset { - purge = payload[offset] == 0x01 + if p.HasRemaining() { + purgeByte, _ := p.ParseByte() + purge = purgeByte == 0x01 } if h.authConfig != nil && h.authConfig.Enabled { @@ -528,34 +470,12 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er } func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:16] - if len(payload) < 16 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "") - } - - apiKeyHash := payload[:16] - - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) - } - } - user, err := h.validateWSUser(apiKeyHash) + user, err := h.authenticate(conn, payload, ProtocolMinDatasetList) if err != nil { - return h.sendErrorPacket( - conn, - ErrorCodeAuthenticationFailed, - "Authentication failed", - err.Error(), - ) + return err } - if user != nil && !user.HasPermission("jupyter:read") { - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + if err := h.requirePermission(user, PermJupyterRead, conn); err != nil { + return err } meta := map[string]string{ @@ -578,18 +498,15 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro out := strings.TrimSpace(result.Output) if out == "" { - empty, _ := json.Marshal([]any{}) - return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty)) + return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{}))) } var payloadOut jupyterTaskOutput if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { - // Always return an array payload (even if empty) so clients can render a stable table. payload := payloadOut.Services if len(payload) == 0 { payload = []byte("[]") } return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload)) } - // Fallback: return empty array on unexpected output. - return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]"))) + return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{}))) } diff --git a/internal/api/ws_protocol.go b/internal/api/ws_protocol.go new file mode 100644 index 0000000..e86f8ff --- /dev/null +++ b/internal/api/ws_protocol.go @@ -0,0 +1,43 @@ +package api + +// Protocol constants for WebSocket binary protocol +// All handlers use [api_key_hash:16] as the first 16 bytes of every payload +const ( + // Auth header size (present in all payloads) + ProtocolAPIKeyHashLen = 16 + + // Commit ID size (20 bytes hex = 40 char string) + ProtocolCommitIDLen = 20 + + // Minimum payload sizes for each operation + ProtocolMinStatusRequest = ProtocolAPIKeyHashLen // [api_key_hash:16] + ProtocolMinCancelJob = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][job_name_len:1] + ProtocolMinPrune = ProtocolAPIKeyHashLen + 5 // [api_key_hash:16][prune_type:1][value:4] + ProtocolMinDatasetList = ProtocolAPIKeyHashLen // [api_key_hash:16] + ProtocolMinDatasetRegister = ProtocolAPIKeyHashLen + 3 // [api_key_hash:16][name_len:1][url_len:2] + ProtocolMinDatasetInfo = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][name_len:1] + ProtocolMinDatasetSearch = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][term_len:1] + ProtocolMinLogMetric = ProtocolAPIKeyHashLen + 25 // [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1] + ProtocolMinGetExperiment = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][commit_id:20] + ProtocolMinQueueJob = ProtocolAPIKeyHashLen + 21 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1] + ProtocolMinQueueJobWithSnapshot = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][snap_id_len:1] + ProtocolMinQueueJobWithTracking = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][tracking_len:2] + ProtocolMinQueueJobWithNote = ProtocolAPIKeyHashLen + 26 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][note_len:2][force:1] + ProtocolMinQueueJobWithArgs = ProtocolAPIKeyHashLen + 24 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][force:1] + ProtocolMinAnnotateRun = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][author_len:1][note_len:2] + ProtocolMinSetRunNarrative = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][patch_len:2] + + // Logs and debug minimum payload sizes + ProtocolMinGetLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1] + ProtocolMinStreamLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1] + ProtocolMinAttachDebug = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1] + + // Permission constants + PermJobsCreate = "jobs:create" + PermJobsRead = "jobs:read" + PermJobsUpdate = "jobs:update" + PermDatasetsRead = "datasets:read" + PermDatasetsCreate = "datasets:create" + PermJupyterManage = "jupyter:manage" + PermJupyterRead = "jupyter:read" +) diff --git a/internal/api/ws_validate.go b/internal/api/ws_validate.go index 63af9e9..bc3308e 100644 --- a/internal/api/ws_validate.go +++ b/internal/api/ws_validate.go @@ -1,7 +1,6 @@ package api import ( - "encoding/hex" "encoding/json" "fmt" "os" @@ -11,6 +10,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/manifest" @@ -228,20 +228,17 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) } // Validate commit id format - if len(commitID) != 40 { + if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok { r.OK = false - r.Errors = append(r.Errors, "invalid commit_id length") - } else if _, err := hex.DecodeString(commitID); err != nil { - r.OK = false - r.Errors = append(r.Errors, "invalid commit_id hex") + r.Errors = append(r.Errors, errMsg) } // Experiment manifest integrity // TODO(context): Extend report to include per-file diff list on mismatch (bounded output). if r.OK { - if err := h.expManager.ValidateManifest(commitID); err != nil { + if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok { r.OK = false - r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()} + r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details} r.Errors = append(r.Errors, "experiment manifest validation failed") } else { r.Checks["experiment_manifest"] = validateCheck{OK: true} @@ -251,29 +248,13 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) // Deps manifest presence + hash // TODO(context): Allow client to declare which dependency manifest is authoritative. filesPath := h.expManager.GetFilesPath(commitID) - depName, depErr := worker.SelectDependencyManifest(filesPath) - if depErr != nil { + depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID) + if depErrs != nil { r.OK = false - r.Checks["deps_manifest"] = validateCheck{ - OK: false, - Details: depErr.Error(), - } - r.Errors = append(r.Errors, "deps manifest missing") + r.Checks["deps_manifest"] = validateCheck(depCheck) + r.Errors = append(r.Errors, depErrs...) } else { - sha, err := fileSHA256Hex(filepath.Join(filesPath, depName)) - if err != nil { - r.OK = false - r.Checks["deps_manifest"] = validateCheck{ - OK: false, - Details: err.Error(), - } - r.Errors = append(r.Errors, "deps manifest hash failed") - } else { - r.Checks["deps_manifest"] = validateCheck{ - OK: true, - Actual: depName + ":" + sha, - } - } + r.Checks["deps_manifest"] = validateCheck(depCheck) } // Compare against expected task metadata if available. @@ -339,158 +320,58 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) } } - if strings.TrimSpace(rm.TaskID) == "" { - r.OK = false - r.Errors = append(r.Errors, "run manifest missing task_id") - r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID} - } else if rm.TaskID != task.ID { + // Validate task ID using helper + taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID) + r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck) + if !taskIDCheck.OK { r.OK = false r.Errors = append(r.Errors, "run manifest task_id mismatch") - r.Checks["run_manifest_task_id"] = validateCheck{ - OK: false, - Expected: task.ID, - Actual: rm.TaskID, - } - } else { - r.Checks["run_manifest_task_id"] = validateCheck{ - OK: true, - Expected: task.ID, - Actual: rm.TaskID, - } } - commitWant := strings.TrimSpace(task.Metadata["commit_id"]) - commitGot := strings.TrimSpace(rm.CommitID) - if commitWant != "" && commitGot != "" && commitWant != commitGot { + // Validate commit ID using helper + commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"]) + r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck) + if !commitCheck.OK { r.OK = false r.Errors = append(r.Errors, "run manifest commit_id mismatch") - r.Checks["run_manifest_commit_id"] = validateCheck{ - OK: false, - Expected: commitWant, - Actual: commitGot, - } - } else if commitWant != "" { - r.Checks["run_manifest_commit_id"] = validateCheck{ - OK: true, - Expected: commitWant, - Actual: commitGot, - } } + // Validate deps provenance using helper depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"]) depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) depGotName := strings.TrimSpace(rm.DepsManifestName) depGotSHA := strings.TrimSpace(rm.DepsManifestSHA) - if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" { - expectedDep := depWantName + ":" + depWantSHA - actualDep := depGotName + ":" + depGotSHA - if depWantName != depGotName || depWantSHA != depGotSHA { - r.OK = false - r.Errors = append(r.Errors, "run manifest deps provenance mismatch") - r.Checks["run_manifest_deps"] = validateCheck{ - OK: false, - Expected: expectedDep, - Actual: actualDep, - } - } else { - r.Checks["run_manifest_deps"] = validateCheck{ - OK: true, - Expected: expectedDep, - Actual: actualDep, - } - } + depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA) + r.Checks["run_manifest_deps"] = validateCheck(depsCheck) + if !depsCheck.OK { + r.OK = false + r.Errors = append(r.Errors, "run manifest deps provenance mismatch") } + // Validate snapshot using helpers if strings.TrimSpace(task.SnapshotID) != "" { snapWantID := strings.TrimSpace(task.SnapshotID) snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"]) snapGotID := strings.TrimSpace(rm.SnapshotID) snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256) - if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID { + + snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID) + r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck) + if !snapIDCheck.OK { r.OK = false r.Errors = append(r.Errors, "run manifest snapshot_id mismatch") - r.Checks["run_manifest_snapshot_id"] = validateCheck{ - OK: false, - Expected: snapWantID, - Actual: snapGotID, - } - } else { - r.Checks["run_manifest_snapshot_id"] = validateCheck{ - OK: true, - Expected: snapWantID, - Actual: snapGotID, - } } - if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA { + + snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA) + r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck) + if !snapSHACheck.OK { r.OK = false r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch") - r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ - OK: false, - Expected: snapWantSHA, - Actual: snapGotSHA, - } - } else if snapWantSHA != "" { - r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ - OK: true, - Expected: snapWantSHA, - Actual: snapGotSHA, - } - } - } - - statusLower := strings.ToLower(strings.TrimSpace(task.Status)) - lifecycleOK := true - details := "" - - switch statusLower { - case "running": - if rm.StartedAt.IsZero() { - lifecycleOK = false - details = "missing started_at for running task" - } - if !rm.EndedAt.IsZero() { - lifecycleOK = false - if details == "" { - details = "ended_at must be empty for running task" - } - } - if rm.ExitCode != nil { - lifecycleOK = false - if details == "" { - details = "exit_code must be empty for running task" - } - } - case "completed", "failed": - if rm.StartedAt.IsZero() { - lifecycleOK = false - details = "missing started_at for completed/failed task" - } - if rm.EndedAt.IsZero() { - lifecycleOK = false - if details == "" { - details = "missing ended_at for completed/failed task" - } - } - if rm.ExitCode == nil { - lifecycleOK = false - if details == "" { - details = "missing exit_code for completed/failed task" - } - } - if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) { - lifecycleOK = false - if details == "" { - details = "ended_at is before started_at" - } - } - case "queued", "pending": - // queued/pending tasks may not have started yet. - if !rm.EndedAt.IsZero() || rm.ExitCode != nil { - lifecycleOK = false - details = "queued/pending task should not have ended_at/exit_code" } } + // Validate lifecycle using helper + lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status) if lifecycleOK { r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true} } else { @@ -535,7 +416,7 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) r.Errors = append(r.Errors, "missing expected deps manifest provenance") r.Checks["expected_deps_manifest"] = validateCheck{OK: false} } else if depName != "" { - sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName)) + sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName)) ok := (wantDep == depName && wantDepSha == sha) if !ok { r.OK = false