feat(core): API, worker, queue, and manifest improvements

- Add protocol buffer optimizations (internal/api/protocol.go)
- Add filesystem queue backend (internal/queue/filesystem_queue.go)
- Add run manifest support (internal/manifest/run_manifest.go)
- Worker and jupyter task refinements
- Exported test wrappers for benchmarking
This commit is contained in:
Jeremie Fraeys 2026-02-12 12:05:17 -05:00
parent 8e3fa94322
commit 2e701340e5
No known key found for this signature in database
16 changed files with 1877 additions and 46 deletions

View file

@ -308,7 +308,6 @@ func GetErrorMessage(code byte) string {
return "Resource not found"
case ErrorCodeResourceAlreadyExists:
return "Resource already exists"
case ErrorCodeServerOverloaded:
return "Server is overloaded"
case ErrorCodeDatabaseError:
@ -319,7 +318,6 @@ func GetErrorMessage(code byte) string {
return "Storage error occurred"
case ErrorCodeTimeout:
return "Operation timed out"
case ErrorCodeJobNotFound:
return "Job not found"
case ErrorCodeJobAlreadyRunning:
@ -330,7 +328,6 @@ func GetErrorMessage(code byte) string {
return "Job execution failed"
case ErrorCodeJobCancelled:
return "Job was cancelled"
case ErrorCodeOutOfMemory:
return "Server out of memory"
case ErrorCodeDiskFull:
@ -339,7 +336,6 @@ func GetErrorMessage(code byte) string {
return "Invalid server configuration"
case ErrorCodeServiceUnavailable:
return "Service temporarily unavailable"
default:
return "Unknown error code"
}

View file

@ -159,6 +159,8 @@ func (s *Server) initTaskQueue() error {
RedisPassword: s.config.Redis.Password,
RedisDB: s.config.Redis.DB,
SQLitePath: s.config.Queue.SQLitePath,
FilesystemPath: s.config.Queue.FilesystemPath,
FallbackToFilesystem: s.config.Queue.FallbackToFilesystem,
MetricsFlushInterval: 0,
}

View file

@ -15,8 +15,10 @@ import (
)
type QueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
FilesystemPath string `yaml:"filesystem_path"`
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
}
// ServerConfig holds all server configuration
@ -172,8 +174,8 @@ func (c *ServerConfig) Validate() error {
backend = "redis"
c.Queue.Backend = backend
}
if backend != "redis" && backend != "sqlite" {
return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'")
if backend != "redis" && backend != "sqlite" && backend != "filesystem" {
return fmt.Errorf("queue.backend must be one of 'redis', 'sqlite', or 'filesystem'")
}
if backend == "sqlite" {
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
@ -184,6 +186,15 @@ func (c *ServerConfig) Validate() error {
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
}
}
if backend == "filesystem" || c.Queue.FallbackToFilesystem {
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
c.Queue.FilesystemPath = filepath.Join(c.DataDir, "queue-fs")
}
c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath)
if !filepath.IsAbs(c.Queue.FilesystemPath) {
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
}
}
return nil
}

View file

