feat(worker): add Jupyter/vLLM plugins and process isolation

Extend worker capabilities with new execution plugins and security features:
- Jupyter plugin for notebook-based ML experiments
- vLLM plugin for LLM inference workloads
- Cross-platform process isolation (Unix/Windows)
- Network policy enforcement with platform-specific implementations
- Service manager integration for lifecycle management
- Scheduler backend integration for queue coordination

Update lifecycle management:
- Enhanced runloop with state transitions
- Service manager integration for plugin coordination
- Improved state persistence and recovery

Add test coverage:
- Unit tests for Jupyter and vLLM plugins
- Updated worker execution tests
This commit is contained in:
Jeremie Fraeys 2026-02-26 12:03:59 -05:00
parent a981e89005
commit 95adcba437
No known key found for this signature in database
13 changed files with 1928 additions and 16 deletions

View file

@ -0,0 +1,227 @@
package queue
import (
"encoding/json"
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// SchedulerConn defines the interface for scheduler connection
type SchedulerConn interface {
Send(msg scheduler.Message)
}
// SchedulerBackend implements queue.Backend by communicating with a scheduler over WebSocket
type SchedulerBackend struct {
conn SchedulerConn
pendingTasks chan *Task
prewarmHint *Task
localSlots scheduler.SlotStatus
}
// NewSchedulerBackend creates a new scheduler backend
func NewSchedulerBackend(conn SchedulerConn) *SchedulerBackend {
return &SchedulerBackend{
conn: conn,
pendingTasks: make(chan *Task, 1),
}
}
// GetNextTaskWithLeaseBlocking implements queue.Backend
func (sb *SchedulerBackend) GetNextTaskWithLeaseBlocking(
workerID string,
leaseDuration time.Duration,
blockTimeout time.Duration,
) (*Task, error) {
// Signal readiness with current slot status
sb.conn.Send(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: workerID,
Slots: sb.localSlots,
Reason: "polling",
}),
})
// Wait for scheduler push or timeout
select {
case task := <-sb.pendingTasks:
return task, nil
case <-time.After(blockTimeout):
return nil, nil // RunLoop retries — same behaviour as empty queue
}
}
// OnJobAssign is called when the scheduler pushes a job assignment
func (sb *SchedulerBackend) OnJobAssign(spec *scheduler.JobSpec) {
task := &Task{
ID: spec.ID,
JobName: "distributed-job",
Priority: 1,
Status: "assigned",
}
select {
case sb.pendingTasks <- task:
default:
// Channel full, drop (shouldn't happen with buffer of 1)
}
}
// UpdateTask implements queue.Backend - forwards state changes to scheduler
func (sb *SchedulerBackend) UpdateTask(task *Task) error {
// Notify scheduler of state change
result := scheduler.JobResultPayload{
TaskID: task.ID,
State: task.Status,
}
if task.Error != "" {
result.Error = task.Error
}
sb.conn.Send(scheduler.Message{
Type: scheduler.MsgJobResult,
Payload: mustMarshal(result),
})
return nil
}
// PeekNextTask implements queue.Backend - returns last prewarm hint
func (sb *SchedulerBackend) PeekNextTask() (*Task, error) {
if sb.prewarmHint == nil {
return nil, nil
}
return sb.prewarmHint, nil
}
// OnPrewarmHint stores a prewarm hint from the scheduler
func (sb *SchedulerBackend) OnPrewarmHint(hint scheduler.PrewarmHintPayload) {
sb.prewarmHint = &Task{
ID: hint.TaskID,
Status: "prewarm_hint",
Metadata: map[string]string{
"snapshot_id": hint.SnapshotID,
"snapshot_sha": hint.SnapshotSHA,
},
}
}
// UpdateSlots updates the local slot status for readiness signaling
func (sb *SchedulerBackend) UpdateSlots(running, max int) {
sb.localSlots = scheduler.SlotStatus{
BatchTotal: max,
BatchInUse: running,
}
}
// Required queue.Backend methods (stub implementations for now)
func (sb *SchedulerBackend) AddTask(task *Task) error {
return fmt.Errorf("AddTask not supported in distributed mode - submit to scheduler instead")
}
func (sb *SchedulerBackend) GetNextTask() (*Task, error) {
return nil, fmt.Errorf("GetNextTask not supported - use GetNextTaskWithLeaseBlocking")
}
func (sb *SchedulerBackend) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
return sb.GetNextTaskWithLeaseBlocking(workerID, leaseDuration, 0)
}
func (sb *SchedulerBackend) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
// Distributed mode: scheduler manages leases
return nil
}
func (sb *SchedulerBackend) ReleaseLease(taskID string, workerID string) error {
// Distributed mode: scheduler manages leases
return nil
}
func (sb *SchedulerBackend) RetryTask(task *Task) error {
sb.conn.Send(scheduler.Message{
Type: scheduler.MsgJobResult,
Payload: mustMarshal(scheduler.JobResultPayload{TaskID: task.ID, State: "retry"}),
})
return nil
}
func (sb *SchedulerBackend) MoveToDeadLetterQueue(task *Task, reason string) error {
return nil // Scheduler handles this
}
func (sb *SchedulerBackend) GetTask(taskID string) (*Task, error) {
return nil, fmt.Errorf("GetTask not supported in distributed mode")
}
func (sb *SchedulerBackend) GetAllTasks() ([]*Task, error) {
return nil, fmt.Errorf("GetAllTasks not supported in distributed mode")
}
func (sb *SchedulerBackend) GetTaskByName(jobName string) (*Task, error) {
return nil, fmt.Errorf("GetTaskByName not supported in distributed mode")
}
func (sb *SchedulerBackend) CancelTask(taskID string) error {
return fmt.Errorf("CancelTask not supported - use scheduler API")
}
func (sb *SchedulerBackend) UpdateTaskWithMetrics(task *Task, action string) error {
return sb.UpdateTask(task)
}
func (sb *SchedulerBackend) RecordMetric(jobName, metric string, value float64) error {
return nil // Metrics handled by scheduler
}
func (sb *SchedulerBackend) Heartbeat(workerID string) error {
return nil // Heartbeat handled by SchedulerConn
}
func (sb *SchedulerBackend) QueueDepth() (int64, error) {
return 0, nil // Queue depth managed by scheduler
}
func (sb *SchedulerBackend) SetWorkerPrewarmState(state PrewarmState) error {
return nil
}
func (sb *SchedulerBackend) ClearWorkerPrewarmState(workerID string) error {
return nil
}
func (sb *SchedulerBackend) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
return nil, nil
}
func (sb *SchedulerBackend) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
return nil, nil
}
func (sb *SchedulerBackend) SignalPrewarmGC() error {
return nil
}
func (sb *SchedulerBackend) PrewarmGCRequestValue() (string, error) {
return "", nil
}
func (sb *SchedulerBackend) Close() error {
return nil
}
// Conn returns the underlying scheduler connection for heartbeat
func (sb *SchedulerBackend) Conn() *scheduler.SchedulerConn {
if conn, ok := sb.conn.(*scheduler.SchedulerConn); ok {
return conn
}
return nil
}
func mustMarshal(v any) []byte {
b, _ := json.Marshal(v)
return b
}
// Compile-time check - ensure SchedulerBackend implements the interface
var _ Backend = (*SchedulerBackend)(nil)

