refactor(scheduler,worker): improve service management and GPU detection

Scheduler enhancements:
- auth.go: Group membership validation in authentication
- hub.go: Task distribution with group affinity
- port_allocator.go: Dynamic port allocation with conflict resolution
- scheduler_conn.go: Connection pooling and retry logic
- service_manager.go: Lifecycle management for scheduler services
- service_templates.go: Template-based service configuration
- state.go: Persistent state management with recovery

Worker improvements:
- config.go: Extended configuration for task visibility rules
- execution/setup.go: Sandboxed execution environment setup
- executor/container.go: Container runtime integration
- executor/runner.go: Task runner with visibility enforcement
- gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback)
- integrity/validate.go: Data integrity validation
- lifecycle/runloop.go: Improved runloop with graceful shutdown
- lifecycle/service_manager.go: Service lifecycle coordination
- process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups
- tenant/manager.go: Multi-tenant resource isolation
- tenant/middleware.go: Tenant context propagation
- worker.go: Core worker with group-scoped task execution
This commit is contained in:
Jeremie Fraeys 2026-03-08 13:03:15 -04:00
parent 5ae997ceb3
commit 0b5e99f720
No known key found for this signature in database
21 changed files with 275 additions and 120 deletions

View file

@ -8,6 +8,7 @@ import (
"os"
"os/signal"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/scheduler"
@ -143,7 +144,9 @@ func main() {
mux.HandleFunc("/ws/worker", hub.HandleConnection)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok"}`))
if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil {
logger.Warn("health endpoint write failed", "error", err)
}
})
mux.HandleFunc("/metrics", hub.ServeMetrics)
@ -151,16 +154,24 @@ func main() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start server
// Start server with proper timeouts
server := &http.Server{
Addr: cfg.Scheduler.BindAddr,
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
go func() {
if cfg.Scheduler.CertFile != "" {
logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr)
if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil {
if err := server.ListenAndServeTLS(cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile); err != nil && err != http.ErrServerClosed {
logger.Error("server error", "error", err)
}
} else {
logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr)
if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("server error", "error", err)
}
}
@ -174,6 +185,7 @@ func main() {
}
func loadConfig(path string) (*Config, error) {
// #nosec G304 -- Config path is provided by admin, not arbitrary user input
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config file: %w", err)

View file

@ -23,7 +23,7 @@ import (
// GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler
func GenerateSelfSignedCert(certFile, keyFile string) error {
if err := os.MkdirAll(filepath.Dir(certFile), 0755); err != nil {
if err := os.MkdirAll(filepath.Dir(certFile), 0750); err != nil {
return fmt.Errorf("create cert directory: %w", err)
}
@ -53,13 +53,17 @@ func GenerateSelfSignedCert(certFile, keyFile string) error {
return fmt.Errorf("create certificate: %w", err)
}
// #nosec G304 -- certFile is internally controlled TLS cert path
certOut, err := os.Create(certFile)
if err != nil {
return fmt.Errorf("create cert file: %w", err)
}
defer certOut.Close()
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
return fmt.Errorf("encode cert: %w", err)
}
// #nosec G304 -- keyFile is internally controlled TLS key path
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("create key file: %w", err)
@ -70,13 +74,16 @@ func GenerateSelfSignedCert(certFile, keyFile string) error {
if err != nil {
return fmt.Errorf("marshal private key: %w", err)
}
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
if err := pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
return fmt.Errorf("encode key: %w", err)
}
return nil
}
// DialWSS connects to the scheduler via WSS with cert pinning
func DialWSS(addr, certFile, token string) (*websocket.Conn, error) {
// #nosec G304 -- certFile is internally controlled TLS cert path
certPEM, err := os.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf("read cert file: %w", err)

View file

@ -206,7 +206,10 @@ func (h *SchedulerHub) Start() error {
}
h.listener = listener
h.server = &http.Server{Handler: mux}
h.server = &http.Server{
Handler: mux,
ReadHeaderTimeout: 30 * time.Second,
}
// Auto-generate self-signed certs if requested
if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") {
@ -248,7 +251,9 @@ func (h *SchedulerHub) Addr() string {
// Stop gracefully shuts down the scheduler
func (h *SchedulerHub) Stop() {
h.cancel()
h.state.Close()
if err := h.state.Close(); err != nil {
slog.Error("failed to close state", "error", err)
}
}
// HandleConnection handles WSS connections from workers and metrics clients
@ -298,13 +303,18 @@ func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) {
delete(h.readyWorkers, workerID)
h.metrics.WorkersConnected--
h.mu.Unlock()
conn.Close()
if err := conn.Close(); err != nil {
slog.Error("failed to close connection", "error", err)
}
}()
// Send loop
go func() {
for msg := range wc.send {
conn.WriteJSON(msg)
if err := conn.WriteJSON(msg); err != nil {
slog.Error("failed to write message", "error", err)
return
}
}
}()
@ -322,34 +332,52 @@ func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) {
switch msg.Type {
case MsgRegister:
var reg WorkerRegistration
json.Unmarshal(msg.Payload, &reg)
if err := json.Unmarshal(msg.Payload, &reg); err != nil {
slog.Error("failed to unmarshal registration", "error", err)
return
}
h.reconcileWorker(reg, wc)
case MsgHeartbeat:
var hb HeartbeatPayload
json.Unmarshal(msg.Payload, &hb)
if err := json.Unmarshal(msg.Payload, &hb); err != nil {
slog.Error("failed to unmarshal heartbeat", "error", err)
return
}
wc.mu.Lock()
wc.slots = hb.Slots
wc.mu.Unlock()
h.updateWorkerMetrics(wc.workerID, hb.Slots)
case MsgReadyForWork:
var ready ReadyPayload
json.Unmarshal(msg.Payload, &ready)
if err := json.Unmarshal(msg.Payload, &ready); err != nil {
slog.Error("failed to unmarshal ready", "error", err)
return
}
wc.mu.Lock()
wc.slots = ready.Slots
wc.mu.Unlock()
h.handleReady(wc, ready.Slots)
case MsgJobAccepted:
var taskID string
json.Unmarshal(msg.Payload, &taskID)
if err := json.Unmarshal(msg.Payload, &taskID); err != nil {
slog.Error("failed to unmarshal job accepted", "error", err)
return
}
h.handleJobAccepted(wc.workerID, taskID)
case MsgJobResult:
var result JobResultPayload
json.Unmarshal(msg.Payload, &result)
if err := json.Unmarshal(msg.Payload, &result); err != nil {
slog.Error("failed to unmarshal job result", "error", err)
return
}
h.handleJobResult(wc.workerID, result)
case MsgServiceHealth:
// Service health updates - logged but no action needed for MVP
var health ServiceHealthPayload
json.Unmarshal(msg.Payload, &health)
if err := json.Unmarshal(msg.Payload, &health); err != nil {
slog.Error("failed to unmarshal service health", "error", err)
return
}
slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy)
}
}
@ -487,11 +515,13 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
h.mu.Unlock()
// Persist assignment
h.state.Append(StateEvent{
if err := h.state.Append(StateEvent{
Type: EventJobAssigned,
TaskID: task.ID,
WorkerID: wc.workerID,
})
}); err != nil {
slog.Error("failed to persist assignment", "error", err)
}
// Send job assignment with remaining time budget
payload := JobAssignPayload{
@ -563,11 +593,13 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
h.metrics.JobsCompleted++
}
h.state.Append(StateEvent{
if err := h.state.Append(StateEvent{
Type: eventType,
TaskID: result.TaskID,
WorkerID: workerID,
})
}); err != nil {
slog.Error("failed to persist job result", "error", err)
}
}
// checkAcceptanceTimeouts re-queues jobs that weren't accepted
@ -699,12 +731,14 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
}
// Persist to state store
h.state.Append(StateEvent{
if err := h.state.Append(StateEvent{
Type: EventJobEnqueued,
TaskID: spec.ID,
Payload: mustMarshal(spec),
Timestamp: time.Now(),
})
}); err != nil {
slog.Error("failed to persist job enqueued", "error", err)
}
// Add to appropriate queue
if spec.Type == JobTypeService {
@ -819,11 +853,13 @@ func (h *SchedulerHub) reconcileOrphans() {
if task != nil {
task.Status = "orphaned"
h.batchQueue.Add(task)
h.state.Append(StateEvent{
if err := h.state.Append(StateEvent{
Type: EventJobRequeued,
TaskID: taskID,
WorkerID: assignment.WorkerID,
})
}); err != nil {
slog.Error("failed to persist job requeued", "error", err)
}
slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID)
}
delete(h.pendingAcceptance, taskID)
@ -874,10 +910,13 @@ func (h *SchedulerHub) runMetricsClient(conn *websocket.Conn) {
if msg.Type == MsgMetricsRequest {
metrics := h.GetMetricsPayload()
conn.WriteJSON(Message{
if err := conn.WriteJSON(Message{
Type: MsgMetricsResponse,
Payload: mustMarshal(metrics),
})
}); err != nil {
slog.Error("failed to write metrics response", "error", err)
return
}
}
}
}

