fetch_ml/internal/worker/lifecycle/service_manager.go
Jeremie Fraeys 0b5e99f720
refactor(scheduler,worker): improve service management and GPU detection
Scheduler enhancements:
- auth.go: Group membership validation in authentication
- hub.go: Task distribution with group affinity
- port_allocator.go: Dynamic port allocation with conflict resolution
- scheduler_conn.go: Connection pooling and retry logic
- service_manager.go: Lifecycle management for scheduler services
- service_templates.go: Template-based service configuration
- state.go: Persistent state management with recovery

Worker improvements:
- config.go: Extended configuration for task visibility rules
- execution/setup.go: Sandboxed execution environment setup
- executor/container.go: Container runtime integration
- executor/runner.go: Task runner with visibility enforcement
- gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback)
- integrity/validate.go: Data integrity validation
- lifecycle/runloop.go: Improved runloop with graceful shutdown
- lifecycle/service_manager.go: Service lifecycle coordination
- process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups
- tenant/manager.go: Multi-tenant resource isolation
- tenant/middleware.go: Tenant context propagation
- worker.go: Core worker with group-scoped task execution
2026-03-08 13:03:15 -04:00

265 lines
7.2 KiB
Go

// Package lifecycle provides service job lifecycle management for long-running services.
package lifecycle
import (
"context"
"fmt"
"net/http"
"os/exec"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// ServiceManager handles the lifecycle of long-running service jobs (Jupyter, vLLM)
type ServiceManager struct {
task *domain.Task
spec ServiceSpec
cmd *exec.Cmd
port int
stateMgr *StateManager
logger Logger
healthCheck *scheduler.HealthCheck
}
// ServiceSpec defines the specification for a service job
type ServiceSpec struct {
Command []string
Env map[string]string
Port int
HealthCheck *scheduler.HealthCheck
}
// NewServiceManager creates a new service manager
func NewServiceManager(task *domain.Task, spec ServiceSpec, port int, stateMgr *StateManager, logger Logger) *ServiceManager {
return &ServiceManager{
task: task,
spec: spec,
port: port,
stateMgr: stateMgr,
logger: logger,
healthCheck: spec.HealthCheck,
}
}
// Run executes the service lifecycle: start, wait for ready, health loop
func (sm *ServiceManager) Run(ctx context.Context) error {
if err := sm.start(); err != nil {
sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
}
}
return fmt.Errorf("start service: %w", err)
}
// Wait for readiness
readyCtx, cancel := context.WithTimeout(ctx, 120*time.Second)
defer cancel()
if err := sm.waitReady(readyCtx); err != nil {
sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
}
}
_ = sm.stop()
return fmt.Errorf("wait ready: %w", err)
}
// Transition to serving state
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateServing); err != nil {
sm.logger.Error("failed to transition to serving", "task_id", sm.task.ID, "error", err)
}
}
sm.logger.Info("service is serving", "task_id", sm.task.ID, "port", sm.port)
// Run health loop
return sm.healthLoop(ctx)
}
// start launches the service process
func (sm *ServiceManager) start() error {
if len(sm.spec.Command) == 0 {
return fmt.Errorf("no command specified")
}
sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...)
// Set environment
for k, v := range sm.spec.Env {
sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
// Start process
if err := sm.cmd.Start(); err != nil {
return fmt.Errorf("start process: %w", err)
}
return nil
}
// waitReady blocks until the service passes readiness check or timeout
func (sm *ServiceManager) waitReady(ctx context.Context) error {
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
// No readiness check - assume ready immediately
return nil
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if sm.checkReadiness() {
return nil
}
}
}
}
// healthLoop monitors service health until context cancellation
func (sm *ServiceManager) healthLoop(ctx context.Context) error {
if sm.healthCheck == nil {
// No health check - just wait for context cancellation
<-ctx.Done()
return sm.gracefulStop()
}
interval := time.Duration(sm.healthCheck.IntervalSecs) * time.Second
if interval == 0 {
interval = 15 * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return sm.gracefulStop()
case <-ticker.C:
if !sm.checkLiveness() {
sm.logger.Error("service liveness check failed", "task_id", sm.task.ID)
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
return fmt.Errorf("transition to failed: %w", err)
}
}
return fmt.Errorf("liveness check failed")
}
}
}
}
// checkLiveness returns true if the service process is alive
func (sm *ServiceManager) checkLiveness() bool {
if sm.cmd == nil || sm.cmd.Process == nil {
return false
}
// Check process state
if sm.healthCheck != nil && sm.healthCheck.LivenessEndpoint != "" {
// HTTP liveness check
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(sm.healthCheck.LivenessEndpoint)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == 200
}
// Process existence check
return sm.cmd.Process.Signal(syscall.Signal(0)) == nil
}
// checkReadiness returns true if the service is ready to accept traffic
func (sm *ServiceManager) checkReadiness() bool {
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
return true
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(sm.healthCheck.ReadinessEndpoint)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == 200
}
// gracefulStop stops the service gracefully with a timeout
func (sm *ServiceManager) gracefulStop() error {
sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID)
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateStopping); err != nil {
sm.logger.Error("failed to transition to stopping", "task_id", sm.task.ID, "error", err)
}
}
if sm.cmd == nil || sm.cmd.Process == nil {
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil {
sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err)
}
}
return nil
}
// Send SIGTERM for graceful shutdown
if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil {
sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err)
if err := sm.cmd.Process.Kill(); err != nil {
sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err)
}
} else {
// Wait for graceful shutdown
done := make(chan error, 1)
go func() {
done <- sm.cmd.Wait()
}()
select {
case <-done:
// Graceful shutdown completed
case <-time.After(30 * time.Second):
// Timeout - force kill
sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID)
if err := sm.cmd.Process.Kill(); err != nil {
sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err)
}
}
}
if sm.stateMgr != nil {
if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil {
sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err)
}
}
return nil
}
// stop forcefully stops the service
func (sm *ServiceManager) stop() error {
if sm.cmd == nil || sm.cmd.Process == nil {
return nil
}
return sm.cmd.Process.Kill()
}
// GetPort returns the assigned port for the service
func (sm *ServiceManager) GetPort() int {
return sm.port
}