Extend worker capabilities with new execution plugins and security features: - Jupyter plugin for notebook-based ML experiments - vLLM plugin for LLM inference workloads - Cross-platform process isolation (Unix/Windows) - Network policy enforcement with platform-specific implementations - Service manager integration for lifecycle management - Scheduler backend integration for queue coordination Update lifecycle management: - Enhanced runloop with state transitions - Service manager integration for plugin coordination - Improved state persistence and recovery Add test coverage: - Unit tests for Jupyter and vLLM plugins - Updated worker execution tests
257 lines
6.2 KiB
Go
257 lines
6.2 KiB
Go
package plugins__test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewVLLMPlugin(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
|
|
// Test with default image
|
|
p1 := plugins.NewVLLMPlugin(logger, "")
|
|
assert.NotNil(t, p1)
|
|
|
|
// Test with custom image
|
|
p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0")
|
|
assert.NotNil(t, p2)
|
|
}
|
|
|
|
func TestVLLMPluginValidateConfig(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
plugin := plugins.NewVLLMPlugin(logger, "")
|
|
|
|
tests := []struct {
|
|
name string
|
|
config plugins.VLLMConfig
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
config: plugins.VLLMConfig{
|
|
Model: "meta-llama/Llama-2-7b",
|
|
TensorParallelSize: 1,
|
|
GpuMemoryUtilization: 0.9,
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "missing model",
|
|
config: plugins.VLLMConfig{
|
|
Model: "",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid gpu memory high",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
GpuMemoryUtilization: 1.5,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid gpu memory low",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
GpuMemoryUtilization: -0.1,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid quantization awq",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
Quantization: "awq",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid quantization gptq",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
Quantization: "gptq",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid quantization",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
Quantization: "invalid",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := plugin.ValidateConfig(&tt.config)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVLLMPluginBuildCommand(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
plugin := plugins.NewVLLMPlugin(logger, "")
|
|
|
|
tests := []struct {
|
|
name string
|
|
config plugins.VLLMConfig
|
|
port int
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "basic command",
|
|
config: plugins.VLLMConfig{
|
|
Model: "meta-llama/Llama-2-7b",
|
|
},
|
|
port: 8000,
|
|
expected: []string{
|
|
"python", "-m", "vllm.entrypoints.openai.api_server",
|
|
"--model", "meta-llama/Llama-2-7b",
|
|
"--port", "8000",
|
|
"--host", "0.0.0.0",
|
|
},
|
|
},
|
|
{
|
|
name: "with tensor parallelism",
|
|
config: plugins.VLLMConfig{
|
|
Model: "meta-llama/Llama-2-70b",
|
|
TensorParallelSize: 4,
|
|
},
|
|
port: 8000,
|
|
expected: []string{
|
|
"--tensor-parallel-size", "4",
|
|
},
|
|
},
|
|
{
|
|
name: "with quantization",
|
|
config: plugins.VLLMConfig{
|
|
Model: "test-model",
|
|
Quantization: "awq",
|
|
},
|
|
port: 8000,
|
|
expected: []string{
|
|
"--quantization", "awq",
|
|
},
|
|
},
|
|
{
|
|
name: "with trust remote code",
|
|
config: plugins.VLLMConfig{
|
|
Model: "custom-model",
|
|
TrustRemoteCode: true,
|
|
},
|
|
port: 8000,
|
|
expected: []string{
|
|
"--trust-remote-code",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cmd := plugin.BuildCommand(&tt.config, tt.port)
|
|
for _, exp := range tt.expected {
|
|
assert.Contains(t, cmd, exp)
|
|
}
|
|
assert.Contains(t, cmd, "--model")
|
|
assert.Contains(t, cmd, "--port")
|
|
assert.Contains(t, cmd, "8000")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVLLMPluginGetServiceTemplate(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
plugin := plugins.NewVLLMPlugin(logger, "")
|
|
|
|
template := plugin.GetServiceTemplate()
|
|
|
|
assert.Equal(t, "service", template.JobType)
|
|
assert.Equal(t, "service", template.SlotPool)
|
|
assert.Equal(t, 1, template.GPUCount)
|
|
|
|
// Verify health check endpoints
|
|
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness)
|
|
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness)
|
|
assert.Equal(t, 30, template.HealthCheck.Interval)
|
|
assert.Equal(t, 10, template.HealthCheck.Timeout)
|
|
|
|
// Verify mounts
|
|
require.Len(t, template.Mounts, 2)
|
|
assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source)
|
|
assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination)
|
|
|
|
// Verify resource requirements
|
|
assert.Equal(t, 1, template.Resources.MinGPU)
|
|
assert.Equal(t, 8, template.Resources.MaxGPU)
|
|
assert.Equal(t, 16, template.Resources.MinMemory)
|
|
}
|
|
|
|
func TestRunVLLMTask(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
ctx := context.Background()
|
|
|
|
t.Run("missing config", func(t *testing.T) {
|
|
metadata := map[string]string{}
|
|
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "missing 'vllm_config'")
|
|
})
|
|
|
|
t.Run("invalid config json", func(t *testing.T) {
|
|
metadata := map[string]string{
|
|
"vllm_config": "invalid json",
|
|
}
|
|
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid config")
|
|
})
|
|
|
|
t.Run("invalid config values", func(t *testing.T) {
|
|
config := plugins.VLLMConfig{
|
|
Model: "", // Missing model
|
|
}
|
|
configJSON, _ := json.Marshal(config)
|
|
metadata := map[string]string{
|
|
"vllm_config": string(configJSON),
|
|
}
|
|
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "model name is required")
|
|
})
|
|
|
|
t.Run("successful start", func(t *testing.T) {
|
|
config := plugins.VLLMConfig{
|
|
Model: "meta-llama/Llama-2-7b",
|
|
TensorParallelSize: 1,
|
|
}
|
|
configJSON, _ := json.Marshal(config)
|
|
metadata := map[string]string{
|
|
"vllm_config": string(configJSON),
|
|
}
|
|
output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
|
require.NoError(t, err)
|
|
|
|
var result map[string]interface{}
|
|
err = json.Unmarshal(output, &result)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "starting", result["status"])
|
|
assert.Equal(t, "meta-llama/Llama-2-7b", result["model"])
|
|
assert.Equal(t, "vllm", result["plugin"])
|
|
})
|
|
}
|