Move schema ownership to infrastructure layer: - Redis keys: config/constants.go -> queue/keys.go (TaskQueueKey, TaskPrefix, etc.) - Filesystem paths: config/paths.go -> storage/paths.go (JobPaths) - Create config/shared.go with RedisConfig, SSHConfig - Update all imports: worker/, api/helpers, api/ws_jobs, api/ws_validate - Clean up: remove duplicates from queue/task.go, queue/queue.go, config/paths.go Build status: Compiles successfully
1365 lines
42 KiB
Go
1365 lines
42 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
|
)
|
|
|
|
func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
|
|
if len(payload) < 16+1+1+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
|
|
}
|
|
|
|
offset := 16
|
|
|
|
jobNameLen := int(payload[offset])
|
|
offset += 1
|
|
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
jobName := string(payload[offset : offset+jobNameLen])
|
|
offset += jobNameLen
|
|
|
|
authorLen := int(payload[offset])
|
|
offset += 1
|
|
if authorLen < 0 || len(payload) < offset+authorLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
|
|
}
|
|
author := string(payload[offset : offset+authorLen])
|
|
offset += authorLen
|
|
|
|
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
|
offset += 2
|
|
if noteLen <= 0 || len(payload) < offset+noteLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
|
|
}
|
|
note := string(payload[offset : offset+noteLen])
|
|
|
|
user, err := h.authenticate(conn, payload, 16)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := container.ValidateJobName(jobName); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
|
}
|
|
|
|
base := strings.TrimSpace(h.expManager.BasePath())
|
|
if base == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
|
}
|
|
|
|
jobPaths := storage.NewJobPaths(base)
|
|
typedRoots := []struct{ root string }{
|
|
{root: jobPaths.RunningPath()},
|
|
{root: jobPaths.PendingPath()},
|
|
{root: jobPaths.FinishedPath()},
|
|
{root: jobPaths.FailedPath()},
|
|
}
|
|
|
|
var manifestDir string
|
|
for _, item := range typedRoots {
|
|
dir := filepath.Join(item.root, jobName)
|
|
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
|
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
|
manifestDir = dir
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if manifestDir == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
|
|
}
|
|
|
|
rm, err := manifest.LoadFromDir(manifestDir)
|
|
if err != nil || rm == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
|
|
}
|
|
|
|
if strings.TrimSpace(author) == "" {
|
|
author = user.Name
|
|
}
|
|
rm.AddAnnotation(time.Now().UTC(), author, note)
|
|
if err := rm.WriteToDir(manifestDir); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added"))
|
|
}
|
|
|
|
func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var]
|
|
if len(payload) < 16+1+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
|
|
}
|
|
|
|
offset := 16
|
|
|
|
jobNameLen := int(payload[offset])
|
|
offset += 1
|
|
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
jobName := string(payload[offset : offset+jobNameLen])
|
|
offset += jobNameLen
|
|
|
|
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
|
offset += 2
|
|
if patchLen <= 0 || len(payload) < offset+patchLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "")
|
|
}
|
|
patchJSON := payload[offset : offset+patchLen]
|
|
|
|
user, err := h.authenticate(conn, payload, 16)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
|
|
return err
|
|
}
|
|
if err := container.ValidateJobName(jobName); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
|
}
|
|
|
|
base := strings.TrimSpace(h.expManager.BasePath())
|
|
if base == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
|
}
|
|
|
|
jobPaths := storage.NewJobPaths(base)
|
|
typedRoots := []struct{ root string }{
|
|
{root: jobPaths.RunningPath()},
|
|
{root: jobPaths.PendingPath()},
|
|
{root: jobPaths.FinishedPath()},
|
|
{root: jobPaths.FailedPath()},
|
|
}
|
|
|
|
var manifestDir string
|
|
for _, item := range typedRoots {
|
|
dir := filepath.Join(item.root, jobName)
|
|
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
|
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
|
manifestDir = dir
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if manifestDir == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
|
|
}
|
|
|
|
rm, err := manifest.LoadFromDir(manifestDir)
|
|
if err != nil || rm == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
|
|
}
|
|
|
|
var patch manifest.NarrativePatch
|
|
if err := json.Unmarshal(patchJSON, &patch); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error())
|
|
}
|
|
rm.ApplyNarrativePatch(patch)
|
|
if err := rm.WriteToDir(manifestDir); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated"))
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|
// Parse payload first
|
|
if len(payload) < ProtocolMinQueueJob {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
|
|
}
|
|
|
|
commitID := payload[16:36]
|
|
priority := int64(payload[36])
|
|
jobNameLen := int(payload[37])
|
|
|
|
if len(payload) < 38+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
|
|
jobName := string(payload[38 : 38+jobNameLen])
|
|
resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:])
|
|
if resErr != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
|
|
}
|
|
|
|
h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID))
|
|
|
|
// Authenticate and authorize
|
|
user, err := h.authenticate(conn, payload, ProtocolMinQueueJob)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, nil, resources)
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < ProtocolMinQueueJobWithSnapshot {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "")
|
|
}
|
|
|
|
commitID := payload[16:36]
|
|
priority := int64(payload[36])
|
|
jobNameLen := int(payload[37])
|
|
|
|
if len(payload) < 38+jobNameLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
jobName := string(payload[38 : 38+jobNameLen])
|
|
offset := 38 + jobNameLen
|
|
|
|
snapIDLen := int(payload[offset])
|
|
offset++
|
|
if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "")
|
|
}
|
|
snapshotID := string(payload[offset : offset+snapIDLen])
|
|
offset += snapIDLen
|
|
|
|
snapSHALen := int(payload[offset])
|
|
offset++
|
|
if snapSHALen < 1 || len(payload) < offset+snapSHALen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "")
|
|
}
|
|
snapshotSHA := string(payload[offset : offset+snapSHALen])
|
|
offset += snapSHALen
|
|
|
|
resources, resErr := parseOptionalResourceRequest(payload[offset:])
|
|
if resErr != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
|
|
}
|
|
|
|
h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority,
|
|
"commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID)
|
|
|
|
user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithSnapshot)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
return h.processAndEnqueueJobWithSnapshot(conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA)
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error {
|
|
if len(payload) < ProtocolMinQueueJobWithTracking {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "")
|
|
}
|
|
|
|
commitID := payload[16:36]
|
|
priority := int64(payload[36])
|
|
jobNameLen := int(payload[37])
|
|
|
|
if len(payload) < 38+jobNameLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
|
|
jobName := string(payload[38 : 38+jobNameLen])
|
|
offset := 38 + jobNameLen
|
|
trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
|
offset += 2
|
|
|
|
if trackingLen < 0 || len(payload) < offset+trackingLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json length", "")
|
|
}
|
|
|
|
var trackingCfg *queue.TrackingConfig
|
|
if trackingLen > 0 {
|
|
var cfg queue.TrackingConfig
|
|
if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error())
|
|
}
|
|
trackingCfg = &cfg
|
|
}
|
|
|
|
offset += trackingLen
|
|
resources, resErr := parseOptionalResourceRequest(payload[offset:])
|
|
if resErr != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error())
|
|
}
|
|
|
|
h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID))
|
|
|
|
user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithTracking)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, trackingCfg, resources)
|
|
}
|
|
|
|
type queueJobWithArgsPayload struct {
|
|
apiKeyHash []byte
|
|
commitID []byte
|
|
priority int64
|
|
jobName string
|
|
args string
|
|
force bool
|
|
resources *resourceRequest
|
|
}
|
|
|
|
type queueJobWithNotePayload struct {
|
|
apiKeyHash []byte
|
|
commitID []byte
|
|
priority int64
|
|
jobName string
|
|
args string
|
|
note string
|
|
force bool
|
|
resources *resourceRequest
|
|
}
|
|
|
|
func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) {
|
|
// Protocol:
|
|
// [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
|
// [args_len:2][args:var][note_len:2][note:var][force:1][resources?:var]
|
|
if len(payload) < 43 {
|
|
return nil, fmt.Errorf("queue job with note payload too short")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
commitID := payload[16:36]
|
|
priority := int64(payload[36])
|
|
|
|
p := helpers.NewPayloadParser(payload, 37)
|
|
|
|
jobName, err := p.ParseLengthPrefixedString()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid job name: %w", err)
|
|
}
|
|
|
|
args, err := p.ParseUint16PrefixedString()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid args: %w", err)
|
|
}
|
|
|
|
note, err := p.ParseUint16PrefixedString()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid note: %w", err)
|
|
}
|
|
|
|
force, err := p.ParseBool()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("missing force flag: %w", err)
|
|
}
|
|
|
|
resources, resErr := helpers.ParseResourceRequest(p.Remaining())
|
|
if resErr != nil {
|
|
return nil, resErr
|
|
}
|
|
|
|
return &queueJobWithNotePayload{
|
|
apiKeyHash: apiKeyHash,
|
|
commitID: commitID,
|
|
priority: priority,
|
|
jobName: jobName,
|
|
args: args,
|
|
note: note,
|
|
force: force,
|
|
resources: resources,
|
|
}, nil
|
|
}
|
|
|
|
func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) {
|
|
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][force:1][resources?:var]
|
|
if len(payload) < 41 {
|
|
return nil, fmt.Errorf("queue job with args payload too short")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
commitID := payload[16:36]
|
|
priority := int64(payload[36])
|
|
|
|
p := helpers.NewPayloadParser(payload, 37)
|
|
|
|
jobName, err := p.ParseLengthPrefixedString()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid job name: %w", err)
|
|
}
|
|
|
|
args, err := p.ParseUint16PrefixedString()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid args: %w", err)
|
|
}
|
|
|
|
force, err := p.ParseBool()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("missing force flag: %w", err)
|
|
}
|
|
|
|
resources, resErr := helpers.ParseResourceRequest(p.Remaining())
|
|
if resErr != nil {
|
|
return nil, resErr
|
|
}
|
|
|
|
return &queueJobWithArgsPayload{
|
|
apiKeyHash: apiKeyHash,
|
|
commitID: commitID,
|
|
priority: priority,
|
|
jobName: jobName,
|
|
args: args,
|
|
force: force,
|
|
resources: resources,
|
|
}, nil
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error {
|
|
p, err := parseQueueJobWithArgsPayload(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error())
|
|
}
|
|
|
|
h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID))
|
|
|
|
user, err := h.authenticateWithHash(conn, p.apiKeyHash)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
return h.processAndEnqueueJobWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, p.force, nil, p.resources)
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error {
|
|
p, err := parseQueueJobWithNotePayload(payload)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error())
|
|
}
|
|
|
|
h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID))
|
|
|
|
user, err := h.authenticateWithHash(conn, p.apiKeyHash)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsCreate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
return h.processAndEnqueueJobWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, p.force, nil, p.resources)
|
|
}
|
|
|
|
// findDuplicateTask searches for an existing task with the same composite key
|
|
// (commit_id + dataset_id + params_hash) to detect truly identical experiments
|
|
func (h *WSHandler) findDuplicateTask(commitIDStr, datasetID, paramsHash string) *queue.Task {
|
|
if h.queue == nil {
|
|
return nil
|
|
}
|
|
|
|
tasks, err := h.queue.GetAllTasks()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
if task.Metadata == nil {
|
|
continue
|
|
}
|
|
// Check all three components of the composite key
|
|
if task.Metadata["commit_id"] == commitIDStr &&
|
|
task.Metadata["dataset_id"] == datasetID &&
|
|
task.Metadata["params_hash"] == paramsHash {
|
|
return task
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// sendDuplicateResponse sends a data packet response for duplicate jobs
|
|
func (h *WSHandler) sendDuplicateResponse(conn *websocket.Conn, existingTask *queue.Task) error {
|
|
response := map[string]interface{}{
|
|
"duplicate": true,
|
|
"existing_id": existingTask.ID,
|
|
"status": existingTask.Status,
|
|
"queued_by": existingTask.CreatedBy,
|
|
"queued_at": existingTask.CreatedAt.Unix(),
|
|
}
|
|
|
|
// Add duration for completed tasks
|
|
if existingTask.Status == "completed" && existingTask.EndedAt != nil {
|
|
duration := existingTask.EndedAt.Sub(existingTask.CreatedAt).Seconds()
|
|
response["duration_seconds"] = int64(duration)
|
|
|
|
// Try to get metrics for completed tasks
|
|
if h.expManager != nil {
|
|
commitID := existingTask.Metadata["commit_id"]
|
|
if metrics, err := h.expManager.GetMetrics(commitID); err == nil && len(metrics) > 0 {
|
|
metricsMap := make(map[string]interface{})
|
|
for _, m := range metrics {
|
|
metricsMap[m.Name] = m.Value
|
|
}
|
|
response["metrics"] = metricsMap
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add error reason for failed tasks with full failure classification
|
|
if existingTask.Status == "failed" && existingTask.Error != "" {
|
|
response["error_reason"] = existingTask.Error
|
|
|
|
// Classify failure using exit codes, signals, and error context
|
|
failureClass := queue.FailureUnknown
|
|
exitCode := 0
|
|
signalName := ""
|
|
|
|
// Extract exit code from error or metadata
|
|
if code, ok := existingTask.Metadata["exit_code"]; ok {
|
|
fmt.Sscanf(code, "%d", &exitCode)
|
|
}
|
|
if sig, ok := existingTask.Metadata["signal"]; ok {
|
|
signalName = sig
|
|
}
|
|
|
|
// Get log tail for classification if available
|
|
logTail := existingTask.Error
|
|
if existingTask.LastError != "" {
|
|
logTail = existingTask.LastError
|
|
}
|
|
|
|
// Classify failure directly using signals, exit codes, and log content
|
|
// Note: failureClass declared above at line 536, just reassign here
|
|
|
|
// Override with signal-based classification if available
|
|
if signalName == "SIGKILL" || signalName == "9" {
|
|
failureClass = queue.FailureInfrastructure
|
|
} else if exitCode != 0 {
|
|
// Use the new ClassifyFailure with error log content
|
|
logContent := existingTask.Error
|
|
if existingTask.LastError != "" {
|
|
logContent = existingTask.LastError
|
|
}
|
|
failureClass = queue.ClassifyFailure(exitCode, nil, logContent)
|
|
}
|
|
|
|
response["failure_class"] = string(failureClass)
|
|
response["exit_code"] = exitCode
|
|
response["signal"] = signalName
|
|
response["log_tail"] = logTail
|
|
|
|
// Add user-facing suggestion
|
|
response["suggestion"] = queue.GetFailureSuggestion(failureClass, logTail)
|
|
|
|
// Add retry information with class-specific policy
|
|
response["retry_count"] = existingTask.RetryCount
|
|
response["retry_cap"] = 3
|
|
response["auto_retryable"] = queue.ShouldAutoRetry(failureClass, existingTask.RetryCount)
|
|
|
|
// Add attempts history if available
|
|
if len(existingTask.Attempts) > 0 {
|
|
response["attempts"] = existingTask.Attempts
|
|
}
|
|
}
|
|
|
|
responseData, err := json.Marshal(response)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize duplicate response", err.Error())
|
|
}
|
|
|
|
packet := NewDataPacket("duplicate", responseData)
|
|
return h.sendResponsePacket(conn, packet)
|
|
}
|
|
|
|
// enqueueTaskAndRespond enqueues a task and sends a success response.
|
|
func (h *WSHandler) enqueueTaskAndRespond(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", false, tracking, resources)
|
|
}
|
|
|
|
func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
args string,
|
|
note string,
|
|
force bool,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
|
|
|
commitIDStr := fmt.Sprintf("%x", commitID)
|
|
|
|
// Compute dataset_id and params_hash from existing data
|
|
paramsHash := helpers.ComputeParamsHash(args)
|
|
// Note: dataset_id will be empty here since we don't have DatasetSpecs yet
|
|
// It will be populated when the task is actually created with datasets
|
|
datasetID := ""
|
|
|
|
// Check for duplicate tasks before proceeding (skip if force=true)
|
|
if !force {
|
|
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
|
|
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
|
|
return h.sendDuplicateResponse(conn, existingTask)
|
|
}
|
|
} else {
|
|
h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr)
|
|
}
|
|
|
|
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
|
|
if provErr != nil {
|
|
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
|
|
"commit_id", commitIDStr,
|
|
"error", provErr)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeStorageError,
|
|
"Failed to compute expected provenance",
|
|
provErr.Error(),
|
|
)
|
|
}
|
|
|
|
if h.queue != nil {
|
|
taskID := uuid.New().String()
|
|
task := &queue.Task{
|
|
ID: taskID,
|
|
JobName: jobName,
|
|
Args: strings.TrimSpace(args),
|
|
Status: "queued",
|
|
Priority: priority,
|
|
CreatedAt: time.Now(),
|
|
UserID: user.Name,
|
|
Username: user.Name,
|
|
CreatedBy: user.Name,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDStr,
|
|
"dataset_id": datasetID,
|
|
"params_hash": paramsHash,
|
|
},
|
|
Tracking: tracking,
|
|
}
|
|
if strings.TrimSpace(note) != "" {
|
|
task.Metadata["note"] = strings.TrimSpace(note)
|
|
}
|
|
for k, v := range prov {
|
|
if v != "" {
|
|
task.Metadata[k] = v
|
|
}
|
|
}
|
|
if resources != nil {
|
|
task.CPU = resources.CPU
|
|
task.MemoryGB = resources.MemoryGB
|
|
task.GPU = resources.GPU
|
|
task.GPUMemory = resources.GPUMemory
|
|
}
|
|
|
|
if _, err := telemetry.ExecWithMetrics(
|
|
h.logger,
|
|
"queue.add_task",
|
|
20*time.Millisecond,
|
|
func() (string, error) {
|
|
return "", h.queue.AddTask(task)
|
|
},
|
|
); err != nil {
|
|
h.logger.Error("failed to enqueue task", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeDatabaseError,
|
|
"Failed to enqueue task",
|
|
err.Error(),
|
|
)
|
|
}
|
|
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash)
|
|
} else {
|
|
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
|
}
|
|
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Internal error",
|
|
"Failed to serialize response",
|
|
)
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
func (h *WSHandler) enqueueTaskAndRespondWithArgs(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
args string,
|
|
force bool,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
|
|
|
commitIDStr := fmt.Sprintf("%x", commitID)
|
|
|
|
// Compute dataset_id and params_hash from existing data
|
|
paramsHash := helpers.ComputeParamsHash(args)
|
|
// Note: dataset_id will be empty here since we don't have DatasetSpecs yet
|
|
// It will be populated when the task is actually created with datasets
|
|
datasetID := ""
|
|
|
|
// Check for duplicate tasks before proceeding (skip if force=true)
|
|
if !force {
|
|
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
|
|
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
|
|
return h.sendDuplicateResponse(conn, existingTask)
|
|
}
|
|
} else {
|
|
h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr)
|
|
}
|
|
|
|
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
|
|
if provErr != nil {
|
|
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
|
|
"commit_id", commitIDStr,
|
|
"dataset_id", datasetID,
|
|
"params_hash", paramsHash,
|
|
"error", provErr)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeStorageError,
|
|
"Failed to compute expected provenance",
|
|
provErr.Error(),
|
|
)
|
|
}
|
|
|
|
// Enqueue task if queue is available
|
|
if h.queue != nil {
|
|
taskID := uuid.New().String()
|
|
task := &queue.Task{
|
|
ID: taskID,
|
|
JobName: jobName,
|
|
Args: strings.TrimSpace(args),
|
|
Status: "queued",
|
|
Priority: priority,
|
|
CreatedAt: time.Now(),
|
|
UserID: user.Name,
|
|
Username: user.Name,
|
|
CreatedBy: user.Name,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDStr,
|
|
"dataset_id": datasetID,
|
|
"params_hash": paramsHash,
|
|
},
|
|
Tracking: tracking,
|
|
}
|
|
for k, v := range prov {
|
|
if v != "" {
|
|
task.Metadata[k] = v
|
|
}
|
|
}
|
|
if resources != nil {
|
|
task.CPU = resources.CPU
|
|
task.MemoryGB = resources.MemoryGB
|
|
task.GPU = resources.GPU
|
|
task.GPUMemory = resources.GPUMemory
|
|
}
|
|
|
|
if _, err := telemetry.ExecWithMetrics(
|
|
h.logger,
|
|
"queue.add_task",
|
|
20*time.Millisecond,
|
|
func() (string, error) {
|
|
return "", h.queue.AddTask(task)
|
|
},
|
|
); err != nil {
|
|
h.logger.Error("failed to enqueue task", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeDatabaseError,
|
|
"Failed to enqueue task",
|
|
err.Error(),
|
|
)
|
|
}
|
|
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash)
|
|
} else {
|
|
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
|
}
|
|
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Internal error",
|
|
"Failed to serialize response",
|
|
)
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
// processAndEnqueueJob handles common experiment setup and task enqueueing
|
|
func (h *WSHandler) processAndEnqueueJob(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
|
|
}
|
|
|
|
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
|
|
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, tracking, resources)
|
|
}
|
|
|
|
// processAndEnqueueJobWithSnapshot handles experiment setup and task enqueueing for snapshot jobs
|
|
func (h *WSHandler) processAndEnqueueJobWithSnapshot(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) error {
|
|
commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
|
|
}
|
|
|
|
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
|
|
return h.enqueueTaskAndRespondWithSnapshot(conn, user, jobName, priority, commitID, tracking, resources, snapshotID, snapshotSHA)
|
|
}
|
|
|
|
// processAndEnqueueJobWithArgs handles experiment setup and task enqueueing for jobs with args
|
|
func (h *WSHandler) processAndEnqueueJobWithArgs(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
args string,
|
|
force bool,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
|
|
}
|
|
|
|
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
|
|
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, args, force, tracking, resources)
|
|
}
|
|
|
|
// processAndEnqueueJobWithArgsAndNote handles experiment setup for jobs with args and note
|
|
func (h *WSHandler) processAndEnqueueJobWithArgsAndNote(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
args string,
|
|
note string,
|
|
force bool,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
) error {
|
|
commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "")
|
|
}
|
|
|
|
helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name)
|
|
return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, jobName, priority, commitID, args, note, force, tracking, resources)
|
|
}
|
|
|
|
func (h *WSHandler) enqueueTaskAndRespondWithSnapshot(
|
|
conn *websocket.Conn,
|
|
user *auth.User,
|
|
jobName string,
|
|
priority int64,
|
|
commitID []byte,
|
|
tracking *queue.TrackingConfig,
|
|
resources *resourceRequest,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) error {
|
|
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
|
|
|
commitIDStr := fmt.Sprintf("%x", commitID)
|
|
|
|
// Compute dataset_id from snapshot SHA (snapshot acts as dataset)
|
|
datasetID := ""
|
|
if strings.TrimSpace(snapshotSHA) != "" {
|
|
datasetID = snapshotSHA[:16]
|
|
}
|
|
// Snapshots don't have args, so params_hash is empty
|
|
paramsHash := ""
|
|
|
|
// Check for duplicate tasks before proceeding
|
|
if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil {
|
|
h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status)
|
|
return h.sendDuplicateResponse(conn, existingTask)
|
|
}
|
|
|
|
prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr)
|
|
if provErr != nil {
|
|
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
|
|
"commit_id", commitIDStr,
|
|
"error", provErr)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeStorageError,
|
|
"Failed to compute expected provenance",
|
|
provErr.Error(),
|
|
)
|
|
}
|
|
|
|
if h.queue != nil {
|
|
taskID := uuid.New().String()
|
|
task := &queue.Task{
|
|
ID: taskID,
|
|
JobName: jobName,
|
|
Args: "",
|
|
Status: "queued",
|
|
Priority: priority,
|
|
CreatedAt: time.Now(),
|
|
UserID: user.Name,
|
|
Username: user.Name,
|
|
CreatedBy: user.Name,
|
|
SnapshotID: strings.TrimSpace(snapshotID),
|
|
Metadata: map[string]string{
|
|
"commit_id": commitIDStr,
|
|
"snapshot_sha256": strings.TrimSpace(snapshotSHA),
|
|
},
|
|
Tracking: tracking,
|
|
}
|
|
for k, v := range prov {
|
|
if v != "" {
|
|
task.Metadata[k] = v
|
|
}
|
|
}
|
|
if resources != nil {
|
|
task.CPU = resources.CPU
|
|
task.MemoryGB = resources.MemoryGB
|
|
task.GPU = resources.GPU
|
|
task.GPUMemory = resources.GPUMemory
|
|
}
|
|
|
|
if _, err := telemetry.ExecWithMetrics(
|
|
h.logger,
|
|
"queue.add_task",
|
|
20*time.Millisecond,
|
|
func() (string, error) {
|
|
return "", h.queue.AddTask(task)
|
|
},
|
|
); err != nil {
|
|
h.logger.Error("failed to enqueue task", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
|
|
}
|
|
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
|
|
} else {
|
|
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
|
}
|
|
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Internal error",
|
|
"Failed to serialize response",
|
|
)
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
// resourceRequest is an alias to helpers.ResourceRequest for backward compatibility
|
|
type resourceRequest = helpers.ResourceRequest
|
|
|
|
// parseOptionalResourceRequest is an alias to helpers.ParseResourceRequest for backward compatibility
|
|
func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) {
|
|
r, err := helpers.ParseResourceRequest(payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Type conversion is needed because Go doesn't automatically convert named types even with identical underlying structures
|
|
if r == nil {
|
|
return nil, nil
|
|
}
|
|
return (*resourceRequest)(r), nil
|
|
}
|
|
|
|
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinStatusRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", payload[:16]))
|
|
|
|
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
var tasks []*queue.Task
|
|
if h.queue != nil {
|
|
allTasks, err := h.queue.GetAllTasks()
|
|
if err != nil {
|
|
h.logger.Error("failed to get tasks", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
|
|
}
|
|
|
|
for _, task := range allTasks {
|
|
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
|
|
tasks = append(tasks, task)
|
|
continue
|
|
}
|
|
if task.UserID == user.Name || task.CreatedBy == user.Name {
|
|
tasks = append(tasks, task)
|
|
}
|
|
}
|
|
}
|
|
|
|
h.logger.Info("building status response")
|
|
status := map[string]any{
|
|
"user": map[string]any{
|
|
"name": user.Name,
|
|
"admin": user.Admin,
|
|
"roles": user.Roles,
|
|
},
|
|
"tasks": map[string]any{
|
|
"total": len(tasks),
|
|
"queued": countTasksByStatus(tasks, "queued"),
|
|
"running": countTasksByStatus(tasks, "running"),
|
|
"failed": countTasksByStatus(tasks, "failed"),
|
|
"completed": countTasksByStatus(tasks, "completed"),
|
|
},
|
|
"queue": tasks,
|
|
}
|
|
if h.queue != nil {
|
|
if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil {
|
|
sort.Slice(states, func(i, j int) bool {
|
|
if states[i].WorkerID != states[j].WorkerID {
|
|
return states[i].WorkerID < states[j].WorkerID
|
|
}
|
|
return states[i].TaskID < states[j].TaskID
|
|
})
|
|
status["prewarm"] = states
|
|
}
|
|
}
|
|
|
|
h.logger.Info("serializing JSON response")
|
|
jsonData, err := json.Marshal(status)
|
|
if err != nil {
|
|
h.logger.Error("failed to marshal JSON", "error", err)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Internal error",
|
|
"Failed to serialize response",
|
|
)
|
|
}
|
|
|
|
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
|
|
return h.sendResponsePacket(conn, NewDataPacket("status", jsonData))
|
|
}
|
|
|
|
// countTasksByStatus counts tasks by their status
|
|
func countTasksByStatus(tasks []*queue.Task, status string) int {
|
|
count := 0
|
|
for _, task := range tasks {
|
|
if task.Status == status {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinCancelJob)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
jobNameLen := int(payload[ProtocolAPIKeyHashLen])
|
|
if len(payload) < ProtocolAPIKeyHashLen+1+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
|
|
jobName := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+jobNameLen])
|
|
h.logger.Info("cancel job request", "job", jobName)
|
|
|
|
if h.queue == nil {
|
|
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
|
|
return nil
|
|
}
|
|
|
|
task, err := h.queue.GetTaskByName(jobName)
|
|
if err != nil {
|
|
h.logger.Error("task not found", "job", jobName, "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
|
|
}
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled && !user.Admin &&
|
|
task.UserID != user.Name && task.CreatedBy != user.Name {
|
|
h.logger.Error(
|
|
"unauthorized job cancellation attempt",
|
|
"user", user.Name,
|
|
"job", jobName,
|
|
"task_owner", task.UserID,
|
|
)
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodePermissionDenied,
|
|
"You can only cancel your own jobs",
|
|
"",
|
|
)
|
|
}
|
|
|
|
if err := h.queue.CancelTask(task.ID); err != nil {
|
|
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
|
|
}
|
|
|
|
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)))
|
|
}
|
|
|
|
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinPrune)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
pruneType := payload[ProtocolAPIKeyHashLen]
|
|
value := binary.BigEndian.Uint32(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+5])
|
|
|
|
h.logger.Info("prune request", "type", pruneType, "value", value)
|
|
|
|
var keepCount int
|
|
var olderThanDays int
|
|
|
|
switch pruneType {
|
|
case 0:
|
|
keepCount = int(value)
|
|
case 1:
|
|
olderThanDays = int(value)
|
|
default:
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeInvalidRequest,
|
|
fmt.Sprintf("invalid prune type: %d", pruneType),
|
|
"",
|
|
)
|
|
}
|
|
|
|
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
|
|
if err != nil {
|
|
h.logger.Error("prune failed", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
|
|
}
|
|
if h.queue != nil {
|
|
_ = h.queue.SignalPrewarmGC()
|
|
}
|
|
|
|
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))))
|
|
}
|
|
|
|
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinLogMetric)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen]
|
|
step := int(binary.BigEndian.Uint32(payload[36:40]))
|
|
valueBits := binary.BigEndian.Uint64(payload[40:48])
|
|
value := math.Float64frombits(valueBits)
|
|
nameLen := int(payload[48])
|
|
|
|
if len(payload) < 49+nameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
|
}
|
|
|
|
name := string(payload[49 : 49+nameLen])
|
|
|
|
if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil {
|
|
h.logger.Error("failed to log metric", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
|
|
}
|
|
|
|
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinGetExperiment)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen]
|
|
|
|
meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID))
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
|
|
}
|
|
|
|
metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID))
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
|
|
}
|
|
|
|
var dbMeta *storage.ExperimentWithMetadata
|
|
if h.db != nil {
|
|
ctx, cancel := helpers.DBContextShort()
|
|
defer cancel()
|
|
m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID))
|
|
if err == nil {
|
|
dbMeta = m
|
|
}
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"metadata": meta,
|
|
"metrics": metrics,
|
|
}
|
|
if dbMeta != nil {
|
|
response["reproducibility"] = dbMeta
|
|
}
|
|
|
|
responseData, err := json.Marshal(response)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Failed to serialize response",
|
|
err.Error(),
|
|
)
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData))
|
|
}
|
|
|
|
// handleGetLogs handles requests to fetch logs for a task/run
|
|
func (h *WSHandler) handleGetLogs(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinGetLogs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
targetIDLen := int(payload[ProtocolAPIKeyHashLen])
|
|
if len(payload) < ProtocolMinGetLogs+targetIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", fmt.Sprintf("got %d, need %d", len(payload), ProtocolMinGetLogs+targetIDLen))
|
|
}
|
|
|
|
targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen])
|
|
h.logger.Info("get logs request", "target_id", targetID, "user", user.Name)
|
|
|
|
// TODO: Implement actual log fetching from storage
|
|
// For now, return a stub response
|
|
response := map[string]interface{}{
|
|
"target_id": targetID,
|
|
"logs": "[Stub] Log content would appear here\nLine 1: Log output\nLine 2: More output\n",
|
|
"truncated": false,
|
|
"total_lines": 3,
|
|
}
|
|
|
|
responseData, err := json.Marshal(response)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewDataPacket("logs", responseData))
|
|
}
|
|
|
|
// handleStreamLogs handles requests to stream logs in real-time
|
|
func (h *WSHandler) handleStreamLogs(conn *websocket.Conn, payload []byte) error {
|
|
user, err := h.authenticate(conn, payload, ProtocolMinStreamLogs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := h.requirePermission(user, PermJobsRead, conn); err != nil {
|
|
return err
|
|
}
|
|
|
|
targetIDLen := int(payload[ProtocolAPIKeyHashLen])
|
|
if len(payload) < ProtocolMinStreamLogs+targetIDLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", "")
|
|
}
|
|
|
|
targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen])
|
|
h.logger.Info("stream logs request", "target_id", targetID, "user", user.Name)
|
|
|
|
// TODO: Implement actual log streaming
|
|
// For now, return a stub response indicating streaming started
|
|
response := map[string]interface{}{
|
|
"target_id": targetID,
|
|
"streaming": true,
|
|
"message": "[Stub] Log streaming would start here. This feature is not yet fully implemented.",
|
|
}
|
|
|
|
responseData, err := json.Marshal(response)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewDataPacket("logs_stream", responseData))
|
|
}
|