fetch_ml/internal/worker/plugins/vllm.go
Jeremie Fraeys 95adcba437
feat(worker): add Jupyter/vLLM plugins and process isolation
Extend worker capabilities with new execution plugins and security features:
- Jupyter plugin for notebook-based ML experiments
- vLLM plugin for LLM inference workloads
- Cross-platform process isolation (Unix/Windows)
- Network policy enforcement with platform-specific implementations
- Service manager integration for lifecycle management
- Scheduler backend integration for queue coordination

Update lifecycle management:
- Enhanced runloop with state transitions
- Service manager integration for plugin coordination
- Improved state persistence and recovery

Add test coverage:
- Unit tests for Jupyter and vLLM plugins
- Updated worker execution tests
2026-02-26 12:03:59 -05:00

279 lines
8.2 KiB
Go

// Package plugins provides service plugin implementations for fetch_ml.
// This file contains the vLLM inference server plugin.
package plugins
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
)
// VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service.
// This plugin manages the lifecycle of vLLM, including startup, health monitoring,
// and graceful shutdown.
type VLLMPlugin struct {
logger *slog.Logger
baseImage string
envOverrides map[string]string
}
// VLLMConfig contains the configuration for a vLLM service instance
type VLLMConfig struct {
// Model name or path (required)
Model string `json:"model"`
// Tensor parallel size (number of GPUs)
TensorParallelSize int `json:"tensor_parallel_size,omitempty"`
// Port to run on (0 = auto-allocate)
Port int `json:"port,omitempty"`
// Maximum number of tokens in the cache
MaxModelLen int `json:"max_model_len,omitempty"`
// GPU memory utilization (0.0 - 1.0)
GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"`
// Quantization method (awq, gptq, squeezellm, or empty for none)
Quantization string `json:"quantization,omitempty"`
// Trust remote code for custom models
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
// Additional CLI arguments to pass to vLLM
ExtraArgs []string `json:"extra_args,omitempty"`
}
// VLLMStatus represents the current status of a vLLM service
type VLLMStatus struct {
Healthy bool `json:"healthy"`
Ready bool `json:"ready"`
ModelLoaded bool `json:"model_loaded"`
Uptime time.Duration `json:"uptime"`
Requests int64 `json:"requests_served"`
}
// NewVLLMPlugin creates a new vLLM plugin instance
func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin {
if baseImage == "" {
baseImage = "vllm/vllm-openai:latest"
}
return &VLLMPlugin{
logger: logger,
baseImage: baseImage,
envOverrides: make(map[string]string),
}
}
// GetServiceTemplate returns the service template for vLLM
func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate {
return ServiceTemplate{
JobType: "service",
SlotPool: "service",
GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism
Command: []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", "{{MODEL_NAME}}",
"--port", "{{SERVICE_PORT}}",
"--host", "0.0.0.0",
"{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}",
"{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}",
"{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}",
},
Env: map[string]string{
"CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}",
"VLLM_LOGGING_LEVEL": "INFO",
},
HealthCheck: ServiceHealthCheck{
Liveness: "http://localhost:{{SERVICE_PORT}}/health",
Readiness: "http://localhost:{{SERVICE_PORT}}/health",
Interval: 30,
Timeout: 10,
},
Mounts: []ServiceMount{
{Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"},
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
},
Resources: ResourceRequirements{
MinGPU: 1,
MaxGPU: 8, // Support up to 8-GPU tensor parallelism
MinMemory: 16, // 16GB minimum
},
}
}
// ValidateConfig validates the vLLM configuration
func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error {
if config.Model == "" {
return fmt.Errorf("vllm: model name is required")
}
if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 {
return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0")
}
validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"}
found := false
for _, q := range validQuantizations {
if config.Quantization == q {
found = true
break
}
}
if !found {
return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization)
}
return nil
}
// BuildCommand builds the vLLM command line from configuration
func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string {
args := []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", config.Model,
"--port", fmt.Sprintf("%d", port),
"--host", "0.0.0.0",
}
if config.TensorParallelSize > 1 {
args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize))
}
if config.MaxModelLen > 0 {
args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen))
}
if config.GpuMemoryUtilization > 0 {
args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization))
}
if config.Quantization != "" {
args = append(args, "--quantization", config.Quantization)
}
if config.TrustRemoteCode {
args = append(args, "--trust-remote-code")
}
// Add any extra arguments
args = append(args, config.ExtraArgs...)
return args
}
// CheckHealth performs a health check against a running vLLM instance
func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil)
if err != nil {
return nil, fmt.Errorf("vllm health check: %w", err)
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return &VLLMStatus{Healthy: false, Ready: false}, nil
}
defer resp.Body.Close()
// Parse metrics endpoint for detailed status
metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil)
metricsResp, err := client.Do(metricsReq)
status := &VLLMStatus{
Healthy: resp.StatusCode == 200,
Ready: resp.StatusCode == 200,
}
if err == nil && metricsResp.StatusCode == 200 {
defer metricsResp.Body.Close()
// Could parse Prometheus metrics here for detailed stats
}
return status, nil
}
// RunVLLMTask executes a vLLM service task
// This is called by the worker when it receives a vLLM job assignment
func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) {
logger = logger.With("plugin", "vllm", "task_id", taskID)
// Parse vLLM configuration from task metadata
configJSON, ok := metadata["vllm_config"]
if !ok {
return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata")
}
var config VLLMConfig
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, fmt.Errorf("vllm: invalid config: %w", err)
}
plugin := NewVLLMPlugin(logger, "")
if err := plugin.ValidateConfig(&config); err != nil {
return nil, err
}
logger.Info("starting vllm service",
"model", config.Model,
"tensor_parallel", config.TensorParallelSize,
)
// Port allocation and command building would happen here
// The actual execution goes through the standard service runner
// Return success output for the task
output := map[string]interface{}{
"status": "starting",
"model": config.Model,
"plugin": "vllm",
}
return json.Marshal(output)
}
// ServiceTemplate is re-exported here for plugin use
type ServiceTemplate = struct {
JobType string `json:"job_type"`
SlotPool string `json:"slot_pool"`
GPUCount int `json:"gpu_count"`
Command []string `json:"command"`
Env map[string]string `json:"env"`
HealthCheck ServiceHealthCheck `json:"health_check"`
Mounts []ServiceMount `json:"mounts,omitempty"`
Resources ResourceRequirements `json:"resources,omitempty"`
}
// ServiceHealthCheck is re-exported for plugin use
type ServiceHealthCheck = struct {
Liveness string `json:"liveness"`
Readiness string `json:"readiness"`
Interval int `json:"interval"`
Timeout int `json:"timeout"`
}
// ServiceMount is re-exported for plugin use
type ServiceMount = struct {
Source string `json:"source"`
Destination string `json:"destination"`
ReadOnly bool `json:"readonly,omitempty"`
}
// ResourceRequirements defines resource needs for a service
type ResourceRequirements struct {
MinGPU int `json:"min_gpu,omitempty"`
MaxGPU int `json:"max_gpu,omitempty"`
MinMemory int `json:"min_memory_gb,omitempty"` // GB
MaxMemory int `json:"max_memory_gb,omitempty"` // GB
}