package load import ( "bytes" "context" "encoding/json" "flag" "fmt" "math" "net/http" "net/http/httptest" "path/filepath" "sort" "strings" "sync" "sync/atomic" "testing" "time" "github.com/jfraeys/fetch_ml/internal/storage" fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" "github.com/redis/go-redis/v9" "golang.org/x/time/rate" ) var ( loadSuite = flag.String("load-suite", "small", "load test suite to run: small, medium, or full") loadProfileScenario = flag.String("load-profile", "", "run a specific profiling scenario (e.g. medium)") profileNoRate = flag.Bool("profile-norate", false, "disable rate limiting for profiling") ) type loadSuiteStep struct { run func(t *testing.T, baseURL string) name string } type scenarioDefinition struct { config LoadTestConfig name string } var standardScenarios = map[string]scenarioDefinition{ "light": { name: "LightLoad", config: LoadTestConfig{ Concurrency: 10, Duration: 30 * time.Second, RampUpTime: 5 * time.Second, RequestsPerSec: 50, PayloadSize: 1024, Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, }, }, "medium": { name: "MediumLoad", config: LoadTestConfig{ Concurrency: 50, Duration: 60 * time.Second, RampUpTime: 10 * time.Second, RequestsPerSec: 200, PayloadSize: 4096, Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, }, }, "heavy": { name: "HeavyLoad", config: LoadTestConfig{ Concurrency: 100, Duration: 120 * time.Second, RampUpTime: 20 * time.Second, RequestsPerSec: 500, PayloadSize: 8192, Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, }, }, } func scenarioStep(def scenarioDefinition) loadSuiteStep { return loadSuiteStep{ name: def.name, run: func(t *testing.T, baseURL string) { runLoadTestScenario(t, baseURL, def.name, def.config) }, } } type loadTestEnvironment struct { baseURL string } var flagInit sync.Once func ensureLoadTestFlagsParsed() { flagInit.Do(func() { if !flag.Parsed() { flag.Parse() } if loadSuite != nil { suiteKey := normalizeKey(*loadSuite) if _, ok := loadSuites[suiteKey]; !ok { suiteKey = "small" } *loadSuite = suiteKey } if loadProfileScenario != nil { *loadProfileScenario = normalizeKey(*loadProfileScenario) } }) } func normalizeKey(value string) string { value = strings.TrimSpace(strings.ToLower(value)) return value } func availableSuites() []string { names := make([]string, 0, len(loadSuites)) for name := range loadSuites { names = append(names, name) } sort.Strings(names) return names } func availableProfileScenarios() []string { names := make([]string, 0, len(standardScenarios)) for name := range standardScenarios { names = append(names, name) } sort.Strings(names) return names } func setupLoadTestEnvironment(t *testing.T) *loadTestEnvironment { tempDir := t.TempDir() rdb := setupLoadTestRedis(t) if rdb == nil { t.Skip("Redis not available for load tests") } dbPath := filepath.Join(tempDir, "load.db") db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } t.Cleanup(func() { _ = db.Close() }) if err := db.Initialize(fixtures.TestSchema); err != nil { t.Fatalf("Failed to initialize database: %v", err) } server := setupLoadTestServer(db, rdb) t.Cleanup(server.Close) return &loadTestEnvironment{baseURL: server.URL} } var loadSuites = map[string][]loadSuiteStep{ "small": { scenarioStep(standardScenarios["light"]), {name: "SpikeTest", run: runSpikeTest}, }, "medium": { scenarioStep(standardScenarios["light"]), scenarioStep(standardScenarios["medium"]), {name: "SpikeTest", run: runSpikeTest}, }, "full": { scenarioStep(standardScenarios["light"]), scenarioStep(standardScenarios["medium"]), scenarioStep(standardScenarios["heavy"]), {name: "SpikeTest", run: runSpikeTest}, {name: "EnduranceTest", run: runEnduranceTest}, {name: "StressTest", run: runStressTest}, }, } // LoadTestConfig defines load testing parameters type LoadTestConfig struct { Headers map[string]string Endpoint string Method string Concurrency int Duration time.Duration RampUpTime time.Duration RequestsPerSec int PayloadSize int } // LoadTestResults contains test results and metrics type LoadTestResults struct { Latencies []time.Duration `json:"latencies"` Errors []string `json:"errors"` AvgLatency time.Duration `json:"avg_latency"` FailedReqs int64 `json:"failed_requests"` MinLatency time.Duration `json:"min_latency"` MaxLatency time.Duration `json:"max_latency"` TotalRequests int64 `json:"total_requests"` P95Latency time.Duration `json:"p95_latency"` P99Latency time.Duration `json:"p99_latency"` Throughput float64 `json:"throughput_rps"` ErrorRate float64 `json:"error_rate_percent"` TestDuration time.Duration `json:"test_duration"` SuccessfulReqs int64 `json:"successful_requests"` } // LoadTestRunner executes load tests type LoadTestRunner struct { Config LoadTestConfig BaseURL string Client *http.Client Results *LoadTestResults latencies []time.Duration latencyMu sync.Mutex errorMu sync.Mutex } // NewLoadTestRunner creates a new load test runner func NewLoadTestRunner(baseURL string, config LoadTestConfig) *LoadTestRunner { concurrency := config.Concurrency if concurrency <= 0 { concurrency = 1 } transport := &http.Transport{ MaxIdleConns: concurrency * 4, MaxIdleConnsPerHost: concurrency * 4, MaxConnsPerHost: concurrency * 4, IdleConnTimeout: 90 * time.Second, DisableCompression: true, } client := &http.Client{ Timeout: 30 * time.Second, Transport: transport, } expectedSamples := config.RequestsPerSec * int(config.Duration/time.Second) if expectedSamples <= 0 { expectedSamples = concurrency * 2 } runner := &LoadTestRunner{ Config: config, BaseURL: baseURL, Client: client, Results: &LoadTestResults{ Latencies: []time.Duration{}, Errors: []string{}, }, } runner.latencies = make([]time.Duration, 0, expectedSamples) return runner } func TestLoadTestSuite(t *testing.T) { if testing.Short() { t.Skip("Skipping load tests in short mode") } ensureLoadTestFlagsParsed() env := setupLoadTestEnvironment(t) suiteKey := *loadSuite steps, ok := loadSuites[suiteKey] if !ok || len(steps) == 0 { t.Fatalf("unknown load suite %q; available suites: %v", suiteKey, availableSuites()) } t.Logf("Running %s load suite (%d steps)", suiteKey, len(steps)) for _, step := range steps { step := step t.Run(step.name, func(t *testing.T) { step.run(t, env.baseURL) }) } } func TestLoadProfileScenario(t *testing.T) { if testing.Short() { t.Skip("Skipping load profiling in short mode") } ensureLoadTestFlagsParsed() scenarioKey := *loadProfileScenario if scenarioKey == "" { scenarioKey = "medium" } scenarioKey = normalizeKey(scenarioKey) scenario, ok := standardScenarios[scenarioKey] if !ok { t.Skipf("unknown profile scenario %q; available scenarios: %v", scenarioKey, availableProfileScenarios()) } env := setupLoadTestEnvironment(t) config := scenario.config if *profileNoRate { config.RequestsPerSec = 0 t.Log("Profiling mode: rate limiting disabled") } runner := NewLoadTestRunner(env.baseURL, config) results := runner.Run() t.Logf("Profiling %s scenario (no assertions):", scenario.name) t.Logf(" Total requests: %d", results.TotalRequests) t.Logf(" Successful: %d", results.SuccessfulReqs) t.Logf(" Failed: %d", results.FailedReqs) t.Logf(" Throughput: %.4f RPS", results.Throughput) t.Logf(" Error rate: %.2f%%", results.ErrorRate) t.Logf(" Avg latency: %v", results.AvgLatency) t.Logf(" P95 latency: %v", results.P95Latency) t.Logf(" P99 latency: %v", results.P99Latency) } // runLoadTestScenario executes a single load test scenario func runLoadTestScenario(t *testing.T, baseURL, scenarioName string, config LoadTestConfig) { t.Logf("Starting load test scenario: %s", scenarioName) runner := NewLoadTestRunner(baseURL, config) results := runner.Run() if minThroughput, maxErrorRate, maxP99Latency, ok := loadTestCriteria(scenarioName); ok { t.Logf("Load test criteria for %s:", scenarioName) t.Logf(" Min throughput: %.2f RPS", minThroughput) t.Logf(" Max error rate: %.2f%%", maxErrorRate) t.Logf(" Max P99 latency: %.0fms", maxP99Latency) } t.Logf("Load test config for %s:", scenarioName) t.Logf(" Concurrency: %d", config.Concurrency) t.Logf(" Duration: %v", config.Duration) t.Logf(" Ramp up: %v", config.RampUpTime) t.Logf(" Target RPS: %d", config.RequestsPerSec) t.Logf(" Payload size: %d", config.PayloadSize) t.Logf(" Method: %s", config.Method) t.Logf(" Endpoint: %s", config.Endpoint) t.Logf("Load test results for %s:", scenarioName) t.Logf(" Total requests: %d", results.TotalRequests) t.Logf(" Successful: %d", results.SuccessfulReqs) t.Logf(" Failed: %d", results.FailedReqs) t.Logf(" Test duration: %v", results.TestDuration) t.Logf(" Test duration (seconds): %.6f", results.TestDuration.Seconds()) t.Logf(" Throughput: %.4f RPS", results.Throughput) t.Logf(" Error rate: %.2f%%", results.ErrorRate) t.Logf(" Avg latency: %v", results.AvgLatency) t.Logf(" P95 latency: %v", results.P95Latency) t.Logf(" P99 latency: %v", results.P99Latency) // Validate results against thresholds validateLoadTestResults(t, scenarioName, results) } func loadTestCriteria(scenarioName string) (minThroughput, maxErrorRate, maxP99Latency float64, ok bool) { switch scenarioName { case "LightLoad": return 40.0, 1.0, 100.0, true case "MediumLoad": return 180.0, 2.0, 200.0, true case "HeavyLoad": return 450.0, 5.0, 500.0, true default: return 0, 0, 0, false } } // Run executes the load test func (ltr *LoadTestRunner) Run() *LoadTestResults { start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), ltr.Config.Duration) defer cancel() var wg sync.WaitGroup concurrency := ltr.Config.Concurrency if concurrency <= 0 { concurrency = 1 } var limiter *rate.Limiter if ltr.Config.RequestsPerSec > 0 { limiter = rate.NewLimiter(rate.Limit(ltr.Config.RequestsPerSec), ltr.Config.Concurrency) } // Ramp up workers gradually rampUpInterval := time.Duration(0) if concurrency > 0 && ltr.Config.RampUpTime > 0 { rampUpInterval = ltr.Config.RampUpTime / time.Duration(concurrency) } // Start request workers for i := 0; i < concurrency; i++ { wg.Add(1) go ltr.worker(ctx, &wg, limiter, rampUpInterval*time.Duration(i), i) } wg.Wait() ltr.Results.TestDuration = time.Since(start) ltr.calculateMetrics() return ltr.Results } // worker executes requests continuously func (ltr *LoadTestRunner) worker( ctx context.Context, wg *sync.WaitGroup, limiter *rate.Limiter, rampDelay time.Duration, workerID int, ) { defer wg.Done() if rampDelay > 0 { select { case <-time.After(rampDelay): case <-ctx.Done(): return } } latencies := make([]time.Duration, 0, 256) errors := make([]string, 0, 32) for { select { case <-ctx.Done(): ltr.flushWorkerBuffers(latencies, errors) return default: } if limiter != nil { if err := limiter.Wait(ctx); err != nil { ltr.flushWorkerBuffers(latencies, errors) return } } latency, success, errMsg := ltr.makeRequest(ctx, workerID) latencies = append(latencies, latency) if success { atomic.AddInt64(<r.Results.SuccessfulReqs, 1) } else { atomic.AddInt64(<r.Results.FailedReqs, 1) if errMsg != "" { errors = append(errors, fmt.Sprintf("Worker %d: %s", workerID, errMsg)) } } atomic.AddInt64(<r.Results.TotalRequests, 1) } } func (ltr *LoadTestRunner) flushWorkerBuffers(latencies []time.Duration, errors []string) { if len(latencies) > 0 { ltr.latencyMu.Lock() ltr.latencies = append(ltr.latencies, latencies...) ltr.latencyMu.Unlock() } if len(errors) > 0 { ltr.errorMu.Lock() ltr.Results.Errors = append(ltr.Results.Errors, errors...) ltr.errorMu.Unlock() } } // makeRequest performs a single HTTP request with retry logic func (ltr *LoadTestRunner) makeRequest(ctx context.Context, workerID int) (time.Duration, bool, string) { start := time.Now() // Create request payload payload := ltr.generatePayload(workerID) var req *http.Request var err error if ltr.Config.Method == "GET" { req, err = http.NewRequestWithContext(ctx, ltr.Config.Method, ltr.BaseURL+ltr.Config.Endpoint, nil) } else { req, err = http.NewRequestWithContext(ctx, ltr.Config.Method, ltr.BaseURL+ltr.Config.Endpoint, bytes.NewBuffer(payload)) } if err != nil { return time.Since(start), false, fmt.Sprintf("Failed to create request: %v", err) } // Set headers for key, value := range ltr.Config.Headers { req.Header.Set(key, value) } // Retry logic for transient failures maxRetries := 3 for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { // Exponential backoff: 100ms, 200ms, 400ms backoff := time.Duration(100*int(math.Pow(2, float64(attempt-1)))) * time.Millisecond select { case <-time.After(backoff): case <-ctx.Done(): return time.Since(start), false, "context cancelled during retry backoff" } } resp, err := ltr.Client.Do(req) if err != nil { if attempt == maxRetries { return time.Since(start), false, fmt.Sprintf("Request failed after %d attempts: %v", maxRetries+1, err) } continue // Retry on network errors } success := resp.StatusCode >= 200 && resp.StatusCode < 400 if !success { resp.Body.Close() // Don't retry on client errors (4xx), only on server errors (5xx) if resp.StatusCode >= 400 && resp.StatusCode < 500 { return time.Since(start), false, fmt.Sprintf("Client error HTTP %d (not retried)", resp.StatusCode) } if attempt == maxRetries { return time.Since(start), false, fmt.Sprintf( "Server error HTTP %d after %d attempts", resp.StatusCode, maxRetries+1, ) } continue // Retry on server errors } resp.Body.Close() return time.Since(start), true, "" } return time.Since(start), false, "max retries exceeded" } // generatePayload creates test payload data func (ltr *LoadTestRunner) generatePayload(workerID int) []byte { if ltr.Config.Method == "GET" { return nil } payload := map[string]interface{}{ "job_name": fmt.Sprintf("load-test-job-%d-%d", workerID, time.Now().UnixNano()), "args": map[string]interface{}{ "model": "test-model", "data": generateRandomData(ltr.Config.PayloadSize), "worker_id": workerID, }, "priority": workerID % 3, } data, _ := json.Marshal(payload) return data } // calculateMetrics computes performance metrics from collected latencies func (ltr *LoadTestRunner) calculateMetrics() { if len(ltr.latencies) == 0 { return } // Sort latencies for percentile calculations sorted := make([]time.Duration, len(ltr.latencies)) copy(sorted, ltr.latencies) sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) ltr.Results.MinLatency = sorted[0] ltr.Results.MaxLatency = sorted[len(sorted)-1] // Calculate average var total time.Duration for _, latency := range sorted { total += latency } ltr.Results.AvgLatency = total / time.Duration(len(sorted)) // Calculate percentiles p95Index := int(float64(len(sorted)) * 0.95) p99Index := int(float64(len(sorted)) * 0.99) if p95Index < len(sorted) { ltr.Results.P95Latency = sorted[p95Index] } if p99Index < len(sorted) { ltr.Results.P99Latency = sorted[p99Index] } ltr.Results.Latencies = sorted // Calculate throughput and error rate testDurationSeconds := ltr.Results.TestDuration.Seconds() if testDurationSeconds > 0 { ltr.Results.Throughput = float64(ltr.Results.TotalRequests) / testDurationSeconds } else { ltr.Results.Throughput = 0 } if ltr.Results.TotalRequests > 0 { ltr.Results.ErrorRate = float64(ltr.Results.FailedReqs) / float64(ltr.Results.TotalRequests) * 100 } } // runSpikeTest tests system behavior under sudden load spikes func runSpikeTest(t *testing.T, baseURL string) { t.Log("Running spike test") config := LoadTestConfig{ Concurrency: 50, // Reduced from 200 Duration: 30 * time.Second, RampUpTime: 2 * time.Second, // Slower ramp-up from 1s RequestsPerSec: 200, // Reduced from 1000 PayloadSize: 2048, Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, } runner := NewLoadTestRunner(baseURL, config) results := runner.Run() t.Logf("Spike test results:") t.Logf(" Throughput: %.4f RPS", results.Throughput) t.Logf(" Error rate: %.2f%%", results.ErrorRate) t.Logf(" P99 latency: %v", results.P99Latency) // Spike test should allow higher error rate but still maintain reasonable performance if results.ErrorRate > 20.0 { t.Errorf("Spike test error rate too high: %.2f%%", results.ErrorRate) } } // runEnduranceTest tests system performance over extended periods func runEnduranceTest(t *testing.T, baseURL string) { t.Log("Running endurance test") config := LoadTestConfig{ Concurrency: 25, Duration: 2 * time.Minute, // Reduced from 10 minutes RampUpTime: 15 * time.Second, // Reduced from 30s RequestsPerSec: 100, PayloadSize: 2048, // Reduced from 4096 Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, } runner := NewLoadTestRunner(baseURL, config) results := runner.Run() t.Logf("Endurance test results:") t.Logf(" Total requests: %d", results.TotalRequests) t.Logf(" Throughput: %.4f RPS", results.Throughput) t.Logf(" Error rate: %.2f%%", results.ErrorRate) t.Logf(" Avg latency: %v", results.AvgLatency) // Endurance test should maintain stable performance if results.ErrorRate > 5.0 { t.Errorf("Endurance test error rate too high: %.2f%%", results.ErrorRate) } } // runStressTest tests system limits and breaking points func runStressTest(t *testing.T, baseURL string) { t.Log("Running stress test") // Gradually increase load until system breaks maxConcurrency := 200 // Reduced from 500 for concurrency := 50; concurrency <= maxConcurrency; concurrency += 50 { // Start from 50, increment by 50 config := LoadTestConfig{ Concurrency: concurrency, Duration: 20 * time.Second, // Reduced from 60s RampUpTime: 5 * time.Second, // Reduced from 10s RequestsPerSec: concurrency * 3, // Reduced from *5 PayloadSize: 4096, // Reduced from 8192 Endpoint: "/api/v1/jobs", Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, } runner := NewLoadTestRunner(baseURL, config) results := runner.Run() t.Logf("Stress test at concurrency %d:", concurrency) t.Logf(" Throughput: %.4f RPS", results.Throughput) t.Logf(" Error rate: %.2f%%", results.ErrorRate) // Stop test if error rate becomes too high if results.ErrorRate > 50.0 { t.Logf("System breaking point reached at concurrency %d", concurrency) break } } } // validateLoadTestResults checks if results meet performance criteria func validateLoadTestResults(t *testing.T, scenarioName string, results *LoadTestResults) { minThroughput, maxErrorRate, maxP99Latency, ok := loadTestCriteria(scenarioName) if !ok { return } if results.ErrorRate > maxErrorRate { t.Errorf("%s error rate too high: %.2f%% (max: %.2f%%)", scenarioName, results.ErrorRate, maxErrorRate) } if float64(results.P99Latency.Nanoseconds())/1e6 > maxP99Latency { t.Errorf("%s P99 latency too high: %v (max: %.0fms)", scenarioName, results.P99Latency, maxP99Latency) } if results.Throughput < minThroughput { t.Errorf("%s throughput too low: %.4f RPS (min: %.2f RPS)", scenarioName, results.Throughput, minThroughput) } } // Performance benchmarks for tracking improvements func BenchmarkLoadTestLightLoad(b *testing.B) { config := standardScenarios["light"].config for i := 0; i < b.N; i++ { b.StopTimer() server := setupLoadTestServer(nil, nil) baseURL := server.URL b.Cleanup(server.Close) runner := NewLoadTestRunner(baseURL, config) b.StartTimer() results := runner.Run() b.StopTimer() // Track key metrics b.ReportMetric(float64(results.Throughput), "RPS") b.ReportMetric(results.ErrorRate, "error_rate") b.ReportMetric(float64(results.P99Latency.Nanoseconds())/1e6, "P99_latency_ms") } } func setupLoadTestRedis(t *testing.T) *redis.Client { rdb := redis.NewClient(&redis.Options{ Addr: "localhost:6379", Password: "", DB: 7, // Use DB 7 for load tests }) ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { t.Skipf("Redis not available for load tests: %v", err) return nil } rdb.FlushDB(ctx) t.Cleanup(func() { rdb.FlushDB(ctx) _ = rdb.Close() }) return rdb } func setupLoadTestServer(_ *storage.DB, _ *redis.Client) *httptest.Server { mux := http.NewServeMux() // Simple API endpoints for load testing mux.HandleFunc("/api/v1/jobs", func(w http.ResponseWriter, r *http.Request) { if r.Method == "POST" { w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"id": "test-job-id"}) } else { w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode([]map[string]string{{"id": "test-job-id", "status": "pending"}}) } }) mux.HandleFunc("/api/v1/jobs/", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(map[string]string{"status": "pending"}) }) server := httptest.NewUnstartedServer(mux) // Optimize server configuration for better throughput server.Config.ReadTimeout = 30 * time.Second server.Config.WriteTimeout = 30 * time.Second server.Config.IdleTimeout = 120 * time.Second server.Start() return server } func generateRandomData(size int) string { data := make([]byte, size) for i := range data { data[i] = byte(i % 256) } return string(data) }