feat(worker): add Jupyter/vLLM plugins and process isolation
Extend worker capabilities with new execution plugins and security features: - Jupyter plugin for notebook-based ML experiments - vLLM plugin for LLM inference workloads - Cross-platform process isolation (Unix/Windows) - Network policy enforcement with platform-specific implementations - Service manager integration for lifecycle management - Scheduler backend integration for queue coordination Update lifecycle management: - Enhanced runloop with state transitions - Service manager integration for plugin coordination - Improved state persistence and recovery Add test coverage: - Unit tests for Jupyter and vLLM plugins - Updated worker execution tests
This commit is contained in:
parent
a981e89005
commit
95adcba437
13 changed files with 1928 additions and 16 deletions
227
internal/queue/scheduler_backend.go
Normal file
227
internal/queue/scheduler_backend.go
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
)
|
||||
|
||||
// SchedulerConn defines the interface for scheduler connection
|
||||
type SchedulerConn interface {
|
||||
Send(msg scheduler.Message)
|
||||
}
|
||||
|
||||
// SchedulerBackend implements queue.Backend by communicating with a scheduler over WebSocket
|
||||
type SchedulerBackend struct {
|
||||
conn SchedulerConn
|
||||
pendingTasks chan *Task
|
||||
prewarmHint *Task
|
||||
localSlots scheduler.SlotStatus
|
||||
}
|
||||
|
||||
// NewSchedulerBackend creates a new scheduler backend
|
||||
func NewSchedulerBackend(conn SchedulerConn) *SchedulerBackend {
|
||||
return &SchedulerBackend{
|
||||
conn: conn,
|
||||
pendingTasks: make(chan *Task, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// GetNextTaskWithLeaseBlocking implements queue.Backend
|
||||
func (sb *SchedulerBackend) GetNextTaskWithLeaseBlocking(
|
||||
workerID string,
|
||||
leaseDuration time.Duration,
|
||||
blockTimeout time.Duration,
|
||||
) (*Task, error) {
|
||||
// Signal readiness with current slot status
|
||||
sb.conn.Send(scheduler.Message{
|
||||
Type: scheduler.MsgReadyForWork,
|
||||
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||
WorkerID: workerID,
|
||||
Slots: sb.localSlots,
|
||||
Reason: "polling",
|
||||
}),
|
||||
})
|
||||
|
||||
// Wait for scheduler push or timeout
|
||||
select {
|
||||
case task := <-sb.pendingTasks:
|
||||
return task, nil
|
||||
case <-time.After(blockTimeout):
|
||||
return nil, nil // RunLoop retries — same behaviour as empty queue
|
||||
}
|
||||
}
|
||||
|
||||
// OnJobAssign is called when the scheduler pushes a job assignment
|
||||
func (sb *SchedulerBackend) OnJobAssign(spec *scheduler.JobSpec) {
|
||||
task := &Task{
|
||||
ID: spec.ID,
|
||||
JobName: "distributed-job",
|
||||
Priority: 1,
|
||||
Status: "assigned",
|
||||
}
|
||||
select {
|
||||
case sb.pendingTasks <- task:
|
||||
default:
|
||||
// Channel full, drop (shouldn't happen with buffer of 1)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTask implements queue.Backend - forwards state changes to scheduler
|
||||
func (sb *SchedulerBackend) UpdateTask(task *Task) error {
|
||||
// Notify scheduler of state change
|
||||
result := scheduler.JobResultPayload{
|
||||
TaskID: task.ID,
|
||||
State: task.Status,
|
||||
}
|
||||
if task.Error != "" {
|
||||
result.Error = task.Error
|
||||
}
|
||||
sb.conn.Send(scheduler.Message{
|
||||
Type: scheduler.MsgJobResult,
|
||||
Payload: mustMarshal(result),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeekNextTask implements queue.Backend - returns last prewarm hint
|
||||
func (sb *SchedulerBackend) PeekNextTask() (*Task, error) {
|
||||
if sb.prewarmHint == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return sb.prewarmHint, nil
|
||||
}
|
||||
|
||||
// OnPrewarmHint stores a prewarm hint from the scheduler
|
||||
func (sb *SchedulerBackend) OnPrewarmHint(hint scheduler.PrewarmHintPayload) {
|
||||
sb.prewarmHint = &Task{
|
||||
ID: hint.TaskID,
|
||||
Status: "prewarm_hint",
|
||||
Metadata: map[string]string{
|
||||
"snapshot_id": hint.SnapshotID,
|
||||
"snapshot_sha": hint.SnapshotSHA,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSlots updates the local slot status for readiness signaling
|
||||
func (sb *SchedulerBackend) UpdateSlots(running, max int) {
|
||||
sb.localSlots = scheduler.SlotStatus{
|
||||
BatchTotal: max,
|
||||
BatchInUse: running,
|
||||
}
|
||||
}
|
||||
|
||||
// Required queue.Backend methods (stub implementations for now)
|
||||
|
||||
func (sb *SchedulerBackend) AddTask(task *Task) error {
|
||||
return fmt.Errorf("AddTask not supported in distributed mode - submit to scheduler instead")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetNextTask() (*Task, error) {
|
||||
return nil, fmt.Errorf("GetNextTask not supported - use GetNextTaskWithLeaseBlocking")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
return sb.GetNextTaskWithLeaseBlocking(workerID, leaseDuration, 0)
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
|
||||
// Distributed mode: scheduler manages leases
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) ReleaseLease(taskID string, workerID string) error {
|
||||
// Distributed mode: scheduler manages leases
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) RetryTask(task *Task) error {
|
||||
sb.conn.Send(scheduler.Message{
|
||||
Type: scheduler.MsgJobResult,
|
||||
Payload: mustMarshal(scheduler.JobResultPayload{TaskID: task.ID, State: "retry"}),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) MoveToDeadLetterQueue(task *Task, reason string) error {
|
||||
return nil // Scheduler handles this
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetTask(taskID string) (*Task, error) {
|
||||
return nil, fmt.Errorf("GetTask not supported in distributed mode")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetAllTasks() ([]*Task, error) {
|
||||
return nil, fmt.Errorf("GetAllTasks not supported in distributed mode")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetTaskByName(jobName string) (*Task, error) {
|
||||
return nil, fmt.Errorf("GetTaskByName not supported in distributed mode")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) CancelTask(taskID string) error {
|
||||
return fmt.Errorf("CancelTask not supported - use scheduler API")
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) UpdateTaskWithMetrics(task *Task, action string) error {
|
||||
return sb.UpdateTask(task)
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) RecordMetric(jobName, metric string, value float64) error {
|
||||
return nil // Metrics handled by scheduler
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) Heartbeat(workerID string) error {
|
||||
return nil // Heartbeat handled by SchedulerConn
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) QueueDepth() (int64, error) {
|
||||
return 0, nil // Queue depth managed by scheduler
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) SetWorkerPrewarmState(state PrewarmState) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) ClearWorkerPrewarmState(workerID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetWorkerPrewarmState(workerID string) (*PrewarmState, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) GetAllWorkerPrewarmStates() ([]PrewarmState, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) SignalPrewarmGC() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) PrewarmGCRequestValue() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (sb *SchedulerBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn returns the underlying scheduler connection for heartbeat
|
||||
func (sb *SchedulerBackend) Conn() *scheduler.SchedulerConn {
|
||||
if conn, ok := sb.conn.(*scheduler.SchedulerConn); ok {
|
||||
return conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMarshal(v any) []byte {
|
||||
b, _ := json.Marshal(v)
|
||||
return b
|
||||
}
|
||||
|
||||
// Compile-time check - ensure SchedulerBackend implements the interface
|
||||
var _ Backend = (*SchedulerBackend)(nil)
|
||||
|
|
@ -28,18 +28,16 @@ type RunLoopConfig struct {
|
|||
|
||||
// RunLoop manages the main worker processing loop
|
||||
type RunLoop struct {
|
||||
config RunLoopConfig
|
||||
queue queue.Backend
|
||||
executor TaskExecutor
|
||||
metrics MetricsRecorder
|
||||
logger Logger
|
||||
stateMgr *StateManager
|
||||
|
||||
// State management
|
||||
running map[string]context.CancelFunc
|
||||
runningMu sync.RWMutex
|
||||
queue queue.Backend
|
||||
executor TaskExecutor
|
||||
metrics MetricsRecorder
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
stateMgr *StateManager
|
||||
running map[string]context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
config RunLoopConfig
|
||||
runningMu sync.RWMutex
|
||||
}
|
||||
|
||||
// MetricsRecorder defines the contract for recording metrics
|
||||
|
|
|
|||
248
internal/worker/lifecycle/service_manager.go
Normal file
248
internal/worker/lifecycle/service_manager.go
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
// Package lifecycle provides service job lifecycle management for long-running services.
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
)
|
||||
|
||||
// ServiceManager handles the lifecycle of long-running service jobs (Jupyter, vLLM)
|
||||
type ServiceManager struct {
|
||||
task *domain.Task
|
||||
spec ServiceSpec
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
stateMgr *StateManager
|
||||
logger Logger
|
||||
healthCheck *scheduler.HealthCheck
|
||||
}
|
||||
|
||||
// ServiceSpec defines the specification for a service job
|
||||
type ServiceSpec struct {
|
||||
Command []string
|
||||
Env map[string]string
|
||||
Port int
|
||||
HealthCheck *scheduler.HealthCheck
|
||||
}
|
||||
|
||||
// NewServiceManager creates a new service manager
|
||||
func NewServiceManager(task *domain.Task, spec ServiceSpec, port int, stateMgr *StateManager, logger Logger) *ServiceManager {
|
||||
return &ServiceManager{
|
||||
task: task,
|
||||
spec: spec,
|
||||
port: port,
|
||||
stateMgr: stateMgr,
|
||||
logger: logger,
|
||||
healthCheck: spec.HealthCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the service lifecycle: start, wait for ready, health loop
|
||||
func (sm *ServiceManager) Run(ctx context.Context) error {
|
||||
if err := sm.start(); err != nil {
|
||||
sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err)
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateFailed)
|
||||
}
|
||||
return fmt.Errorf("start service: %w", err)
|
||||
}
|
||||
|
||||
// Wait for readiness
|
||||
readyCtx, cancel := context.WithTimeout(ctx, 120*time.Second)
|
||||
defer cancel()
|
||||
if err := sm.waitReady(readyCtx); err != nil {
|
||||
sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err)
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateFailed)
|
||||
}
|
||||
sm.stop()
|
||||
return fmt.Errorf("wait ready: %w", err)
|
||||
}
|
||||
|
||||
// Transition to serving state
|
||||
if sm.stateMgr != nil {
|
||||
if err := sm.stateMgr.Transition(sm.task, StateServing); err != nil {
|
||||
sm.logger.Error("failed to transition to serving", "task_id", sm.task.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
sm.logger.Info("service is serving", "task_id", sm.task.ID, "port", sm.port)
|
||||
|
||||
// Run health loop
|
||||
return sm.healthLoop(ctx)
|
||||
}
|
||||
|
||||
// start launches the service process
|
||||
func (sm *ServiceManager) start() error {
|
||||
if len(sm.spec.Command) == 0 {
|
||||
return fmt.Errorf("no command specified")
|
||||
}
|
||||
|
||||
sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...)
|
||||
|
||||
// Set environment
|
||||
for k, v := range sm.spec.Env {
|
||||
sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
|
||||
// Start process
|
||||
if err := sm.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitReady blocks until the service passes readiness check or timeout
|
||||
func (sm *ServiceManager) waitReady(ctx context.Context) error {
|
||||
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
|
||||
// No readiness check - assume ready immediately
|
||||
return nil
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if sm.checkReadiness() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthLoop monitors service health until context cancellation
|
||||
func (sm *ServiceManager) healthLoop(ctx context.Context) error {
|
||||
if sm.healthCheck == nil {
|
||||
// No health check - just wait for context cancellation
|
||||
<-ctx.Done()
|
||||
return sm.gracefulStop()
|
||||
}
|
||||
|
||||
interval := time.Duration(sm.healthCheck.IntervalSecs) * time.Second
|
||||
if interval == 0 {
|
||||
interval = 15 * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return sm.gracefulStop()
|
||||
case <-ticker.C:
|
||||
if !sm.checkLiveness() {
|
||||
sm.logger.Error("service liveness check failed", "task_id", sm.task.ID)
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateFailed)
|
||||
}
|
||||
return fmt.Errorf("liveness check failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkLiveness returns true if the service process is alive
|
||||
func (sm *ServiceManager) checkLiveness() bool {
|
||||
if sm.cmd == nil || sm.cmd.Process == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check process state
|
||||
if sm.healthCheck != nil && sm.healthCheck.LivenessEndpoint != "" {
|
||||
// HTTP liveness check
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get(sm.healthCheck.LivenessEndpoint)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
|
||||
// Process existence check
|
||||
return sm.cmd.Process.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
// checkReadiness returns true if the service is ready to accept traffic
|
||||
func (sm *ServiceManager) checkReadiness() bool {
|
||||
if sm.healthCheck == nil || sm.healthCheck.ReadinessEndpoint == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get(sm.healthCheck.ReadinessEndpoint)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
|
||||
// gracefulStop stops the service gracefully with a timeout
|
||||
func (sm *ServiceManager) gracefulStop() error {
|
||||
sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID)
|
||||
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateStopping)
|
||||
}
|
||||
|
||||
if sm.cmd == nil || sm.cmd.Process == nil {
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateCompleted)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send SIGTERM for graceful shutdown
|
||||
if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err)
|
||||
sm.cmd.Process.Kill()
|
||||
} else {
|
||||
// Wait for graceful shutdown
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- sm.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Graceful shutdown completed
|
||||
case <-time.After(30 * time.Second):
|
||||
// Timeout - force kill
|
||||
sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID)
|
||||
sm.cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
if sm.stateMgr != nil {
|
||||
sm.stateMgr.Transition(sm.task, StateCompleted)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop forcefully stops the service
|
||||
func (sm *ServiceManager) stop() error {
|
||||
if sm.cmd == nil || sm.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// GetPort returns the assigned port for the service
|
||||
func (sm *ServiceManager) GetPort() int {
|
||||
return sm.port
|
||||
}
|
||||
|
|
@ -27,18 +27,30 @@ const (
|
|||
StateCompleted TaskState = "completed"
|
||||
// StateFailed indicates the task failed during execution.
|
||||
StateFailed TaskState = "failed"
|
||||
// StateCancelled indicates the task was cancelled by user or system.
|
||||
StateCancelled TaskState = "cancelled"
|
||||
// StateOrphaned indicates the worker was lost - task will be silently re-queued.
|
||||
StateOrphaned TaskState = "orphaned"
|
||||
// StateServing indicates a service job (Jupyter/vLLM) is up and ready.
|
||||
StateServing TaskState = "serving"
|
||||
// StateStopping indicates a service job is in graceful shutdown.
|
||||
StateStopping TaskState = "stopping"
|
||||
)
|
||||
|
||||
// ValidTransitions defines the allowed state transitions.
|
||||
// The key is the "from" state, the value is a list of valid "to" states.
|
||||
// This enforces that state transitions follow the expected lifecycle.
|
||||
var ValidTransitions = map[TaskState][]TaskState{
|
||||
StateQueued: {StatePreparing, StateFailed},
|
||||
StatePreparing: {StateRunning, StateFailed},
|
||||
StateRunning: {StateCollecting, StateFailed},
|
||||
StateCollecting: {StateCompleted, StateFailed},
|
||||
StateQueued: {StatePreparing, StateFailed, StateCancelled},
|
||||
StatePreparing: {StateRunning, StateFailed, StateCancelled},
|
||||
StateRunning: {StateCollecting, StateFailed, StateOrphaned, StateServing, StateCancelled},
|
||||
StateCollecting: {StateCompleted, StateFailed, StateCancelled},
|
||||
StateCompleted: {},
|
||||
StateFailed: {},
|
||||
StateCancelled: {},
|
||||
StateOrphaned: {StateQueued},
|
||||
StateServing: {StateStopping, StateFailed, StateOrphaned, StateCancelled},
|
||||
StateStopping: {StateCompleted, StateFailed},
|
||||
}
|
||||
|
||||
// StateTransitionError is returned when an invalid state transition is attempted.
|
||||
|
|
@ -53,8 +65,8 @@ func (e StateTransitionError) Error() string {
|
|||
|
||||
// StateManager manages task state transitions with audit logging.
|
||||
type StateManager struct {
|
||||
enabled bool
|
||||
auditor *audit.Logger
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewStateManager creates a new state manager with the given audit logger.
|
||||
|
|
@ -116,7 +128,7 @@ func (sm *StateManager) validateTransition(from, to TaskState) error {
|
|||
|
||||
// IsTerminalState returns true if the state is terminal (no further transitions allowed).
|
||||
func IsTerminalState(state TaskState) bool {
|
||||
return state == StateCompleted || state == StateFailed
|
||||
return state == StateCompleted || state == StateFailed || state == StateCancelled
|
||||
}
|
||||
|
||||
// CanTransition returns true if a transition from -> to is valid.
|
||||
|
|
|
|||
133
internal/worker/plugins/jupyter.go
Normal file
133
internal/worker/plugins/jupyter.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
// Package plugins provides framework-specific extensions to the worker.
|
||||
// This implements the prolog/epilog model where plugins hook into task lifecycle events.
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// JupyterManager interface for jupyter service management
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
// TaskRunner executes framework-specific tasks
|
||||
// This is the minimal interface needed by plugins to execute tasks
|
||||
type TaskRunner interface {
|
||||
GetJupyterManager() JupyterManager
|
||||
}
|
||||
|
||||
// RunJupyterTask runs a Jupyter-related task.
|
||||
// It handles start, stop, remove, restore, and list_packages actions.
|
||||
// This is a plugin function that extends the core worker with Jupyter support.
|
||||
func RunJupyterTask(ctx context.Context, runner TaskRunner, task *queue.Task) ([]byte, error) {
|
||||
jm := runner.GetJupyterManager()
|
||||
if jm == nil {
|
||||
return nil, fmt.Errorf("jupyter manager not configured")
|
||||
}
|
||||
|
||||
action := task.Metadata["jupyter_action"]
|
||||
if action == "" {
|
||||
action = "start" // Default action
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "start":
|
||||
name := task.Metadata["jupyter_name"]
|
||||
if name == "" {
|
||||
name = task.Metadata["jupyter_workspace"]
|
||||
}
|
||||
if name == "" {
|
||||
// Extract from jobName if format is "jupyter-<name>"
|
||||
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
|
||||
name = task.JobName[8:]
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||
}
|
||||
|
||||
req := &jupyter.StartRequest{Name: name}
|
||||
service, err := jm.StartService(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
||||
}
|
||||
|
||||
output := map[string]any{
|
||||
"type": "start",
|
||||
"service": service,
|
||||
}
|
||||
return json.Marshal(output)
|
||||
|
||||
case "stop":
|
||||
serviceID := task.Metadata["jupyter_service_id"]
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
if err := jm.StopService(ctx, serviceID); err != nil {
|
||||
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
||||
|
||||
case "remove":
|
||||
serviceID := task.Metadata["jupyter_service_id"]
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
purge := task.Metadata["jupyter_purge"] == "true"
|
||||
if err := jm.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
||||
|
||||
case "restore":
|
||||
name := task.Metadata["jupyter_name"]
|
||||
if name == "" {
|
||||
name = task.Metadata["jupyter_workspace"]
|
||||
}
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||
}
|
||||
serviceID, err := jm.RestoreWorkspace(ctx, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
|
||||
|
||||
case "list_packages":
|
||||
serviceName := task.Metadata["jupyter_name"]
|
||||
if serviceName == "" {
|
||||
// Extract from jobName if format is "jupyter-packages-<name>"
|
||||
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
|
||||
serviceName = task.JobName[16:]
|
||||
}
|
||||
}
|
||||
if serviceName == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
||||
}
|
||||
|
||||
packages, err := jm.ListInstalledPackages(ctx, serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
||||
}
|
||||
|
||||
output := map[string]any{
|
||||
"type": "list_packages",
|
||||
"packages": packages,
|
||||
}
|
||||
return json.Marshal(output)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown jupyter action: %s", action)
|
||||
}
|
||||
}
|
||||
279
internal/worker/plugins/vllm.go
Normal file
279
internal/worker/plugins/vllm.go
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
// Package plugins provides service plugin implementations for fetch_ml.
|
||||
// This file contains the vLLM inference server plugin.
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VLLMPlugin implements the vLLM OpenAI-compatible inference server as a fetch_ml service.
|
||||
// This plugin manages the lifecycle of vLLM, including startup, health monitoring,
|
||||
// and graceful shutdown.
|
||||
type VLLMPlugin struct {
|
||||
logger *slog.Logger
|
||||
baseImage string
|
||||
envOverrides map[string]string
|
||||
}
|
||||
|
||||
// VLLMConfig contains the configuration for a vLLM service instance
|
||||
type VLLMConfig struct {
|
||||
// Model name or path (required)
|
||||
Model string `json:"model"`
|
||||
|
||||
// Tensor parallel size (number of GPUs)
|
||||
TensorParallelSize int `json:"tensor_parallel_size,omitempty"`
|
||||
|
||||
// Port to run on (0 = auto-allocate)
|
||||
Port int `json:"port,omitempty"`
|
||||
|
||||
// Maximum number of tokens in the cache
|
||||
MaxModelLen int `json:"max_model_len,omitempty"`
|
||||
|
||||
// GPU memory utilization (0.0 - 1.0)
|
||||
GpuMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"`
|
||||
|
||||
// Quantization method (awq, gptq, squeezellm, or empty for none)
|
||||
Quantization string `json:"quantization,omitempty"`
|
||||
|
||||
// Trust remote code for custom models
|
||||
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
|
||||
|
||||
// Additional CLI arguments to pass to vLLM
|
||||
ExtraArgs []string `json:"extra_args,omitempty"`
|
||||
}
|
||||
|
||||
// VLLMStatus represents the current status of a vLLM service
|
||||
type VLLMStatus struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
Ready bool `json:"ready"`
|
||||
ModelLoaded bool `json:"model_loaded"`
|
||||
Uptime time.Duration `json:"uptime"`
|
||||
Requests int64 `json:"requests_served"`
|
||||
}
|
||||
|
||||
// NewVLLMPlugin creates a new vLLM plugin instance
|
||||
func NewVLLMPlugin(logger *slog.Logger, baseImage string) *VLLMPlugin {
|
||||
if baseImage == "" {
|
||||
baseImage = "vllm/vllm-openai:latest"
|
||||
}
|
||||
return &VLLMPlugin{
|
||||
logger: logger,
|
||||
baseImage: baseImage,
|
||||
envOverrides: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// GetServiceTemplate returns the service template for vLLM
|
||||
func (p *VLLMPlugin) GetServiceTemplate() ServiceTemplate {
|
||||
return ServiceTemplate{
|
||||
JobType: "service",
|
||||
SlotPool: "service",
|
||||
GPUCount: 1, // Minimum 1 GPU, can be more with tensor parallelism
|
||||
|
||||
Command: []string{
|
||||
"python", "-m", "vllm.entrypoints.openai.api_server",
|
||||
"--model", "{{MODEL_NAME}}",
|
||||
"--port", "{{SERVICE_PORT}}",
|
||||
"--host", "0.0.0.0",
|
||||
"{{#GPU_COUNT}}--tensor-parallel-size", "{{GPU_COUNT}}", "{{/GPU_COUNT}}",
|
||||
"{{#QUANTIZATION}}--quantization", "{{QUANTIZATION}}", "{{/QUANTIZATION}}",
|
||||
"{{#TRUST_REMOTE_CODE}}--trust-remote-code", "{{/TRUST_REMOTE_CODE}}",
|
||||
},
|
||||
|
||||
Env: map[string]string{
|
||||
"CUDA_VISIBLE_DEVICES": "{{GPU_DEVICES}}",
|
||||
"VLLM_LOGGING_LEVEL": "INFO",
|
||||
},
|
||||
|
||||
HealthCheck: ServiceHealthCheck{
|
||||
Liveness: "http://localhost:{{SERVICE_PORT}}/health",
|
||||
Readiness: "http://localhost:{{SERVICE_PORT}}/health",
|
||||
Interval: 30,
|
||||
Timeout: 10,
|
||||
},
|
||||
|
||||
Mounts: []ServiceMount{
|
||||
{Source: "{{MODEL_CACHE}}", Destination: "/root/.cache/huggingface"},
|
||||
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
|
||||
},
|
||||
|
||||
Resources: ResourceRequirements{
|
||||
MinGPU: 1,
|
||||
MaxGPU: 8, // Support up to 8-GPU tensor parallelism
|
||||
MinMemory: 16, // 16GB minimum
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig validates the vLLM configuration
|
||||
func (p *VLLMPlugin) ValidateConfig(config *VLLMConfig) error {
|
||||
if config.Model == "" {
|
||||
return fmt.Errorf("vllm: model name is required")
|
||||
}
|
||||
|
||||
if config.GpuMemoryUtilization < 0 || config.GpuMemoryUtilization > 1 {
|
||||
return fmt.Errorf("vllm: gpu_memory_utilization must be between 0.0 and 1.0")
|
||||
}
|
||||
|
||||
validQuantizations := []string{"", "awq", "gptq", "squeezellm", "fp8"}
|
||||
found := false
|
||||
for _, q := range validQuantizations {
|
||||
if config.Quantization == q {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("vllm: unsupported quantization: %s", config.Quantization)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommand builds the vLLM command line from configuration
|
||||
func (p *VLLMPlugin) BuildCommand(config *VLLMConfig, port int) []string {
|
||||
args := []string{
|
||||
"python", "-m", "vllm.entrypoints.openai.api_server",
|
||||
"--model", config.Model,
|
||||
"--port", fmt.Sprintf("%d", port),
|
||||
"--host", "0.0.0.0",
|
||||
}
|
||||
|
||||
if config.TensorParallelSize > 1 {
|
||||
args = append(args, "--tensor-parallel-size", fmt.Sprintf("%d", config.TensorParallelSize))
|
||||
}
|
||||
|
||||
if config.MaxModelLen > 0 {
|
||||
args = append(args, "--max-model-len", fmt.Sprintf("%d", config.MaxModelLen))
|
||||
}
|
||||
|
||||
if config.GpuMemoryUtilization > 0 {
|
||||
args = append(args, "--gpu-memory-utilization", fmt.Sprintf("%.2f", config.GpuMemoryUtilization))
|
||||
}
|
||||
|
||||
if config.Quantization != "" {
|
||||
args = append(args, "--quantization", config.Quantization)
|
||||
}
|
||||
|
||||
if config.TrustRemoteCode {
|
||||
args = append(args, "--trust-remote-code")
|
||||
}
|
||||
|
||||
// Add any extra arguments
|
||||
args = append(args, config.ExtraArgs...)
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// CheckHealth performs a health check against a running vLLM instance
|
||||
func (p *VLLMPlugin) CheckHealth(ctx context.Context, endpoint string) (*VLLMStatus, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint+"/health", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vllm health check: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &VLLMStatus{Healthy: false, Ready: false}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Parse metrics endpoint for detailed status
|
||||
metricsReq, _ := http.NewRequestWithContext(ctx, "GET", endpoint+"/metrics", nil)
|
||||
metricsResp, err := client.Do(metricsReq)
|
||||
|
||||
status := &VLLMStatus{
|
||||
Healthy: resp.StatusCode == 200,
|
||||
Ready: resp.StatusCode == 200,
|
||||
}
|
||||
|
||||
if err == nil && metricsResp.StatusCode == 200 {
|
||||
defer metricsResp.Body.Close()
|
||||
// Could parse Prometheus metrics here for detailed stats
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// RunVLLMTask executes a vLLM service task
|
||||
// This is called by the worker when it receives a vLLM job assignment
|
||||
func RunVLLMTask(ctx context.Context, logger *slog.Logger, taskID string, metadata map[string]string) ([]byte, error) {
|
||||
logger = logger.With("plugin", "vllm", "task_id", taskID)
|
||||
|
||||
// Parse vLLM configuration from task metadata
|
||||
configJSON, ok := metadata["vllm_config"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("vllm: missing 'vllm_config' in task metadata")
|
||||
}
|
||||
|
||||
var config VLLMConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
||||
return nil, fmt.Errorf("vllm: invalid config: %w", err)
|
||||
}
|
||||
|
||||
plugin := NewVLLMPlugin(logger, "")
|
||||
if err := plugin.ValidateConfig(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("starting vllm service",
|
||||
"model", config.Model,
|
||||
"tensor_parallel", config.TensorParallelSize,
|
||||
)
|
||||
|
||||
// Port allocation and command building would happen here
|
||||
// The actual execution goes through the standard service runner
|
||||
|
||||
// Return success output for the task
|
||||
output := map[string]interface{}{
|
||||
"status": "starting",
|
||||
"model": config.Model,
|
||||
"plugin": "vllm",
|
||||
}
|
||||
|
||||
return json.Marshal(output)
|
||||
}
|
||||
|
||||
// ServiceTemplate is re-exported here for plugin use
|
||||
type ServiceTemplate = struct {
|
||||
JobType string `json:"job_type"`
|
||||
SlotPool string `json:"slot_pool"`
|
||||
GPUCount int `json:"gpu_count"`
|
||||
Command []string `json:"command"`
|
||||
Env map[string]string `json:"env"`
|
||||
HealthCheck ServiceHealthCheck `json:"health_check"`
|
||||
Mounts []ServiceMount `json:"mounts,omitempty"`
|
||||
Resources ResourceRequirements `json:"resources,omitempty"`
|
||||
}
|
||||
|
||||
// ServiceHealthCheck is re-exported for plugin use
|
||||
type ServiceHealthCheck = struct {
|
||||
Liveness string `json:"liveness"`
|
||||
Readiness string `json:"readiness"`
|
||||
Interval int `json:"interval"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// ServiceMount is re-exported for plugin use
|
||||
type ServiceMount = struct {
|
||||
Source string `json:"source"`
|
||||
Destination string `json:"destination"`
|
||||
ReadOnly bool `json:"readonly,omitempty"`
|
||||
}
|
||||
|
||||
// ResourceRequirements defines resource needs for a service
|
||||
type ResourceRequirements struct {
|
||||
MinGPU int `json:"min_gpu,omitempty"`
|
||||
MaxGPU int `json:"max_gpu,omitempty"`
|
||||
MinMemory int `json:"min_memory_gb,omitempty"` // GB
|
||||
MaxMemory int `json:"max_memory_gb,omitempty"` // GB
|
||||
}
|
||||
63
internal/worker/process/isolation.go
Normal file
63
internal/worker/process/isolation.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// Package process provides process isolation and resource limiting for HIPAA compliance.
|
||||
// Implements Worker Process Isolation controls.
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// IsolationConfig holds process isolation parameters
|
||||
type IsolationConfig struct {
|
||||
MaxProcesses int // Fork bomb protection (RLIMIT_NPROC on Linux)
|
||||
MaxOpenFiles int // FD exhaustion protection (RLIMIT_NOFILE)
|
||||
DisableSwap bool // Prevent swap exfiltration
|
||||
OOMScoreAdj int // OOM killer priority adjustment (Linux only)
|
||||
}
|
||||
|
||||
// ApplyIsolation applies process isolation controls to the current process.
|
||||
// This should be called after forking but before execing the target command.
|
||||
func ApplyIsolation(cfg IsolationConfig) error {
|
||||
// Apply resource limits (platform-specific)
|
||||
if err := applyResourceLimits(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// OOM score adjustment (only on Linux)
|
||||
if cfg.OOMScoreAdj != 0 && runtime.GOOS == "linux" {
|
||||
if err := setOOMScoreAdj(cfg.OOMScoreAdj); err != nil {
|
||||
return fmt.Errorf("failed to set OOM score adjustment: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Disable swap (Linux only) - requires CAP_SYS_RESOURCE or root
|
||||
if cfg.DisableSwap && runtime.GOOS == "linux" {
|
||||
if err := disableSwap(); err != nil {
|
||||
// Log but don't fail - swap disabling requires privileges
|
||||
// This is best-effort security hardening
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setOOMScoreAdj adjusts the OOM killer score (Linux only)
|
||||
// Lower values = less likely to be killed (negative is "never kill")
|
||||
// Higher values = more likely to be killed
|
||||
func setOOMScoreAdj(score int) error {
|
||||
// Write to /proc/self/oom_score_adj
|
||||
path := "/proc/self/oom_score_adj"
|
||||
data := []byte(fmt.Sprintf("%d\n", score))
|
||||
return os.WriteFile(path, data, 0644)
|
||||
}
|
||||
|
||||
// IsolatedExec runs a command with process isolation applied.
|
||||
// This is a helper for container execution that applies limits before exec.
|
||||
func IsolatedExec(argv []string, cfg IsolationConfig) error {
|
||||
if err := ApplyIsolation(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return syscall.Exec(argv[0], argv, os.Environ())
|
||||
}
|
||||
122
internal/worker/process/isolation_unix.go
Normal file
122
internal/worker/process/isolation_unix.go
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// applyResourceLimits sets Unix/Linux resource limits
|
||||
func applyResourceLimits(cfg IsolationConfig) error {
|
||||
// Apply file descriptor limits (RLIMIT_NOFILE for FD exhaustion protection)
|
||||
if cfg.MaxOpenFiles > 0 {
|
||||
if err := setResourceLimit(syscall.RLIMIT_NOFILE, uint64(cfg.MaxOpenFiles)); err != nil {
|
||||
return fmt.Errorf("failed to set max open files limit: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply process limits if available (Linux only)
|
||||
if cfg.MaxProcesses > 0 {
|
||||
if err := setProcessLimit(cfg.MaxProcesses); err != nil {
|
||||
// Log but don't fail - this is defense in depth
|
||||
return fmt.Errorf("failed to set max processes limit: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setResourceLimit sets a soft and hard rlimit for the current process
|
||||
func setResourceLimit(resource int, limit uint64) error {
|
||||
rl := &syscall.Rlimit{
|
||||
Cur: limit,
|
||||
Max: limit,
|
||||
}
|
||||
return syscall.Setrlimit(resource, rl)
|
||||
}
|
||||
|
||||
// setProcessLimit sets RLIMIT_NPROC on Linux, no-op on other Unix
|
||||
func setProcessLimit(maxProcs int) error {
|
||||
// Try to set RLIMIT_NPROC - only available on Linux
|
||||
// On Darwin/macOS, this returns ENOTSUP
|
||||
const RLIMIT_NPROC = 7 // Linux value
|
||||
rl := &syscall.Rlimit{
|
||||
Cur: uint64(maxProcs),
|
||||
Max: uint64(maxProcs),
|
||||
}
|
||||
err := syscall.Setrlimit(RLIMIT_NPROC, rl)
|
||||
if err != nil {
|
||||
// ENOTSUP means not supported (macOS)
|
||||
if err == syscall.ENOTSUP || err == syscall.EINVAL {
|
||||
return nil // Silently ignore on platforms that don't support it
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableSwap attempts to lock memory to prevent swapping (mlockall)
|
||||
// This is best-effort and requires CAP_IPC_LOCK capability
|
||||
func disableSwap() error {
|
||||
// MCL_CURRENT: lock all current pages
|
||||
// MCL_FUTURE: lock all future pages
|
||||
const MCL_CURRENT = 0x1
|
||||
const MCL_FUTURE = 0x2
|
||||
|
||||
// Note: mlockall requires CAP_IPC_LOCK capability
|
||||
// If this fails, we log but continue (defense in depth)
|
||||
return syscall.Mlockall(MCL_CURRENT | MCL_FUTURE)
|
||||
}
|
||||
|
||||
// GetCurrentLimits returns the current rlimit values for diagnostics
|
||||
func GetCurrentLimits() (map[string]uint64, error) {
|
||||
limits := make(map[string]uint64)
|
||||
|
||||
// Get NOFILE limit (available on all platforms)
|
||||
var nofile syscall.Rlimit
|
||||
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &nofile); err != nil {
|
||||
return nil, fmt.Errorf("failed to get NOFILE limit: %w", err)
|
||||
}
|
||||
limits["NOFILE_soft"] = nofile.Cur
|
||||
limits["NOFILE_hard"] = nofile.Max
|
||||
|
||||
// Get platform-specific limits
|
||||
getPlatformLimits(limits)
|
||||
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
// getPlatformLimits adds platform-specific limits to the map
|
||||
func getPlatformLimits(limits map[string]uint64) {
|
||||
// Get virtual memory limit (AS)
|
||||
var as syscall.Rlimit
|
||||
if err := syscall.Getrlimit(syscall.RLIMIT_AS, &as); err == nil {
|
||||
limits["AS_soft"] = as.Cur
|
||||
limits["AS_hard"] = as.Max
|
||||
}
|
||||
|
||||
// Get data segment limit
|
||||
var data syscall.Rlimit
|
||||
if err := syscall.Getrlimit(syscall.RLIMIT_DATA, &data); err == nil {
|
||||
limits["DATA_soft"] = data.Cur
|
||||
limits["DATA_hard"] = data.Max
|
||||
}
|
||||
|
||||
// Try to get RLIMIT_NPROC (Linux only)
|
||||
const RLIMIT_NPROC = 7
|
||||
var nproc syscall.Rlimit
|
||||
if err := syscall.Getrlimit(RLIMIT_NPROC, &nproc); err == nil {
|
||||
limits["NPROC_soft"] = nproc.Cur
|
||||
limits["NPROC_hard"] = nproc.Max
|
||||
}
|
||||
|
||||
// Try to get RLIMIT_RSS (Linux only)
|
||||
const RLIMIT_RSS = 5
|
||||
var rss syscall.Rlimit
|
||||
if err := syscall.Getrlimit(RLIMIT_RSS, &rss); err == nil {
|
||||
limits["RSS_soft"] = rss.Cur
|
||||
limits["RSS_hard"] = rss.Max
|
||||
}
|
||||
}
|
||||
41
internal/worker/process/isolation_windows.go
Normal file
41
internal/worker/process/isolation_windows.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package process
|
||||
|
||||
import "fmt"
|
||||
|
||||
// applyResourceLimits is a no-op on Windows as rlimits are Unix-specific.
|
||||
// Windows uses Job Objects for process limits, which is more complex.
|
||||
func applyResourceLimits(cfg IsolationConfig) error {
|
||||
// Windows doesn't support setrlimit - would need Job Objects
|
||||
// For now, log that limits are not enforced on Windows
|
||||
if cfg.MaxOpenFiles > 0 || cfg.MaxProcesses > 0 {
|
||||
// Process limits not implemented on Windows yet
|
||||
// TODO: Use Windows Job Objects for process limits
|
||||
return fmt.Errorf("process isolation limits not implemented on Windows (max_open_files=%d, max_processes=%d)",
|
||||
cfg.MaxOpenFiles, cfg.MaxProcesses)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableSwap is a no-op on Windows
|
||||
func disableSwap() error {
|
||||
// Windows doesn't support mlockall
|
||||
// Would need VirtualLock API
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentLimits returns empty limits on Windows
|
||||
func GetCurrentLimits() (map[string]uint64, error) {
|
||||
limits := make(map[string]uint64)
|
||||
getPlatformLimits(limits)
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
// getPlatformLimits adds Windows-specific limits (none currently)
|
||||
func getPlatformLimits(limits map[string]uint64) {
|
||||
// Windows doesn't have rlimit equivalent for these
|
||||
// Could use GetProcessWorkingSetSize for memory limits
|
||||
limits["WINDOWS"] = 1 // Marker to indicate Windows platform
|
||||
}
|
||||
278
internal/worker/process/network_policy.go
Normal file
278
internal/worker/process/network_policy.go
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
// Package process provides process isolation and security enforcement for worker tasks.
|
||||
// This file implements Network Micro-Segmentation enforcement hooks.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NetworkPolicy defines network segmentation rules for a task
|
||||
type NetworkPolicy struct {
|
||||
// Mode is the network isolation mode: "none", "bridge", "container", "host"
|
||||
Mode string
|
||||
|
||||
// AllowedEndpoints is a list of allowed network endpoints (host:port format)
|
||||
// Only used when Mode is "bridge" or "container"
|
||||
AllowedEndpoints []string
|
||||
|
||||
// BlockedSubnets is a list of CIDR ranges to block
|
||||
BlockedSubnets []string
|
||||
|
||||
// DNSResolution controls DNS resolution (true = allow, false = block)
|
||||
DNSResolution bool
|
||||
|
||||
// OutboundTraffic controls outbound connections (true = allow, false = block)
|
||||
OutboundTraffic bool
|
||||
|
||||
// InboundTraffic controls inbound connections (true = allow, false = block)
|
||||
InboundTraffic bool
|
||||
}
|
||||
|
||||
// DefaultNetworkPolicy returns a hardened default network policy
|
||||
// This implements Network Micro-Segmentation
|
||||
func DefaultNetworkPolicy() NetworkPolicy {
|
||||
return NetworkPolicy{
|
||||
Mode: "none",
|
||||
AllowedEndpoints: []string{},
|
||||
BlockedSubnets: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"},
|
||||
DNSResolution: false,
|
||||
OutboundTraffic: false,
|
||||
InboundTraffic: false,
|
||||
}
|
||||
}
|
||||
|
||||
// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance
|
||||
// This blocks all external network access except specific allowlisted endpoints
|
||||
func HIPAACompliantPolicy(allowlist []string) NetworkPolicy {
|
||||
return NetworkPolicy{
|
||||
Mode: "bridge",
|
||||
AllowedEndpoints: allowlist,
|
||||
BlockedSubnets: []string{"0.0.0.0/0"}, // Block everything by default
|
||||
DNSResolution: len(allowlist) > 0, // Only allow DNS if endpoints specified
|
||||
OutboundTraffic: len(allowlist) > 0, // Only allow outbound if endpoints specified
|
||||
InboundTraffic: false, // Never allow inbound
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks the network policy for security violations
|
||||
func (np *NetworkPolicy) Validate() error {
|
||||
// Validate mode
|
||||
validModes := map[string]bool{
|
||||
"none": true,
|
||||
"bridge": true,
|
||||
"container": true,
|
||||
"host": true,
|
||||
}
|
||||
if !validModes[np.Mode] {
|
||||
return fmt.Errorf("invalid network mode: %q", np.Mode)
|
||||
}
|
||||
|
||||
// Block host network mode in production
|
||||
if np.Mode == "host" {
|
||||
return fmt.Errorf("host network mode is not allowed for security reasons")
|
||||
}
|
||||
|
||||
// Validate allowed endpoints format
|
||||
for _, endpoint := range np.AllowedEndpoints {
|
||||
if !isValidEndpoint(endpoint) {
|
||||
return fmt.Errorf("invalid endpoint format: %q (expected host:port)", endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate CIDR blocks
|
||||
for _, cidr := range np.BlockedSubnets {
|
||||
if !isValidCIDR(cidr) {
|
||||
return fmt.Errorf("invalid CIDR format: %q", cidr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidEndpoint checks if an endpoint string is valid (host:port format)
|
||||
func isValidEndpoint(endpoint string) bool {
|
||||
if endpoint == "" {
|
||||
return false
|
||||
}
|
||||
parts := strings.Split(endpoint, ":")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
// Basic validation - port should be numeric
|
||||
if _, err := parsePort(parts[1]); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isValidCIDR performs basic CIDR validation
|
||||
func isValidCIDR(cidr string) bool {
|
||||
// Simple validation - check for / separator
|
||||
if !strings.Contains(cidr, "/") {
|
||||
return false
|
||||
}
|
||||
parts := strings.Split(cidr, "/")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
// Check prefix is numeric
|
||||
if _, err := parsePort(parts[1]); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parsePort parses a port string (helper for validation)
|
||||
func parsePort(s string) (int, error) {
|
||||
port := 0
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, fmt.Errorf("invalid port")
|
||||
}
|
||||
port = port*10 + int(c-'0')
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// ApplyNetworkPolicy applies network policy enforcement to a podman command
|
||||
// This creates iptables rules and returns the modified command with network options
|
||||
func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) {
|
||||
if err := policy.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid network policy: %w", err)
|
||||
}
|
||||
|
||||
// Apply network mode
|
||||
args := append(baseArgs, "--network", policy.Mode)
|
||||
|
||||
// For bridge mode with specific restrictions, we need to create a custom network
|
||||
if policy.Mode == "bridge" && len(policy.AllowedEndpoints) > 0 {
|
||||
// Add additional network restrictions via iptables (applied externally)
|
||||
// The container will be started with the bridge network, but external
|
||||
// firewall rules will restrict its connectivity
|
||||
|
||||
// Set environment variables to inform the container about network restrictions
|
||||
args = append(args, "-e", "FETCHML_NETWORK_RESTRICTED=1")
|
||||
if !policy.DNSResolution {
|
||||
args = append(args, "-e", "FETCHML_DNS_DISABLED=1")
|
||||
}
|
||||
}
|
||||
|
||||
// Disable DNS if required (via /etc/resolv.conf bind mount)
|
||||
if !policy.DNSResolution {
|
||||
// Mount empty resolv.conf to disable DNS
|
||||
emptyResolv, err := createEmptyResolvConf()
|
||||
if err == nil {
|
||||
args = append(args, "-v", fmt.Sprintf("%s:/etc/resolv.conf:ro", emptyResolv))
|
||||
}
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// createEmptyResolvConf creates a temporary empty resolv.conf file
|
||||
func createEmptyResolvConf() (string, error) {
|
||||
tmpDir := os.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty-resolv.conf")
|
||||
|
||||
// Create empty file if it doesn't exist
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(path, []byte{}, 0644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// SetupExternalFirewall sets up external firewall rules for a container
|
||||
// This is called after the container starts to enforce egress filtering
|
||||
// NOTE: This requires root or CAP_NET_ADMIN capability
|
||||
func SetupExternalFirewall(containerID string, policy NetworkPolicy) error {
|
||||
// This function requires root privileges and iptables
|
||||
// It's meant to be called from a privileged helper or init container
|
||||
|
||||
if len(policy.BlockedSubnets) == 0 && len(policy.AllowedEndpoints) == 0 {
|
||||
return nil // No rules to apply
|
||||
}
|
||||
|
||||
// Get container PID for network namespace targeting
|
||||
pid, err := getContainerPID(containerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get container PID: %w", err)
|
||||
}
|
||||
|
||||
// Create iptables commands in the container's network namespace
|
||||
// This requires nsenter with appropriate capabilities
|
||||
|
||||
// Block all outbound traffic by default
|
||||
if !policy.OutboundTraffic {
|
||||
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-A", "OUTPUT", "-j", "DROP")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to block outbound traffic: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow specific endpoints
|
||||
for _, endpoint := range policy.AllowedEndpoints {
|
||||
host, port := parseEndpoint(endpoint)
|
||||
if host != "" {
|
||||
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-I", "OUTPUT", "1",
|
||||
"-p", "tcp", "-d", host, "--dport", port, "-j", "ACCEPT")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to allow endpoint %s: %w", endpoint, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getContainerPID retrieves the PID of a running container
|
||||
func getContainerPID(containerID string) (string, error) {
|
||||
cmd := exec.Command("podman", "inspect", "-f", "{{.State.Pid}}", containerID)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(output)), nil
|
||||
}
|
||||
|
||||
// parseEndpoint splits an endpoint string into host and port
|
||||
func parseEndpoint(endpoint string) (host, port string) {
|
||||
parts := strings.Split(endpoint, ":")
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration
|
||||
func NetworkPolicyFromSandbox(
|
||||
networkMode string,
|
||||
allowedEndpoints []string,
|
||||
blockedSubnets []string,
|
||||
) NetworkPolicy {
|
||||
// Use defaults if not specified
|
||||
if networkMode == "" {
|
||||
networkMode = "none"
|
||||
}
|
||||
if len(blockedSubnets) == 0 {
|
||||
blockedSubnets = DefaultNetworkPolicy().BlockedSubnets
|
||||
}
|
||||
|
||||
return NetworkPolicy{
|
||||
Mode: networkMode,
|
||||
AllowedEndpoints: allowedEndpoints,
|
||||
BlockedSubnets: blockedSubnets,
|
||||
DNSResolution: networkMode != "none" && len(allowedEndpoints) > 0,
|
||||
OutboundTraffic: networkMode != "none" && len(allowedEndpoints) > 0,
|
||||
InboundTraffic: false, // Never allow inbound by default
|
||||
}
|
||||
}
|
||||
91
internal/worker/process/network_policy_windows.go
Normal file
91
internal/worker/process/network_policy_windows.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
// Package process provides process isolation and security enforcement for worker tasks.
|
||||
// This file implements Network Micro-Segmentation enforcement hooks (Windows stub).
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NetworkPolicy defines network segmentation rules for a task
|
||||
// (Windows stub - policy enforcement handled differently on Windows)
|
||||
type NetworkPolicy struct {
|
||||
Mode string
|
||||
AllowedEndpoints []string
|
||||
BlockedSubnets []string
|
||||
DNSResolution bool
|
||||
OutboundTraffic bool
|
||||
InboundTraffic bool
|
||||
}
|
||||
|
||||
// DefaultNetworkPolicy returns a hardened default network policy (Windows stub)
|
||||
func DefaultNetworkPolicy() NetworkPolicy {
|
||||
return NetworkPolicy{
|
||||
Mode: "none",
|
||||
AllowedEndpoints: []string{},
|
||||
BlockedSubnets: []string{},
|
||||
DNSResolution: false,
|
||||
OutboundTraffic: false,
|
||||
InboundTraffic: false,
|
||||
}
|
||||
}
|
||||
|
||||
// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance (Windows stub)
|
||||
func HIPAACompliantPolicy(allowlist []string) NetworkPolicy {
|
||||
return NetworkPolicy{
|
||||
Mode: "none",
|
||||
AllowedEndpoints: allowlist,
|
||||
BlockedSubnets: []string{},
|
||||
DNSResolution: false,
|
||||
OutboundTraffic: false,
|
||||
InboundTraffic: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks the network policy for security violations (Windows stub)
|
||||
func (np *NetworkPolicy) Validate() error {
|
||||
// On Windows, only "none" mode is supported without additional tooling
|
||||
if np.Mode != "none" && np.Mode != "" {
|
||||
return fmt.Errorf("network mode %q not supported on Windows (use 'none' or implement via Windows Firewall)", np.Mode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyNetworkPolicy applies network policy enforcement (Windows stub)
|
||||
func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) {
|
||||
if err := policy.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid network policy: %w", err)
|
||||
}
|
||||
|
||||
// On Windows, just set the network mode
|
||||
args := append(baseArgs, "--network", policy.Mode)
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// SetupExternalFirewall sets up external firewall rules (Windows stub - no-op)
|
||||
func SetupExternalFirewall(containerID string, policy NetworkPolicy) error {
|
||||
// Windows firewall integration would require PowerShell or netsh
|
||||
// For now, this is a no-op - rely on container runtime's default restrictions
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration (Windows stub)
|
||||
func NetworkPolicyFromSandbox(
|
||||
networkMode string,
|
||||
allowedEndpoints []string,
|
||||
blockedSubnets []string,
|
||||
) NetworkPolicy {
|
||||
if networkMode == "" {
|
||||
networkMode = "none"
|
||||
}
|
||||
return NetworkPolicy{
|
||||
Mode: networkMode,
|
||||
AllowedEndpoints: allowedEndpoints,
|
||||
BlockedSubnets: blockedSubnets,
|
||||
DNSResolution: false,
|
||||
OutboundTraffic: false,
|
||||
InboundTraffic: false,
|
||||
}
|
||||
}
|
||||
163
tests/unit/worker/plugins/jupyter_task_test.go
Normal file
163
tests/unit/worker/plugins/jupyter_task_test.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package plugins__test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
||||
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
||||
)
|
||||
|
||||
type fakeJupyterManager struct {
|
||||
startFn func(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
stopFn func(ctx context.Context, serviceID string) error
|
||||
removeFn func(ctx context.Context, serviceID string, purge bool) error
|
||||
restoreFn func(ctx context.Context, name string) (string, error)
|
||||
listFn func() []*jupyter.JupyterService
|
||||
listPkgsFn func(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
||||
return f.startFn(ctx, req)
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) StopService(ctx context.Context, serviceID string) error {
|
||||
return f.stopFn(ctx, serviceID)
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) RemoveService(ctx context.Context, serviceID string, purge bool) error {
|
||||
return f.removeFn(ctx, serviceID, purge)
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) RestoreWorkspace(ctx context.Context, name string) (string, error) {
|
||||
return f.restoreFn(ctx, name)
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) ListServices() []*jupyter.JupyterService {
|
||||
return f.listFn()
|
||||
}
|
||||
|
||||
func (f *fakeJupyterManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
|
||||
if f.listPkgsFn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return f.listPkgsFn(ctx, serviceName)
|
||||
}
|
||||
|
||||
type jupyterOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service *struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
} `json:"service"`
|
||||
}
|
||||
|
||||
type jupyterPackagesOutput struct {
|
||||
Type string `json:"type"`
|
||||
Packages []struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Source string `json:"source"`
|
||||
} `json:"packages"`
|
||||
}
|
||||
|
||||
func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
||||
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
||||
if req.Name != "my-workspace" {
|
||||
return nil, errors.New("bad name")
|
||||
}
|
||||
return &jupyter.JupyterService{Name: req.Name, URL: "http://127.0.0.1:8888"}, nil
|
||||
},
|
||||
stopFn: func(context.Context, string) error { return nil },
|
||||
removeFn: func(context.Context, string, bool) error { return nil },
|
||||
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
||||
listFn: func() []*jupyter.JupyterService { return nil },
|
||||
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
|
||||
})
|
||||
|
||||
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
|
||||
"task_type": "jupyter",
|
||||
"jupyter_action": "start",
|
||||
"jupyter_name": "my-workspace",
|
||||
"jupyter_workspace": "my-workspace",
|
||||
}}
|
||||
out, err := plugins.RunJupyterTask(context.Background(), w, task)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
t.Fatalf("expected output")
|
||||
}
|
||||
var decoded jupyterOutput
|
||||
if err := json.Unmarshal(out, &decoded); err != nil {
|
||||
t.Fatalf("expected valid JSON, got %v", err)
|
||||
}
|
||||
if decoded.Service == nil || decoded.Service.Name != "my-workspace" {
|
||||
t.Fatalf("expected service name to be my-workspace, got %#v", decoded.Service)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunJupyterTaskStopFailure(t *testing.T) {
|
||||
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
|
||||
removeFn: func(context.Context, string, bool) error { return nil },
|
||||
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
||||
listFn: func() []*jupyter.JupyterService { return nil },
|
||||
listPkgsFn: func(context.Context, string) ([]jupyter.InstalledPackage, error) { return nil, nil },
|
||||
})
|
||||
|
||||
task := &queue.Task{JobName: "jupyter-my-workspace", Metadata: map[string]string{
|
||||
"task_type": "jupyter",
|
||||
"jupyter_action": "stop",
|
||||
"jupyter_service_id": "svc-1",
|
||||
}}
|
||||
_, err := plugins.RunJupyterTask(context.Background(), w, task)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
|
||||
w := tests.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||
stopFn: func(context.Context, string) error { return nil },
|
||||
removeFn: func(context.Context, string, bool) error { return nil },
|
||||
restoreFn: func(context.Context, string) (string, error) { return "", nil },
|
||||
listFn: func() []*jupyter.JupyterService { return nil },
|
||||
listPkgsFn: func(_ context.Context, serviceName string) ([]jupyter.InstalledPackage, error) {
|
||||
if serviceName != "my-workspace" {
|
||||
return nil, errors.New("bad service")
|
||||
}
|
||||
return []jupyter.InstalledPackage{
|
||||
{Name: "numpy", Version: "1.26.0", Source: "pip"},
|
||||
{Name: "pandas", Version: "2.1.0", Source: "conda"},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
task := &queue.Task{JobName: "jupyter-packages-my-workspace", Metadata: map[string]string{
|
||||
"task_type": "jupyter",
|
||||
"jupyter_action": "list_packages",
|
||||
"jupyter_name": "my-workspace",
|
||||
}}
|
||||
|
||||
out, err := plugins.RunJupyterTask(context.Background(), w, task)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
|
||||
var decoded jupyterPackagesOutput
|
||||
if err := json.Unmarshal(out, &decoded); err != nil {
|
||||
t.Fatalf("expected valid JSON, got %v", err)
|
||||
}
|
||||
if len(decoded.Packages) != 2 {
|
||||
t.Fatalf("expected 2 packages, got %d", len(decoded.Packages))
|
||||
}
|
||||
}
|
||||
257
tests/unit/worker/plugins/vllm_test.go
Normal file
257
tests/unit/worker/plugins/vllm_test.go
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
package plugins__test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewVLLMPlugin(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Test with default image
|
||||
p1 := plugins.NewVLLMPlugin(logger, "")
|
||||
assert.NotNil(t, p1)
|
||||
|
||||
// Test with custom image
|
||||
p2 := plugins.NewVLLMPlugin(logger, "custom/vllm:1.0")
|
||||
assert.NotNil(t, p2)
|
||||
}
|
||||
|
||||
func TestVLLMPluginValidateConfig(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
plugin := plugins.NewVLLMPlugin(logger, "")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config plugins.VLLMConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "meta-llama/Llama-2-7b",
|
||||
TensorParallelSize: 1,
|
||||
GpuMemoryUtilization: 0.9,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid gpu memory high",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
GpuMemoryUtilization: 1.5,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid gpu memory low",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
GpuMemoryUtilization: -0.1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid quantization awq",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
Quantization: "awq",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid quantization gptq",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
Quantization: "gptq",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid quantization",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
Quantization: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := plugin.ValidateConfig(&tt.config)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVLLMPluginBuildCommand(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
plugin := plugins.NewVLLMPlugin(logger, "")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config plugins.VLLMConfig
|
||||
port int
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "meta-llama/Llama-2-7b",
|
||||
},
|
||||
port: 8000,
|
||||
expected: []string{
|
||||
"python", "-m", "vllm.entrypoints.openai.api_server",
|
||||
"--model", "meta-llama/Llama-2-7b",
|
||||
"--port", "8000",
|
||||
"--host", "0.0.0.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tensor parallelism",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "meta-llama/Llama-2-70b",
|
||||
TensorParallelSize: 4,
|
||||
},
|
||||
port: 8000,
|
||||
expected: []string{
|
||||
"--tensor-parallel-size", "4",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with quantization",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "test-model",
|
||||
Quantization: "awq",
|
||||
},
|
||||
port: 8000,
|
||||
expected: []string{
|
||||
"--quantization", "awq",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with trust remote code",
|
||||
config: plugins.VLLMConfig{
|
||||
Model: "custom-model",
|
||||
TrustRemoteCode: true,
|
||||
},
|
||||
port: 8000,
|
||||
expected: []string{
|
||||
"--trust-remote-code",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := plugin.BuildCommand(&tt.config, tt.port)
|
||||
for _, exp := range tt.expected {
|
||||
assert.Contains(t, cmd, exp)
|
||||
}
|
||||
assert.Contains(t, cmd, "--model")
|
||||
assert.Contains(t, cmd, "--port")
|
||||
assert.Contains(t, cmd, "8000")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVLLMPluginGetServiceTemplate(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
plugin := plugins.NewVLLMPlugin(logger, "")
|
||||
|
||||
template := plugin.GetServiceTemplate()
|
||||
|
||||
assert.Equal(t, "service", template.JobType)
|
||||
assert.Equal(t, "service", template.SlotPool)
|
||||
assert.Equal(t, 1, template.GPUCount)
|
||||
|
||||
// Verify health check endpoints
|
||||
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Liveness)
|
||||
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/health", template.HealthCheck.Readiness)
|
||||
assert.Equal(t, 30, template.HealthCheck.Interval)
|
||||
assert.Equal(t, 10, template.HealthCheck.Timeout)
|
||||
|
||||
// Verify mounts
|
||||
require.Len(t, template.Mounts, 2)
|
||||
assert.Equal(t, "{{MODEL_CACHE}}", template.Mounts[0].Source)
|
||||
assert.Equal(t, "/root/.cache/huggingface", template.Mounts[0].Destination)
|
||||
|
||||
// Verify resource requirements
|
||||
assert.Equal(t, 1, template.Resources.MinGPU)
|
||||
assert.Equal(t, 8, template.Resources.MaxGPU)
|
||||
assert.Equal(t, 16, template.Resources.MinMemory)
|
||||
}
|
||||
|
||||
func TestRunVLLMTask(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("missing config", func(t *testing.T) {
|
||||
metadata := map[string]string{}
|
||||
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing 'vllm_config'")
|
||||
})
|
||||
|
||||
t.Run("invalid config json", func(t *testing.T) {
|
||||
metadata := map[string]string{
|
||||
"vllm_config": "invalid json",
|
||||
}
|
||||
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid config")
|
||||
})
|
||||
|
||||
t.Run("invalid config values", func(t *testing.T) {
|
||||
config := plugins.VLLMConfig{
|
||||
Model: "", // Missing model
|
||||
}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
metadata := map[string]string{
|
||||
"vllm_config": string(configJSON),
|
||||
}
|
||||
_, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "model name is required")
|
||||
})
|
||||
|
||||
t.Run("successful start", func(t *testing.T) {
|
||||
config := plugins.VLLMConfig{
|
||||
Model: "meta-llama/Llama-2-7b",
|
||||
TensorParallelSize: 1,
|
||||
}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
metadata := map[string]string{
|
||||
"vllm_config": string(configJSON),
|
||||
}
|
||||
output, err := plugins.RunVLLMTask(ctx, logger, "task-123", metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "starting", result["status"])
|
||||
assert.Equal(t, "meta-llama/Llama-2-7b", result["model"])
|
||||
assert.Equal(t, "vllm", result["plugin"])
|
||||
})
|
||||
}
|
||||
Loading…
Reference in a new issue