From 95adcba43717861a69c526cc3a38f0cb80a648e8 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 26 Feb 2026 12:03:59 -0500 Subject: [PATCH] feat(worker): add Jupyter/vLLM plugins and process isolation Extend worker capabilities with new execution plugins and security features: - Jupyter plugin for notebook-based ML experiments - vLLM plugin for LLM inference workloads - Cross-platform process isolation (Unix/Windows) - Network policy enforcement with platform-specific implementations - Service manager integration for lifecycle management - Scheduler backend integration for queue coordination Update lifecycle management: - Enhanced runloop with state transitions - Service manager integration for plugin coordination - Improved state persistence and recovery Add test coverage: - Unit tests for Jupyter and vLLM plugins - Updated worker execution tests --- internal/queue/scheduler_backend.go | 227 ++++++++++++++ internal/worker/lifecycle/runloop.go | 18 +- internal/worker/lifecycle/service_manager.go | 248 ++++++++++++++++ internal/worker/lifecycle/states.go | 24 +- internal/worker/plugins/jupyter.go | 133 +++++++++ internal/worker/plugins/vllm.go | 279 ++++++++++++++++++ internal/worker/process/isolation.go | 63 ++++ internal/worker/process/isolation_unix.go | 122 ++++++++ internal/worker/process/isolation_windows.go | 41 +++ internal/worker/process/network_policy.go | 278 +++++++++++++++++ .../worker/process/network_policy_windows.go | 91 ++++++ .../unit/worker/plugins/jupyter_task_test.go | 163 ++++++++++ tests/unit/worker/plugins/vllm_test.go | 257 ++++++++++++++++ 13 files changed, 1928 insertions(+), 16 deletions(-) create mode 100644 internal/queue/scheduler_backend.go create mode 100644 internal/worker/lifecycle/service_manager.go create mode 100644 internal/worker/plugins/jupyter.go create mode 100644 internal/worker/plugins/vllm.go create mode 100644 internal/worker/process/isolation.go create mode 100644 internal/worker/process/isolation_unix.go create mode 100644 internal/worker/process/isolation_windows.go create mode 100644 internal/worker/process/network_policy.go create mode 100644 internal/worker/process/network_policy_windows.go create mode 100644 tests/unit/worker/plugins/jupyter_task_test.go create mode 100644 tests/unit/worker/plugins/vllm_test.go diff --git a/internal/queue/scheduler_backend.go b/internal/queue/scheduler_backend.go new file mode 100644 index 0000000..efdb312 --- /dev/null +++ b/internal/queue/scheduler_backend.go @@ -0,0 +1,227 @@ +package queue + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" +) + +// SchedulerConn defines the interface for scheduler connection +type SchedulerConn interface { + Send(msg scheduler.Message) +} + +// SchedulerBackend implements queue.Backend by communicating with a scheduler over WebSocket +type SchedulerBackend struct { + conn SchedulerConn + pendingTasks chan *Task + prewarmHint *Task + localSlots scheduler.SlotStatus +} + +// NewSchedulerBackend creates a new scheduler backend +func NewSchedulerBackend(conn SchedulerConn) *SchedulerBackend { + return &SchedulerBackend{ + conn: conn, + pendingTasks: make(chan *Task, 1), + } +} + +// GetNextTaskWithLeaseBlocking implements queue.Backend +func (sb *SchedulerBackend) GetNextTaskWithLeaseBlocking( + workerID string, + leaseDuration time.Duration, + blockTimeout time.Duration, +) (*Task, error) { + // Signal readiness with current slot status + sb.conn.Send(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: workerID, + Slots: sb.localSlots, + Reason: "polling", + }), + }) + + // Wait for scheduler push or timeout + select { + case task := <-sb.pendingTasks: + return task, nil + case <-time.After(blockTimeout): + return nil, nil // RunLoop retries — same behaviour as empty queue + } +} + +// OnJobAssign is called when the scheduler pushes a job assignment +func (sb *SchedulerBackend) OnJobAssign(spec *scheduler.JobSpec) { + task := &Task{ + ID: spec.ID, + JobName: "distributed-job", + Priority: 1, + Status: "assigned", + } + select { + case sb.pendingTasks <- task: + default: + // Channel full, drop (shouldn't happen with buffer of 1) + } +} + +// UpdateTask implements queue.Backend - forwards state changes to scheduler +func (sb *SchedulerBackend) UpdateTask(task *Task) error { + // Notify scheduler of state change + result := scheduler.JobResultPayload{ + TaskID: task.ID, + State: task.Status, + } + if task.Error != "" { + result.Error = task.Error + } + sb.conn.Send(scheduler.Message{ + Type: scheduler.MsgJobResult, + Payload: mustMarshal(result), + }) + return nil +} + +// PeekNextTask implements queue.Backend - returns last prewarm hint +func (sb *SchedulerBackend) PeekNextTask() (*Task, error) { + if sb.prewarmHint == nil { + return nil, nil + } + return sb.prewarmHint, nil +} + +// OnPrewarmHint stores a prewarm hint from the scheduler +func (sb *SchedulerBackend) OnPrewarmHint(hint scheduler.PrewarmHintPayload) { + sb.prewarmHint = &Task{ + ID: hint.TaskID, + Status: "prewarm_hint", + Metadata: map[string]string{ + "snapshot_id": hint.SnapshotID, + "snapshot_sha": hint.SnapshotSHA, + }, + } +} + +// UpdateSlots updates the local slot status for readiness signaling +func (sb *SchedulerBackend) UpdateSlots(running, max int) { + sb.localSlots = scheduler.SlotStatus{ + BatchTotal: max, + BatchInUse: running, + } +} + +// Required queue.Backend methods (stub implementations for now) + +func (sb *SchedulerBackend) AddTask(task *Task) error { + return fmt.Errorf("AddTask not supported in distributed mode - submit to scheduler instead") +} + +func (sb *SchedulerBackend) GetNextTask() (*Task, error) { + return nil, fmt.Errorf("GetNextTask not supported - use GetNextTaskWithLeaseBlocking") +} + +func (sb *SchedulerBackend) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { + return sb.GetNextTaskWithLeaseBlocking(workerID, leaseDuration, 0) +} + +func (sb *SchedulerBackend) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error { + // Distributed mode: scheduler manages leases + return nil +} + +func (sb *SchedulerBackend) ReleaseLease(taskID string, workerID string) error { + // Distributed mode: scheduler manages leases + return nil +} + +func (sb *SchedulerBackend) RetryTask(task *Task) error { + sb.conn.Send(scheduler.Message{ + Type: scheduler.MsgJobResult, + Payload: mustMarshal(scheduler.JobResultPayload{TaskID: task.ID, State: "retry"}), + }) + return nil +} + +func (sb *SchedulerBackend) MoveToDeadLetterQueue(task *Task, reason string) error { + return nil // Scheduler handles this +} + +func (sb *SchedulerBackend) GetTask(taskID string) (*Task, error) { + return nil, fmt.Errorf("GetTask not supported in distributed mode") +} + +func (sb *SchedulerBackend) GetAllTasks() ([]*Task, error) { + return nil, fmt.Errorf("GetAllTasks not supported in distributed mode") +} + +func (sb *SchedulerBackend) GetTaskByName(jobName string) (*Task, error) { + return nil, fmt.Errorf("GetTaskByName not supported in distributed mode") +} + +func (sb *SchedulerBackend) CancelTask(taskID string) error { + return fmt.Errorf("CancelTask not supported - use scheduler API") +} + +func (sb *SchedulerBackend) UpdateTaskWithMetrics(task *Task, action string) error { + return sb.UpdateTask(task) +} + +func (sb *SchedulerBackend) RecordMetric(jobName, metric string, value float64) error { + return nil // Metrics handled by scheduler +} + +func (sb *SchedulerBackend) Heartbeat(workerID string) error { + return nil // Heartbeat handled by SchedulerConn +} + +func (sb *SchedulerBackend) QueueDepth() (int64, error) { + return 0, nil // Queue depth managed by scheduler +} + +func (sb *SchedulerBackend) SetWorkerPrewarmState(state PrewarmState) error { + return nil +} + +func (sb *SchedulerBackend) ClearWorkerPrewarmState(workerID string) error { + return nil +} + +func (sb *SchedulerBackend) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) { + return nil, nil +} + +func (sb *SchedulerBackend) GetAllWorkerPrewarmStates() ([]PrewarmState, error) { + return nil, nil +} + +func (sb *SchedulerBackend) SignalPrewarmGC() error { + return nil +} + +func (sb *SchedulerBackend) PrewarmGCRequestValue() (string, error) { + return "", nil +} + +func (sb *SchedulerBackend) Close() error { + return nil +} + +// Conn returns the underlying scheduler connection for heartbeat +func (sb *SchedulerBackend) Conn() *scheduler.SchedulerConn { + if conn, ok := sb.conn.(*scheduler.SchedulerConn); ok { + return conn + } + return nil +} + +func mustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +// Compile-time check - ensure SchedulerBackend implements the interface +var _ Backend = (*SchedulerBackend)(nil) diff --git a/internal/worker/lifecycle/runloop.go b/internal/worker/lifecycle/runloop.go index ba6d687..c9e0546 100644 --- a/internal/worker/lifecycle/runloop.go +++ b/internal/worker/lifecycle/runloop.go @@ -28,18 +28,16 @@ type RunLoopConfig struct { // RunLoop manages the main worker processing loop type RunLoop struct { - config RunLoopConfig - queue queue.Backend - executor TaskExecutor - metrics MetricsRecorder - logger Logger - stateMgr *StateManager - - // State management - running map[string]context.CancelFunc - runningMu sync.RWMutex + queue queue.Backend + executor TaskExecutor + metrics MetricsRecorder + logger Logger ctx context.Context + stateMgr *StateManager + running map[string]context.CancelFunc cancel context.CancelFunc + config RunLoopConfig + runningMu sync.RWMutex } // MetricsRecorder defines the contract for recording metrics diff --git a/internal/worker/lifecycle/service_manager.go b/internal/worker/lifecycle/service_manager.go new file mode 100644 index 0000000..59f5ab8 --- /dev/null +++ b/internal/worker/lifecycle/service_manager.go @@ -0,0 +1,248 @@ +// Package lifecycle provides service job lifecycle management for long-running services. +package lifecycle + +import ( + "context" + "fmt" + "net/http" + "os/exec" + "syscall" + "time" + + "github.com/jfraeys/fetch_ml/internal/domain" + "github.com/jfraeys/fetch_ml/internal/scheduler" +) + +// ServiceManager handles the lifecycle of long-running service jobs (Jupyter, vLLM) +type ServiceManager struct { + task *domain.Task + spec ServiceSpec + cmd *exec.Cmd + port int + stateMgr *StateManager + logger Logger + healthCheck *scheduler.HealthCheck +} + +// ServiceSpec defines the specification for a service job +type ServiceSpec struct { + Command []string + Env map[string]string + Port int + HealthCheck *scheduler.HealthCheck +} + +// NewServiceManager creates a new service manager +func NewServiceManager(task *domain.Task, spec ServiceSpec, port int, stateMgr *StateManager, logger Logger) *ServiceManager { + return &ServiceManager{ + task: task, + spec: spec, + port: port, + stateMgr: stateMgr, + logger: logger, + healthCheck: spec.HealthCheck, + } +} + +// Run executes the service lifecycle: start, wait for ready, health loop +func (sm *ServiceManager) Run(ctx context.Context) error { + if err := sm.start(); err != nil { + sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err) + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateFailed) + } + return fmt.Errorf("start service: %w", err) + } + + // Wait for readiness + readyCtx, cancel := context.WithTimeout(ctx, 120*time.Second) + defer cancel() + if err := sm.waitReady(readyCtx); err != nil { + sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err) + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateFailed) + } + sm.stop() + return fmt.Errorf("wait ready: %w", err) + } + + // Transition to serving state + if sm.stateMgr != nil { + if err := sm.stateMgr.Transition(sm.task, StateServing); err != nil { + sm.logger.Error("failed to transition to serving", "task_id", sm.task.ID, "error", err) + } + } + + sm.logger.Info("service is serving", "task_id", sm.task.ID, "port", sm.port) + + // Run health loop + return sm.healthLoop(ctx) +} + +// start launches the service process +func (sm *ServiceManager) start() error { + if len(sm.spec.Command) == 0 { + return fmt.Errorf("no command specified") + } + + sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...) + + // Set environment + for k, v := range sm.spec.Env { + sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + // Start process + if err := sm.cmd.Start(); err != nil { + return fmt.Errorf("start process: %w", err) + } + + return nil +} + +// waitReady blocks until the service passes readiness check or timeout +func (sm *ServiceManager) waitReady(ctx context.Context) error { + if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" { + // No readiness check - assume ready immediately + return nil + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if sm.checkReadiness() { + return nil + } + } + } +} + +// healthLoop monitors service health until context cancellation +func (sm *ServiceManager) healthLoop(ctx context.Context) error { + if sm.healthCheck == nil { + // No health check - just wait for context cancellation + <-ctx.Done() + return sm.gracefulStop() + } + + interval := time.Duration(sm.healthCheck.IntervalSecs) * time.Second + if interval == 0 { + interval = 15 * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return sm.gracefulStop() + case <-ticker.C: + if !sm.checkLiveness() { + sm.logger.Error("service liveness check failed", "task_id", sm.task.ID) + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateFailed) + } + return fmt.Errorf("liveness check failed") + } + } + } +} + +// checkLiveness returns true if the service process is alive +func (sm *ServiceManager) checkLiveness() bool { + if sm.cmd == nil || sm.cmd.Process == nil { + return false + } + + // Check process state + if sm.healthCheck != nil && sm.healthCheck.LivenessEndpoint != "" { + // HTTP liveness check + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(sm.healthCheck.LivenessEndpoint) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == 200 + } + + // Process existence check + return sm.cmd.Process.Signal(syscall.Signal(0)) == nil +} + +// checkReadiness returns true if the service is ready to accept traffic +func (sm *ServiceManager) checkReadiness() bool { + if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" { + return true + } + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(sm.healthCheck.ReadinessEndpoint) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == 200 +} + +// gracefulStop stops the service gracefully with a timeout +func (sm *ServiceManager) gracefulStop() error { + sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID) + + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateStopping) + } + + if sm.cmd == nil || sm.cmd.Process == nil { + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateCompleted) + } + return nil + } + + // Send SIGTERM for graceful shutdown + if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil { + sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err) + sm.cmd.Process.Kill() + } else { + // Wait for graceful shutdown + done := make(chan error, 1) + go func() { + done <- sm.cmd.Wait() + }() + + select { + case <-done: + // Graceful shutdown completed + case <-time.After(30 * time.Second): + // Timeout - force kill + sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID) + sm.cmd.Process.Kill() + } + } + + if sm.stateMgr != nil { + sm.stateMgr.Transition(sm.task, StateCompleted) + } + + return nil +} + +// stop forcefully stops the service +func (sm *ServiceManager) stop() error { + if sm.cmd == nil || sm.cmd.Process == nil { + return nil + } + return sm.cmd.Process.Kill() +} + +// GetPort returns the assigned port for the service +func (sm *ServiceManager) GetPort() int { + return sm.port +} diff --git a/internal/worker/lifecycle/states.go b/internal/worker/lifecycle/states.go index 607d660..08cdeb2 100644 --- a/internal/worker/lifecycle/states.go +++ b/internal/worker/lifecycle/states.go @@ -27,18 +27,30 @@ const ( StateCompleted TaskState = "completed" // StateFailed indicates the task failed during execution. StateFailed TaskState = "failed" + // StateCancelled indicates the task was cancelled by user or system. + StateCancelled TaskState = "cancelled" + // StateOrphaned indicates the worker was lost - task will be silently re-queued. + StateOrphaned TaskState = "orphaned" + // StateServing indicates a service job (Jupyter/vLLM) is up and ready. + StateServing TaskState = "serving" + // StateStopping indicates a service job is in graceful shutdown. + StateStopping TaskState = "stopping" ) // ValidTransitions defines the allowed state transitions. // The key is the "from" state, the value is a list of valid "to" states. // This enforces that state transitions follow the expected lifecycle. var ValidTransitions = map[TaskState][]TaskState{ - StateQueued: {StatePreparing, StateFailed}, - StatePreparing: {StateRunning, StateFailed}, - StateRunning: {StateCollecting, StateFailed}, - StateCollecting: {StateCompleted, StateFailed}, + StateQueued: {StatePreparing, StateFailed, StateCancelled}, + StatePreparing: {StateRunning, StateFailed, StateCancelled}, + StateRunning: {StateCollecting, StateFailed, StateOrphaned, StateServing, StateCancelled}, + StateCollecting: {StateCompleted, StateFailed, StateCancelled}, StateCompleted: {}, StateFailed: {}, + StateCancelled: {}, + StateOrphaned: {StateQueued}, + StateServing: {StateStopping, StateFailed, StateOrphaned, StateCancelled}, + StateStopping: {StateCompleted, StateFailed}, } // StateTransitionError is returned when an invalid state transition is attempted. @@ -53,8 +65,8 @@ func (e StateTransitionError) Error() string { // StateManager manages task state transitions with audit logging. type StateManager struct { - enabled bool auditor *audit.Logger + enabled bool } // NewStateManager creates a new state manager with the given audit logger. @@ -116,7 +128,7 @@ func (sm *StateManager) validateTransition(from, to TaskState) error { // IsTerminalState returns true if the state is terminal (no further transitions allowed). func IsTerminalState(state TaskState) bool { - return state == StateCompleted || state == StateFailed + return state == StateCompleted || state == StateFailed || state == StateCancelled } // CanTransition returns true if a transition from -> to is valid. diff --git a/internal/worker/plugins/jupyter.go b/internal/worker/plugins/jupyter.go new file mode 100644 index 0000000..49e0457 --- /dev/null +++ b/internal/worker/plugins/jupyter.go @@ -0,0 +1,133 @@ +// Package plugins provides framework-specific extensions to the worker. +// This implements the prolog/epilog model where plugins hook into task lifecycle events. +package plugins + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// JupyterManager interface for jupyter service management +type JupyterManager interface { + StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) + StopService(ctx context.Context, serviceID string) error + RemoveService(ctx context.Context, serviceID string, purge bool) error + RestoreWorkspace(ctx context.Context, name string) (string, error) + ListServices() []*jupyter.JupyterService + ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) +} + +// TaskRunner executes framework-specific tasks +// This is the minimal interface needed by plugins to execute tasks +type TaskRunner interface { + GetJupyterManager() JupyterManager +} + +// RunJupyterTask runs a Jupyter-related task. +// It handles start, stop, remove, restore, and list_packages actions. +// This is a plugin function that extends the core worker with Jupyter support. +func RunJupyterTask(ctx context.Context, runner TaskRunner, task *queue.Task) ([]byte, error) { + jm := runner.GetJupyterManager() + if jm == nil { + return nil, fmt.Errorf("jupyter manager not configured") + } + + action := task.Metadata["jupyter_action"] + if action == "" { + action = "start" // Default action + } + + switch action { + case "start": + name := task.Metadata["jupyter_name"] + if name == "" { + name = task.Metadata["jupyter_workspace"] + } + if name == "" { + // Extract from jobName if format is "jupyter-" + if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" { + name = task.JobName[8:] + } + } + if name == "" { + return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") + } + + req := &jupyter.StartRequest{Name: name} + service, err := jm.StartService(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to start jupyter service: %w", err) + } + + output := map[string]any{ + "type": "start", + "service": service, + } + return json.Marshal(output) + + case "stop": + serviceID := task.Metadata["jupyter_service_id"] + if serviceID == "" { + return nil, fmt.Errorf("missing jupyter_service_id in task metadata") + } + if err := jm.StopService(ctx, serviceID); err != nil { + return nil, fmt.Errorf("failed to stop jupyter service: %w", err) + } + return json.Marshal(map[string]string{"type": "stop", "status": "stopped"}) + + case "remove": + serviceID := task.Metadata["jupyter_service_id"] + if serviceID == "" { + return nil, fmt.Errorf("missing jupyter_service_id in task metadata") + } + purge := task.Metadata["jupyter_purge"] == "true" + if err := jm.RemoveService(ctx, serviceID, purge); err != nil { + return nil, fmt.Errorf("failed to remove jupyter service: %w", err) + } + return json.Marshal(map[string]string{"type": "remove", "status": "removed"}) + + case "restore": + name := task.Metadata["jupyter_name"] + if name == "" { + name = task.Metadata["jupyter_workspace"] + } + if name == "" { + return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") + } + serviceID, err := jm.RestoreWorkspace(ctx, name) + if err != nil { + return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err) + } + return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID}) + + case "list_packages": + serviceName := task.Metadata["jupyter_name"] + if serviceName == "" { + // Extract from jobName if format is "jupyter-packages-" + if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" { + serviceName = task.JobName[16:] + } + } + if serviceName == "" { + return nil, fmt.Errorf("missing jupyter_name in task metadata") + } + + packages, err := jm.ListInstalledPackages(ctx, serviceName) + if err != nil { + return nil, fmt.Errorf("failed to list installed packages: %w", err) + } + + output := map[string]any{ + "type": "list_packages", + "packages": packages, + } + return json.Marshal(output) + + default: + return nil, fmt.Errorf("unknown jupyter action: %s", action) + } +} diff --git a/internal/worker/plugins/vllm.go b/internal/worker/plugins/vllm.go new file mode 100644 index 0000000..c239283 --- /dev/null +++ b/internal/worker/plugins/vllm.go @@ -0,0 +1,279 @@ +// Package plugins provides service plugin implementations for fetch_ml. +// This file contains the vLLM inference server plugin. +package plugins + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" +) + +// VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service. +// This plugin manages the lifecycle of vLLM, including startup, health monitoring, +// and graceful shutdown. +type VLLMPlugin struct { + logger *slog.Logger + baseImage string + envOverrides map[string]string +} + +// VLLMConfig contains the configuration for a vLLM service instance +type VLLMConfig struct { + // Model name or path (required) + Model string `json:"model"` + + // Tensor parallel size (number of GPUs) + TensorParallelSize int `json:"tensor_parallel_size,omitempty"` + + // Port to run on (0 = auto-allocate) + Port int `json:"port,omitempty"` + + // Maximum number of tokens in the cache + MaxModelLen int `json:"max_model_len,omitempty"` + + // GPU memory utilization (0.0 - 1.0) + GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` + + // Quantization method (awq, gptq, squeezellm, or empty for none) + Quantization string `json:"quantization,omitempty"` + + // Trust remote code for custom models + TrustRemoteCode bool `json:"trust_remote_code,omitempty"` + + // Additional CLI arguments to pass to vLLM + ExtraArgs []string `json:"extra_args,omitempty"` +} + +// VLLMStatus represents the current status of a vLLM service +type VLLMStatus struct { + Healthy bool `json:"healthy"` + Ready bool `json:"ready"` + ModelLoaded bool `json:"model_loaded"` + Uptime time.Duration `json:"uptime"` + Requests int64 `json:"requests_served"` +} + +// NewVLLMPlugin creates a new vLLM plugin instance +func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin { + if baseImage == "" { + baseImage = "vllm/vllm-openai:latest" + } + return &VLLMPlugin{ + logger: logger, + baseImage: baseImage, + envOverrides: make(map[string]string), + } +} + +// GetServiceTemplate returns the service template for vLLM +func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate { + return ServiceTemplate{ + JobType: "service", + SlotPool: "service", + GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism + + Command: []string{ + "python", "-m", "vllm.entrypoints.openai.api_server", + "--model", "{{MODEL_NAME}}", + "--port", "{{SERVICE_PORT}}", + "--host", "0.0.0.0", + "{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}", + "{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}", + "{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}", + }, + + Env: map[string]string{ + "CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}", + "VLLM_LOGGING_LEVEL": "INFO", + }, + + HealthCheck: ServiceHealthCheck{ + Liveness: "http://localhost:{{SERVICE_PORT}}/health", + Readiness: "http://localhost:{{SERVICE_PORT}}/health", + Interval: 30, + Timeout: 10, + }, + + Mounts: []ServiceMount{ + {Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"}, + {Source: "{{WORKSPACE}}", Destination: "/workspace"}, + }, + + Resources: ResourceRequirements{ + MinGPU: 1, + MaxGPU: 8, // Support up to 8-GPU tensor parallelism + MinMemory: 16, // 16GB minimum + }, + } +} + +// ValidateConfig validates the vLLM configuration +func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error { + if config.Model == "" { + return fmt.Errorf("vllm: model name is required") + } + + if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 { + return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0") + } + + validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"} + found := false + for _, q := range validQuantizations { + if config.Quantization == q { + found = true + break + } + } + if !found { + return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization) + } + + return nil +} + +// BuildCommand builds the vLLM command line from configuration +func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string { + args := []string{ + "python", "-m", "vllm.entrypoints.openai.api_server", + "--model", config.Model, + "--port", fmt.Sprintf("%d", port), + "--host", "0.0.0.0", + } + + if config.TensorParallelSize > 1 { + args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize)) + } + + if config.MaxModelLen > 0 { + args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen)) + } + + if config.GpuMemoryUtilization > 0 { + args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization)) + } + + if config.Quantization != "" { + args = append(args, "--quantization", config.Quantization) + } + + if config.TrustRemoteCode { + args = append(args, "--trust-remote-code") + } + + // Add any extra arguments + args = append(args, config.ExtraArgs...) + + return args +} + +// CheckHealth performs a health check against a running vLLM instance +func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil) + if err != nil { + return nil, fmt.Errorf("vllm health check: %w", err) + } + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return &VLLMStatus{Healthy: false, Ready: false}, nil + } + defer resp.Body.Close() + + // Parse metrics endpoint for detailed status + metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil) + metricsResp, err := client.Do(metricsReq) + + status := &VLLMStatus{ + Healthy: resp.StatusCode == 200, + Ready: resp.StatusCode == 200, + } + + if err == nil && metricsResp.StatusCode == 200 { + defer metricsResp.Body.Close() + // Could parse Prometheus metrics here for detailed stats + } + + return status, nil +} + +// RunVLLMTask executes a vLLM service task +// This is called by the worker when it receives a vLLM job assignment +func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) { + logger = logger.With("plugin", "vllm", "task_id", taskID) + + // Parse vLLM configuration from task metadata + configJSON, ok := metadata["vllm_config"] + if !ok { + return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata") + } + + var config VLLMConfig + if err := json.Unmarshal([]byte(configJSON), &config); err != nil { + return nil, fmt.Errorf("vllm: invalid config: %w", err) + } + + plugin := NewVLLMPlugin(logger, "") + if err := plugin.ValidateConfig(&config); err != nil { + return nil, err + } + + logger.Info("starting vllm service", + "model", config.Model, + "tensor_parallel", config.TensorParallelSize, + ) + + // Port allocation and command building would happen here + // The actual execution goes through the standard service runner + + // Return success output for the task + output := map[string]interface{}{ + "status": "starting", + "model": config.Model, + "plugin": "vllm", + } + + return json.Marshal(output) +} + +// ServiceTemplate is re-exported here for plugin use +type ServiceTemplate = struct { + JobType string `json:"job_type"` + SlotPool string `json:"slot_pool"` + GPUCount int `json:"gpu_count"` + Command []string `json:"command"` + Env map[string]string `json:"env"` + HealthCheck ServiceHealthCheck `json:"health_check"` + Mounts []ServiceMount `json:"mounts,omitempty"` + Resources ResourceRequirements `json:"resources,omitempty"` +} + +// ServiceHealthCheck is re-exported for plugin use +type ServiceHealthCheck = struct { + Liveness string `json:"liveness"` + Readiness string `json:"readiness"` + Interval int `json:"interval"` + Timeout int `json:"timeout"` +} + +// ServiceMount is re-exported for plugin use +type ServiceMount = struct { + Source string `json:"source"` + Destination string `json:"destination"` + ReadOnly bool `json:"readonly,omitempty"` +} + +// ResourceRequirements defines resource needs for a service +type ResourceRequirements struct { + MinGPU int `json:"min_gpu,omitempty"` + MaxGPU int `json:"max_gpu,omitempty"` + MinMemory int `json:"min_memory_gb,omitempty"` // GB + MaxMemory int `json:"max_memory_gb,omitempty"` // GB +} diff --git a/internal/worker/process/isolation.go b/internal/worker/process/isolation.go new file mode 100644 index 0000000..b9e0d67 --- /dev/null +++ b/internal/worker/process/isolation.go @@ -0,0 +1,63 @@ +// Package process provides process isolation and resource limiting for HIPAA compliance. +// Implements Worker Process Isolation controls. +package process + +import ( + "fmt" + "os" + "runtime" + "syscall" +) + +// IsolationConfig holds process isolation parameters +type IsolationConfig struct { + MaxProcesses int // Fork bomb protection (RLIMIT_NPROC on Linux) + MaxOpenFiles int // FD exhaustion protection (RLIMIT_NOFILE) + DisableSwap bool // Prevent swap exfiltration + OOMScoreAdj int // OOM killer priority adjustment (Linux only) +} + +// ApplyIsolation applies process isolation controls to the current process. +// This should be called after forking but before execing the target command. +func ApplyIsolation(cfg IsolationConfig) error { + // Apply resource limits (platform-specific) + if err := applyResourceLimits(cfg); err != nil { + return err + } + + // OOM score adjustment (only on Linux) + if cfg.OOMScoreAdj != 0 && runtime.GOOS == "linux" { + if err := setOOMScoreAdj(cfg.OOMScoreAdj); err != nil { + return fmt.Errorf("failed to set OOM score adjustment: %w", err) + } + } + + // Disable swap (Linux only) - requires CAP_SYS_RESOURCE or root + if cfg.DisableSwap && runtime.GOOS == "linux" { + if err := disableSwap(); err != nil { + // Log but don't fail - swap disabling requires privileges + // This is best-effort security hardening + } + } + + return nil +} + +// setOOMScoreAdj adjusts the OOM killer score (Linux only) +// Lower values = less likely to be killed (negative is "never kill") +// Higher values = more likely to be killed +func setOOMScoreAdj(score int) error { + // Write to /proc/self/oom_score_adj + path := "/proc/self/oom_score_adj" + data := []byte(fmt.Sprintf("%d\n", score)) + return os.WriteFile(path, data, 0644) +} + +// IsolatedExec runs a command with process isolation applied. +// This is a helper for container execution that applies limits before exec. +func IsolatedExec(argv []string, cfg IsolationConfig) error { + if err := ApplyIsolation(cfg); err != nil { + return err + } + return syscall.Exec(argv[0], argv, os.Environ()) +} diff --git a/internal/worker/process/isolation_unix.go b/internal/worker/process/isolation_unix.go new file mode 100644 index 0000000..4f129ff --- /dev/null +++ b/internal/worker/process/isolation_unix.go @@ -0,0 +1,122 @@ +//go:build !windows +// +build !windows + +package process + +import ( + "fmt" + "syscall" +) + +// applyResourceLimits sets Unix/Linux resource limits +func applyResourceLimits(cfg IsolationConfig) error { + // Apply file descriptor limits (RLIMIT_NOFILE for FD exhaustion protection) + if cfg.MaxOpenFiles > 0 { + if err := setResourceLimit(syscall.RLIMIT_NOFILE, uint64(cfg.MaxOpenFiles)); err != nil { + return fmt.Errorf("failed to set max open files limit: %w", err) + } + } + + // Apply process limits if available (Linux only) + if cfg.MaxProcesses > 0 { + if err := setProcessLimit(cfg.MaxProcesses); err != nil { + // Log but don't fail - this is defense in depth + return fmt.Errorf("failed to set max processes limit: %w", err) + } + } + + return nil +} + +// setResourceLimit sets a soft and hard rlimit for the current process +func setResourceLimit(resource int, limit uint64) error { + rl := &syscall.Rlimit{ + Cur: limit, + Max: limit, + } + return syscall.Setrlimit(resource, rl) +} + +// setProcessLimit sets RLIMIT_NPROC on Linux, no-op on other Unix +func setProcessLimit(maxProcs int) error { + // Try to set RLIMIT_NPROC - only available on Linux + // On Darwin/macOS, this returns ENOTSUP + const RLIMIT_NPROC = 7 // Linux value + rl := &syscall.Rlimit{ + Cur: uint64(maxProcs), + Max: uint64(maxProcs), + } + err := syscall.Setrlimit(RLIMIT_NPROC, rl) + if err != nil { + // ENOTSUP means not supported (macOS) + if err == syscall.ENOTSUP || err == syscall.EINVAL { + return nil // Silently ignore on platforms that don't support it + } + return err + } + return nil +} + +// disableSwap attempts to lock memory to prevent swapping (mlockall) +// This is best-effort and requires CAP_IPC_LOCK capability +func disableSwap() error { + // MCL_CURRENT: lock all current pages + // MCL_FUTURE: lock all future pages + const MCL_CURRENT = 0x1 + const MCL_FUTURE = 0x2 + + // Note: mlockall requires CAP_IPC_LOCK capability + // If this fails, we log but continue (defense in depth) + return syscall.Mlockall(MCL_CURRENT | MCL_FUTURE) +} + +// GetCurrentLimits returns the current rlimit values for diagnostics +func GetCurrentLimits() (map[string]uint64, error) { + limits := make(map[string]uint64) + + // Get NOFILE limit (available on all platforms) + var nofile syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &nofile); err != nil { + return nil, fmt.Errorf("failed to get NOFILE limit: %w", err) + } + limits["NOFILE_soft"] = nofile.Cur + limits["NOFILE_hard"] = nofile.Max + + // Get platform-specific limits + getPlatformLimits(limits) + + return limits, nil +} + +// getPlatformLimits adds platform-specific limits to the map +func getPlatformLimits(limits map[string]uint64) { + // Get virtual memory limit (AS) + var as syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_AS, &as); err == nil { + limits["AS_soft"] = as.Cur + limits["AS_hard"] = as.Max + } + + // Get data segment limit + var data syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_DATA, &data); err == nil { + limits["DATA_soft"] = data.Cur + limits["DATA_hard"] = data.Max + } + + // Try to get RLIMIT_NPROC (Linux only) + const RLIMIT_NPROC = 7 + var nproc syscall.Rlimit + if err := syscall.Getrlimit(RLIMIT_NPROC, &nproc); err == nil { + limits["NPROC_soft"] = nproc.Cur + limits["NPROC_hard"] = nproc.Max + } + + // Try to get RLIMIT_RSS (Linux only) + const RLIMIT_RSS = 5 + var rss syscall.Rlimit + if err := syscall.Getrlimit(RLIMIT_RSS, &rss); err == nil { + limits["RSS_soft"] = rss.Cur + limits["RSS_hard"] = rss.Max + } +} diff --git a/internal/worker/process/isolation_windows.go b/internal/worker/process/isolation_windows.go new file mode 100644 index 0000000..01d6b80 --- /dev/null +++ b/internal/worker/process/isolation_windows.go @@ -0,0 +1,41 @@ +//go:build windows +// +build windows + +package process + +import "fmt" + +// applyResourceLimits is a no-op on Windows as rlimits are Unix-specific. +// Windows uses Job Objects for process limits, which is more complex. +func applyResourceLimits(cfg IsolationConfig) error { + // Windows doesn't support setrlimit - would need Job Objects + // For now, log that limits are not enforced on Windows + if cfg.MaxOpenFiles > 0 || cfg.MaxProcesses > 0 { + // Process limits not implemented on Windows yet + // TODO: Use Windows Job Objects for process limits + return fmt.Errorf("process isolation limits not implemented on Windows (max_open_files=%d, max_processes=%d)", + cfg.MaxOpenFiles, cfg.MaxProcesses) + } + return nil +} + +// disableSwap is a no-op on Windows +func disableSwap() error { + // Windows doesn't support mlockall + // Would need VirtualLock API + return nil +} + +// GetCurrentLimits returns empty limits on Windows +func GetCurrentLimits() (map[string]uint64, error) { + limits := make(map[string]uint64) + getPlatformLimits(limits) + return limits, nil +} + +// getPlatformLimits adds Windows-specific limits (none currently) +func getPlatformLimits(limits map[string]uint64) { + // Windows doesn't have rlimit equivalent for these + // Could use GetProcessWorkingSetSize for memory limits + limits["WINDOWS"] = 1 // Marker to indicate Windows platform +} diff --git a/internal/worker/process/network_policy.go b/internal/worker/process/network_policy.go new file mode 100644 index 0000000..3e59273 --- /dev/null +++ b/internal/worker/process/network_policy.go @@ -0,0 +1,278 @@ +// Package process provides process isolation and security enforcement for worker tasks. +// This file implements Network Micro-Segmentation enforcement hooks. +//go:build linux +// +build linux + +package process + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// NetworkPolicy defines network segmentation rules for a task +type NetworkPolicy struct { + // Mode is the network isolation mode: "none", "bridge", "container", "host" + Mode string + + // AllowedEndpoints is a list of allowed network endpoints (host:port format) + // Only used when Mode is "bridge" or "container" + AllowedEndpoints []string + + // BlockedSubnets is a list of CIDR ranges to block + BlockedSubnets []string + + // DNSResolution controls DNS resolution (true = allow, false = block) + DNSResolution bool + + // OutboundTraffic controls outbound connections (true = allow, false = block) + OutboundTraffic bool + + // InboundTraffic controls inbound connections (true = allow, false = block) + InboundTraffic bool +} + +// DefaultNetworkPolicy returns a hardened default network policy +// This implements Network Micro-Segmentation +func DefaultNetworkPolicy() NetworkPolicy { + return NetworkPolicy{ + Mode: "none", + AllowedEndpoints: []string{}, + BlockedSubnets: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}, + DNSResolution: false, + OutboundTraffic: false, + InboundTraffic: false, + } +} + +// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance +// This blocks all external network access except specific allowlisted endpoints +func HIPAACompliantPolicy(allowlist []string) NetworkPolicy { + return NetworkPolicy{ + Mode: "bridge", + AllowedEndpoints: allowlist, + BlockedSubnets: []string{"0.0.0.0/0"}, // Block everything by default + DNSResolution: len(allowlist) > 0, // Only allow DNS if endpoints specified + OutboundTraffic: len(allowlist) > 0, // Only allow outbound if endpoints specified + InboundTraffic: false, // Never allow inbound + } +} + +// Validate checks the network policy for security violations +func (np *NetworkPolicy) Validate() error { + // Validate mode + validModes := map[string]bool{ + "none": true, + "bridge": true, + "container": true, + "host": true, + } + if !validModes[np.Mode] { + return fmt.Errorf("invalid network mode: %q", np.Mode) + } + + // Block host network mode in production + if np.Mode == "host" { + return fmt.Errorf("host network mode is not allowed for security reasons") + } + + // Validate allowed endpoints format + for _, endpoint := range np.AllowedEndpoints { + if !isValidEndpoint(endpoint) { + return fmt.Errorf("invalid endpoint format: %q (expected host:port)", endpoint) + } + } + + // Validate CIDR blocks + for _, cidr := range np.BlockedSubnets { + if !isValidCIDR(cidr) { + return fmt.Errorf("invalid CIDR format: %q", cidr) + } + } + + return nil +} + +// isValidEndpoint checks if an endpoint string is valid (host:port format) +func isValidEndpoint(endpoint string) bool { + if endpoint == "" { + return false + } + parts := strings.Split(endpoint, ":") + if len(parts) != 2 { + return false + } + // Basic validation - port should be numeric + if _, err := parsePort(parts[1]); err != nil { + return false + } + return true +} + +// isValidCIDR performs basic CIDR validation +func isValidCIDR(cidr string) bool { + // Simple validation - check for / separator + if !strings.Contains(cidr, "/") { + return false + } + parts := strings.Split(cidr, "/") + if len(parts) != 2 { + return false + } + // Check prefix is numeric + if _, err := parsePort(parts[1]); err != nil { + return false + } + return true +} + +// parsePort parses a port string (helper for validation) +func parsePort(s string) (int, error) { + port := 0 + for _, c := range s { + if c < '0' || c > '9' { + return 0, fmt.Errorf("invalid port") + } + port = port*10 + int(c-'0') + } + return port, nil +} + +// ApplyNetworkPolicy applies network policy enforcement to a podman command +// This creates iptables rules and returns the modified command with network options +func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) { + if err := policy.Validate(); err != nil { + return nil, fmt.Errorf("invalid network policy: %w", err) + } + + // Apply network mode + args := append(baseArgs, "--network", policy.Mode) + + // For bridge mode with specific restrictions, we need to create a custom network + if policy.Mode == "bridge" && len(policy.AllowedEndpoints) > 0 { + // Add additional network restrictions via iptables (applied externally) + // The container will be started with the bridge network, but external + // firewall rules will restrict its connectivity + + // Set environment variables to inform the container about network restrictions + args = append(args, "-e", "FETCHML_NETWORK_RESTRICTED=1") + if !policy.DNSResolution { + args = append(args, "-e", "FETCHML_DNS_DISABLED=1") + } + } + + // Disable DNS if required (via /etc/resolv.conf bind mount) + if !policy.DNSResolution { + // Mount empty resolv.conf to disable DNS + emptyResolv, err := createEmptyResolvConf() + if err == nil { + args = append(args, "-v", fmt.Sprintf("%s:/etc/resolv.conf:ro", emptyResolv)) + } + } + + return args, nil +} + +// createEmptyResolvConf creates a temporary empty resolv.conf file +func createEmptyResolvConf() (string, error) { + tmpDir := os.TempDir() + path := filepath.Join(tmpDir, "empty-resolv.conf") + + // Create empty file if it doesn't exist + if _, err := os.Stat(path); os.IsNotExist(err) { + if err := os.WriteFile(path, []byte{}, 0644); err != nil { + return "", err + } + } + + return path, nil +} + +// SetupExternalFirewall sets up external firewall rules for a container +// This is called after the container starts to enforce egress filtering +// NOTE: This requires root or CAP_NET_ADMIN capability +func SetupExternalFirewall(containerID string, policy NetworkPolicy) error { + // This function requires root privileges and iptables + // It's meant to be called from a privileged helper or init container + + if len(policy.BlockedSubnets) == 0 && len(policy.AllowedEndpoints) == 0 { + return nil // No rules to apply + } + + // Get container PID for network namespace targeting + pid, err := getContainerPID(containerID) + if err != nil { + return fmt.Errorf("failed to get container PID: %w", err) + } + + // Create iptables commands in the container's network namespace + // This requires nsenter with appropriate capabilities + + // Block all outbound traffic by default + if !policy.OutboundTraffic { + cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-A", "OUTPUT", "-j", "DROP") + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to block outbound traffic: %w", err) + } + } + + // Allow specific endpoints + for _, endpoint := range policy.AllowedEndpoints { + host, port := parseEndpoint(endpoint) + if host != "" { + cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-I", "OUTPUT", "1", + "-p", "tcp", "-d", host, "--dport", port, "-j", "ACCEPT") + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to allow endpoint %s: %w", endpoint, err) + } + } + } + + return nil +} + +// getContainerPID retrieves the PID of a running container +func getContainerPID(containerID string) (string, error) { + cmd := exec.Command("podman", "inspect", "-f", "{{.State.Pid}}", containerID) + output, err := cmd.Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(output)), nil +} + +// parseEndpoint splits an endpoint string into host and port +func parseEndpoint(endpoint string) (host, port string) { + parts := strings.Split(endpoint, ":") + if len(parts) == 2 { + return parts[0], parts[1] + } + return "", "" +} + +// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration +func NetworkPolicyFromSandbox( + networkMode string, + allowedEndpoints []string, + blockedSubnets []string, +) NetworkPolicy { + // Use defaults if not specified + if networkMode == "" { + networkMode = "none" + } + if len(blockedSubnets) == 0 { + blockedSubnets = DefaultNetworkPolicy().BlockedSubnets + } + + return NetworkPolicy{ + Mode: networkMode, + AllowedEndpoints: allowedEndpoints, + BlockedSubnets: blockedSubnets, + DNSResolution: networkMode != "none" && len(allowedEndpoints) > 0, + OutboundTraffic: networkMode != "none" && len(allowedEndpoints) > 0, + InboundTraffic: false, // Never allow inbound by default + } +} diff --git a/internal/worker/process/network_policy_windows.go b/internal/worker/process/network_policy_windows.go new file mode 100644 index 0000000..e330d97 --- /dev/null +++ b/internal/worker/process/network_policy_windows.go @@ -0,0 +1,91 @@ +// Package process provides process isolation and security enforcement for worker tasks. +// This file implements Network Micro-Segmentation enforcement hooks (Windows stub). +//go:build windows +// +build windows + +package process + +import ( + "fmt" +) + +// NetworkPolicy defines network segmentation rules for a task +// (Windows stub - policy enforcement handled differently on Windows) +type NetworkPolicy struct { + Mode string + AllowedEndpoints []string + BlockedSubnets []string + DNSResolution bool + OutboundTraffic bool + InboundTraffic bool +} + +// DefaultNetworkPolicy returns a hardened default network policy (Windows stub) +func DefaultNetworkPolicy() NetworkPolicy { + return NetworkPolicy{ + Mode: "none", + AllowedEndpoints: []string{}, + BlockedSubnets: []string{}, + DNSResolution: false, + OutboundTraffic: false, + InboundTraffic: false, + } +} + +// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance (Windows stub) +func HIPAACompliantPolicy(allowlist []string) NetworkPolicy { + return NetworkPolicy{ + Mode: "none", + AllowedEndpoints: allowlist, + BlockedSubnets: []string{}, + DNSResolution: false, + OutboundTraffic: false, + InboundTraffic: false, + } +} + +// Validate checks the network policy for security violations (Windows stub) +func (np *NetworkPolicy) Validate() error { + // On Windows, only "none" mode is supported without additional tooling + if np.Mode != "none" && np.Mode != "" { + return fmt.Errorf("network mode %q not supported on Windows (use 'none' or implement via Windows Firewall)", np.Mode) + } + return nil +} + +// ApplyNetworkPolicy applies network policy enforcement (Windows stub) +func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) { + if err := policy.Validate(); err != nil { + return nil, fmt.Errorf("invalid network policy: %w", err) + } + + // On Windows, just set the network mode + args := append(baseArgs, "--network", policy.Mode) + return args, nil +} + +// SetupExternalFirewall sets up external firewall rules (Windows stub - no-op) +func SetupExternalFirewall(containerID string, policy NetworkPolicy) error { + // Windows firewall integration would require PowerShell or netsh + // For now, this is a no-op - rely on container runtime's default restrictions + return nil +} + +// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration (Windows stub) +func NetworkPolicyFromSandbox( + networkMode string, + allowedEndpoints []string, + blockedSubnets []string, +) NetworkPolicy { + if networkMode == "" { + networkMode = "none" + } + return NetworkPolicy{ + Mode: networkMode, + AllowedEndpoints: allowedEndpoints, + BlockedSubnets: blockedSubnets, + DNSResolution: false, + OutboundTraffic: false, + InboundTraffic: false, + } +} diff --git a/tests/unit/worker/plugins/jupyter_task_test.go b/tests/unit/worker/plugins/jupyter_task_test.go new file mode 100644 index 0000000..b3f3e6a --- /dev/null +++ b/tests/unit/worker/plugins/jupyter_task_test.go @@ -0,0 +1,163 @@ +package plugins__test + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/worker/plugins" + tests "github.com/jfraeys/fetch_ml/tests/fixtures" +) + +type fakeJupyterManager struct { + startFn func(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) + stopFn func(ctx context.Context, serviceID string) error + removeFn func(ctx context.Context, serviceID string, purge bool) error + restoreFn func(ctx context.Context, name string) (string, error) + listFn func() []*jupyter.JupyterService + listPkgsFn func(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) +} + +func (f *fakeJupyterManager) StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { + return f.startFn(ctx, req) +} + +func (f *fakeJupyterManager) StopService(ctx context.Context, serviceID string) error { + return f.stopFn(ctx, serviceID) +} + +func (f *fakeJupyterManager) RemoveService(ctx context.Context, serviceID string, purge bool) error { + return f.removeFn(ctx, serviceID, purge) +} + +func (f *fakeJupyterManager) RestoreWorkspace(ctx context.Context, name string) (string, error) { + return f.restoreFn(ctx, name) +} + +func (f *fakeJupyterManager) ListServices() []*jupyter.JupyterService { + return f.listFn() +} + +func (f *fakeJupyterManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) { + if f.listPkgsFn == nil { + return nil, nil + } + return f.listPkgsFn(ctx, serviceName) +} + +type jupyterOutput struct { + Type string `json:"type"` + Service *struct { + Name string `json:"name"` + URL string `json:"url"` + } `json:"service"` +} + +type jupyterPackagesOutput struct { + Type string `json:"type"` + Packages []struct { + Name string `json:"name"` + Version string `json:"version"` + Source string `json:"source"` + } `json:"packages"` +} + +func TestRunJupyterTaskStartSuccess(t *testing.T) { + w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { + if req.Name != "my-workspace" { + return nil, errors.New("bad name") + } + return &jupyter.JupyterService{Name: req.Name, URL: "http://127.0.0.1:8888"}, nil + }, + stopFn: func(context.Context, string) error { return nil }, + removeFn: func(context.Context, string, bool) error { return nil }, + restoreFn: func(context.Context, string) (string, error) { return "", nil }, + listFn: func() []*jupyter.JupyterService { return nil }, + listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil }, + }) + + task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{ + "task_type": "jupyter", + "jupyter_action": "start", + "jupyter_name": "my-workspace", + "jupyter_workspace": "my-workspace", + }} + out, err := plugins.RunJupyterTask(context.Background(), w, task) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if len(out) == 0 { + t.Fatalf("expected output") + } + var decoded jupyterOutput + if err := json.Unmarshal(out, &decoded); err != nil { + t.Fatalf("expected valid JSON, got %v", err) + } + if decoded.Service == nil || decoded.Service.Name != "my-workspace" { + t.Fatalf("expected service name to be my-workspace, got %#v", decoded.Service) + } +} + +func TestRunJupyterTaskStopFailure(t *testing.T) { + w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, + stopFn: func(context.Context, string) error { return errors.New("stop failed") }, + removeFn: func(context.Context, string, bool) error { return nil }, + restoreFn: func(context.Context, string) (string, error) { return "", nil }, + listFn: func() []*jupyter.JupyterService { return nil }, + listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil }, + }) + + task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{ + "task_type": "jupyter", + "jupyter_action": "stop", + "jupyter_service_id": "svc-1", + }} + _, err := plugins.RunJupyterTask(context.Background(), w, task) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestRunJupyterTaskListPackagesSuccess(t *testing.T) { + w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, + stopFn: func(context.Context, string) error { return nil }, + removeFn: func(context.Context, string, bool) error { return nil }, + restoreFn: func(context.Context, string) (string, error) { return "", nil }, + listFn: func() []*jupyter.JupyterService { return nil }, + listPkgsFn: func(_ context.Context, serviceName string) ([]jupyter.InstalledPackage, error) { + if serviceName != "my-workspace" { + return nil, errors.New("bad service") + } + return []jupyter.InstalledPackage{ + {Name: "numpy", Version: "1.26.0", Source: "pip"}, + {Name: "pandas", Version: "2.1.0", Source: "conda"}, + }, nil + }, + }) + + task := &queue.Task{JobName: "jupyter-packages-my-workspace", Metadata: map[string]string{ + "task_type": "jupyter", + "jupyter_action": "list_packages", + "jupyter_name": "my-workspace", + }} + + out, err := plugins.RunJupyterTask(context.Background(), w, task) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + var decoded jupyterPackagesOutput + if err := json.Unmarshal(out, &decoded); err != nil { + t.Fatalf("expected valid JSON, got %v", err) + } + if len(decoded.Packages) != 2 { + t.Fatalf("expected 2 packages, got %d", len(decoded.Packages)) + } +} diff --git a/tests/unit/worker/plugins/vllm_test.go b/tests/unit/worker/plugins/vllm_test.go new file mode 100644 index 0000000..2f84b9d --- /dev/null +++ b/tests/unit/worker/plugins/vllm_test.go @@ -0,0 +1,257 @@ +package plugins__test + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker/plugins" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewVLLMPlugin(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + // Test with default image + p1 := plugins.NewVLLMPlugin(logger, "") + assert.NotNil(t, p1) + + // Test with custom image + p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0") + assert.NotNil(t, p2) +} + +func TestVLLMPluginValidateConfig(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + plugin := plugins.NewVLLMPlugin(logger, "") + + tests := []struct { + name string + config plugins.VLLMConfig + wantErr bool + }{ + { + name: "valid config", + config: plugins.VLLMConfig{ + Model: "meta-llama/Llama-2-7b", + TensorParallelSize: 1, + GpuMemoryUtilization: 0.9, + }, + wantErr: false, + }, + { + name: "missing model", + config: plugins.VLLMConfig{ + Model: "", + }, + wantErr: true, + }, + { + name: "invalid gpu memory high", + config: plugins.VLLMConfig{ + Model: "test-model", + GpuMemoryUtilization: 1.5, + }, + wantErr: true, + }, + { + name: "invalid gpu memory low", + config: plugins.VLLMConfig{ + Model: "test-model", + GpuMemoryUtilization: -0.1, + }, + wantErr: true, + }, + { + name: "valid quantization awq", + config: plugins.VLLMConfig{ + Model: "test-model", + Quantization: "awq", + }, + wantErr: false, + }, + { + name: "valid quantization gptq", + config: plugins.VLLMConfig{ + Model: "test-model", + Quantization: "gptq", + }, + wantErr: false, + }, + { + name: "invalid quantization", + config: plugins.VLLMConfig{ + Model: "test-model", + Quantization: "invalid", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := plugin.ValidateConfig(&tt.config) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVLLMPluginBuildCommand(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + plugin := plugins.NewVLLMPlugin(logger, "") + + tests := []struct { + name string + config plugins.VLLMConfig + port int + expected []string + }{ + { + name: "basic command", + config: plugins.VLLMConfig{ + Model: "meta-llama/Llama-2-7b", + }, + port: 8000, + expected: []string{ + "python", "-m", "vllm.entrypoints.openai.api_server", + "--model", "meta-llama/Llama-2-7b", + "--port", "8000", + "--host", "0.0.0.0", + }, + }, + { + name: "with tensor parallelism", + config: plugins.VLLMConfig{ + Model: "meta-llama/Llama-2-70b", + TensorParallelSize: 4, + }, + port: 8000, + expected: []string{ + "--tensor-parallel-size", "4", + }, + }, + { + name: "with quantization", + config: plugins.VLLMConfig{ + Model: "test-model", + Quantization: "awq", + }, + port: 8000, + expected: []string{ + "--quantization", "awq", + }, + }, + { + name: "with trust remote code", + config: plugins.VLLMConfig{ + Model: "custom-model", + TrustRemoteCode: true, + }, + port: 8000, + expected: []string{ + "--trust-remote-code", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := plugin.BuildCommand(&tt.config, tt.port) + for _, exp := range tt.expected { + assert.Contains(t, cmd, exp) + } + assert.Contains(t, cmd, "--model") + assert.Contains(t, cmd, "--port") + assert.Contains(t, cmd, "8000") + }) + } +} + +func TestVLLMPluginGetServiceTemplate(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + plugin := plugins.NewVLLMPlugin(logger, "") + + template := plugin.GetServiceTemplate() + + assert.Equal(t, "service", template.JobType) + assert.Equal(t, "service", template.SlotPool) + assert.Equal(t, 1, template.GPUCount) + + // Verify health check endpoints + assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness) + assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness) + assert.Equal(t, 30, template.HealthCheck.Interval) + assert.Equal(t, 10, template.HealthCheck.Timeout) + + // Verify mounts + require.Len(t, template.Mounts, 2) + assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source) + assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination) + + // Verify resource requirements + assert.Equal(t, 1, template.Resources.MinGPU) + assert.Equal(t, 8, template.Resources.MaxGPU) + assert.Equal(t, 16, template.Resources.MinMemory) +} + +func TestRunVLLMTask(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + ctx := context.Background() + + t.Run("missing config", func(t *testing.T) { + metadata := map[string]string{} + _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing 'vllm_config'") + }) + + t.Run("invalid config json", func(t *testing.T) { + metadata := map[string]string{ + "vllm_config": "invalid json", + } + _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid config") + }) + + t.Run("invalid config values", func(t *testing.T) { + config := plugins.VLLMConfig{ + Model: "", // Missing model + } + configJSON, _ := json.Marshal(config) + metadata := map[string]string{ + "vllm_config": string(configJSON), + } + _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "model name is required") + }) + + t.Run("successful start", func(t *testing.T) { + config := plugins.VLLMConfig{ + Model: "meta-llama/Llama-2-7b", + TensorParallelSize: 1, + } + configJSON, _ := json.Marshal(config) + metadata := map[string]string{ + "vllm_config": string(configJSON), + } + output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err) + + assert.Equal(t, "starting", result["status"]) + assert.Equal(t, "meta-llama/Llama-2-7b", result["model"]) + assert.Equal(t, "vllm", result["plugin"]) + }) +}