package plugins__test import ( "context" "encoding/json" "log/slog" "os" "testing" "github.com/jfraeys/fetch_ml/internal/worker/plugins" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewVLLMPlugin(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Test with default image p1 := plugins.NewVLLMPlugin(logger, "") assert.NotNil(t, p1) // Test with custom image p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0") assert.NotNil(t, p2) } func TestVLLMPluginValidateConfig(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) plugin := plugins.NewVLLMPlugin(logger, "") tests := []struct { name string config plugins.VLLMConfig wantErr bool }{ { name: "valid config", config: plugins.VLLMConfig{ Model: "meta-llama/Llama-2-7b", TensorParallelSize: 1, GpuMemoryUtilization: 0.9, }, wantErr: false, }, { name: "missing model", config: plugins.VLLMConfig{ Model: "", }, wantErr: true, }, { name: "invalid gpu memory high", config: plugins.VLLMConfig{ Model: "test-model", GpuMemoryUtilization: 1.5, }, wantErr: true, }, { name: "invalid gpu memory low", config: plugins.VLLMConfig{ Model: "test-model", GpuMemoryUtilization: -0.1, }, wantErr: true, }, { name: "valid quantization awq", config: plugins.VLLMConfig{ Model: "test-model", Quantization: "awq", }, wantErr: false, }, { name: "valid quantization gptq", config: plugins.VLLMConfig{ Model: "test-model", Quantization: "gptq", }, wantErr: false, }, { name: "invalid quantization", config: plugins.VLLMConfig{ Model: "test-model", Quantization: "invalid", }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := plugin.ValidateConfig(&tt.config) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } func TestVLLMPluginBuildCommand(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) plugin := plugins.NewVLLMPlugin(logger, "") tests := []struct { name string config plugins.VLLMConfig port int expected []string }{ { name: "basic command", config: plugins.VLLMConfig{ Model: "meta-llama/Llama-2-7b", }, port: 8000, expected: []string{ "python", "-m", "vllm.entrypoints.openai.api_server", "--model", "meta-llama/Llama-2-7b", "--port", "8000", "--host", "0.0.0.0", }, }, { name: "with tensor parallelism", config: plugins.VLLMConfig{ Model: "meta-llama/Llama-2-70b", TensorParallelSize: 4, }, port: 8000, expected: []string{ "--tensor-parallel-size", "4", }, }, { name: "with quantization", config: plugins.VLLMConfig{ Model: "test-model", Quantization: "awq", }, port: 8000, expected: []string{ "--quantization", "awq", }, }, { name: "with trust remote code", config: plugins.VLLMConfig{ Model: "custom-model", TrustRemoteCode: true, }, port: 8000, expected: []string{ "--trust-remote-code", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cmd := plugin.BuildCommand(&tt.config, tt.port) for _, exp := range tt.expected { assert.Contains(t, cmd, exp) } assert.Contains(t, cmd, "--model") assert.Contains(t, cmd, "--port") assert.Contains(t, cmd, "8000") }) } } func TestVLLMPluginGetServiceTemplate(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) plugin := plugins.NewVLLMPlugin(logger, "") template := plugin.GetServiceTemplate() assert.Equal(t, "service", template.JobType) assert.Equal(t, "service", template.SlotPool) assert.Equal(t, 1, template.GPUCount) // Verify health check endpoints assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness) assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness) assert.Equal(t, 30, template.HealthCheck.Interval) assert.Equal(t, 10, template.HealthCheck.Timeout) // Verify mounts require.Len(t, template.Mounts, 2) assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source) assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination) // Verify resource requirements assert.Equal(t, 1, template.Resources.MinGPU) assert.Equal(t, 8, template.Resources.MaxGPU) assert.Equal(t, 16, template.Resources.MinMemory) } func TestRunVLLMTask(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) ctx := context.Background() t.Run("missing config", func(t *testing.T) { metadata := map[string]string{} _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) assert.Error(t, err) assert.Contains(t, err.Error(), "missing 'vllm_config'") }) t.Run("invalid config json", func(t *testing.T) { metadata := map[string]string{ "vllm_config": "invalid json", } _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid config") }) t.Run("invalid config values", func(t *testing.T) { config := plugins.VLLMConfig{ Model: "", // Missing model } configJSON, _ := json.Marshal(config) metadata := map[string]string{ "vllm_config": string(configJSON), } _, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) assert.Error(t, err) assert.Contains(t, err.Error(), "model name is required") }) t.Run("successful start", func(t *testing.T) { config := plugins.VLLMConfig{ Model: "meta-llama/Llama-2-7b", TensorParallelSize: 1, } configJSON, _ := json.Marshal(config) metadata := map[string]string{ "vllm_config": string(configJSON), } output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) require.NoError(t, err) var result map[string]interface{} err = json.Unmarshal(output, &result) require.NoError(t, err) assert.Equal(t, "starting", result["status"]) assert.Equal(t, "meta-llama/Llama-2-7b", result["model"]) assert.Equal(t, "vllm", result["plugin"]) }) }