Extend worker capabilities with new execution plugins and security features: - Jupyter plugin for notebook-based ML experiments - vLLM plugin for LLM inference workloads - Cross-platform process isolation (Unix/Windows) - Network policy enforcement with platform-specific implementations - Service manager integration for lifecycle management - Scheduler backend integration for queue coordination Update lifecycle management: - Enhanced runloop with state transitions - Service manager integration for plugin coordination - Improved state persistence and recovery Add test coverage: - Unit tests for Jupyter and vLLM plugins - Updated worker execution tests
279 lines
8.2 KiB
Go
279 lines
8.2 KiB
Go
// Package plugins provides service plugin implementations for fetch_ml.
|
|
// This file contains the vLLM inference server plugin.
|
|
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service.
|
|
// This plugin manages the lifecycle of vLLM, including startup, health monitoring,
|
|
// and graceful shutdown.
|
|
type VLLMPlugin struct {
|
|
logger *slog.Logger
|
|
baseImage string
|
|
envOverrides map[string]string
|
|
}
|
|
|
|
// VLLMConfig contains the configuration for a vLLM service instance
|
|
type VLLMConfig struct {
|
|
// Model name or path (required)
|
|
Model string `json:"model"`
|
|
|
|
// Tensor parallel size (number of GPUs)
|
|
TensorParallelSize int `json:"tensor_parallel_size,omitempty"`
|
|
|
|
// Port to run on (0 = auto-allocate)
|
|
Port int `json:"port,omitempty"`
|
|
|
|
// Maximum number of tokens in the cache
|
|
MaxModelLen int `json:"max_model_len,omitempty"`
|
|
|
|
// GPU memory utilization (0.0 - 1.0)
|
|
GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"`
|
|
|
|
// Quantization method (awq, gptq, squeezellm, or empty for none)
|
|
Quantization string `json:"quantization,omitempty"`
|
|
|
|
// Trust remote code for custom models
|
|
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
|
|
|
|
// Additional CLI arguments to pass to vLLM
|
|
ExtraArgs []string `json:"extra_args,omitempty"`
|
|
}
|
|
|
|
// VLLMStatus represents the current status of a vLLM service
|
|
type VLLMStatus struct {
|
|
Healthy bool `json:"healthy"`
|
|
Ready bool `json:"ready"`
|
|
ModelLoaded bool `json:"model_loaded"`
|
|
Uptime time.Duration `json:"uptime"`
|
|
Requests int64 `json:"requests_served"`
|
|
}
|
|
|
|
// NewVLLMPlugin creates a new vLLM plugin instance
|
|
func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin {
|
|
if baseImage == "" {
|
|
baseImage = "vllm/vllm-openai:latest"
|
|
}
|
|
return &VLLMPlugin{
|
|
logger: logger,
|
|
baseImage: baseImage,
|
|
envOverrides: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// GetServiceTemplate returns the service template for vLLM
|
|
func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate {
|
|
return ServiceTemplate{
|
|
JobType: "service",
|
|
SlotPool: "service",
|
|
GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism
|
|
|
|
Command: []string{
|
|
"python", "-m", "vllm.entrypoints.openai.api_server",
|
|
"--model", "{{MODEL_NAME}}",
|
|
"--port", "{{SERVICE_PORT}}",
|
|
"--host", "0.0.0.0",
|
|
"{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}",
|
|
"{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}",
|
|
"{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}",
|
|
},
|
|
|
|
Env: map[string]string{
|
|
"CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}",
|
|
"VLLM_LOGGING_LEVEL": "INFO",
|
|
},
|
|
|
|
HealthCheck: ServiceHealthCheck{
|
|
Liveness: "http://localhost:{{SERVICE_PORT}}/health",
|
|
Readiness: "http://localhost:{{SERVICE_PORT}}/health",
|
|
Interval: 30,
|
|
Timeout: 10,
|
|
},
|
|
|
|
Mounts: []ServiceMount{
|
|
{Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"},
|
|
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
|
|
},
|
|
|
|
Resources: ResourceRequirements{
|
|
MinGPU: 1,
|
|
MaxGPU: 8, // Support up to 8-GPU tensor parallelism
|
|
MinMemory: 16, // 16GB minimum
|
|
},
|
|
}
|
|
}
|
|
|
|
// ValidateConfig validates the vLLM configuration
|
|
func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error {
|
|
if config.Model == "" {
|
|
return fmt.Errorf("vllm: model name is required")
|
|
}
|
|
|
|
if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 {
|
|
return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0")
|
|
}
|
|
|
|
validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"}
|
|
found := false
|
|
for _, q := range validQuantizations {
|
|
if config.Quantization == q {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// BuildCommand builds the vLLM command line from configuration
|
|
func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string {
|
|
args := []string{
|
|
"python", "-m", "vllm.entrypoints.openai.api_server",
|
|
"--model", config.Model,
|
|
"--port", fmt.Sprintf("%d", port),
|
|
"--host", "0.0.0.0",
|
|
}
|
|
|
|
if config.TensorParallelSize > 1 {
|
|
args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize))
|
|
}
|
|
|
|
if config.MaxModelLen > 0 {
|
|
args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen))
|
|
}
|
|
|
|
if config.GpuMemoryUtilization > 0 {
|
|
args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization))
|
|
}
|
|
|
|
if config.Quantization != "" {
|
|
args = append(args, "--quantization", config.Quantization)
|
|
}
|
|
|
|
if config.TrustRemoteCode {
|
|
args = append(args, "--trust-remote-code")
|
|
}
|
|
|
|
// Add any extra arguments
|
|
args = append(args, config.ExtraArgs...)
|
|
|
|
return args
|
|
}
|
|
|
|
// CheckHealth performs a health check against a running vLLM instance
|
|
func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("vllm health check: %w", err)
|
|
}
|
|
|
|
client := &http.Client{Timeout: 5 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return &VLLMStatus{Healthy: false, Ready: false}, nil
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Parse metrics endpoint for detailed status
|
|
metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil)
|
|
metricsResp, err := client.Do(metricsReq)
|
|
|
|
status := &VLLMStatus{
|
|
Healthy: resp.StatusCode == 200,
|
|
Ready: resp.StatusCode == 200,
|
|
}
|
|
|
|
if err == nil && metricsResp.StatusCode == 200 {
|
|
defer metricsResp.Body.Close()
|
|
// Could parse Prometheus metrics here for detailed stats
|
|
}
|
|
|
|
return status, nil
|
|
}
|
|
|
|
// RunVLLMTask executes a vLLM service task
|
|
// This is called by the worker when it receives a vLLM job assignment
|
|
func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) {
|
|
logger = logger.With("plugin", "vllm", "task_id", taskID)
|
|
|
|
// Parse vLLM configuration from task metadata
|
|
configJSON, ok := metadata["vllm_config"]
|
|
if !ok {
|
|
return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata")
|
|
}
|
|
|
|
var config VLLMConfig
|
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
|
return nil, fmt.Errorf("vllm: invalid config: %w", err)
|
|
}
|
|
|
|
plugin := NewVLLMPlugin(logger, "")
|
|
if err := plugin.ValidateConfig(&config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info("starting vllm service",
|
|
"model", config.Model,
|
|
"tensor_parallel", config.TensorParallelSize,
|
|
)
|
|
|
|
// Port allocation and command building would happen here
|
|
// The actual execution goes through the standard service runner
|
|
|
|
// Return success output for the task
|
|
output := map[string]interface{}{
|
|
"status": "starting",
|
|
"model": config.Model,
|
|
"plugin": "vllm",
|
|
}
|
|
|
|
return json.Marshal(output)
|
|
}
|
|
|
|
// ServiceTemplate is re-exported here for plugin use
|
|
type ServiceTemplate = struct {
|
|
JobType string `json:"job_type"`
|
|
SlotPool string `json:"slot_pool"`
|
|
GPUCount int `json:"gpu_count"`
|
|
Command []string `json:"command"`
|
|
Env map[string]string `json:"env"`
|
|
HealthCheck ServiceHealthCheck `json:"health_check"`
|
|
Mounts []ServiceMount `json:"mounts,omitempty"`
|
|
Resources ResourceRequirements `json:"resources,omitempty"`
|
|
}
|
|
|
|
// ServiceHealthCheck is re-exported for plugin use
|
|
type ServiceHealthCheck = struct {
|
|
Liveness string `json:"liveness"`
|
|
Readiness string `json:"readiness"`
|
|
Interval int `json:"interval"`
|
|
Timeout int `json:"timeout"`
|
|
}
|
|
|
|
// ServiceMount is re-exported for plugin use
|
|
type ServiceMount = struct {
|
|
Source string `json:"source"`
|
|
Destination string `json:"destination"`
|
|
ReadOnly bool `json:"readonly,omitempty"`
|
|
}
|
|
|
|
// ResourceRequirements defines resource needs for a service
|
|
type ResourceRequirements struct {
|
|
MinGPU int `json:"min_gpu,omitempty"`
|
|
MaxGPU int `json:"max_gpu,omitempty"`
|
|
MinMemory int `json:"min_memory_gb,omitempty"` // GB
|
|
MaxMemory int `json:"max_memory_gb,omitempty"` // GB
|
|
}
|