fetch_ml/tests/integration/scheduler/distributed_test.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

248 lines
5.7 KiB
Go

package scheduler_test
import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDistributedRoundTrip validates full job lifecycle through scheduler
func TestDistributedRoundTrip(t *testing.T) {
// Create scheduler hub with token auth configured
testToken := "test-token-123"
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 4,
AcceptanceTimeoutSecs: 5,
WorkerTokens: map[string]string{
testToken: "test-worker",
},
}, nil)
require.NoError(t, err)
defer hub.Stop()
// Start scheduler
err = hub.Start()
require.NoError(t, err)
// Get scheduler address - use the actual listening address
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
// Create mock worker connection with auth token
workerID := "test-worker"
headers := http.Header{}
headers.Set("Authorization", "Bearer "+testToken)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
require.NoError(t, err)
defer conn.Close()
// Start receive goroutine
recvCh := make(chan scheduler.Message, 10)
go func() {
for {
var msg scheduler.Message
err := conn.ReadJSON(&msg)
if err != nil {
close(recvCh)
return
}
recvCh <- msg
}
}()
// Register worker
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: workerID,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 0,
GPUType: "",
},
}),
})
require.NoError(t, err)
// Wait for ack
select {
case msg := <-recvCh:
require.Equal(t, scheduler.MsgAck, msg.Type)
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for registration ack")
}
// Send heartbeat to show we're alive
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgHeartbeat,
Payload: mustMarshal(scheduler.HeartbeatPayload{
WorkerID: workerID,
Slots: scheduler.SlotStatus{
BatchTotal: 4,
BatchInUse: 0,
},
}),
})
require.NoError(t, err)
// Signal ready for work
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: mustMarshal(scheduler.ReadyPayload{
WorkerID: workerID,
Slots: scheduler.SlotStatus{
BatchTotal: 4,
BatchInUse: 0,
},
Reason: "polling",
}),
})
require.NoError(t, err)
// Wait a bit and verify connection is still alive
time.Sleep(200 * time.Millisecond)
}
// TestWorkerRegistration validates worker registration flow
func TestWorkerRegistration(t *testing.T) {
testToken := "reg-test-token"
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 4,
WorkerTokens: map[string]string{
testToken: "reg-test-worker",
},
}, nil)
require.NoError(t, err)
defer hub.Stop()
// Start scheduler
err = hub.Start()
require.NoError(t, err)
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
// Connect worker with auth token
headers := http.Header{}
headers.Set("Authorization", "Bearer "+testToken)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
require.NoError(t, err)
defer conn.Close()
// Receive channel
recvCh := make(chan scheduler.Message, 10)
go func() {
for {
var msg scheduler.Message
err := conn.ReadJSON(&msg)
if err != nil {
close(recvCh)
return
}
recvCh <- msg
}
}()
// Register with capabilities
workerID := "reg-test-worker"
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: workerID,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 2,
GPUType: "nvidia",
CPUCount: 8,
MemoryGB: 32.0,
},
}),
})
require.NoError(t, err)
// Expect ack
select {
case msg := <-recvCh:
assert.Equal(t, scheduler.MsgAck, msg.Type)
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ack")
}
}
// TestHeartbeat validates heartbeat slot reporting
func TestHeartbeat(t *testing.T) {
testToken := "hb-test-token"
hub, err := scheduler.NewHub(scheduler.HubConfig{
BindAddr: "localhost:0",
StateDir: t.TempDir(),
DefaultBatchSlots: 4,
WorkerTokens: map[string]string{
testToken: "hb-test-worker",
},
}, nil)
require.NoError(t, err)
defer hub.Stop()
// Start scheduler
err = hub.Start()
require.NoError(t, err)
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
wsURL := u.String()
headers := http.Header{}
headers.Set("Authorization", "Bearer "+testToken)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
require.NoError(t, err)
defer conn.Close()
workerID := "hb-test-worker"
// Register first
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: mustMarshal(scheduler.WorkerRegistration{
ID: workerID,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 0,
},
}),
})
require.NoError(t, err)
// Send multiple heartbeats
slots := []scheduler.SlotStatus{
{BatchTotal: 4, BatchInUse: 0},
{BatchTotal: 4, BatchInUse: 1},
{BatchTotal: 4, BatchInUse: 2},
}
for _, slot := range slots {
err = conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgHeartbeat,
Payload: mustMarshal(scheduler.HeartbeatPayload{
WorkerID: workerID,
Slots: slot,
}),
})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
}
// Connection should remain healthy
time.Sleep(100 * time.Millisecond)
}
func mustMarshal(v any) []byte {
b, _ := json.Marshal(v)
return b
}