refactor: improve API structure and WebSocket protocol

- Extract WebSocket protocol handling to dedicated module
- Add helper functions for DB operations, validation, and responses
- Improve WebSocket frame handling and opcodes
- Refactor dataset, job, and Jupyter handlers
- Add duplicate detection processing
This commit is contained in:
Jeremie Fraeys 2026-02-16 20:38:12 -05:00
parent 1147958e15
commit b05470b30a
No known key found for this signature in database
13 changed files with 1663 additions and 1370 deletions

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
@ -62,9 +63,8 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
"message": "Database status check not implemented",
}
jsonBytes, _ := json.Marshal(response)
w.WriteHeader(http.StatusOK)
if _, err := w.Write(jsonBytes); err != nil {
if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}
@ -105,13 +105,8 @@ func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request)
// listJupyterServices lists all Jupyter services
func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) {
services := h.jupyterServiceMgr.ListServices()
jsonBytes, err := json.Marshal(services)
if err != nil {
http.Error(w, "Failed to marshal services", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if _, err := w.Write(jsonBytes); err != nil {
if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}
@ -131,13 +126,8 @@ func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) {
return
}
jsonBytes, err := json.Marshal(service)
if err != nil {
http.Error(w, "Failed to marshal service", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
if _, err := w.Write(jsonBytes); err != nil {
if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil {
h.logger.Error("failed to write response", "error", err)
}
}

View file

@ -0,0 +1,49 @@
// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"context"
"time"
)
// DBContext provides a standard database operation context.
// It creates a context with the specified timeout and returns the context and cancel function.
func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// DBContextShort returns a short-lived context for quick DB operations (3 seconds).
func DBContextShort() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 3*time.Second)
}
// DBContextMedium returns a medium-lived context for standard DB operations (5 seconds).
func DBContextMedium() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 5*time.Second)
}
// DBContextLong returns a long-lived context for complex DB operations (10 seconds).
func DBContextLong() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 10*time.Second)
}
// StringSliceContains checks if a string slice contains a specific string.
func StringSliceContains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// StringSliceFilter filters a string slice based on a predicate.
func StringSliceFilter(slice []string, predicate func(string) bool) []string {
result := make([]string, 0)
for _, s := range slice {
if predicate(s) {
result = append(result, s)
}
}
return result
}

View file

@ -0,0 +1,193 @@
// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"context"
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
)
// ExperimentSetupResult contains the result of experiment setup operations
type ExperimentSetupResult struct {
CommitIDStr string
Manifest *experiment.Manifest
Err error
}
// RunExperimentSetup performs the common experiment setup operations:
// create experiment dir, write metadata, ensure minimal files, generate manifest.
// Returns the commitID string and any error that occurred.
func RunExperimentSetup(
logger *logging.Logger,
expMgr *experiment.Manager,
commitID []byte,
jobName string,
userName string,
) (string, error) {
commitIDStr := fmt.Sprintf("%x", commitID)
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.create", 50*time.Millisecond,
func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) },
); err != nil {
logger.Error("failed to create experiment directory", "error", err)
return "", fmt.Errorf("failed to create experiment directory: %w", err)
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: jobName,
User: userName,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.write_metadata", 50*time.Millisecond,
func() (string, error) { return "", expMgr.WriteMetadata(meta) },
); err != nil {
logger.Error("failed to save experiment metadata", "error", err)
return "", fmt.Errorf("failed to save experiment metadata: %w", err)
}
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.ensure_minimal_files", 50*time.Millisecond,
func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) },
); err != nil {
logger.Error("failed to ensure minimal experiment files", "error", err)
return "", fmt.Errorf("failed to initialize experiment files: %w", err)
}
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.generate_manifest", 100*time.Millisecond,
func() (string, error) {
manifest, err := expMgr.GenerateManifest(commitIDStr)
if err != nil {
return "", fmt.Errorf("failed to generate manifest: %w", err)
}
return "", expMgr.WriteManifest(manifest)
},
); err != nil {
logger.Error("failed to generate/write manifest", "error", err)
return "", fmt.Errorf("failed to generate content integrity manifest: %w", err)
}
return commitIDStr, nil
}
// RunExperimentSetupWithoutManifest performs experiment setup without manifest generation.
// Used for jobs with args/note where manifest generation is deferred.
func RunExperimentSetupWithoutManifest(
logger *logging.Logger,
expMgr *experiment.Manager,
commitID []byte,
jobName string,
userName string,
) (string, error) {
commitIDStr := fmt.Sprintf("%x", commitID)
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.create", 50*time.Millisecond,
func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) },
); err != nil {
logger.Error("failed to create experiment directory", "error", err)
return "", fmt.Errorf("failed to create experiment directory: %w", err)
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: jobName,
User: userName,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.write_metadata", 50*time.Millisecond,
func() (string, error) { return "", expMgr.WriteMetadata(meta) },
); err != nil {
logger.Error("failed to save experiment metadata", "error", err)
return "", fmt.Errorf("failed to save experiment metadata: %w", err)
}
if _, err := telemetry.ExecWithMetrics(
logger, "experiment.ensure_minimal_files", 50*time.Millisecond,
func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) },
); err != nil {
logger.Error("failed to ensure minimal experiment files", "error", err)
return "", fmt.Errorf("failed to initialize experiment files: %w", err)
}
return commitIDStr, nil
}
// UpsertExperimentDBAsync upserts experiment data to the database asynchronously.
// This is a fire-and-forget operation that runs in a goroutine.
func UpsertExperimentDBAsync(
logger *logging.Logger,
db *storage.DB,
commitIDStr string,
jobName string,
userName string,
) {
if db == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
exp := &storage.Experiment{ID: commitIDStr, Name: jobName, Status: "pending", UserID: userName}
if _, err := telemetry.ExecWithMetrics(logger, "db.experiments.upsert", 50*time.Millisecond,
func() (string, error) { return "", db.UpsertExperiment(ctx, exp) }); err != nil {
logger.Error("failed to upsert experiment row", "error", err)
}
}()
}
// TaskEnqueueResult contains the result of task enqueueing
type TaskEnqueueResult struct {
TaskID string
Err error
}
// BuildTaskMetadata creates the standard task metadata map.
func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string {
meta := map[string]string{
"commit_id": commitIDStr,
"dataset_id": datasetID,
"params_hash": paramsHash,
}
for k, v := range prov {
if v != "" {
meta[k] = v
}
}
return meta
}
// BuildSnapshotTaskMetadata creates task metadata for snapshot jobs.
func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string {
meta := map[string]string{
"commit_id": commitIDStr,
"snapshot_sha256": snapshotSHA,
}
for k, v := range prov {
if v != "" {
meta[k] = v
}
}
return meta
}
// ApplyResourceRequest applies resource request to a task.
func ApplyResourceRequest(task *queue.Task, resources *ResourceRequest) {
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
}

