fetch_ml/internal/api/ws_jobs.go
Jeremie Fraeys d1bef0a450
refactor: Phase 3 - fix config/storage boundaries
Move schema ownership to infrastructure layer:

- Redis keys: config/constants.go -> queue/keys.go (TaskQueueKey, TaskPrefix, etc.)

- Filesystem paths: config/paths.go -> storage/paths.go (JobPaths)

- Create config/shared.go with RedisConfig, SSHConfig

- Update all imports: worker/, api/helpers, api/ws_jobs, api/ws_validate

- Clean up: remove duplicates from queue/task.go, queue/queue.go, config/paths.go

Build status: Compiles successfully
2026-02-17 12:49:53 -05:00

1365 lines
42 KiB
Go

package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
)
func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
if len(payload) < 16+1+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
authorLen := int(payload[offset])
offset += 1
if authorLen < 0 || len(payload) < offset+authorLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
}
author := string(payload[offset : offset+authorLen])
offset += authorLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen <= 0 || len(payload) < offset+noteLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
}
note := string(payload[offset : offset+noteLen])
user, err := h.authenticate(conn, payload, 16)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
return err
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct{ root string }{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
if strings.TrimSpace(author) == "" {
author = user.Name
}
rm.AddAnnotation(time.Now().UTC(), author, note)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added"))
}
func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "")
}
patchJSON := payload[offset : offset+patchLen]
user, err := h.authenticate(conn, payload, 16)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
return err
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct{ root string }{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
var patch manifest.NarrativePatch
if err := json.Unmarshal(patchJSON, &patch); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error())
}
rm.ApplyNarrativePatch(patch)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated"))
}
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Parse payload first
if len(payload) < ProtocolMinQueueJob {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
}
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:])
if resErr != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
}
h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID))
// Authenticate and authorize
user, err := h.authenticate(conn, payload, ProtocolMinQueueJob)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
return err
}
return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, nil, resources)
}
func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
if len(payload) < ProtocolMinQueueJobWithSnapshot {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "")
}
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
snapIDLen := int(payload[offset])
offset++
if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "")
}
snapshotID := string(payload[offset : offset+snapIDLen])
offset += snapIDLen
snapSHALen := int(payload[offset])
offset++
if snapSHALen < 1 || len(payload) < offset+snapSHALen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "")
}
snapshotSHA := string(payload[offset : offset+snapSHALen])
offset += snapSHALen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
}
h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority,
"commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID)
user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithSnapshot)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
return err
}
return h.processAndEnqueueJobWithSnapshot(conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA)
}
func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error {
if len(payload) < ProtocolMinQueueJobWithTracking {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "")
}
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if trackingLen < 0 || len(payload) < offset+trackingLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json length", "")
}
var trackingCfg *queue.TrackingConfig
if trackingLen > 0 {
var cfg queue.TrackingConfig
if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error())
}
trackingCfg = &cfg
}
offset += trackingLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
}
h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID))
user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithTracking)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
return err
}
return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, trackingCfg, resources)
}
type queueJobWithArgsPayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
force bool
resources *resourceRequest
}
type queueJobWithNotePayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
note string
force bool
resources *resourceRequest
}
func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) {
// Protocol:
// [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
// [args_len:2][args:var][note_len:2][note:var][force:1][resources?:var]
if len(payload) < 43 {
return nil, fmt.Errorf("queue job with note payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
p := helpers.NewPayloadParser(payload, 37)
jobName, err := p.ParseLengthPrefixedString()
if err != nil {
return nil, fmt.Errorf("invalid job name: %w", err)
}
args, err := p.ParseUint16PrefixedString()
if err != nil {
return nil, fmt.Errorf("invalid args: %w", err)
}
note, err := p.ParseUint16PrefixedString()
if err != nil {
return nil, fmt.Errorf("invalid note: %w", err)
}
force, err := p.ParseBool()
if err != nil {
return nil, fmt.Errorf("missing force flag: %w", err)
}
resources, resErr := helpers.ParseResourceRequest(p.Remaining())
if resErr != nil {
return nil, resErr
}
return &queueJobWithNotePayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
note: note,
force: force,
resources: resources,
}, nil
}
func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) {
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][force:1][resources?:var]
if len(payload) < 41 {
return nil, fmt.Errorf("queue job with args payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
p := helpers.NewPayloadParser(payload, 37)
jobName, err := p.ParseLengthPrefixedString()
if err != nil {
return nil, fmt.Errorf("invalid job name: %w", err)
}
args, err := p.ParseUint16PrefixedString()
if err != nil {
return nil, fmt.Errorf("invalid args: %w", err)
}
force, err := p.ParseBool()
if err != nil {
return nil, fmt.Errorf("missing force flag: %w", err)
}
resources, resErr := helpers.ParseResourceRequest(p.Remaining())
if resErr != nil {
return nil, resErr
}
return &queueJobWithArgsPayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
force: force,
resources: resources,
}, nil
}
func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithArgsPayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error())
}
h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID))
user, err := h.authenticateWithHash(conn, p.apiKeyHash)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
return err
}
return h.processAndEnqueueJobWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, p.force, nil, p.resources)
}
func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithNotePayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error())
}
h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID))
user, err := h.authenticateWithHash(conn, p.apiKeyHash)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
return err
}
return h.processAndEnqueueJobWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, p.force, nil, p.resources)
}
// findDuplicateTask searches for an existing task with the same composite key
// (commit_id + dataset_id + params_hash) to detect truly identical experiments
func (h *WSHandler) findDuplicateTask(commitIDStr, datasetID, paramsHash string) *queue.Task {
if h.queue == nil {
return nil
}
tasks, err := h.queue.GetAllTasks()
if err != nil {
return nil
}
for _, task := range tasks {
if task.Metadata == nil {
continue
}
// Check all three components of the composite key
if task.Metadata["commit_id"] == commitIDStr &&
task.Metadata["dataset_id"] == datasetID &&
task.Metadata["params_hash"] == paramsHash {
return task
}
}
return nil
}
// sendDuplicateResponse sends a data packet response for duplicate jobs
func (h *WSHandler) sendDuplicateResponse(conn *websocket.Conn, existingTask *queue.Task) error {
response := map[string]interface{}{
"duplicate": true,
"existing_id": existingTask.ID,
"status": existingTask.Status,
"queued_by": existingTask.CreatedBy,
"queued_at": existingTask.CreatedAt.Unix(),
}
// Add duration for completed tasks
if existingTask.Status == "completed" && existingTask.EndedAt != nil {
duration := existingTask.EndedAt.Sub(existingTask.CreatedAt).Seconds()
response["duration_seconds"] = int64(duration)
// Try to get metrics for completed tasks
if h.expManager != nil {
commitID := existingTask.Metadata["commit_id"]
if metrics, err := h.expManager.GetMetrics(commitID); err == nil && len(metrics) > 0 {
metricsMap := make(map[string]interface{})
for _, m := range metrics {
metricsMap[m.Name] = m.Value
}
response["metrics"] = metricsMap
}
}
}
// Add error reason for failed tasks with full failure classification
if existingTask.Status == "failed" && existingTask.Error != "" {
response["error_reason"] = existingTask.Error
// Classify failure using exit codes, signals, and error context
failureClass := queue.FailureUnknown
exitCode := 0
signalName := ""
// Extract exit code from error or metadata
if code, ok := existingTask.Metadata["exit_code"]; ok {
fmt.Sscanf(code, "%d", &exitCode)
}
if sig, ok := existingTask.Metadata["signal"]; ok {
signalName = sig
}
// Get log tail for classification if available
logTail := existingTask.Error
if existingTask.LastError != "" {
logTail = existingTask.LastError
}
// Classify failure directly using signals, exit codes, and log content
// Note: failureClass declared above at line 536, just reassign here
// Override with signal-based classification if available
if signalName == "SIGKILL" || signalName == "9" {
failureClass = queue.FailureInfrastructure
} else if exitCode != 0 {
// Use the new ClassifyFailure with error log content
logContent := existingTask.Error
if existingTask.LastError != "" {
logContent = existingTask.LastError
}
failureClass = queue.ClassifyFailure(exitCode, nil, logContent)
}
response["failure_class"] = string(failureClass)
response["exit_code"] = exitCode
response["signal"] = signalName
response["log_tail"] = logTail
// Add user-facing suggestion
response["suggestion"] = queue.GetFailureSuggestion(failureClass, logTail)
// Add retry information with class-specific policy
response["retry_count"] = existingTask.RetryCount
response["retry_cap"] = 3
response["auto_retryable"] = queue.ShouldAutoRetry(failureClass, existingTask.RetryCount)
// Add attempts history if available
if len(existingTask.Attempts) > 0 {
response["attempts"] = existingTask.Attempts
}
}
responseData, err := json.Marshal(response)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize duplicate response", err.Error())
}
packet := NewDataPacket("duplicate", responseData)
return h.sendResponsePacket(conn, packet)
}
// enqueueTaskAndRespond enqueues a task and sends a success response.
func (h *WSHandler) enqueueTaskAndRespond(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", false, tracking, resources)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
note string,
force bool,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
// Compute dataset_id and params_hash from existing data
paramsHash := helpers.ComputeParamsHash(args)
// Note: dataset_id will be empty here since we don't have DatasetSpecs yet
// It will be populated when the task is actually created with datasets
datasetID := ""
// Check for duplicate tasks before proceeding (skip if force=true)
if !force {
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
return h.sendDuplicateResponse(conn, existingTask)
}
} else {
h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr)
}
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitIDStr,
"dataset_id": datasetID,
"params_hash": paramsHash,
},
Tracking: tracking,
}
if strings.TrimSpace(note) != "" {
task.Metadata["note"] = strings.TrimSpace(note)
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeDatabaseError,
"Failed to enqueue task",
err.Error(),
)
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgs(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
force bool,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
// Compute dataset_id and params_hash from existing data
paramsHash := helpers.ComputeParamsHash(args)
// Note: dataset_id will be empty here since we don't have DatasetSpecs yet
// It will be populated when the task is actually created with datasets
datasetID := ""
// Check for duplicate tasks before proceeding (skip if force=true)
if !force {
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
return h.sendDuplicateResponse(conn, existingTask)
}
} else {
h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr)
}
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"dataset_id", datasetID,
"params_hash", paramsHash,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
// Enqueue task if queue is available
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitIDStr,
"dataset_id": datasetID,
"params_hash": paramsHash,
},
Tracking: tracking,
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeDatabaseError,
"Failed to enqueue task",
err.Error(),
)
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
// processAndEnqueueJob handles common experiment setup and task enqueueing
func (h *WSHandler) processAndEnqueueJob(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
}
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, tracking, resources)
}
// processAndEnqueueJobWithSnapshot handles experiment setup and task enqueueing for snapshot jobs
func (h *WSHandler) processAndEnqueueJobWithSnapshot(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
snapshotID string,
snapshotSHA string,
) error {
commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
}
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
return h.enqueueTaskAndRespondWithSnapshot(conn, user, jobName, priority, commitID, tracking, resources, snapshotID, snapshotSHA)
}
// processAndEnqueueJobWithArgs handles experiment setup and task enqueueing for jobs with args
func (h *WSHandler) processAndEnqueueJobWithArgs(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
force bool,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
}
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, args, force, tracking, resources)
}
// processAndEnqueueJobWithArgsAndNote handles experiment setup for jobs with args and note
func (h *WSHandler) processAndEnqueueJobWithArgsAndNote(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
note string,
force bool,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
}
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, jobName, priority, commitID, args, note, force, tracking, resources)
}
func (h *WSHandler) enqueueTaskAndRespondWithSnapshot(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
snapshotID string,
snapshotSHA string,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
// Compute dataset_id from snapshot SHA (snapshot acts as dataset)
datasetID := ""
if strings.TrimSpace(snapshotSHA) != "" {
datasetID = snapshotSHA[:16]
}
// Snapshots don't have args, so params_hash is empty
paramsHash := ""
// Check for duplicate tasks before proceeding
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
return h.sendDuplicateResponse(conn, existingTask)
}
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
SnapshotID: strings.TrimSpace(snapshotID),
Metadata: map[string]string{
"commit_id": commitIDStr,
"snapshot_sha256": strings.TrimSpace(snapshotSHA),
},
Tracking: tracking,
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
// resourceRequest is an alias to helpers.ResourceRequest for backward compatibility
type resourceRequest = helpers.ResourceRequest
// parseOptionalResourceRequest is an alias to helpers.ParseResourceRequest for backward compatibility
func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) {
r, err := helpers.ParseResourceRequest(payload)
if err != nil {
return nil, err
}
// Type conversion is needed because Go doesn't automatically convert named types even with identical underlying structures
if r == nil {
return nil, nil
}
return (*resourceRequest)(r), nil
}
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinStatusRequest)
if err != nil {
return err
}
h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", payload[:16]))
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
return err
}
var tasks []*queue.Task
if h.queue != nil {
allTasks, err := h.queue.GetAllTasks()
if err != nil {
h.logger.Error("failed to get tasks", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
}
for _, task := range allTasks {
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
tasks = append(tasks, task)
continue
}
if task.UserID == user.Name || task.CreatedBy == user.Name {
tasks = append(tasks, task)
}
}
}
h.logger.Info("building status response")
status := map[string]any{
"user": map[string]any{
"name": user.Name,
"admin": user.Admin,
"roles": user.Roles,
},
"tasks": map[string]any{
"total": len(tasks),
"queued": countTasksByStatus(tasks, "queued"),
"running": countTasksByStatus(tasks, "running"),
"failed": countTasksByStatus(tasks, "failed"),
"completed": countTasksByStatus(tasks, "completed"),
},
"queue": tasks,
}
if h.queue != nil {
if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil {
sort.Slice(states, func(i, j int) bool {
if states[i].WorkerID != states[j].WorkerID {
return states[i].WorkerID < states[j].WorkerID
}
return states[i].TaskID < states[j].TaskID
})
status["prewarm"] = states
}
}
h.logger.Info("serializing JSON response")
jsonData, err := json.Marshal(status)
if err != nil {
h.logger.Error("failed to marshal JSON", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
return h.sendResponsePacket(conn, NewDataPacket("status", jsonData))
}
// countTasksByStatus counts tasks by their status
func countTasksByStatus(tasks []*queue.Task, status string) int {
count := 0
for _, task := range tasks {
if task.Status == status {
count++
}
}
return count
}
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinCancelJob)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
return err
}
jobNameLen := int(payload[ProtocolAPIKeyHashLen])
if len(payload) < ProtocolAPIKeyHashLen+1+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+jobNameLen])
h.logger.Info("cancel job request", "job", jobName)
if h.queue == nil {
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
return nil
}
task, err := h.queue.GetTaskByName(jobName)
if err != nil {
h.logger.Error("task not found", "job", jobName, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
}
if h.authConfig != nil && h.authConfig.Enabled && !user.Admin &&
task.UserID != user.Name && task.CreatedBy != user.Name {
h.logger.Error(
"unauthorized job cancellation attempt",
"user", user.Name,
"job", jobName,
"task_owner", task.UserID,
)
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only cancel your own jobs",
"",
)
}
if err := h.queue.CancelTask(task.ID); err != nil {
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
}
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)))
}
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinPrune)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
return err
}
pruneType := payload[ProtocolAPIKeyHashLen]
value := binary.BigEndian.Uint32(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+5])
h.logger.Info("prune request", "type", pruneType, "value", value)
var keepCount int
var olderThanDays int
switch pruneType {
case 0:
keepCount = int(value)
case 1:
olderThanDays = int(value)
default:
return h.sendErrorPacket(
conn,
ErrorCodeInvalidRequest,
fmt.Sprintf("invalid prune type: %d", pruneType),
"",
)
}
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
if err != nil {
h.logger.Error("prune failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
}
if h.queue != nil {
_ = h.queue.SignalPrewarmGC()
}
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))))
}
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinLogMetric)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
return err
}
commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen]
step := int(binary.BigEndian.Uint32(payload[36:40]))
valueBits := binary.BigEndian.Uint64(payload[40:48])
value := math.Float64frombits(valueBits)
nameLen := int(payload[48])
if len(payload) < 49+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
}
name := string(payload[49 : 49+nameLen])
if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil {
h.logger.Error("failed to log metric", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
}
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinGetExperiment)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
return err
}
commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen]
meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID))
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
}
metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID))
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
}
var dbMeta *storage.ExperimentWithMetadata
if h.db != nil {
ctx, cancel := helpers.DBContextShort()
defer cancel()
m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID))
if err == nil {
dbMeta = m
}
}
response := map[string]interface{}{
"metadata": meta,
"metrics": metrics,
}
if dbMeta != nil {
response["reproducibility"] = dbMeta
}
responseData, err := json.Marshal(response)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData))
}
// handleGetLogs handles requests to fetch logs for a task/run
func (h *WSHandler) handleGetLogs(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinGetLogs)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
return err
}
targetIDLen := int(payload[ProtocolAPIKeyHashLen])
if len(payload) < ProtocolMinGetLogs+targetIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", fmt.Sprintf("got %d, need %d", len(payload), ProtocolMinGetLogs+targetIDLen))
}
targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen])
h.logger.Info("get logs request", "target_id", targetID, "user", user.Name)
// TODO: Implement actual log fetching from storage
// For now, return a stub response
response := map[string]interface{}{
"target_id": targetID,
"logs": "[Stub] Log content would appear here\nLine 1: Log output\nLine 2: More output\n",
"truncated": false,
"total_lines": 3,
}
responseData, err := json.Marshal(response)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error())
}
return h.sendResponsePacket(conn, NewDataPacket("logs", responseData))
}
// handleStreamLogs handles requests to stream logs in real-time
func (h *WSHandler) handleStreamLogs(conn *websocket.Conn, payload []byte) error {
user, err := h.authenticate(conn, payload, ProtocolMinStreamLogs)
if err != nil {
return err
}
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
return err
}
targetIDLen := int(payload[ProtocolAPIKeyHashLen])
if len(payload) < ProtocolMinStreamLogs+targetIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", "")
}
targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen])
h.logger.Info("stream logs request", "target_id", targetID, "user", user.Name)
// TODO: Implement actual log streaming
// For now, return a stub response indicating streaming started
response := map[string]interface{}{
"target_id": targetID,
"streaming": true,
"message": "[Stub] Log streaming would start here. This feature is not yet fully implemented.",
}
responseData, err := json.Marshal(response)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error())
}
return h.sendResponsePacket(conn, NewDataPacket("logs_stream", responseData))
}