refactor: improve API structure and WebSocket protocol
- Extract WebSocket protocol handling to dedicated module - Add helper functions for DB operations, validation, and responses - Improve WebSocket frame handling and opcodes - Refactor dataset, job, and Jupyter handlers - Add duplicate detection processing
This commit is contained in:
parent
1147958e15
commit
b05470b30a
13 changed files with 1663 additions and 1370 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
|
|
@ -62,9 +63,8 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
|
|||
"message": "Database status check not implemented",
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(response)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(jsonBytes); err != nil {
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(response)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -105,13 +105,8 @@ func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request)
|
|||
// listJupyterServices lists all Jupyter services
|
||||
func (h *Handlers) listJupyterServices(w http.ResponseWriter, _ *http.Request) {
|
||||
services := h.jupyterServiceMgr.ListServices()
|
||||
jsonBytes, err := json.Marshal(services)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to marshal services", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(jsonBytes); err != nil {
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(services)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -131,13 +126,8 @@ func (h *Handlers) startJupyterService(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(service)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to marshal service", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if _, err := w.Write(jsonBytes); err != nil {
|
||||
if _, err := w.Write(helpers.MarshalJSONOrEmpty(service)); err != nil {
|
||||
h.logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
49
internal/api/helpers/db_helpers.go
Normal file
49
internal/api/helpers/db_helpers.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
// Package helpers provides shared utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DBContext provides a standard database operation context.
|
||||
// It creates a context with the specified timeout and returns the context and cancel function.
|
||||
func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), timeout)
|
||||
}
|
||||
|
||||
// DBContextShort returns a short-lived context for quick DB operations (3 seconds).
|
||||
func DBContextShort() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 3*time.Second)
|
||||
}
|
||||
|
||||
// DBContextMedium returns a medium-lived context for standard DB operations (5 seconds).
|
||||
func DBContextMedium() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 5*time.Second)
|
||||
}
|
||||
|
||||
// DBContextLong returns a long-lived context for complex DB operations (10 seconds).
|
||||
func DBContextLong() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 10*time.Second)
|
||||
}
|
||||
|
||||
// StringSliceContains checks if a string slice contains a specific string.
|
||||
func StringSliceContains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StringSliceFilter filters a string slice based on a predicate.
|
||||
func StringSliceFilter(slice []string, predicate func(string) bool) []string {
|
||||
result := make([]string, 0)
|
||||
for _, s := range slice {
|
||||
if predicate(s) {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
193
internal/api/helpers/experiment_setup.go
Normal file
193
internal/api/helpers/experiment_setup.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
// Package helpers provides shared utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
)
|
||||
|
||||
// ExperimentSetupResult contains the result of experiment setup operations
|
||||
type ExperimentSetupResult struct {
|
||||
CommitIDStr string
|
||||
Manifest *experiment.Manifest
|
||||
Err error
|
||||
}
|
||||
|
||||
// RunExperimentSetup performs the common experiment setup operations:
|
||||
// create experiment dir, write metadata, ensure minimal files, generate manifest.
|
||||
// Returns the commitID string and any error that occurred.
|
||||
func RunExperimentSetup(
|
||||
logger *logging.Logger,
|
||||
expMgr *experiment.Manager,
|
||||
commitID []byte,
|
||||
jobName string,
|
||||
userName string,
|
||||
) (string, error) {
|
||||
commitIDStr := fmt.Sprintf("%x", commitID)
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.create", 50*time.Millisecond,
|
||||
func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) },
|
||||
); err != nil {
|
||||
logger.Error("failed to create experiment directory", "error", err)
|
||||
return "", fmt.Errorf("failed to create experiment directory: %w", err)
|
||||
}
|
||||
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitIDStr,
|
||||
JobName: jobName,
|
||||
User: userName,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.write_metadata", 50*time.Millisecond,
|
||||
func() (string, error) { return "", expMgr.WriteMetadata(meta) },
|
||||
); err != nil {
|
||||
logger.Error("failed to save experiment metadata", "error", err)
|
||||
return "", fmt.Errorf("failed to save experiment metadata: %w", err)
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.ensure_minimal_files", 50*time.Millisecond,
|
||||
func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) },
|
||||
); err != nil {
|
||||
logger.Error("failed to ensure minimal experiment files", "error", err)
|
||||
return "", fmt.Errorf("failed to initialize experiment files: %w", err)
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.generate_manifest", 100*time.Millisecond,
|
||||
func() (string, error) {
|
||||
manifest, err := expMgr.GenerateManifest(commitIDStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate manifest: %w", err)
|
||||
}
|
||||
return "", expMgr.WriteManifest(manifest)
|
||||
},
|
||||
); err != nil {
|
||||
logger.Error("failed to generate/write manifest", "error", err)
|
||||
return "", fmt.Errorf("failed to generate content integrity manifest: %w", err)
|
||||
}
|
||||
|
||||
return commitIDStr, nil
|
||||
}
|
||||
|
||||
// RunExperimentSetupWithoutManifest performs experiment setup without manifest generation.
|
||||
// Used for jobs with args/note where manifest generation is deferred.
|
||||
func RunExperimentSetupWithoutManifest(
|
||||
logger *logging.Logger,
|
||||
expMgr *experiment.Manager,
|
||||
commitID []byte,
|
||||
jobName string,
|
||||
userName string,
|
||||
) (string, error) {
|
||||
commitIDStr := fmt.Sprintf("%x", commitID)
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.create", 50*time.Millisecond,
|
||||
func() (string, error) { return "", expMgr.CreateExperiment(commitIDStr) },
|
||||
); err != nil {
|
||||
logger.Error("failed to create experiment directory", "error", err)
|
||||
return "", fmt.Errorf("failed to create experiment directory: %w", err)
|
||||
}
|
||||
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitIDStr,
|
||||
JobName: jobName,
|
||||
User: userName,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.write_metadata", 50*time.Millisecond,
|
||||
func() (string, error) { return "", expMgr.WriteMetadata(meta) },
|
||||
); err != nil {
|
||||
logger.Error("failed to save experiment metadata", "error", err)
|
||||
return "", fmt.Errorf("failed to save experiment metadata: %w", err)
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
logger, "experiment.ensure_minimal_files", 50*time.Millisecond,
|
||||
func() (string, error) { return "", EnsureMinimalExperimentFiles(expMgr, commitIDStr) },
|
||||
); err != nil {
|
||||
logger.Error("failed to ensure minimal experiment files", "error", err)
|
||||
return "", fmt.Errorf("failed to initialize experiment files: %w", err)
|
||||
}
|
||||
|
||||
return commitIDStr, nil
|
||||
}
|
||||
|
||||
// UpsertExperimentDBAsync upserts experiment data to the database asynchronously.
|
||||
// This is a fire-and-forget operation that runs in a goroutine.
|
||||
func UpsertExperimentDBAsync(
|
||||
logger *logging.Logger,
|
||||
db *storage.DB,
|
||||
commitIDStr string,
|
||||
jobName string,
|
||||
userName string,
|
||||
) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
exp := &storage.Experiment{ID: commitIDStr, Name: jobName, Status: "pending", UserID: userName}
|
||||
if _, err := telemetry.ExecWithMetrics(logger, "db.experiments.upsert", 50*time.Millisecond,
|
||||
func() (string, error) { return "", db.UpsertExperiment(ctx, exp) }); err != nil {
|
||||
logger.Error("failed to upsert experiment row", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// TaskEnqueueResult contains the result of task enqueueing
|
||||
type TaskEnqueueResult struct {
|
||||
TaskID string
|
||||
Err error
|
||||
}
|
||||
|
||||
// BuildTaskMetadata creates the standard task metadata map.
|
||||
func BuildTaskMetadata(commitIDStr, datasetID, paramsHash string, prov map[string]string) map[string]string {
|
||||
meta := map[string]string{
|
||||
"commit_id": commitIDStr,
|
||||
"dataset_id": datasetID,
|
||||
"params_hash": paramsHash,
|
||||
}
|
||||
for k, v := range prov {
|
||||
if v != "" {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
// BuildSnapshotTaskMetadata creates task metadata for snapshot jobs.
|
||||
func BuildSnapshotTaskMetadata(commitIDStr, snapshotSHA string, prov map[string]string) map[string]string {
|
||||
meta := map[string]string{
|
||||
"commit_id": commitIDStr,
|
||||
"snapshot_sha256": snapshotSHA,
|
||||
}
|
||||
for k, v := range prov {
|
||||
if v != "" {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
// ApplyResourceRequest applies resource request to a task.
|
||||
func ApplyResourceRequest(task *queue.Task, resources *ResourceRequest) {
|
||||
if resources != nil {
|
||||
task.CPU = resources.CPU
|
||||
task.MemoryGB = resources.MemoryGB
|
||||
task.GPU = resources.GPU
|
||||
task.GPUMemory = resources.GPUMemory
|
||||
}
|
||||
}
|
||||
129
internal/api/helpers/hash_helpers.go
Normal file
129
internal/api/helpers/hash_helpers.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
// Package helpers provides shared utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
// ComputeDatasetID computes a dataset ID from dataset specs or dataset names.
|
||||
func ComputeDatasetID(datasetSpecs []queue.DatasetSpec, datasets []string) string {
|
||||
if len(datasetSpecs) > 0 {
|
||||
var checksums []string
|
||||
for _, ds := range datasetSpecs {
|
||||
if ds.Checksum != "" {
|
||||
checksums = append(checksums, ds.Checksum)
|
||||
} else if ds.Name != "" {
|
||||
checksums = append(checksums, ds.Name)
|
||||
}
|
||||
}
|
||||
if len(checksums) > 0 {
|
||||
h := sha256.New()
|
||||
for _, cs := range checksums {
|
||||
h.Write([]byte(cs))
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
||||
}
|
||||
}
|
||||
if len(datasets) > 0 {
|
||||
h := sha256.New()
|
||||
for _, ds := range datasets {
|
||||
h.Write([]byte(ds))
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ComputeParamsHash computes a hash of the args string.
|
||||
func ComputeParamsHash(args string) string {
|
||||
if strings.TrimSpace(args) == "" {
|
||||
return ""
|
||||
}
|
||||
h := sha256.New()
|
||||
h.Write([]byte(strings.TrimSpace(args)))
|
||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
||||
}
|
||||
|
||||
// FileSHA256Hex computes the SHA256 hash of a file.
|
||||
func FileSHA256Hex(path string) (string, error) {
|
||||
f, err := os.Open(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// ExpectedProvenanceForCommit computes expected provenance metadata for a commit.
|
||||
func ExpectedProvenanceForCommit(
|
||||
expMgr *experiment.Manager,
|
||||
commitID string,
|
||||
) (map[string]string, error) {
|
||||
out := map[string]string{}
|
||||
manifest, err := expMgr.ReadManifest(commitID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if manifest == nil || manifest.OverallSHA == "" {
|
||||
return nil, fmt.Errorf("missing manifest overall_sha")
|
||||
}
|
||||
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
|
||||
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
depName, err := worker.SelectDependencyManifest(filesPath)
|
||||
if err == nil && strings.TrimSpace(depName) != "" {
|
||||
depPath := filepath.Join(filesPath, depName)
|
||||
sha, err := FileSHA256Hex(depPath)
|
||||
if err == nil && strings.TrimSpace(sha) != "" {
|
||||
out["deps_manifest_name"] = depName
|
||||
out["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EnsureMinimalExperimentFiles ensures minimal experiment files exist.
|
||||
func EnsureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error {
|
||||
if expMgr == nil {
|
||||
return fmt.Errorf("missing experiment manager")
|
||||
}
|
||||
commitID = strings.TrimSpace(commitID)
|
||||
if commitID == "" {
|
||||
return fmt.Errorf("missing commit id")
|
||||
}
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
if err := os.MkdirAll(filesPath, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trainPath := filepath.Join(filesPath, "train.py")
|
||||
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
|
||||
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reqPath := filepath.Join(filesPath, "requirements.txt")
|
||||
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
|
||||
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
121
internal/api/helpers/payload_parser.go
Normal file
121
internal/api/helpers/payload_parser.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// Package helpers provides shared utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// PayloadParser provides helpers for parsing binary WebSocket payloads.
|
||||
type PayloadParser struct {
|
||||
payload []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
// NewPayloadParser creates a new payload parser starting after the API key hash.
|
||||
func NewPayloadParser(payload []byte, apiKeyHashLen int) *PayloadParser {
|
||||
return &PayloadParser{
|
||||
payload: payload,
|
||||
offset: apiKeyHashLen,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseByte parses a single byte and advances the offset.
|
||||
func (p *PayloadParser) ParseByte() (byte, error) {
|
||||
if p.offset >= len(p.payload) {
|
||||
return 0, fmt.Errorf("payload too short at offset %d", p.offset)
|
||||
}
|
||||
b := p.payload[p.offset]
|
||||
p.offset++
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ParseUint16 parses a 2-byte big-endian uint16 and advances the offset.
|
||||
func (p *PayloadParser) ParseUint16() (uint16, error) {
|
||||
if p.offset+2 > len(p.payload) {
|
||||
return 0, fmt.Errorf("payload too short for uint16 at offset %d", p.offset)
|
||||
}
|
||||
v := binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2])
|
||||
p.offset += 2
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// ParseLengthPrefixedString parses a length-prefixed string.
|
||||
// Format: [length:1][string:var]
|
||||
func (p *PayloadParser) ParseLengthPrefixedString() (string, error) {
|
||||
if p.offset >= len(p.payload) {
|
||||
return "", fmt.Errorf("payload too short for length at offset %d", p.offset)
|
||||
}
|
||||
length := int(p.payload[p.offset])
|
||||
p.offset++
|
||||
if length < 0 {
|
||||
return "", fmt.Errorf("invalid negative length at offset %d", p.offset-1)
|
||||
}
|
||||
if p.offset+length > len(p.payload) {
|
||||
return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset)
|
||||
}
|
||||
str := string(p.payload[p.offset : p.offset+length])
|
||||
p.offset += length
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ParseUint16PrefixedString parses a string prefixed by a 2-byte length.
|
||||
// Format: [length:2][string:var]
|
||||
func (p *PayloadParser) ParseUint16PrefixedString() (string, error) {
|
||||
if p.offset+2 > len(p.payload) {
|
||||
return "", fmt.Errorf("payload too short for uint16 length at offset %d", p.offset)
|
||||
}
|
||||
length := int(binary.BigEndian.Uint16(p.payload[p.offset : p.offset+2]))
|
||||
p.offset += 2
|
||||
if length < 0 {
|
||||
return "", fmt.Errorf("invalid negative length at offset %d", p.offset-2)
|
||||
}
|
||||
if p.offset+length > len(p.payload) {
|
||||
return "", fmt.Errorf("payload too short for string of length %d at offset %d", length, p.offset)
|
||||
}
|
||||
str := string(p.payload[p.offset : p.offset+length])
|
||||
p.offset += length
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// Payload returns the underlying payload bytes.
|
||||
func (p *PayloadParser) Payload() []byte {
|
||||
return p.payload
|
||||
}
|
||||
|
||||
// Offset returns the current offset into the payload.
|
||||
func (p *PayloadParser) Offset() int {
|
||||
return p.offset
|
||||
}
|
||||
|
||||
// HasRemaining returns true if there are remaining bytes.
|
||||
func (p *PayloadParser) HasRemaining() bool {
|
||||
return p.offset < len(p.payload)
|
||||
}
|
||||
|
||||
// Remaining returns the remaining bytes in the payload from current offset.
|
||||
func (p *PayloadParser) Remaining() []byte {
|
||||
if p.offset >= len(p.payload) {
|
||||
return nil
|
||||
}
|
||||
return p.payload[p.offset:]
|
||||
}
|
||||
|
||||
// ParseBool parses a byte as a boolean (0 = false, non-zero = true).
|
||||
func (p *PayloadParser) ParseBool() (bool, error) {
|
||||
b, err := p.ParseByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return b != 0, nil
|
||||
}
|
||||
|
||||
// ParseFixedBytes parses a fixed-length byte slice.
|
||||
func (p *PayloadParser) ParseFixedBytes(length int) ([]byte, error) {
|
||||
if p.offset+length > len(p.payload) {
|
||||
return nil, fmt.Errorf("payload too short for %d bytes at offset %d", length, p.offset)
|
||||
}
|
||||
bytes := p.payload[p.offset : p.offset+length]
|
||||
p.offset += length
|
||||
return bytes, nil
|
||||
}
|
||||
185
internal/api/helpers/response_helpers.go
Normal file
185
internal/api/helpers/response_helpers.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
// Package helpers provides shared utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// ErrorCode represents WebSocket error codes
|
||||
type ErrorCode byte
|
||||
|
||||
// TaskErrorMapper maps task errors to error codes
|
||||
type TaskErrorMapper struct{}
|
||||
|
||||
// NewTaskErrorMapper creates a new task error mapper
|
||||
func NewTaskErrorMapper() *TaskErrorMapper {
|
||||
return &TaskErrorMapper{}
|
||||
}
|
||||
|
||||
// MapError maps a task error to an error code based on status and error message
|
||||
func (m *TaskErrorMapper) MapError(t *queue.Task, defaultCode ErrorCode) ErrorCode {
|
||||
if t == nil {
|
||||
return defaultCode
|
||||
}
|
||||
status := strings.ToLower(strings.TrimSpace(t.Status))
|
||||
errStr := strings.ToLower(strings.TrimSpace(t.Error))
|
||||
|
||||
if status == "cancelled" {
|
||||
return 0x24 // ErrorCodeJobCancelled
|
||||
}
|
||||
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
|
||||
return 0x30 // ErrorCodeOutOfMemory
|
||||
}
|
||||
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
|
||||
return 0x31 // ErrorCodeDiskFull
|
||||
}
|
||||
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
|
||||
return 0x33 // ErrorCodeServiceUnavailable
|
||||
}
|
||||
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
|
||||
return 0x14 // ErrorCodeTimeout
|
||||
}
|
||||
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
|
||||
return 0x12 // ErrorCodeNetworkError
|
||||
}
|
||||
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
|
||||
return 0x32 // ErrorCodeInvalidConfiguration
|
||||
}
|
||||
|
||||
// Default for worker-side execution failures
|
||||
if status == "failed" {
|
||||
return 0x23 // ErrorCodeJobExecutionFailed
|
||||
}
|
||||
return defaultCode
|
||||
}
|
||||
|
||||
// MapJupyterError maps Jupyter task errors to error codes
|
||||
func (m *TaskErrorMapper) MapJupyterError(t *queue.Task) ErrorCode {
|
||||
if t == nil {
|
||||
return 0x00 // ErrorCodeUnknownError
|
||||
}
|
||||
status := strings.ToLower(strings.TrimSpace(t.Status))
|
||||
errStr := strings.ToLower(strings.TrimSpace(t.Error))
|
||||
|
||||
if status == "cancelled" {
|
||||
return 0x24 // ErrorCodeJobCancelled
|
||||
}
|
||||
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
|
||||
return 0x30 // ErrorCodeOutOfMemory
|
||||
}
|
||||
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
|
||||
return 0x31 // ErrorCodeDiskFull
|
||||
}
|
||||
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
|
||||
return 0x33 // ErrorCodeServiceUnavailable
|
||||
}
|
||||
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
|
||||
return 0x14 // ErrorCodeTimeout
|
||||
}
|
||||
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
|
||||
return 0x12 // ErrorCodeNetworkError
|
||||
}
|
||||
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
|
||||
return 0x32 // ErrorCodeInvalidConfiguration
|
||||
}
|
||||
|
||||
// Default for worker-side execution failures
|
||||
if status == "failed" {
|
||||
return 0x23 // ErrorCodeJobExecutionFailed
|
||||
}
|
||||
return 0x00 // ErrorCodeUnknownError
|
||||
}
|
||||
|
||||
// ResourceRequest represents resource requirements
|
||||
type ResourceRequest struct {
|
||||
CPU int
|
||||
MemoryGB int
|
||||
GPU int
|
||||
GPUMemory string
|
||||
}
|
||||
|
||||
// ParseResourceRequest parses an optional resource request from bytes.
|
||||
// Format: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
|
||||
// If payload is empty, returns nil.
|
||||
func ParseResourceRequest(payload []byte) (*ResourceRequest, error) {
|
||||
if len(payload) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if len(payload) < 4 {
|
||||
return nil, fmt.Errorf("resource payload too short")
|
||||
}
|
||||
cpu := int(payload[0])
|
||||
mem := int(payload[1])
|
||||
gpu := int(payload[2])
|
||||
gpuMemLen := int(payload[3])
|
||||
if gpuMemLen < 0 || len(payload) < 4+gpuMemLen {
|
||||
return nil, fmt.Errorf("invalid gpu memory length")
|
||||
}
|
||||
gpuMem := ""
|
||||
if gpuMemLen > 0 {
|
||||
gpuMem = string(payload[4 : 4+gpuMemLen])
|
||||
}
|
||||
return &ResourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil
|
||||
}
|
||||
|
||||
// JSONResponseBuilder helps build JSON data responses
|
||||
type JSONResponseBuilder struct {
|
||||
data interface{}
|
||||
}
|
||||
|
||||
// NewJSONResponseBuilder creates a new JSON response builder
|
||||
func NewJSONResponseBuilder(data interface{}) *JSONResponseBuilder {
|
||||
return &JSONResponseBuilder{data: data}
|
||||
}
|
||||
|
||||
// Build marshals the data to JSON
|
||||
func (b *JSONResponseBuilder) Build() ([]byte, error) {
|
||||
return json.Marshal(b.data)
|
||||
}
|
||||
|
||||
// BuildOrEmpty marshals the data to JSON or returns empty array on error
|
||||
func (b *JSONResponseBuilder) BuildOrEmpty() []byte {
|
||||
data, err := json.Marshal(b.data)
|
||||
if err != nil {
|
||||
return []byte("[]")
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// StringPtr returns a pointer to a string
|
||||
func StringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// IntPtr returns a pointer to an int
|
||||
func IntPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
// MarshalJSONOrEmpty marshals data to JSON or returns empty array on error
|
||||
func MarshalJSONOrEmpty(data interface{}) []byte {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return []byte("[]")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalJSONBytes marshals data to JSON bytes with error handling
|
||||
func MarshalJSONBytes(data interface{}) ([]byte, error) {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
// IsEmptyJSON checks if JSON data is empty or "null"
|
||||
func IsEmptyJSON(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check for "null", "[]", "{}" or empty after trimming
|
||||
s := strings.TrimSpace(string(data))
|
||||
return s == "" || s == "null" || s == "[]" || s == "{}"
|
||||
}
|
||||
237
internal/api/helpers/validation_helpers.go
Normal file
237
internal/api/helpers/validation_helpers.go
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
// Package helpers provides validation utilities for WebSocket handlers.
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
// ValidateCommitIDFormat validates the commit ID format (40 hex chars)
|
||||
func ValidateCommitIDFormat(commitID string) (ok bool, errMsg string) {
|
||||
if len(commitID) != 40 {
|
||||
return false, "invalid commit_id length"
|
||||
}
|
||||
if _, err := hex.DecodeString(commitID); err != nil {
|
||||
return false, "invalid commit_id hex"
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ValidateExperimentManifest validates the experiment manifest integrity
|
||||
func ValidateExperimentManifest(expMgr *experiment.Manager, commitID string) (ok bool, details string) {
|
||||
if err := expMgr.ValidateManifest(commitID); err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ValidateDepsManifest validates the dependency manifest presence and hash
|
||||
func ValidateDepsManifest(
|
||||
expMgr *experiment.Manager,
|
||||
commitID string,
|
||||
) (depName string, check ValidateCheck, errMsgs []string) {
|
||||
filesPath := expMgr.GetFilesPath(commitID)
|
||||
depName, depErr := worker.SelectDependencyManifest(filesPath)
|
||||
if depErr != nil {
|
||||
return "", ValidateCheck{OK: false, Details: depErr.Error()}, []string{"deps manifest missing"}
|
||||
}
|
||||
|
||||
sha, err := FileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
if err != nil {
|
||||
return depName, ValidateCheck{OK: false, Details: err.Error()}, []string{"deps manifest hash failed"}
|
||||
}
|
||||
return depName, ValidateCheck{OK: true, Actual: depName + ":" + sha}, nil
|
||||
}
|
||||
|
||||
// ValidateCheck represents a validation check result
|
||||
type ValidateCheck struct {
|
||||
OK bool `json:"ok"`
|
||||
Expected string `json:"expected,omitempty"`
|
||||
Actual string `json:"actual,omitempty"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ValidateReport represents a validation report
|
||||
type ValidateReport struct {
|
||||
OK bool `json:"ok"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Checks map[string]ValidateCheck `json:"checks"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
TS string `json:"ts"`
|
||||
}
|
||||
|
||||
// NewValidateReport creates a new validation report
|
||||
func NewValidateReport() ValidateReport {
|
||||
return ValidateReport{
|
||||
OK: true,
|
||||
Checks: map[string]ValidateCheck{},
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldRequireRunManifest returns true if run manifest should be required for the given status
|
||||
func ShouldRequireRunManifest(task *queue.Task) bool {
|
||||
if task == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(strings.TrimSpace(task.Status))
|
||||
switch s {
|
||||
case "running", "completed", "failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectedRunManifestBucketForStatus returns the expected bucket for a given status
|
||||
func ExpectedRunManifestBucketForStatus(status string) (string, bool) {
|
||||
s := strings.ToLower(strings.TrimSpace(status))
|
||||
switch s {
|
||||
case "queued", "pending":
|
||||
return "pending", true
|
||||
case "running":
|
||||
return "running", true
|
||||
case "completed", "finished":
|
||||
return "finished", true
|
||||
case "failed":
|
||||
return "failed", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// FindRunManifestDir finds the run manifest directory for a job
|
||||
func FindRunManifestDir(basePath string, jobName string) (dir string, bucket string, found bool) {
|
||||
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
|
||||
return "", "", false
|
||||
}
|
||||
jobPaths := config.NewJobPaths(basePath)
|
||||
typedRoots := []struct {
|
||||
bucket string
|
||||
root string
|
||||
}{
|
||||
{bucket: "running", root: jobPaths.RunningPath()},
|
||||
{bucket: "pending", root: jobPaths.PendingPath()},
|
||||
{bucket: "finished", root: jobPaths.FinishedPath()},
|
||||
{bucket: "failed", root: jobPaths.FailedPath()},
|
||||
}
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
||||
return dir, item.bucket, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// ValidateRunManifestLifecycle validates the run manifest lifecycle fields
|
||||
func ValidateRunManifestLifecycle(rm *manifest.RunManifest, status string) (ok bool, details string) {
|
||||
statusLower := strings.ToLower(strings.TrimSpace(status))
|
||||
|
||||
switch statusLower {
|
||||
case "running":
|
||||
if rm.StartedAt.IsZero() {
|
||||
return false, "missing started_at for running task"
|
||||
}
|
||||
if !rm.EndedAt.IsZero() {
|
||||
return false, "ended_at must be empty for running task"
|
||||
}
|
||||
if rm.ExitCode != nil {
|
||||
return false, "exit_code must be empty for running task"
|
||||
}
|
||||
case "completed", "failed":
|
||||
if rm.StartedAt.IsZero() {
|
||||
return false, "missing started_at for completed/failed task"
|
||||
}
|
||||
if rm.EndedAt.IsZero() {
|
||||
return false, "missing ended_at for completed/failed task"
|
||||
}
|
||||
if rm.ExitCode == nil {
|
||||
return false, "missing exit_code for completed/failed task"
|
||||
}
|
||||
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
|
||||
return false, "ended_at is before started_at"
|
||||
}
|
||||
case "queued", "pending":
|
||||
// queued/pending tasks may not have started yet.
|
||||
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
|
||||
return false, "queued/pending task should not have ended_at/exit_code"
|
||||
}
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ValidateTaskIDMatch validates the task ID in the run manifest matches the expected task
|
||||
func ValidateTaskIDMatch(rm *manifest.RunManifest, expectedTaskID string) ValidateCheck {
|
||||
if strings.TrimSpace(rm.TaskID) == "" {
|
||||
return ValidateCheck{OK: false, Expected: expectedTaskID}
|
||||
}
|
||||
if rm.TaskID != expectedTaskID {
|
||||
return ValidateCheck{OK: false, Expected: expectedTaskID, Actual: rm.TaskID}
|
||||
}
|
||||
return ValidateCheck{OK: true, Expected: expectedTaskID, Actual: rm.TaskID}
|
||||
}
|
||||
|
||||
// ValidateCommitIDMatch validates the commit ID in the run manifest matches the expected commit
|
||||
func ValidateCommitIDMatch(rmCommitID, expectedCommitID string) ValidateCheck {
|
||||
want := strings.TrimSpace(expectedCommitID)
|
||||
got := strings.TrimSpace(rmCommitID)
|
||||
if want != "" && got != "" && want != got {
|
||||
return ValidateCheck{OK: false, Expected: want, Actual: got}
|
||||
}
|
||||
if want != "" {
|
||||
return ValidateCheck{OK: true, Expected: want, Actual: got}
|
||||
}
|
||||
return ValidateCheck{OK: true}
|
||||
}
|
||||
|
||||
// ValidateDepsProvenance validates the dependency manifest provenance
|
||||
func ValidateDepsProvenance(wantName, wantSHA, gotName, gotSHA string) ValidateCheck {
|
||||
if wantName == "" || wantSHA == "" || gotName == "" || gotSHA == "" {
|
||||
return ValidateCheck{OK: true}
|
||||
}
|
||||
expected := wantName + ":" + wantSHA
|
||||
actual := gotName + ":" + gotSHA
|
||||
if wantName != gotName || wantSHA != gotSHA {
|
||||
return ValidateCheck{OK: false, Expected: expected, Actual: actual}
|
||||
}
|
||||
return ValidateCheck{OK: true, Expected: expected, Actual: actual}
|
||||
}
|
||||
|
||||
// ValidateSnapshotID validates the snapshot ID in the run manifest
|
||||
func ValidateSnapshotID(wantID, gotID string) ValidateCheck {
|
||||
if wantID == "" || gotID == "" {
|
||||
return ValidateCheck{OK: true, Expected: wantID, Actual: gotID}
|
||||
}
|
||||
if wantID != gotID {
|
||||
return ValidateCheck{OK: false, Expected: wantID, Actual: gotID}
|
||||
}
|
||||
return ValidateCheck{OK: true, Expected: wantID, Actual: gotID}
|
||||
}
|
||||
|
||||
// ValidateSnapshotSHA validates the snapshot SHA in the run manifest
|
||||
func ValidateSnapshotSHA(wantSHA, gotSHA string) ValidateCheck {
|
||||
if wantSHA == "" || gotSHA == "" {
|
||||
return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA}
|
||||
}
|
||||
if wantSHA != gotSHA {
|
||||
return ValidateCheck{OK: false, Expected: wantSHA, Actual: gotSHA}
|
||||
}
|
||||
return ValidateCheck{OK: true, Expected: wantSHA, Actual: gotSHA}
|
||||
}
|
||||
|
||||
// ContainerStat is a function type for stat operations (for mocking in tests)
|
||||
var ContainerStat = func(path string) (os.FileInfo, error) {
|
||||
return os.Stat(path)
|
||||
}
|
||||
|
|
@ -1,40 +1,30 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16]
|
||||
if len(payload) < 16 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.ListDatasets(ctx, 0)
|
||||
|
|
@ -55,26 +45,18 @@ func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) erro
|
|||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := 16
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
||||
|
|
@ -90,7 +72,6 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
|
|||
}
|
||||
urlStr := string(payload[offset : offset+urlLen])
|
||||
|
||||
// Minimal validation (server-side authoritative): name non-empty and url parseable.
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
|
||||
}
|
||||
|
|
@ -98,7 +79,7 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
|
|||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
|
||||
|
|
@ -108,26 +89,18 @@ func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte)
|
|||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := 16
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen {
|
||||
|
|
@ -135,7 +108,7 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro
|
|||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
ds, err := h.db.GetDataset(ctx, name)
|
||||
|
|
@ -159,26 +132,18 @@ func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) erro
|
|||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][term_len:1][term:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := 16
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
termLen := int(payload[offset])
|
||||
offset++
|
||||
if termLen < 0 || len(payload) < offset+termLen {
|
||||
|
|
@ -187,7 +152,7 @@ func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) er
|
|||
term := string(payload[offset : offset+termLen])
|
||||
term = strings.TrimSpace(term)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.SearchDatasets(ctx, term, 0)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ const (
|
|||
OpcodeListJupyter = 0x0F
|
||||
OpcodeListJupyterPackages = 0x1E
|
||||
OpcodeValidateRequest = 0x16
|
||||
|
||||
// Logs opcodes
|
||||
OpcodeGetLogs = 0x20
|
||||
OpcodeStreamLogs = 0x21
|
||||
)
|
||||
|
||||
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
||||
|
|
@ -288,7 +292,88 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
|||
return h.handleListJupyterPackages(conn, payload)
|
||||
case OpcodeValidateRequest:
|
||||
return h.handleValidateRequest(conn, payload)
|
||||
case OpcodeGetLogs:
|
||||
return h.handleGetLogs(conn, payload)
|
||||
case OpcodeStreamLogs:
|
||||
return h.handleStreamLogs(conn, payload)
|
||||
default:
|
||||
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthHandler is a handler function that receives an authenticated user
|
||||
type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error
|
||||
|
||||
// authenticate validates the API key from raw payload and returns the user
|
||||
func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) {
|
||||
if len(payload) < minLen {
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// authenticateWithHash validates a pre-extracted API key hash
|
||||
func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) {
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// requirePermission checks if the user has the required permission
|
||||
func (h *WSHandler) requirePermission(
|
||||
user *auth.User,
|
||||
permission string,
|
||||
conn *websocket.Conn,
|
||||
) error {
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", permission)
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
fmt.Sprintf("Insufficient permissions: %s", permission),
|
||||
"",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// requireDB checks if the database is configured
|
||||
func (h *WSHandler) requireDB(conn *websocket.Conn) error {
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -9,44 +9,16 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// JupyterTaskErrorCode returns the error code for a Jupyter task.
|
||||
// This is kept for backward compatibility and delegates to the helper.
|
||||
func JupyterTaskErrorCode(t *queue.Task) byte {
|
||||
if t == nil {
|
||||
return ErrorCodeUnknownError
|
||||
}
|
||||
status := strings.ToLower(strings.TrimSpace(t.Status))
|
||||
errStr := strings.ToLower(strings.TrimSpace(t.Error))
|
||||
|
||||
if status == "cancelled" {
|
||||
return ErrorCodeJobCancelled
|
||||
}
|
||||
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
|
||||
return ErrorCodeOutOfMemory
|
||||
}
|
||||
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
|
||||
return ErrorCodeDiskFull
|
||||
}
|
||||
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
|
||||
return ErrorCodeServiceUnavailable
|
||||
}
|
||||
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
|
||||
return ErrorCodeTimeout
|
||||
}
|
||||
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
|
||||
return ErrorCodeNetworkError
|
||||
}
|
||||
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
|
||||
return ErrorCodeInvalidConfiguration
|
||||
}
|
||||
|
||||
// Default for worker-side execution failures.
|
||||
if status == "failed" {
|
||||
return ErrorCodeJobExecutionFailed
|
||||
}
|
||||
return ErrorCodeUnknownError
|
||||
mapper := helpers.NewTaskErrorMapper()
|
||||
return byte(mapper.MapJupyterError(t))
|
||||
}
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
|
|
@ -58,37 +30,15 @@ type jupyterTaskOutput struct {
|
|||
}
|
||||
|
||||
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
user, err := h.authenticate(conn, payload, 18)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
if err := h.requirePermission(user, PermJupyterManage, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := 16
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+nameLen {
|
||||
|
|
@ -183,13 +133,11 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
|
|||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+nameLen {
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
name, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
|
||||
|
|
@ -217,7 +165,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
|
|||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
|
|
@ -228,7 +176,7 @@ func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []by
|
|||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
|
||||
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
|
||||
|
|
@ -427,16 +375,12 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro
|
|||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
idLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+idLen {
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
serviceID, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
|
||||
serviceID := string(payload[offset : offset+idLen])
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionStop,
|
||||
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
||||
|
|
@ -466,19 +410,17 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
|
|||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
offset := 16
|
||||
idLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+idLen {
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
serviceID, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
serviceID := string(payload[offset : offset+idLen])
|
||||
offset += idLen
|
||||
|
||||
// Optional: purge flag (1 byte). Default false for trash-first behavior.
|
||||
purge := false
|
||||
if len(payload) > offset {
|
||||
purge = payload[offset] == 0x01
|
||||
if p.HasRemaining() {
|
||||
purgeByte, _ := p.ParseByte()
|
||||
purge = purgeByte == 0x01
|
||||
}
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
|
|
@ -528,34 +470,12 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
|
|||
}
|
||||
|
||||
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16]
|
||||
if len(payload) < 16 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:read") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
if err := h.requirePermission(user, PermJupyterRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
|
|
@ -578,18 +498,15 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro
|
|||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
empty, _ := json.Marshal([]any{})
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
// Always return an array payload (even if empty) so clients can render a stable table.
|
||||
payload := payloadOut.Services
|
||||
if len(payload) == 0 {
|
||||
payload = []byte("[]")
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
|
||||
}
|
||||
// Fallback: return empty array on unexpected output.
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
|
|
|
|||
43
internal/api/ws_protocol.go
Normal file
43
internal/api/ws_protocol.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package api
|
||||
|
||||
// Protocol constants for WebSocket binary protocol
|
||||
// All handlers use [api_key_hash:16] as the first 16 bytes of every payload
|
||||
const (
|
||||
// Auth header size (present in all payloads)
|
||||
ProtocolAPIKeyHashLen = 16
|
||||
|
||||
// Commit ID size (20 bytes hex = 40 char string)
|
||||
ProtocolCommitIDLen = 20
|
||||
|
||||
// Minimum payload sizes for each operation
|
||||
ProtocolMinStatusRequest = ProtocolAPIKeyHashLen // [api_key_hash:16]
|
||||
ProtocolMinCancelJob = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][job_name_len:1]
|
||||
ProtocolMinPrune = ProtocolAPIKeyHashLen + 5 // [api_key_hash:16][prune_type:1][value:4]
|
||||
ProtocolMinDatasetList = ProtocolAPIKeyHashLen // [api_key_hash:16]
|
||||
ProtocolMinDatasetRegister = ProtocolAPIKeyHashLen + 3 // [api_key_hash:16][name_len:1][url_len:2]
|
||||
ProtocolMinDatasetInfo = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][name_len:1]
|
||||
ProtocolMinDatasetSearch = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][term_len:1]
|
||||
ProtocolMinLogMetric = ProtocolAPIKeyHashLen + 25 // [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1]
|
||||
ProtocolMinGetExperiment = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][commit_id:20]
|
||||
ProtocolMinQueueJob = ProtocolAPIKeyHashLen + 21 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1]
|
||||
ProtocolMinQueueJobWithSnapshot = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][snap_id_len:1]
|
||||
ProtocolMinQueueJobWithTracking = ProtocolAPIKeyHashLen + 23 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][tracking_len:2]
|
||||
ProtocolMinQueueJobWithNote = ProtocolAPIKeyHashLen + 26 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][note_len:2][force:1]
|
||||
ProtocolMinQueueJobWithArgs = ProtocolAPIKeyHashLen + 24 // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][args_len:2][force:1]
|
||||
ProtocolMinAnnotateRun = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][author_len:1][note_len:2]
|
||||
ProtocolMinSetRunNarrative = ProtocolAPIKeyHashLen + 20 // [api_key_hash:16][job_name_len:1][patch_len:2]
|
||||
|
||||
// Logs and debug minimum payload sizes
|
||||
ProtocolMinGetLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
|
||||
ProtocolMinStreamLogs = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
|
||||
ProtocolMinAttachDebug = ProtocolAPIKeyHashLen + 1 // [api_key_hash:16][target_id_len:1]
|
||||
|
||||
// Permission constants
|
||||
PermJobsCreate = "jobs:create"
|
||||
PermJobsRead = "jobs:read"
|
||||
PermJobsUpdate = "jobs:update"
|
||||
PermDatasetsRead = "datasets:read"
|
||||
PermDatasetsCreate = "datasets:create"
|
||||
PermJupyterManage = "jupyter:manage"
|
||||
PermJupyterRead = "jupyter:read"
|
||||
)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
|
@ -11,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
|
|
@ -228,20 +228,17 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
|
|||
}
|
||||
|
||||
// Validate commit id format
|
||||
if len(commitID) != 40 {
|
||||
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid commit_id length")
|
||||
} else if _, err := hex.DecodeString(commitID); err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid commit_id hex")
|
||||
r.Errors = append(r.Errors, errMsg)
|
||||
}
|
||||
|
||||
// Experiment manifest integrity
|
||||
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
|
||||
if r.OK {
|
||||
if err := h.expManager.ValidateManifest(commitID); err != nil {
|
||||
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
|
||||
r.OK = false
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()}
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
|
||||
r.Errors = append(r.Errors, "experiment manifest validation failed")
|
||||
} else {
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: true}
|
||||
|
|
@ -251,29 +248,13 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
|
|||
// Deps manifest presence + hash
|
||||
// TODO(context): Allow client to declare which dependency manifest is authoritative.
|
||||
filesPath := h.expManager.GetFilesPath(commitID)
|
||||
depName, depErr := worker.SelectDependencyManifest(filesPath)
|
||||
if depErr != nil {
|
||||
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
|
||||
if depErrs != nil {
|
||||
r.OK = false
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Details: depErr.Error(),
|
||||
}
|
||||
r.Errors = append(r.Errors, "deps manifest missing")
|
||||
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
||||
r.Errors = append(r.Errors, depErrs...)
|
||||
} else {
|
||||
sha, err := fileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Details: err.Error(),
|
||||
}
|
||||
r.Errors = append(r.Errors, "deps manifest hash failed")
|
||||
} else {
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: true,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
}
|
||||
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
||||
}
|
||||
|
||||
// Compare against expected task metadata if available.
|
||||
|
|
@ -339,158 +320,58 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
|
|||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(rm.TaskID) == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest missing task_id")
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID}
|
||||
} else if rm.TaskID != task.ID {
|
||||
// Validate task ID using helper
|
||||
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
|
||||
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
|
||||
if !taskIDCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest task_id mismatch")
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: task.ID,
|
||||
Actual: rm.TaskID,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: task.ID,
|
||||
Actual: rm.TaskID,
|
||||
}
|
||||
}
|
||||
|
||||
commitWant := strings.TrimSpace(task.Metadata["commit_id"])
|
||||
commitGot := strings.TrimSpace(rm.CommitID)
|
||||
if commitWant != "" && commitGot != "" && commitWant != commitGot {
|
||||
// Validate commit ID using helper
|
||||
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
|
||||
if !commitCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: commitWant,
|
||||
Actual: commitGot,
|
||||
}
|
||||
} else if commitWant != "" {
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: commitWant,
|
||||
Actual: commitGot,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate deps provenance using helper
|
||||
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
||||
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
||||
depGotName := strings.TrimSpace(rm.DepsManifestName)
|
||||
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
|
||||
if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" {
|
||||
expectedDep := depWantName + ":" + depWantSHA
|
||||
actualDep := depGotName + ":" + depGotSHA
|
||||
if depWantName != depGotName || depWantSHA != depGotSHA {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
|
||||
r.Checks["run_manifest_deps"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: expectedDep,
|
||||
Actual: actualDep,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_deps"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: expectedDep,
|
||||
Actual: actualDep,
|
||||
}
|
||||
}
|
||||
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
|
||||
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
|
||||
if !depsCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
|
||||
}
|
||||
|
||||
// Validate snapshot using helpers
|
||||
if strings.TrimSpace(task.SnapshotID) != "" {
|
||||
snapWantID := strings.TrimSpace(task.SnapshotID)
|
||||
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
|
||||
snapGotID := strings.TrimSpace(rm.SnapshotID)
|
||||
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
|
||||
if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID {
|
||||
|
||||
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
|
||||
if !snapIDCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: snapWantID,
|
||||
Actual: snapGotID,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: snapWantID,
|
||||
Actual: snapGotID,
|
||||
}
|
||||
}
|
||||
if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA {
|
||||
|
||||
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
|
||||
if !snapSHACheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: snapWantSHA,
|
||||
Actual: snapGotSHA,
|
||||
}
|
||||
} else if snapWantSHA != "" {
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: snapWantSHA,
|
||||
Actual: snapGotSHA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
statusLower := strings.ToLower(strings.TrimSpace(task.Status))
|
||||
lifecycleOK := true
|
||||
details := ""
|
||||
|
||||
switch statusLower {
|
||||
case "running":
|
||||
if rm.StartedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
details = "missing started_at for running task"
|
||||
}
|
||||
if !rm.EndedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "ended_at must be empty for running task"
|
||||
}
|
||||
}
|
||||
if rm.ExitCode != nil {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "exit_code must be empty for running task"
|
||||
}
|
||||
}
|
||||
case "completed", "failed":
|
||||
if rm.StartedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
details = "missing started_at for completed/failed task"
|
||||
}
|
||||
if rm.EndedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "missing ended_at for completed/failed task"
|
||||
}
|
||||
}
|
||||
if rm.ExitCode == nil {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "missing exit_code for completed/failed task"
|
||||
}
|
||||
}
|
||||
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "ended_at is before started_at"
|
||||
}
|
||||
}
|
||||
case "queued", "pending":
|
||||
// queued/pending tasks may not have started yet.
|
||||
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
|
||||
lifecycleOK = false
|
||||
details = "queued/pending task should not have ended_at/exit_code"
|
||||
}
|
||||
}
|
||||
|
||||
// Validate lifecycle using helper
|
||||
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
|
||||
if lifecycleOK {
|
||||
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
|
||||
} else {
|
||||
|
|
@ -535,7 +416,7 @@ func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte)
|
|||
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
|
||||
} else if depName != "" {
|
||||
sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
ok := (wantDep == depName && wantDepSha == sha)
|
||||
if !ok {
|
||||
r.OK = false
|
||||
|
|
|
|||
Loading…
Reference in a new issue