View file

@ -0,0 +1,129 @@
// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// ComputeDatasetID computes a dataset ID from dataset specs or dataset names.
func ComputeDatasetID(datasetSpecs []queue.DatasetSpec, datasets []string) string {
if len(datasetSpecs) > 0 {
var checksums []string
for _, ds := range datasetSpecs {
if ds.Checksum != "" {
checksums = append(checksums, ds.Checksum)
} else if ds.Name != "" {
checksums = append(checksums, ds.Name)
}
}
if len(checksums) > 0 {
h := sha256.New()
for _, cs := range checksums {
h.Write([]byte(cs))
}
return hex.EncodeToString(h.Sum(nil))[:16]
}
}
if len(datasets) > 0 {
h := sha256.New()
for _, ds := range datasets {
h.Write([]byte(ds))
}
return hex.EncodeToString(h.Sum(nil))[:16]
}
return ""
}
// ComputeParamsHash computes a hash of the args string.
func ComputeParamsHash(args string) string {
if strings.TrimSpace(args) == "" {
return ""
}
h := sha256.New()
h.Write([]byte(strings.TrimSpace(args)))
return hex.EncodeToString(h.Sum(nil))[:16]
}
// FileSHA256Hex computes the SHA256 hash of a file.
func FileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
// ExpectedProvenanceForCommit computes expected provenance metadata for a commit.
func ExpectedProvenanceForCommit(
expMgr *experiment.Manager,
commitID string,
) (map[string]string, error) {
out := map[string]string{}
manifest, err := expMgr.ReadManifest(commitID)
if err != nil {
return nil, err
}
if manifest == nil || manifest.OverallSHA == "" {
return nil, fmt.Errorf("missing manifest overall_sha")
}
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
filesPath := expMgr.GetFilesPath(commitID)
depName, err := worker.SelectDependencyManifest(filesPath)
if err == nil && strings.TrimSpace(depName) != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := FileSHA256Hex(depPath)
if err == nil && strings.TrimSpace(sha) != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
// EnsureMinimalExperimentFiles ensures minimal experiment files exist.
func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error {
if expMgr == nil {
return fmt.Errorf("missing experiment manager")
}
commitID = strings.TrimSpace(commitID)
if commitID == "" {
return fmt.Errorf("missing commit id")
}
filesPath := expMgr.GetFilesPath(commitID)
if err := os.MkdirAll(filesPath, 0750); err != nil {
return err
}
trainPath := filepath.Join(filesPath, "train.py")
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
return err
}
}
reqPath := filepath.Join(filesPath, "requirements.txt")
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,121 @@
// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"encoding/binary"
"fmt"
)
// PayloadParser provides helpers for parsing binary WebSocket payloads.
type PayloadParser struct {
payload []byte
offset int
}
// NewPayloadParser creates a new payload parser starting after the API key hash.
func NewPayloadParser(payload []byte, apiKeyHashLen int) *PayloadParser {
return &PayloadParser{
payload: payload,
offset: apiKeyHashLen,
}
}
// ParseByte parses a single byte and advances the offset.
func (p *PayloadParser) ParseByte() (byte, error) {
if p.offset >= len(p.payload) {
return 0, fmt.Errorf("payload too short at offset %d", p.offset)
}
b := p.payload[p.offset]
p.offset++
return b, nil
}
// ParseUint16 parses a 2-byte big-endian uint16 and advances the offset.
func (p *PayloadParser) ParseUint16() (uint16, error) {
if p.offset+2 > len(p.payload) {
return 0, fmt.Errorf("payload too short for uint16 at offset %d", p.offset)
}
v := binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2])
p.offset += 2
return v, nil
}
// ParseLengthPrefixedString parses a length-prefixed string.
// Format: [length:1][string:var]
func (p *PayloadParser) ParseLengthPrefixedString() (string, error) {
if p.offset >= len(p.payload) {
return "", fmt.Errorf("payload too short for length at offset %d", p.offset)
}
length := int(p.payload[p.offset])
p.offset++
if length < 0 {
return "", fmt.Errorf("invalid negative length at offset %d", p.offset-1)
}
if p.offset+length > len(p.payload) {
return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset)
}
str := string(p.payload[p.offset : p.offset+length])
p.offset += length
return str, nil
}
// ParseUint16PrefixedString parses a string prefixed by a 2-byte length.
// Format: [length:2][string:var]
func (p *PayloadParser) ParseUint16PrefixedString() (string, error) {
if p.offset+2 > len(p.payload) {
return "", fmt.Errorf("payload too short for uint16 length at offset %d", p.offset)
}
length := int(binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2]))
p.offset += 2
if length < 0 {
return "", fmt.Errorf("invalid negative length at offset %d", p.offset-2)
}
if p.offset+length > len(p.payload) {
return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset)
}
str := string(p.payload[p.offset : p.offset+length])
p.offset += length
return str, nil
}
// Payload returns the underlying payload bytes.
func (p *PayloadParser) Payload() []byte {
return p.payload
}
// Offset returns the current offset into the payload.
func (p *PayloadParser) Offset() int {
return p.offset
}
// HasRemaining returns true if there are remaining bytes.
func (p *PayloadParser) HasRemaining() bool {
return p.offset < len(p.payload)
}
// Remaining returns the remaining bytes in the payload from current offset.
func (p *PayloadParser) Remaining() []byte {
if p.offset >= len(p.payload) {
return nil
}
return p.payload[p.offset:]
}
// ParseBool parses a byte as a boolean (0 = false, non-zero = true).
func (p *PayloadParser) ParseBool() (bool, error) {
b, err := p.ParseByte()
if err != nil {
return false, err
}
return b != 0, nil
}
// ParseFixedBytes parses a fixed-length byte slice.
func (p *PayloadParser) ParseFixedBytes(length int) ([]byte, error) {
if p.offset+length > len(p.payload) {
return nil, fmt.Errorf("payload too short for %d bytes at offset %d", length, p.offset)
}
bytes := p.payload[p.offset : p.offset+length]
p.offset += length
return bytes, nil
}