@ -35,11 +35,16 @@ const (
OpcodeGetExperiment = 0x0B
OpcodeQueueJobWithTracking = 0x0C
OpcodeQueueJobWithSnapshot = 0x17
OpcodeQueueJobWithArgs = 0x1A
OpcodeQueueJobWithNote = 0x1B
OpcodeAnnotateRun = 0x1C
OpcodeSetRunNarrative = 0x1D
OpcodeStartJupyter = 0x0D
OpcodeStopJupyter = 0x0E
OpcodeRemoveJupyter = 0x18
OpcodeRestoreJupyter = 0x19
OpcodeListJupyter = 0x0F
OpcodeListJupyterPackages = 0x1E
OpcodeValidateRequest = 0x16
)
@ -243,6 +248,14 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
return h.handleQueueJobWithTracking(conn, payload)
case OpcodeQueueJobWithSnapshot:
return h.handleQueueJobWithSnapshot(conn, payload)
case OpcodeQueueJobWithArgs:
return h.handleQueueJobWithArgs(conn, payload)
case OpcodeQueueJobWithNote:
return h.handleQueueJobWithNote(conn, payload)
case OpcodeAnnotateRun:
return h.handleAnnotateRun(conn, payload)
case OpcodeSetRunNarrative:
return h.handleSetRunNarrative(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
@ -271,6 +284,8 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
return h.handleRestoreJupyter(conn, payload)
case OpcodeListJupyter:
return h.handleListJupyter(conn, payload)
case OpcodeListJupyterPackages:
return h.handleListJupyterPackages(conn, payload)
case OpcodeValidateRequest:
return h.handleValidateRequest(conn, payload)
default:

View file

@ -18,14 +18,220 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/jfraeys/fetch_ml/internal/worker"
)
func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
if len(payload) < 16+1+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
authorLen := int(payload[offset])
offset += 1
if authorLen < 0 || len(payload) < offset+authorLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
}
author := string(payload[offset : offset+authorLen])
offset += authorLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen <= 0 || len(payload) < offset+noteLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
}
note := string(payload[offset : offset+noteLen])
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Permission model: if auth is enabled, require jobs:update to mutate shared run artifacts.
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to annotate runs", "")
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := config.NewJobPaths(base)
typedRoots := []struct {
root string
}{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
// Default author to authenticated user if empty.
if strings.TrimSpace(author) == "" {
author = user.Name
}
rm.AddAnnotation(time.Now().UTC(), author, note)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added"))
}
func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "")
}
patchJSON := payload[offset : offset+patchLen]
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to update run narrative", "")
}
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := config.NewJobPaths(base)
typedRoots := []struct{ root string }{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
manifestDir = dir
break
}
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
}
rm, err := manifest.LoadFromDir(manifestDir)
if err != nil || rm == nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
}
var patch manifest.NarrativePatch
if err := json.Unmarshal(patchJSON, &patch); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error())
}
rm.ApplyNarrativePatch(patch)
if err := rm.WriteToDir(manifestDir); err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated"))
}
func fileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
@ -668,6 +874,294 @@ func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []b
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources)
}
type queueJobWithArgsPayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
resources *resourceRequest
}
type queueJobWithNotePayload struct {
apiKeyHash []byte
commitID []byte
priority int64
jobName string
args string
note string
resources *resourceRequest
}
func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) {
// Protocol:
// [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
// [args_len:2][args:var][note_len:2][note:var][resources?:var]
if len(payload) < 42 {
return nil, fmt.Errorf("queue job with note payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return nil, fmt.Errorf("invalid job name length")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if argsLen < 0 || len(payload) < offset+argsLen+2 {
return nil, fmt.Errorf("invalid args length")
}
args := ""
if argsLen > 0 {
args = string(payload[offset : offset+argsLen])
}
offset += argsLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen < 0 || len(payload) < offset+noteLen {
return nil, fmt.Errorf("invalid note length")
}
note := ""
if noteLen > 0 {
note = string(payload[offset : offset+noteLen])
}
offset += noteLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return nil, resErr
}
return &queueJobWithNotePayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
note: note,
resources: resources,
}, nil
}
func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) {
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][resources?:var]
if len(payload) < 40 {
return nil, fmt.Errorf("queue job with args payload too short")
}
apiKeyHash := payload[:16]
commitID := payload[16:36]
priority := int64(payload[36])
jobNameLen := int(payload[37])
if len(payload) < 38+jobNameLen+2 {
return nil, fmt.Errorf("invalid job name length")
}
jobName := string(payload[38 : 38+jobNameLen])
offset := 38 + jobNameLen
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if argsLen < 0 || len(payload) < offset+argsLen {
return nil, fmt.Errorf("invalid args length")
}
args := ""
if argsLen > 0 {
args = string(payload[offset : offset+argsLen])
}
offset += argsLen
resources, resErr := parseOptionalResourceRequest(payload[offset:])
if resErr != nil {
return nil, resErr
}
return &queueJobWithArgsPayload{
apiKeyHash: apiKeyHash,
commitID: commitID,
priority: priority,
jobName: jobName,
args: args,
resources: resources,
}, nil
}
func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithArgsPayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error())
}
h.logger.Info("queue job request",
"job", p.jobName,
"priority", p.priority,
"commit_id", fmt.Sprintf("%x", p.commitID),
)
// Validate API key and get user information
var user *auth.User
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", p.jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
commitIDStr := fmt.Sprintf("%x", p.commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(commitIDStr)
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: p.jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
}
return h.enqueueTaskAndRespondWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, nil, p.resources)
}
func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error {
p, err := parseQueueJobWithNotePayload(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error())
}
h.logger.Info("queue job request",
"job", p.jobName,
"priority", p.priority,
"commit_id", fmt.Sprintf("%x", p.commitID),
)
var user *auth.User
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info(
"job queued",
"job", p.jobName,
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
"user", user.Name,
)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to create jobs",
"",
)
}
commitIDStr := fmt.Sprintf("%x", p.commitID)
if _, err := telemetry.ExecWithMetrics(
h.logger,
"experiment.create",
50*time.Millisecond,
func() (string, error) {
return "", h.expManager.CreateExperiment(commitIDStr)
},
); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: p.jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
}); err != nil {
h.logger.Error("failed to ensure minimal experiment files", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
}
return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, nil, p.resources)
}
// enqueueTaskAndRespond enqueues a task and sends a success response.
func (h *WSHandler) enqueueTaskAndRespond(
conn *websocket.Conn,
@ -677,6 +1171,112 @@ func (h *WSHandler) enqueueTaskAndRespond(
commitID []byte,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", tracking, resources)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
note string,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
commitIDStr := fmt.Sprintf("%x", commitID)
prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr)
if provErr != nil {
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
"commit_id", commitIDStr,
"error", provErr)
return h.sendErrorPacket(
conn,
ErrorCodeStorageError,
"Failed to compute expected provenance",
provErr.Error(),
)
}
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitIDStr,
},
Tracking: tracking,
}
if strings.TrimSpace(note) != "" {
task.Metadata["note"] = strings.TrimSpace(note)
}
for k, v := range prov {
if v != "" {
task.Metadata[k] = v
}
}
if resources != nil {
task.CPU = resources.CPU
task.MemoryGB = resources.MemoryGB
task.GPU = resources.GPU
task.GPUMemory = resources.GPUMemory
}
if _, err := telemetry.ExecWithMetrics(
h.logger,
"queue.add_task",
20*time.Millisecond,
func() (string, error) {
return "", h.queue.AddTask(task)
},
); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeDatabaseError,
"Failed to enqueue task",
err.Error(),
)
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Internal error",
"Failed to serialize response",
)
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) enqueueTaskAndRespondWithArgs(
conn *websocket.Conn,
user *auth.User,
jobName string,
priority int64,
commitID []byte,
args string,
tracking *queue.TrackingConfig,
resources *resourceRequest,
) error {
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
@ -700,7 +1300,7 @@ func (h *WSHandler) enqueueTaskAndRespond(
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Args: strings.TrimSpace(args),
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),

View file

@ -13,10 +13,47 @@ import (
"github.com/jfraeys/fetch_ml/internal/queue"
)
func JupyterTaskErrorCode(t *queue.Task) byte {
if t == nil {
return ErrorCodeUnknownError
}
status := strings.ToLower(strings.TrimSpace(t.Status))
errStr := strings.ToLower(strings.TrimSpace(t.Error))
if status == "cancelled" {
return ErrorCodeJobCancelled
}
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
return ErrorCodeOutOfMemory
}
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
return ErrorCodeDiskFull
}
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
return ErrorCodeServiceUnavailable
}
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
return ErrorCodeTimeout
}
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
return ErrorCodeNetworkError
}
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
return ErrorCodeInvalidConfiguration
}
// Default for worker-side execution failures.
if status == "failed" {
return ErrorCodeJobExecutionFailed
}
return ErrorCodeUnknownError
}
type jupyterTaskOutput struct {
Type string `json:"type"`
Service json.RawMessage `json:"service,omitempty"`
Services json.RawMessage `json:"services,omitempty"`
Packages json.RawMessage `json:"packages,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
@ -76,7 +113,7 @@ func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) e
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
}
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
@ -102,18 +139,98 @@ const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterActionListPkgs = "list_packages"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
)
func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:read") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
name = strings.TrimSpace(name)
if name == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionListPkgs,
jupyterNameKey: name,
}
jobName := fmt.Sprintf("jupyter-packages-%s", name)
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter packages list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter packages list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)
if out == "" {
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
payload := payloadOut.Packages
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
}
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
if h.queue == nil {
return "", fmt.Errorf("task queue not configured")
@ -260,7 +377,7 @@ func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) err
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
}
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details)
}
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
@ -336,7 +453,7 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
}
@ -405,7 +522,7 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
}
@ -456,7 +573,7 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)

View file

@ -12,7 +12,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/logging"
)
var defaultBlockedPackages = []string{"requests", "urllib3", "httpx"}
var defaultBlockedPackages = []string{}
func DefaultBlockedPackages() []string {
return append([]string{}, defaultBlockedPackages...)

View file

@ -203,6 +203,36 @@ type ServiceManager struct {
services map[string]*JupyterService
workspaceMetadataMgr *WorkspaceMetadataManager
securityMgr *SecurityManager
startupBlockedPkgs []string
}
func splitPackageList(value string) []string {
value = strings.TrimSpace(value)
if value == "" {
return nil
}
parts := strings.Split(value, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
out = append(out, p)
}
return out
}
func startupBlockedPackages(installBlocked []string) []string {
val, ok := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
if !ok {
return append([]string{}, installBlocked...)
}
val = strings.TrimSpace(val)
if val == "" || strings.EqualFold(val, "off") || strings.EqualFold(val, "none") || strings.EqualFold(val, "disabled") {
return []string{}
}
return splitPackageList(val)
}
// ServiceConfig holds configuration for Jupyter services
@ -270,6 +300,12 @@ type JupyterService struct {
Metadata map[string]string `json:"metadata"`
}
type InstalledPackage struct {
Name string `json:"name"`
Version string `json:"version"`
Source string `json:"source"`
}
// StartRequest defines parameters for starting a Jupyter service
type StartRequest struct {
Name string `json:"name"`
@ -316,6 +352,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM
}
securityMgr := NewSecurityManager(logger, securityConfig)
startupBlockedPkgs := startupBlockedPackages(securityConfig.BlockedPackages)
sm := &ServiceManager{
logger: logger,
@ -324,6 +361,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM
services: make(map[string]*JupyterService),
workspaceMetadataMgr: workspaceMetadataMgr,
securityMgr: securityMgr,
startupBlockedPkgs: startupBlockedPkgs,
}
// Load existing services
@ -421,6 +459,10 @@ func (sm *ServiceManager) StartService(
// checkPackageBlacklist validates that no blacklisted packages are installed in the container
func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID string) error {
if len(sm.startupBlockedPkgs) == 0 {
return nil
}
// Get list of installed packages from the container
// Try both pip and conda package managers
packages, err := sm.getInstalledPackages(ctx, containerID)
@ -430,19 +472,14 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID
return nil
}
// Check each installed package against the blacklist
// Check each installed package against the startup blacklist
var blockedPackages []string
for _, pkg := range packages {
// Create a package request for validation
pkgReq := &PackageRequest{
PackageName: pkg,
RequestedBy: "system",
Channel: "",
Version: "",
}
if err := sm.securityMgr.ValidatePackageRequest(pkgReq); err != nil {
blockedPackages = append(blockedPackages, pkg)
for _, blocked := range sm.startupBlockedPkgs {
if strings.EqualFold(blocked, pkg) {
blockedPackages = append(blockedPackages, pkg)
break
}
}
}
@ -450,7 +487,7 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID
if len(blockedPackages) > 0 {
return fmt.Errorf("container startup failed: blacklisted packages detected: %v. "+
"These packages are blocked by security policy. "+
"Remove them from the image or use FETCHML_JUPYTER_BLOCKED_PACKAGES to configure the blacklist",
"Remove them from the image or use FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES to configure the startup blacklist",
blockedPackages)
}
@ -508,6 +545,88 @@ func (sm *ServiceManager) parsePipList(output string) []string {
return packages
}
func (sm *ServiceManager) serviceByName(name string) *JupyterService {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
for _, svc := range sm.services {
if svc == nil {
continue
}
if strings.EqualFold(strings.TrimSpace(svc.Name), name) {
return svc
}
}
return nil
}
func (sm *ServiceManager) listInstalledPackages(ctx context.Context, containerID string) ([]InstalledPackage, error) {
var pkgs []InstalledPackage
// pip
pipJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"pip", "list", "--format=json"})
if err == nil {
var parsed []struct {
Name string `json:"name"`
Version string `json:"version"`
}
if json.Unmarshal([]byte(pipJSON), &parsed) == nil {
for _, p := range parsed {
name := strings.TrimSpace(p.Name)
if name == "" {
continue
}
pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "pip"})
}
}
}
// conda
condaJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"conda", "list", "--json"})
if err == nil {
var parsed []struct {
Name string `json:"name"`
Version string `json:"version"`
}
if json.Unmarshal([]byte(condaJSON), &parsed) == nil {
for _, p := range parsed {
name := strings.TrimSpace(p.Name)
if name == "" {
continue
}
pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "conda"})
}
}
}
seen := make(map[string]bool)
out := make([]InstalledPackage, 0, len(pkgs))
for _, p := range pkgs {
key := strings.ToLower(strings.TrimSpace(p.Name)) + ":" + strings.ToLower(strings.TrimSpace(p.Source))
if seen[key] {
continue
}
seen[key] = true
out = append(out, p)
}
return out, nil
}
func (sm *ServiceManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]InstalledPackage, error) {
if sm == nil {
return nil, fmt.Errorf("service manager is nil")
}
svc := sm.serviceByName(serviceName)
if svc == nil {
return nil, fmt.Errorf("service %s not found", strings.TrimSpace(serviceName))
}
if strings.TrimSpace(svc.ContainerID) == "" {
return nil, fmt.Errorf("service container not available")
}
return sm.listInstalledPackages(ctx, svc.ContainerID)
}
// parseCondaList parses conda list --export output
func (sm *ServiceManager) parseCondaList(output string) []string {
var packages []string

View file

@ -0,0 +1,57 @@
package jupyter
import (
"os"
"testing"
)
func TestStartupBlockedPackages_DefaultInheritsInstallBlocked(t *testing.T) {
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
_, hadStartup := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
_ = os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
if hadStartup {
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
} else {
defer os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
}
cfg := DefaultEnhancedSecurityConfigFromEnv()
startup := startupBlockedPackages(cfg.BlockedPackages)
if len(startup) != 2 {
t.Fatalf("expected startup list to inherit 2 items, got %d", len(startup))
}
}
func TestStartupBlockedPackages_Disabled(t *testing.T) {
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "off")
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
cfg := DefaultEnhancedSecurityConfigFromEnv()
startup := startupBlockedPackages(cfg.BlockedPackages)
if len(startup) != 0 {
t.Fatalf("expected startup list to be disabled, got %d", len(startup))
}
}
func TestStartupBlockedPackages_ExplicitList(t *testing.T) {
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "aiohttp")
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
cfg := DefaultEnhancedSecurityConfigFromEnv()
startup := startupBlockedPackages(cfg.BlockedPackages)
if len(startup) != 1 || startup[0] != "aiohttp" {
t.Fatalf("expected explicit startup list [aiohttp], got %v", startup)
}
}

View file

@ -0,0 +1,226 @@
package manifest
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
const runManifestFilename = "run_manifest.json"
type Annotation struct {
Timestamp time.Time `json:"timestamp"`
Author string `json:"author,omitempty"`
Note string `json:"note"`
}
func (a *Annotation) UnmarshalJSON(data []byte) error {
type annotationWire struct {
Timestamp *time.Time `json:"timestamp,omitempty"`
TS *time.Time `json:"ts,omitempty"`
Author string `json:"author,omitempty"`
Note string `json:"note"`
}
var w annotationWire
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Timestamp != nil {
a.Timestamp = *w.Timestamp
} else if w.TS != nil {
a.Timestamp = *w.TS
}
a.Author = w.Author
a.Note = w.Note
return nil
}
type Narrative struct {
Hypothesis string `json:"hypothesis,omitempty"`
Context string `json:"context,omitempty"`
Intent string `json:"intent,omitempty"`
ExpectedOutcome string `json:"expected_outcome,omitempty"`
ParentRun string `json:"parent_run,omitempty"`
ExperimentGroup string `json:"experiment_group,omitempty"`
Tags []string `json:"tags,omitempty"`
}
type NarrativePatch struct {
Hypothesis *string `json:"hypothesis,omitempty"`
Context *string `json:"context,omitempty"`
Intent *string `json:"intent,omitempty"`
ExpectedOutcome *string `json:"expected_outcome,omitempty"`
ParentRun *string `json:"parent_run,omitempty"`
ExperimentGroup *string `json:"experiment_group,omitempty"`
Tags *[]string `json:"tags,omitempty"`
}
type ArtifactFile struct {
Path string `json:"path"`
SizeBytes int64 `json:"size_bytes"`
Modified time.Time `json:"modified"`
}
type Artifacts struct {
DiscoveryTime time.Time `json:"discovery_time"`
Files []ArtifactFile `json:"files,omitempty"`
TotalSizeBytes int64 `json:"total_size_bytes,omitempty"`
}
// RunManifest is a best-effort, self-contained provenance record for a run.
// It is written to <run_dir>/run_manifest.json.
type RunManifest struct {
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
JobName string `json:"job_name"`
CreatedAt time.Time `json:"created_at"`
StartedAt time.Time `json:"started_at,omitempty"`
EndedAt time.Time `json:"ended_at,omitempty"`
Annotations []Annotation `json:"annotations,omitempty"`
Narrative *Narrative `json:"narrative,omitempty"`
Artifacts *Artifacts `json:"artifacts,omitempty"`
CommitID string `json:"commit_id,omitempty"`
ExperimentManifestSHA string `json:"experiment_manifest_sha,omitempty"`
DepsManifestName string `json:"deps_manifest_name,omitempty"`
DepsManifestSHA string `json:"deps_manifest_sha,omitempty"`
TrainScriptPath string `json:"train_script_path,omitempty"`
WorkerVersion string `json:"worker_version,omitempty"`
PodmanImage string `json:"podman_image,omitempty"`
ImageDigest string `json:"image_digest,omitempty"`
SnapshotID string `json:"snapshot_id,omitempty"`
SnapshotSHA256 string `json:"snapshot_sha256,omitempty"`
Command string `json:"command,omitempty"`
Args string `json:"args,omitempty"`
ExitCode *int `json:"exit_code,omitempty"`
Error string `json:"error,omitempty"`
StagingDurationMS int64 `json:"staging_duration_ms,omitempty"`
ExecutionDurationMS int64 `json:"execution_duration_ms,omitempty"`
FinalizeDurationMS int64 `json:"finalize_duration_ms,omitempty"`
TotalDurationMS int64 `json:"total_duration_ms,omitempty"`
GPUDevices []string `json:"gpu_devices,omitempty"`
WorkerHost string `json:"worker_host,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest {
m := &RunManifest{
RunID: runID,
TaskID: taskID,
JobName: jobName,
CreatedAt: createdAt,
Metadata: make(map[string]string),
}
return m
}
func ManifestPath(dir string) string {
return filepath.Join(dir, runManifestFilename)
}
func (m *RunManifest) WriteToDir(dir string) error {
if m == nil {
return fmt.Errorf("run manifest is nil")
}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return fmt.Errorf("marshal run manifest: %w", err)
}
if err := fileutil.SecureFileWrite(ManifestPath(dir), data, 0640); err != nil {
return fmt.Errorf("write run manifest: %w", err)
}
return nil
}
func LoadFromDir(dir string) (*RunManifest, error) {
data, err := fileutil.SecureFileRead(ManifestPath(dir))
if err != nil {
return nil, fmt.Errorf("read run manifest: %w", err)
}
var m RunManifest
if err := json.Unmarshal(data, &m); err != nil {
return nil, fmt.Errorf("parse run manifest: %w", err)
}
return &m, nil
}
func (m *RunManifest) MarkStarted(t time.Time) {
m.StartedAt = t
}
func (m *RunManifest) MarkFinished(t time.Time, exitCode *int, execErr error) {
m.EndedAt = t
m.ExitCode = exitCode
if execErr != nil {
m.Error = execErr.Error()
} else {
m.Error = ""
}
if !m.StartedAt.IsZero() {
m.TotalDurationMS = m.EndedAt.Sub(m.StartedAt).Milliseconds()
}
}
func (m *RunManifest) AddAnnotation(ts time.Time, author, note string) {
if m == nil {
return
}
n := strings.TrimSpace(note)
if n == "" {
return
}
a := Annotation{
Timestamp: ts,
Author: strings.TrimSpace(author),
Note: n,
}
m.Annotations = append(m.Annotations, a)
}
func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) {
if m == nil {
return
}
if m.Narrative == nil {
m.Narrative = &Narrative{}
}
if p.Hypothesis != nil {
m.Narrative.Hypothesis = strings.TrimSpace(*p.Hypothesis)
}
if p.Context != nil {
m.Narrative.Context = strings.TrimSpace(*p.Context)
}
if p.Intent != nil {
m.Narrative.Intent = strings.TrimSpace(*p.Intent)
}
if p.ExpectedOutcome != nil {
m.Narrative.ExpectedOutcome = strings.TrimSpace(*p.ExpectedOutcome)
}
if p.ParentRun != nil {
m.Narrative.ParentRun = strings.TrimSpace(*p.ParentRun)
}
if p.ExperimentGroup != nil {
m.Narrative.ExperimentGroup = strings.TrimSpace(*p.ExperimentGroup)
}
if p.Tags != nil {
clean := make([]string, 0, len(*p.Tags))
for _, t := range *p.Tags {
t = strings.TrimSpace(t)
if t == "" {
continue
}
clean = append(clean, t)
}
m.Narrative.Tags = clean
}
}