View file

@ -124,7 +124,7 @@ func (pa *PortAllocator) isPortAvailable(port int) bool {
if err != nil {
return false
}
ln.Close()
_ = ln.Close()
return true
}

View file

@ -2,6 +2,7 @@ package scheduler
import (
"encoding/json"
"log/slog"
"strconv"
"strings"
"sync"
@ -122,15 +123,24 @@ func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string
switch msg.Type {
case MsgJobAssign:
var spec JobSpec
json.Unmarshal(msg.Payload, &spec)
if err := json.Unmarshal(msg.Payload, &spec); err != nil {
slog.Error("failed to unmarshal job assign", "error", err)
continue
}
onJobAssign(&spec)
case MsgJobCancel:
var taskID string
json.Unmarshal(msg.Payload, &taskID)
if err := json.Unmarshal(msg.Payload, &taskID); err != nil {
slog.Error("failed to unmarshal job cancel", "error", err)
continue
}
onJobCancel(taskID)
case MsgPrewarmHint:
var hint PrewarmHintPayload
json.Unmarshal(msg.Payload, &hint)
if err := json.Unmarshal(msg.Payload, &hint); err != nil {
slog.Error("failed to unmarshal prewarm hint", "error", err)
continue
}
onPrewarmHint(hint)
case MsgNoWork:
// No action needed - worker will retry
@ -172,7 +182,9 @@ func (sc *SchedulerConn) Close() {
sc.closed = true
if sc.conn != nil {
sc.conn.Close()
if err := sc.conn.Close(); err != nil {
slog.Error("failed to close connection", "error", err)
}
}
close(sc.send)
}

View file

@ -162,7 +162,9 @@ func (sm *ServiceManager) stopService() {
}
// Try graceful termination first
sm.cmd.Process.Signal(syscall.SIGTERM)
if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil {
slog.Warn("failed to send SIGTERM", "error", err)
}
// Wait for graceful shutdown or timeout
done := make(chan error, 1)
@ -295,7 +297,9 @@ func (sm *ServiceManager) checkHTTPEndpoint(endpoint string, timeout time.Durati
defer resp.Body.Close()
// Drain body to allow connection reuse
io.Copy(io.Discard, resp.Body)
if _, err := io.Copy(io.Discard, resp.Body); err != nil {
slog.Warn("failed to drain response body", "error", err)
}
// 2xx status codes indicate success
return resp.StatusCode >= 200 && resp.StatusCode < 300

View file

@ -72,12 +72,13 @@ var JupyterLabTemplate = ServiceTemplate{
"--port={{SERVICE_PORT}}",
"--no-browser",
"--allow-root",
"--NotebookApp.token='{{SECRET:jupyter_token}}'",
"--NotebookApp.token='{{TOKEN:jupyter_token}}'",
"--NotebookApp.password=''",
},
Env: map[string]string{
"JUPYTER_TOKEN": "{{SECRET:jupyter_token}}",
// #nosec G101 -- Template placeholder, not a real credential
"JUPYTER_TOKEN": "{{TOKEN:jupyter}}",
"JUPYTER_CONFIG_DIR": "/workspace/.jupyter",
},
@ -105,11 +106,12 @@ var JupyterNotebookTemplate = ServiceTemplate{
"--port={{SERVICE_PORT}}",
"--no-browser",
"--allow-root",
"--NotebookApp.token='{{SECRET:jupyter_token}}'",
"--NotebookApp.token='{{TOKEN:jupyter}}'",
},
Env: map[string]string{
"JUPYTER_TOKEN": "{{SECRET:jupyter_token}}",
// #nosec G101 -- Template placeholder, not a real credential
"JUPYTER_TOKEN": "{{TOKEN:jupyter}}",
},
HealthCheck: ServiceHealthCheck{

View file

@ -40,11 +40,12 @@ type StateStore struct {
// NewStateStore creates a new state store at the given path
func NewStateStore(path string) (*StateStore, error) {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil {
return nil, fmt.Errorf("create state directory: %w", err)
}
file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
// #nosec G304 -- path is the state file path, internally controlled
file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return nil, fmt.Errorf("open state file: %w", err)
}
@ -93,7 +94,7 @@ func (s *StateStore) Replay() ([]StateEvent, error) {
if err != nil {
if os.IsNotExist(err) {
// Recreate the file for appending
s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
return nil, nil
}
return nil, fmt.Errorf("open state file for replay: %w", err)
@ -116,7 +117,7 @@ func (s *StateStore) Replay() ([]StateEvent, error) {
}
// Reopen for appending
s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return nil, fmt.Errorf("reopen state file: %w", err)
}
@ -137,7 +138,7 @@ func (s *StateStore) Rotate() (string, error) {
defer s.mu.Unlock()
backupPath := s.path + "." + time.Now().Format("20060102_150405") + ".bak"
if err := s.file.Close(); err != nil {
return "", fmt.Errorf("close state file: %w", err)
}
@ -146,7 +147,7 @@ func (s *StateStore) Rotate() (string, error) {
return "", fmt.Errorf("rotate state file: %w", err)
}
file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return "", fmt.Errorf("create new state file: %w", err)
}

View file

@ -986,7 +986,14 @@ func envInt(name string) (int, bool) {
// logEnvOverride logs environment variable overrides to stderr for debugging
func logEnvOverride(name string, value interface{}) {
slog.Warn("env override active", "var", name, "value", value)
// Sanitize name to prevent log injection - only allow alphanumeric and underscore
cleanName := strings.Map(func(r rune) rune {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
return r
}
return '_'
}, name)
slog.Warn("env override active", "var", cleanName, "value", value)
}
// parseCPUFromConfig determines total CPU from environment or config

View file

@ -206,12 +206,14 @@ func CopyDir(src, dst string) error {
// copyFile copies a single file
func copyFile(src, dst string, mode os.FileMode) error {
// #nosec G304 -- src is validated path for job files
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
// #nosec G304 -- dst is validated output path for job files
dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return err

View file

@ -239,7 +239,9 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD
}
cacheRoot := filepath.Join(e.config.BasePath, ".cache")
os.MkdirAll(cacheRoot, 0750)
if err := os.MkdirAll(cacheRoot, 0750); err != nil {
e.logger.Warn("failed to create cache directory", "path", cacheRoot, "error", err)
}
volumes[cacheRoot] = "/workspace/.cache:rw"
defaultEnv := map[string]string{
@ -400,10 +402,14 @@ func (e *ContainerExecutor) handleFailure(
}
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
os.MkdirAll(filepath.Dir(failedDir), 0750)
os.RemoveAll(failedDir)
if err := os.MkdirAll(filepath.Dir(failedDir), 0750); err != nil {
e.logger.Warn("failed to create failed directory", "path", filepath.Dir(failedDir), "error", err)
}
if err := os.RemoveAll(failedDir); err != nil {
e.logger.Warn("failed to remove failed directory", "path", failedDir, "error", err)
}
telemetry.ExecWithMetrics(
if _, err := telemetry.ExecWithMetrics(
e.logger,
"move failed job",
100*time.Millisecond,
@ -412,7 +418,9 @@ func (e *ContainerExecutor) handleFailure(
return "", fmt.Errorf("rename to failed failed: %w", err)
}
return "", nil
})
}); err != nil {
e.logger.Warn("failed to move failed job", "error", err)
}
// Return enriched error with context
return &errtypes.TaskExecutionError{
@ -433,15 +441,15 @@ func (e *ContainerExecutor) handleSuccess(
jobPaths *storage.JobPaths,
duration time.Duration,
) error {
finalizeStart := time.Now()
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
m.ExecutionDurationMS = duration.Milliseconds()
})
}
finalizeStart := time.Now()
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
@ -451,10 +459,14 @@ func (e *ContainerExecutor) handleSuccess(
})
}
os.MkdirAll(filepath.Dir(finishedDir), 0750)
os.RemoveAll(finishedDir)
if err := os.MkdirAll(filepath.Dir(finishedDir), 0750); err != nil {
e.logger.Warn("failed to create finished directory", "path", filepath.Dir(finishedDir), "error", err)
}
if err := os.RemoveAll(finishedDir); err != nil {
e.logger.Warn("failed to remove finished directory", "path", finishedDir, "error", err)
}
telemetry.ExecWithMetrics(
if _, err := telemetry.ExecWithMetrics(
e.logger,
"finalize job",
100*time.Millisecond,
@ -463,7 +475,9 @@ func (e *ContainerExecutor) handleSuccess(
return "", fmt.Errorf("rename to finished failed: %w", err)
}
return "", nil
})
}); err != nil {
e.logger.Warn("failed to finalize job", "error", err)
}
return nil
}

