fetch_ml/internal/scheduler/template.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

245 lines
6.6 KiB
Go

package scheduler
import (
"fmt"
"os"
"regexp"
"strings"
)
// TemplateResolver handles variable substitution in job specifications
// Template variables are resolved at dispatch time on the worker
//
// Supported variables:
// {{HEAD_ADDR}} - Hostname of rank-0 worker (for multi-node)
// {{WORLD_SIZE}} - Total node count (for multi-node)
// {{NODE_RANK}} - 0-based rank of this worker (for multi-node)
// {{GPU_COUNT}} - Number of GPUs available on this worker
// {{SERVICE_PORT}} - Port assigned by PortAllocator (for service jobs)
// {{HOSTNAME}} - This worker's hostname
// {{TASK_ID}} - The task/job ID
// {{SECRET:name}} - Secret from worker's secret store
// TemplateContext provides the values for template substitution
type TemplateContext struct {
HeadAddr string // Rank-0 worker hostname (multi-node)
WorldSize int // Total nodes (multi-node)
NodeRank int // This worker's rank (multi-node)
GPUCount int // GPUs available
ServicePort int // Assigned port (service jobs)
Hostname string // This worker's hostname
TaskID string // Task/job ID
Secrets map[string]string // Secret store
}
var (
// templatePattern matches {{VAR}} or {{SECRET:name}}
templatePattern = regexp.MustCompile(`\{\{(\w+)(?::([^}]+))?\}\}`)
)
// Resolve substitutes template variables in a string
// Returns the resolved string and any error encountered
func (tc *TemplateContext) Resolve(input string) (string, error) {
if !strings.Contains(input, "{{") {
return input, nil // No templates to resolve
}
result := templatePattern.ReplaceAllStringFunc(input, func(match string) string {
// Extract variable name and optional secret name
parts := templatePattern.FindStringSubmatch(match)
if len(parts) < 2 {
return match // Keep original if malformed
}
varName := parts[1]
secretName := ""
if len(parts) >= 3 {
secretName = parts[2]
}
switch varName {
case "HEAD_ADDR":
if tc.HeadAddr == "" {
return match // Keep unresolved if not set
}
return tc.HeadAddr
case "WORLD_SIZE":
if tc.WorldSize == 0 {
return match
}
return fmt.Sprintf("%d", tc.WorldSize)
case "NODE_RANK":
return fmt.Sprintf("%d", tc.NodeRank)
case "GPU_COUNT":
return fmt.Sprintf("%d", tc.GPUCount)
case "SERVICE_PORT":
if tc.ServicePort == 0 {
return match
}
return fmt.Sprintf("%d", tc.ServicePort)
case "HOSTNAME":
if tc.Hostname == "" {
tc.Hostname, _ = os.Hostname()
}
return tc.Hostname
case "TASK_ID":
return tc.TaskID
case "SECRET":
if val, ok := tc.Secrets[secretName]; ok {
return val
}
return match // Keep unresolved if secret not found
default:
return match // Unknown variable - keep as-is
}
})
return result, nil
}
// ResolveCommand resolves templates in a command slice
func (tc *TemplateContext) ResolveCommand(cmd []string) ([]string, error) {
result := make([]string, len(cmd))
for i, arg := range cmd {
resolved, err := tc.Resolve(arg)
if err != nil {
return nil, fmt.Errorf("resolve arg %d: %w", i, err)
}
result[i] = resolved
}
return result, nil
}
// ResolveEnv resolves templates in environment variables
func (tc *TemplateContext) ResolveEnv(env map[string]string) (map[string]string, error) {
result := make(map[string]string, len(env))
for k, v := range env {
resolved, err := tc.Resolve(v)
if err != nil {
return nil, fmt.Errorf("resolve env %s: %w", k, err)
}
result[k] = resolved
}
return result, nil
}
// ResolveJobSpec resolves all templates in a JobSpec
// Returns a new JobSpec with all template variables substituted
func (tc *TemplateContext) ResolveJobSpec(spec *JobSpec) (*JobSpec, error) {
// Deep copy the spec
resolved := &JobSpec{
ID: spec.ID,
Type: spec.Type,
SlotPool: spec.SlotPool,
GPUCount: spec.GPUCount,
GPUType: spec.GPUType,
NodeCount: spec.NodeCount,
SnapshotID: spec.SnapshotID,
SnapshotSHA: spec.SnapshotSHA,
Metadata: make(map[string]string, len(spec.Metadata)),
}
// Copy metadata
for k, v := range spec.Metadata {
resolved.Metadata[k] = v
}
// Resolve command
if len(spec.Command) > 0 {
cmd, err := tc.ResolveCommand(spec.Command)
if err != nil {
return nil, fmt.Errorf("resolve command: %w", err)
}
resolved.Command = cmd
}
// Resolve environment
if len(spec.Env) > 0 {
env, err := tc.ResolveEnv(spec.Env)
if err != nil {
return nil, fmt.Errorf("resolve env: %w", err)
}
resolved.Env = env
}
// Resolve prolog
if len(spec.Prolog) > 0 {
prolog, err := tc.ResolveCommand(spec.Prolog)
if err != nil {
return nil, fmt.Errorf("resolve prolog: %w", err)
}
resolved.Prolog = prolog
}
// Resolve epilog
if len(spec.Epilog) > 0 {
epilog, err := tc.ResolveCommand(spec.Epilog)
if err != nil {
return nil, fmt.Errorf("resolve epilog: %w", err)
}
resolved.Epilog = epilog
}
// Resolve health check endpoints
if spec.HealthCheck != nil {
hc := &HealthCheck{
LivenessEndpoint: spec.HealthCheck.LivenessEndpoint,
ReadinessEndpoint: spec.HealthCheck.ReadinessEndpoint,
IntervalSecs: spec.HealthCheck.IntervalSecs,
}
if hc.LivenessEndpoint != "" {
endpoint, err := tc.Resolve(hc.LivenessEndpoint)
if err != nil {
return nil, fmt.Errorf("resolve liveness endpoint: %w", err)
}
hc.LivenessEndpoint = endpoint
}
if hc.ReadinessEndpoint != "" {
endpoint, err := tc.Resolve(hc.ReadinessEndpoint)
if err != nil {
return nil, fmt.Errorf("resolve readiness endpoint: %w", err)
}
hc.ReadinessEndpoint = endpoint
}
resolved.HealthCheck = hc
}
return resolved, nil
}
// NewMultiNodeContext creates a template context for a multi-node job
func NewMultiNodeContext(taskID, headAddr string, worldSize, nodeRank, gpuCount int) *TemplateContext {
hostname, _ := os.Hostname()
return &TemplateContext{
TaskID: taskID,
HeadAddr: headAddr,
WorldSize: worldSize,
NodeRank: nodeRank,
GPUCount: gpuCount,
Hostname: hostname,
Secrets: make(map[string]string),
}
}
// NewServiceContext creates a template context for a service job
func NewServiceContext(taskID string, servicePort, gpuCount int) *TemplateContext {
hostname, _ := os.Hostname()
return &TemplateContext{
TaskID: taskID,
ServicePort: servicePort,
GPUCount: gpuCount,
Hostname: hostname,
Secrets: make(map[string]string),
}
}
// SetSecret adds a secret to the context
func (tc *TemplateContext) SetSecret(name, value string) {
if tc.Secrets == nil {
tc.Secrets = make(map[string]string)
}
tc.Secrets[name] = value
}