package worker import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/queue" ) const ( jupyterTaskTypeKey = "task_type" jupyterTaskTypeValue = "jupyter" jupyterTaskActionKey = "jupyter_action" jupyterActionStart = "start" jupyterActionStop = "stop" jupyterActionRemove = "remove" jupyterActionRestore = "restore" jupyterActionList = "list" jupyterActionListPkgs = "list_packages" jupyterNameKey = "jupyter_name" jupyterWorkspaceKey = "jupyter_workspace" jupyterServiceIDKey = "jupyter_service_id" jupyterTaskOutputType = "jupyter_output" ) type jupyterTaskOutput struct { Type string `json:"type"` Service *jupyter.JupyterService `json:"service,omitempty"` Services []*jupyter.JupyterService `json:"services"` Packages []jupyter.InstalledPackage `json:"packages,omitempty"` RestorePath string `json:"restore_path,omitempty"` } func isJupyterTask(task *queue.Task) bool { if task == nil || task.Metadata == nil { return false } return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue } func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { if w == nil { return nil, fmt.Errorf("worker is nil") } if task == nil { return nil, fmt.Errorf("task is nil") } if w.jupyter == nil { return nil, fmt.Errorf("jupyter manager not configured") } if task.Metadata == nil { return nil, fmt.Errorf("missing task metadata") } action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey])) if action == "" { return nil, fmt.Errorf("missing jupyter action") } // Validate job name since it is used as the task status key and shows up in logs. if err := container.ValidateJobName(task.JobName); err != nil { return nil, err } ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) defer cancel() switch action { case jupyterActionStart: name := strings.TrimSpace(task.Metadata[jupyterNameKey]) ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey]) if name == "" { return nil, fmt.Errorf("missing jupyter name") } if ws == "" { return nil, fmt.Errorf("missing jupyter workspace") } service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws}) if err != nil { return nil, err } out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service} return json.Marshal(out) case jupyterActionStop: serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey]) if serviceID == "" { return nil, fmt.Errorf("missing jupyter service id") } if err := w.jupyter.StopService(ctx, serviceID); err != nil { return nil, err } out := jupyterTaskOutput{Type: jupyterTaskOutputType} return json.Marshal(out) case jupyterActionRemove: serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey]) if serviceID == "" { return nil, fmt.Errorf("missing jupyter service id") } purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true") if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil { return nil, err } out := jupyterTaskOutput{Type: jupyterTaskOutputType} return json.Marshal(out) case jupyterActionList: services := w.jupyter.ListServices() out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services} return json.Marshal(out) case jupyterActionListPkgs: name := strings.TrimSpace(task.Metadata[jupyterNameKey]) if name == "" { return nil, fmt.Errorf("missing jupyter name") } pkgs, err := w.jupyter.ListInstalledPackages(ctx, name) if err != nil { return nil, err } out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs} return json.Marshal(out) case jupyterActionRestore: name := strings.TrimSpace(task.Metadata[jupyterNameKey]) if name == "" { return nil, fmt.Errorf("missing jupyter name") } restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name) if err != nil { return nil, err } out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath} return json.Marshal(out) default: return nil, fmt.Errorf("invalid jupyter action: %s", action) } } func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { return w.runJupyterTask(ctx, task) }