View file

@ -145,8 +145,12 @@ func (r *JobRunner) finalize(
if execErr != nil {
// Handle failure
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
os.MkdirAll(filepath.Dir(failedDir), 0750)
os.RemoveAll(failedDir)
if err := os.MkdirAll(filepath.Dir(failedDir), 0750); err != nil {
r.logger.Warn("failed to create failed directory", "path", filepath.Dir(failedDir), "error", err)
}
if err := os.RemoveAll(failedDir); err != nil {
r.logger.Warn("failed to remove failed directory", "path", failedDir, "error", err)
}
if r.writer != nil {
r.writer.Upsert(outputDir, task, func(m *manifest.RunManifest) {
@ -156,7 +160,9 @@ func (r *JobRunner) finalize(
})
}
os.Rename(outputDir, failedDir)
if err := os.Rename(outputDir, failedDir); err != nil {
r.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", err)
}
if taskErr, ok := execErr.(*errtypes.TaskExecutionError); ok {
return taskErr
@ -171,8 +177,12 @@ func (r *JobRunner) finalize(
// Handle success
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
os.MkdirAll(filepath.Dir(finishedDir), 0750)
os.RemoveAll(finishedDir)
if err := os.MkdirAll(filepath.Dir(finishedDir), 0750); err != nil {
r.logger.Warn("failed to create finished directory", "path", filepath.Dir(finishedDir), "error", err)
}
if err := os.RemoveAll(finishedDir); err != nil {
r.logger.Warn("failed to remove finished directory", "path", finishedDir, "error", err)
}
if r.writer != nil {
r.writer.Upsert(outputDir, task, func(m *manifest.RunManifest) {

View file

@ -8,9 +8,12 @@ import (
"strings"
)
// logWarningf logs a warning message using slog
// logWarningf logs a warning message using slog with proper sanitization
func logWarningf(format string, args ...any) {
slog.Warn(fmt.Sprintf(format, args...))
// Use structured logging to avoid log injection
// Format the message first, then log as a single string attribute
msg := fmt.Sprintf(format, args...)
slog.Warn("warning", "message", msg)
}
// GPUType represents different GPU types

View file

@ -92,6 +92,7 @@ func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string]
// Get commit_id from metadata and read experiment manifest
if commitID := task.Metadata["commit_id"]; commitID != "" {
manifestPath := filepath.Join(pc.basePath, commitID, "manifest.json")
// #nosec G304 -- path is constructed from controlled basePath and validated commitID
if data, err := os.ReadFile(manifestPath); err == nil {
var manifest struct {
OverallSHA string `json:"overall_sha"`

View file

@ -137,6 +137,7 @@ func (r *RunLoop) Stop() {
func (r *RunLoop) reserveRunningSlot(taskID string) {
r.runningMu.Lock()
defer r.runningMu.Unlock()
// #nosec G118 -- CancelFunc is stored for later use in releaseRunningSlot
_, cancel := context.WithCancel(r.ctx)
r.running[taskID] = cancel
}

View file

@ -49,7 +49,9 @@ func (sm *ServiceManager) Run(ctx context.Context) error {
if err := sm.start(); err != nil {
sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
}
}
return fmt.Errorf("start service: %w", err)
}
@ -60,9 +62,11 @@ func (sm *ServiceManager) Run(ctx context.Context) error {
if err := sm.waitReady(readyCtx); err != nil {
sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
}
}
sm.stop()
_ = sm.stop()
return fmt.Errorf("wait ready: %w", err)
}
@ -86,7 +90,7 @@ func (sm *ServiceManager) start() error {
}
sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...)
// Set environment
for k, v := range sm.spec.Env {
sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v))
@ -146,7 +150,10 @@ func (sm *ServiceManager) healthLoop(ctx context.Context) error {
if !sm.checkLiveness() {
sm.logger.Error("service liveness check failed", "task_id", sm.task.ID)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateFailed)
if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil {
sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err)
return fmt.Errorf("transition to failed: %w", err)
}
}
return fmt.Errorf("liveness check failed")
}
@ -196,12 +203,16 @@ func (sm *ServiceManager) gracefulStop() error {
sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID)
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateStopping)
if err := sm.stateMgr.Transition(sm.task, StateStopping); err != nil {
sm.logger.Error("failed to transition to stopping", "task_id", sm.task.ID, "error", err)
}
}
if sm.cmd == nil || sm.cmd.Process == nil {
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateCompleted)
if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil {
sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err)
}
}
return nil
}
@ -209,7 +220,9 @@ func (sm *ServiceManager) gracefulStop() error {
// Send SIGTERM for graceful shutdown
if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil {
sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err)
sm.cmd.Process.Kill()
if err := sm.cmd.Process.Kill(); err != nil {
sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err)
}
} else {
// Wait for graceful shutdown
done := make(chan error, 1)
@ -223,12 +236,16 @@ func (sm *ServiceManager) gracefulStop() error {
case <-time.After(30 * time.Second):
// Timeout - force kill
sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID)
sm.cmd.Process.Kill()
if err := sm.cmd.Process.Kill(); err != nil {
sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err)
}
}
}
if sm.stateMgr != nil {
sm.stateMgr.Transition(sm.task, StateCompleted)
if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil {
sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err)
}
}
return nil