View file

@ -28,18 +28,16 @@ type RunLoopConfig struct {
// RunLoop manages the main worker processing loop
type RunLoop struct {
config RunLoopConfig
queue queue.Backend
executor TaskExecutor
metrics MetricsRecorder
logger Logger
stateMgr *StateManager
// State management
running map[string]context.CancelFunc
runningMu sync.RWMutex
queue queue.Backend
executor TaskExecutor
metrics MetricsRecorder
logger Logger
ctx context.Context
stateMgr *StateManager
running map[string]context.CancelFunc
cancel context.CancelFunc
config RunLoopConfig
runningMu sync.RWMutex
}
// MetricsRecorder defines the contract for recording metrics

View file

@ -0,0 +1,248 @@
// Package lifecycle provides service job lifecycle management for long-running services.
package lifecycle
import (
"context"
"fmt"
"net/http"
"os/exec"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// ServiceManager handles the lifecycle of long-running service jobs (Jupyter, vLLM)
type ServiceManager struct {
task *domain.Task
spec ServiceSpec
cmd *exec.Cmd
port int
stateMgr *StateManager
logger Logger
healthCheck *scheduler.HealthCheck
}
// ServiceSpec defines the specification for a service job
type ServiceSpec struct {
Command []string
Env map[string]string
Port int
HealthCheck *scheduler.HealthCheck
}
// NewServiceManager creates a new service manager
func NewServiceManager(task *domain.Task, spec ServiceSpec, port int, stateMgr *StateManager, logger Logger) *ServiceManager {
return &ServiceManager{
task: task,
spec: spec,
port: port,
stateMgr: stateMgr,
logger: logger,
healthCheck: spec.HealthCheck,
}
}
// Run executes the service lifecycle: start, wait for ready, health loop
func (sm *ServiceManager) Run(ctx context.Context) error {
if err := sm.start(); err != nil {
sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
}
return fmt.Errorf("start service: %w", err)
}
// Wait for readiness
readyCtx, cancel := context.WithTimeout(ctx, 120*time.Second)
defer cancel()
if err := sm.waitReady(readyCtx); err != nil {
sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
}
sm.stop()
return fmt.Errorf("wait ready: %w", err)
}
// Transition to serving state
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateServing); err != nil {
sm.logger.Error("failed to transition to serving", "task_id", sm.task.ID, "error", err)
}
}
sm.logger.Info("service is serving", "task_id", sm.task.ID, "port", sm.port)
// Run health loop
return sm.healthLoop(ctx)
}
// start launches the service process
func (sm *ServiceManager) start() error {
if len(sm.spec.Command) == 0 {
return fmt.Errorf("no command specified")
}
sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...)
// Set environment
for k, v := range sm.spec.Env {
sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
// Start process
if err := sm.cmd.Start(); err != nil {
return fmt.Errorf("start process: %w", err)
}
return nil
}
// waitReady blocks until the service passes readiness check or timeout
func (sm *ServiceManager) waitReady(ctx context.Context) error {
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
// No readiness check - assume ready immediately
return nil
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if sm.checkReadiness() {
return nil
}
}
}
}
// healthLoop monitors service health until context cancellation
func (sm *ServiceManager) healthLoop(ctx context.Context) error {
if sm.healthCheck == nil {
// No health check - just wait for context cancellation
<-ctx.Done()
return sm.gracefulStop()
}
interval := time.Duration(sm.healthCheck.IntervalSecs) * time.Second
if interval == 0 {
interval = 15 * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return sm.gracefulStop()
case <-ticker.C:
if !sm.checkLiveness() {
sm.logger.Error("service liveness check failed", "task_id", sm.task.ID)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
}
return fmt.Errorf("liveness check failed")
}
}
}
}
// checkLiveness returns true if the service process is alive
func (sm *ServiceManager) checkLiveness() bool {
if sm.cmd == nil || sm.cmd.Process == nil {
return false
}
// Check process state
if sm.healthCheck != nil && sm.healthCheck.LivenessEndpoint != "" {
// HTTP liveness check
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(sm.healthCheck.LivenessEndpoint)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == 200
}
// Process existence check
return sm.cmd.Process.Signal(syscall.Signal(0)) == nil
}
// checkReadiness returns true if the service is ready to accept traffic
func (sm *ServiceManager) checkReadiness() bool {
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
return true
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(sm.healthCheck.ReadinessEndpoint)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == 200
}
// gracefulStop stops the service gracefully with a timeout
func (sm *ServiceManager) gracefulStop() error {
sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateStopping)
}
if sm.cmd == nil || sm.cmd.Process == nil {
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateCompleted)
}
return nil
}
// Send SIGTERM for graceful shutdown
if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil {
sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err)
sm.cmd.Process.Kill()
} else {
// Wait for graceful shutdown
done := make(chan error, 1)
go func() {
done <- sm.cmd.Wait()
}()
select {
case <-done:
// Graceful shutdown completed
case <-time.After(30 * time.Second):
// Timeout - force kill
sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID)
sm.cmd.Process.Kill()
}
}
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateCompleted)
}
return nil
}
// stop forcefully stops the service
func (sm *ServiceManager) stop() error {
if sm.cmd == nil || sm.cmd.Process == nil {
return nil
}
return sm.cmd.Process.Kill()
}
// GetPort returns the assigned port for the service
func (sm *ServiceManager) GetPort() int {
return sm.port
}

