Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
316 lines
8.1 KiB
Go
316 lines
8.1 KiB
Go
package scheduler_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestMultiNodeGangAllocation validates 2-node torchrun scenario
|
|
func TestMultiNodeGangAllocation(t *testing.T) {
|
|
// Create scheduler hub with gang timeout and auth tokens
|
|
tokens := map[string]string{
|
|
"worker1-token": "worker-1",
|
|
"worker2-token": "worker-2",
|
|
}
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 4,
|
|
GangAllocTimeoutSecs: 10,
|
|
WorkerTokens: tokens,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
// Start scheduler
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
// Get scheduler address
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
// Create two worker connections with auth
|
|
worker1, recv1 := createTestWorkerWithToken(t, wsURL, "worker-1", "worker1-token")
|
|
worker2, recv2 := createTestWorkerWithToken(t, wsURL, "worker-2", "worker2-token")
|
|
defer worker1.Close()
|
|
defer worker2.Close()
|
|
|
|
// Register both workers
|
|
workers := []struct {
|
|
conn *websocket.Conn
|
|
recv <-chan scheduler.Message
|
|
id string
|
|
}{
|
|
{worker1, recv1, "worker-1"},
|
|
{worker2, recv2, "worker-2"},
|
|
}
|
|
for _, w := range workers {
|
|
w.conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: w.id,
|
|
Capabilities: scheduler.WorkerCapabilities{
|
|
GPUCount: 0,
|
|
},
|
|
}),
|
|
})
|
|
msg := <-w.recv
|
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
|
}
|
|
|
|
// Submit multi-node job (2 nodes)
|
|
jobID := "gang-job-001"
|
|
err = hub.SubmitJob(scheduler.JobSpec{
|
|
ID: jobID,
|
|
Type: scheduler.JobTypeBatch,
|
|
SlotPool: "batch",
|
|
NodeCount: 2,
|
|
Command: []string{"torchrun", "--nnodes=2", "train.py"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Both workers signal ready
|
|
for _, w := range []struct {
|
|
conn *websocket.Conn
|
|
id string
|
|
}{{worker1, "worker-1"}, {worker2, "worker-2"}} {
|
|
w.conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: w.id,
|
|
Slots: scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0},
|
|
}),
|
|
})
|
|
}
|
|
|
|
// Both workers should receive job assignments
|
|
msg1 := <-recv1
|
|
msg2 := <-recv2
|
|
|
|
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
|
|
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
|
|
|
|
// Verify both got the same job
|
|
var spec1, spec2 scheduler.JobSpec
|
|
json.Unmarshal(msg1.Payload, &spec1)
|
|
json.Unmarshal(msg2.Payload, &spec2)
|
|
|
|
assert.Equal(t, jobID, spec1.ID)
|
|
assert.Equal(t, jobID, spec2.ID)
|
|
|
|
// Verify ranks are different
|
|
assert.NotEqual(t, spec1.Env["NODE_RANK"], spec2.Env["NODE_RANK"])
|
|
assert.Equal(t, "2", spec1.Env["WORLD_SIZE"])
|
|
assert.Equal(t, "2", spec2.Env["WORLD_SIZE"])
|
|
}
|
|
|
|
// TestServiceLifecycle validates service job start, health checks, and stop
|
|
func TestServiceLifecycle(t *testing.T) {
|
|
testToken := "service-test-token"
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 4,
|
|
WorkerTokens: map[string]string{
|
|
testToken: "service-worker",
|
|
},
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
// Create worker with auth
|
|
conn, recvCh := createTestWorkerWithToken(t, wsURL, "service-worker", testToken)
|
|
defer conn.Close()
|
|
|
|
// Register
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: "service-worker",
|
|
}),
|
|
})
|
|
msg := <-recvCh
|
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
|
|
|
// Submit service job
|
|
jobID := "service-001"
|
|
err = hub.SubmitJob(scheduler.JobSpec{
|
|
ID: jobID,
|
|
Type: scheduler.JobTypeService,
|
|
SlotPool: "service",
|
|
Command: []string{"python", "-m", "http.server", "8080"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Signal ready
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: "service-worker",
|
|
Slots: scheduler.SlotStatus{ServiceTotal: 4, ServiceInUse: 0},
|
|
}),
|
|
})
|
|
|
|
// Should receive job assignment
|
|
assignMsg := <-recvCh
|
|
require.Equal(t, scheduler.MsgJobAssign, assignMsg.Type)
|
|
|
|
// Send job accepted
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgJobAccepted,
|
|
Payload: mustMarshal(map[string]string{
|
|
"task_id": jobID,
|
|
}),
|
|
})
|
|
|
|
// Send periodic health updates
|
|
for i := 0; i < 3; i++ {
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgServiceHealth,
|
|
Payload: mustMarshal(scheduler.ServiceHealthPayload{
|
|
TaskID: jobID,
|
|
Healthy: true,
|
|
Message: "healthy",
|
|
}),
|
|
})
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
// Verify task exists and is running
|
|
task := hub.GetTask(jobID)
|
|
require.NotNil(t, task)
|
|
assert.Equal(t, "running", task.Status)
|
|
}
|
|
|
|
// TestStarvationPrevention validates low-priority jobs eventually get scheduled
|
|
func TestStarvationPrevention(t *testing.T) {
|
|
testToken := "starvation-test-token"
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 2,
|
|
StarvationThresholdMins: 1, // 1 minute for testing
|
|
WorkerTokens: map[string]string{
|
|
testToken: "starvation-worker",
|
|
},
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
// Create worker with auth
|
|
conn, recvCh := createTestWorkerWithToken(t, wsURL, "starvation-worker", testToken)
|
|
defer conn.Close()
|
|
|
|
// Register
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: "starvation-worker",
|
|
}),
|
|
})
|
|
msg := <-recvCh
|
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
|
|
|
// Submit high-priority job
|
|
err = hub.SubmitJob(scheduler.JobSpec{
|
|
ID: "high-priority-job",
|
|
Type: scheduler.JobTypeBatch,
|
|
Env: map[string]string{"priority": "100"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Submit low-priority job
|
|
err = hub.SubmitJob(scheduler.JobSpec{
|
|
ID: "low-priority-job",
|
|
Type: scheduler.JobTypeBatch,
|
|
Env: map[string]string{"priority": "1"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Signal ready - should get high priority job first
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: "starvation-worker",
|
|
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
|
|
}),
|
|
})
|
|
|
|
// First assignment should be high priority
|
|
msg1 := <-recvCh
|
|
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
|
|
|
|
var spec1 scheduler.JobSpec
|
|
json.Unmarshal(msg1.Payload, &spec1)
|
|
assert.Equal(t, "high-priority-job", spec1.ID)
|
|
|
|
// Complete first job
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgJobResult,
|
|
Payload: mustMarshal(scheduler.JobResultPayload{
|
|
TaskID: "high-priority-job",
|
|
State: "completed",
|
|
}),
|
|
})
|
|
|
|
// Signal ready again
|
|
conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: "starvation-worker",
|
|
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
|
|
}),
|
|
})
|
|
|
|
// Should get low priority job
|
|
msg2 := <-recvCh
|
|
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
|
|
|
|
var spec2 scheduler.JobSpec
|
|
json.Unmarshal(msg2.Payload, &spec2)
|
|
assert.Equal(t, "low-priority-job", spec2.ID)
|
|
}
|
|
|
|
// Helper function to create test worker with token auth
|
|
func createTestWorkerWithToken(t *testing.T, wsURL, workerID, token string) (*websocket.Conn, <-chan scheduler.Message) {
|
|
headers := http.Header{}
|
|
headers.Set("Authorization", "Bearer "+token)
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
require.NoError(t, err)
|
|
|
|
recvCh := make(chan scheduler.Message, 10)
|
|
go func() {
|
|
for {
|
|
var msg scheduler.Message
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
close(recvCh)
|
|
return
|
|
}
|
|
recvCh <- msg
|
|
}
|
|
}()
|
|
|
|
return conn, recvCh
|
|
}
|