diff --git a/tests/fixtures/test_helpers.go b/tests/fixtures/test_helpers.go new file mode 100644 index 0000000..254dc89 --- /dev/null +++ b/tests/fixtures/test_helpers.go @@ -0,0 +1,237 @@ +// Package fixtures provides test fixtures and helpers +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" +) + +// WaitForEvent waits for a specific event type from the scheduler's state events +// with a timeout. It polls the state events until the event is found or timeout. +// Returns the matching event and true if found, nil and false if timeout. +func WaitForEvent( + t *testing.T, + hub *scheduler.SchedulerHub, + eventType scheduler.StateEventType, + timeout time.Duration, +) (*scheduler.StateEvent, bool) { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + events, err := hub.GetStateEvents() + if err != nil { + t.Logf("WaitForEvent: error getting state events: %v", err) + time.Sleep(50 * time.Millisecond) + continue + } + + for _, event := range events { + if event.Type == eventType { + return &event, true + } + } + + time.Sleep(50 * time.Millisecond) + } + + return nil, false +} + +// WaitForEventWithFilter waits for a specific event type that matches a filter function +func WaitForEventWithFilter( + t *testing.T, + hub *scheduler.SchedulerHub, + eventType scheduler.StateEventType, + filter func(scheduler.StateEvent) bool, + timeout time.Duration, +) (*scheduler.StateEvent, bool) { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + events, err := hub.GetStateEvents() + if err != nil { + t.Logf("WaitForEventWithFilter: error getting state events: %v", err) + time.Sleep(50 * time.Millisecond) + continue + } + + for _, event := range events { + if event.Type == eventType && filter(event) { + return &event, true + } + } + + time.Sleep(50 * time.Millisecond) + } + + return nil, false +} + +// WaitForTaskStatus waits for a task to reach a specific status +func WaitForTaskStatus( + t *testing.T, + hub *scheduler.SchedulerHub, + taskID string, + status string, + timeout time.Duration, +) bool { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + task := hub.GetTask(taskID) + if task != nil && task.Status == status { + return true + } + time.Sleep(50 * time.Millisecond) + } + + return false +} + +// WaitForMetric waits for a metric to satisfy a condition +func WaitForMetric( + t *testing.T, + hub *scheduler.SchedulerHub, + metricKey string, + condition func(interface{}) bool, + timeout time.Duration, +) bool { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + metrics := hub.GetMetricsPayload() + if value, ok := metrics[metricKey]; ok { + if condition(value) { + return true + } + } + time.Sleep(50 * time.Millisecond) + } + + return false +} + +// PollWithTimeout repeatedly calls a function until it returns true or timeout +func PollWithTimeout( + t *testing.T, + name string, + fn func() bool, + timeout time.Duration, + interval time.Duration, +) bool { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + t.Logf("PollWithTimeout %s: timeout after %v", name, timeout) + return false + case <-ticker.C: + if fn() { + return true + } + } + } +} + +// AssertEventReceived asserts that an event of the specified type was received +func AssertEventReceived( + t *testing.T, + hub *scheduler.SchedulerHub, + eventType scheduler.StateEventType, + timeout time.Duration, +) *scheduler.StateEvent { + t.Helper() + + event, found := WaitForEvent(t, hub, eventType, timeout) + if !found { + t.Fatalf("Expected event type %v within %v, but was not received", eventType, timeout) + } + return event +} + +// AssertTaskStatus asserts that a task reaches the expected status +func AssertTaskStatus( + t *testing.T, + hub *scheduler.SchedulerHub, + taskID string, + expectedStatus string, + timeout time.Duration, +) { + t.Helper() + + if !WaitForTaskStatus(t, hub, taskID, expectedStatus, timeout) { + task := hub.GetTask(taskID) + if task == nil { + t.Fatalf("Task %s not found (expected status: %s)", taskID, expectedStatus) + } + t.Fatalf("Task %s has status %s, expected %s (timeout: %v)", + taskID, task.Status, expectedStatus, timeout) + } +} + +// WaitForCondition waits for a condition to be true with a timeout +// Returns true if condition was met, false if timeout +func WaitForCondition( + t *testing.T, + name string, + condition func() bool, + timeout time.Duration, +) bool { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return true + } + time.Sleep(50 * time.Millisecond) + } + + t.Logf("WaitForCondition %s: timeout after %v", name, timeout) + return false +} + +// RetryWithBackoff retries an operation with exponential backoff +func RetryWithBackoff( + t *testing.T, + name string, + maxRetries int, + baseDelay time.Duration, + fn func() error, +) error { + t.Helper() + + var err error + delay := baseDelay + + for i := 0; i < maxRetries; i++ { + err = fn() + if err == nil { + return nil + } + + t.Logf("RetryWithBackoff %s: attempt %d/%d failed: %v", name, i+1, maxRetries, err) + + if i < maxRetries-1 { + time.Sleep(delay) + delay *= 2 // exponential backoff + } + } + + return fmt.Errorf("%s failed after %d attempts: %w", name, maxRetries, err) +} diff --git a/tests/long_running/scheduler_long_running_test.go b/tests/long_running/scheduler_long_running_test.go new file mode 100644 index 0000000..7455412 --- /dev/null +++ b/tests/long_running/scheduler_long_running_test.go @@ -0,0 +1,364 @@ +// Package longrunning provides long-running tests for the scheduler +// These tests are designed for nightly CI runs and are advisory only - failures +// alert the team but don't block releases. +// +// To run long-running tests: go test -v ./tests/long_running/... -run TestLongRunning +// These tests are skipped in short mode (go test -short) +// +// Environment variables: +// +// LONG_RUNNING_DURATION - Override test duration (e.g., "5m" for 5 minutes) +// LONG_RUNNING_WORKERS - Number of workers to use (default varies by test) +package longrunning + +import ( + "fmt" + "os" + "runtime" + "sync" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// getDuration returns the test duration from environment or default +func getDuration(defaultDuration time.Duration) time.Duration { + if d := os.Getenv("LONG_RUNNING_DURATION"); d != "" { + if parsed, err := time.ParseDuration(d); err == nil { + return parsed + } + } + return defaultDuration +} + +// TestLongRunning_MemoryLeak monitors heap growth over extended period +// Validates that the scheduler doesn't leak memory under sustained load. +func TestLongRunning_MemoryLeak(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long-running test in short mode") + } + + duration := getDuration(2 * time.Minute) // Default 2 min for CI, use LONG_RUNNING_DURATION for longer + cfg := fixtures.DefaultHubConfig() + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + numWorkers := 10 + workers := make([]*fixtures.MockWorker, numWorkers) + + // Create initial workers (use bench-worker-* pattern with valid tokens) + for i := 0; i < numWorkers; i++ { + workerID := fmt.Sprintf("bench-worker-%d", i) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + CPUCount: 8, + }) + } + + // Submit some initial jobs + for i := 0; i < 50; i++ { + fixture.SubmitJob(scheduler.JobSpec{ + ID: fmt.Sprintf("memleak-job-%d", i), + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 1, + }) + } + + // Signal workers ready + for _, w := range workers { + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + } + + // Collect baseline memory + runtime.GC() + var m1 runtime.MemStats + runtime.ReadMemStats(&m1) + t.Logf("Baseline heap: %d bytes", m1.HeapAlloc) + + // Run for duration, cycling workers and jobs + start := time.Now() + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + cycle := 0 + for range ticker.C { + if time.Since(start) >= duration { + break + } + cycle++ + + // Every 30 seconds, submit new batch of jobs + if cycle%3 == 0 { + for i := 0; i < 10; i++ { + fixture.SubmitJob(scheduler.JobSpec{ + ID: fmt.Sprintf("memleak-job-cycle%d-%d", cycle, i), + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 1, + }) + } + } + + // Every minute, recycle half the workers + if cycle%6 == 0 { + for i := 0; i < numWorkers/2; i++ { + workers[i].Close() + workerID := fmt.Sprintf("bench-worker-%d", (i+cycle*10)%1000) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + }) + workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + } + } + + // Process any pending messages to keep connections alive + for _, w := range workers { + select { + case msg := <-w.RecvCh: + if msg.Type == scheduler.MsgJobAssign { + w.AcceptJob("") + w.CompleteJob("", 0, "") + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + } + default: + } + } + } + + // Cleanup workers + for _, w := range workers { + w.Close() + } + + // Final memory check + runtime.GC() + time.Sleep(200 * time.Millisecond) + var m2 runtime.MemStats + runtime.ReadMemStats(&m2) + + elapsed := time.Since(start) + growth := int64(m2.HeapAlloc) - int64(m1.HeapAlloc) + growthPerMinute := float64(growth) / elapsed.Minutes() + + t.Logf("Memory leak test completed: %v elapsed", elapsed) + t.Logf("Heap growth: %d bytes (%.0f bytes/min)", growth, growthPerMinute) + + // Allow 5MB per minute growth max + maxGrowthPerMinute := float64(5 * 1024 * 1024) + assert.Less(t, growthPerMinute, maxGrowthPerMinute, + "memory growth should be less than 5MB/min (possible leak)") +} + +// TestLongRunning_OrphanRecovery simulates worker deaths periodically +// Validates orphan recovery remains stable over extended period. +func TestLongRunning_OrphanRecovery(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long-running test in short mode") + } + + duration := getDuration(2 * time.Minute) + cfg := fixtures.DefaultHubConfig() + // Use short grace periods for testing + cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{ + scheduler.TierTraining: 2 * time.Second, + scheduler.TierDataProcessing: 1 * time.Second, + } + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + // Track orphan events + var orphanMu sync.Mutex + orphanCount := 0 + requeueCount := 0 + + done := make(chan struct{}) + go func() { + defer close(done) + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + start := time.Now() + cycle := 0 + + for time.Since(start) < duration { + <-ticker.C + cycle++ + + // Create workers for this cycle (use bench-worker-* pattern with valid tokens) + workers := make([]*fixtures.MockWorker, 3) + for i := range 3 { + workerID := fmt.Sprintf("bench-multi-worker-%d", (cycle*3+i)%1000) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + }) + } + + // Submit jobs + jobIDs := make([]string, 3) + for i := range 3 { + jobID := fmt.Sprintf("orphan-job-cycle%d-%d", cycle, i) + jobIDs[i] = jobID + fixture.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 2, + JobTier: scheduler.TierTraining, + }) + } + + // Signal ready and accept jobs + for i, w := range workers { + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + // Wait for assignment + select { + case msg := <-w.RecvCh: + if msg.Type == scheduler.MsgJobAssign { + w.AcceptJob(jobIDs[i]) + } + case <-time.After(2 * time.Second): + t.Logf("Timeout waiting for job assignment in cycle %d", cycle) + } + } + + // Kill workers abruptly (orphan the jobs) + for _, w := range workers { + w.Close() + } + + orphanMu.Lock() + orphanCount += len(workers) + orphanMu.Unlock() + } + }() + + <-done + + // Trigger final orphan reconciliation + fixture.Hub.TriggerReconcileOrphans() + time.Sleep(500 * time.Millisecond) + + // Check state for requeue events + events, err := fixture.Hub.GetStateEvents() + require.NoError(t, err) + + for _, ev := range events { + if ev.Type == scheduler.EventJobRequeued { + requeueCount++ + } + } + + t.Logf("Orphan recovery test: %d orphans created, %d jobs requeued", orphanCount, requeueCount) + assert.Greater(t, requeueCount, 0, "should have requeued some orphaned jobs") +} + +// TestLongRunning_WebSocketStability maintains multiple connections for extended period +// Validates WebSocket connections remain stable without unexpected disconnections. +func TestLongRunning_WebSocketStability(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long-running test in short mode") + } + + duration := getDuration(1 * time.Minute) + cfg := fixtures.DefaultHubConfig() + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + numWorkers := 20 + workers := make([]*fixtures.MockWorker, numWorkers) + disconnectCounts := make([]int, numWorkers) + + // Create all workers (use bench-worker-* pattern which has valid tokens 0-999) + for i := range numWorkers { + workerID := fmt.Sprintf("bench-worker-%d", i%1000) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + CPUCount: 8, + }) + } + + t.Logf("Created %d workers, monitoring for %v", numWorkers, duration) + + // Monitor connections and send heartbeats + start := time.Now() + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + cycle := 0 + for time.Since(start) < duration { + <-ticker.C + cycle++ + + for i := range workers { + w := workers[i] + // Check if worker disconnected + select { + case <-w.Done: + // Worker disconnected unexpectedly - reconnect + disconnectCounts[i]++ + workerID := fmt.Sprintf("bench-worker-%d", (i+disconnectCounts[i]*100)%1000) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + }) + workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "reconnected") + w = workers[i] // Update w to new worker + default: + // Send heartbeat to keep alive + w.SendHeartbeat(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}) + } + + // Drain any messages + select { + case msg := <-w.RecvCh: + if msg.Type == scheduler.MsgJobAssign { + w.AcceptJob("") + w.CompleteJob("", 0, "") + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + } + default: + } + } + + // Every 30 seconds, submit some jobs to keep scheduler active + if cycle%6 == 0 { + for i := range 10 { + fixture.SubmitJob(scheduler.JobSpec{ + ID: fmt.Sprintf("stability-job-%d-%d", cycle, i), + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 1, + }) + } + } + } + + // Cleanup + for _, w := range workers { + w.Close() + } + + totalDisconnects := 0 + for _, count := range disconnectCounts { + totalDisconnects += count + } + + t.Logf("WebSocket stability test completed: %v elapsed", time.Since(start)) + t.Logf("Total unexpected disconnects: %d (%.1f%% of connections)", + totalDisconnects, float64(totalDisconnects)/float64(numWorkers)*100) + + // Allow 1 disconnect per 10 workers over the test period + maxAllowedDisconnects := numWorkers / 10 + assert.LessOrEqual(t, totalDisconnects, maxAllowedDisconnects, + "unexpected disconnects should be minimal") +} diff --git a/tests/stress/scheduler_stress_test.go b/tests/stress/scheduler_stress_test.go new file mode 100644 index 0000000..cc27d8b --- /dev/null +++ b/tests/stress/scheduler_stress_test.go @@ -0,0 +1,308 @@ +// Package stress provides stress tests for the scheduler +// These tests validate scheduler behavior under high load and burst conditions. +// +// To run stress tests: go test -v ./tests/stress/... -run TestStress +// These tests are skipped in short mode (go test -short) +package stress + +import ( + "encoding/json" + "fmt" + "runtime" + "sync" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" + "github.com/stretchr/testify/assert" +) + +// TestStress_WorkerConnectBurst tests 30 sequential WebSocket connections +// Validates that the scheduler can handle burst worker connections without failure. +func TestStress_WorkerConnectBurst(t *testing.T) { + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + cfg := fixtures.DefaultHubConfig() + cfg.DefaultBatchSlots = 4 + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + numWorkers := 30 + workers := make([]*fixtures.MockWorker, 0, numWorkers) + latencies := make([]time.Duration, 0, numWorkers) + + // Connect workers sequentially with minimal delay + start := time.Now() + for i := 0; i < numWorkers; i++ { + workerStart := time.Now() + workerID := fmt.Sprintf("bench-worker-%d", i) + + worker := fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + CPUCount: 8, + }) + workers = append(workers, worker) + latencies = append(latencies, time.Since(workerStart)) + + // Small yield to avoid overwhelming the scheduler + if i%10 == 9 { + time.Sleep(10 * time.Millisecond) + } + } + totalTime := time.Since(start) + + // Cleanup all workers + for _, w := range workers { + w.Close() + } + + // Validate p99 latency is under 100ms + p99 := calculateP99(latencies) + t.Logf("Worker connect burst: %d workers in %v, p99 latency: %v", numWorkers, totalTime, p99) + assert.Less(t, p99, 100*time.Millisecond, "p99 connection latency should be under 100ms") + assert.Less(t, totalTime, 5*time.Second, "total connect time should be under 5s") +} + +// TestStress_JobSubmissionBurst tests 1K job submissions +// Validates that the scheduler can handle burst job submissions without queue overflow. +func TestStress_JobSubmissionBurst(t *testing.T) { + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + cfg := fixtures.DefaultHubConfig() + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + // Create a single worker to receive assignments (use bench-worker-* pattern which has tokens 0-999) + worker := fixture.CreateWorker("bench-worker-100", scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 8, + CPUCount: 16, + }) + defer worker.Close() + + numJobs := 1000 + start := time.Now() + + // Submit 1K jobs + for i := range numJobs { + jobID := fmt.Sprintf("burst-job-%d", i) + fixture.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 1, + JobTier: scheduler.TierTraining, + }) + } + submitTime := time.Since(start) + + t.Logf("Submitted %d jobs in %v (%.0f jobs/sec)", numJobs, submitTime, float64(numJobs)/submitTime.Seconds()) + + // Signal worker ready to process some jobs + worker.SignalReady(scheduler.SlotStatus{BatchTotal: 8, BatchInUse: 0}, "ready") + + // Wait for and accept some assignments + accepted := 0 + done := time.After(3 * time.Second) + for accepted < 10 { + select { + case <-done: + goto doneAccepting + default: + select { + case msg := <-worker.RecvCh: + if msg.Type == scheduler.MsgJobAssign { + var payload scheduler.JobAssignPayload + _ = json.Unmarshal(msg.Payload, &payload) + worker.AcceptJob(payload.Spec.ID) + accepted++ + } + case <-time.After(100 * time.Millisecond): + worker.SignalReady(scheduler.SlotStatus{BatchTotal: 8, BatchInUse: accepted}, "still_ready") + } + } + } +doneAccepting: + + t.Logf("Worker accepted %d jobs from burst queue", accepted) + assert.Greater(t, accepted, 0, "worker should receive at least some job assignments") +} + +// TestStress_WorkerChurn tests rapid connect/disconnect cycles +// Validates that the scheduler properly cleans up resources and doesn't leak memory. +func TestStress_WorkerChurn(t *testing.T) { + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + cfg := fixtures.DefaultHubConfig() + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + cycles := 50 + var m1, m2 runtime.MemStats + + runtime.GC() + runtime.ReadMemStats(&m1) + + for i := range cycles { + workerID := fmt.Sprintf("churn-worker-%d", i%10) // Reuse 10 worker IDs + + // Create worker - the fixture has dynamic tokens for bench-worker patterns + workerID = fmt.Sprintf("bench-worker-%d", i) + worker := fixtures.NewMockWorker(t, fixture.Hub, workerID) + worker.Register(scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + }) + + // Brief connection + time.Sleep(20 * time.Millisecond) + + // Close worker + worker.Close() + + // Small delay between cycles + if i%10 == 9 { + time.Sleep(50 * time.Millisecond) + } + } + + // Force GC and check memory + runtime.GC() + time.Sleep(100 * time.Millisecond) + runtime.ReadMemStats(&m2) + + // Allow 10MB growth for 50 cycles (200KB per cycle max) + growth := int64(m2.HeapAlloc) - int64(m1.HeapAlloc) + maxGrowth := int64(10 * 1024 * 1024) // 10MB + + t.Logf("Worker churn: %d cycles, heap growth: %d bytes", cycles, growth) + assert.Less(t, growth, maxGrowth, "memory growth should be bounded (possible leak)") +} + +// TestStress_ConcurrentScheduling tests job queue contention with multiple workers +// Validates fair scheduling and lack of race conditions under concurrent load. +func TestStress_ConcurrentScheduling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + cfg := fixtures.DefaultHubConfig() + cfg.DefaultBatchSlots = 4 + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + numWorkers := 10 + jobsPerWorker := 20 + + // Create workers + workers := make([]*fixtures.MockWorker, numWorkers) + for i := range numWorkers { + workerID := fmt.Sprintf("bench-multi-worker-%d", i) + workers[i] = fixture.CreateWorker(workerID, scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: 4, + CPUCount: 8, + }) + } + + // Submit jobs concurrently + var wg sync.WaitGroup + for i := range numWorkers { + wg.Add(1) + go func(workerIdx int) { + defer wg.Done() + for j := 0; j < jobsPerWorker; j++ { + jobID := fmt.Sprintf("concurrent-job-w%d-j%d", workerIdx, j) + fixture.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + GPUCount: 1, + JobTier: scheduler.TierDataProcessing, + }) + } + }(i) + } + wg.Wait() + + totalJobs := numWorkers * jobsPerWorker + t.Logf("Submitted %d jobs from %d workers concurrently", totalJobs, numWorkers) + + // Signal all workers ready and collect some assignments + var assignWg sync.WaitGroup + assignmentCounts := make([]int, numWorkers) + + for i, worker := range workers { + assignWg.Add(1) + go func(idx int, w *fixtures.MockWorker) { + defer assignWg.Done() + + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "ready") + + // Collect assignments for 500ms + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + select { + case msg := <-w.RecvCh: + if msg.Type == scheduler.MsgJobAssign { + assignmentCounts[idx]++ + var payload scheduler.JobAssignPayload + _ = json.Unmarshal(msg.Payload, &payload) + w.AcceptJob(payload.Spec.ID) + // Signal ready again after accepting + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 1}, "processing") + } + case <-time.After(50 * time.Millisecond): + // Ping ready status + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "still_ready") + } + } + }(i, worker) + } + assignWg.Wait() + + totalAssigned := 0 + for _, count := range assignmentCounts { + totalAssigned += count + } + + t.Logf("Workers received %d total assignments", totalAssigned) + assert.Greater(t, totalAssigned, 0, "should have some job assignments") + + // Cleanup + for _, w := range workers { + w.Close() + } +} + +// calculateP99 returns the 99th percentile latency from a slice of durations +func calculateP99(latencies []time.Duration) time.Duration { + if len(latencies) == 0 { + return 0 + } + + // Simple sort-based approach (not efficient for large N, but fine for stress tests) + sorted := make([]time.Duration, len(latencies)) + copy(sorted, latencies) + for i := range sorted { + for j := i + 1; j < len(sorted); j++ { + if sorted[i] > sorted[j] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + idx := (len(sorted) * 99) / 100 + if idx >= len(sorted) { + idx = len(sorted) - 1 + } + return sorted[idx] +} diff --git a/tests/testreport/reporter.go b/tests/testreport/reporter.go new file mode 100644 index 0000000..59c9961 --- /dev/null +++ b/tests/testreport/reporter.go @@ -0,0 +1,278 @@ +// Package testreport provides structured test reporting and output +package testreport + +import ( + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" +) + +// TestResult represents a single test result +type TestResult struct { + Name string `json:"name"` + Package string `json:"package"` + Status string `json:"status"` // pass, fail, skip + Duration time.Duration `json:"duration"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` +} + +// TestSuite represents a collection of test results +type TestSuite struct { + Name string `json:"name"` + Package string `json:"package"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + Tests []TestResult `json:"tests"` +} + +// Summary provides aggregate statistics +type Summary struct { + Total int `json:"total"` + Passed int `json:"passed"` + Failed int `json:"failed"` + Skipped int `json:"skipped"` + Duration time.Duration `json:"duration"` +} + +// Reporter handles test reporting +type Reporter struct { + suite TestSuite + current *TestResult + testMap map[string]*TestResult +} + +// NewReporter creates a new test reporter +func NewReporter(name, pkg string) *Reporter { + return &Reporter{ + suite: TestSuite{ + Name: name, + Package: pkg, + StartTime: time.Now(), + Tests: []TestResult{}, + }, + testMap: make(map[string]*TestResult), + } +} + +// StartTest records the start of a test +func (r *Reporter) StartTest(name string) { + result := &TestResult{ + Name: name, + Package: r.suite.Package, + StartTime: time.Now(), + Status: "running", + } + r.current = result + r.testMap[name] = result +} + +// EndTest records the end of a test +func (r *Reporter) EndTest(name string, status string, err error) { + if r.current == nil || r.current.Name != name { + r.current = r.testMap[name] + } + if r.current == nil { + return + } + + r.current.EndTime = time.Now() + r.current.Duration = r.current.EndTime.Sub(r.current.StartTime) + r.current.Status = status + + if err != nil { + r.current.Error = err.Error() + } + + r.suite.Tests = append(r.suite.Tests, *r.current) + r.current = nil +} + +// RecordOutput captures test output +func (r *Reporter) RecordOutput(output string) { + if r.current != nil { + r.current.Output += output + "\n" + } +} + +// Summary generates aggregate statistics +func (r *Reporter) Summary() Summary { + s := Summary{ + Total: len(r.suite.Tests), + } + + for _, t := range r.suite.Tests { + switch t.Status { + case "pass": + s.Passed++ + case "fail": + s.Failed++ + case "skip": + s.Skipped++ + } + } + + return s +} + +// ToJSON exports the test suite as JSON +func (r *Reporter) ToJSON() ([]byte, error) { + r.suite.EndTime = time.Now() + return json.MarshalIndent(r.suite, "", " ") +} + +// SaveToFile writes the test report to a file +func (r *Reporter) SaveToFile(path string) error { + data, err := r.ToJSON() + if err != nil { + return err + } + return os.WriteFile(path, data, 0644) +} + +// ReportToEnv outputs report path to environment for CI +func (r *Reporter) ReportToEnv() { + if path := os.Getenv("TEST_REPORT_PATH"); path != "" { + r.SaveToFile(path) + fmt.Fprintf(os.Stderr, "Test report saved to: %s\n", path) + } +} + +// FlakyTestTracker tracks potentially flaky tests +type FlakyTestTracker struct { + runs map[string][]bool // test name -> []passed +} + +// NewFlakyTestTracker creates a new flaky test tracker +func NewFlakyTestTracker() *FlakyTestTracker { + return &FlakyTestTracker{ + runs: make(map[string][]bool), + } +} + +// RecordResult records a test result +func (ft *FlakyTestTracker) RecordResult(name string, passed bool) { + ft.runs[name] = append(ft.runs[name], passed) +} + +// IsFlaky returns true if a test has inconsistent results +func (ft *FlakyTestTracker) IsFlaky(name string) bool { + runs := ft.runs[name] + if len(runs) < 3 { + return false + } + + // Check for mixed results + passed := 0 + failed := 0 + for _, r := range runs { + if r { + passed++ + } else { + failed++ + } + } + + // Flaky if both passed and failed exist + return passed > 0 && failed > 0 +} + +// GetFlakyTests returns all tests that appear flaky +func (ft *FlakyTestTracker) GetFlakyTests() []string { + var flaky []string + for name := range ft.runs { + if ft.IsFlaky(name) { + flaky = append(flaky, name) + } + } + return flaky +} + +// Report generates a flaky test report +func (ft *FlakyTestTracker) Report() string { + flaky := ft.GetFlakyTests() + if len(flaky) == 0 { + return "No flaky tests detected" + } + + var report strings.Builder + report.WriteString("Potentially Flaky Tests:\n") + for _, name := range flaky { + runs := ft.runs[name] + passed := 0 + for _, r := range runs { + if r { + passed++ + } + } + fmt.Fprintf(&report, " - %s: %d/%d passed (%.1f%%)\n", + name, passed, len(runs), float64(passed)*100/float64(len(runs))) + } + return report.String() +} + +// TestTimer provides timing utilities for tests +type TestTimer struct { + start time.Time + duration time.Duration +} + +// NewTestTimer creates a new test timer +func NewTestTimer() *TestTimer { + return &TestTimer{start: time.Now()} +} + +// Elapsed returns elapsed time +func (tt *TestTimer) Elapsed() time.Duration { + return time.Since(tt.start) +} + +// CheckBudget checks if test is within time budget +func (tt *TestTimer) CheckBudget(budget time.Duration, t *testing.T) bool { + elapsed := tt.Elapsed() + if elapsed > budget { + t.Logf("WARNING: Test exceeded time budget: %v > %v", elapsed, budget) + return false + } + return true +} + +// PerformanceRegression tracks performance metrics +type PerformanceRegression struct { + metrics map[string][]float64 // metric name -> values +} + +// NewPerformanceRegression creates a new tracker +func NewPerformanceRegression() *PerformanceRegression { + return &PerformanceRegression{ + metrics: make(map[string][]float64), + } +} + +// Record records a metric value +func (pr *PerformanceRegression) Record(name string, value float64) { + pr.metrics[name] = append(pr.metrics[name], value) +} + +// CheckRegression checks if current value regresses from baseline +func (pr *PerformanceRegression) CheckRegression(name string, current float64, threshold float64) bool { + values := pr.metrics[name] + if len(values) < 3 { + return false // Not enough data + } + + // Calculate average of previous runs + var sum float64 + for _, v := range values { + sum += v + } + avg := sum / float64(len(values)) + + // Regression if current is worse than threshold * average + return current > avg*threshold +} diff --git a/tests/testreport/reporter_test.go b/tests/testreport/reporter_test.go new file mode 100644 index 0000000..5443063 --- /dev/null +++ b/tests/testreport/reporter_test.go @@ -0,0 +1,134 @@ +package testreport_test + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/tests/testreport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReporter_BasicFlow(t *testing.T) { + r := testreport.NewReporter("unit-test", "test/package") + + // Record a passing test + r.StartTest("TestPass") + time.Sleep(10 * time.Millisecond) + r.EndTest("TestPass", "pass", nil) + + // Record a failing test + r.StartTest("TestFail") + time.Sleep(5 * time.Millisecond) + r.EndTest("TestFail", "fail", errors.New("test error")) + + // Record a skipped test + r.StartTest("TestSkip") + r.EndTest("TestSkip", "skip", nil) + + // Check summary + summary := r.Summary() + assert.Equal(t, 3, summary.Total) + assert.Equal(t, 1, summary.Passed) + assert.Equal(t, 1, summary.Failed) + assert.Equal(t, 1, summary.Skipped) +} + +func TestReporter_JSONOutput(t *testing.T) { + r := testreport.NewReporter("json-test", "test/package") + + r.StartTest("TestOne") + r.RecordOutput("some output") + r.EndTest("TestOne", "pass", nil) + + jsonData, err := r.ToJSON() + require.NoError(t, err) + assert.Contains(t, string(jsonData), "json-test") + assert.Contains(t, string(jsonData), "TestOne") + assert.Contains(t, string(jsonData), "some output") +} + +func TestReporter_SaveToFile(t *testing.T) { + r := testreport.NewReporter("file-test", "test/package") + + r.StartTest("TestOne") + r.EndTest("TestOne", "pass", nil) + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "report.json") + + err := r.SaveToFile(path) + require.NoError(t, err) + + data, err := os.ReadFile(path) + require.NoError(t, err) + assert.Contains(t, string(data), "file-test") +} + +func TestFlakyTestTracker(t *testing.T) { + ft := testreport.NewFlakyTestTracker() + + // Record consistent passes + ft.RecordResult("stable-pass", true) + ft.RecordResult("stable-pass", true) + ft.RecordResult("stable-pass", true) + + // Record consistent failures + ft.RecordResult("stable-fail", false) + ft.RecordResult("stable-fail", false) + + // Record mixed results (flaky) + ft.RecordResult("flaky-test", true) + ft.RecordResult("flaky-test", false) + ft.RecordResult("flaky-test", true) + + // Not flaky with < 3 runs + assert.False(t, ft.IsFlaky("stable-pass")) + assert.False(t, ft.IsFlaky("stable-fail")) + + // Flaky with mixed results + assert.True(t, ft.IsFlaky("flaky-test")) + + // Get all flaky tests + flaky := ft.GetFlakyTests() + require.Len(t, flaky, 1) + assert.Equal(t, "flaky-test", flaky[0]) + + // Check report + report := ft.Report() + assert.Contains(t, report, "flaky-test") +} + +func TestTestTimer(t *testing.T) { + timer := testreport.NewTestTimer() + + // Should be very small initially + elapsed := timer.Elapsed() + assert.True(t, elapsed < 100*time.Millisecond) + + // Budget check should pass + passed := timer.CheckBudget(1*time.Second, t) + assert.True(t, passed) +} + +func TestPerformanceRegression(t *testing.T) { + pr := testreport.NewPerformanceRegression() + + // Record baseline values + pr.Record("latency", 100.0) + pr.Record("latency", 110.0) + pr.Record("latency", 105.0) + + // Current value within threshold + assert.False(t, pr.CheckRegression("latency", 120.0, 1.5)) + + // Current value regresses + assert.True(t, pr.CheckRegression("latency", 200.0, 1.5)) + + // Not enough data for new metric + pr.Record("new-metric", 50.0) + assert.False(t, pr.CheckRegression("new-metric", 1000.0, 2.0)) +}