package scheduler import ( "fmt" "os" "regexp" "strings" ) // TemplateResolver handles variable substitution in job specifications // Template variables are resolved at dispatch time on the worker // // Supported variables: // {{HEAD_ADDR}} - Hostname of rank-0 worker (for multi-node) // {{WORLD_SIZE}} - Total node count (for multi-node) // {{NODE_RANK}} - 0-based rank of this worker (for multi-node) // {{GPU_COUNT}} - Number of GPUs available on this worker // {{SERVICE_PORT}} - Port assigned by PortAllocator (for service jobs) // {{HOSTNAME}} - This worker's hostname // {{TASK_ID}} - The task/job ID // {{SECRET:name}} - Secret from worker's secret store // TemplateContext provides the values for template substitution type TemplateContext struct { HeadAddr string // Rank-0 worker hostname (multi-node) WorldSize int // Total nodes (multi-node) NodeRank int // This worker's rank (multi-node) GPUCount int // GPUs available ServicePort int // Assigned port (service jobs) Hostname string // This worker's hostname TaskID string // Task/job ID Secrets map[string]string // Secret store } var ( // templatePattern matches {{VAR}} or {{SECRET:name}} templatePattern = regexp.MustCompile(`\{\{(\w+)(?::([^}]+))?\}\}`) ) // Resolve substitutes template variables in a string // Returns the resolved string and any error encountered func (tc *TemplateContext) Resolve(input string) (string, error) { if !strings.Contains(input, "{{") { return input, nil // No templates to resolve } result := templatePattern.ReplaceAllStringFunc(input, func(match string) string { // Extract variable name and optional secret name parts := templatePattern.FindStringSubmatch(match) if len(parts) < 2 { return match // Keep original if malformed } varName := parts[1] secretName := "" if len(parts) >= 3 { secretName = parts[2] } switch varName { case "HEAD_ADDR": if tc.HeadAddr == "" { return match // Keep unresolved if not set } return tc.HeadAddr case "WORLD_SIZE": if tc.WorldSize == 0 { return match } return fmt.Sprintf("%d", tc.WorldSize) case "NODE_RANK": return fmt.Sprintf("%d", tc.NodeRank) case "GPU_COUNT": return fmt.Sprintf("%d", tc.GPUCount) case "SERVICE_PORT": if tc.ServicePort == 0 { return match } return fmt.Sprintf("%d", tc.ServicePort) case "HOSTNAME": if tc.Hostname == "" { tc.Hostname, _ = os.Hostname() } return tc.Hostname case "TASK_ID": return tc.TaskID case "SECRET": if val, ok := tc.Secrets[secretName]; ok { return val } return match // Keep unresolved if secret not found default: return match // Unknown variable - keep as-is } }) return result, nil } // ResolveCommand resolves templates in a command slice func (tc *TemplateContext) ResolveCommand(cmd []string) ([]string, error) { result := make([]string, len(cmd)) for i, arg := range cmd { resolved, err := tc.Resolve(arg) if err != nil { return nil, fmt.Errorf("resolve arg %d: %w", i, err) } result[i] = resolved } return result, nil } // ResolveEnv resolves templates in environment variables func (tc *TemplateContext) ResolveEnv(env map[string]string) (map[string]string, error) { result := make(map[string]string, len(env)) for k, v := range env { resolved, err := tc.Resolve(v) if err != nil { return nil, fmt.Errorf("resolve env %s: %w", k, err) } result[k] = resolved } return result, nil } // ResolveJobSpec resolves all templates in a JobSpec // Returns a new JobSpec with all template variables substituted func (tc *TemplateContext) ResolveJobSpec(spec *JobSpec) (*JobSpec, error) { // Deep copy the spec resolved := &JobSpec{ ID: spec.ID, Type: spec.Type, SlotPool: spec.SlotPool, GPUCount: spec.GPUCount, GPUType: spec.GPUType, NodeCount: spec.NodeCount, SnapshotID: spec.SnapshotID, SnapshotSHA: spec.SnapshotSHA, Metadata: make(map[string]string, len(spec.Metadata)), } // Copy metadata for k, v := range spec.Metadata { resolved.Metadata[k] = v } // Resolve command if len(spec.Command) > 0 { cmd, err := tc.ResolveCommand(spec.Command) if err != nil { return nil, fmt.Errorf("resolve command: %w", err) } resolved.Command = cmd } // Resolve environment if len(spec.Env) > 0 { env, err := tc.ResolveEnv(spec.Env) if err != nil { return nil, fmt.Errorf("resolve env: %w", err) } resolved.Env = env } // Resolve prolog if len(spec.Prolog) > 0 { prolog, err := tc.ResolveCommand(spec.Prolog) if err != nil { return nil, fmt.Errorf("resolve prolog: %w", err) } resolved.Prolog = prolog } // Resolve epilog if len(spec.Epilog) > 0 { epilog, err := tc.ResolveCommand(spec.Epilog) if err != nil { return nil, fmt.Errorf("resolve epilog: %w", err) } resolved.Epilog = epilog } // Resolve health check endpoints if spec.HealthCheck != nil { hc := &HealthCheck{ LivenessEndpoint: spec.HealthCheck.LivenessEndpoint, ReadinessEndpoint: spec.HealthCheck.ReadinessEndpoint, IntervalSecs: spec.HealthCheck.IntervalSecs, } if hc.LivenessEndpoint != "" { endpoint, err := tc.Resolve(hc.LivenessEndpoint) if err != nil { return nil, fmt.Errorf("resolve liveness endpoint: %w", err) } hc.LivenessEndpoint = endpoint } if hc.ReadinessEndpoint != "" { endpoint, err := tc.Resolve(hc.ReadinessEndpoint) if err != nil { return nil, fmt.Errorf("resolve readiness endpoint: %w", err) } hc.ReadinessEndpoint = endpoint } resolved.HealthCheck = hc } return resolved, nil } // NewMultiNodeContext creates a template context for a multi-node job func NewMultiNodeContext(taskID, headAddr string, worldSize, nodeRank, gpuCount int) *TemplateContext { hostname, _ := os.Hostname() return &TemplateContext{ TaskID: taskID, HeadAddr: headAddr, WorldSize: worldSize, NodeRank: nodeRank, GPUCount: gpuCount, Hostname: hostname, Secrets: make(map[string]string), } } // NewServiceContext creates a template context for a service job func NewServiceContext(taskID string, servicePort, gpuCount int) *TemplateContext { hostname, _ := os.Hostname() return &TemplateContext{ TaskID: taskID, ServicePort: servicePort, GPUCount: gpuCount, Hostname: hostname, Secrets: make(map[string]string), } } // SetSecret adds a secret to the context func (tc *TemplateContext) SetSecret(name, value string) { if tc.Secrets == nil { tc.Secrets = make(map[string]string) } tc.Secrets[name] = value }