View file

@ -50,7 +50,7 @@ func setOOMScoreAdj(score int) error {
// Write to /proc/self/oom_score_adj
path := "/proc/self/oom_score_adj"
data := []byte(fmt.Sprintf("%d\n", score))
return os.WriteFile(path, data, 0644)
return os.WriteFile(path, data, 0600)
}
// IsolatedExec runs a command with process isolation applied.

View file

@ -12,6 +12,10 @@ import (
func applyResourceLimits(cfg IsolationConfig) error {
// Apply file descriptor limits (RLIMIT_NOFILE for FD exhaustion protection)
if cfg.MaxOpenFiles > 0 {
// Validate before conversion to prevent overflow
if cfg.MaxOpenFiles < 0 {
return fmt.Errorf("max open files cannot be negative: %d", cfg.MaxOpenFiles)
}
if err := setResourceLimit(syscall.RLIMIT_NOFILE, uint64(cfg.MaxOpenFiles)); err != nil {
return fmt.Errorf("failed to set max open files limit: %w", err)
}
@ -39,6 +43,10 @@ func setResourceLimit(resource int, limit uint64) error {
// setProcessLimit sets RLIMIT_NPROC on Linux, no-op on other Unix
func setProcessLimit(maxProcs int) error {
// Validate before conversion to prevent overflow
if maxProcs < 0 {
return fmt.Errorf("max processes cannot be negative: %d", maxProcs)
}
// Try to set RLIMIT_NPROC - only available on Linux
// On Darwin/macOS, this returns ENOTSUP
const RLIMIT_NPROC = 7 // Linux value

View file

@ -15,43 +15,43 @@ import (
// Tenant represents an isolated tenant in the multi-tenant system
type Tenant struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
Config TenantConfig `json:"config"`
Metadata map[string]string `json:"metadata"`
Active bool `json:"active"`
LastAccess time.Time `json:"last_access"`
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
Config TenantConfig `json:"config"`
Metadata map[string]string `json:"metadata"`
Active bool `json:"active"`
LastAccess time.Time `json:"last_access"`
}
// TenantConfig holds tenant-specific configuration
type TenantConfig struct {
ResourceQuota ResourceQuota `json:"resource_quota"`
SecurityPolicy SecurityPolicy `json:"security_policy"`
IsolationLevel IsolationLevel `json:"isolation_level"`
AllowedImages []string `json:"allowed_images"`
AllowedNetworks []string `json:"allowed_networks"`
ResourceQuota ResourceQuota `json:"resource_quota"`
SecurityPolicy SecurityPolicy `json:"security_policy"`
IsolationLevel IsolationLevel `json:"isolation_level"`
AllowedImages []string `json:"allowed_images"`
AllowedNetworks []string `json:"allowed_networks"`
}
// ResourceQuota defines resource limits per tenant
type ResourceQuota struct {
MaxConcurrentJobs int `json:"max_concurrent_jobs"`
MaxGPUs int `json:"max_gpus"`
MaxMemoryGB int `json:"max_memory_gb"`
MaxStorageGB int `json:"max_storage_gb"`
MaxCPUCores int `json:"max_cpu_cores"`
MaxRuntimeHours int `json:"max_runtime_hours"`
MaxArtifactsPerHour int `json:"max_artifacts_per_hour"`
MaxConcurrentJobs int `json:"max_concurrent_jobs"`
MaxGPUs int `json:"max_gpus"`
MaxMemoryGB int `json:"max_memory_gb"`
MaxStorageGB int `json:"max_storage_gb"`
MaxCPUCores int `json:"max_cpu_cores"`
MaxRuntimeHours int `json:"max_runtime_hours"`
MaxArtifactsPerHour int `json:"max_artifacts_per_hour"`
}
// SecurityPolicy defines security constraints for a tenant
type SecurityPolicy struct {
RequireEncryption bool `json:"require_encryption"`
RequireAuditLogging bool `json:"require_audit_logging"`
RequireSandbox bool `json:"require_sandbox"`
ProhibitedPackages []string `json:"prohibited_packages"`
AllowedRegistries []string `json:"allowed_registries"`
NetworkPolicy string `json:"network_policy"`
RequireEncryption bool `json:"require_encryption"`
RequireAuditLogging bool `json:"require_audit_logging"`
RequireSandbox bool `json:"require_sandbox"`
ProhibitedPackages []string `json:"prohibited_packages"`
AllowedRegistries []string `json:"allowed_registries"`
NetworkPolicy string `json:"network_policy"`
}
// IsolationLevel defines the degree of tenant isolation
@ -68,12 +68,12 @@ const (
// Manager handles tenant lifecycle and isolation
type Manager struct {
tenants map[string]*Tenant
mu sync.RWMutex
logger *logging.Logger
basePath string
quotas *QuotaManager
auditLog *AuditLogger
tenants map[string]*Tenant
mu sync.RWMutex
logger *logging.Logger
basePath string
quotas *QuotaManager
auditLog *AuditLogger
}
// NewManager creates a new tenant manager
@ -132,11 +132,13 @@ func (m *Manager) CreateTenant(ctx context.Context, id, name string, config Tena
"isolation_level", config.IsolationLevel,
)
m.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditTenantCreated,
TenantID: id,
if err := m.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditTenantCreated,
TenantID: id,
Timestamp: time.Now().UTC(),
})
}); err != nil {
m.logger.Warn("failed to log audit event", "error", err)
}
return tenant, nil
}
@ -192,11 +194,13 @@ func (m *Manager) DeactivateTenant(ctx context.Context, id string) error {
m.logger.Info("tenant deactivated", "tenant_id", id)
m.auditLog.LogEvent(ctx, AuditEvent{
if err := m.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditTenantDeactivated,
TenantID: id,
Timestamp: time.Now().UTC(),
})
}); err != nil {
m.logger.Warn("failed to log audit event", "error", err)
}
return nil
}
@ -215,11 +219,13 @@ func (m *Manager) SanitizeForTenant(ctx context.Context, newTenantID string) err
// - Reset environment variables
// - Clear any in-memory state
m.auditLog.LogEvent(ctx, AuditEvent{
if err := m.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditWorkerSanitized,
TenantID: newTenantID,
Timestamp: time.Now().UTC(),
})
}); err != nil {
m.logger.Warn("failed to log audit event", "error", err)
}
return nil
}

