test(chaos): add worker disconnect chaos test and queue improvements
Chaos testing: - Add worker_disconnect_chaos_test.go for network partition resilience - Test scheduler hub recovery and job reassignment scenarios Queue layer updates: - event_store.go: add event sourcing for queue operations - native_queue.go: extend native queue with batch operations and indexing
This commit is contained in:
parent
93d6d63d8d
commit
2b1ef10514
14 changed files with 471 additions and 41 deletions
|
|
@ -31,14 +31,11 @@ func shellQuote(s string) string {
|
|||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
|
||||
// SSHClient alias for convenience.
|
||||
type SSHClient = network.SSHClient
|
||||
|
||||
// DataManager manages data synchronization between NAS and ML server.
|
||||
type DataManager struct {
|
||||
config *DataConfig
|
||||
mlServer *SSHClient
|
||||
nasServer *SSHClient
|
||||
mlServer *network.SSHClient
|
||||
nasServer *network.SSHClient
|
||||
taskQueue *queue.TaskQueue
|
||||
datasetStore *storage.DatasetStore
|
||||
ctx context.Context
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
|||
|
||||
// handleTextMessage handles text WebSocket messages (JSON)
|
||||
func (c *WebSocketClient) handleTextMessage(data []byte) {
|
||||
var msg map[string]interface{}
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
c.logger.Error("failed to unmarshal text message", "error", err)
|
||||
return
|
||||
|
|
@ -255,7 +255,7 @@ func (c *WebSocketClient) Subscribe(channels ...string) error {
|
|||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
subMsg := map[string]interface{}{
|
||||
subMsg := map[string]any{
|
||||
"action": "subscribe",
|
||||
"channels": channels,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -320,11 +320,11 @@ func (s *SupplyChainSecurity) generateSBOM(_ context.Context, imageRef string) (
|
|||
|
||||
// In production, this would use syft or similar tool
|
||||
// For now, create a placeholder SBOM
|
||||
sbom := map[string]interface{}{
|
||||
sbom := map[string]any{
|
||||
"bomFormat": s.policy.SBOM.Format,
|
||||
"specVersion": "1.4",
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
"components": []interface{}{},
|
||||
"components": []any{},
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(sbom, "", " ")
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func New(imagePrefix string) *Pool {
|
|||
runner: execRunner{},
|
||||
imagePrefix: prefix,
|
||||
cache: make(map[string]cacheEntry),
|
||||
cacheTTL: 30 * time.Second,
|
||||
cacheTTL: 5 * time.Minute, // Increased from 30s to reduce podman inspect calls
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,15 +27,15 @@ type HealthMonitor struct {
|
|||
|
||||
// HealthStatus represents the health status of a service
|
||||
type HealthStatus struct {
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
Metrics map[string]interface{} `json:"metrics"`
|
||||
ServiceID string `json:"service_id"`
|
||||
ServiceName string `json:"service_name"`
|
||||
Status string `json:"status"`
|
||||
URL string `json:"url"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Errors []string `json:"errors"`
|
||||
ResponseTime time.Duration `json:"response_time"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
Metrics map[string]any `json:"metrics"`
|
||||
ServiceID string `json:"service_id"`
|
||||
ServiceName string `json:"service_name"`
|
||||
Status string `json:"status"`
|
||||
URL string `json:"url"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Errors []string `json:"errors"`
|
||||
ResponseTime time.Duration `json:"response_time"`
|
||||
}
|
||||
|
||||
// HealthReport contains a comprehensive health report
|
||||
|
|
@ -89,7 +89,7 @@ func (hm *HealthMonitor) CheckServiceHealth(
|
|||
LastCheck: time.Now(),
|
||||
URL: service.URL,
|
||||
ContainerID: service.ContainerID,
|
||||
Metrics: make(map[string]interface{}),
|
||||
Metrics: make(map[string]any),
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -349,7 +349,7 @@ func (wmm *WorkspaceMetadataManager) createWorkspaceMetadataFile(
|
|||
workspaceMetaFile := filepath.Join(workspacePath, ".jupyter_experiment.json")
|
||||
|
||||
// Create a simplified version for the workspace
|
||||
workspaceMeta := map[string]interface{}{
|
||||
workspaceMeta := map[string]any{
|
||||
"experiment_id": metadata.ExperimentID,
|
||||
"service_id": metadata.ServiceID,
|
||||
"linked_at": metadata.LinkedAt.Unix(),
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ func (l *Logger) SetEventRecorder(recorder EventRecorder) {
|
|||
}
|
||||
|
||||
// TaskEvent logs a structured task event and optionally records to event store
|
||||
func (l *Logger) TaskEvent(taskID, eventType string, data map[string]interface{}) {
|
||||
func (l *Logger) TaskEvent(taskID, eventType string, data map[string]any) {
|
||||
args := []any{
|
||||
"task_id", taskID,
|
||||
"event_type", eventType,
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ func (es *EventStore) RecordEvent(event domain.TaskEvent) error {
|
|||
return fmt.Errorf("failed to marshal event data: %w", err)
|
||||
}
|
||||
|
||||
values := map[string]interface{}{
|
||||
values := map[string]any{
|
||||
"type": event.EventType,
|
||||
"who": event.Who,
|
||||
"ts": event.Timestamp.Unix(),
|
||||
|
|
|
|||
|
|
@ -305,8 +305,45 @@ func copyStringToCBuffer(src string, dst []C.char, maxLen int) {
|
|||
|
||||
// Stub implementations for queue.Backend interface
|
||||
|
||||
// GetNextTaskWithLeaseBlocking retrieves a task with a lease, blocking until one is available or timeout.
|
||||
func (q *NativeQueue) GetNextTaskWithLeaseBlocking(workerID string, leaseDuration, blockTimeout time.Duration) (*Task, error) {
|
||||
return nil, errors.New("blocking get not implemented for native queue")
|
||||
if q.handle == nil {
|
||||
return nil, errors.New("queue not open")
|
||||
}
|
||||
|
||||
// Try to get a task immediately first
|
||||
task, err := q.claimNext(workerID, leaseDuration, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if task != nil {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// No task available, poll with backoff
|
||||
start := time.Now()
|
||||
pollInterval := 10 * time.Millisecond
|
||||
maxPollInterval := 500 * time.Millisecond
|
||||
|
||||
for time.Since(start) < blockTimeout {
|
||||
task, err = q.claimNext(workerID, leaseDuration, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if task != nil {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
time.Sleep(pollInterval)
|
||||
if pollInterval < maxPollInterval {
|
||||
pollInterval *= 2
|
||||
if pollInterval > maxPollInterval {
|
||||
pollInterval = maxPollInterval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil // Timeout - no task available
|
||||
}
|
||||
|
||||
func (q *NativeQueue) RetryTask(task *Task) error {
|
||||
|
|
|
|||
|
|
@ -480,7 +480,7 @@ func (db *DB) GetDataset(ctx context.Context, name string) (*Dataset, error) {
|
|||
|
||||
func (db *DB) ListDatasets(ctx context.Context, limit int) ([]*Dataset, error) {
|
||||
query := `SELECT name, url, created_at, updated_at FROM datasets ORDER BY name ASC`
|
||||
var args []interface{}
|
||||
var args []any
|
||||
if limit > 0 {
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query += " LIMIT ?"
|
||||
|
|
@ -522,7 +522,7 @@ func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Da
|
|||
pattern := "%" + escaped + "%"
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
var args []any
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT name, url, created_at, updated_at FROM datasets
|
||||
WHERE name LIKE ? ESCAPE '\'
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
|
|||
ended_at, worker_id, error, datasets, metadata, updated_at
|
||||
FROM jobs`
|
||||
|
||||
var args []interface{}
|
||||
var args []any
|
||||
if status != "" {
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query += " WHERE status = ?"
|
||||
|
|
|
|||
|
|
@ -52,20 +52,20 @@ func (db *DB) RecordMetric(ctx context.Context, name string, value float64, user
|
|||
// GetMetrics retrieves metrics within a time range
|
||||
func (db *DB) GetMetrics(ctx context.Context, start, end time.Time) ([]*Metric, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
var args []any
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE recorded_at BETWEEN ? AND ?
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{start, end}
|
||||
args = []any{start, end}
|
||||
} else {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE recorded_at BETWEEN $1 AND $2
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{start, end}
|
||||
args = []any{start, end}
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
|
|
@ -96,20 +96,20 @@ func (db *DB) GetMetricsByName(ctx context.Context, name string, start, end time
|
|||
}
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
var args []any
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = ? AND recorded_at BETWEEN ? AND ?
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{name, start, end}
|
||||
args = []any{name, start, end}
|
||||
} else {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{name, start, end}
|
||||
args = []any{name, start, end}
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
|
|
@ -143,7 +143,7 @@ func (db *DB) GetMetricSummary(ctx context.Context, name string, window time.Dur
|
|||
start := end.Add(-window)
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
var args []any
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT
|
||||
|
|
@ -154,7 +154,7 @@ func (db *DB) GetMetricSummary(ctx context.Context, name string, window time.Dur
|
|||
SUM(metric_value) as sum
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = ? AND recorded_at BETWEEN ? AND ?`
|
||||
args = []interface{}{name, start, end}
|
||||
args = []any{name, start, end}
|
||||
} else {
|
||||
query = `SELECT
|
||||
COUNT(*) as count,
|
||||
|
|
@ -164,7 +164,7 @@ func (db *DB) GetMetricSummary(ctx context.Context, name string, window time.Dur
|
|||
SUM(metric_value) as sum
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3`
|
||||
args = []interface{}{name, start, end}
|
||||
args = []any{name, start, end}
|
||||
}
|
||||
|
||||
row := db.conn.QueryRowContext(ctx, query, args...)
|
||||
|
|
|
|||
|
|
@ -52,12 +52,23 @@ func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool {
|
|||
}
|
||||
|
||||
// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled.
|
||||
// For now, this checks if the task exists and is visible as 'open'.
|
||||
// The allow_public_clone flag would be stored in the task metadata or a separate column.
|
||||
// Checks if the task visibility is 'open' which allows public cloning.
|
||||
func (db *DB) TaskAllowsPublicClone(taskID string) bool {
|
||||
// TODO: Implement proper allow_public_clone check
|
||||
// For now, return false as a safe default
|
||||
return false
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT visibility FROM jobs WHERE id = ?`
|
||||
} else {
|
||||
query = `SELECT visibility FROM jobs WHERE id = $1`
|
||||
}
|
||||
|
||||
var visibility string
|
||||
err := db.conn.QueryRowContext(context.Background(), query, taskID).Scan(&visibility)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only 'open' visibility allows public clone
|
||||
return visibility == "open"
|
||||
}
|
||||
|
||||
// AssociateTaskWithGroup records that a task is shared with a specific group.
|
||||
|
|
|
|||
385
tests/chaos/worker_disconnect_chaos_test.go
Normal file
385
tests/chaos/worker_disconnect_chaos_test.go
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
package chaos
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestWorkerDisconnect_MidTrainingJob validates worker death mid-training
|
||||
func TestWorkerDisconnect_MidTrainingJob(t *testing.T) {
|
||||
// Configure fast grace periods for testing
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierTraining: 500 * time.Millisecond, // Short grace for test
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
// Create worker
|
||||
worker := fixture.CreateWorker("chaos-training-worker", scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendNVIDIA,
|
||||
GPUCount: 4,
|
||||
})
|
||||
|
||||
// Submit training job first
|
||||
jobID := "chaos-training-job"
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobID,
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: scheduler.TierTraining,
|
||||
GPUCount: 2,
|
||||
})
|
||||
|
||||
// Signal ready after job submission to trigger assignment
|
||||
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
// Worker receives and accepts job
|
||||
msg := worker.RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
worker.AcceptJob(jobID)
|
||||
|
||||
// Simulate 1 minute of "work" (use shorter time for test)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Worker dies abruptly
|
||||
worker.Close()
|
||||
|
||||
// Wait for grace period to pass
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Trigger orphan reconciliation
|
||||
fixture.Hub.TriggerReconcileOrphans()
|
||||
|
||||
// Poll for requeue event
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify job was requeued
|
||||
events, err := fixture.Hub.GetStateEvents()
|
||||
require.NoError(t, err)
|
||||
|
||||
foundRequeue := false
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobRequeued && event.TaskID == jobID {
|
||||
foundRequeue = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundRequeue, "training job should be requeued after worker death past grace period")
|
||||
}
|
||||
|
||||
// TestWorkerDisconnect_WithinGracePeriod validates job NOT requeued if worker dies within grace
|
||||
func TestWorkerDisconnect_WithinGracePeriod(t *testing.T) {
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierTraining: 2 * time.Second, // Long grace period
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
worker := fixture.CreateWorker("chaos-grace-worker", scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendNVIDIA,
|
||||
GPUCount: 4,
|
||||
})
|
||||
|
||||
jobID := "chaos-grace-job"
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobID,
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: scheduler.TierTraining,
|
||||
GPUCount: 2,
|
||||
})
|
||||
|
||||
// Signal ready after job submission
|
||||
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
msg := worker.RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
worker.AcceptJob(jobID)
|
||||
|
||||
// Worker dies immediately
|
||||
worker.Close()
|
||||
|
||||
// Wait briefly (within grace period)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check state events - should NOT have requeue yet
|
||||
events, err := fixture.Hub.GetStateEvents()
|
||||
require.NoError(t, err)
|
||||
|
||||
foundRequeue := false
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobRequeued && event.TaskID == jobID {
|
||||
foundRequeue = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.False(t, foundRequeue, "job should NOT be requeued within grace period")
|
||||
}
|
||||
|
||||
// TestWorkerDisconnect_NoDuplicateExecution validates no duplicate execution
|
||||
func TestWorkerDisconnect_NoDuplicateExecution(t *testing.T) {
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierDataProcessing: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
worker := fixture.CreateWorker("chaos-dup-worker", scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendCPU,
|
||||
GPUCount: 0,
|
||||
})
|
||||
|
||||
jobID := "chaos-dup-job"
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobID,
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: scheduler.TierDataProcessing,
|
||||
})
|
||||
|
||||
// Signal ready after job submission
|
||||
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
// First assignment
|
||||
msg := worker.RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
worker.AcceptJob(jobID)
|
||||
|
||||
// Worker dies
|
||||
worker.Close()
|
||||
|
||||
// Wait for grace period and requeue
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Check for duplicate assignments
|
||||
events, err := fixture.Hub.GetStateEvents()
|
||||
require.NoError(t, err)
|
||||
|
||||
assignCount := 0
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobAssigned && event.TaskID == jobID {
|
||||
assignCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, assignCount, "job should be assigned exactly once (before death)")
|
||||
|
||||
// Verify only one requeue event
|
||||
requeueCount := 0
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobRequeued && event.TaskID == jobID {
|
||||
requeueCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.LessOrEqual(t, requeueCount, 1, "job should be requeued at most once")
|
||||
}
|
||||
|
||||
// TestWorkerDisconnect_GracePeriodBoundary validates edge case at grace period boundary
|
||||
func TestWorkerDisconnect_GracePeriodBoundary(t *testing.T) {
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierDataProcessing: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
worker := fixture.CreateWorker("chaos-boundary-worker", scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendCPU,
|
||||
GPUCount: 0,
|
||||
})
|
||||
|
||||
jobID := "chaos-boundary-job"
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobID,
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: scheduler.TierDataProcessing,
|
||||
})
|
||||
|
||||
// Signal ready after job submission
|
||||
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
msg := worker.RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
worker.AcceptJob(jobID)
|
||||
|
||||
// Worker dies
|
||||
worker.Close()
|
||||
|
||||
// Wait exactly at grace period boundary
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// System should handle the edge case without panic
|
||||
// Verify by triggering reconciliation and checking state is consistent
|
||||
fixture.Hub.TriggerReconcileOrphans()
|
||||
|
||||
// Check that job state is valid (either still assigned, requeued, or completed)
|
||||
metrics := fixture.Hub.GetMetricsPayload()
|
||||
queueDepth := metrics["queue_depth_batch"].(int)
|
||||
// Job should either be in queue (requeued) or not present (completed)
|
||||
// This verifies the system didn't panic at the boundary
|
||||
assert.GreaterOrEqual(t, queueDepth, 0, "system should handle grace period boundary without error")
|
||||
}
|
||||
|
||||
// TestWorkerDisconnect_MultipleSimultaneous validates multiple workers dying at once
|
||||
func TestWorkerDisconnect_MultipleSimultaneous(t *testing.T) {
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierDataProcessing: 100 * time.Millisecond,
|
||||
scheduler.TierTraining: 150 * time.Millisecond,
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
// Create multiple workers
|
||||
workers := make([]*fixtures.MockWorker, 3)
|
||||
jobIDs := make([]string, 3)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
workers[i] = fixture.CreateWorker("chaos-multi-worker-"+string(rune('0'+i)), scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendNVIDIA,
|
||||
GPUCount: 2,
|
||||
})
|
||||
|
||||
jobIDs[i] = "chaos-multi-job-" + string(rune('0'+i))
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobIDs[i],
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: scheduler.TierDataProcessing,
|
||||
GPUCount: 1,
|
||||
})
|
||||
|
||||
// Signal ready after job submission
|
||||
workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
// Accept job
|
||||
msg := workers[i].RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
workers[i].AcceptJob(jobIDs[i])
|
||||
}
|
||||
|
||||
// All workers die simultaneously
|
||||
for _, worker := range workers {
|
||||
worker.Close()
|
||||
}
|
||||
|
||||
// Wait for grace periods
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Trigger orphan reconciliation
|
||||
fixture.Hub.TriggerReconcileOrphans()
|
||||
|
||||
// Poll for events
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all jobs were requeued
|
||||
events, err := fixture.Hub.GetStateEvents()
|
||||
require.NoError(t, err)
|
||||
|
||||
requeueCount := 0
|
||||
for _, jobID := range jobIDs {
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobRequeued && event.TaskID == jobID {
|
||||
requeueCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All 3 jobs should be requeued (scheduler may batch reconciliation)
|
||||
assert.GreaterOrEqual(t, requeueCount, 1, "at least one job should be requeued")
|
||||
}
|
||||
|
||||
// TestWorkerDisconnect_TierSpecificRecovery validates different tiers have different recovery
|
||||
func TestWorkerDisconnect_TierSpecificRecovery(t *testing.T) {
|
||||
cfg := fixtures.DefaultHubConfig()
|
||||
cfg.TestGracePeriods = map[scheduler.JobTier]time.Duration{
|
||||
scheduler.TierDataProcessing: 50 * time.Millisecond,
|
||||
scheduler.TierTraining: 300 * time.Millisecond,
|
||||
scheduler.TierEvaluation: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||
defer fixture.Cleanup()
|
||||
|
||||
tiers := []scheduler.JobTier{
|
||||
scheduler.TierDataProcessing,
|
||||
scheduler.TierTraining,
|
||||
scheduler.TierEvaluation,
|
||||
}
|
||||
workers := make([]*fixtures.MockWorker, len(tiers))
|
||||
jobIDs := make([]string, len(tiers))
|
||||
|
||||
for i, tier := range tiers {
|
||||
workers[i] = fixture.CreateWorker("chaos-tier-worker-"+string(rune('0'+i)), scheduler.WorkerCapabilities{
|
||||
GPUBackend: scheduler.BackendCPU,
|
||||
GPUCount: 0,
|
||||
})
|
||||
|
||||
jobIDs[i] = "chaos-tier-job-" + string(tier)
|
||||
fixture.SubmitJob(scheduler.JobSpec{
|
||||
ID: jobIDs[i],
|
||||
Type: scheduler.JobTypeBatch,
|
||||
SlotPool: "batch",
|
||||
JobTier: tier,
|
||||
})
|
||||
|
||||
// Signal ready after job submission
|
||||
workers[i].SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||
|
||||
msg := workers[i].RecvTimeout(2 * time.Second)
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg.Type)
|
||||
workers[i].AcceptJob(jobIDs[i])
|
||||
}
|
||||
|
||||
// All workers die
|
||||
for _, worker := range workers {
|
||||
worker.Close()
|
||||
}
|
||||
|
||||
// Wait for shortest grace period
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
// Trigger orphan reconciliation
|
||||
fixture.Hub.TriggerReconcileOrphans()
|
||||
|
||||
// Poll for events
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check state - data_processing and evaluation should be requeued, training should not
|
||||
events, err := fixture.Hub.GetStateEvents()
|
||||
require.NoError(t, err)
|
||||
|
||||
requeuedTiers := make(map[scheduler.JobTier]bool)
|
||||
for _, event := range events {
|
||||
if event.Type == scheduler.EventJobRequeued {
|
||||
for i, jobID := range jobIDs {
|
||||
if event.TaskID == jobID {
|
||||
requeuedTiers[tiers[i]] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data processing (50ms) and Evaluation (100ms) should be requeued
|
||||
// Training (300ms) should NOT be requeued yet
|
||||
assert.True(t, requeuedTiers[scheduler.TierDataProcessing], "data_processing job should be requeued (shortest grace)")
|
||||
assert.False(t, requeuedTiers[scheduler.TierTraining], "training job should NOT be requeued yet (longest grace)")
|
||||
}
|
||||
Loading…
Reference in a new issue