// Package plugins provides service plugin implementations for fetch_ml. // This file contains the vLLM inference server plugin. package plugins import ( "context" "encoding/json" "fmt" "log/slog" "net/http" "time" ) // VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service. // This plugin manages the lifecycle of vLLM, including startup, health monitoring, // and graceful shutdown. type VLLMPlugin struct { logger *slog.Logger baseImage string envOverrides map[string]string } // VLLMConfig contains the configuration for a vLLM service instance type VLLMConfig struct { // Model name or path (required) Model string `json:"model"` // Tensor parallel size (number of GPUs) TensorParallelSize int `json:"tensor_parallel_size,omitempty"` // Port to run on (0 = auto-allocate) Port int `json:"port,omitempty"` // Maximum number of tokens in the cache MaxModelLen int `json:"max_model_len,omitempty"` // GPU memory utilization (0.0 - 1.0) GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` // Quantization method (awq, gptq, squeezellm, or empty for none) Quantization string `json:"quantization,omitempty"` // Trust remote code for custom models TrustRemoteCode bool `json:"trust_remote_code,omitempty"` // Additional CLI arguments to pass to vLLM ExtraArgs []string `json:"extra_args,omitempty"` } // VLLMStatus represents the current status of a vLLM service type VLLMStatus struct { Healthy bool `json:"healthy"` Ready bool `json:"ready"` ModelLoaded bool `json:"model_loaded"` Uptime time.Duration `json:"uptime"` Requests int64 `json:"requests_served"` } // NewVLLMPlugin creates a new vLLM plugin instance func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin { if baseImage == "" { baseImage = "vllm/vllm-openai:latest" } return &VLLMPlugin{ logger: logger, baseImage: baseImage, envOverrides: make(map[string]string), } } // GetServiceTemplate returns the service template for vLLM func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate { return ServiceTemplate{ JobType: "service", SlotPool: "service", GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism Command: []string{ "python", "-m", "vllm.entrypoints.openai.api_server", "--model", "{{MODEL_NAME}}", "--port", "{{SERVICE_PORT}}", "--host", "0.0.0.0", "{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}", "{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}", "{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}", }, Env: map[string]string{ "CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}", "VLLM_LOGGING_LEVEL": "INFO", }, HealthCheck: ServiceHealthCheck{ Liveness: "http://localhost:{{SERVICE_PORT}}/health", Readiness: "http://localhost:{{SERVICE_PORT}}/health", Interval: 30, Timeout: 10, }, Mounts: []ServiceMount{ {Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"}, {Source: "{{WORKSPACE}}", Destination: "/workspace"}, }, Resources: ResourceRequirements{ MinGPU: 1, MaxGPU: 8, // Support up to 8-GPU tensor parallelism MinMemory: 16, // 16GB minimum }, } } // ValidateConfig validates the vLLM configuration func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error { if config.Model == "" { return fmt.Errorf("vllm: model name is required") } if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 { return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0") } validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"} found := false for _, q := range validQuantizations { if config.Quantization == q { found = true break } } if !found { return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization) } return nil } // BuildCommand builds the vLLM command line from configuration func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string { args := []string{ "python", "-m", "vllm.entrypoints.openai.api_server", "--model", config.Model, "--port", fmt.Sprintf("%d", port), "--host", "0.0.0.0", } if config.TensorParallelSize > 1 { args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize)) } if config.MaxModelLen > 0 { args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen)) } if config.GpuMemoryUtilization > 0 { args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization)) } if config.Quantization != "" { args = append(args, "--quantization", config.Quantization) } if config.TrustRemoteCode { args = append(args, "--trust-remote-code") } // Add any extra arguments args = append(args, config.ExtraArgs...) return args } // CheckHealth performs a health check against a running vLLM instance func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil) if err != nil { return nil, fmt.Errorf("vllm health check: %w", err) } client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Do(req) if err != nil { return &VLLMStatus{Healthy: false, Ready: false}, nil } defer resp.Body.Close() // Parse metrics endpoint for detailed status metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil) metricsResp, err := client.Do(metricsReq) status := &VLLMStatus{ Healthy: resp.StatusCode == 200, Ready: resp.StatusCode == 200, } if err == nil && metricsResp.StatusCode == 200 { defer metricsResp.Body.Close() // Could parse Prometheus metrics here for detailed stats } return status, nil } // RunVLLMTask executes a vLLM service task // This is called by the worker when it receives a vLLM job assignment func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) { logger = logger.With("plugin", "vllm", "task_id", taskID) // Parse vLLM configuration from task metadata configJSON, ok := metadata["vllm_config"] if !ok { return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata") } var config VLLMConfig if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, fmt.Errorf("vllm: invalid config: %w", err) } plugin := NewVLLMPlugin(logger, "") if err := plugin.ValidateConfig(&config); err != nil { return nil, err } logger.Info("starting vllm service", "model", config.Model, "tensor_parallel", config.TensorParallelSize, ) // Port allocation and command building would happen here // The actual execution goes through the standard service runner // Return success output for the task output := map[string]interface{}{ "status": "starting", "model": config.Model, "plugin": "vllm", } return json.Marshal(output) } // ServiceTemplate is re-exported here for plugin use type ServiceTemplate = struct { JobType string `json:"job_type"` SlotPool string `json:"slot_pool"` GPUCount int `json:"gpu_count"` Command []string `json:"command"` Env map[string]string `json:"env"` HealthCheck ServiceHealthCheck `json:"health_check"` Mounts []ServiceMount `json:"mounts,omitempty"` Resources ResourceRequirements `json:"resources,omitempty"` } // ServiceHealthCheck is re-exported for plugin use type ServiceHealthCheck = struct { Liveness string `json:"liveness"` Readiness string `json:"readiness"` Interval int `json:"interval"` Timeout int `json:"timeout"` } // ServiceMount is re-exported for plugin use type ServiceMount = struct { Source string `json:"source"` Destination string `json:"destination"` ReadOnly bool `json:"readonly,omitempty"` } // ResourceRequirements defines resource needs for a service type ResourceRequirements struct { MinGPU int `json:"min_gpu,omitempty"` MaxGPU int `json:"max_gpu,omitempty"` MinMemory int `json:"min_memory_gb,omitempty"` // GB MaxMemory int `json:"max_memory_gb,omitempty"` // GB }