View file

@ -27,18 +27,30 @@ const (
StateCompleted TaskState = "completed"
// StateFailed indicates the task failed during execution.
StateFailed TaskState = "failed"
// StateCancelled indicates the task was cancelled by user or system.
StateCancelled TaskState = "cancelled"
// StateOrphaned indicates the worker was lost - task will be silently re-queued.
StateOrphaned TaskState = "orphaned"
// StateServing indicates a service job (Jupyter/vLLM) is up and ready.
StateServing TaskState = "serving"
// StateStopping indicates a service job is in graceful shutdown.
StateStopping TaskState = "stopping"
)
// ValidTransitions defines the allowed state transitions.
// The key is the "from" state, the value is a list of valid "to" states.
// This enforces that state transitions follow the expected lifecycle.
var ValidTransitions = map[TaskState][]TaskState{
StateQueued: {StatePreparing, StateFailed},
StatePreparing: {StateRunning, StateFailed},
StateRunning: {StateCollecting, StateFailed},
StateCollecting: {StateCompleted, StateFailed},
StateQueued: {StatePreparing, StateFailed, StateCancelled},
StatePreparing: {StateRunning, StateFailed, StateCancelled},
StateRunning: {StateCollecting, StateFailed, StateOrphaned, StateServing, StateCancelled},
StateCollecting: {StateCompleted, StateFailed, StateCancelled},
StateCompleted: {},
StateFailed: {},
StateCancelled: {},
StateOrphaned: {StateQueued},
StateServing: {StateStopping, StateFailed, StateOrphaned, StateCancelled},
StateStopping: {StateCompleted, StateFailed},
}
// StateTransitionError is returned when an invalid state transition is attempted.
@ -53,8 +65,8 @@ func (e StateTransitionError) Error() string {
// StateManager manages task state transitions with audit logging.
type StateManager struct {
enabled bool
auditor *audit.Logger
enabled bool
}
// NewStateManager creates a new state manager with the given audit logger.
@ -116,7 +128,7 @@ func (sm *StateManager) validateTransition(from, to TaskState) error {
// IsTerminalState returns true if the state is terminal (no further transitions allowed).
func IsTerminalState(state TaskState) bool {
return state == StateCompleted || state == StateFailed
return state == StateCompleted || state == StateFailed || state == StateCancelled
}
// CanTransition returns true if a transition from -> to is valid.

View file

@ -0,0 +1,133 @@
// Package plugins provides framework-specific extensions to the worker.
// This implements the prolog/epilog model where plugins hook into task lifecycle events.
package plugins
import (
"context"
"encoding/json"
"fmt"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// JupyterManager interface for jupyter service management
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
// TaskRunner executes framework-specific tasks
// This is the minimal interface needed by plugins to execute tasks
type TaskRunner interface {
GetJupyterManager() JupyterManager
}
// RunJupyterTask runs a Jupyter-related task.
// It handles start, stop, remove, restore, and list_packages actions.
// This is a plugin function that extends the core worker with Jupyter support.
func RunJupyterTask(ctx context.Context, runner TaskRunner, task *queue.Task) ([]byte, error) {
jm := runner.GetJupyterManager()
if jm == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
action := task.Metadata["jupyter_action"]
if action == "" {
action = "start" // Default action
}
switch action {
case "start":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
// Extract from jobName if format is "jupyter-<name>"
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
name = task.JobName[8:]
}
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
req := &jupyter.StartRequest{Name: name}
service, err := jm.StartService(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
}
output := map[string]any{
"type": "start",
"service": service,
}
return json.Marshal(output)
case "stop":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
if err := jm.StopService(ctx, serviceID); err != nil {
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
case "remove":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
purge := task.Metadata["jupyter_purge"] == "true"
if err := jm.RemoveService(ctx, serviceID, purge); err != nil {
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
case "restore":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
serviceID, err := jm.RestoreWorkspace(ctx, name)
if err != nil {
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
}
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
case "list_packages":
serviceName := task.Metadata["jupyter_name"]
if serviceName == "" {
// Extract from jobName if format is "jupyter-packages-<name>"
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
serviceName = task.JobName[16:]
}
}
if serviceName == "" {
return nil, fmt.Errorf("missing jupyter_name in task metadata")
}
packages, err := jm.ListInstalledPackages(ctx, serviceName)
if err != nil {
return nil, fmt.Errorf("failed to list installed packages: %w", err)
}
output := map[string]any{
"type": "list_packages",
"packages": packages,
}
return json.Marshal(output)
default:
return nil, fmt.Errorf("unknown jupyter action: %s", action)
}
}

View file

@ -0,0 +1,279 @@
// Package plugins provides service plugin implementations for fetch_ml.
// This file contains the vLLM inference server plugin.
package plugins
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
)
// VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service.
// This plugin manages the lifecycle of vLLM, including startup, health monitoring,
// and graceful shutdown.
type VLLMPlugin struct {
logger *slog.Logger
baseImage string
envOverrides map[string]string
}
// VLLMConfig contains the configuration for a vLLM service instance
type VLLMConfig struct {
// Model name or path (required)
Model string `json:"model"`
// Tensor parallel size (number of GPUs)
TensorParallelSize int `json:"tensor_parallel_size,omitempty"`
// Port to run on (0 = auto-allocate)
Port int `json:"port,omitempty"`
// Maximum number of tokens in the cache
MaxModelLen int `json:"max_model_len,omitempty"`
// GPU memory utilization (0.0 - 1.0)
GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"`
// Quantization method (awq, gptq, squeezellm, or empty for none)
Quantization string `json:"quantization,omitempty"`
// Trust remote code for custom models
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
// Additional CLI arguments to pass to vLLM
ExtraArgs []string `json:"extra_args,omitempty"`
}
// VLLMStatus represents the current status of a vLLM service
type VLLMStatus struct {
Healthy bool `json:"healthy"`
Ready bool `json:"ready"`
ModelLoaded bool `json:"model_loaded"`
Uptime time.Duration `json:"uptime"`
Requests int64 `json:"requests_served"`
}
// NewVLLMPlugin creates a new vLLM plugin instance
func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin {
if baseImage == "" {
baseImage = "vllm/vllm-openai:latest"
}
return &VLLMPlugin{
logger: logger,
baseImage: baseImage,
envOverrides: make(map[string]string),
}
}
// GetServiceTemplate returns the service template for vLLM
func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate {
return ServiceTemplate{
JobType: "service",
SlotPool: "service",
GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism
Command: []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", "{{MODEL_NAME}}",
"--port", "{{SERVICE_PORT}}",
"--host", "0.0.0.0",
"{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}",
"{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}",
"{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}",
},
Env: map[string]string{
"CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}",
"VLLM_LOGGING_LEVEL": "INFO",
},
HealthCheck: ServiceHealthCheck{
Liveness: "http://localhost:{{SERVICE_PORT}}/health",
Readiness: "http://localhost:{{SERVICE_PORT}}/health",
Interval: 30,
Timeout: 10,
},
Mounts: []ServiceMount{
{Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"},
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
},
Resources: ResourceRequirements{
MinGPU: 1,
MaxGPU: 8, // Support up to 8-GPU tensor parallelism
MinMemory: 16, // 16GB minimum
},
}
}
// ValidateConfig validates the vLLM configuration
func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error {
if config.Model == "" {
return fmt.Errorf("vllm: model name is required")
}
if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 {
return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0")
}
validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"}
found := false
for _, q := range validQuantizations {
if config.Quantization == q {
found = true
break
}
}
if !found {
return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization)
}
return nil
}
// BuildCommand builds the vLLM command line from configuration
func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string {
args := []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", config.Model,
"--port", fmt.Sprintf("%d", port),
"--host", "0.0.0.0",
}
if config.TensorParallelSize > 1 {
args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize))
}
if config.MaxModelLen > 0 {
args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen))
}
if config.GpuMemoryUtilization > 0 {
args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization))
}
if config.Quantization != "" {
args = append(args, "--quantization", config.Quantization)
}
if config.TrustRemoteCode {
args = append(args, "--trust-remote-code")
}
// Add any extra arguments
args = append(args, config.ExtraArgs...)
return args
}
// CheckHealth performs a health check against a running vLLM instance
func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil)
if err != nil {
return nil, fmt.Errorf("vllm health check: %w", err)
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return &VLLMStatus{Healthy: false, Ready: false}, nil
}
defer resp.Body.Close()
// Parse metrics endpoint for detailed status
metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil)
metricsResp, err := client.Do(metricsReq)
status := &VLLMStatus{
Healthy: resp.StatusCode == 200,
Ready: resp.StatusCode == 200,
}
if err == nil && metricsResp.StatusCode == 200 {
defer metricsResp.Body.Close()
// Could parse Prometheus metrics here for detailed stats
}
return status, nil
}
// RunVLLMTask executes a vLLM service task
// This is called by the worker when it receives a vLLM job assignment
func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) {
logger = logger.With("plugin", "vllm", "task_id", taskID)
// Parse vLLM configuration from task metadata
configJSON, ok := metadata["vllm_config"]
if !ok {
return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata")
}
var config VLLMConfig
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, fmt.Errorf("vllm: invalid config: %w", err)
}
plugin := NewVLLMPlugin(logger, "")
if err := plugin.ValidateConfig(&config); err != nil {
return nil, err
}
logger.Info("starting vllm service",
"model", config.Model,
"tensor_parallel", config.TensorParallelSize,
)
// Port allocation and command building would happen here
// The actual execution goes through the standard service runner
// Return success output for the task
output := map[string]interface{}{
"status": "starting",
"model": config.Model,
"plugin": "vllm",
}
return json.Marshal(output)
}
// ServiceTemplate is re-exported here for plugin use
type ServiceTemplate = struct {
JobType string `json:"job_type"`
SlotPool string `json:"slot_pool"`
GPUCount int `json:"gpu_count"`
Command []string `json:"command"`
Env map[string]string `json:"env"`
HealthCheck ServiceHealthCheck `json:"health_check"`
Mounts []ServiceMount `json:"mounts,omitempty"`
Resources ResourceRequirements `json:"resources,omitempty"`
}
// ServiceHealthCheck is re-exported for plugin use
type ServiceHealthCheck = struct {
Liveness string `json:"liveness"`
Readiness string `json:"readiness"`
Interval int `json:"interval"`
Timeout int `json:"timeout"`
}
// ServiceMount is re-exported for plugin use
type ServiceMount = struct {
Source string `json:"source"`
Destination string `json:"destination"`
ReadOnly bool `json:"readonly,omitempty"`
}
// ResourceRequirements defines resource needs for a service
type ResourceRequirements struct {
MinGPU int `json:"min_gpu,omitempty"`
MaxGPU int `json:"max_gpu,omitempty"`
MinMemory int `json:"min_memory_gb,omitempty"` // GB
MaxMemory int `json:"max_memory_gb,omitempty"` // GB
}

