Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
245 lines
6.6 KiB
Go
245 lines
6.6 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// TemplateResolver handles variable substitution in job specifications
|
|
// Template variables are resolved at dispatch time on the worker
|
|
//
|
|
// Supported variables:
|
|
// {{HEAD_ADDR}} - Hostname of rank-0 worker (for multi-node)
|
|
// {{WORLD_SIZE}} - Total node count (for multi-node)
|
|
// {{NODE_RANK}} - 0-based rank of this worker (for multi-node)
|
|
// {{GPU_COUNT}} - Number of GPUs available on this worker
|
|
// {{SERVICE_PORT}} - Port assigned by PortAllocator (for service jobs)
|
|
// {{HOSTNAME}} - This worker's hostname
|
|
// {{TASK_ID}} - The task/job ID
|
|
// {{SECRET:name}} - Secret from worker's secret store
|
|
|
|
// TemplateContext provides the values for template substitution
|
|
type TemplateContext struct {
|
|
HeadAddr string // Rank-0 worker hostname (multi-node)
|
|
WorldSize int // Total nodes (multi-node)
|
|
NodeRank int // This worker's rank (multi-node)
|
|
GPUCount int // GPUs available
|
|
ServicePort int // Assigned port (service jobs)
|
|
Hostname string // This worker's hostname
|
|
TaskID string // Task/job ID
|
|
Secrets map[string]string // Secret store
|
|
}
|
|
|
|
var (
|
|
// templatePattern matches {{VAR}} or {{SECRET:name}}
|
|
templatePattern = regexp.MustCompile(`\{\{(\w+)(?::([^}]+))?\}\}`)
|
|
)
|
|
|
|
// Resolve substitutes template variables in a string
|
|
// Returns the resolved string and any error encountered
|
|
func (tc *TemplateContext) Resolve(input string) (string, error) {
|
|
if !strings.Contains(input, "{{") {
|
|
return input, nil // No templates to resolve
|
|
}
|
|
|
|
result := templatePattern.ReplaceAllStringFunc(input, func(match string) string {
|
|
// Extract variable name and optional secret name
|
|
parts := templatePattern.FindStringSubmatch(match)
|
|
if len(parts) < 2 {
|
|
return match // Keep original if malformed
|
|
}
|
|
|
|
varName := parts[1]
|
|
secretName := ""
|
|
if len(parts) >= 3 {
|
|
secretName = parts[2]
|
|
}
|
|
|
|
switch varName {
|
|
case "HEAD_ADDR":
|
|
if tc.HeadAddr == "" {
|
|
return match // Keep unresolved if not set
|
|
}
|
|
return tc.HeadAddr
|
|
case "WORLD_SIZE":
|
|
if tc.WorldSize == 0 {
|
|
return match
|
|
}
|
|
return fmt.Sprintf("%d", tc.WorldSize)
|
|
case "NODE_RANK":
|
|
return fmt.Sprintf("%d", tc.NodeRank)
|
|
case "GPU_COUNT":
|
|
return fmt.Sprintf("%d", tc.GPUCount)
|
|
case "SERVICE_PORT":
|
|
if tc.ServicePort == 0 {
|
|
return match
|
|
}
|
|
return fmt.Sprintf("%d", tc.ServicePort)
|
|
case "HOSTNAME":
|
|
if tc.Hostname == "" {
|
|
tc.Hostname, _ = os.Hostname()
|
|
}
|
|
return tc.Hostname
|
|
case "TASK_ID":
|
|
return tc.TaskID
|
|
case "SECRET":
|
|
if val, ok := tc.Secrets[secretName]; ok {
|
|
return val
|
|
}
|
|
return match // Keep unresolved if secret not found
|
|
default:
|
|
return match // Unknown variable - keep as-is
|
|
}
|
|
})
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ResolveCommand resolves templates in a command slice
|
|
func (tc *TemplateContext) ResolveCommand(cmd []string) ([]string, error) {
|
|
result := make([]string, len(cmd))
|
|
for i, arg := range cmd {
|
|
resolved, err := tc.Resolve(arg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve arg %d: %w", i, err)
|
|
}
|
|
result[i] = resolved
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// ResolveEnv resolves templates in environment variables
|
|
func (tc *TemplateContext) ResolveEnv(env map[string]string) (map[string]string, error) {
|
|
result := make(map[string]string, len(env))
|
|
for k, v := range env {
|
|
resolved, err := tc.Resolve(v)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve env %s: %w", k, err)
|
|
}
|
|
result[k] = resolved
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// ResolveJobSpec resolves all templates in a JobSpec
|
|
// Returns a new JobSpec with all template variables substituted
|
|
func (tc *TemplateContext) ResolveJobSpec(spec *JobSpec) (*JobSpec, error) {
|
|
// Deep copy the spec
|
|
resolved := &JobSpec{
|
|
ID: spec.ID,
|
|
Type: spec.Type,
|
|
SlotPool: spec.SlotPool,
|
|
GPUCount: spec.GPUCount,
|
|
GPUType: spec.GPUType,
|
|
NodeCount: spec.NodeCount,
|
|
SnapshotID: spec.SnapshotID,
|
|
SnapshotSHA: spec.SnapshotSHA,
|
|
Metadata: make(map[string]string, len(spec.Metadata)),
|
|
}
|
|
|
|
// Copy metadata
|
|
for k, v := range spec.Metadata {
|
|
resolved.Metadata[k] = v
|
|
}
|
|
|
|
// Resolve command
|
|
if len(spec.Command) > 0 {
|
|
cmd, err := tc.ResolveCommand(spec.Command)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve command: %w", err)
|
|
}
|
|
resolved.Command = cmd
|
|
}
|
|
|
|
// Resolve environment
|
|
if len(spec.Env) > 0 {
|
|
env, err := tc.ResolveEnv(spec.Env)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve env: %w", err)
|
|
}
|
|
resolved.Env = env
|
|
}
|
|
|
|
// Resolve prolog
|
|
if len(spec.Prolog) > 0 {
|
|
prolog, err := tc.ResolveCommand(spec.Prolog)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve prolog: %w", err)
|
|
}
|
|
resolved.Prolog = prolog
|
|
}
|
|
|
|
// Resolve epilog
|
|
if len(spec.Epilog) > 0 {
|
|
epilog, err := tc.ResolveCommand(spec.Epilog)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve epilog: %w", err)
|
|
}
|
|
resolved.Epilog = epilog
|
|
}
|
|
|
|
// Resolve health check endpoints
|
|
if spec.HealthCheck != nil {
|
|
hc := &HealthCheck{
|
|
LivenessEndpoint: spec.HealthCheck.LivenessEndpoint,
|
|
ReadinessEndpoint: spec.HealthCheck.ReadinessEndpoint,
|
|
IntervalSecs: spec.HealthCheck.IntervalSecs,
|
|
}
|
|
|
|
if hc.LivenessEndpoint != "" {
|
|
endpoint, err := tc.Resolve(hc.LivenessEndpoint)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve liveness endpoint: %w", err)
|
|
}
|
|
hc.LivenessEndpoint = endpoint
|
|
}
|
|
|
|
if hc.ReadinessEndpoint != "" {
|
|
endpoint, err := tc.Resolve(hc.ReadinessEndpoint)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve readiness endpoint: %w", err)
|
|
}
|
|
hc.ReadinessEndpoint = endpoint
|
|
}
|
|
|
|
resolved.HealthCheck = hc
|
|
}
|
|
|
|
return resolved, nil
|
|
}
|
|
|
|
// NewMultiNodeContext creates a template context for a multi-node job
|
|
func NewMultiNodeContext(taskID, headAddr string, worldSize, nodeRank, gpuCount int) *TemplateContext {
|
|
hostname, _ := os.Hostname()
|
|
return &TemplateContext{
|
|
TaskID: taskID,
|
|
HeadAddr: headAddr,
|
|
WorldSize: worldSize,
|
|
NodeRank: nodeRank,
|
|
GPUCount: gpuCount,
|
|
Hostname: hostname,
|
|
Secrets: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// NewServiceContext creates a template context for a service job
|
|
func NewServiceContext(taskID string, servicePort, gpuCount int) *TemplateContext {
|
|
hostname, _ := os.Hostname()
|
|
return &TemplateContext{
|
|
TaskID: taskID,
|
|
ServicePort: servicePort,
|
|
GPUCount: gpuCount,
|
|
Hostname: hostname,
|
|
Secrets: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// SetSecret adds a secret to the context
|
|
func (tc *TemplateContext) SetSecret(name, value string) {
|
|
if tc.Secrets == nil {
|
|
tc.Secrets = make(map[string]string)
|
|
}
|
|
tc.Secrets[name] = value
|
|
}
|