fetch_ml/tests/integration/scheduler/gang_service_test.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

316 lines
8.1 KiB
Go

package scheduler_test
import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMultiNodeGangAllocation validates 2-node torchrun scenario
func TestMultiNodeGangAllocation(t *testing.T) {
// Create scheduler hub with gang timeout and auth tokens
tokens := map[string]string{
"worker1-token": "worker-1",
"worker2-token": "worker-2",
}
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 4,
GangAllocTimeoutSecs: 10,
WorkerTokens: tokens,
}, nil)
require.NoError(t, err)
defer hub.Stop()
// Start scheduler
err = hub.Start()
require.NoError(t, err)
// Get scheduler address
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
// Create two worker connections with auth
worker1, recv1 := createTestWorkerWithToken(t, wsURL, "worker-1", "worker1-token")
worker2, recv2 := createTestWorkerWithToken(t, wsURL, "worker-2", "worker2-token")
defer worker1.Close()
defer worker2.Close()
// Register both workers
workers := []struct {
conn *websocket.Conn
recv <-chan scheduler.Message
id string
}{
{worker1, recv1, "worker-1"},
{worker2, recv2, "worker-2"},
}
for _, w := range workers {
w.conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: w.id,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 0,
},
}),
})
msg := <-w.recv
require.Equal(t, scheduler.MsgAck, msg.Type)
}
// Submit multi-node job (2 nodes)
jobID := "gang-job-001"
err = hub.SubmitJob(scheduler.JobSpec{
ID: jobID,
Type: scheduler.JobTypeBatch,
SlotPool: "batch",
NodeCount: 2,
Command: []string{"torchrun", "--nnodes=2", "train.py"},
})
require.NoError(t, err)
// Both workers signal ready
for _, w := range []struct {
conn *websocket.Conn
id string
}{{worker1, "worker-1"}, {worker2, "worker-2"}} {
w.conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: w.id,
Slots: scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0},
}),
})
}
// Both workers should receive job assignments
msg1 := <-recv1
msg2 := <-recv2
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
// Verify both got the same job
var spec1, spec2 scheduler.JobSpec
json.Unmarshal(msg1.Payload, &spec1)
json.Unmarshal(msg2.Payload, &spec2)
assert.Equal(t, jobID, spec1.ID)
assert.Equal(t, jobID, spec2.ID)
// Verify ranks are different
assert.NotEqual(t, spec1.Env["NODE_RANK"], spec2.Env["NODE_RANK"])
assert.Equal(t, "2", spec1.Env["WORLD_SIZE"])
assert.Equal(t, "2", spec2.Env["WORLD_SIZE"])
}
// TestServiceLifecycle validates service job start, health checks, and stop
func TestServiceLifecycle(t *testing.T) {
testToken := "service-test-token"
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 4,
WorkerTokens: map[string]string{
testToken: "service-worker",
},
}, nil)
require.NoError(t, err)
defer hub.Stop()
err = hub.Start()
require.NoError(t, err)
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
// Create worker with auth
conn, recvCh := createTestWorkerWithToken(t, wsURL, "service-worker", testToken)
defer conn.Close()
// Register
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: "service-worker",
}),
})
msg := <-recvCh
require.Equal(t, scheduler.MsgAck, msg.Type)
// Submit service job
jobID := "service-001"
err = hub.SubmitJob(scheduler.JobSpec{
ID: jobID,
Type: scheduler.JobTypeService,
SlotPool: "service",
Command: []string{"python", "-m", "http.server", "8080"},
})
require.NoError(t, err)
// Signal ready
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: "service-worker",
Slots: scheduler.SlotStatus{ServiceTotal: 4, ServiceInUse: 0},
}),
})
// Should receive job assignment
assignMsg := <-recvCh
require.Equal(t, scheduler.MsgJobAssign, assignMsg.Type)
// Send job accepted
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgJobAccepted,
Payload: mustMarshal(map[string]string{
"task_id": jobID,
}),
})
// Send periodic health updates
for i := 0; i < 3; i++ {
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgServiceHealth,
Payload: mustMarshal(scheduler.ServiceHealthPayload{
TaskID: jobID,
Healthy: true,
Message: "healthy",
}),
})
time.Sleep(50 * time.Millisecond)
}
// Verify task exists and is running
task := hub.GetTask(jobID)
require.NotNil(t, task)
assert.Equal(t, "running", task.Status)
}
// TestStarvationPrevention validates low-priority jobs eventually get scheduled
func TestStarvationPrevention(t *testing.T) {
testToken := "starvation-test-token"
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 2,
StarvationThresholdMins: 1, // 1 minute for testing
WorkerTokens: map[string]string{
testToken: "starvation-worker",
},
}, nil)
require.NoError(t, err)
defer hub.Stop()
err = hub.Start()
require.NoError(t, err)
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
// Create worker with auth
conn, recvCh := createTestWorkerWithToken(t, wsURL, "starvation-worker", testToken)
defer conn.Close()
// Register
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: "starvation-worker",
}),
})
msg := <-recvCh
require.Equal(t, scheduler.MsgAck, msg.Type)
// Submit high-priority job
err = hub.SubmitJob(scheduler.JobSpec{
ID: "high-priority-job",
Type: scheduler.JobTypeBatch,
Env: map[string]string{"priority": "100"},
})
require.NoError(t, err)
// Submit low-priority job
err = hub.SubmitJob(scheduler.JobSpec{
ID: "low-priority-job",
Type: scheduler.JobTypeBatch,
Env: map[string]string{"priority": "1"},
})
require.NoError(t, err)
// Signal ready - should get high priority job first
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: "starvation-worker",
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
}),
})
// First assignment should be high priority
msg1 := <-recvCh
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
var spec1 scheduler.JobSpec
json.Unmarshal(msg1.Payload, &spec1)
assert.Equal(t, "high-priority-job", spec1.ID)
// Complete first job
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgJobResult,
Payload: mustMarshal(scheduler.JobResultPayload{
TaskID: "high-priority-job",
State: "completed",
}),
})
// Signal ready again
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: "starvation-worker",
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
}),
})
// Should get low priority job
msg2 := <-recvCh
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
var spec2 scheduler.JobSpec
json.Unmarshal(msg2.Payload, &spec2)
assert.Equal(t, "low-priority-job", spec2.ID)
}
// Helper function to create test worker with token auth
func createTestWorkerWithToken(t *testing.T, wsURL, workerID, token string) (*websocket.Conn, <-chan scheduler.Message) {
headers := http.Header{}
headers.Set("Authorization", "Bearer "+token)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
require.NoError(t, err)
recvCh := make(chan scheduler.Message, 10)
go func() {
for {
var msg scheduler.Message
err := conn.ReadJSON(&msg)
if err != nil {
close(recvCh)
return
}
recvCh <- msg
}
}()
return conn, recvCh
}