View file

@ -0,0 +1,185 @@
// Package helpers provides shared utilities for WebSocket handlers.
package helpers
import (
"encoding/json"
"fmt"
"strings"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// ErrorCode represents WebSocket error codes
type ErrorCode byte
// TaskErrorMapper maps task errors to error codes
type TaskErrorMapper struct{}
// NewTaskErrorMapper creates a new task error mapper
func NewTaskErrorMapper() *TaskErrorMapper {
return &TaskErrorMapper{}
}
// MapError maps a task error to an error code based on status and error message
func (m *TaskErrorMapper) MapError(t *queue.Task, defaultCode ErrorCode) ErrorCode {
if t == nil {
return defaultCode
}
status := strings.ToLower(strings.TrimSpace(t.Status))
errStr := strings.ToLower(strings.TrimSpace(t.Error))
if status == "cancelled" {
return 0x24 // ErrorCodeJobCancelled
}
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
return 0x30 // ErrorCodeOutOfMemory
}
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
return 0x31 // ErrorCodeDiskFull
}
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
return 0x33 // ErrorCodeServiceUnavailable
}
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
return 0x14 // ErrorCodeTimeout
}
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
return 0x12 // ErrorCodeNetworkError
}
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
return 0x32 // ErrorCodeInvalidConfiguration
}
// Default for worker-side execution failures
if status == "failed" {
return 0x23 // ErrorCodeJobExecutionFailed
}
return defaultCode
}
// MapJupyterError maps Jupyter task errors to error codes
func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode {
if t == nil {
return 0x00 // ErrorCodeUnknownError
}
status := strings.ToLower(strings.TrimSpace(t.Status))
errStr := strings.ToLower(strings.TrimSpace(t.Error))
if status == "cancelled" {
return 0x24 // ErrorCodeJobCancelled
}
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
return 0x30 // ErrorCodeOutOfMemory
}
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
return 0x31 // ErrorCodeDiskFull
}
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
return 0x33 // ErrorCodeServiceUnavailable
}
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
return 0x14 // ErrorCodeTimeout
}
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
return 0x12 // ErrorCodeNetworkError
}
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
return 0x32 // ErrorCodeInvalidConfiguration
}
// Default for worker-side execution failures
if status == "failed" {
return 0x23 // ErrorCodeJobExecutionFailed
}
return 0x00 // ErrorCodeUnknownError
}
// ResourceRequest represents resource requirements
type ResourceRequest struct {
CPU int
MemoryGB int
GPU int
GPUMemory string
}
// ParseResourceRequest parses an optional resource request from bytes.
// Format: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
// If payload is empty, returns nil.
func ParseResourceRequest(payload []byte) (*ResourceRequest, error) {
if len(payload) == 0 {
return nil, nil
}
if len(payload) < 4 {
return nil, fmt.Errorf("resource payload too short")
}
cpu := int(payload[0])
mem := int(payload[1])
gpu := int(payload[2])
gpuMemLen := int(payload[3])
if gpuMemLen < 0 || len(payload) < 4+gpuMemLen {
return nil, fmt.Errorf("invalid gpu memory length")
}
gpuMem := ""
if gpuMemLen > 0 {
gpuMem = string(payload[4 : 4+gpuMemLen])
}
return &ResourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil
}
// JSONResponseBuilder helps build JSON data responses
type JSONResponseBuilder struct {
data interface{}
}
// NewJSONResponseBuilder creates a new JSON response builder
func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder {
return &JSONResponseBuilder{data: data}
}
// Build marshals the data to JSON
func (b *JSONResponseBuilder) Build() ([]byte, error) {
return json.Marshal(b.data)
}
// BuildOrEmpty marshals the data to JSON or returns empty array on error
func (b *JSONResponseBuilder) BuildOrEmpty() []byte {
data, err := json.Marshal(b.data)
if err != nil {
return []byte("[]")
}
return data
}
// StringPtr returns a pointer to a string
func StringPtr(s string) *string {
return &s
}
// IntPtr returns a pointer to an int
func IntPtr(i int) *int {
return &i
}
// MarshalJSONOrEmpty marshals data to JSON or returns empty array on error
func MarshalJSONOrEmpty(data interface{}) []byte {
b, err := json.Marshal(data)
if err != nil {
return []byte("[]")
}
return b
}
// MarshalJSONBytes marshals data to JSON bytes with error handling
func MarshalJSONBytes(data interface{}) ([]byte, error) {
return json.Marshal(data)
}
// IsEmptyJSON checks if JSON data is empty or "null"
func IsEmptyJSON(data []byte) bool {
if len(data) == 0 {
return true
}
// Check for "null", "[]", "{}" or empty after trimming
s := strings.TrimSpace(string(data))
return s == "" || s == "null" || s == "[]" || s == "{}"
}