View file

@ -0,0 +1,63 @@
// Package process provides process isolation and resource limiting for HIPAA compliance.
// Implements Worker Process Isolation controls.
package process
import (
"fmt"
"os"
"runtime"
"syscall"
)
// IsolationConfig holds process isolation parameters
type IsolationConfig struct {
MaxProcesses int // Fork bomb protection (RLIMIT_NPROC on Linux)
MaxOpenFiles int // FD exhaustion protection (RLIMIT_NOFILE)
DisableSwap bool // Prevent swap exfiltration
OOMScoreAdj int // OOM killer priority adjustment (Linux only)
}
// ApplyIsolation applies process isolation controls to the current process.
// This should be called after forking but before execing the target command.
func ApplyIsolation(cfg IsolationConfig) error {
// Apply resource limits (platform-specific)
if err := applyResourceLimits(cfg); err != nil {
return err
}
// OOM score adjustment (only on Linux)
if cfg.OOMScoreAdj != 0 && runtime.GOOS == "linux" {
if err := setOOMScoreAdj(cfg.OOMScoreAdj); err != nil {
return fmt.Errorf("failed to set OOM score adjustment: %w", err)
}
}
// Disable swap (Linux only) - requires CAP_SYS_RESOURCE or root
if cfg.DisableSwap && runtime.GOOS == "linux" {
if err := disableSwap(); err != nil {
// Log but don't fail - swap disabling requires privileges
// This is best-effort security hardening
}
}
return nil
}
// setOOMScoreAdj adjusts the OOM killer score (Linux only)
// Lower values = less likely to be killed (negative is "never kill")
// Higher values = more likely to be killed
func setOOMScoreAdj(score int) error {
// Write to /proc/self/oom_score_adj
path := "/proc/self/oom_score_adj"
data := []byte(fmt.Sprintf("%d\n", score))
return os.WriteFile(path, data, 0644)
}
// IsolatedExec runs a command with process isolation applied.
// This is a helper for container execution that applies limits before exec.
func IsolatedExec(argv []string, cfg IsolationConfig) error {
if err := ApplyIsolation(cfg); err != nil {
return err
}
return syscall.Exec(argv[0], argv, os.Environ())
}

