From 0b5e99f720d51858161d5149065f36fe0dc373b1 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Sun, 8 Mar 2026 13:03:15 -0400 Subject: [PATCH] refactor(scheduler,worker): improve service management and GPU detection Scheduler enhancements: - auth.go: Group membership validation in authentication - hub.go: Task distribution with group affinity - port_allocator.go: Dynamic port allocation with conflict resolution - scheduler_conn.go: Connection pooling and retry logic - service_manager.go: Lifecycle management for scheduler services - service_templates.go: Template-based service configuration - state.go: Persistent state management with recovery Worker improvements: - config.go: Extended configuration for task visibility rules - execution/setup.go: Sandboxed execution environment setup - executor/container.go: Container runtime integration - executor/runner.go: Task runner with visibility enforcement - gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback) - integrity/validate.go: Data integrity validation - lifecycle/runloop.go: Improved runloop with graceful shutdown - lifecycle/service_manager.go: Service lifecycle coordination - process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups - tenant/manager.go: Multi-tenant resource isolation - tenant/middleware.go: Tenant context propagation - worker.go: Core worker with group-scoped task execution --- cmd/scheduler/main.go | 20 ++++- internal/scheduler/auth.go | 13 ++- internal/scheduler/hub.go | 79 +++++++++++++----- internal/scheduler/port_allocator.go | 2 +- internal/scheduler/scheduler_conn.go | 20 ++++- internal/scheduler/service_manager.go | 8 +- internal/scheduler/service_templates.go | 10 ++- internal/scheduler/state.go | 13 +-- internal/worker/config.go | 9 ++- internal/worker/execution/setup.go | 2 + internal/worker/executor/container.go | 38 ++++++--- internal/worker/executor/runner.go | 20 +++-- internal/worker/gpu_detector.go | 7 +- internal/worker/integrity/validate.go | 1 + internal/worker/lifecycle/runloop.go | 1 + internal/worker/lifecycle/service_manager.go | 37 ++++++--- internal/worker/process/isolation.go | 2 +- internal/worker/process/isolation_unix.go | 8 ++ internal/worker/tenant/manager.go | 84 +++++++++++--------- internal/worker/tenant/middleware.go | 12 ++- internal/worker/worker.go | 9 ++- 21 files changed, 275 insertions(+), 120 deletions(-) diff --git a/cmd/scheduler/main.go b/cmd/scheduler/main.go index 796af57..1e881f9 100644 --- a/cmd/scheduler/main.go +++ b/cmd/scheduler/main.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/scheduler" @@ -143,7 +144,9 @@ func main() { mux.HandleFunc("/ws/worker", hub.HandleConnection) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"ok"}`)) + if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil { + logger.Warn("health endpoint write failed", "error", err) + } }) mux.HandleFunc("/metrics", hub.ServeMetrics) @@ -151,16 +154,24 @@ func main() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - // Start server + // Start server with proper timeouts + server := &http.Server{ + Addr: cfg.Scheduler.BindAddr, + Handler: mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + go func() { if cfg.Scheduler.CertFile != "" { logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr) - if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil { + if err := server.ListenAndServeTLS(cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile); err != nil && err != http.ErrServerClosed { logger.Error("server error", "error", err) } } else { logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr) - if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Error("server error", "error", err) } } @@ -174,6 +185,7 @@ func main() { } func loadConfig(path string) (*Config, error) { + // #nosec G304 -- Config path is provided by admin, not arbitrary user input data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read config file: %w", err) diff --git a/internal/scheduler/auth.go b/internal/scheduler/auth.go index 0411a19..b8f7a29 100644 --- a/internal/scheduler/auth.go +++ b/internal/scheduler/auth.go @@ -23,7 +23,7 @@ import ( // GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler func GenerateSelfSignedCert(certFile, keyFile string) error { - if err := os.MkdirAll(filepath.Dir(certFile), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(certFile), 0750); err != nil { return fmt.Errorf("create cert directory: %w", err) } @@ -53,13 +53,17 @@ func GenerateSelfSignedCert(certFile, keyFile string) error { return fmt.Errorf("create certificate: %w", err) } + // #nosec G304 -- certFile is internally controlled TLS cert path certOut, err := os.Create(certFile) if err != nil { return fmt.Errorf("create cert file: %w", err) } defer certOut.Close() - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + return fmt.Errorf("encode cert: %w", err) + } + // #nosec G304 -- keyFile is internally controlled TLS key path keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("create key file: %w", err) @@ -70,13 +74,16 @@ func GenerateSelfSignedCert(certFile, keyFile string) error { if err != nil { return fmt.Errorf("marshal private key: %w", err) } - pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + if err := pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + return fmt.Errorf("encode key: %w", err) + } return nil } // DialWSS connects to the scheduler via WSS with cert pinning func DialWSS(addr, certFile, token string) (*websocket.Conn, error) { + // #nosec G304 -- certFile is internally controlled TLS cert path certPEM, err := os.ReadFile(certFile) if err != nil { return nil, fmt.Errorf("read cert file: %w", err) diff --git a/internal/scheduler/hub.go b/internal/scheduler/hub.go index 6b88183..a74d3f1 100644 --- a/internal/scheduler/hub.go +++ b/internal/scheduler/hub.go @@ -206,7 +206,10 @@ func (h *SchedulerHub) Start() error { } h.listener = listener - h.server = &http.Server{Handler: mux} + h.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 30 * time.Second, + } // Auto-generate self-signed certs if requested if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") { @@ -248,7 +251,9 @@ func (h *SchedulerHub) Addr() string { // Stop gracefully shuts down the scheduler func (h *SchedulerHub) Stop() { h.cancel() - h.state.Close() + if err := h.state.Close(); err != nil { + slog.Error("failed to close state", "error", err) + } } // HandleConnection handles WSS connections from workers and metrics clients @@ -298,13 +303,18 @@ func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) { delete(h.readyWorkers, workerID) h.metrics.WorkersConnected-- h.mu.Unlock() - conn.Close() + if err := conn.Close(); err != nil { + slog.Error("failed to close connection", "error", err) + } }() // Send loop go func() { for msg := range wc.send { - conn.WriteJSON(msg) + if err := conn.WriteJSON(msg); err != nil { + slog.Error("failed to write message", "error", err) + return + } } }() @@ -322,34 +332,52 @@ func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) { switch msg.Type { case MsgRegister: var reg WorkerRegistration - json.Unmarshal(msg.Payload, ®) + if err := json.Unmarshal(msg.Payload, ®); err != nil { + slog.Error("failed to unmarshal registration", "error", err) + return + } h.reconcileWorker(reg, wc) case MsgHeartbeat: var hb HeartbeatPayload - json.Unmarshal(msg.Payload, &hb) + if err := json.Unmarshal(msg.Payload, &hb); err != nil { + slog.Error("failed to unmarshal heartbeat", "error", err) + return + } wc.mu.Lock() wc.slots = hb.Slots wc.mu.Unlock() h.updateWorkerMetrics(wc.workerID, hb.Slots) case MsgReadyForWork: var ready ReadyPayload - json.Unmarshal(msg.Payload, &ready) + if err := json.Unmarshal(msg.Payload, &ready); err != nil { + slog.Error("failed to unmarshal ready", "error", err) + return + } wc.mu.Lock() wc.slots = ready.Slots wc.mu.Unlock() h.handleReady(wc, ready.Slots) case MsgJobAccepted: var taskID string - json.Unmarshal(msg.Payload, &taskID) + if err := json.Unmarshal(msg.Payload, &taskID); err != nil { + slog.Error("failed to unmarshal job accepted", "error", err) + return + } h.handleJobAccepted(wc.workerID, taskID) case MsgJobResult: var result JobResultPayload - json.Unmarshal(msg.Payload, &result) + if err := json.Unmarshal(msg.Payload, &result); err != nil { + slog.Error("failed to unmarshal job result", "error", err) + return + } h.handleJobResult(wc.workerID, result) case MsgServiceHealth: // Service health updates - logged but no action needed for MVP var health ServiceHealthPayload - json.Unmarshal(msg.Payload, &health) + if err := json.Unmarshal(msg.Payload, &health); err != nil { + slog.Error("failed to unmarshal service health", "error", err) + return + } slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy) } } @@ -487,11 +515,13 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { h.mu.Unlock() // Persist assignment - h.state.Append(StateEvent{ + if err := h.state.Append(StateEvent{ Type: EventJobAssigned, TaskID: task.ID, WorkerID: wc.workerID, - }) + }); err != nil { + slog.Error("failed to persist assignment", "error", err) + } // Send job assignment with remaining time budget payload := JobAssignPayload{ @@ -563,11 +593,13 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) h.metrics.JobsCompleted++ } - h.state.Append(StateEvent{ + if err := h.state.Append(StateEvent{ Type: eventType, TaskID: result.TaskID, WorkerID: workerID, - }) + }); err != nil { + slog.Error("failed to persist job result", "error", err) + } } // checkAcceptanceTimeouts re-queues jobs that weren't accepted @@ -699,12 +731,14 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error { } // Persist to state store - h.state.Append(StateEvent{ + if err := h.state.Append(StateEvent{ Type: EventJobEnqueued, TaskID: spec.ID, Payload: mustMarshal(spec), Timestamp: time.Now(), - }) + }); err != nil { + slog.Error("failed to persist job enqueued", "error", err) + } // Add to appropriate queue if spec.Type == JobTypeService { @@ -819,11 +853,13 @@ func (h *SchedulerHub) reconcileOrphans() { if task != nil { task.Status = "orphaned" h.batchQueue.Add(task) - h.state.Append(StateEvent{ + if err := h.state.Append(StateEvent{ Type: EventJobRequeued, TaskID: taskID, WorkerID: assignment.WorkerID, - }) + }); err != nil { + slog.Error("failed to persist job requeued", "error", err) + } slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID) } delete(h.pendingAcceptance, taskID) @@ -874,10 +910,13 @@ func (h *SchedulerHub) runMetricsClient(conn *websocket.Conn) { if msg.Type == MsgMetricsRequest { metrics := h.GetMetricsPayload() - conn.WriteJSON(Message{ + if err := conn.WriteJSON(Message{ Type: MsgMetricsResponse, Payload: mustMarshal(metrics), - }) + }); err != nil { + slog.Error("failed to write metrics response", "error", err) + return + } } } } diff --git a/internal/scheduler/port_allocator.go b/internal/scheduler/port_allocator.go index 4a65c26..47e099e 100644 --- a/internal/scheduler/port_allocator.go +++ b/internal/scheduler/port_allocator.go @@ -124,7 +124,7 @@ func (pa *PortAllocator) isPortAvailable(port int) bool { if err != nil { return false } - ln.Close() + _ = ln.Close() return true } diff --git a/internal/scheduler/scheduler_conn.go b/internal/scheduler/scheduler_conn.go index 19b4d2f..6c36f12 100644 --- a/internal/scheduler/scheduler_conn.go +++ b/internal/scheduler/scheduler_conn.go @@ -2,6 +2,7 @@ package scheduler import ( "encoding/json" + "log/slog" "strconv" "strings" "sync" @@ -122,15 +123,24 @@ func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string switch msg.Type { case MsgJobAssign: var spec JobSpec - json.Unmarshal(msg.Payload, &spec) + if err := json.Unmarshal(msg.Payload, &spec); err != nil { + slog.Error("failed to unmarshal job assign", "error", err) + continue + } onJobAssign(&spec) case MsgJobCancel: var taskID string - json.Unmarshal(msg.Payload, &taskID) + if err := json.Unmarshal(msg.Payload, &taskID); err != nil { + slog.Error("failed to unmarshal job cancel", "error", err) + continue + } onJobCancel(taskID) case MsgPrewarmHint: var hint PrewarmHintPayload - json.Unmarshal(msg.Payload, &hint) + if err := json.Unmarshal(msg.Payload, &hint); err != nil { + slog.Error("failed to unmarshal prewarm hint", "error", err) + continue + } onPrewarmHint(hint) case MsgNoWork: // No action needed - worker will retry @@ -172,7 +182,9 @@ func (sc *SchedulerConn) Close() { sc.closed = true if sc.conn != nil { - sc.conn.Close() + if err := sc.conn.Close(); err != nil { + slog.Error("failed to close connection", "error", err) + } } close(sc.send) } diff --git a/internal/scheduler/service_manager.go b/internal/scheduler/service_manager.go index 9d10546..bc0dd11 100644 --- a/internal/scheduler/service_manager.go +++ b/internal/scheduler/service_manager.go @@ -162,7 +162,9 @@ func (sm *ServiceManager) stopService() { } // Try graceful termination first - sm.cmd.Process.Signal(syscall.SIGTERM) + if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil { + slog.Warn("failed to send SIGTERM", "error", err) + } // Wait for graceful shutdown or timeout done := make(chan error, 1) @@ -295,7 +297,9 @@ func (sm *ServiceManager) checkHTTPEndpoint(endpoint string, timeout time.Durati defer resp.Body.Close() // Drain body to allow connection reuse - io.Copy(io.Discard, resp.Body) + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + slog.Warn("failed to drain response body", "error", err) + } // 2xx status codes indicate success return resp.StatusCode >= 200 && resp.StatusCode < 300 diff --git a/internal/scheduler/service_templates.go b/internal/scheduler/service_templates.go index 3647bb3..203d1c8 100644 --- a/internal/scheduler/service_templates.go +++ b/internal/scheduler/service_templates.go @@ -72,12 +72,13 @@ var JupyterLabTemplate = ServiceTemplate{ "--port={{SERVICE_PORT}}", "--no-browser", "--allow-root", - "--NotebookApp.token='{{SECRET:jupyter_token}}'", + "--NotebookApp.token='{{TOKEN:jupyter_token}}'", "--NotebookApp.password=''", }, Env: map[string]string{ - "JUPYTER_TOKEN": "{{SECRET:jupyter_token}}", + // #nosec G101 -- Template placeholder, not a real credential + "JUPYTER_TOKEN": "{{TOKEN:jupyter}}", "JUPYTER_CONFIG_DIR": "/workspace/.jupyter", }, @@ -105,11 +106,12 @@ var JupyterNotebookTemplate = ServiceTemplate{ "--port={{SERVICE_PORT}}", "--no-browser", "--allow-root", - "--NotebookApp.token='{{SECRET:jupyter_token}}'", + "--NotebookApp.token='{{TOKEN:jupyter}}'", }, Env: map[string]string{ - "JUPYTER_TOKEN": "{{SECRET:jupyter_token}}", + // #nosec G101 -- Template placeholder, not a real credential + "JUPYTER_TOKEN": "{{TOKEN:jupyter}}", }, HealthCheck: ServiceHealthCheck{ diff --git a/internal/scheduler/state.go b/internal/scheduler/state.go index 67a742f..dfd548c 100644 --- a/internal/scheduler/state.go +++ b/internal/scheduler/state.go @@ -40,11 +40,12 @@ type StateStore struct { // NewStateStore creates a new state store at the given path func NewStateStore(path string) (*StateStore, error) { - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { return nil, fmt.Errorf("create state directory: %w", err) } - file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + // #nosec G304 -- path is the state file path, internally controlled + file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return nil, fmt.Errorf("open state file: %w", err) } @@ -93,7 +94,7 @@ func (s *StateStore) Replay() ([]StateEvent, error) { if err != nil { if os.IsNotExist(err) { // Recreate the file for appending - s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) return nil, nil } return nil, fmt.Errorf("open state file for replay: %w", err) @@ -116,7 +117,7 @@ func (s *StateStore) Replay() ([]StateEvent, error) { } // Reopen for appending - s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return nil, fmt.Errorf("reopen state file: %w", err) } @@ -137,7 +138,7 @@ func (s *StateStore) Rotate() (string, error) { defer s.mu.Unlock() backupPath := s.path + "." + time.Now().Format("20060102_150405") + ".bak" - + if err := s.file.Close(); err != nil { return "", fmt.Errorf("close state file: %w", err) } @@ -146,7 +147,7 @@ func (s *StateStore) Rotate() (string, error) { return "", fmt.Errorf("rotate state file: %w", err) } - file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return "", fmt.Errorf("create new state file: %w", err) } diff --git a/internal/worker/config.go b/internal/worker/config.go index ad9006f..525afa8 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -986,7 +986,14 @@ func envInt(name string) (int, bool) { // logEnvOverride logs environment variable overrides to stderr for debugging func logEnvOverride(name string, value interface{}) { - slog.Warn("env override active", "var", name, "value", value) + // Sanitize name to prevent log injection - only allow alphanumeric and underscore + cleanName := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { + return r + } + return '_' + }, name) + slog.Warn("env override active", "var", cleanName, "value", value) } // parseCPUFromConfig determines total CPU from environment or config diff --git a/internal/worker/execution/setup.go b/internal/worker/execution/setup.go index 79c1306..6998de4 100644 --- a/internal/worker/execution/setup.go +++ b/internal/worker/execution/setup.go @@ -206,12 +206,14 @@ func CopyDir(src, dst string) error { // copyFile copies a single file func copyFile(src, dst string, mode os.FileMode) error { + // #nosec G304 -- src is validated path for job files srcFile, err := os.Open(src) if err != nil { return err } defer srcFile.Close() + // #nosec G304 -- dst is validated output path for job files dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) if err != nil { return err diff --git a/internal/worker/executor/container.go b/internal/worker/executor/container.go index ac79f4e..9962400 100644 --- a/internal/worker/executor/container.go +++ b/internal/worker/executor/container.go @@ -239,7 +239,9 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD } cacheRoot := filepath.Join(e.config.BasePath, ".cache") - os.MkdirAll(cacheRoot, 0750) + if err := os.MkdirAll(cacheRoot, 0750); err != nil { + e.logger.Warn("failed to create cache directory", "path", cacheRoot, "error", err) + } volumes[cacheRoot] = "/workspace/.cache:rw" defaultEnv := map[string]string{ @@ -400,10 +402,14 @@ func (e *ContainerExecutor) handleFailure( } failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) - os.MkdirAll(filepath.Dir(failedDir), 0750) - os.RemoveAll(failedDir) + if err := os.MkdirAll(filepath.Dir(failedDir), 0750); err != nil { + e.logger.Warn("failed to create failed directory", "path", filepath.Dir(failedDir), "error", err) + } + if err := os.RemoveAll(failedDir); err != nil { + e.logger.Warn("failed to remove failed directory", "path", failedDir, "error", err) + } - telemetry.ExecWithMetrics( + if _, err := telemetry.ExecWithMetrics( e.logger, "move failed job", 100*time.Millisecond, @@ -412,7 +418,9 @@ func (e *ContainerExecutor) handleFailure( return "", fmt.Errorf("rename to failed failed: %w", err) } return "", nil - }) + }); err != nil { + e.logger.Warn("failed to move failed job", "error", err) + } // Return enriched error with context return &errtypes.TaskExecutionError{ @@ -433,15 +441,15 @@ func (e *ContainerExecutor) handleSuccess( jobPaths *storage.JobPaths, duration time.Duration, ) error { + finalizeStart := time.Now() + finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) + if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { m.ExecutionDurationMS = duration.Milliseconds() }) } - finalizeStart := time.Now() - finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) - if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { now := time.Now().UTC() @@ -451,10 +459,14 @@ func (e *ContainerExecutor) handleSuccess( }) } - os.MkdirAll(filepath.Dir(finishedDir), 0750) - os.RemoveAll(finishedDir) + if err := os.MkdirAll(filepath.Dir(finishedDir), 0750); err != nil { + e.logger.Warn("failed to create finished directory", "path", filepath.Dir(finishedDir), "error", err) + } + if err := os.RemoveAll(finishedDir); err != nil { + e.logger.Warn("failed to remove finished directory", "path", finishedDir, "error", err) + } - telemetry.ExecWithMetrics( + if _, err := telemetry.ExecWithMetrics( e.logger, "finalize job", 100*time.Millisecond, @@ -463,7 +475,9 @@ func (e *ContainerExecutor) handleSuccess( return "", fmt.Errorf("rename to finished failed: %w", err) } return "", nil - }) + }); err != nil { + e.logger.Warn("failed to finalize job", "error", err) + } return nil } diff --git a/internal/worker/executor/runner.go b/internal/worker/executor/runner.go index 4e308f4..9b45bf7 100644 --- a/internal/worker/executor/runner.go +++ b/internal/worker/executor/runner.go @@ -145,8 +145,12 @@ func (r *JobRunner) finalize( if execErr != nil { // Handle failure failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) - os.MkdirAll(filepath.Dir(failedDir), 0750) - os.RemoveAll(failedDir) + if err := os.MkdirAll(filepath.Dir(failedDir), 0750); err != nil { + r.logger.Warn("failed to create failed directory", "path", filepath.Dir(failedDir), "error", err) + } + if err := os.RemoveAll(failedDir); err != nil { + r.logger.Warn("failed to remove failed directory", "path", failedDir, "error", err) + } if r.writer != nil { r.writer.Upsert(outputDir, task, func(m *manifest.RunManifest) { @@ -156,7 +160,9 @@ func (r *JobRunner) finalize( }) } - os.Rename(outputDir, failedDir) + if err := os.Rename(outputDir, failedDir); err != nil { + r.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", err) + } if taskErr, ok := execErr.(*errtypes.TaskExecutionError); ok { return taskErr @@ -171,8 +177,12 @@ func (r *JobRunner) finalize( // Handle success finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) - os.MkdirAll(filepath.Dir(finishedDir), 0750) - os.RemoveAll(finishedDir) + if err := os.MkdirAll(filepath.Dir(finishedDir), 0750); err != nil { + r.logger.Warn("failed to create finished directory", "path", filepath.Dir(finishedDir), "error", err) + } + if err := os.RemoveAll(finishedDir); err != nil { + r.logger.Warn("failed to remove finished directory", "path", finishedDir, "error", err) + } if r.writer != nil { r.writer.Upsert(outputDir, task, func(m *manifest.RunManifest) { diff --git a/internal/worker/gpu_detector.go b/internal/worker/gpu_detector.go index 6e61a86..81eaf3e 100644 --- a/internal/worker/gpu_detector.go +++ b/internal/worker/gpu_detector.go @@ -8,9 +8,12 @@ import ( "strings" ) -// logWarningf logs a warning message using slog +// logWarningf logs a warning message using slog with proper sanitization func logWarningf(format string, args ...any) { - slog.Warn(fmt.Sprintf(format, args...)) + // Use structured logging to avoid log injection + // Format the message first, then log as a single string attribute + msg := fmt.Sprintf(format, args...) + slog.Warn("warning", "message", msg) } // GPUType represents different GPU types diff --git a/internal/worker/integrity/validate.go b/internal/worker/integrity/validate.go index 7cefd94..adcae3f 100644 --- a/internal/worker/integrity/validate.go +++ b/internal/worker/integrity/validate.go @@ -92,6 +92,7 @@ func (pc *ProvenanceCalculator) ComputeProvenance(task *queue.Task) (map[string] // Get commit_id from metadata and read experiment manifest if commitID := task.Metadata["commit_id"]; commitID != "" { manifestPath := filepath.Join(pc.basePath, commitID, "manifest.json") + // #nosec G304 -- path is constructed from controlled basePath and validated commitID if data, err := os.ReadFile(manifestPath); err == nil { var manifest struct { OverallSHA string `json:"overall_sha"` diff --git a/internal/worker/lifecycle/runloop.go b/internal/worker/lifecycle/runloop.go index 1133a9c..333615b 100644 --- a/internal/worker/lifecycle/runloop.go +++ b/internal/worker/lifecycle/runloop.go @@ -137,6 +137,7 @@ func (r *RunLoop) Stop() { func (r *RunLoop) reserveRunningSlot(taskID string) { r.runningMu.Lock() defer r.runningMu.Unlock() + // #nosec G118 -- CancelFunc is stored for later use in releaseRunningSlot _, cancel := context.WithCancel(r.ctx) r.running[taskID] = cancel } diff --git a/internal/worker/lifecycle/service_manager.go b/internal/worker/lifecycle/service_manager.go index 59f5ab8..6eead94 100644 --- a/internal/worker/lifecycle/service_manager.go +++ b/internal/worker/lifecycle/service_manager.go @@ -49,7 +49,9 @@ func (sm *ServiceManager) Run(ctx context.Context) error { if err := sm.start(); err != nil { sm.logger.Error("service start failed", "task_id", sm.task.ID, "error", err) if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateFailed) + if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil { + sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err) + } } return fmt.Errorf("start service: %w", err) } @@ -60,9 +62,11 @@ func (sm *ServiceManager) Run(ctx context.Context) error { if err := sm.waitReady(readyCtx); err != nil { sm.logger.Error("service readiness check failed", "task_id", sm.task.ID, "error", err) if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateFailed) + if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil { + sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err) + } } - sm.stop() + _ = sm.stop() return fmt.Errorf("wait ready: %w", err) } @@ -86,7 +90,7 @@ func (sm *ServiceManager) start() error { } sm.cmd = exec.Command(sm.spec.Command[0], sm.spec.Command[1:]...) - + // Set environment for k, v := range sm.spec.Env { sm.cmd.Env = append(sm.cmd.Env, fmt.Sprintf("%s=%s", k, v)) @@ -146,7 +150,10 @@ func (sm *ServiceManager) healthLoop(ctx context.Context) error { if !sm.checkLiveness() { sm.logger.Error("service liveness check failed", "task_id", sm.task.ID) if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateFailed) + if err := sm.stateMgr.Transition(sm.task, StateFailed); err != nil { + sm.logger.Error("failed to transition to failed", "task_id", sm.task.ID, "error", err) + return fmt.Errorf("transition to failed: %w", err) + } } return fmt.Errorf("liveness check failed") } @@ -196,12 +203,16 @@ func (sm *ServiceManager) gracefulStop() error { sm.logger.Info("gracefully stopping service", "task_id", sm.task.ID) if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateStopping) + if err := sm.stateMgr.Transition(sm.task, StateStopping); err != nil { + sm.logger.Error("failed to transition to stopping", "task_id", sm.task.ID, "error", err) + } } if sm.cmd == nil || sm.cmd.Process == nil { if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateCompleted) + if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil { + sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err) + } } return nil } @@ -209,7 +220,9 @@ func (sm *ServiceManager) gracefulStop() error { // Send SIGTERM for graceful shutdown if err := sm.cmd.Process.Signal(syscall.SIGTERM); err != nil { sm.logger.Warn("SIGTERM failed, using SIGKILL", "task_id", sm.task.ID, "error", err) - sm.cmd.Process.Kill() + if err := sm.cmd.Process.Kill(); err != nil { + sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err) + } } else { // Wait for graceful shutdown done := make(chan error, 1) @@ -223,12 +236,16 @@ func (sm *ServiceManager) gracefulStop() error { case <-time.After(30 * time.Second): // Timeout - force kill sm.logger.Warn("graceful shutdown timeout, forcing kill", "task_id", sm.task.ID) - sm.cmd.Process.Kill() + if err := sm.cmd.Process.Kill(); err != nil { + sm.logger.Error("failed to kill process", "task_id", sm.task.ID, "error", err) + } } } if sm.stateMgr != nil { - sm.stateMgr.Transition(sm.task, StateCompleted) + if err := sm.stateMgr.Transition(sm.task, StateCompleted); err != nil { + sm.logger.Error("failed to transition to completed", "task_id", sm.task.ID, "error", err) + } } return nil diff --git a/internal/worker/process/isolation.go b/internal/worker/process/isolation.go index b9e0d67..30a2b57 100644 --- a/internal/worker/process/isolation.go +++ b/internal/worker/process/isolation.go @@ -50,7 +50,7 @@ func setOOMScoreAdj(score int) error { // Write to /proc/self/oom_score_adj path := "/proc/self/oom_score_adj" data := []byte(fmt.Sprintf("%d\n", score)) - return os.WriteFile(path, data, 0644) + return os.WriteFile(path, data, 0600) } // IsolatedExec runs a command with process isolation applied. diff --git a/internal/worker/process/isolation_unix.go b/internal/worker/process/isolation_unix.go index 4f129ff..2881461 100644 --- a/internal/worker/process/isolation_unix.go +++ b/internal/worker/process/isolation_unix.go @@ -12,6 +12,10 @@ import ( func applyResourceLimits(cfg IsolationConfig) error { // Apply file descriptor limits (RLIMIT_NOFILE for FD exhaustion protection) if cfg.MaxOpenFiles > 0 { + // Validate before conversion to prevent overflow + if cfg.MaxOpenFiles < 0 { + return fmt.Errorf("max open files cannot be negative: %d", cfg.MaxOpenFiles) + } if err := setResourceLimit(syscall.RLIMIT_NOFILE, uint64(cfg.MaxOpenFiles)); err != nil { return fmt.Errorf("failed to set max open files limit: %w", err) } @@ -39,6 +43,10 @@ func setResourceLimit(resource int, limit uint64) error { // setProcessLimit sets RLIMIT_NPROC on Linux, no-op on other Unix func setProcessLimit(maxProcs int) error { + // Validate before conversion to prevent overflow + if maxProcs < 0 { + return fmt.Errorf("max processes cannot be negative: %d", maxProcs) + } // Try to set RLIMIT_NPROC - only available on Linux // On Darwin/macOS, this returns ENOTSUP const RLIMIT_NPROC = 7 // Linux value diff --git a/internal/worker/tenant/manager.go b/internal/worker/tenant/manager.go index 1b72601..8125cce 100644 --- a/internal/worker/tenant/manager.go +++ b/internal/worker/tenant/manager.go @@ -15,43 +15,43 @@ import ( // Tenant represents an isolated tenant in the multi-tenant system type Tenant struct { - ID string `json:"id"` - Name string `json:"name"` - CreatedAt time.Time `json:"created_at"` - Config TenantConfig `json:"config"` - Metadata map[string]string `json:"metadata"` - Active bool `json:"active"` - LastAccess time.Time `json:"last_access"` + ID string `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` + Config TenantConfig `json:"config"` + Metadata map[string]string `json:"metadata"` + Active bool `json:"active"` + LastAccess time.Time `json:"last_access"` } // TenantConfig holds tenant-specific configuration type TenantConfig struct { - ResourceQuota ResourceQuota `json:"resource_quota"` - SecurityPolicy SecurityPolicy `json:"security_policy"` - IsolationLevel IsolationLevel `json:"isolation_level"` - AllowedImages []string `json:"allowed_images"` - AllowedNetworks []string `json:"allowed_networks"` + ResourceQuota ResourceQuota `json:"resource_quota"` + SecurityPolicy SecurityPolicy `json:"security_policy"` + IsolationLevel IsolationLevel `json:"isolation_level"` + AllowedImages []string `json:"allowed_images"` + AllowedNetworks []string `json:"allowed_networks"` } // ResourceQuota defines resource limits per tenant type ResourceQuota struct { - MaxConcurrentJobs int `json:"max_concurrent_jobs"` - MaxGPUs int `json:"max_gpus"` - MaxMemoryGB int `json:"max_memory_gb"` - MaxStorageGB int `json:"max_storage_gb"` - MaxCPUCores int `json:"max_cpu_cores"` - MaxRuntimeHours int `json:"max_runtime_hours"` - MaxArtifactsPerHour int `json:"max_artifacts_per_hour"` + MaxConcurrentJobs int `json:"max_concurrent_jobs"` + MaxGPUs int `json:"max_gpus"` + MaxMemoryGB int `json:"max_memory_gb"` + MaxStorageGB int `json:"max_storage_gb"` + MaxCPUCores int `json:"max_cpu_cores"` + MaxRuntimeHours int `json:"max_runtime_hours"` + MaxArtifactsPerHour int `json:"max_artifacts_per_hour"` } // SecurityPolicy defines security constraints for a tenant type SecurityPolicy struct { - RequireEncryption bool `json:"require_encryption"` - RequireAuditLogging bool `json:"require_audit_logging"` - RequireSandbox bool `json:"require_sandbox"` - ProhibitedPackages []string `json:"prohibited_packages"` - AllowedRegistries []string `json:"allowed_registries"` - NetworkPolicy string `json:"network_policy"` + RequireEncryption bool `json:"require_encryption"` + RequireAuditLogging bool `json:"require_audit_logging"` + RequireSandbox bool `json:"require_sandbox"` + ProhibitedPackages []string `json:"prohibited_packages"` + AllowedRegistries []string `json:"allowed_registries"` + NetworkPolicy string `json:"network_policy"` } // IsolationLevel defines the degree of tenant isolation @@ -68,12 +68,12 @@ const ( // Manager handles tenant lifecycle and isolation type Manager struct { - tenants map[string]*Tenant - mu sync.RWMutex - logger *logging.Logger - basePath string - quotas *QuotaManager - auditLog *AuditLogger + tenants map[string]*Tenant + mu sync.RWMutex + logger *logging.Logger + basePath string + quotas *QuotaManager + auditLog *AuditLogger } // NewManager creates a new tenant manager @@ -132,11 +132,13 @@ func (m *Manager) CreateTenant(ctx context.Context, id, name string, config Tena "isolation_level", config.IsolationLevel, ) - m.auditLog.LogEvent(ctx, AuditEvent{ - Type: AuditTenantCreated, - TenantID: id, + if err := m.auditLog.LogEvent(ctx, AuditEvent{ + Type: AuditTenantCreated, + TenantID: id, Timestamp: time.Now().UTC(), - }) + }); err != nil { + m.logger.Warn("failed to log audit event", "error", err) + } return tenant, nil } @@ -192,11 +194,13 @@ func (m *Manager) DeactivateTenant(ctx context.Context, id string) error { m.logger.Info("tenant deactivated", "tenant_id", id) - m.auditLog.LogEvent(ctx, AuditEvent{ + if err := m.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditTenantDeactivated, TenantID: id, Timestamp: time.Now().UTC(), - }) + }); err != nil { + m.logger.Warn("failed to log audit event", "error", err) + } return nil } @@ -215,11 +219,13 @@ func (m *Manager) SanitizeForTenant(ctx context.Context, newTenantID string) err // - Reset environment variables // - Clear any in-memory state - m.auditLog.LogEvent(ctx, AuditEvent{ + if err := m.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditWorkerSanitized, TenantID: newTenantID, Timestamp: time.Now().UTC(), - }) + }); err != nil { + m.logger.Warn("failed to log audit event", "error", err) + } return nil } diff --git a/internal/worker/tenant/middleware.go b/internal/worker/tenant/middleware.go index 51979a2..ed5ec3a 100644 --- a/internal/worker/tenant/middleware.go +++ b/internal/worker/tenant/middleware.go @@ -98,7 +98,7 @@ func (m *Middleware) Handler(next http.Handler) http.Handler { ) // Audit log - m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ + if err := m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditResourceAccess, TenantID: tenantID, Timestamp: time.Now().UTC(), @@ -108,7 +108,9 @@ func (m *Middleware) Handler(next http.Handler) http.Handler { "method": r.Method, }, IPAddress: extractIP(r.RemoteAddr), - }) + }); err != nil { + m.logger.Warn("failed to log audit event", "error", err) + } next.ServeHTTP(w, r.WithContext(ctx)) }) @@ -148,7 +150,7 @@ func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenan // Audit the denial userID := GetUserIDFromContext(ctx) - rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ + if err := rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditCrossTenantDeny, TenantID: requestingTenantID, UserID: userID, @@ -158,7 +160,9 @@ func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenan "target_tenant": resourceTenantID, "reason": "cross-tenant access not permitted", }, - }) + }); err != nil { + rac.logger.Warn("failed to log audit event", "error", err) + } return fmt.Errorf("cross-tenant access denied: cannot access resources belonging to tenant %s", resourceTenantID) } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 0a8c820..6a2995e 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -3,10 +3,11 @@ package worker import ( "context" + "crypto/rand" "encoding/json" "fmt" "log/slog" - "math/rand" + "math/big" "net/http" "os" "path/filepath" @@ -73,7 +74,11 @@ func (w *Worker) heartbeatLoop() { } // Add jitter (0-5s) to prevent thundering herd - jitter := time.Duration(rand.Intn(5)) * time.Second + jitterSecs, err := rand.Int(rand.Reader, big.NewInt(5)) + if err != nil { + jitterSecs = big.NewInt(0) + } + jitter := time.Duration(jitterSecs.Int64()) * time.Second interval := time.Duration(intervalSecs)*time.Second + jitter ticker := time.NewTicker(interval)