fetch_ml/tests/long_running/scheduler_long_running_test.go
Jeremie Fraeys 6af85ddaf6
feat(tests): enable stress and long-running test suites
Stress Tests:
- TestStress_WorkerConnectBurst: 30 workers, p99 latency validation
- TestStress_JobSubmissionBurst: 1K job submissions
- TestStress_WorkerChurn: 50 connect/disconnect cycles, memory leak detection
- TestStress_ConcurrentScheduling: 10 workers x 20 jobs contention

Long-Running Tests:
- TestLongRunning_MemoryLeak: heap growth monitoring
- TestLongRunning_OrphanRecovery: worker death/requeue stability
- TestLongRunning_WebSocketStability: 20 worker connection stability

Infrastructure:
- Add testreport package with JSON output, flaky test tracking
- Add TestTimer for timing/budget enforcement
- Add WaitForEvent, WaitForTaskStatus helpers
- Fix worker IDs to use valid bench-worker token patterns
2026-03-12 14:05:45 -04:00

364 lines
10 KiB
Go

// Package longrunning provides long-running tests for the scheduler
// These tests are designed for nightly CI runs and are advisory only - failures
// alert the team but don't block releases.
//
// To run long-running tests: go test -v ./tests/long_running/... -run TestLongRunning
// These tests are skipped in short mode (go test -short)
//
// Environment variables:
//
// LONG_RUNNING_DURATION - Override test duration (e.g., "5m" for 5 minutes)
// LONG_RUNNING_WORKERS - Number of workers to use (default varies by test)
package longrunning
import (
"fmt"
"os"
"runtime"
"sync"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/scheduler"
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// getDuration returns the test duration from environment or default
func getDuration(defaultDuration time.Duration) time.Duration {
if d := os.Getenv("LONG_RUNNING_DURATION"); d != "" {
if parsed, err := time.ParseDuration(d); err == nil {
return parsed
}
}
return defaultDuration
}
// TestLongRunning_MemoryLeak monitors heap growth over extended period
// Validates that the scheduler doesn't leak memory under sustained load.
func TestLongRunning_MemoryLeak(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long-running test in short mode")
}
duration := getDuration(2 * time.Minute) // Default 2 min for CI, use LONG_RUNNING_DURATION for longer
cfg := fixtures.DefaultHubConfig()
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
defer fixture.Cleanup()
numWorkers := 10
workers := make([]*fixtures.MockWorker, numWorkers)
// Create initial workers (use bench-worker-* pattern with valid tokens)
for i := 0; i < numWorkers; i++ {
workerID := fmt.Sprintf("bench-worker-%d", i)
workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: 4,
CPUCount: 8,
})
}
// Submit some initial jobs
for i := 0; i < 50; i++ {
fixture.SubmitJob(scheduler.JobSpec{
ID: fmt.Sprintf("memleak-job-%d", i),
Type: scheduler.JobTypeBatch,
SlotPool: "batch",
GPUCount: 1,
})
}
// Signal workers ready
for _, w := range workers {
w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready")
}
// Collect baseline memory
runtime.GC()
var m1 runtime.MemStats
runtime.ReadMemStats(&m1)
t.Logf("Baseline heap: %d bytes", m1.HeapAlloc)
// Run for duration, cycling workers and jobs
start := time.Now()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
cycle := 0
for range ticker.C {
if time.Since(start) >= duration {
break
}
cycle++
// Every 30 seconds, submit new batch of jobs
if cycle%3 == 0 {
for i := 0; i < 10; i++ {
fixture.SubmitJob(scheduler.JobSpec{
ID: fmt.Sprintf("memleak-job-cycle%d-%d", cycle, i),
Type: scheduler.JobTypeBatch,
SlotPool: "batch",
GPUCount: 1,
})
}
}
// Every minute, recycle half the workers
if cycle%6 == 0 {
for i := 0; i < numWorkers/2; i++ {
workers[i].Close()
workerID := fmt.Sprintf("bench-worker-%d", (i+cycle*10)%1000)
workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: 4,
})
workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready")
}
}
// Process any pending messages to keep connections alive
for _, w := range workers {
select {
case msg := <-w.RecvCh:
if msg.Type == scheduler.MsgJobAssign {
w.AcceptJob("")
w.CompleteJob("", 0, "")
w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready")
}
default:
}
}
}
// Cleanup workers
for _, w := range workers {
w.Close()
}
// Final memory check
runtime.GC()
time.Sleep(200 * time.Millisecond)
var m2 runtime.MemStats
runtime.ReadMemStats(&m2)
elapsed := time.Since(start)
growth := int64(m2.HeapAlloc) - int64(m1.HeapAlloc)
growthPerMinute := float64(growth) / elapsed.Minutes()
t.Logf("Memory leak test completed: %v elapsed", elapsed)
t.Logf("Heap growth: %d bytes (%.0f bytes/min)", growth, growthPerMinute)
// Allow 5MB per minute growth max
maxGrowthPerMinute := float64(5 * 1024 * 1024)
assert.Less(t, growthPerMinute, maxGrowthPerMinute,
"memory growth should be less than 5MB/min (possible leak)")
}
// TestLongRunning_OrphanRecovery simulates worker deaths periodically
// Validates orphan recovery remains stable over extended period.
func TestLongRunning_OrphanRecovery(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long-running test in short mode")
}
duration := getDuration(2 * time.Minute)
cfg := fixtures.DefaultHubConfig()
// Use short grace periods for testing
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
scheduler.TierTraining: 2 * time.Second,
scheduler.TierDataProcessing: 1 * time.Second,
}
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
defer fixture.Cleanup()
// Track orphan events
var orphanMu sync.Mutex
orphanCount := 0
requeueCount := 0
done := make(chan struct{})
go func() {
defer close(done)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
start := time.Now()
cycle := 0
for time.Since(start) < duration {
<-ticker.C
cycle++
// Create workers for this cycle (use bench-worker-* pattern with valid tokens)
workers := make([]*fixtures.MockWorker, 3)
for i := range 3 {
workerID := fmt.Sprintf("bench-multi-worker-%d", (cycle*3+i)%1000)
workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: 4,
})
}
// Submit jobs
jobIDs := make([]string, 3)
for i := range 3 {
jobID := fmt.Sprintf("orphan-job-cycle%d-%d", cycle, i)
jobIDs[i] = jobID
fixture.SubmitJob(scheduler.JobSpec{
ID: jobID,
Type: scheduler.JobTypeBatch,
SlotPool: "batch",
GPUCount: 2,
JobTier: scheduler.TierTraining,
})
}
// Signal ready and accept jobs
for i, w := range workers {
w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready")
// Wait for assignment
select {
case msg := <-w.RecvCh:
if msg.Type == scheduler.MsgJobAssign {
w.AcceptJob(jobIDs[i])
}
case <-time.After(2 * time.Second):
t.Logf("Timeout waiting for job assignment in cycle %d", cycle)
}
}
// Kill workers abruptly (orphan the jobs)
for _, w := range workers {
w.Close()
}
orphanMu.Lock()
orphanCount += len(workers)
orphanMu.Unlock()
}
}()
<-done
// Trigger final orphan reconciliation
fixture.Hub.TriggerReconcileOrphans()
time.Sleep(500 * time.Millisecond)
// Check state for requeue events
events, err := fixture.Hub.GetStateEvents()
require.NoError(t, err)
for _, ev := range events {
if ev.Type == scheduler.EventJobRequeued {
requeueCount++
}
}
t.Logf("Orphan recovery test: %d orphans created, %d jobs requeued", orphanCount, requeueCount)
assert.Greater(t, requeueCount, 0, "should have requeued some orphaned jobs")
}
// TestLongRunning_WebSocketStability maintains multiple connections for extended period
// Validates WebSocket connections remain stable without unexpected disconnections.
func TestLongRunning_WebSocketStability(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long-running test in short mode")
}
duration := getDuration(1 * time.Minute)
cfg := fixtures.DefaultHubConfig()
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
defer fixture.Cleanup()
numWorkers := 20
workers := make([]*fixtures.MockWorker, numWorkers)
disconnectCounts := make([]int, numWorkers)
// Create all workers (use bench-worker-* pattern which has valid tokens 0-999)
for i := range numWorkers {
workerID := fmt.Sprintf("bench-worker-%d", i%1000)
workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: 4,
CPUCount: 8,
})
}
t.Logf("Created %d workers, monitoring for %v", numWorkers, duration)
// Monitor connections and send heartbeats
start := time.Now()
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
cycle := 0
for time.Since(start) < duration {
<-ticker.C
cycle++
for i := range workers {
w := workers[i]
// Check if worker disconnected
select {
case <-w.Done:
// Worker disconnected unexpectedly - reconnect
disconnectCounts[i]++
workerID := fmt.Sprintf("bench-worker-%d", (i+disconnectCounts[i]*100)%1000)
workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: 4,
})
workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "reconnected")
w = workers[i] // Update w to new worker
default:
// Send heartbeat to keep alive
w.SendHeartbeat(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0})
}
// Drain any messages
select {
case msg := <-w.RecvCh:
if msg.Type == scheduler.MsgJobAssign {
w.AcceptJob("")
w.CompleteJob("", 0, "")
w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready")
}
default:
}
}
// Every 30 seconds, submit some jobs to keep scheduler active
if cycle%6 == 0 {
for i := range 10 {
fixture.SubmitJob(scheduler.JobSpec{
ID: fmt.Sprintf("stability-job-%d-%d", cycle, i),
Type: scheduler.JobTypeBatch,
SlotPool: "batch",
GPUCount: 1,
})
}
}
}
// Cleanup
for _, w := range workers {
w.Close()
}
totalDisconnects := 0
for _, count := range disconnectCounts {
totalDisconnects += count
}
t.Logf("WebSocket stability test completed: %v elapsed", time.Since(start))
t.Logf("Total unexpected disconnects: %d (%.1f%% of connections)",
totalDisconnects, float64(totalDisconnects)/float64(numWorkers)*100)
// Allow 1 disconnect per 10 workers over the test period
maxAllowedDisconnects := numWorkers / 10
assert.LessOrEqual(t, totalDisconnects, maxAllowedDisconnects,
"unexpected disconnects should be minimal")
}