fetch_ml/tests/unit/worker/plugins/vllm_test.go
Jeremie Fraeys 17170667e2
feat(worker): improve lifecycle management and vLLM plugin
Lifecycle improvements:
- runloop.go: refined state machine with better error recovery
- service_manager.go: service dependency management and health checks
- states.go: add states for capability advertisement and draining

Container execution:
- container.go: improved OCI runtime integration with supply chain checks
- Add image verification and signature validation
- Better resource limits enforcement for GPU/memory

vLLM plugin updates:
- vllm.go: support for vLLM 0.3+ with new engine arguments
- Add quantization-aware scheduling (AWQ, GPTQ, FP8)
- Improve model download and caching logic

Configuration:
- config.go: add capability advertisement configuration
- snapshot_store.go: improve snapshot management for checkpointing
2026-03-12 12:05:02 -04:00

257 lines
6.2 KiB
Go

package plugins__test
import (
"context"
"encoding/json"
"log/slog"
"os"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewVLLMPlugin(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Test with default image
p1 := plugins.NewVLLMPlugin(logger, "")
assert.NotNil(t, p1)
// Test with custom image
p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0")
assert.NotNil(t, p2)
}
func TestVLLMPluginValidateConfig(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
wantErr bool
}{
{
name: "valid config",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
GpuMemoryUtilization: 0.9,
},
wantErr: false,
},
{
name: "missing model",
config: plugins.VLLMConfig{
Model: "",
},
wantErr: true,
},
{
name: "invalid gpu memory high",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: 1.5,
},
wantErr: true,
},
{
name: "invalid gpu memory low",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: -0.1,
},
wantErr: true,
},
{
name: "valid quantization awq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
wantErr: false,
},
{
name: "valid quantization gptq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "gptq",
},
wantErr: false,
},
{
name: "invalid quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := plugin.ValidateConfig(&tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVLLMPluginBuildCommand(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
port int
expected []string
}{
{
name: "basic command",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
},
port: 8000,
expected: []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", "meta-llama/Llama-2-7b",
"--port", "8000",
"--host", "0.0.0.0",
},
},
{
name: "with tensor parallelism",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-70b",
TensorParallelSize: 4,
},
port: 8000,
expected: []string{
"--tensor-parallel-size", "4",
},
},
{
name: "with quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
port: 8000,
expected: []string{
"--quantization", "awq",
},
},
{
name: "with trust remote code",
config: plugins.VLLMConfig{
Model: "custom-model",
TrustRemoteCode: true,
},
port: 8000,
expected: []string{
"--trust-remote-code",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := plugin.BuildCommand(&tt.config, tt.port)
for _, exp := range tt.expected {
assert.Contains(t, cmd, exp)
}
assert.Contains(t, cmd, "--model")
assert.Contains(t, cmd, "--port")
assert.Contains(t, cmd, "8000")
})
}
}
func TestVLLMPluginGetServiceTemplate(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
template := plugin.GetServiceTemplate()
assert.Equal(t, "service", template.JobType)
assert.Equal(t, "service", template.SlotPool)
assert.Equal(t, 1, template.GPUCount)
// Verify health check endpoints
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness)
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness)
assert.Equal(t, 30, template.HealthCheck.Interval)
assert.Equal(t, 10, template.HealthCheck.Timeout)
// Verify mounts
require.Len(t, template.Mounts, 2)
assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source)
assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination)
// Verify resource requirements
assert.Equal(t, 1, template.Resources.MinGPU)
assert.Equal(t, 8, template.Resources.MaxGPU)
assert.Equal(t, 16, template.Resources.MinMemory)
}
func TestRunVLLMTask(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
ctx := context.Background()
t.Run("missing config", func(t *testing.T) {
metadata := map[string]string{}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing 'vllm_config'")
})
t.Run("invalid config json", func(t *testing.T) {
metadata := map[string]string{
"vllm_config": "invalid json",
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid config")
})
t.Run("invalid config values", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "", // Missing model
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "model name is required")
})
t.Run("successful start", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
require.NoError(t, err)
var result map[string]any
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "starting", result["status"])
assert.Equal(t, "meta-llama/Llama-2-7b", result["model"])
assert.Equal(t, "vllm", result["plugin"])
})
}