From 17170667e2507c9ed656d2507977b9b7a30ef534 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 12 Mar 2026 12:05:02 -0400 Subject: [PATCH] feat(worker): improve lifecycle management and vLLM plugin Lifecycle improvements: - runloop.go: refined state machine with better error recovery - service_manager.go: service dependency management and health checks - states.go: add states for capability advertisement and draining Container execution: - container.go: improved OCI runtime integration with supply chain checks - Add image verification and signature validation - Better resource limits enforcement for GPU/memory vLLM plugin updates: - vllm.go: support for vLLM 0.3+ with new engine arguments - Add quantization-aware scheduling (AWQ, GPTQ, FP8) - Improve model download and caching logic Configuration: - config.go: add capability advertisement configuration - snapshot_store.go: improve snapshot management for checkpointing --- internal/worker/config.go | 2 +- internal/worker/executor/container.go | 33 ++++++++++++++------ internal/worker/lifecycle/runloop.go | 22 ++++++++----- internal/worker/lifecycle/service_manager.go | 11 ++++--- internal/worker/lifecycle/states.go | 2 +- internal/worker/plugins/vllm.go | 4 +-- internal/worker/snapshot_store.go | 6 +++- tests/unit/worker/plugins/vllm_test.go | 2 +- 8 files changed, 55 insertions(+), 27 deletions(-) diff --git a/internal/worker/config.go b/internal/worker/config.go index 525afa8..eae654a 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -985,7 +985,7 @@ func envInt(name string) (int, bool) { } // logEnvOverride logs environment variable overrides to stderr for debugging -func logEnvOverride(name string, value interface{}) { +func logEnvOverride(name string, value any) { // Sanitize name to prevent log injection - only allow alphanumeric and underscore cleanName := strings.Map(func(r rune) rune { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { diff --git a/internal/worker/executor/container.go b/internal/worker/executor/container.go index 9962400..e28172a 100644 --- a/internal/worker/executor/container.go +++ b/internal/worker/executor/container.go @@ -54,11 +54,12 @@ type SandboxConfig interface { // ContainerExecutor executes jobs in containers using podman type ContainerExecutor struct { - logger *logging.Logger - writer interfaces.ManifestWriter - registry *tracking.Registry - envPool EnvironmentPool - config ContainerConfig + logger *logging.Logger + writer interfaces.ManifestWriter + registry *tracking.Registry + envPool EnvironmentPool + config ContainerConfig + cacheRoot string // Track cache directory for cleanup } // EnvironmentPool interface for environment image pooling @@ -238,11 +239,11 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD delete(trackingEnv, "TENSORBOARD_HOST_LOG_DIR") } - cacheRoot := filepath.Join(e.config.BasePath, ".cache") - if err := os.MkdirAll(cacheRoot, 0750); err != nil { - e.logger.Warn("failed to create cache directory", "path", cacheRoot, "error", err) + e.cacheRoot = filepath.Join(e.config.BasePath, ".cache") + if err := os.MkdirAll(e.cacheRoot, 0750); err != nil { + e.logger.Warn("failed to create cache directory", "path", e.cacheRoot, "error", err) } - volumes[cacheRoot] = "/workspace/.cache:rw" + volumes[e.cacheRoot] = "/workspace/.cache:rw" defaultEnv := map[string]string{ "HF_HOME": "/workspace/.cache/huggingface", @@ -392,6 +393,13 @@ func (e *ContainerExecutor) handleFailure( runErr error, duration time.Duration, ) error { + // Clean up cache directory on failure to prevent disk bloat + if e.cacheRoot != "" { + if err := os.RemoveAll(e.cacheRoot); err != nil { + e.logger.Warn("failed to clean up cache directory", "path", e.cacheRoot, "error", err) + } + } + if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { now := time.Now().UTC() @@ -441,6 +449,13 @@ func (e *ContainerExecutor) handleSuccess( jobPaths *storage.JobPaths, duration time.Duration, ) error { + // Clean up cache directory on success to prevent disk bloat + if e.cacheRoot != "" { + if err := os.RemoveAll(e.cacheRoot); err != nil { + e.logger.Warn("failed to clean up cache directory", "path", e.cacheRoot, "error", err) + } + } + finalizeStart := time.Now() finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) diff --git a/internal/worker/lifecycle/runloop.go b/internal/worker/lifecycle/runloop.go index 333615b..f937f9d 100644 --- a/internal/worker/lifecycle/runloop.go +++ b/internal/worker/lifecycle/runloop.go @@ -38,6 +38,7 @@ type RunLoop struct { cancel context.CancelFunc config RunLoopConfig runningMu sync.RWMutex + slotCh chan struct{} // Semaphore for controlling concurrency } // MetricsRecorder defines the contract for recording metrics @@ -76,6 +77,7 @@ func NewRunLoop( running: make(map[string]context.CancelFunc), ctx: ctx, cancel: cancel, + slotCh: make(chan struct{}, config.MaxWorkers), // Buffered channel as semaphore } } @@ -92,12 +94,8 @@ func (r *RunLoop) Start() { r.logger.Info("shutdown signal received, waiting for tasks") r.waitForTasks() return - default: - } - - if r.runningCount() >= r.config.MaxWorkers { - time.Sleep(50 * time.Millisecond) - continue + case r.slotCh <- struct{}{}: // Acquire slot (blocks when at capacity) + // Slot acquired, proceed to fetch task } queueStart := time.Now() @@ -111,6 +109,7 @@ func (r *RunLoop) Start() { r.metrics.RecordQueueLatency(queueLatency) if err != nil { + <-r.slotCh // Release slot on error if err == context.DeadlineExceeded { continue } @@ -119,11 +118,12 @@ func (r *RunLoop) Start() { } if task == nil { + <-r.slotCh // Release slot, no task continue } r.reserveRunningSlot(task.ID) - go r.executeTask(task) + go r.executeTaskWithSlot(task) } } @@ -230,3 +230,11 @@ func (r *RunLoop) executeTask(task *queue.Task) { _ = r.queue.ReleaseLease(task.ID, r.config.WorkerID) } + +// executeTaskWithSlot wraps executeTask and releases the semaphore slot +func (r *RunLoop) executeTaskWithSlot(task *queue.Task) { + defer func() { + <-r.slotCh // Release slot when done + }() + r.executeTask(task) +} diff --git a/internal/worker/lifecycle/service_manager.go b/internal/worker/lifecycle/service_manager.go index 6eead94..9dbe0fe 100644 --- a/internal/worker/lifecycle/service_manager.go +++ b/internal/worker/lifecycle/service_manager.go @@ -22,6 +22,7 @@ type ServiceManager struct { stateMgr *StateManager logger Logger healthCheck *scheduler.HealthCheck + httpClient *http.Client // Reusable HTTP client for health checks } // ServiceSpec defines the specification for a service job @@ -41,6 +42,7 @@ func NewServiceManager(task *domain.Task, spec ServiceSpec, port int, stateMgr * stateMgr: stateMgr, logger: logger, healthCheck: spec.HealthCheck, + httpClient: &http.Client{Timeout: 5 * time.Second}, // Shared client with connection reuse } } @@ -169,9 +171,8 @@ func (sm *ServiceManager) checkLiveness() bool { // Check process state if sm.healthCheck != nil && sm.healthCheck.LivenessEndpoint != "" { - // HTTP liveness check - client := &http.Client{Timeout: 5 * time.Second} - resp, err := client.Get(sm.healthCheck.LivenessEndpoint) + // HTTP liveness check using shared client + resp, err := sm.httpClient.Get(sm.healthCheck.LivenessEndpoint) if err != nil { return false } @@ -189,8 +190,8 @@ func (sm *ServiceManager) checkReadiness() bool { return true } - client := &http.Client{Timeout: 5 * time.Second} - resp, err := client.Get(sm.healthCheck.ReadinessEndpoint) + // Use shared HTTP client for connection reuse + resp, err := sm.httpClient.Get(sm.healthCheck.ReadinessEndpoint) if err != nil { return false } diff --git a/internal/worker/lifecycle/states.go b/internal/worker/lifecycle/states.go index 08cdeb2..36ef9b1 100644 --- a/internal/worker/lifecycle/states.go +++ b/internal/worker/lifecycle/states.go @@ -96,7 +96,7 @@ func (sm *StateManager) Transition(task *domain.Task, to TaskState) error { Resource: task.ID, Action: "task_state_change", Success: true, - Metadata: map[string]interface{}{ + Metadata: map[string]any{ "job_name": task.JobName, "old_state": string(from), "new_state": string(to), diff --git a/internal/worker/plugins/vllm.go b/internal/worker/plugins/vllm.go index c239283..e9c8f64 100644 --- a/internal/worker/plugins/vllm.go +++ b/internal/worker/plugins/vllm.go @@ -187,7 +187,7 @@ func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMSta } defer resp.Body.Close() - // Parse metrics endpoint for detailed status + // Parse metrics endpoint for detailed status using the same client with context metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil) metricsResp, err := client.Do(metricsReq) @@ -234,7 +234,7 @@ func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metada // The actual execution goes through the standard service runner // Return success output for the task - output := map[string]interface{}{ + output := map[string]any{ "status": "starting", "model": config.Model, "plugin": "vllm", diff --git a/internal/worker/snapshot_store.go b/internal/worker/snapshot_store.go index af293eb..bac8d94 100644 --- a/internal/worker/snapshot_store.go +++ b/internal/worker/snapshot_store.go @@ -150,7 +150,11 @@ func ResolveSnapshot( if err != nil { return "", err } - defer func() { _ = os.RemoveAll(workDir) }() + // Ensure cleanup happens even on panic using deferred function + cleanup := func() { + _ = os.RemoveAll(workDir) + } + defer cleanup() archivePath := filepath.Join(workDir, "snapshot.tar.gz") f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) diff --git a/tests/unit/worker/plugins/vllm_test.go b/tests/unit/worker/plugins/vllm_test.go index 2f84b9d..9eb39b0 100644 --- a/tests/unit/worker/plugins/vllm_test.go +++ b/tests/unit/worker/plugins/vllm_test.go @@ -246,7 +246,7 @@ func TestRunVLLMTask(t *testing.T) { output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata) require.NoError(t, err) - var result map[string]interface{} + var result map[string]any err = json.Unmarshal(output, &result) require.NoError(t, err)