feat(core): API, worker, queue, and manifest improvements
- Add protocol buffer optimizations (internal/api/protocol.go) - Add filesystem queue backend (internal/queue/filesystem_queue.go) - Add run manifest support (internal/manifest/run_manifest.go) - Worker and jupyter task refinements - Exported test wrappers for benchmarking
This commit is contained in:
parent
8e3fa94322
commit
2e701340e5
16 changed files with 1877 additions and 46 deletions
|
|
@ -308,7 +308,6 @@ func GetErrorMessage(code byte) string {
|
|||
return "Resource not found"
|
||||
case ErrorCodeResourceAlreadyExists:
|
||||
return "Resource already exists"
|
||||
|
||||
case ErrorCodeServerOverloaded:
|
||||
return "Server is overloaded"
|
||||
case ErrorCodeDatabaseError:
|
||||
|
|
@ -319,7 +318,6 @@ func GetErrorMessage(code byte) string {
|
|||
return "Storage error occurred"
|
||||
case ErrorCodeTimeout:
|
||||
return "Operation timed out"
|
||||
|
||||
case ErrorCodeJobNotFound:
|
||||
return "Job not found"
|
||||
case ErrorCodeJobAlreadyRunning:
|
||||
|
|
@ -330,7 +328,6 @@ func GetErrorMessage(code byte) string {
|
|||
return "Job execution failed"
|
||||
case ErrorCodeJobCancelled:
|
||||
return "Job was cancelled"
|
||||
|
||||
case ErrorCodeOutOfMemory:
|
||||
return "Server out of memory"
|
||||
case ErrorCodeDiskFull:
|
||||
|
|
@ -339,7 +336,6 @@ func GetErrorMessage(code byte) string {
|
|||
return "Invalid server configuration"
|
||||
case ErrorCodeServiceUnavailable:
|
||||
return "Service temporarily unavailable"
|
||||
|
||||
default:
|
||||
return "Unknown error code"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,6 +159,8 @@ func (s *Server) initTaskQueue() error {
|
|||
RedisPassword: s.config.Redis.Password,
|
||||
RedisDB: s.config.Redis.DB,
|
||||
SQLitePath: s.config.Queue.SQLitePath,
|
||||
FilesystemPath: s.config.Queue.FilesystemPath,
|
||||
FallbackToFilesystem: s.config.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: 0,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@ import (
|
|||
)
|
||||
|
||||
type QueueConfig struct {
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
FilesystemPath string `yaml:"filesystem_path"`
|
||||
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
|
||||
}
|
||||
|
||||
// ServerConfig holds all server configuration
|
||||
|
|
@ -172,8 +174,8 @@ func (c *ServerConfig) Validate() error {
|
|||
backend = "redis"
|
||||
c.Queue.Backend = backend
|
||||
}
|
||||
if backend != "redis" && backend != "sqlite" {
|
||||
return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'")
|
||||
if backend != "redis" && backend != "sqlite" && backend != "filesystem" {
|
||||
return fmt.Errorf("queue.backend must be one of 'redis', 'sqlite', or 'filesystem'")
|
||||
}
|
||||
if backend == "sqlite" {
|
||||
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
|
||||
|
|
@ -184,6 +186,15 @@ func (c *ServerConfig) Validate() error {
|
|||
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
||||
}
|
||||
}
|
||||
if backend == "filesystem" || c.Queue.FallbackToFilesystem {
|
||||
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
|
||||
c.Queue.FilesystemPath = filepath.Join(c.DataDir, "queue-fs")
|
||||
}
|
||||
c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath)
|
||||
if !filepath.IsAbs(c.Queue.FilesystemPath) {
|
||||
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,11 +35,16 @@ const (
|
|||
OpcodeGetExperiment = 0x0B
|
||||
OpcodeQueueJobWithTracking = 0x0C
|
||||
OpcodeQueueJobWithSnapshot = 0x17
|
||||
OpcodeQueueJobWithArgs = 0x1A
|
||||
OpcodeQueueJobWithNote = 0x1B
|
||||
OpcodeAnnotateRun = 0x1C
|
||||
OpcodeSetRunNarrative = 0x1D
|
||||
OpcodeStartJupyter = 0x0D
|
||||
OpcodeStopJupyter = 0x0E
|
||||
OpcodeRemoveJupyter = 0x18
|
||||
OpcodeRestoreJupyter = 0x19
|
||||
OpcodeListJupyter = 0x0F
|
||||
OpcodeListJupyterPackages = 0x1E
|
||||
OpcodeValidateRequest = 0x16
|
||||
)
|
||||
|
||||
|
|
@ -243,6 +248,14 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
|||
return h.handleQueueJobWithTracking(conn, payload)
|
||||
case OpcodeQueueJobWithSnapshot:
|
||||
return h.handleQueueJobWithSnapshot(conn, payload)
|
||||
case OpcodeQueueJobWithArgs:
|
||||
return h.handleQueueJobWithArgs(conn, payload)
|
||||
case OpcodeQueueJobWithNote:
|
||||
return h.handleQueueJobWithNote(conn, payload)
|
||||
case OpcodeAnnotateRun:
|
||||
return h.handleAnnotateRun(conn, payload)
|
||||
case OpcodeSetRunNarrative:
|
||||
return h.handleSetRunNarrative(conn, payload)
|
||||
case OpcodeStatusRequest:
|
||||
return h.handleStatusRequest(conn, payload)
|
||||
case OpcodeCancelJob:
|
||||
|
|
@ -271,6 +284,8 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
|||
return h.handleRestoreJupyter(conn, payload)
|
||||
case OpcodeListJupyter:
|
||||
return h.handleListJupyter(conn, payload)
|
||||
case OpcodeListJupyterPackages:
|
||||
return h.handleListJupyterPackages(conn, payload)
|
||||
case OpcodeValidateRequest:
|
||||
return h.handleValidateRequest(conn, payload)
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -18,14 +18,220 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
|
||||
if len(payload) < 16+1+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
offset := 16
|
||||
|
||||
jobNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
jobName := string(payload[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
authorLen := int(payload[offset])
|
||||
offset += 1
|
||||
if authorLen < 0 || len(payload) < offset+authorLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
|
||||
}
|
||||
author := string(payload[offset : offset+authorLen])
|
||||
offset += authorLen
|
||||
|
||||
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if noteLen <= 0 || len(payload) < offset+noteLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
|
||||
}
|
||||
note := string(payload[offset : offset+noteLen])
|
||||
|
||||
// Validate API key and get user information
|
||||
var user *auth.User
|
||||
var err error
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Permission model: if auth is enabled, require jobs:update to mutate shared run artifacts.
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to annotate runs", "")
|
||||
}
|
||||
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := config.NewJobPaths(base)
|
||||
typedRoots := []struct {
|
||||
root string
|
||||
}{
|
||||
{root: jobPaths.RunningPath()},
|
||||
{root: jobPaths.PendingPath()},
|
||||
{root: jobPaths.FinishedPath()},
|
||||
{root: jobPaths.FailedPath()},
|
||||
}
|
||||
|
||||
var manifestDir string
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
||||
manifestDir = dir
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if manifestDir == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
|
||||
}
|
||||
|
||||
rm, err := manifest.LoadFromDir(manifestDir)
|
||||
if err != nil || rm == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
|
||||
}
|
||||
|
||||
// Default author to authenticated user if empty.
|
||||
if strings.TrimSpace(author) == "" {
|
||||
author = user.Name
|
||||
}
|
||||
rm.AddAnnotation(time.Now().UTC(), author, note)
|
||||
if err := rm.WriteToDir(manifestDir); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added"))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var]
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
offset := 16
|
||||
|
||||
jobNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
jobName := string(payload[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if patchLen <= 0 || len(payload) < offset+patchLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "")
|
||||
}
|
||||
patchJSON := payload[offset : offset+patchLen]
|
||||
|
||||
var user *auth.User
|
||||
var err error
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to update run narrative", "")
|
||||
}
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := config.NewJobPaths(base)
|
||||
typedRoots := []struct{ root string }{
|
||||
{root: jobPaths.RunningPath()},
|
||||
{root: jobPaths.PendingPath()},
|
||||
{root: jobPaths.FinishedPath()},
|
||||
{root: jobPaths.FailedPath()},
|
||||
}
|
||||
|
||||
var manifestDir string
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
||||
manifestDir = dir
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if manifestDir == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "")
|
||||
}
|
||||
|
||||
rm, err := manifest.LoadFromDir(manifestDir)
|
||||
if err != nil || rm == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err))
|
||||
}
|
||||
|
||||
var patch manifest.NarrativePatch
|
||||
if err := json.Unmarshal(patchJSON, &patch); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error())
|
||||
}
|
||||
rm.ApplyNarrativePatch(patch)
|
||||
if err := rm.WriteToDir(manifestDir); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error())
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated"))
|
||||
}
|
||||
|
||||
func fileSHA256Hex(path string) (string, error) {
|
||||
f, err := os.Open(filepath.Clean(path))
|
||||
if err != nil {
|
||||
|
|
@ -668,6 +874,294 @@ func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []b
|
|||
return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources)
|
||||
}
|
||||
|
||||
type queueJobWithArgsPayload struct {
|
||||
apiKeyHash []byte
|
||||
commitID []byte
|
||||
priority int64
|
||||
jobName string
|
||||
args string
|
||||
resources *resourceRequest
|
||||
}
|
||||
|
||||
type queueJobWithNotePayload struct {
|
||||
apiKeyHash []byte
|
||||
commitID []byte
|
||||
priority int64
|
||||
jobName string
|
||||
args string
|
||||
note string
|
||||
resources *resourceRequest
|
||||
}
|
||||
|
||||
func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) {
|
||||
// Protocol:
|
||||
// [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
||||
// [args_len:2][args:var][note_len:2][note:var][resources?:var]
|
||||
if len(payload) < 42 {
|
||||
return nil, fmt.Errorf("queue job with note payload too short")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
commitID := payload[16:36]
|
||||
priority := int64(payload[36])
|
||||
jobNameLen := int(payload[37])
|
||||
if len(payload) < 38+jobNameLen+2 {
|
||||
return nil, fmt.Errorf("invalid job name length")
|
||||
}
|
||||
jobName := string(payload[38 : 38+jobNameLen])
|
||||
|
||||
offset := 38 + jobNameLen
|
||||
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if argsLen < 0 || len(payload) < offset+argsLen+2 {
|
||||
return nil, fmt.Errorf("invalid args length")
|
||||
}
|
||||
args := ""
|
||||
if argsLen > 0 {
|
||||
args = string(payload[offset : offset+argsLen])
|
||||
}
|
||||
offset += argsLen
|
||||
|
||||
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if noteLen < 0 || len(payload) < offset+noteLen {
|
||||
return nil, fmt.Errorf("invalid note length")
|
||||
}
|
||||
note := ""
|
||||
if noteLen > 0 {
|
||||
note = string(payload[offset : offset+noteLen])
|
||||
}
|
||||
offset += noteLen
|
||||
|
||||
resources, resErr := parseOptionalResourceRequest(payload[offset:])
|
||||
if resErr != nil {
|
||||
return nil, resErr
|
||||
}
|
||||
|
||||
return &queueJobWithNotePayload{
|
||||
apiKeyHash: apiKeyHash,
|
||||
commitID: commitID,
|
||||
priority: priority,
|
||||
jobName: jobName,
|
||||
args: args,
|
||||
note: note,
|
||||
resources: resources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) {
|
||||
// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][resources?:var]
|
||||
if len(payload) < 40 {
|
||||
return nil, fmt.Errorf("queue job with args payload too short")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
commitID := payload[16:36]
|
||||
priority := int64(payload[36])
|
||||
jobNameLen := int(payload[37])
|
||||
if len(payload) < 38+jobNameLen+2 {
|
||||
return nil, fmt.Errorf("invalid job name length")
|
||||
}
|
||||
jobName := string(payload[38 : 38+jobNameLen])
|
||||
|
||||
offset := 38 + jobNameLen
|
||||
argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if argsLen < 0 || len(payload) < offset+argsLen {
|
||||
return nil, fmt.Errorf("invalid args length")
|
||||
}
|
||||
args := ""
|
||||
if argsLen > 0 {
|
||||
args = string(payload[offset : offset+argsLen])
|
||||
}
|
||||
offset += argsLen
|
||||
|
||||
resources, resErr := parseOptionalResourceRequest(payload[offset:])
|
||||
if resErr != nil {
|
||||
return nil, resErr
|
||||
}
|
||||
|
||||
return &queueJobWithArgsPayload{
|
||||
apiKeyHash: apiKeyHash,
|
||||
commitID: commitID,
|
||||
priority: priority,
|
||||
jobName: jobName,
|
||||
args: args,
|
||||
resources: resources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error {
|
||||
p, err := parseQueueJobWithArgsPayload(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("queue job request",
|
||||
"job", p.jobName,
|
||||
"priority", p.priority,
|
||||
"commit_id", fmt.Sprintf("%x", p.commitID),
|
||||
)
|
||||
|
||||
// Validate API key and get user information
|
||||
var user *auth.User
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
|
||||
h.logger.Info(
|
||||
"job queued",
|
||||
"job", p.jobName,
|
||||
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
|
||||
"user", user.Name,
|
||||
)
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"Insufficient permissions to create jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
commitIDStr := fmt.Sprintf("%x", p.commitID)
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger,
|
||||
"experiment.create",
|
||||
50*time.Millisecond,
|
||||
func() (string, error) {
|
||||
return "", h.expManager.CreateExperiment(commitIDStr)
|
||||
},
|
||||
); err != nil {
|
||||
h.logger.Error("failed to create experiment directory", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
|
||||
}
|
||||
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitIDStr,
|
||||
JobName: p.jobName,
|
||||
User: user.Name,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
|
||||
return "", h.expManager.WriteMetadata(meta)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to save experiment metadata", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
|
||||
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to ensure minimal experiment files", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
|
||||
}
|
||||
|
||||
return h.enqueueTaskAndRespondWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, nil, p.resources)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error {
|
||||
p, err := parseQueueJobWithNotePayload(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("queue job request",
|
||||
"job", p.jobName,
|
||||
"priority", p.priority,
|
||||
"commit_id", fmt.Sprintf("%x", p.commitID),
|
||||
)
|
||||
|
||||
var user *auth.User
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
|
||||
h.logger.Info(
|
||||
"job queued",
|
||||
"job", p.jobName,
|
||||
"path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)),
|
||||
"user", user.Name,
|
||||
)
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"Insufficient permissions to create jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
commitIDStr := fmt.Sprintf("%x", p.commitID)
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger,
|
||||
"experiment.create",
|
||||
50*time.Millisecond,
|
||||
func() (string, error) {
|
||||
return "", h.expManager.CreateExperiment(commitIDStr)
|
||||
},
|
||||
); err != nil {
|
||||
h.logger.Error("failed to create experiment directory", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
|
||||
}
|
||||
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitIDStr,
|
||||
JobName: p.jobName,
|
||||
User: user.Name,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
|
||||
return "", h.expManager.WriteMetadata(meta)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to save experiment metadata", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) {
|
||||
return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to ensure minimal experiment files", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error())
|
||||
}
|
||||
|
||||
return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, nil, p.resources)
|
||||
}
|
||||
|
||||
// enqueueTaskAndRespond enqueues a task and sends a success response.
|
||||
func (h *WSHandler) enqueueTaskAndRespond(
|
||||
conn *websocket.Conn,
|
||||
|
|
@ -677,6 +1171,112 @@ func (h *WSHandler) enqueueTaskAndRespond(
|
|||
commitID []byte,
|
||||
tracking *queue.TrackingConfig,
|
||||
resources *resourceRequest,
|
||||
) error {
|
||||
return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", tracking, resources)
|
||||
}
|
||||
|
||||
func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote(
|
||||
conn *websocket.Conn,
|
||||
user *auth.User,
|
||||
jobName string,
|
||||
priority int64,
|
||||
commitID []byte,
|
||||
args string,
|
||||
note string,
|
||||
tracking *queue.TrackingConfig,
|
||||
resources *resourceRequest,
|
||||
) error {
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
||||
|
||||
commitIDStr := fmt.Sprintf("%x", commitID)
|
||||
prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr)
|
||||
if provErr != nil {
|
||||
h.logger.Error("failed to compute expected provenance; refusing to enqueue",
|
||||
"commit_id", commitIDStr,
|
||||
"error", provErr)
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeStorageError,
|
||||
"Failed to compute expected provenance",
|
||||
provErr.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
if h.queue != nil {
|
||||
taskID := uuid.New().String()
|
||||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: strings.TrimSpace(args),
|
||||
Status: "queued",
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.Name,
|
||||
Username: user.Name,
|
||||
CreatedBy: user.Name,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitIDStr,
|
||||
},
|
||||
Tracking: tracking,
|
||||
}
|
||||
if strings.TrimSpace(note) != "" {
|
||||
task.Metadata["note"] = strings.TrimSpace(note)
|
||||
}
|
||||
for k, v := range prov {
|
||||
if v != "" {
|
||||
task.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
if resources != nil {
|
||||
task.CPU = resources.CPU
|
||||
task.MemoryGB = resources.MemoryGB
|
||||
task.GPU = resources.GPU
|
||||
task.GPUMemory = resources.GPUMemory
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger,
|
||||
"queue.add_task",
|
||||
20*time.Millisecond,
|
||||
func() (string, error) {
|
||||
return "", h.queue.AddTask(task)
|
||||
},
|
||||
); err != nil {
|
||||
h.logger.Error("failed to enqueue task", "error", err)
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeDatabaseError,
|
||||
"Failed to enqueue task",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
|
||||
} else {
|
||||
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
||||
}
|
||||
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Internal error",
|
||||
"Failed to serialize response",
|
||||
)
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
func (h *WSHandler) enqueueTaskAndRespondWithArgs(
|
||||
conn *websocket.Conn,
|
||||
user *auth.User,
|
||||
jobName string,
|
||||
priority int64,
|
||||
commitID []byte,
|
||||
args string,
|
||||
tracking *queue.TrackingConfig,
|
||||
resources *resourceRequest,
|
||||
) error {
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
||||
|
||||
|
|
@ -700,7 +1300,7 @@ func (h *WSHandler) enqueueTaskAndRespond(
|
|||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: "",
|
||||
Args: strings.TrimSpace(args),
|
||||
Status: "queued",
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
|
|
|
|||
|
|
@ -13,10 +13,47 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func JupyterTaskErrorCode(t *queue.Task) byte {
|
||||
if t == nil {
|
||||
return ErrorCodeUnknownError
|
||||
}
|
||||
status := strings.ToLower(strings.TrimSpace(t.Status))
|
||||
errStr := strings.ToLower(strings.TrimSpace(t.Error))
|
||||
|
||||
if status == "cancelled" {
|
||||
return ErrorCodeJobCancelled
|
||||
}
|
||||
if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") {
|
||||
return ErrorCodeOutOfMemory
|
||||
}
|
||||
if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") {
|
||||
return ErrorCodeDiskFull
|
||||
}
|
||||
if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") {
|
||||
return ErrorCodeServiceUnavailable
|
||||
}
|
||||
if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") {
|
||||
return ErrorCodeTimeout
|
||||
}
|
||||
if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") {
|
||||
return ErrorCodeNetworkError
|
||||
}
|
||||
if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") {
|
||||
return ErrorCodeInvalidConfiguration
|
||||
}
|
||||
|
||||
// Default for worker-side execution failures.
|
||||
if status == "failed" {
|
||||
return ErrorCodeJobExecutionFailed
|
||||
}
|
||||
return ErrorCodeUnknownError
|
||||
}
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service json.RawMessage `json:"service,omitempty"`
|
||||
Services json.RawMessage `json:"services,omitempty"`
|
||||
Packages json.RawMessage `json:"packages,omitempty"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
|
|
@ -76,7 +113,7 @@ func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) e
|
|||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
|
||||
|
|
@ -102,18 +139,98 @@ const (
|
|||
jupyterTaskTypeKey = "task_type"
|
||||
jupyterTaskTypeValue = "jupyter"
|
||||
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterActionListPkgs = "list_packages"
|
||||
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:read") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionListPkgs,
|
||||
jupyterNameKey: name,
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-packages-%s", name)
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter packages list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter packages list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
payload := payloadOut.Packages
|
||||
if len(payload) == 0 {
|
||||
payload = []byte("[]")
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]")))
|
||||
}
|
||||
|
||||
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
|
||||
if h.queue == nil {
|
||||
return "", fmt.Errorf("task queue not configured")
|
||||
|
|
@ -260,7 +377,7 @@ func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) err
|
|||
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
|
||||
}
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
|
||||
|
|
@ -336,7 +453,7 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro
|
|||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
|
@ -405,7 +522,7 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er
|
|||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
|
@ -456,7 +573,7 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro
|
|||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
var defaultBlockedPackages = []string{"requests", "urllib3", "httpx"}
|
||||
var defaultBlockedPackages = []string{}
|
||||
|
||||
func DefaultBlockedPackages() []string {
|
||||
return append([]string{}, defaultBlockedPackages...)
|
||||
|
|
|
|||
|
|
@ -203,6 +203,36 @@ type ServiceManager struct {
|
|||
services map[string]*JupyterService
|
||||
workspaceMetadataMgr *WorkspaceMetadataManager
|
||||
securityMgr *SecurityManager
|
||||
startupBlockedPkgs []string
|
||||
}
|
||||
|
||||
func splitPackageList(value string) []string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(value, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func startupBlockedPackages(installBlocked []string) []string {
|
||||
val, ok := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
if !ok {
|
||||
return append([]string{}, installBlocked...)
|
||||
}
|
||||
val = strings.TrimSpace(val)
|
||||
if val == "" || strings.EqualFold(val, "off") || strings.EqualFold(val, "none") || strings.EqualFold(val, "disabled") {
|
||||
return []string{}
|
||||
}
|
||||
return splitPackageList(val)
|
||||
}
|
||||
|
||||
// ServiceConfig holds configuration for Jupyter services
|
||||
|
|
@ -270,6 +300,12 @@ type JupyterService struct {
|
|||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
type InstalledPackage struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
// StartRequest defines parameters for starting a Jupyter service
|
||||
type StartRequest struct {
|
||||
Name string `json:"name"`
|
||||
|
|
@ -316,6 +352,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM
|
|||
}
|
||||
|
||||
securityMgr := NewSecurityManager(logger, securityConfig)
|
||||
startupBlockedPkgs := startupBlockedPackages(securityConfig.BlockedPackages)
|
||||
|
||||
sm := &ServiceManager{
|
||||
logger: logger,
|
||||
|
|
@ -324,6 +361,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM
|
|||
services: make(map[string]*JupyterService),
|
||||
workspaceMetadataMgr: workspaceMetadataMgr,
|
||||
securityMgr: securityMgr,
|
||||
startupBlockedPkgs: startupBlockedPkgs,
|
||||
}
|
||||
|
||||
// Load existing services
|
||||
|
|
@ -421,6 +459,10 @@ func (sm *ServiceManager) StartService(
|
|||
|
||||
// checkPackageBlacklist validates that no blacklisted packages are installed in the container
|
||||
func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID string) error {
|
||||
if len(sm.startupBlockedPkgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get list of installed packages from the container
|
||||
// Try both pip and conda package managers
|
||||
packages, err := sm.getInstalledPackages(ctx, containerID)
|
||||
|
|
@ -430,19 +472,14 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID
|
|||
return nil
|
||||
}
|
||||
|
||||
// Check each installed package against the blacklist
|
||||
// Check each installed package against the startup blacklist
|
||||
var blockedPackages []string
|
||||
for _, pkg := range packages {
|
||||
// Create a package request for validation
|
||||
pkgReq := &PackageRequest{
|
||||
PackageName: pkg,
|
||||
RequestedBy: "system",
|
||||
Channel: "",
|
||||
Version: "",
|
||||
}
|
||||
|
||||
if err := sm.securityMgr.ValidatePackageRequest(pkgReq); err != nil {
|
||||
blockedPackages = append(blockedPackages, pkg)
|
||||
for _, blocked := range sm.startupBlockedPkgs {
|
||||
if strings.EqualFold(blocked, pkg) {
|
||||
blockedPackages = append(blockedPackages, pkg)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -450,7 +487,7 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID
|
|||
if len(blockedPackages) > 0 {
|
||||
return fmt.Errorf("container startup failed: blacklisted packages detected: %v. "+
|
||||
"These packages are blocked by security policy. "+
|
||||
"Remove them from the image or use FETCHML_JUPYTER_BLOCKED_PACKAGES to configure the blacklist",
|
||||
"Remove them from the image or use FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES to configure the startup blacklist",
|
||||
blockedPackages)
|
||||
}
|
||||
|
||||
|
|
@ -508,6 +545,88 @@ func (sm *ServiceManager) parsePipList(output string) []string {
|
|||
return packages
|
||||
}
|
||||
|
||||
func (sm *ServiceManager) serviceByName(name string) *JupyterService {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
for _, svc := range sm.services {
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(svc.Name), name) {
|
||||
return svc
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServiceManager) listInstalledPackages(ctx context.Context, containerID string) ([]InstalledPackage, error) {
|
||||
var pkgs []InstalledPackage
|
||||
|
||||
// pip
|
||||
pipJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"pip", "list", "--format=json"})
|
||||
if err == nil {
|
||||
var parsed []struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if json.Unmarshal([]byte(pipJSON), &parsed) == nil {
|
||||
for _, p := range parsed {
|
||||
name := strings.TrimSpace(p.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "pip"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// conda
|
||||
condaJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"conda", "list", "--json"})
|
||||
if err == nil {
|
||||
var parsed []struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if json.Unmarshal([]byte(condaJSON), &parsed) == nil {
|
||||
for _, p := range parsed {
|
||||
name := strings.TrimSpace(p.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "conda"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
out := make([]InstalledPackage, 0, len(pkgs))
|
||||
for _, p := range pkgs {
|
||||
key := strings.ToLower(strings.TrimSpace(p.Name)) + ":" + strings.ToLower(strings.TrimSpace(p.Source))
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (sm *ServiceManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]InstalledPackage, error) {
|
||||
if sm == nil {
|
||||
return nil, fmt.Errorf("service manager is nil")
|
||||
}
|
||||
svc := sm.serviceByName(serviceName)
|
||||
if svc == nil {
|
||||
return nil, fmt.Errorf("service %s not found", strings.TrimSpace(serviceName))
|
||||
}
|
||||
if strings.TrimSpace(svc.ContainerID) == "" {
|
||||
return nil, fmt.Errorf("service container not available")
|
||||
}
|
||||
return sm.listInstalledPackages(ctx, svc.ContainerID)
|
||||
}
|
||||
|
||||
// parseCondaList parses conda list --export output
|
||||
func (sm *ServiceManager) parseCondaList(output string) []string {
|
||||
var packages []string
|
||||
|
|
|
|||
57
internal/jupyter/startup_blacklist_test.go
Normal file
57
internal/jupyter/startup_blacklist_test.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package jupyter
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStartupBlockedPackages_DefaultInheritsInstallBlocked(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
_, hadStartup := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
if hadStartup {
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
} else {
|
||||
defer os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
}
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 2 {
|
||||
t.Fatalf("expected startup list to inherit 2 items, got %d", len(startup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupBlockedPackages_Disabled(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "off")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 0 {
|
||||
t.Fatalf("expected startup list to be disabled, got %d", len(startup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupBlockedPackages_ExplicitList(t *testing.T) {
|
||||
oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES")
|
||||
oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3")
|
||||
_ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "aiohttp")
|
||||
defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall)
|
||||
defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup)
|
||||
|
||||
cfg := DefaultEnhancedSecurityConfigFromEnv()
|
||||
startup := startupBlockedPackages(cfg.BlockedPackages)
|
||||
if len(startup) != 1 || startup[0] != "aiohttp" {
|
||||
t.Fatalf("expected explicit startup list [aiohttp], got %v", startup)
|
||||
}
|
||||
}
|
||||
226
internal/manifest/run_manifest.go
Normal file
226
internal/manifest/run_manifest.go
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
)
|
||||
|
||||
const runManifestFilename = "run_manifest.json"
|
||||
|
||||
type Annotation struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Author string `json:"author,omitempty"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
func (a *Annotation) UnmarshalJSON(data []byte) error {
|
||||
type annotationWire struct {
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
TS *time.Time `json:"ts,omitempty"`
|
||||
Author string `json:"author,omitempty"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
var w annotationWire
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.Timestamp != nil {
|
||||
a.Timestamp = *w.Timestamp
|
||||
} else if w.TS != nil {
|
||||
a.Timestamp = *w.TS
|
||||
}
|
||||
a.Author = w.Author
|
||||
a.Note = w.Note
|
||||
return nil
|
||||
}
|
||||
|
||||
type Narrative struct {
|
||||
Hypothesis string `json:"hypothesis,omitempty"`
|
||||
Context string `json:"context,omitempty"`
|
||||
Intent string `json:"intent,omitempty"`
|
||||
ExpectedOutcome string `json:"expected_outcome,omitempty"`
|
||||
ParentRun string `json:"parent_run,omitempty"`
|
||||
ExperimentGroup string `json:"experiment_group,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
type NarrativePatch struct {
|
||||
Hypothesis *string `json:"hypothesis,omitempty"`
|
||||
Context *string `json:"context,omitempty"`
|
||||
Intent *string `json:"intent,omitempty"`
|
||||
ExpectedOutcome *string `json:"expected_outcome,omitempty"`
|
||||
ParentRun *string `json:"parent_run,omitempty"`
|
||||
ExperimentGroup *string `json:"experiment_group,omitempty"`
|
||||
Tags *[]string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
type ArtifactFile struct {
|
||||
Path string `json:"path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
Modified time.Time `json:"modified"`
|
||||
}
|
||||
|
||||
type Artifacts struct {
|
||||
DiscoveryTime time.Time `json:"discovery_time"`
|
||||
Files []ArtifactFile `json:"files,omitempty"`
|
||||
TotalSizeBytes int64 `json:"total_size_bytes,omitempty"`
|
||||
}
|
||||
|
||||
// RunManifest is a best-effort, self-contained provenance record for a run.
|
||||
// It is written to <run_dir>/run_manifest.json.
|
||||
type RunManifest struct {
|
||||
RunID string `json:"run_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
JobName string `json:"job_name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt time.Time `json:"started_at,omitempty"`
|
||||
EndedAt time.Time `json:"ended_at,omitempty"`
|
||||
|
||||
Annotations []Annotation `json:"annotations,omitempty"`
|
||||
Narrative *Narrative `json:"narrative,omitempty"`
|
||||
Artifacts *Artifacts `json:"artifacts,omitempty"`
|
||||
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
ExperimentManifestSHA string `json:"experiment_manifest_sha,omitempty"`
|
||||
DepsManifestName string `json:"deps_manifest_name,omitempty"`
|
||||
DepsManifestSHA string `json:"deps_manifest_sha,omitempty"`
|
||||
TrainScriptPath string `json:"train_script_path,omitempty"`
|
||||
|
||||
WorkerVersion string `json:"worker_version,omitempty"`
|
||||
PodmanImage string `json:"podman_image,omitempty"`
|
||||
ImageDigest string `json:"image_digest,omitempty"`
|
||||
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
SnapshotSHA256 string `json:"snapshot_sha256,omitempty"`
|
||||
|
||||
Command string `json:"command,omitempty"`
|
||||
Args string `json:"args,omitempty"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
|
||||
StagingDurationMS int64 `json:"staging_duration_ms,omitempty"`
|
||||
ExecutionDurationMS int64 `json:"execution_duration_ms,omitempty"`
|
||||
FinalizeDurationMS int64 `json:"finalize_duration_ms,omitempty"`
|
||||
TotalDurationMS int64 `json:"total_duration_ms,omitempty"`
|
||||
|
||||
GPUDevices []string `json:"gpu_devices,omitempty"`
|
||||
WorkerHost string `json:"worker_host,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest {
|
||||
m := &RunManifest{
|
||||
RunID: runID,
|
||||
TaskID: taskID,
|
||||
JobName: jobName,
|
||||
CreatedAt: createdAt,
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func ManifestPath(dir string) string {
|
||||
return filepath.Join(dir, runManifestFilename)
|
||||
}
|
||||
|
||||
func (m *RunManifest) WriteToDir(dir string) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("run manifest is nil")
|
||||
}
|
||||
data, err := json.MarshalIndent(m, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal run manifest: %w", err)
|
||||
}
|
||||
if err := fileutil.SecureFileWrite(ManifestPath(dir), data, 0640); err != nil {
|
||||
return fmt.Errorf("write run manifest: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadFromDir(dir string) (*RunManifest, error) {
|
||||
data, err := fileutil.SecureFileRead(ManifestPath(dir))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read run manifest: %w", err)
|
||||
}
|
||||
var m RunManifest
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, fmt.Errorf("parse run manifest: %w", err)
|
||||
}
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *RunManifest) MarkStarted(t time.Time) {
|
||||
m.StartedAt = t
|
||||
}
|
||||
|
||||
func (m *RunManifest) MarkFinished(t time.Time, exitCode *int, execErr error) {
|
||||
m.EndedAt = t
|
||||
m.ExitCode = exitCode
|
||||
if execErr != nil {
|
||||
m.Error = execErr.Error()
|
||||
} else {
|
||||
m.Error = ""
|
||||
}
|
||||
if !m.StartedAt.IsZero() {
|
||||
m.TotalDurationMS = m.EndedAt.Sub(m.StartedAt).Milliseconds()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RunManifest) AddAnnotation(ts time.Time, author, note string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
n := strings.TrimSpace(note)
|
||||
if n == "" {
|
||||
return
|
||||
}
|
||||
a := Annotation{
|
||||
Timestamp: ts,
|
||||
Author: strings.TrimSpace(author),
|
||||
Note: n,
|
||||
}
|
||||
m.Annotations = append(m.Annotations, a)
|
||||
}
|
||||
|
||||
func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.Narrative == nil {
|
||||
m.Narrative = &Narrative{}
|
||||
}
|
||||
if p.Hypothesis != nil {
|
||||
m.Narrative.Hypothesis = strings.TrimSpace(*p.Hypothesis)
|
||||
}
|
||||
if p.Context != nil {
|
||||
m.Narrative.Context = strings.TrimSpace(*p.Context)
|
||||
}
|
||||
if p.Intent != nil {
|
||||
m.Narrative.Intent = strings.TrimSpace(*p.Intent)
|
||||
}
|
||||
if p.ExpectedOutcome != nil {
|
||||
m.Narrative.ExpectedOutcome = strings.TrimSpace(*p.ExpectedOutcome)
|
||||
}
|
||||
if p.ParentRun != nil {
|
||||
m.Narrative.ParentRun = strings.TrimSpace(*p.ParentRun)
|
||||
}
|
||||
if p.ExperimentGroup != nil {
|
||||
m.Narrative.ExperimentGroup = strings.TrimSpace(*p.ExperimentGroup)
|
||||
}
|
||||
if p.Tags != nil {
|
||||
clean := make([]string, 0, len(*p.Tags))
|
||||
for _, t := range *p.Tags {
|
||||
t = strings.TrimSpace(t)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
clean = append(clean, t)
|
||||
}
|
||||
m.Narrative.Tags = clean
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ package queue
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -48,6 +50,7 @@ type QueueBackend string
|
|||
const (
|
||||
QueueBackendRedis QueueBackend = "redis"
|
||||
QueueBackendSQLite QueueBackend = "sqlite"
|
||||
QueueBackendFS QueueBackend = "filesystem"
|
||||
)
|
||||
|
||||
type BackendConfig struct {
|
||||
|
|
@ -56,20 +59,49 @@ type BackendConfig struct {
|
|||
RedisPassword string
|
||||
RedisDB int
|
||||
SQLitePath string
|
||||
FilesystemPath string
|
||||
FallbackToFilesystem bool
|
||||
MetricsFlushInterval time.Duration
|
||||
}
|
||||
|
||||
func NewBackend(cfg BackendConfig) (Backend, error) {
|
||||
mkFallback := func(err error) (Backend, error) {
|
||||
if !cfg.FallbackToFilesystem {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(cfg.FilesystemPath) == "" {
|
||||
return nil, fmt.Errorf("filesystem queue path is required for fallback")
|
||||
}
|
||||
fsq, fsErr := NewFilesystemQueue(cfg.FilesystemPath)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("filesystem queue fallback init failed: %w", fsErr)
|
||||
}
|
||||
return fsq, nil
|
||||
}
|
||||
|
||||
switch cfg.Backend {
|
||||
case QueueBackendFS:
|
||||
if strings.TrimSpace(cfg.FilesystemPath) == "" {
|
||||
return nil, fmt.Errorf("filesystem queue path is required")
|
||||
}
|
||||
return NewFilesystemQueue(cfg.FilesystemPath)
|
||||
case "", QueueBackendRedis:
|
||||
return NewTaskQueue(Config{
|
||||
b, err := NewTaskQueue(Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
})
|
||||
if err != nil {
|
||||
return mkFallback(err)
|
||||
}
|
||||
return b, nil
|
||||
case QueueBackendSQLite:
|
||||
return NewSQLiteQueue(cfg.SQLitePath)
|
||||
b, err := NewSQLiteQueue(cfg.SQLitePath)
|
||||
if err != nil {
|
||||
return mkFallback(err)
|
||||
}
|
||||
return b, nil
|
||||
default:
|
||||
return nil, ErrInvalidQueueBackend
|
||||
}
|
||||
|
|
|
|||
572
internal/queue/filesystem_queue.go
Normal file
572
internal/queue/filesystem_queue.go
Normal file
|
|
@ -0,0 +1,572 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FilesystemQueue struct {
|
||||
root string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type filesystemQueueIndex struct {
|
||||
Version int `json:"version"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Tasks []filesystemQueueIndexTask `json:"tasks"`
|
||||
}
|
||||
|
||||
type filesystemQueueIndexTask struct {
|
||||
ID string `json:"id"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func NewFilesystemQueue(root string) (*FilesystemQueue, error) {
|
||||
root = strings.TrimSpace(root)
|
||||
if root == "" {
|
||||
return nil, fmt.Errorf("filesystem queue root is required")
|
||||
}
|
||||
root = filepath.Clean(root)
|
||||
if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, d := range []string{"running", "finished", "failed"} {
|
||||
if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
q := &FilesystemQueue{root: root, ctx: ctx, cancel: cancel}
|
||||
_ = q.rebuildIndex()
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) Close() error {
|
||||
q.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) AddTask(task *Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if strings.TrimSpace(task.ID) == "" {
|
||||
return fmt.Errorf("task id is required")
|
||||
}
|
||||
if strings.TrimSpace(task.JobName) == "" {
|
||||
return fmt.Errorf("job name is required")
|
||||
}
|
||||
if task.MaxRetries == 0 {
|
||||
task.MaxRetries = defaultMaxRetries
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if strings.TrimSpace(task.Status) == "" {
|
||||
task.Status = "queued"
|
||||
}
|
||||
if task.Status != "queued" {
|
||||
// For filesystem backend we only enqueue queued tasks.
|
||||
// Other status updates should go through UpdateTask.
|
||||
task.Status = "queued"
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := q.pendingEntryPath(task.ID)
|
||||
if err := writeFileAtomic(path, payload, 0640); err != nil {
|
||||
return err
|
||||
}
|
||||
TasksQueued.Inc()
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
_ = q.rebuildIndex()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetNextTask() (*Task, error) {
|
||||
return q.claimNext("", 0, false)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) PeekNextTask() (*Task, error) {
|
||||
return q.claimNext("", 0, true)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
return q.claimNext(workerID, leaseDuration, false)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetNextTaskWithLeaseBlocking(
|
||||
workerID string,
|
||||
leaseDuration, blockTimeout time.Duration,
|
||||
) (*Task, error) {
|
||||
if blockTimeout <= 0 {
|
||||
blockTimeout = defaultBlockTimeout
|
||||
}
|
||||
deadline := time.Now().Add(blockTimeout)
|
||||
for {
|
||||
t, err := q.claimNext(workerID, leaseDuration, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t != nil {
|
||||
return t, nil
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return nil, nil
|
||||
}
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return nil, q.ctx.Err()
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
|
||||
// Single-worker friendly best-effort: update task lease fields if present.
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
|
||||
}
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
exp := time.Now().UTC().Add(leaseDuration)
|
||||
t.LeaseExpiry = &exp
|
||||
if workerID != "" {
|
||||
t.LeasedBy = workerID
|
||||
}
|
||||
RecordLeaseRenewal(workerID)
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) ReleaseLease(taskID string, workerID string) error {
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", t.LeasedBy)
|
||||
}
|
||||
t.LeaseExpiry = nil
|
||||
t.LeasedBy = ""
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) RetryTask(task *Task) error {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
RecordDLQAddition("max_retries")
|
||||
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
|
||||
}
|
||||
|
||||
errorCategory := ErrorUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
}
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
task.Status = "queued"
|
||||
task.LastError = task.Error
|
||||
task.Error = ""
|
||||
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
return q.AddTask(task)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetTask(taskID string) (*Task, error) {
|
||||
path, err := q.findTaskPath(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(data, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetAllTasks() ([]*Task, error) {
|
||||
paths := make([]string, 0, 32)
|
||||
for _, p := range []string{
|
||||
filepath.Join(q.root, "pending", "entries"),
|
||||
filepath.Join(q.root, "running"),
|
||||
filepath.Join(q.root, "finished"),
|
||||
filepath.Join(q.root, "failed"),
|
||||
} {
|
||||
entries, err := os.ReadDir(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(e.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
paths = append(paths, filepath.Join(p, e.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]*Task, 0, len(paths))
|
||||
for _, path := range paths {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(data, &t); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, &t)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) GetTaskByName(jobName string) (*Task, error) {
|
||||
jobName = strings.TrimSpace(jobName)
|
||||
if jobName == "" {
|
||||
return nil, fmt.Errorf("job name is required")
|
||||
}
|
||||
tasks, err := q.GetAllTasks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var best *Task
|
||||
for _, t := range tasks {
|
||||
if t == nil || t.JobName != jobName {
|
||||
continue
|
||||
}
|
||||
if best == nil || t.CreatedAt.After(best.CreatedAt) {
|
||||
best = t
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
return best, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) CancelTask(taskID string) error {
|
||||
t, err := q.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Status = "cancelled"
|
||||
now := time.Now().UTC()
|
||||
t.EndedAt = &now
|
||||
return q.UpdateTask(t)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) UpdateTask(task *Task) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
if strings.TrimSpace(task.ID) == "" {
|
||||
return fmt.Errorf("task id is required")
|
||||
}
|
||||
if strings.TrimSpace(task.Status) == "" {
|
||||
return fmt.Errorf("task status is required")
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst := q.pathForStatus(task.Status, task.ID)
|
||||
if dst == "" {
|
||||
// For statuses we don't map yet, keep it in running.
|
||||
dst = filepath.Join(q.root, "running", task.ID+".json")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Best-effort: remove any other copies before writing.
|
||||
_ = q.removeTaskFromAllDirs(task.ID)
|
||||
if err := writeFileAtomic(dst, payload, 0640); err != nil {
|
||||
return err
|
||||
}
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
_ = q.rebuildIndex()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) UpdateTaskWithMetrics(task *Task, _ string) error {
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) RecordMetric(_, _ string, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) Heartbeat(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) QueueDepth() (int64, error) {
|
||||
entries, err := os.ReadDir(filepath.Join(q.root, "pending", "entries"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var n int64
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(e.Name(), ".json") {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) SetWorkerPrewarmState(_ PrewarmState) error { return nil }
|
||||
func (q *FilesystemQueue) ClearWorkerPrewarmState(_ string) error { return nil }
|
||||
func (q *FilesystemQueue) GetWorkerPrewarmState(_ string) (*PrewarmState, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (q *FilesystemQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (q *FilesystemQueue) SignalPrewarmGC() error { return nil }
|
||||
func (q *FilesystemQueue) PrewarmGCRequestValue() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) claimNext(workerID string, leaseDuration time.Duration, peek bool) (*Task, error) {
|
||||
pendingDir := filepath.Join(q.root, "pending", "entries")
|
||||
entries, err := os.ReadDir(pendingDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidates := make([]*Task, 0, len(entries))
|
||||
paths := make(map[string]string, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(e.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(pendingDir, e.Name())
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(data, &t); err != nil {
|
||||
continue
|
||||
}
|
||||
if t.NextRetry != nil && time.Now().UTC().Before(t.NextRetry.UTC()) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, &t)
|
||||
paths[t.ID] = path
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
if candidates[i].Priority != candidates[j].Priority {
|
||||
return candidates[i].Priority > candidates[j].Priority
|
||||
}
|
||||
return candidates[i].CreatedAt.Before(candidates[j].CreatedAt)
|
||||
})
|
||||
|
||||
chosen := candidates[0]
|
||||
if peek {
|
||||
return chosen, nil
|
||||
}
|
||||
|
||||
src := paths[chosen.ID]
|
||||
if src == "" {
|
||||
return nil, nil
|
||||
}
|
||||
dst := filepath.Join(q.root, "running", chosen.ID+".json")
|
||||
if err := os.Rename(src, dst); err != nil {
|
||||
// Another process might have claimed it.
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh from the moved file to avoid race on content.
|
||||
data, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(data, &t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
exp := now.Add(leaseDuration)
|
||||
t.LeaseExpiry = &exp
|
||||
if strings.TrimSpace(workerID) != "" {
|
||||
t.LeasedBy = workerID
|
||||
}
|
||||
// Note: status transitions are handled by worker UpdateTask calls.
|
||||
|
||||
payload, err := json.Marshal(&t)
|
||||
if err == nil {
|
||||
_ = writeFileAtomic(dst, payload, 0640)
|
||||
}
|
||||
if depth, derr := q.QueueDepth(); derr == nil {
|
||||
UpdateQueueDepth(depth)
|
||||
}
|
||||
_ = q.rebuildIndex()
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) pendingEntryPath(taskID string) string {
|
||||
return filepath.Join(q.root, "pending", "entries", taskID+".json")
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) pathForStatus(status, taskID string) string {
|
||||
switch status {
|
||||
case "queued":
|
||||
return q.pendingEntryPath(taskID)
|
||||
case "running":
|
||||
return filepath.Join(q.root, "running", taskID+".json")
|
||||
case "completed", "finished":
|
||||
return filepath.Join(q.root, "finished", taskID+".json")
|
||||
case "failed", "cancelled":
|
||||
return filepath.Join(q.root, "failed", taskID+".json")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) findTaskPath(taskID string) (string, error) {
|
||||
paths := []string{
|
||||
q.pendingEntryPath(taskID),
|
||||
filepath.Join(q.root, "running", taskID+".json"),
|
||||
filepath.Join(q.root, "finished", taskID+".json"),
|
||||
filepath.Join(q.root, "failed", taskID+".json"),
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) removeTaskFromAllDirs(taskID string) error {
|
||||
paths := []string{
|
||||
q.pendingEntryPath(taskID),
|
||||
filepath.Join(q.root, "running", taskID+".json"),
|
||||
filepath.Join(q.root, "finished", taskID+".json"),
|
||||
filepath.Join(q.root, "failed", taskID+".json"),
|
||||
}
|
||||
var outErr error
|
||||
for _, p := range paths {
|
||||
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
outErr = err
|
||||
}
|
||||
}
|
||||
return outErr
|
||||
}
|
||||
|
||||
func (q *FilesystemQueue) rebuildIndex() error {
|
||||
pendingDir := filepath.Join(q.root, "pending", "entries")
|
||||
entries, err := os.ReadDir(pendingDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx := filesystemQueueIndex{Version: 1, UpdatedAt: time.Now().UTC().Format(time.RFC3339)}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(pendingDir, e.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var t Task
|
||||
if err := json.Unmarshal(data, &t); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.Tasks = append(idx.Tasks, filesystemQueueIndexTask{ID: t.ID, Priority: t.Priority, CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339Nano)})
|
||||
}
|
||||
|
||||
sort.Slice(idx.Tasks, func(i, j int) bool {
|
||||
if idx.Tasks[i].Priority != idx.Tasks[j].Priority {
|
||||
return idx.Tasks[i].Priority > idx.Tasks[j].Priority
|
||||
}
|
||||
return idx.Tasks[i].CreatedAt < idx.Tasks[j].CreatedAt
|
||||
})
|
||||
|
||||
payload, err := json.MarshalIndent(&idx, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path := filepath.Join(q.root, "pending", ".queue.json")
|
||||
return writeFileAtomic(path, payload, 0640)
|
||||
}
|
||||
|
||||
func writeFileAtomic(path string, data []byte, perm os.FileMode) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, perm); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, path)
|
||||
}
|
||||
|
|
@ -23,8 +23,10 @@ const (
|
|||
)
|
||||
|
||||
type QueueConfig struct {
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
FilesystemPath string `yaml:"filesystem_path"`
|
||||
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
|
||||
}
|
||||
|
||||
// Config holds worker configuration.
|
||||
|
|
@ -203,6 +205,12 @@ func LoadConfig(path string) (*Config, error) {
|
|||
}
|
||||
cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem {
|
||||
if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" {
|
||||
cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs")
|
||||
}
|
||||
cfg.Queue.FilesystemPath = config.ExpandPath(cfg.Queue.FilesystemPath)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.GPUVendor) == "" {
|
||||
if cfg.AppleGPU.Enabled {
|
||||
|
|
@ -254,8 +262,8 @@ func (c *Config) Validate() error {
|
|||
backend = string(queue.QueueBackendRedis)
|
||||
c.Queue.Backend = backend
|
||||
}
|
||||
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) {
|
||||
return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite)
|
||||
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) {
|
||||
return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS)
|
||||
}
|
||||
|
||||
if backend == string(queue.QueueBackendSQLite) {
|
||||
|
|
@ -267,6 +275,15 @@ func (c *Config) Validate() error {
|
|||
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
||||
}
|
||||
}
|
||||
if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem {
|
||||
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
|
||||
return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled")
|
||||
}
|
||||
c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath)
|
||||
if !filepath.IsAbs(c.Queue.FilesystemPath) {
|
||||
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
addr := strings.TrimSpace(c.RedisAddr)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ type JupyterManager interface {
|
|||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
// isValidName validates that input strings contain only safe characters.
|
||||
|
|
@ -382,6 +383,8 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
SQLitePath: cfg.Queue.SQLitePath,
|
||||
FilesystemPath: cfg.Queue.FilesystemPath,
|
||||
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
|
|
|
|||
|
|
@ -253,6 +253,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice
|
|||
|
||||
if err := w.stageExperimentFiles(task, jobDir); err != nil {
|
||||
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
|
||||
if a, aerr := scanArtifacts(jobDir); aerr == nil {
|
||||
m.Artifacts = a
|
||||
} else {
|
||||
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
exitCode := 1
|
||||
m.MarkFinished(now, &exitCode, err)
|
||||
|
|
@ -271,6 +276,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice
|
|||
}
|
||||
if err := w.stageSnapshot(ctx, task, jobDir); err != nil {
|
||||
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
|
||||
if a, aerr := scanArtifacts(jobDir); aerr == nil {
|
||||
m.Artifacts = a
|
||||
} else {
|
||||
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
exitCode := 1
|
||||
m.MarkFinished(now, &exitCode, err)
|
||||
|
|
@ -586,6 +596,11 @@ func (w *Worker) executeJob(
|
|||
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
|
||||
now := time.Now().UTC()
|
||||
m.ExecutionDurationMS = execDuration.Milliseconds()
|
||||
if a, aerr := scanArtifacts(outputDir); aerr == nil {
|
||||
m.Artifacts = a
|
||||
} else {
|
||||
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
|
||||
}
|
||||
if err != nil {
|
||||
exitCode := 1
|
||||
m.MarkFinished(now, &exitCode, err)
|
||||
|
|
@ -832,6 +847,35 @@ func (w *Worker) executeContainerJob(
|
|||
if trackingEnv == nil {
|
||||
trackingEnv = make(map[string]string)
|
||||
}
|
||||
cacheRoot := filepath.Join(w.config.BasePath, ".cache")
|
||||
if err := os.MkdirAll(cacheRoot, 0755); err != nil {
|
||||
return &errtypes.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "cache_setup",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
if volumes == nil {
|
||||
volumes = make(map[string]string)
|
||||
}
|
||||
volumes[cacheRoot] = "/workspace/.cache:rw"
|
||||
defaultEnv := map[string]string{
|
||||
"HF_HOME": "/workspace/.cache/huggingface",
|
||||
"TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub",
|
||||
"HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets",
|
||||
"TORCH_HOME": "/workspace/.cache/torch",
|
||||
"TORCH_HUB_DIR": "/workspace/.cache/torch/hub",
|
||||
"KERAS_HOME": "/workspace/.cache/keras",
|
||||
"CUDA_CACHE_PATH": "/workspace/.cache/cuda",
|
||||
"PIP_CACHE_DIR": "/workspace/.cache/pip",
|
||||
}
|
||||
for k, v := range defaultEnv {
|
||||
if _, ok := trackingEnv[k]; ok {
|
||||
continue
|
||||
}
|
||||
trackingEnv[k] = v
|
||||
}
|
||||
if strings.TrimSpace(visibleEnvVar) != "" {
|
||||
trackingEnv[visibleEnvVar] = strings.TrimSpace(visibleDevices)
|
||||
}
|
||||
|
|
@ -841,9 +885,6 @@ func (w *Worker) executeContainerJob(
|
|||
if strings.TrimSpace(task.SnapshotID) != "" {
|
||||
trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID)
|
||||
}
|
||||
if volumes == nil {
|
||||
volumes = make(map[string]string)
|
||||
}
|
||||
volumes[snap] = "/snapshot:ro"
|
||||
}
|
||||
|
||||
|
|
@ -932,6 +973,11 @@ func (w *Worker) executeContainerJob(
|
|||
now := time.Now().UTC()
|
||||
exitCode := 1
|
||||
m.ExecutionDurationMS = containerDuration.Milliseconds()
|
||||
if a, aerr := scanArtifacts(outputDir); aerr == nil {
|
||||
m.Artifacts = a
|
||||
} else {
|
||||
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
|
||||
}
|
||||
m.MarkFinished(now, &exitCode, err)
|
||||
})
|
||||
// Move job to failed directory
|
||||
|
|
@ -981,6 +1027,11 @@ func (w *Worker) executeContainerJob(
|
|||
now := time.Now().UTC()
|
||||
exitCode := 0
|
||||
m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds()
|
||||
if a, aerr := scanArtifacts(outputDir); aerr == nil {
|
||||
m.Artifacts = a
|
||||
} else {
|
||||
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
|
||||
}
|
||||
m.MarkFinished(now, &exitCode, nil)
|
||||
})
|
||||
if _, moveErr := telemetry.ExecWithMetrics(
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const (
|
|||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterActionListPkgs = "list_packages"
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
|
|
@ -28,10 +29,11 @@ const (
|
|||
)
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service *jupyter.JupyterService `json:"service,omitempty"`
|
||||
Services []*jupyter.JupyterService `json:"services"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Service *jupyter.JupyterService `json:"service,omitempty"`
|
||||
Services []*jupyter.JupyterService `json:"services"`
|
||||
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
func isJupyterTask(task *queue.Task) bool {
|
||||
|
|
@ -109,6 +111,17 @@ func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
services := w.jupyter.ListServices()
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionListPkgs:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter name")
|
||||
}
|
||||
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
|
||||
return json.Marshal(out)
|
||||
case jupyterActionRestore:
|
||||
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
||||
if name == "" {
|
||||
|
|
|
|||
Loading…
Reference in a new issue