package plugins__test import ( "context" "encoding/json" "errors" "testing" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" "github.com/jfraeys/fetch_ml/internal/worker/plugins" tests "github.com/jfraeys/fetch_ml/tests/fixtures" ) type fakeJupyterManager struct { startFn func(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) stopFn func(ctx context.Context, serviceID string) error removeFn func(ctx context.Context, serviceID string, purge bool) error restoreFn func(ctx context.Context, name string) (string, error) listFn func() []*jupyter.JupyterService listPkgsFn func(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) } func (f *fakeJupyterManager) StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { return f.startFn(ctx, req) } func (f *fakeJupyterManager) StopService(ctx context.Context, serviceID string) error { return f.stopFn(ctx, serviceID) } func (f *fakeJupyterManager) RemoveService(ctx context.Context, serviceID string, purge bool) error { return f.removeFn(ctx, serviceID, purge) } func (f *fakeJupyterManager) RestoreWorkspace(ctx context.Context, name string) (string, error) { return f.restoreFn(ctx, name) } func (f *fakeJupyterManager) ListServices() []*jupyter.JupyterService { return f.listFn() } func (f *fakeJupyterManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) { if f.listPkgsFn == nil { return nil, nil } return f.listPkgsFn(ctx, serviceName) } type jupyterOutput struct { Type string `json:"type"` Service *struct { Name string `json:"name"` URL string `json:"url"` } `json:"service"` } type jupyterPackagesOutput struct { Type string `json:"type"` Packages []struct { Name string `json:"name"` Version string `json:"version"` Source string `json:"source"` } `json:"packages"` } func TestRunJupyterTaskStartSuccess(t *testing.T) { w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { if req.Name != "my-workspace" { return nil, errors.New("bad name") } return &jupyter.JupyterService{Name: req.Name, URL: "http://127.0.0.1:8888"}, nil }, stopFn: func(context.Context, string) error { return nil }, removeFn: func(context.Context, string, bool) error { return nil }, restoreFn: func(context.Context, string) (string, error) { return "", nil }, listFn: func() []*jupyter.JupyterService { return nil }, listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil }, }) task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{ "task_type": "jupyter", "jupyter_action": "start", "jupyter_name": "my-workspace", "jupyter_workspace": "my-workspace", }} out, err := plugins.RunJupyterTask(context.Background(), w, task) if err != nil { t.Fatalf("expected nil error, got %v", err) } if len(out) == 0 { t.Fatalf("expected output") } var decoded jupyterOutput if err := json.Unmarshal(out, &decoded); err != nil { t.Fatalf("expected valid JSON, got %v", err) } if decoded.Service == nil || decoded.Service.Name != "my-workspace" { t.Fatalf("expected service name to be my-workspace, got %#v", decoded.Service) } } func TestRunJupyterTaskStopFailure(t *testing.T) { w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, stopFn: func(context.Context, string) error { return errors.New("stop failed") }, removeFn: func(context.Context, string, bool) error { return nil }, restoreFn: func(context.Context, string) (string, error) { return "", nil }, listFn: func() []*jupyter.JupyterService { return nil }, listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil }, }) task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{ "task_type": "jupyter", "jupyter_action": "stop", "jupyter_service_id": "svc-1", }} _, err := plugins.RunJupyterTask(context.Background(), w, task) if err == nil { t.Fatalf("expected error") } } func TestRunJupyterTaskListPackagesSuccess(t *testing.T) { w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, stopFn: func(context.Context, string) error { return nil }, removeFn: func(context.Context, string, bool) error { return nil }, restoreFn: func(context.Context, string) (string, error) { return "", nil }, listFn: func() []*jupyter.JupyterService { return nil }, listPkgsFn: func(_ context.Context, serviceName string) ([]jupyter.InstalledPackage, error) { if serviceName != "my-workspace" { return nil, errors.New("bad service") } return []jupyter.InstalledPackage{ {Name: "numpy", Version: "1.26.0", Source: "pip"}, {Name: "pandas", Version: "2.1.0", Source: "conda"}, }, nil }, }) task := &queue.Task{JobName: "jupyter-packages-my-workspace", Metadata: map[string]string{ "task_type": "jupyter", "jupyter_action": "list_packages", "jupyter_name": "my-workspace", }} out, err := plugins.RunJupyterTask(context.Background(), w, task) if err != nil { t.Fatalf("expected nil error, got %v", err) } var decoded jupyterPackagesOutput if err := json.Unmarshal(out, &decoded); err != nil { t.Fatalf("expected valid JSON, got %v", err) } if len(decoded.Packages) != 2 { t.Fatalf("expected 2 packages, got %d", len(decoded.Packages)) } }