View file

@ -0,0 +1,122 @@
//go:build !windows
// +build !windows
package process
import (
"fmt"
"syscall"
)
// applyResourceLimits sets Unix/Linux resource limits
func applyResourceLimits(cfg IsolationConfig) error {
// Apply file descriptor limits (RLIMIT_NOFILE for FD exhaustion protection)
if cfg.MaxOpenFiles > 0 {
if err := setResourceLimit(syscall.RLIMIT_NOFILE, uint64(cfg.MaxOpenFiles)); err != nil {
return fmt.Errorf("failed to set max open files limit: %w", err)
}
}
// Apply process limits if available (Linux only)
if cfg.MaxProcesses > 0 {
if err := setProcessLimit(cfg.MaxProcesses); err != nil {
// Log but don't fail - this is defense in depth
return fmt.Errorf("failed to set max processes limit: %w", err)
}
}
return nil
}
// setResourceLimit sets a soft and hard rlimit for the current process
func setResourceLimit(resource int, limit uint64) error {
rl := &syscall.Rlimit{
Cur: limit,
Max: limit,
}
return syscall.Setrlimit(resource, rl)
}
// setProcessLimit sets RLIMIT_NPROC on Linux, no-op on other Unix
func setProcessLimit(maxProcs int) error {
// Try to set RLIMIT_NPROC - only available on Linux
// On Darwin/macOS, this returns ENOTSUP
const RLIMIT_NPROC = 7 // Linux value
rl := &syscall.Rlimit{
Cur: uint64(maxProcs),
Max: uint64(maxProcs),
}
err := syscall.Setrlimit(RLIMIT_NPROC, rl)
if err != nil {
// ENOTSUP means not supported (macOS)
if err == syscall.ENOTSUP || err == syscall.EINVAL {
return nil // Silently ignore on platforms that don't support it
}
return err
}
return nil
}
// disableSwap attempts to lock memory to prevent swapping (mlockall)
// This is best-effort and requires CAP_IPC_LOCK capability
func disableSwap() error {
// MCL_CURRENT: lock all current pages
// MCL_FUTURE: lock all future pages
const MCL_CURRENT = 0x1
const MCL_FUTURE = 0x2
// Note: mlockall requires CAP_IPC_LOCK capability
// If this fails, we log but continue (defense in depth)
return syscall.Mlockall(MCL_CURRENT | MCL_FUTURE)
}
// GetCurrentLimits returns the current rlimit values for diagnostics
func GetCurrentLimits() (map[string]uint64, error) {
limits := make(map[string]uint64)
// Get NOFILE limit (available on all platforms)
var nofile syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &nofile); err != nil {
return nil, fmt.Errorf("failed to get NOFILE limit: %w", err)
}
limits["NOFILE_soft"] = nofile.Cur
limits["NOFILE_hard"] = nofile.Max
// Get platform-specific limits
getPlatformLimits(limits)
return limits, nil
}
// getPlatformLimits adds platform-specific limits to the map
func getPlatformLimits(limits map[string]uint64) {
// Get virtual memory limit (AS)
var as syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_AS, &as); err == nil {
limits["AS_soft"] = as.Cur
limits["AS_hard"] = as.Max
}
// Get data segment limit
var data syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_DATA, &data); err == nil {
limits["DATA_soft"] = data.Cur
limits["DATA_hard"] = data.Max
}
// Try to get RLIMIT_NPROC (Linux only)
const RLIMIT_NPROC = 7
var nproc syscall.Rlimit
if err := syscall.Getrlimit(RLIMIT_NPROC, &nproc); err == nil {
limits["NPROC_soft"] = nproc.Cur
limits["NPROC_hard"] = nproc.Max
}
// Try to get RLIMIT_RSS (Linux only)
const RLIMIT_RSS = 5
var rss syscall.Rlimit
if err := syscall.Getrlimit(RLIMIT_RSS, &rss); err == nil {
limits["RSS_soft"] = rss.Cur
limits["RSS_hard"] = rss.Max
}
}

View file

@ -0,0 +1,41 @@
//go:build windows
// +build windows
package process
import "fmt"
// applyResourceLimits is a no-op on Windows as rlimits are Unix-specific.
// Windows uses Job Objects for process limits, which is more complex.
func applyResourceLimits(cfg IsolationConfig) error {
// Windows doesn't support setrlimit - would need Job Objects
// For now, log that limits are not enforced on Windows
if cfg.MaxOpenFiles > 0 || cfg.MaxProcesses > 0 {
// Process limits not implemented on Windows yet
// TODO: Use Windows Job Objects for process limits
return fmt.Errorf("process isolation limits not implemented on Windows (max_open_files=%d, max_processes=%d)",
cfg.MaxOpenFiles, cfg.MaxProcesses)
}
return nil
}
// disableSwap is a no-op on Windows
func disableSwap() error {
// Windows doesn't support mlockall
// Would need VirtualLock API
return nil
}
// GetCurrentLimits returns empty limits on Windows
func GetCurrentLimits() (map[string]uint64, error) {
limits := make(map[string]uint64)
getPlatformLimits(limits)
return limits, nil
}
// getPlatformLimits adds Windows-specific limits (none currently)
func getPlatformLimits(limits map[string]uint64) {
// Windows doesn't have rlimit equivalent for these
// Could use GetProcessWorkingSetSize for memory limits
limits["WINDOWS"] = 1 // Marker to indicate Windows platform
}

