fetch_ml/internal/api/ws_jobs.go
Jeremie Fraeys 2e701340e5
feat(core): API, worker, queue, and manifest improvements
- Add protocol buffer optimizations (internal/api/protocol.go)
- Add filesystem queue backend (internal/queue/filesystem_queue.go)
- Add run manifest support (internal/manifest/run_manifest.go)
- Worker and jupyter task refinements
- Exported test wrappers for benchmarking
2026-02-12 12:05:17 -05:00

1868 lines
52 KiB
Go

package api
import (
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/jfraeys/fetch_ml/internal/worker"
)
func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
if len(payload) < 16+1+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
authorLen := int(payload[offset])
offset += 1
if authorLen < 0 || len(payload) < offset+authorLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
}
author := string(payload[offset : offset+authorLen])
offset += authorLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen <= 0 || len(payload) < offset+noteLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
}
note := string(payload[offset : offset+noteLen])
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Permission model: if auth is enabled, require jobs:update to mutate shared run artifacts.
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to annotate runs", "")
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := config.NewJobPaths(base)
typedRoots := []struct {
root string
}{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
// Default author to authenticated user if empty.
if strings.TrimSpace(author) == "" {
author = user.Name
}
rm.AddAnnotation(time.Now().UTC(), author, note)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added"))
}
func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "")
}
patchJSON := payload[offset : offset+patchLen]
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to update run narrative", "")
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := config.NewJobPaths(base)
typedRoots := []struct{ root string }{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
var patch manifest.NarrativePatch
if err := json.Unmarshal(patchJSON, &patch); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error())
}
rm.ApplyNarrativePatch(patch)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated"))
}
func fileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func expectedProvenanceForCommit(
expMgr *experiment.Manager,
commitID string,
) (map[string]string, error) {
out := map[string]string{}
manifest, err := expMgr.ReadManifest(commitID)
if err != nil {
return nil, err
}
if manifest == nil || manifest.OverallSHA == "" {
return nil, fmt.Errorf("missing manifest overall_sha")
}
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
filesPath := expMgr.GetFilesPath(commitID)
depName, err := worker.SelectDependencyManifest(filesPath)
if err == nil && strings.TrimSpace(depName) != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := fileSHA256Hex(depPath)
if err == nil && strings.TrimSpace(sha) != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
func ensureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error {
if expMgr == nil {
return fmt.Errorf("missing experiment manager")
}
commitID = strings.TrimSpace(commitID)
if commitID == "" {
return fmt.Errorf("missing commit id")
}
filesPath := expMgr.GetFilesPath(commitID)
if err := os.MkdirAll(filesPath, 0750); err != nil {
return err
}
trainPath := filepath.Join(filesPath, "train.py")
if _, err := os.Stat(trainPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil {
return err
}
}
reqPath := filepath.Join(filesPath, "requirements.txt")
if _, err := os.Stat(reqPath); os.IsNotExist(err) {
if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil {
return err
}
}
return nil
}
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
if len(payload) < 38 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:])
if resErr != nil {
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"invalid resource request",
resErr.Error(),
)
}
h.logger.Info("queue job request",
"job", jobName,
"priority", priority,
"commit_id", fmt.Sprintf("%x", commitID),
)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
// Create experiment directory and metadata (optimized)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID))
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to create experiment directory",
err.Error(),
)
}
meta := &experiment.Metadata{
CommitID: fmt.Sprintf("%x", commitID),
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to save experiment metadata",
err.Error(),
)
}
// Generate and write content integrity manifest
commitIDStr := fmt.Sprintf("%x", commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to initialize experiment files",
err.Error(),
)
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) {
manifest, err := h.expManager.GenerateManifest(commitIDStr)
if err != nil {
return "", fmt.Errorf("failed to generate manifest: %w", err)
}
if err := h.expManager.WriteManifest(manifest); err != nil {
return "", fmt.Errorf("failed to write manifest: %w", err)
}
return "", nil
}); err != nil {
h.logger.Error("failed to generate/write manifest", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to generate content integrity manifest",
err.Error(),
)
}
// Add user info to experiment metadata (deferred for performance)
go func() {
if h.db != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
exp := &storage.Experiment{
ID: fmt.Sprintf("%x", commitID),
Name: jobName,
Status: "pending",
UserID: user.Name,
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"db.experiments.upsert",
50*time.Millisecond,
func() (string, error) {
return "", h.db.UpsertExperiment(ctx, exp)
},
); err != nil {
h.logger.Error("failed to upsert experiment row", "error", err)
}
}
}()
h.logger.Info(
"job queued",
"job", jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)),
"user", user.Name,
)
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, nil, resources)
}
func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
if len(payload) < 40 {
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"queue job with snapshot payload too short",
"",
)
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
snapIDLen := int(payload[offset])
offset++
if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "")
}
snapshotID := string(payload[offset : offset+snapIDLen])
offset += snapIDLen
snapSHALen := int(payload[offset])
offset++
if snapSHALen < 1 || len(payload) < offset+snapSHALen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "")
}
snapshotSHA := string(payload[offset : offset+snapSHALen])
offset += snapSHALen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"invalid resource request",
resErr.Error(),
)
}
h.logger.Info("queue job with snapshot request",
"job", jobName,
"priority", priority,
"commit_id", fmt.Sprintf("%x", commitID),
"snapshot_id", snapshotID,
)
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID))
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to create experiment directory",
err.Error(),
)
}
meta := &experiment.Metadata{
CommitID: fmt.Sprintf("%x", commitID),
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to save experiment metadata",
err.Error(),
)
}
commitIDStr := fmt.Sprintf("%x", commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to initialize experiment files",
err.Error(),
)
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) {
manifest, err := h.expManager.GenerateManifest(commitIDStr)
if err != nil {
return "", fmt.Errorf("failed to generate manifest: %w", err)
}
if err := h.expManager.WriteManifest(manifest); err != nil {
return "", fmt.Errorf("failed to write manifest: %w", err)
}
return "", nil
}); err != nil {
h.logger.Error("failed to generate/write manifest", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to generate content integrity manifest",
err.Error(),
)
}
go func() {
if h.db != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
exp := &storage.Experiment{
ID: fmt.Sprintf("%x", commitID),
Name: jobName,
Status: "pending",
UserID: user.Name,
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"db.experiments.upsert",
50*time.Millisecond,
func() (string, error) {
return "", h.db.UpsertExperiment(ctx, exp)
},
); err != nil {
h.logger.Error("failed to upsert experiment row", "error", err)
}
}
}()
h.logger.Info(
"job queued",
"job", jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)),
"user", user.Name,
)
return h.enqueueTaskAndRespondWithSnapshot(
conn,
user,
jobName,
priority,
commitID,
nil,
resources,
snapshotID,
snapshotSHA,
)
}
// handleQueueJobWithTracking queues a job with optional tracking configuration.
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
// [tracking_json_len:2][tracking_json:var]
func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error {
if len(payload) < 38+2 { // minimum with zero-length tracking JSON
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"queue job with tracking payload too short",
"",
)
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
// Ensure we have job name and two bytes for tracking length
if len(payload) < 38+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if trackingLen < 0 || len(payload) < offset+trackingLen {
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"invalid tracking json length",
"",
)
}
var trackingCfg *queue.TrackingConfig
if trackingLen > 0 {
var cfg queue.TrackingConfig
if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error())
}
trackingCfg = &cfg
}
offset += trackingLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
"invalid resource request",
resErr.Error(),
)
}
h.logger.Info("queue job with tracking request",
"job", jobName,
"priority", priority,
"commit_id", fmt.Sprintf("%x", commitID),
)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued (with tracking)",
"job", jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
// Create experiment directory and metadata (optimized)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID))
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to create experiment directory",
err.Error(),
)
}
meta := &experiment.Metadata{
CommitID: fmt.Sprintf("%x", commitID),
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to save experiment metadata",
err.Error(),
)
}
// Generate and write content integrity manifest
commitIDStr := fmt.Sprintf("%x", commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) {
manifest, err := h.expManager.GenerateManifest(commitIDStr)
if err != nil {
return "", fmt.Errorf("failed to generate manifest: %w", err)
}
if err := h.expManager.WriteManifest(manifest); err != nil {
return "", fmt.Errorf("failed to write manifest: %w", err)
}
return "", nil
}); err != nil {
h.logger.Error("failed to generate/write manifest", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to generate content integrity manifest",
err.Error(),
)
}
// Add user info to experiment metadata (deferred for performance)
go func() {
if h.db != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
exp := &storage.Experiment{
ID: fmt.Sprintf("%x", commitID),
Name: jobName,
Status: "pending",
UserID: user.Name,
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"db.experiments.upsert",
50*time.Millisecond,
func() (string, error) {
return "", h.db.UpsertExperiment(ctx, exp)
},
); err != nil {
h.logger.Error("failed to upsert experiment row", "error", err)
}
}
}()
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources)
}
type queueJobWithArgsPayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
resources *resourceRequest
}
type queueJobWithNotePayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
note string
resources *resourceRequest
}
func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) {
// Protocol:
// [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
// [args_len:2][args:var][note_len:2][note:var][resources?:var]
if len(payload) < 42 {
return nil, fmt.Errorf("queue job with note payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return nil, fmt.Errorf("invalid job name length")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if argsLen < 0 || len(payload) < offset+argsLen+2 {
return nil, fmt.Errorf("invalid args length")
}
args := ""
if argsLen > 0 {
args = string(payload[offset : offset+argsLen])
}
offset += argsLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen < 0 || len(payload) < offset+noteLen {
return nil, fmt.Errorf("invalid note length")
}
note := ""
if noteLen > 0 {
note = string(payload[offset : offset+noteLen])
}
offset += noteLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return nil, resErr
}
return &queueJobWithNotePayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
note: note,
resources: resources,
}, nil
}
func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) {
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][resources?:var]
if len(payload) < 40 {
return nil, fmt.Errorf("queue job with args payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return nil, fmt.Errorf("invalid job name length")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if argsLen < 0 || len(payload) < offset+argsLen {
return nil, fmt.Errorf("invalid args length")
}
args := ""
if argsLen > 0 {
args = string(payload[offset : offset+argsLen])
}
offset += argsLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return nil, resErr
}
return &queueJobWithArgsPayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
resources: resources,
}, nil
}
func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithArgsPayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error())
}
h.logger.Info("queue job request",
"job", p.jobName,
"priority", p.priority,
"commit_id", fmt.Sprintf("%x", p.commitID),
)
// Validate API key and get user information
var user *auth.User
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", p.jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
commitIDStr := fmt.Sprintf("%x", p.commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(commitIDStr)
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: p.jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
}
return h.enqueueTaskAndRespondWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, nil, p.resources)
}
func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithNotePayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error())
}
h.logger.Info("queue job request",
"job", p.jobName,
"priority", p.priority,
"commit_id", fmt.Sprintf("%x", p.commitID),
)
var user *auth.User
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", p.jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
commitIDStr := fmt.Sprintf("%x", p.commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(commitIDStr)
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: p.jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
}
return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, nil, p.resources)
}
// enqueueTaskAndRespond enqueues a task and sends a success response.
func (h *WSHandler) enqueueTaskAndRespond(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", tracking, resources)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
note string,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitIDStr,
},
Tracking: tracking,
}
if strings.TrimSpace(note) != "" {
task.Metadata["note"] = strings.TrimSpace(note)
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeDatabaseError,
"Failed to enqueue task",
err.Error(),
)
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgs(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
// Enqueue task if queue is available
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitIDStr,
},
Tracking: tracking,
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeDatabaseError,
"Failed to enqueue task",
err.Error(),
)
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) enqueueTaskAndRespondWithSnapshot(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
snapshotID string,
snapshotSHA string,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
SnapshotID: strings.TrimSpace(snapshotID),
Metadata: map[string]string{
"commit_id": commitIDStr,
"snapshot_sha256": strings.TrimSpace(snapshotSHA),
},
Tracking: tracking,
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
type resourceRequest struct {
CPU int
MemoryGB int
GPU int
GPUMemory string
}
// parseOptionalResourceRequest parses an optional tail encoding:
// [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
// If payload is empty, returns nil.
func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) {
if len(payload) == 0 {
return nil, nil
}
if len(payload) < 4 {
return nil, fmt.Errorf("resource payload too short")
}
cpu := int(payload[0])
mem := int(payload[1])
gpu := int(payload[2])
gpuMemLen := int(payload[3])
if gpuMemLen < 0 || len(payload) < 4+gpuMemLen {
return nil, fmt.Errorf("invalid gpu memory length")
}
gpuMem := ""
if gpuMemLen > 0 {
gpuMem = string(payload[4 : 4+gpuMemLen])
}
return &resourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil
}
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
}
apiKeyHash := payload[:16]
h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", apiKeyHash))
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for viewing jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to view jobs",
"",
)
}
// Get tasks with user filtering
var tasks []*queue.Task
if h.queue != nil {
allTasks, err := h.queue.GetAllTasks()
if err != nil {
h.logger.Error("failed to get tasks", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
}
// Filter tasks based on user permissions
for _, task := range allTasks {
// If auth is disabled or admin can see all tasks
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
tasks = append(tasks, task)
continue
}
// Users can only see their own tasks
if task.UserID == user.Name || task.CreatedBy == user.Name {
tasks = append(tasks, task)
}
}
}
// Build status response as raw JSON for CLI compatibility
h.logger.Info("building status response")
status := map[string]any{
"user": map[string]any{
"name": user.Name,
"admin": user.Admin,
"roles": user.Roles,
},
"tasks": map[string]any{
"total": len(tasks),
"queued": countTasksByStatus(tasks, "queued"),
"running": countTasksByStatus(tasks, "running"),
"failed": countTasksByStatus(tasks, "failed"),
"completed": countTasksByStatus(tasks, "completed"),
},
"queue": tasks,
}
if h.queue != nil {
if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil {
sort.Slice(states, func(i, j int) bool {
if states[i].WorkerID != states[j].WorkerID {
return states[i].WorkerID < states[j].WorkerID
}
return states[i].TaskID < states[j].TaskID
})
status["prewarm"] = states
}
}
h.logger.Info("serializing JSON response")
jsonData, err := json.Marshal(status)
if err != nil {
h.logger.Error("failed to marshal JSON", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
// Send as binary protocol packet
packet := NewDataPacket("status", jsonData)
return h.sendResponsePacket(conn, packet)
}
// countTasksByStatus counts tasks by their status
func countTasksByStatus(tasks []*queue.Task, status string) int {
count := 0
for _, task := range tasks {
if task.Status == status {
count++
}
}
return count
}
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
}
// Parse 16-byte binary API key hash
apiKeyHash := payload[:16]
jobNameLen := int(payload[16])
if len(payload) < 17+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[17 : 17+jobNameLen])
h.logger.Info("cancel job request", "job", jobName)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for canceling jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to cancel jobs",
"",
)
}
// Find the task and verify ownership
if h.queue != nil {
task, err := h.queue.GetTaskByName(jobName)
if err != nil {
h.logger.Error("task not found", "job", jobName, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
}
// Check if user can cancel this task (admin or owner)
if h.authConfig != nil &&
h.authConfig.Enabled &&
!user.Admin &&
task.UserID != user.Name &&
task.CreatedBy != user.Name {
h.logger.Error(
"unauthorized job cancellation attempt",
"user", user.Name,
"job", jobName,
"task_owner", task.UserID,
)
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only cancel your own jobs",
"",
)
}
// Cancel the task
if err := h.queue.CancelTask(task.ID); err != nil {
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
}
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
}
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][prune_type:1][value:4]
if len(payload) < 21 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Parse 16-byte binary API key hash
apiKeyHash := payload[:16]
pruneType := payload[16]
value := binary.BigEndian.Uint32(payload[17:21])
h.logger.Info("prune request", "type", pruneType, "value", value)
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
h.logger.Error("api key verification failed", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
// Convert prune parameters
var keepCount int
var olderThanDays int
switch pruneType {
case 0:
// keep N
keepCount = int(value)
olderThanDays = 0
case 1:
// older than days
keepCount = 0
olderThanDays = int(value)
default:
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
fmt.Sprintf("invalid prune type: %d", pruneType),
"",
)
}
// Perform pruning
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
if err != nil {
h.logger.Error("prune failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
}
if h.queue != nil {
_ = h.queue.SignalPrewarmGC()
}
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
// Send structured success response
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
return h.sendResponsePacket(conn, packet)
}
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var]
if len(payload) < 51 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
step := int(binary.BigEndian.Uint32(payload[36:40]))
valueBits := binary.BigEndian.Uint64(payload[40:48])
value := math.Float64frombits(valueBits)
nameLen := int(payload[48])
if len(payload) < 49+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
}
name := string(payload[49 : 49+nameLen])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
h.logger.Error("api key verification failed", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil {
h.logger.Error("failed to log metric", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
}
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][commit_id:20]
if len(payload) < 36 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID))
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
}
metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID))
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
}
var dbMeta *storage.ExperimentWithMetadata
if h.db != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID))
if err == nil {
dbMeta = m
}
}
response := map[string]interface{}{
"metadata": meta,
"metrics": metrics,
}
if dbMeta != nil {
response["reproducibility"] = dbMeta
}
responseData, err := json.Marshal(response)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData))
}