- Add protocol buffer optimizations (internal/api/protocol.go) - Add filesystem queue backend (internal/queue/filesystem_queue.go) - Add run manifest support (internal/manifest/run_manifest.go) - Worker and jupyter task refinements - Exported test wrappers for benchmarking
143 lines
4.4 KiB
Go
143 lines
4.4 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
)
|
|
|
|
const (
|
|
jupyterTaskTypeKey = "task_type"
|
|
jupyterTaskTypeValue = "jupyter"
|
|
jupyterTaskActionKey = "jupyter_action"
|
|
jupyterActionStart = "start"
|
|
jupyterActionStop = "stop"
|
|
jupyterActionRemove = "remove"
|
|
jupyterActionRestore = "restore"
|
|
jupyterActionList = "list"
|
|
jupyterActionListPkgs = "list_packages"
|
|
jupyterNameKey = "jupyter_name"
|
|
jupyterWorkspaceKey = "jupyter_workspace"
|
|
jupyterServiceIDKey = "jupyter_service_id"
|
|
jupyterTaskOutputType = "jupyter_output"
|
|
)
|
|
|
|
type jupyterTaskOutput struct {
|
|
Type string `json:"type"`
|
|
Service *jupyter.JupyterService `json:"service,omitempty"`
|
|
Services []*jupyter.JupyterService `json:"services"`
|
|
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
|
|
RestorePath string `json:"restore_path,omitempty"`
|
|
}
|
|
|
|
func isJupyterTask(task *queue.Task) bool {
|
|
if task == nil || task.Metadata == nil {
|
|
return false
|
|
}
|
|
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
|
|
}
|
|
|
|
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
|
if w == nil {
|
|
return nil, fmt.Errorf("worker is nil")
|
|
}
|
|
if task == nil {
|
|
return nil, fmt.Errorf("task is nil")
|
|
}
|
|
if w.jupyter == nil {
|
|
return nil, fmt.Errorf("jupyter manager not configured")
|
|
}
|
|
if task.Metadata == nil {
|
|
return nil, fmt.Errorf("missing task metadata")
|
|
}
|
|
|
|
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
|
|
if action == "" {
|
|
return nil, fmt.Errorf("missing jupyter action")
|
|
}
|
|
|
|
// Validate job name since it is used as the task status key and shows up in logs.
|
|
if err := container.ValidateJobName(task.JobName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
|
defer cancel()
|
|
|
|
switch action {
|
|
case jupyterActionStart:
|
|
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
|
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
|
|
if name == "" {
|
|
return nil, fmt.Errorf("missing jupyter name")
|
|
}
|
|
if ws == "" {
|
|
return nil, fmt.Errorf("missing jupyter workspace")
|
|
}
|
|
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
|
|
return json.Marshal(out)
|
|
case jupyterActionStop:
|
|
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
|
if serviceID == "" {
|
|
return nil, fmt.Errorf("missing jupyter service id")
|
|
}
|
|
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
|
return nil, err
|
|
}
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
|
return json.Marshal(out)
|
|
case jupyterActionRemove:
|
|
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
|
|
if serviceID == "" {
|
|
return nil, fmt.Errorf("missing jupyter service id")
|
|
}
|
|
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
|
|
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
|
return nil, err
|
|
}
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
|
|
return json.Marshal(out)
|
|
case jupyterActionList:
|
|
services := w.jupyter.ListServices()
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
|
|
return json.Marshal(out)
|
|
case jupyterActionListPkgs:
|
|
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
|
if name == "" {
|
|
return nil, fmt.Errorf("missing jupyter name")
|
|
}
|
|
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
|
|
return json.Marshal(out)
|
|
case jupyterActionRestore:
|
|
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
|
|
if name == "" {
|
|
return nil, fmt.Errorf("missing jupyter name")
|
|
}
|
|
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
|
|
return json.Marshal(out)
|
|
default:
|
|
return nil, fmt.Errorf("invalid jupyter action: %s", action)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
|
return w.runJupyterTask(ctx, task)
|
|
}
|