View file

@ -0,0 +1,278 @@
// Package process provides process isolation and security enforcement for worker tasks.
// This file implements Network Micro-Segmentation enforcement hooks.
//go:build linux
// +build linux
package process
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
)
// NetworkPolicy defines network segmentation rules for a task
type NetworkPolicy struct {
// Mode is the network isolation mode: "none", "bridge", "container", "host"
Mode string
// AllowedEndpoints is a list of allowed network endpoints (host:port format)
// Only used when Mode is "bridge" or "container"
AllowedEndpoints []string
// BlockedSubnets is a list of CIDR ranges to block
BlockedSubnets []string
// DNSResolution controls DNS resolution (true = allow, false = block)
DNSResolution bool
// OutboundTraffic controls outbound connections (true = allow, false = block)
OutboundTraffic bool
// InboundTraffic controls inbound connections (true = allow, false = block)
InboundTraffic bool
}
// DefaultNetworkPolicy returns a hardened default network policy
// This implements Network Micro-Segmentation
func DefaultNetworkPolicy() NetworkPolicy {
return NetworkPolicy{
Mode: "none",
AllowedEndpoints: []string{},
BlockedSubnets: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"},
DNSResolution: false,
OutboundTraffic: false,
InboundTraffic: false,
}
}
// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance
// This blocks all external network access except specific allowlisted endpoints
func HIPAACompliantPolicy(allowlist []string) NetworkPolicy {
return NetworkPolicy{
Mode: "bridge",
AllowedEndpoints: allowlist,
BlockedSubnets: []string{"0.0.0.0/0"}, // Block everything by default
DNSResolution: len(allowlist) > 0, // Only allow DNS if endpoints specified
OutboundTraffic: len(allowlist) > 0, // Only allow outbound if endpoints specified
InboundTraffic: false, // Never allow inbound
}
}
// Validate checks the network policy for security violations
func (np *NetworkPolicy) Validate() error {
// Validate mode
validModes := map[string]bool{
"none": true,
"bridge": true,
"container": true,
"host": true,
}
if !validModes[np.Mode] {
return fmt.Errorf("invalid network mode: %q", np.Mode)
}
// Block host network mode in production
if np.Mode == "host" {
return fmt.Errorf("host network mode is not allowed for security reasons")
}
// Validate allowed endpoints format
for _, endpoint := range np.AllowedEndpoints {
if !isValidEndpoint(endpoint) {
return fmt.Errorf("invalid endpoint format: %q (expected host:port)", endpoint)
}
}
// Validate CIDR blocks
for _, cidr := range np.BlockedSubnets {
if !isValidCIDR(cidr) {
return fmt.Errorf("invalid CIDR format: %q", cidr)
}
}
return nil
}
// isValidEndpoint checks if an endpoint string is valid (host:port format)
func isValidEndpoint(endpoint string) bool {
if endpoint == "" {
return false
}
parts := strings.Split(endpoint, ":")
if len(parts) != 2 {
return false
}
// Basic validation - port should be numeric
if _, err := parsePort(parts[1]); err != nil {
return false
}
return true
}
// isValidCIDR performs basic CIDR validation
func isValidCIDR(cidr string) bool {
// Simple validation - check for / separator
if !strings.Contains(cidr, "/") {
return false
}
parts := strings.Split(cidr, "/")
if len(parts) != 2 {
return false
}
// Check prefix is numeric
if _, err := parsePort(parts[1]); err != nil {
return false
}
return true
}
// parsePort parses a port string (helper for validation)
func parsePort(s string) (int, error) {
port := 0
for _, c := range s {
if c < '0' || c > '9' {
return 0, fmt.Errorf("invalid port")
}
port = port*10 + int(c-'0')
}
return port, nil
}
// ApplyNetworkPolicy applies network policy enforcement to a podman command
// This creates iptables rules and returns the modified command with network options
func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) {
if err := policy.Validate(); err != nil {
return nil, fmt.Errorf("invalid network policy: %w", err)
}
// Apply network mode
args := append(baseArgs, "--network", policy.Mode)
// For bridge mode with specific restrictions, we need to create a custom network
if policy.Mode == "bridge" && len(policy.AllowedEndpoints) > 0 {
// Add additional network restrictions via iptables (applied externally)
// The container will be started with the bridge network, but external
// firewall rules will restrict its connectivity
// Set environment variables to inform the container about network restrictions
args = append(args, "-e", "FETCHML_NETWORK_RESTRICTED=1")
if !policy.DNSResolution {
args = append(args, "-e", "FETCHML_DNS_DISABLED=1")
}
}
// Disable DNS if required (via /etc/resolv.conf bind mount)
if !policy.DNSResolution {
// Mount empty resolv.conf to disable DNS
emptyResolv, err := createEmptyResolvConf()
if err == nil {
args = append(args, "-v", fmt.Sprintf("%s:/etc/resolv.conf:ro", emptyResolv))
}
}
return args, nil
}
// createEmptyResolvConf creates a temporary empty resolv.conf file
func createEmptyResolvConf() (string, error) {
tmpDir := os.TempDir()
path := filepath.Join(tmpDir, "empty-resolv.conf")
// Create empty file if it doesn't exist
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := os.WriteFile(path, []byte{}, 0644); err != nil {
return "", err
}
}
return path, nil
}
// SetupExternalFirewall sets up external firewall rules for a container
// This is called after the container starts to enforce egress filtering
// NOTE: This requires root or CAP_NET_ADMIN capability
func SetupExternalFirewall(containerID string, policy NetworkPolicy) error {
// This function requires root privileges and iptables
// It's meant to be called from a privileged helper or init container
if len(policy.BlockedSubnets) == 0 && len(policy.AllowedEndpoints) == 0 {
return nil // No rules to apply
}
// Get container PID for network namespace targeting
pid, err := getContainerPID(containerID)
if err != nil {
return fmt.Errorf("failed to get container PID: %w", err)
}
// Create iptables commands in the container's network namespace
// This requires nsenter with appropriate capabilities
// Block all outbound traffic by default
if !policy.OutboundTraffic {
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-A", "OUTPUT", "-j", "DROP")
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to block outbound traffic: %w", err)
}
}
// Allow specific endpoints
for _, endpoint := range policy.AllowedEndpoints {
host, port := parseEndpoint(endpoint)
if host != "" {
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-I", "OUTPUT", "1",
"-p", "tcp", "-d", host, "--dport", port, "-j", "ACCEPT")
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to allow endpoint %s: %w", endpoint, err)
}
}
}
return nil
}
// getContainerPID retrieves the PID of a running container
func getContainerPID(containerID string) (string, error) {
cmd := exec.Command("podman", "inspect", "-f", "{{.State.Pid}}", containerID)
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}
// parseEndpoint splits an endpoint string into host and port
func parseEndpoint(endpoint string) (host, port string) {
parts := strings.Split(endpoint, ":")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", ""
}
// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration
func NetworkPolicyFromSandbox(
networkMode string,
allowedEndpoints []string,
blockedSubnets []string,
) NetworkPolicy {
// Use defaults if not specified
if networkMode == "" {
networkMode = "none"
}
if len(blockedSubnets) == 0 {
blockedSubnets = DefaultNetworkPolicy().BlockedSubnets
}
return NetworkPolicy{
Mode: networkMode,
AllowedEndpoints: allowedEndpoints,
BlockedSubnets: blockedSubnets,
DNSResolution: networkMode != "none" && len(allowedEndpoints) > 0,
OutboundTraffic: networkMode != "none" && len(allowedEndpoints) > 0,
InboundTraffic: false, // Never allow inbound by default
}
}