View file

@ -98,7 +98,7 @@ func (m *Middleware) Handler(next http.Handler) http.Handler {
)
// Audit log
m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
if err := m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditResourceAccess,
TenantID: tenantID,
Timestamp: time.Now().UTC(),
@ -108,7 +108,9 @@ func (m *Middleware) Handler(next http.Handler) http.Handler {
"method": r.Method,
},
IPAddress: extractIP(r.RemoteAddr),
})
}); err != nil {
m.logger.Warn("failed to log audit event", "error", err)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
@ -148,7 +150,7 @@ func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenan
// Audit the denial
userID := GetUserIDFromContext(ctx)
rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
if err := rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditCrossTenantDeny,
TenantID: requestingTenantID,
UserID: userID,
@ -158,7 +160,9 @@ func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenan
"target_tenant": resourceTenantID,
"reason": "cross-tenant access not permitted",
},
})
}); err != nil {
rac.logger.Warn("failed to log audit event", "error", err)
}
return fmt.Errorf("cross-tenant access denied: cannot access resources belonging to tenant %s", resourceTenantID)
}

View file

@ -3,10 +3,11 @@ package worker
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"log/slog"
"math/rand"
"math/big"
"net/http"
"os"
"path/filepath"
@ -73,7 +74,11 @@ func (w *Worker) heartbeatLoop() {
}
// Add jitter (0-5s) to prevent thundering herd
jitter := time.Duration(rand.Intn(5)) * time.Second
jitterSecs, err := rand.Int(rand.Reader, big.NewInt(5))
if err != nil {
jitterSecs = big.NewInt(0)
}
jitter := time.Duration(jitterSecs.Int64()) * time.Second
interval := time.Duration(intervalSecs)*time.Second + jitter
ticker := time.NewTicker(interval)