fetch_ml/tests/unit/worker/plugins/vllm_test.go
Jeremie Fraeys 95adcba437
feat(worker): add Jupyter/vLLM plugins and process isolation
Extend worker capabilities with new execution plugins and security features:
- Jupyter plugin for notebook-based ML experiments
- vLLM plugin for LLM inference workloads
- Cross-platform process isolation (Unix/Windows)
- Network policy enforcement with platform-specific implementations
- Service manager integration for lifecycle management
- Scheduler backend integration for queue coordination

Update lifecycle management:
- Enhanced runloop with state transitions
- Service manager integration for plugin coordination
- Improved state persistence and recovery

Add test coverage:
- Unit tests for Jupyter and vLLM plugins
- Updated worker execution tests
2026-02-26 12:03:59 -05:00

257 lines
6.2 KiB
Go

package plugins__test
import (
"context"
"encoding/json"
"log/slog"
"os"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewVLLMPlugin(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Test with default image
p1 := plugins.NewVLLMPlugin(logger, "")
assert.NotNil(t, p1)
// Test with custom image
p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0")
assert.NotNil(t, p2)
}
func TestVLLMPluginValidateConfig(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
wantErr bool
}{
{
name: "valid config",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
GpuMemoryUtilization: 0.9,
},
wantErr: false,
},
{
name: "missing model",
config: plugins.VLLMConfig{
Model: "",
},
wantErr: true,
},
{
name: "invalid gpu memory high",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: 1.5,
},
wantErr: true,
},
{
name: "invalid gpu memory low",
config: plugins.VLLMConfig{
Model: "test-model",
GpuMemoryUtilization: -0.1,
},
wantErr: true,
},
{
name: "valid quantization awq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
wantErr: false,
},
{
name: "valid quantization gptq",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "gptq",
},
wantErr: false,
},
{
name: "invalid quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := plugin.ValidateConfig(&tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVLLMPluginBuildCommand(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
tests := []struct {
name string
config plugins.VLLMConfig
port int
expected []string
}{
{
name: "basic command",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
},
port: 8000,
expected: []string{
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", "meta-llama/Llama-2-7b",
"--port", "8000",
"--host", "0.0.0.0",
},
},
{
name: "with tensor parallelism",
config: plugins.VLLMConfig{
Model: "meta-llama/Llama-2-70b",
TensorParallelSize: 4,
},
port: 8000,
expected: []string{
"--tensor-parallel-size", "4",
},
},
{
name: "with quantization",
config: plugins.VLLMConfig{
Model: "test-model",
Quantization: "awq",
},
port: 8000,
expected: []string{
"--quantization", "awq",
},
},
{
name: "with trust remote code",
config: plugins.VLLMConfig{
Model: "custom-model",
TrustRemoteCode: true,
},
port: 8000,
expected: []string{
"--trust-remote-code",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := plugin.BuildCommand(&tt.config, tt.port)
for _, exp := range tt.expected {
assert.Contains(t, cmd, exp)
}
assert.Contains(t, cmd, "--model")
assert.Contains(t, cmd, "--port")
assert.Contains(t, cmd, "8000")
})
}
}
func TestVLLMPluginGetServiceTemplate(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
plugin := plugins.NewVLLMPlugin(logger, "")
template := plugin.GetServiceTemplate()
assert.Equal(t, "service", template.JobType)
assert.Equal(t, "service", template.SlotPool)
assert.Equal(t, 1, template.GPUCount)
// Verify health check endpoints
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness)
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness)
assert.Equal(t, 30, template.HealthCheck.Interval)
assert.Equal(t, 10, template.HealthCheck.Timeout)
// Verify mounts
require.Len(t, template.Mounts, 2)
assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source)
assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination)
// Verify resource requirements
assert.Equal(t, 1, template.Resources.MinGPU)
assert.Equal(t, 8, template.Resources.MaxGPU)
assert.Equal(t, 16, template.Resources.MinMemory)
}
func TestRunVLLMTask(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
ctx := context.Background()
t.Run("missing config", func(t *testing.T) {
metadata := map[string]string{}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing 'vllm_config'")
})
t.Run("invalid config json", func(t *testing.T) {
metadata := map[string]string{
"vllm_config": "invalid json",
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid config")
})
t.Run("invalid config values", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "", // Missing model
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), "model name is required")
})
t.Run("successful start", func(t *testing.T) {
config := plugins.VLLMConfig{
Model: "meta-llama/Llama-2-7b",
TensorParallelSize: 1,
}
configJSON, _ := json.Marshal(config)
metadata := map[string]string{
"vllm_config": string(configJSON),
}
output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "starting", result["status"])
assert.Equal(t, "meta-llama/Llama-2-7b", result["model"])
assert.Equal(t, "vllm", result["plugin"])
})
}