View file

@ -0,0 +1,91 @@
// Package process provides process isolation and security enforcement for worker tasks.
// This file implements Network Micro-Segmentation enforcement hooks (Windows stub).
//go:build windows
// +build windows
package process
import (
"fmt"
)
// NetworkPolicy defines network segmentation rules for a task
// (Windows stub - policy enforcement handled differently on Windows)
type NetworkPolicy struct {
Mode string
AllowedEndpoints []string
BlockedSubnets []string
DNSResolution bool
OutboundTraffic bool
InboundTraffic bool
}
// DefaultNetworkPolicy returns a hardened default network policy (Windows stub)
func DefaultNetworkPolicy() NetworkPolicy {
return NetworkPolicy{
Mode: "none",
AllowedEndpoints: []string{},
BlockedSubnets: []string{},
DNSResolution: false,
OutboundTraffic: false,
InboundTraffic: false,
}
}
// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance (Windows stub)
func HIPAACompliantPolicy(allowlist []string) NetworkPolicy {
return NetworkPolicy{
Mode: "none",
AllowedEndpoints: allowlist,
BlockedSubnets: []string{},
DNSResolution: false,
OutboundTraffic: false,
InboundTraffic: false,
}
}
// Validate checks the network policy for security violations (Windows stub)
func (np *NetworkPolicy) Validate() error {
// On Windows, only "none" mode is supported without additional tooling
if np.Mode != "none" && np.Mode != "" {
return fmt.Errorf("network mode %q not supported on Windows (use 'none' or implement via Windows Firewall)", np.Mode)
}
return nil
}
// ApplyNetworkPolicy applies network policy enforcement (Windows stub)
func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) {
if err := policy.Validate(); err != nil {
return nil, fmt.Errorf("invalid network policy: %w", err)
}
// On Windows, just set the network mode
args := append(baseArgs, "--network", policy.Mode)
return args, nil
}
// SetupExternalFirewall sets up external firewall rules (Windows stub - no-op)
func SetupExternalFirewall(containerID string, policy NetworkPolicy) error {
// Windows firewall integration would require PowerShell or netsh
// For now, this is a no-op - rely on container runtime's default restrictions
return nil
}
// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration (Windows stub)
func NetworkPolicyFromSandbox(
networkMode string,
allowedEndpoints []string,
blockedSubnets []string,
) NetworkPolicy {
if networkMode == "" {
networkMode = "none"
}
return NetworkPolicy{
Mode: networkMode,
AllowedEndpoints: allowedEndpoints,
BlockedSubnets: blockedSubnets,
DNSResolution: false,
OutboundTraffic: false,
InboundTraffic: false,
}
}

View file

