Extend worker capabilities with new execution plugins and security features: - Jupyter plugin for notebook-based ML experiments - vLLM plugin for LLM inference workloads - Cross-platform process isolation (Unix/Windows) - Network policy enforcement with platform-specific implementations - Service manager integration for lifecycle management - Scheduler backend integration for queue coordination Update lifecycle management: - Enhanced runloop with state transitions - Service manager integration for plugin coordination - Improved state persistence and recovery Add test coverage: - Unit tests for Jupyter and vLLM plugins - Updated worker execution tests
163 lines
5.8 KiB
Go
163 lines
5.8 KiB
Go
package plugins__test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
|
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
|
)
|
|
|
|
type fakeJupyterManager struct {
|
|
startFn func(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
|
stopFn func(ctx context.Context, serviceID string) error
|
|
removeFn func(ctx context.Context, serviceID string, purge bool) error
|
|
restoreFn func(ctx context.Context, name string) (string, error)
|
|
listFn func() []*jupyter.JupyterService
|
|
listPkgsFn func(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
|
}
|
|
|
|
func (f *fakeJupyterManager) StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
|
return f.startFn(ctx, req)
|
|
}
|
|
|
|
func (f *fakeJupyterManager) StopService(ctx context.Context, serviceID string) error {
|
|
return f.stopFn(ctx, serviceID)
|
|
}
|
|
|
|
func (f *fakeJupyterManager) RemoveService(ctx context.Context, serviceID string, purge bool) error {
|
|
return f.removeFn(ctx, serviceID, purge)
|
|
}
|
|
|
|
func (f *fakeJupyterManager) RestoreWorkspace(ctx context.Context, name string) (string, error) {
|
|
return f.restoreFn(ctx, name)
|
|
}
|
|
|
|
func (f *fakeJupyterManager) ListServices() []*jupyter.JupyterService {
|
|
return f.listFn()
|
|
}
|
|
|
|
func (f *fakeJupyterManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
|
|
if f.listPkgsFn == nil {
|
|
return nil, nil
|
|
}
|
|
return f.listPkgsFn(ctx, serviceName)
|
|
}
|
|
|
|
type jupyterOutput struct {
|
|
Type string `json:"type"`
|
|
Service *struct {
|
|
Name string `json:"name"`
|
|
URL string `json:"url"`
|
|
} `json:"service"`
|
|
}
|
|
|
|
type jupyterPackagesOutput struct {
|
|
Type string `json:"type"`
|
|
Packages []struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
Source string `json:"source"`
|
|
} `json:"packages"`
|
|
}
|
|
|
|
func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
|
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
|
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
|
if req.Name != "my-workspace" {
|
|
return nil, errors.New("bad name")
|
|
}
|
|
return &jupyter.JupyterService{Name: req.Name, URL: "http://127.0.0.1:8888"}, nil
|
|
},
|
|
stopFn: func(context.Context, string) error { return nil },
|
|
removeFn: func(context.Context, string, bool) error { return nil },
|
|
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
|
listFn: func() []*jupyter.JupyterService { return nil },
|
|
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
|
|
})
|
|
|
|
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
|
|
"task_type": "jupyter",
|
|
"jupyter_action": "start",
|
|
"jupyter_name": "my-workspace",
|
|
"jupyter_workspace": "my-workspace",
|
|
}}
|
|
out, err := plugins.RunJupyterTask(context.Background(), w, task)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if len(out) == 0 {
|
|
t.Fatalf("expected output")
|
|
}
|
|
var decoded jupyterOutput
|
|
if err := json.Unmarshal(out, &decoded); err != nil {
|
|
t.Fatalf("expected valid JSON, got %v", err)
|
|
}
|
|
if decoded.Service == nil || decoded.Service.Name != "my-workspace" {
|
|
t.Fatalf("expected service name to be my-workspace, got %#v", decoded.Service)
|
|
}
|
|
}
|
|
|
|
func TestRunJupyterTaskStopFailure(t *testing.T) {
|
|
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
|
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
|
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
|
|
removeFn: func(context.Context, string, bool) error { return nil },
|
|
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
|
listFn: func() []*jupyter.JupyterService { return nil },
|
|
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
|
|
})
|
|
|
|
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
|
|
"task_type": "jupyter",
|
|
"jupyter_action": "stop",
|
|
"jupyter_service_id": "svc-1",
|
|
}}
|
|
_, err := plugins.RunJupyterTask(context.Background(), w, task)
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
}
|
|
|
|
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
|
|
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
|
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
|
stopFn: func(context.Context, string) error { return nil },
|
|
removeFn: func(context.Context, string, bool) error { return nil },
|
|
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
|
listFn: func() []*jupyter.JupyterService { return nil },
|
|
listPkgsFn: func(_ context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
|
|
if serviceName != "my-workspace" {
|
|
return nil, errors.New("bad service")
|
|
}
|
|
return []jupyter.InstalledPackage{
|
|
{Name: "numpy", Version: "1.26.0", Source: "pip"},
|
|
{Name: "pandas", Version: "2.1.0", Source: "conda"},
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
task := &queue.Task{JobName: "jupyter-packages-my-workspace", Metadata: map[string]string{
|
|
"task_type": "jupyter",
|
|
"jupyter_action": "list_packages",
|
|
"jupyter_name": "my-workspace",
|
|
}}
|
|
|
|
out, err := plugins.RunJupyterTask(context.Background(), w, task)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
|
|
var decoded jupyterPackagesOutput
|
|
if err := json.Unmarshal(out, &decoded); err != nil {
|
|
t.Fatalf("expected valid JSON, got %v", err)
|
|
}
|
|
if len(decoded.Packages) != 2 {
|
|
t.Fatalf("expected 2 packages, got %d", len(decoded.Packages))
|
|
}
|
|
}
|