View file

@ -2,6 +2,8 @@ package queue
import (
"errors"
"fmt"
"strings"
"time"
)
@ -48,6 +50,7 @@ type QueueBackend string
const (
QueueBackendRedis QueueBackend = "redis"
QueueBackendSQLite QueueBackend = "sqlite"
QueueBackendFS QueueBackend = "filesystem"
)
type BackendConfig struct {
@ -56,20 +59,49 @@ type BackendConfig struct {
RedisPassword string
RedisDB int
SQLitePath string
FilesystemPath string
FallbackToFilesystem bool
MetricsFlushInterval time.Duration
}
func NewBackend(cfg BackendConfig) (Backend, error) {
mkFallback := func(err error) (Backend, error) {
if !cfg.FallbackToFilesystem {
return nil, err
}
if strings.TrimSpace(cfg.FilesystemPath) == "" {
return nil, fmt.Errorf("filesystem queue path is required for fallback")
}
fsq, fsErr := NewFilesystemQueue(cfg.FilesystemPath)
if fsErr != nil {
return nil, fmt.Errorf("filesystem queue fallback init failed: %w", fsErr)
}
return fsq, nil
}
switch cfg.Backend {
case QueueBackendFS:
if strings.TrimSpace(cfg.FilesystemPath) == "" {
return nil, fmt.Errorf("filesystem queue path is required")
}
return NewFilesystemQueue(cfg.FilesystemPath)
case "", QueueBackendRedis:
return NewTaskQueue(Config{
b, err := NewTaskQueue(Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
MetricsFlushInterval: cfg.MetricsFlushInterval,
})
if err != nil {
return mkFallback(err)
}
return b, nil
case QueueBackendSQLite:
return NewSQLiteQueue(cfg.SQLitePath)
b, err := NewSQLiteQueue(cfg.SQLitePath)
if err != nil {
return mkFallback(err)
}
return b, nil
default:
return nil, ErrInvalidQueueBackend
}

View file

@ -0,0 +1,572 @@
package queue
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
type FilesystemQueue struct {
root string
ctx context.Context
cancel context.CancelFunc
}
type filesystemQueueIndex struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at"`
Tasks []filesystemQueueIndexTask `json:"tasks"`
}
type filesystemQueueIndexTask struct {
ID string `json:"id"`
Priority int64 `json:"priority"`
CreatedAt string `json:"created_at"`
}
func NewFilesystemQueue(root string) (*FilesystemQueue, error) {
root = strings.TrimSpace(root)
if root == "" {
return nil, fmt.Errorf("filesystem queue root is required")
}
root = filepath.Clean(root)
if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil {
return nil, err
}
for _, d := range []string{"running", "finished", "failed"} {
if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil {
return nil, err
}
}
ctx, cancel := context.WithCancel(context.Background())
q := &FilesystemQueue{root: root, ctx: ctx, cancel: cancel}
_ = q.rebuildIndex()
return q, nil
}
func (q *FilesystemQueue) Close() error {
q.cancel()
return nil
}
func (q *FilesystemQueue) AddTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if strings.TrimSpace(task.ID) == "" {
return fmt.Errorf("task id is required")
}
if strings.TrimSpace(task.JobName) == "" {
return fmt.Errorf("job name is required")
}
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
if strings.TrimSpace(task.Status) == "" {
task.Status = "queued"
}
if task.Status != "queued" {
// For filesystem backend we only enqueue queued tasks.
// Other status updates should go through UpdateTask.
task.Status = "queued"
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
path := q.pendingEntryPath(task.ID)
if err := writeFileAtomic(path, payload, 0640); err != nil {
return err
}
TasksQueued.Inc()
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
_ = q.rebuildIndex()
return nil
}
func (q *FilesystemQueue) GetNextTask() (*Task, error) {
return q.claimNext("", 0, false)
}
func (q *FilesystemQueue) PeekNextTask() (*Task, error) {
return q.claimNext("", 0, true)
}
func (q *FilesystemQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
return q.claimNext(workerID, leaseDuration, false)
}
func (q *FilesystemQueue) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration, blockTimeout time.Duration,
) (*Task, error) {
if blockTimeout <= 0 {
blockTimeout = defaultBlockTimeout
}
deadline := time.Now().Add(blockTimeout)
for {
t, err := q.claimNext(workerID, leaseDuration, false)
if err != nil {
return nil, err
}
if t != nil {
return t, nil
}
if time.Now().After(deadline) {
return nil, nil
}
select {
case <-q.ctx.Done():
return nil, q.ctx.Err()
case <-time.After(50 * time.Millisecond):
}
}
}
func (q *FilesystemQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
// Single-worker friendly best-effort: update task lease fields if present.
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
exp := time.Now().UTC().Add(leaseDuration)
t.LeaseExpiry = &exp
if workerID != "" {
t.LeasedBy = workerID
}
RecordLeaseRenewal(workerID)
return q.UpdateTask(t)
}
func (q *FilesystemQueue) ReleaseLease(taskID string, workerID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
}
t.LeaseExpiry = nil
t.LeasedBy = ""
return q.UpdateTask(t)
}
func (q *FilesystemQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
RecordDLQAddition("max_retries")
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
}
errorCategory := ErrorUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
}
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error
task.Error = ""
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
task.LeaseExpiry = nil
task.LeasedBy = ""
RecordTaskRetry(task.JobName, errorCategory)
return q.AddTask(task)
}
func (q *FilesystemQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
if task == nil {
return fmt.Errorf("task is nil")
}
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
return q.UpdateTask(task)
}
func (q *FilesystemQueue) GetTask(taskID string) (*Task, error) {
path, err := q.findTaskPath(taskID)
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(data, &t); err != nil {
return nil, err
}
return &t, nil
}
func (q *FilesystemQueue) GetAllTasks() ([]*Task, error) {
paths := make([]string, 0, 32)
for _, p := range []string{
filepath.Join(q.root, "pending", "entries"),
filepath.Join(q.root, "running"),
filepath.Join(q.root, "finished"),
filepath.Join(q.root, "failed"),
} {
entries, err := os.ReadDir(p)
if err != nil {
continue
}
for _, e := range entries {
if e.IsDir() {
continue
}
if !strings.HasSuffix(e.Name(), ".json") {
continue
}
paths = append(paths, filepath.Join(p, e.Name()))
}
}
out := make([]*Task, 0, len(paths))
for _, path := range paths {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var t Task
if err := json.Unmarshal(data, &t); err != nil {
continue
}
out = append(out, &t)
}
return out, nil
}
func (q *FilesystemQueue) GetTaskByName(jobName string) (*Task, error) {
jobName = strings.TrimSpace(jobName)
if jobName == "" {
return nil, fmt.Errorf("job name is required")
}
tasks, err := q.GetAllTasks()
if err != nil {
return nil, err
}
var best *Task
for _, t := range tasks {
if t == nil || t.JobName != jobName {
continue
}
if best == nil || t.CreatedAt.After(best.CreatedAt) {
best = t
}
}
if best == nil {
return nil, os.ErrNotExist
}
return best, nil
}
func (q *FilesystemQueue) CancelTask(taskID string) error {
t, err := q.GetTask(taskID)
if err != nil {
return err
}
t.Status = "cancelled"
now := time.Now().UTC()
t.EndedAt = &now
return q.UpdateTask(t)
}
func (q *FilesystemQueue) UpdateTask(task *Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if strings.TrimSpace(task.ID) == "" {
return fmt.Errorf("task id is required")
}
if strings.TrimSpace(task.Status) == "" {
return fmt.Errorf("task status is required")
}
payload, err := json.Marshal(task)
if err != nil {
return err
}
dst := q.pathForStatus(task.Status, task.ID)
if dst == "" {
// For statuses we don't map yet, keep it in running.
dst = filepath.Join(q.root, "running", task.ID+".json")
}
if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
return err
}
// Best-effort: remove any other copies before writing.
_ = q.removeTaskFromAllDirs(task.ID)
if err := writeFileAtomic(dst, payload, 0640); err != nil {
return err
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
_ = q.rebuildIndex()
return nil
}
func (q *FilesystemQueue) UpdateTaskWithMetrics(task *Task, _ string) error {
return q.UpdateTask(task)
}
func (q *FilesystemQueue) RecordMetric(_, _ string, _ float64) error {
return nil
}
func (q *FilesystemQueue) Heartbeat(_ string) error {
return nil
}
func (q *FilesystemQueue) QueueDepth() (int64, error) {
entries, err := os.ReadDir(filepath.Join(q.root, "pending", "entries"))
if err != nil {
return 0, err
}
var n int64
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".json") {
n++
}
}
return n, nil
}
func (q *FilesystemQueue) SetWorkerPrewarmState(_ PrewarmState) error { return nil }
func (q *FilesystemQueue) ClearWorkerPrewarmState(_ string) error { return nil }
func (q *FilesystemQueue) GetWorkerPrewarmState(_ string) (*PrewarmState, error) {
return nil, nil
}
func (q *FilesystemQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
return nil, nil
}
func (q *FilesystemQueue) SignalPrewarmGC() error { return nil }
func (q *FilesystemQueue) PrewarmGCRequestValue() (string, error) {
return "", nil
}
func (q *FilesystemQueue) claimNext(workerID string, leaseDuration time.Duration, peek bool) (*Task, error) {
pendingDir := filepath.Join(q.root, "pending", "entries")
entries, err := os.ReadDir(pendingDir)
if err != nil {
return nil, err
}
candidates := make([]*Task, 0, len(entries))
paths := make(map[string]string, len(entries))
for _, e := range entries {
if e.IsDir() {
continue
}
if !strings.HasSuffix(e.Name(), ".json") {
continue
}
path := filepath.Join(pendingDir, e.Name())
data, err := os.ReadFile(path)
if err != nil {
continue
}
var t Task
if err := json.Unmarshal(data, &t); err != nil {
continue
}
if t.NextRetry != nil && time.Now().UTC().Before(t.NextRetry.UTC()) {
continue
}
candidates = append(candidates, &t)
paths[t.ID] = path
}
if len(candidates) == 0 {
return nil, nil
}
sort.Slice(candidates, func(i, j int) bool {
if candidates[i].Priority != candidates[j].Priority {
return candidates[i].Priority > candidates[j].Priority
}
return candidates[i].CreatedAt.Before(candidates[j].CreatedAt)
})
chosen := candidates[0]
if peek {
return chosen, nil
}
src := paths[chosen.ID]
if src == "" {
return nil, nil
}
dst := filepath.Join(q.root, "running", chosen.ID+".json")
if err := os.Rename(src, dst); err != nil {
// Another process might have claimed it.
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, err
}
// Refresh from the moved file to avoid race on content.
data, err := os.ReadFile(dst)
if err != nil {
return nil, err
}
var t Task
if err := json.Unmarshal(data, &t); err != nil {
return nil, err
}
now := time.Now().UTC()
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
exp := now.Add(leaseDuration)
t.LeaseExpiry = &exp
if strings.TrimSpace(workerID) != "" {
t.LeasedBy = workerID
}
// Note: status transitions are handled by worker UpdateTask calls.
payload, err := json.Marshal(&t)
if err == nil {
_ = writeFileAtomic(dst, payload, 0640)
}
if depth, derr := q.QueueDepth(); derr == nil {
UpdateQueueDepth(depth)
}
_ = q.rebuildIndex()
return &t, nil
}
func (q *FilesystemQueue) pendingEntryPath(taskID string) string {
return filepath.Join(q.root, "pending", "entries", taskID+".json")
}
func (q *FilesystemQueue) pathForStatus(status, taskID string) string {
switch status {
case "queued":
return q.pendingEntryPath(taskID)
case "running":
return filepath.Join(q.root, "running", taskID+".json")
case "completed", "finished":
return filepath.Join(q.root, "finished", taskID+".json")
case "failed", "cancelled":
return filepath.Join(q.root, "failed", taskID+".json")
default:
return ""
}
}
func (q *FilesystemQueue) findTaskPath(taskID string) (string, error) {
paths := []string{
q.pendingEntryPath(taskID),
filepath.Join(q.root, "running", taskID+".json"),
filepath.Join(q.root, "finished", taskID+".json"),
filepath.Join(q.root, "failed", taskID+".json"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", os.ErrNotExist
}
func (q *FilesystemQueue) removeTaskFromAllDirs(taskID string) error {
paths := []string{
q.pendingEntryPath(taskID),
filepath.Join(q.root, "running", taskID+".json"),
filepath.Join(q.root, "finished", taskID+".json"),
filepath.Join(q.root, "failed", taskID+".json"),
}
var outErr error
for _, p := range paths {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) {
outErr = err
}
}
return outErr
}
func (q *FilesystemQueue) rebuildIndex() error {
pendingDir := filepath.Join(q.root, "pending", "entries")
entries, err := os.ReadDir(pendingDir)
if err != nil {
return err
}
idx := filesystemQueueIndex{Version: 1, UpdatedAt: time.Now().UTC().Format(time.RFC3339)}
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") {
continue
}
data, err := os.ReadFile(filepath.Join(pendingDir, e.Name()))
if err != nil {
continue
}
var t Task
if err := json.Unmarshal(data, &t); err != nil {
continue
}
idx.Tasks = append(idx.Tasks, filesystemQueueIndexTask{ID: t.ID, Priority: t.Priority, CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339Nano)})
}
sort.Slice(idx.Tasks, func(i, j int) bool {
if idx.Tasks[i].Priority != idx.Tasks[j].Priority {
return idx.Tasks[i].Priority > idx.Tasks[j].Priority
}
return idx.Tasks[i].CreatedAt < idx.Tasks[j].CreatedAt
})
payload, err := json.MarshalIndent(&idx, "", " ")
if err != nil {
return err
}
path := filepath.Join(q.root, "pending", ".queue.json")
return writeFileAtomic(path, payload, 0640)
}
func writeFileAtomic(path string, data []byte, perm os.FileMode) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0750); err != nil {
return err
}
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, perm); err != nil {
return err
}
return os.Rename(tmp, path)
}

