fetch_ml/internal/worker/jupyter_task.go
Jeremie Fraeys 2e701340e5
feat(core): API, worker, queue, and manifest improvements
- Add protocol buffer optimizations (internal/api/protocol.go)
- Add filesystem queue backend (internal/queue/filesystem_queue.go)
- Add run manifest support (internal/manifest/run_manifest.go)
- Worker and jupyter task refinements
- Exported test wrappers for benchmarking
2026-02-12 12:05:17 -05:00

143 lines
4.4 KiB
Go

package worker
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
)
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterActionListPkgs = "list_packages"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
jupyterTaskOutputType = "jupyter_output"
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service *jupyter.JupyterService `json:"service,omitempty"`
Services []*jupyter.JupyterService `json:"services"`
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func isJupyterTask(task *queue.Task) bool {
if task == nil || task.Metadata == nil {
return false
}
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
}
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w == nil {
return nil, fmt.Errorf("worker is nil")
}
if task == nil {
return nil, fmt.Errorf("task is nil")
}
if w.jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
if task.Metadata == nil {
return nil, fmt.Errorf("missing task metadata")
}
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
if action == "" {
return nil, fmt.Errorf("missing jupyter action")
}
// Validate job name since it is used as the task status key and shows up in logs.
if err := container.ValidateJobName(task.JobName); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
switch action {
case jupyterActionStart:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
if ws == "" {
return nil, fmt.Errorf("missing jupyter workspace")
}
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
return json.Marshal(out)
case jupyterActionStop:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionRemove:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionList:
services := w.jupyter.ListServices()
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
return json.Marshal(out)
case jupyterActionListPkgs:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
return json.Marshal(out)
case jupyterActionRestore:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
return json.Marshal(out)
default:
return nil, fmt.Errorf("invalid jupyter action: %s", action)
}
}
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
return w.runJupyterTask(ctx, task)
}