View file

@ -0,0 +1,237 @@
// Package helpers provides validation utilities for WebSocket handlers.
package helpers
import (
"encoding/hex"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// ValidateCommitIDFormat validates the commit ID format (40 hex chars)
func ValidateCommitIDFormat(commitID string) (ok bool, errMsg string) {
if len(commitID) != 40 {
return false, "invalid commit_id length"
}
if _, err := hex.DecodeString(commitID); err != nil {
return false, "invalid commit_id hex"
}
return true, ""
}
// ValidateExperimentManifest validates the experiment manifest integrity
func ValidateExperimentManifest(expMgr *experiment.Manager, commitID string) (ok bool, details string) {
if err := expMgr.ValidateManifest(commitID); err != nil {
return false, err.Error()
}
return true, ""
}
// ValidateDepsManifest validates the dependency manifest presence and hash
func ValidateDepsManifest(
expMgr *experiment.Manager,
commitID string,
) (depName string, check ValidateCheck, errMsgs []string) {
filesPath := expMgr.GetFilesPath(commitID)
depName, depErr := worker.SelectDependencyManifest(filesPath)
if depErr != nil {
return "", ValidateCheck{OK: false, Details: depErr.Error()}, []string{"deps manifest missing"}
}
sha, err := FileSHA256Hex(filepath.Join(filesPath, depName))
if err != nil {
return depName, ValidateCheck{OK: false, Details: err.Error()}, []string{"deps manifest hash failed"}
}
return depName, ValidateCheck{OK: true, Actual: depName + ":" + sha}, nil
}
// ValidateCheck represents a validation check result
type ValidateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
}
// ValidateReport represents a validation report
type ValidateReport struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]ValidateCheck `json:"checks"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
}
// NewValidateReport creates a new validation report
func NewValidateReport() ValidateReport {
return ValidateReport{
OK: true,
Checks: map[string]ValidateCheck{},
}
}
// ShouldRequireRunManifest returns true if run manifest should be required for the given status
func ShouldRequireRunManifest(task *queue.Task) bool {
if task == nil {
return false
}
s := strings.ToLower(strings.TrimSpace(task.Status))
switch s {
case "running", "completed", "failed":
return true
default:
return false
}
}
// ExpectedRunManifestBucketForStatus returns the expected bucket for a given status
func ExpectedRunManifestBucketForStatus(status string) (string, bool) {
s := strings.ToLower(strings.TrimSpace(status))
switch s {
case "queued", "pending":
return "pending", true
case "running":
return "running", true
case "completed", "finished":
return "finished", true
case "failed":
return "failed", true
default:
return "", false
}
}
// FindRunManifestDir finds the run manifest directory for a job
func FindRunManifestDir(basePath string, jobName string) (dir string, bucket string, found bool) {
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
return "", "", false
}
jobPaths := config.NewJobPaths(basePath)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
return dir, item.bucket, true
}
}
}
return "", "", false
}
// ValidateRunManifestLifecycle validates the run manifest lifecycle fields
func ValidateRunManifestLifecycle(rm *manifest.RunManifest, status string) (ok bool, details string) {
statusLower := strings.ToLower(strings.TrimSpace(status))
switch statusLower {
case "running":
if rm.StartedAt.IsZero() {
return false, "missing started_at for running task"
}
if !rm.EndedAt.IsZero() {
return false, "ended_at must be empty for running task"
}
if rm.ExitCode != nil {
return false, "exit_code must be empty for running task"
}
case "completed", "failed":
if rm.StartedAt.IsZero() {
return false, "missing started_at for completed/failed task"
}
if rm.EndedAt.IsZero() {
return false, "missing ended_at for completed/failed task"
}
if rm.ExitCode == nil {
return false, "missing exit_code for completed/failed task"
}
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
return false, "ended_at is before started_at"
}
case "queued", "pending":
// queued/pending tasks may not have started yet.
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
return false, "queued/pending task should not have ended_at/exit_code"
}
}
return true, ""
}
// ValidateTaskIDMatch validates the task ID in the run manifest matches the expected task
func ValidateTaskIDMatch(rm *manifest.RunManifest, expectedTaskID string) ValidateCheck {
if strings.TrimSpace(rm.TaskID) == "" {
return ValidateCheck{OK: false, Expected: expectedTaskID}
}
if rm.TaskID != expectedTaskID {
return ValidateCheck{OK: false, Expected: expectedTaskID, Actual: rm.TaskID}
}
return ValidateCheck{OK: true, Expected: expectedTaskID, Actual: rm.TaskID}
}
// ValidateCommitIDMatch validates the commit ID in the run manifest matches the expected commit
func ValidateCommitIDMatch(rmCommitID, expectedCommitID string) ValidateCheck {
want := strings.TrimSpace(expectedCommitID)
got := strings.TrimSpace(rmCommitID)
if want != "" && got != "" && want != got {
return ValidateCheck{OK: false, Expected: want, Actual: got}
}
if want != "" {
return ValidateCheck{OK: true, Expected: want, Actual: got}
}
return ValidateCheck{OK: true}
}
// ValidateDepsProvenance validates the dependency manifest provenance
func ValidateDepsProvenance(wantName, wantSHA, gotName, gotSHA string) ValidateCheck {
if wantName == "" || wantSHA == "" || gotName == "" || gotSHA == "" {
return ValidateCheck{OK: true}
}
expected := wantName + ":" + wantSHA
actual := gotName + ":" + gotSHA
if wantName != gotName || wantSHA != gotSHA {
return ValidateCheck{OK: false, Expected: expected, Actual: actual}
}
return ValidateCheck{OK: true, Expected: expected, Actual: actual}
}
// ValidateSnapshotID validates the snapshot ID in the run manifest
func ValidateSnapshotID(wantID, gotID string) ValidateCheck {
if wantID == "" || gotID == "" {
return ValidateCheck{OK: true, Expected: wantID, Actual: gotID}
}
if wantID != gotID {
return ValidateCheck{OK: false, Expected: wantID, Actual: gotID}
}
return ValidateCheck{OK: true, Expected: wantID, Actual: gotID}
}
// ValidateSnapshotSHA validates the snapshot SHA in the run manifest
func ValidateSnapshotSHA(wantSHA, gotSHA string) ValidateCheck {
if wantSHA == "" || gotSHA == "" {
return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA}
}
if wantSHA != gotSHA {
return ValidateCheck{OK: false, Expected: wantSHA, Actual: gotSHA}
}
return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA}
}
// ContainerStat is a function type for stat operations (for mocking in tests)
var ContainerStat = func(path string) (os.FileInfo, error) {
return os.Stat(path)
}

View file

@ -1,40 +1,30 @@
package api
import (
"context"
"database/sql"
"encoding/binary"
"encoding/json"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/storage"
)
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
if err != nil {
return err
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
if err := h.requireDB(conn); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := helpers.DBContextShort()
defer cancel()
datasets, err := h.db.ListDatasets(ctx, 0)
@ -55,26 +45,18 @@ func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) erro
}
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister)
if err != nil {
return err
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil {
return err
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
if err := h.requireDB(conn); err != nil {
return err
}
offset := 16
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
@ -90,7 +72,6 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
}
urlStr := string(payload[offset : offset+urlLen])
// Minimal validation (server-side authoritative): name non-empty and url parseable.
if strings.TrimSpace(name) == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
}
@ -98,7 +79,7 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := helpers.DBContextShort()
defer cancel()
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
@ -108,26 +89,18 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
}
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo)
if err != nil {
return err
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
if err := h.requireDB(conn); err != nil {
return err
}
offset := 16
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen {
@ -135,7 +108,7 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro
}
name := string(payload[offset : offset+nameLen])
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := helpers.DBContextShort()
defer cancel()
ds, err := h.db.GetDataset(ctx, name)
@ -159,26 +132,18 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro
}
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][term_len:1][term:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch)
if err != nil {
return err
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
return err
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
if err := h.requireDB(conn); err != nil {
return err
}
offset := 16
offset := ProtocolAPIKeyHashLen
termLen := int(payload[offset])
offset++
if termLen < 0 || len(payload) < offset+termLen {
@ -187,7 +152,7 @@ func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) er
term := string(payload[offset : offset+termLen])
term = strings.TrimSpace(term)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := helpers.DBContextShort()
defer cancel()
datasets, err := h.db.SearchDatasets(ctx, term, 0)

View file

@ -46,6 +46,10 @@ const (
OpcodeListJupyter = 0x0F
OpcodeListJupyterPackages = 0x1E
OpcodeValidateRequest = 0x16
// Logs opcodes
OpcodeGetLogs = 0x20
OpcodeStreamLogs = 0x21
)
// createUpgrader creates a WebSocket upgrader with the given security configuration
@ -288,7 +292,88 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
return h.handleListJupyterPackages(conn, payload)
case OpcodeValidateRequest:
return h.handleValidateRequest(conn, payload)
case OpcodeGetLogs:
return h.handleGetLogs(conn, payload)
case OpcodeStreamLogs:
return h.handleStreamLogs(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}
// AuthHandler is a handler function that receives an authenticated user
type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error
// authenticate validates the API key from raw payload and returns the user
func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) {
if len(payload) < minLen {
return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}
// authenticateWithHash validates a pre-extracted API key hash
func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) {
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}
// requirePermission checks if the user has the required permission
func (h *WSHandler) requirePermission(
user *auth.User,
permission string,
conn *websocket.Conn,
) error {
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) {
h.logger.Error("insufficient permissions", "user", user.Name, "required", permission)
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
fmt.Sprintf("Insufficient permissions: %s", permission),
"",
)
}
return nil
}
// requireDB checks if the database is configured
func (h *WSHandler) requireDB(conn *websocket.Conn) error {
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -9,44 +9,16 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// JupyterTaskErrorCode returns the error code for a Jupyter task.
// This is kept for backward compatibility and delegates to the helper.
func JupyterTaskErrorCode(t *queue.Task) byte {
if t == nil {
return ErrorCodeUnknownError
}
status := strings.ToLower(strings.TrimSpace(t.Status))
errStr := strings.ToLower(strings.TrimSpace(t.Error))
if status == "cancelled" {
return ErrorCodeJobCancelled
}
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
return ErrorCodeOutOfMemory
}
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
return ErrorCodeDiskFull
}
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
return ErrorCodeServiceUnavailable
}
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
return ErrorCodeTimeout
}
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
return ErrorCodeNetworkError
}
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
return ErrorCodeInvalidConfiguration
}
// Default for worker-side execution failures.
if status == "failed" {
return ErrorCodeJobExecutionFailed
}
return ErrorCodeUnknownError
mapper := helpers.NewTaskErrorMapper()
return byte(mapper.MapJupyterError(t))
}
type jupyterTaskOutput struct {
@ -58,37 +30,15 @@ type jupyterTaskOutput struct {
}
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
user, err := h.authenticate(conn, payload, 18)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
return err
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
if err := h.requirePermission(user, PermJupyterManage, conn); err != nil {
return err
}
offset := 16
offset := ProtocolAPIKeyHashLen
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
@ -183,13 +133,11 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
p := helpers.NewPayloadParser(payload, 16)
name, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
name = strings.TrimSpace(name)
if name == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
@ -217,7 +165,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
out := strings.TrimSpace(result.Output)
if out == "" {
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
@ -228,7 +176,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
}
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
@ -427,16 +375,12 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
p := helpers.NewPayloadParser(payload, 16)
serviceID, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStop,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
@ -466,19 +410,17 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
apiKeyHash := payload[:16]
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
p := helpers.NewPayloadParser(payload, 16)
serviceID, err := p.ParseLengthPrefixedString()
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
offset += idLen
// Optional: purge flag (1 byte). Default false for trash-first behavior.
purge := false
if len(payload) > offset {
purge = payload[offset] == 0x01
if p.HasRemaining() {
purgeByte, _ := p.ParseByte()
purge = purgeByte == 0x01
}
if h.authConfig != nil && h.authConfig.Enabled {
@ -528,34 +470,12 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
}
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
return err
}
if user != nil && !user.HasPermission("jupyter:read") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
if err := h.requirePermission(user, PermJupyterRead, conn); err != nil {
return err
}
meta := map[string]string{
@ -578,18 +498,15 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro
out := strings.TrimSpace(result.Output)
if out == "" {
empty, _ := json.Marshal([]any{})
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
// Always return an array payload (even if empty) so clients can render a stable table.
payload := payloadOut.Services
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
}
// Fallback: return empty array on unexpected output.
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
}

View file

@ -0,0 +1,43 @@
package api
// Protocol constants for WebSocket binary protocol
// All handlers use [api_key_hash:16] as the first 16 bytes of every payload
const (
// Auth header size (present in all payloads)
ProtocolAPIKeyHashLen = 16
// Commit ID size (20 bytes hex = 40 char string)
ProtocolCommitIDLen = 20
// Minimum payload sizes for each operation
ProtocolMinStatusRequest = ProtocolAPIKeyHashLen // [api_key_hash:16]
ProtocolMinCancelJob = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][job_name_len:1]
ProtocolMinPrune = ProtocolAPIKeyHashLen + 5 // [api_key_hash:16][prune_type:1][value:4]
ProtocolMinDatasetList = ProtocolAPIKeyHashLen // [api_key_hash:16]
ProtocolMinDatasetRegister = ProtocolAPIKeyHashLen + 3 // [api_key_hash:16][name_len:1][url_len:2]
ProtocolMinDatasetInfo = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][name_len:1]
ProtocolMinDatasetSearch = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][term_len:1]
ProtocolMinLogMetric = ProtocolAPIKeyHashLen + 25 // [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1]
ProtocolMinGetExperiment = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][commit_id:20]
ProtocolMinQueueJob = ProtocolAPIKeyHashLen + 21 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1]
ProtocolMinQueueJobWithSnapshot = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][snap_id_len:1]
ProtocolMinQueueJobWithTracking = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][tracking_len:2]
ProtocolMinQueueJobWithNote = ProtocolAPIKeyHashLen + 26 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][note_len:2][force:1]
ProtocolMinQueueJobWithArgs = ProtocolAPIKeyHashLen + 24 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][force:1]
ProtocolMinAnnotateRun = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][author_len:1][note_len:2]
ProtocolMinSetRunNarrative = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][patch_len:2]
// Logs and debug minimum payload sizes
ProtocolMinGetLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
ProtocolMinStreamLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
ProtocolMinAttachDebug = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
// Permission constants
PermJobsCreate = "jobs:create"
PermJobsRead = "jobs:read"
PermJobsUpdate = "jobs:update"
PermDatasetsRead = "datasets:read"
PermDatasetsCreate = "datasets:create"
PermJupyterManage = "jupyter:manage"
PermJupyterRead = "jupyter:read"
)

View file

@ -1,7 +1,6 @@
package api
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
@ -11,6 +10,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
@ -228,20 +228,17 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
}
// Validate commit id format
if len(commitID) != 40 {
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id length")
} else if _, err := hex.DecodeString(commitID); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id hex")
r.Errors = append(r.Errors, errMsg)
}
// Experiment manifest integrity
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
if r.OK {
if err := h.expManager.ValidateManifest(commitID); err != nil {
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
r.OK = false
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()}
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
r.Errors = append(r.Errors, "experiment manifest validation failed")
} else {
r.Checks["experiment_manifest"] = validateCheck{OK: true}
@ -251,29 +248,13 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
// Deps manifest presence + hash
// TODO(context): Allow client to declare which dependency manifest is authoritative.
filesPath := h.expManager.GetFilesPath(commitID)
depName, depErr := worker.SelectDependencyManifest(filesPath)
if depErr != nil {
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
if depErrs != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: depErr.Error(),
}
r.Errors = append(r.Errors, "deps manifest missing")
r.Checks["deps_manifest"] = validateCheck(depCheck)
r.Errors = append(r.Errors, depErrs...)
} else {
sha, err := fileSHA256Hex(filepath.Join(filesPath, depName))
if err != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: err.Error(),
}
r.Errors = append(r.Errors, "deps manifest hash failed")
} else {
r.Checks["deps_manifest"] = validateCheck{
OK: true,
Actual: depName + ":" + sha,
}
}
r.Checks["deps_manifest"] = validateCheck(depCheck)
}
// Compare against expected task metadata if available.
@ -339,158 +320,58 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
}
}
if strings.TrimSpace(rm.TaskID) == "" {
r.OK = false
r.Errors = append(r.Errors, "run manifest missing task_id")
r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID}
} else if rm.TaskID != task.ID {
// Validate task ID using helper
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
if !taskIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest task_id mismatch")
r.Checks["run_manifest_task_id"] = validateCheck{
OK: false,
Expected: task.ID,
Actual: rm.TaskID,
}
} else {
r.Checks["run_manifest_task_id"] = validateCheck{
OK: true,
Expected: task.ID,
Actual: rm.TaskID,
}
}
commitWant := strings.TrimSpace(task.Metadata["commit_id"])
commitGot := strings.TrimSpace(rm.CommitID)
if commitWant != "" && commitGot != "" && commitWant != commitGot {
// Validate commit ID using helper
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
if !commitCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: false,
Expected: commitWant,
Actual: commitGot,
}
} else if commitWant != "" {
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: true,
Expected: commitWant,
Actual: commitGot,
}
}
// Validate deps provenance using helper
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
depGotName := strings.TrimSpace(rm.DepsManifestName)
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" {
expectedDep := depWantName + ":" + depWantSHA
actualDep := depGotName + ":" + depGotSHA
if depWantName != depGotName || depWantSHA != depGotSHA {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
r.Checks["run_manifest_deps"] = validateCheck{
OK: false,
Expected: expectedDep,
Actual: actualDep,
}
} else {
r.Checks["run_manifest_deps"] = validateCheck{
OK: true,
Expected: expectedDep,
Actual: actualDep,
}
}
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
if !depsCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
}
// Validate snapshot using helpers
if strings.TrimSpace(task.SnapshotID) != "" {
snapWantID := strings.TrimSpace(task.SnapshotID)
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
snapGotID := strings.TrimSpace(rm.SnapshotID)
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID {
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
if !snapIDCheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: false,
Expected: snapWantID,
Actual: snapGotID,
}
} else {
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: true,
Expected: snapWantID,
Actual: snapGotID,
}
}
if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA {
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
if !snapSHACheck.OK {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: false,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
} else if snapWantSHA != "" {
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: true,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
}
}
statusLower := strings.ToLower(strings.TrimSpace(task.Status))
lifecycleOK := true
details := ""
switch statusLower {
case "running":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for running task"
}
if !rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "ended_at must be empty for running task"
}
}
if rm.ExitCode != nil {
lifecycleOK = false
if details == "" {
details = "exit_code must be empty for running task"
}
}
case "completed", "failed":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for completed/failed task"
}
if rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "missing ended_at for completed/failed task"
}
}
if rm.ExitCode == nil {
lifecycleOK = false
if details == "" {
details = "missing exit_code for completed/failed task"
}
}
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
lifecycleOK = false
if details == "" {
details = "ended_at is before started_at"
}
}
case "queued", "pending":
// queued/pending tasks may not have started yet.
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
lifecycleOK = false
details = "queued/pending task should not have ended_at/exit_code"
}
}
// Validate lifecycle using helper
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
if lifecycleOK {
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
} else {
@ -535,7 +416,7 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
} else if depName != "" {
sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName))
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
ok := (wantDep == depName && wantDepSha == sha)
if !ok {
r.OK = false