@ -0,0 +1,163 @@
package plugins__test
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
)
type fakeJupyterManager struct {
startFn func(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
stopFn func(ctx context.Context, serviceID string) error
removeFn func(ctx context.Context, serviceID string, purge bool) error
restoreFn func(ctx context.Context, name string) (string, error)
listFn func() []*jupyter.JupyterService
listPkgsFn func(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
func (f *fakeJupyterManager) StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
return f.startFn(ctx, req)
}
func (f *fakeJupyterManager) StopService(ctx context.Context, serviceID string) error {
return f.stopFn(ctx, serviceID)
}
func (f *fakeJupyterManager) RemoveService(ctx context.Context, serviceID string, purge bool) error {
return f.removeFn(ctx, serviceID, purge)
}
func (f *fakeJupyterManager) RestoreWorkspace(ctx context.Context, name string) (string, error) {
return f.restoreFn(ctx, name)
}
func (f *fakeJupyterManager) ListServices() []*jupyter.JupyterService {
return f.listFn()
}
func (f *fakeJupyterManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
if f.listPkgsFn == nil {
return nil, nil
}
return f.listPkgsFn(ctx, serviceName)
}
type jupyterOutput struct {
Type string `json:"type"`
Service *struct {
Name string `json:"name"`
URL string `json:"url"`
} `json:"service"`
}
type jupyterPackagesOutput struct {
Type string `json:"type"`
Packages []struct {
Name string `json:"name"`
Version string `json:"version"`
Source string `json:"source"`
} `json:"packages"`
}
func TestRunJupyterTaskStartSuccess(t *testing.T) {
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
if req.Name != "my-workspace" {
return nil, errors.New("bad name")
}
return &jupyter.JupyterService{Name: req.Name, URL: "http://127.0.0.1:8888"}, nil
},
stopFn: func(context.Context, string) error { return nil },
removeFn: func(context.Context, string, bool) error { return nil },
restoreFn: func(context.Context, string) (string, error) { return "", nil },
listFn: func() []*jupyter.JupyterService { return nil },
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
})
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
"task_type": "jupyter",
"jupyter_action": "start",
"jupyter_name": "my-workspace",
"jupyter_workspace": "my-workspace",
}}
out, err := plugins.RunJupyterTask(context.Background(), w, task)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if len(out) == 0 {
t.Fatalf("expected output")
}
var decoded jupyterOutput
if err := json.Unmarshal(out, &decoded); err != nil {
t.Fatalf("expected valid JSON, got %v", err)
}
if decoded.Service == nil || decoded.Service.Name != "my-workspace" {
t.Fatalf("expected service name to be my-workspace, got %#v", decoded.Service)
}
}
func TestRunJupyterTaskStopFailure(t *testing.T) {
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
removeFn: func(context.Context, string, bool) error { return nil },
restoreFn: func(context.Context, string) (string, error) { return "", nil },
listFn: func() []*jupyter.JupyterService { return nil },
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
})
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
"task_type": "jupyter",
"jupyter_action": "stop",
"jupyter_service_id": "svc-1",
}}
_, err := plugins.RunJupyterTask(context.Background(), w, task)
if err == nil {
t.Fatalf("expected error")
}
}
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return nil },
removeFn: func(context.Context, string, bool) error { return nil },
restoreFn: func(context.Context, string) (string, error) { return "", nil },
listFn: func() []*jupyter.JupyterService { return nil },
listPkgsFn: func(_ context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
if serviceName != "my-workspace" {
return nil, errors.New("bad service")
}
return []jupyter.InstalledPackage{
{Name: "numpy", Version: "1.26.0", Source: "pip"},
{Name: "pandas", Version: "2.1.0", Source: "conda"},
}, nil
},
})
task := &queue.Task{JobName: "jupyter-packages-my-workspace", Metadata: map[string]string{
"task_type": "jupyter",
"jupyter_action": "list_packages",
"jupyter_name": "my-workspace",
}}
out, err := plugins.RunJupyterTask(context.Background(), w, task)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
var decoded jupyterPackagesOutput
if err := json.Unmarshal(out, &decoded); err != nil {
t.Fatalf("expected valid JSON, got %v", err)
}
if len(decoded.Packages) != 2 {
t.Fatalf("expected 2 packages, got %d", len(decoded.Packages))
}
}

View file

@ -0,0 +1,257 @@
package plugins__test
import (
"context"
"encoding/json"
"log/slog"
"os"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewVLLMPlugin(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Test with default image
p1 := plugins.NewVLLMPlugin(logger, "")
assert.NotNil(t, p1)
// Test with custom image
p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0")
assert.NotNil(t, p2)
}
func TestVLLMPluginValidateConfig(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
wantErr bool
}{
{
name: "valid config",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
GpuMemoryUtilization: 0.9,
},
wantErr: false,
},
{
name: "missing model",
config: plugins.VLLMConfig{
Model: "",
},
wantErr: true,
},
{
name: "invalid gpu memory high",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: 1.5,
},
wantErr: true,
},
{
name: "invalid gpu memory low",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: -0.1,
},
wantErr: true,
},
{
name: "valid quantization awq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
wantErr: false,
},
{
name: "valid quantization gptq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "gptq",
},
wantErr: false,
},
{
name: "invalid quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := plugin.ValidateConfig(&tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVLLMPluginBuildCommand(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
port int
expected []string
}{
{
name: "basic command",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
},
port: 8000,
expected: []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", "meta-llama/Llama-2-7b",
"--port", "8000",
"--host", "0.0.0.0",
},
},
{
name: "with tensor parallelism",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-70b",
TensorParallelSize: 4,
},
port: 8000,
expected: []string{
"--tensor-parallel-size", "4",
},
},
{
name: "with quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
port: 8000,
expected: []string{
"--quantization", "awq",
},
},
{
name: "with trust remote code",
config: plugins.VLLMConfig{
Model: "custom-model",
TrustRemoteCode: true,
},
port: 8000,
expected: []string{
"--trust-remote-code",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := plugin.BuildCommand(&tt.config, tt.port)
for _, exp := range tt.expected {
assert.Contains(t, cmd, exp)
}
assert.Contains(t, cmd, "--model")
assert.Contains(t, cmd, "--port")
assert.Contains(t, cmd, "8000")
})
}
}
func TestVLLMPluginGetServiceTemplate(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
template := plugin.GetServiceTemplate()
assert.Equal(t, "service", template.JobType)
assert.Equal(t, "service", template.SlotPool)
assert.Equal(t, 1, template.GPUCount)
// Verify health check endpoints
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness)
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness)
assert.Equal(t, 30, template.HealthCheck.Interval)
assert.Equal(t, 10, template.HealthCheck.Timeout)
// Verify mounts
require.Len(t, template.Mounts, 2)
assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source)
assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination)
// Verify resource requirements
assert.Equal(t, 1, template.Resources.MinGPU)
assert.Equal(t, 8, template.Resources.MaxGPU)
assert.Equal(t, 16, template.Resources.MinMemory)
}
func TestRunVLLMTask(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
ctx := context.Background()
t.Run("missing config", func(t *testing.T) {
metadata := map[string]string{}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing 'vllm_config'")
})
t.Run("invalid config json", func(t *testing.T) {
metadata := map[string]string{
"vllm_config": "invalid json",
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid config")
})
t.Run("invalid config values", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "", // Missing model
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "model name is required")
})
t.Run("successful start", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "starting", result["status"])
assert.Equal(t, "meta-llama/Llama-2-7b", result["model"])
assert.Equal(t, "vllm", result["plugin"])
})
}