View file

@ -23,8 +23,10 @@ const (
)
type QueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
FilesystemPath string `yaml:"filesystem_path"`
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
}
// Config holds worker configuration.
@ -203,6 +205,12 @@ func LoadConfig(path string) (*Config, error) {
}
cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath)
}
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem {
if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" {
cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs")
}
cfg.Queue.FilesystemPath = config.ExpandPath(cfg.Queue.FilesystemPath)
}
if strings.TrimSpace(cfg.GPUVendor) == "" {
if cfg.AppleGPU.Enabled {
@ -254,8 +262,8 @@ func (c *Config) Validate() error {
backend = string(queue.QueueBackendRedis)
c.Queue.Backend = backend
}
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) {
return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite)
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) {
return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS)
}
if backend == string(queue.QueueBackendSQLite) {
@ -267,6 +275,15 @@ func (c *Config) Validate() error {
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
}
}
if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem {
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled")
}
c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath)
if !filepath.IsAbs(c.Queue.FilesystemPath) {
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
}
}
if c.RedisAddr != "" {
addr := strings.TrimSpace(c.RedisAddr)

View file

@ -45,6 +45,7 @@ type JupyterManager interface {
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
// isValidName validates that input strings contain only safe characters.
@ -382,6 +383,8 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
SQLitePath: cfg.Queue.SQLitePath,
FilesystemPath: cfg.Queue.FilesystemPath,
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queueClient, err := queue.NewBackend(backendCfg)

View file

@ -253,6 +253,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice
if err := w.stageExperimentFiles(task, jobDir); err != nil {
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
if a, aerr := scanArtifacts(jobDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
now := time.Now().UTC()
exitCode := 1
m.MarkFinished(now, &exitCode, err)
@ -271,6 +276,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice
}
if err := w.stageSnapshot(ctx, task, jobDir); err != nil {
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
if a, aerr := scanArtifacts(jobDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
now := time.Now().UTC()
exitCode := 1
m.MarkFinished(now, &exitCode, err)
@ -586,6 +596,11 @@ func (w *Worker) executeJob(
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
m.ExecutionDurationMS = execDuration.Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
if err != nil {
exitCode := 1
m.MarkFinished(now, &exitCode, err)
@ -832,6 +847,35 @@ func (w *Worker) executeContainerJob(
if trackingEnv == nil {
trackingEnv = make(map[string]string)
}
cacheRoot := filepath.Join(w.config.BasePath, ".cache")
if err := os.MkdirAll(cacheRoot, 0755); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "cache_setup",
Err: err,
}
}
if volumes == nil {
volumes = make(map[string]string)
}
volumes[cacheRoot] = "/workspace/.cache:rw"
defaultEnv := map[string]string{
"HF_HOME": "/workspace/.cache/huggingface",
"TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub",
"HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets",
"TORCH_HOME": "/workspace/.cache/torch",
"TORCH_HUB_DIR": "/workspace/.cache/torch/hub",
"KERAS_HOME": "/workspace/.cache/keras",
"CUDA_CACHE_PATH": "/workspace/.cache/cuda",
"PIP_CACHE_DIR": "/workspace/.cache/pip",
}
for k, v := range defaultEnv {
if _, ok := trackingEnv[k]; ok {
continue
}
trackingEnv[k] = v
}
if strings.TrimSpace(visibleEnvVar) != "" {
trackingEnv[visibleEnvVar] = strings.TrimSpace(visibleDevices)
}
@ -841,9 +885,6 @@ func (w *Worker) executeContainerJob(
if strings.TrimSpace(task.SnapshotID) != "" {
trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID)
}
if volumes == nil {
volumes = make(map[string]string)
}
volumes[snap] = "/snapshot:ro"
}
@ -932,6 +973,11 @@ func (w *Worker) executeContainerJob(
now := time.Now().UTC()
exitCode := 1
m.ExecutionDurationMS = containerDuration.Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
m.MarkFinished(now, &exitCode, err)
})
// Move job to failed directory
@ -981,6 +1027,11 @@ func (w *Worker) executeContainerJob(
now := time.Now().UTC()
exitCode := 0
m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
m.MarkFinished(now, &exitCode, nil)
})
if _, moveErr := telemetry.ExecWithMetrics(

View file

@ -21,6 +21,7 @@ const (
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterActionListPkgs = "list_packages"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
@ -28,10 +29,11 @@ const (
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service *jupyter.JupyterService `json:"service,omitempty"`
Services []*jupyter.JupyterService `json:"services"`
RestorePath string `json:"restore_path,omitempty"`
Type string `json:"type"`
Service *jupyter.JupyterService `json:"service,omitempty"`
Services []*jupyter.JupyterService `json:"services"`
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func isJupyterTask(task *queue.Task) bool {
@ -109,6 +111,17 @@ func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
services := w.jupyter.ListServices()
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
return json.Marshal(out)
case jupyterActionListPkgs:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
return json.Marshal(out)
case jupyterActionRestore:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {