fetch_ml/tests/load/load_test.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

744 lines
20 KiB
Go

package load
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"net/http"
"net/http/httptest"
"path/filepath"
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/storage"
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
"github.com/redis/go-redis/v9"
"golang.org/x/time/rate"
)
var (
loadSuite = flag.String("load-suite", "small", "load test suite to run: small, medium, or full")
loadProfileScenario = flag.String("load-profile", "", "run a specific profiling scenario (e.g. medium)")
profileNoRate = flag.Bool("profile-norate", false, "disable rate limiting for profiling")
)
type loadSuiteStep struct {
name string
run func(t *testing.T, baseURL string)
}
type scenarioDefinition struct {
name string
config LoadTestConfig
}
var standardScenarios = map[string]scenarioDefinition{
"light": {
name: "LightLoad",
config: LoadTestConfig{
Concurrency: 10,
Duration: 30 * time.Second,
RampUpTime: 5 * time.Second,
RequestsPerSec: 50,
PayloadSize: 1024,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
},
},
"medium": {
name: "MediumLoad",
config: LoadTestConfig{
Concurrency: 50,
Duration: 60 * time.Second,
RampUpTime: 10 * time.Second,
RequestsPerSec: 200,
PayloadSize: 4096,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
},
},
"heavy": {
name: "HeavyLoad",
config: LoadTestConfig{
Concurrency: 100,
Duration: 120 * time.Second,
RampUpTime: 20 * time.Second,
RequestsPerSec: 500,
PayloadSize: 8192,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
},
},
}
func scenarioStep(def scenarioDefinition) loadSuiteStep {
return loadSuiteStep{
name: def.name,
run: func(t *testing.T, baseURL string) {
runLoadTestScenario(t, baseURL, def.name, def.config)
},
}
}
type loadTestEnvironment struct {
baseURL string
}
var flagInit sync.Once
func ensureLoadTestFlagsParsed() {
flagInit.Do(func() {
if !flag.Parsed() {
flag.Parse()
}
if loadSuite != nil {
suiteKey := normalizeKey(*loadSuite)
if _, ok := loadSuites[suiteKey]; !ok {
suiteKey = "small"
}
*loadSuite = suiteKey
}
if loadProfileScenario != nil {
*loadProfileScenario = normalizeKey(*loadProfileScenario)
}
})
}
func normalizeKey(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
return value
}
func availableSuites() []string {
names := make([]string, 0, len(loadSuites))
for name := range loadSuites {
names = append(names, name)
}
sort.Strings(names)
return names
}
func availableProfileScenarios() []string {
names := make([]string, 0, len(standardScenarios))
for name := range standardScenarios {
names = append(names, name)
}
sort.Strings(names)
return names
}
func setupLoadTestEnvironment(t *testing.T) *loadTestEnvironment {
tempDir := t.TempDir()
rdb := setupLoadTestRedis(t)
if rdb == nil {
t.Skip("Redis not available for load tests")
}
dbPath := filepath.Join(tempDir, "load.db")
db, err := storage.NewDBFromPath(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.Initialize(fixtures.TestSchema); err != nil {
t.Fatalf("Failed to initialize database: %v", err)
}
server := setupLoadTestServer(db, rdb)
t.Cleanup(server.Close)
return &loadTestEnvironment{baseURL: server.URL}
}
var loadSuites = map[string][]loadSuiteStep{
"small": {
scenarioStep(standardScenarios["light"]),
{name: "SpikeTest", run: runSpikeTest},
},
"medium": {
scenarioStep(standardScenarios["light"]),
scenarioStep(standardScenarios["medium"]),
{name: "SpikeTest", run: runSpikeTest},
},
"full": {
scenarioStep(standardScenarios["light"]),
scenarioStep(standardScenarios["medium"]),
scenarioStep(standardScenarios["heavy"]),
{name: "SpikeTest", run: runSpikeTest},
{name: "EnduranceTest", run: runEnduranceTest},
{name: "StressTest", run: runStressTest},
},
}
// LoadTestConfig defines load testing parameters
type LoadTestConfig struct {
Concurrency int // Number of concurrent users/workers
Duration time.Duration // How long to run the test
RampUpTime time.Duration // Time to ramp up to full concurrency
RequestsPerSec int // Target requests per second
PayloadSize int // Size of test payloads in bytes
Endpoint string // API endpoint to test
Method string // HTTP method
Headers map[string]string
}
// LoadTestResults contains test results and metrics
type LoadTestResults struct {
TotalRequests int64 `json:"total_requests"`
SuccessfulReqs int64 `json:"successful_requests"`
FailedReqs int64 `json:"failed_requests"`
Latencies []time.Duration `json:"latencies"`
MinLatency time.Duration `json:"min_latency"`
MaxLatency time.Duration `json:"max_latency"`
AvgLatency time.Duration `json:"avg_latency"`
P95Latency time.Duration `json:"p95_latency"`
P99Latency time.Duration `json:"p99_latency"`
Throughput float64 `json:"throughput_rps"`
ErrorRate float64 `json:"error_rate_percent"`
TestDuration time.Duration `json:"test_duration"`
Errors []string `json:"errors"`
}
// LoadTestRunner executes load tests
type LoadTestRunner struct {
Config LoadTestConfig
BaseURL string
Client *http.Client
Results *LoadTestResults
latencies []time.Duration
latencyMu sync.Mutex
errorMu sync.Mutex
}
// NewLoadTestRunner creates a new load test runner
func NewLoadTestRunner(baseURL string, config LoadTestConfig) *LoadTestRunner {
concurrency := config.Concurrency
if concurrency <= 0 {
concurrency = 1
}
transport := &http.Transport{
MaxIdleConns: concurrency * 4,
MaxIdleConnsPerHost: concurrency * 4,
MaxConnsPerHost: concurrency * 4,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true,
}
client := &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
expectedSamples := config.RequestsPerSec * int(config.Duration/time.Second)
if expectedSamples <= 0 {
expectedSamples = concurrency * 2
}
runner := &LoadTestRunner{
Config: config,
BaseURL: baseURL,
Client: client,
Results: &LoadTestResults{
Latencies: []time.Duration{},
Errors: []string{},
},
}
runner.latencies = make([]time.Duration, 0, expectedSamples)
return runner
}
func TestLoadTestSuite(t *testing.T) {
if testing.Short() {
t.Skip("Skipping load tests in short mode")
}
ensureLoadTestFlagsParsed()
env := setupLoadTestEnvironment(t)
suiteKey := *loadSuite
steps, ok := loadSuites[suiteKey]
if !ok || len(steps) == 0 {
t.Fatalf("unknown load suite %q; available suites: %v", suiteKey, availableSuites())
}
t.Logf("Running %s load suite (%d steps)", suiteKey, len(steps))
for _, step := range steps {
step := step
t.Run(step.name, func(t *testing.T) {
step.run(t, env.baseURL)
})
}
}
func TestLoadProfileScenario(t *testing.T) {
if testing.Short() {
t.Skip("Skipping load profiling in short mode")
}
ensureLoadTestFlagsParsed()
scenarioKey := *loadProfileScenario
if scenarioKey == "" {
scenarioKey = "medium"
}
scenarioKey = normalizeKey(scenarioKey)
scenario, ok := standardScenarios[scenarioKey]
if !ok {
t.Skipf("unknown profile scenario %q; available scenarios: %v", scenarioKey, availableProfileScenarios())
}
env := setupLoadTestEnvironment(t)
config := scenario.config
if *profileNoRate {
config.RequestsPerSec = 0
t.Log("Profiling mode: rate limiting disabled")
}
runner := NewLoadTestRunner(env.baseURL, config)
results := runner.Run()
t.Logf("Profiling %s scenario (no assertions):", scenario.name)
t.Logf(" Total requests: %d", results.TotalRequests)
t.Logf(" Successful: %d", results.SuccessfulReqs)
t.Logf(" Failed: %d", results.FailedReqs)
t.Logf(" Throughput: %.2f RPS", results.Throughput)
t.Logf(" Error rate: %.2f%%", results.ErrorRate)
t.Logf(" Avg latency: %v", results.AvgLatency)
t.Logf(" P95 latency: %v", results.P95Latency)
t.Logf(" P99 latency: %v", results.P99Latency)
}
// runLoadTestScenario executes a single load test scenario
func runLoadTestScenario(t *testing.T, baseURL, scenarioName string, config LoadTestConfig) {
t.Logf("Starting load test scenario: %s", scenarioName)
runner := NewLoadTestRunner(baseURL, config)
results := runner.Run()
t.Logf("Load test results for %s:", scenarioName)
t.Logf(" Total requests: %d", results.TotalRequests)
t.Logf(" Successful: %d", results.SuccessfulReqs)
t.Logf(" Failed: %d", results.FailedReqs)
t.Logf(" Throughput: %.2f RPS", results.Throughput)
t.Logf(" Error rate: %.2f%%", results.ErrorRate)
t.Logf(" Avg latency: %v", results.AvgLatency)
t.Logf(" P95 latency: %v", results.P95Latency)
t.Logf(" P99 latency: %v", results.P99Latency)
// Validate results against thresholds
validateLoadTestResults(t, scenarioName, results)
}
// Run executes the load test
func (ltr *LoadTestRunner) Run() *LoadTestResults {
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), ltr.Config.Duration)
defer cancel()
var wg sync.WaitGroup
concurrency := ltr.Config.Concurrency
if concurrency <= 0 {
concurrency = 1
}
// Keep generating requests for the duration
effectiveRPS := ltr.Config.RequestsPerSec
if effectiveRPS <= 0 {
effectiveRPS = concurrency
if effectiveRPS <= 0 {
effectiveRPS = 1
}
}
var limiter *rate.Limiter
if ltr.Config.RequestsPerSec > 0 {
limiter = rate.NewLimiter(rate.Limit(ltr.Config.RequestsPerSec), ltr.Config.Concurrency)
}
// Ramp up workers gradually
rampUpInterval := time.Duration(0)
if concurrency > 0 && ltr.Config.RampUpTime > 0 {
rampUpInterval = ltr.Config.RampUpTime / time.Duration(concurrency)
}
// Start request workers
for i := 0; i < concurrency; i++ {
wg.Add(1)
go ltr.worker(ctx, &wg, limiter, rampUpInterval*time.Duration(i), i)
}
wg.Wait()
ltr.Results.TestDuration = time.Since(start)
ltr.calculateMetrics()
return ltr.Results
}
// worker executes requests continuously
func (ltr *LoadTestRunner) worker(ctx context.Context, wg *sync.WaitGroup, limiter *rate.Limiter, rampDelay time.Duration, workerID int) {
defer wg.Done()
if rampDelay > 0 {
select {
case <-time.After(rampDelay):
case <-ctx.Done():
return
}
}
latencies := make([]time.Duration, 0, 256)
errors := make([]string, 0, 32)
defer ltr.flushWorkerBuffers(latencies, errors)
for {
select {
case <-ctx.Done():
return
default:
}
if limiter != nil {
if err := limiter.Wait(ctx); err != nil {
return
}
}
latency, success, errMsg := ltr.makeRequest(ctx, workerID)
latencies = append(latencies, latency)
if success {
atomic.AddInt64(&ltr.Results.SuccessfulReqs, 1)
} else {
atomic.AddInt64(&ltr.Results.FailedReqs, 1)
if errMsg != "" {
errors = append(errors, fmt.Sprintf("Worker %d: %s", workerID, errMsg))
}
}
atomic.AddInt64(&ltr.Results.TotalRequests, 1)
}
}
func (ltr *LoadTestRunner) flushWorkerBuffers(latencies []time.Duration, errors []string) {
if len(latencies) > 0 {
ltr.latencyMu.Lock()
ltr.latencies = append(ltr.latencies, latencies...)
ltr.latencyMu.Unlock()
}
if len(errors) > 0 {
ltr.errorMu.Lock()
ltr.Results.Errors = append(ltr.Results.Errors, errors...)
ltr.errorMu.Unlock()
}
}
// makeRequest performs a single HTTP request
func (ltr *LoadTestRunner) makeRequest(ctx context.Context, workerID int) (time.Duration, bool, string) {
start := time.Now()
// Create request payload
payload := ltr.generatePayload(workerID)
var req *http.Request
var err error
if ltr.Config.Method == "GET" {
req, err = http.NewRequestWithContext(ctx, ltr.Config.Method, ltr.BaseURL+ltr.Config.Endpoint, nil)
} else {
req, err = http.NewRequestWithContext(ctx,
ltr.Config.Method,
ltr.BaseURL+ltr.Config.Endpoint,
bytes.NewBuffer(payload))
}
if err != nil {
return time.Since(start), false, fmt.Sprintf("Failed to create request: %v", err)
}
// Set headers
for key, value := range ltr.Config.Headers {
req.Header.Set(key, value)
}
resp, err := ltr.Client.Do(req)
if err != nil {
return time.Since(start), false, fmt.Sprintf("Request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
success := resp.StatusCode >= 200 && resp.StatusCode < 400
if !success {
return time.Since(start), false, fmt.Sprintf("HTTP %d", resp.StatusCode)
}
return time.Since(start), true, ""
}
// generatePayload creates test payload data
func (ltr *LoadTestRunner) generatePayload(workerID int) []byte {
if ltr.Config.Method == "GET" {
return nil
}
payload := map[string]interface{}{
"job_name": fmt.Sprintf("load-test-job-%d-%d", workerID, time.Now().UnixNano()),
"args": map[string]interface{}{
"model": "test-model",
"data": generateRandomData(ltr.Config.PayloadSize),
"worker_id": workerID,
},
"priority": workerID % 3,
}
data, _ := json.Marshal(payload)
return data
}
// calculateMetrics computes performance metrics from collected latencies
func (ltr *LoadTestRunner) calculateMetrics() {
if len(ltr.latencies) == 0 {
return
}
// Sort latencies for percentile calculations
sorted := make([]time.Duration, len(ltr.latencies))
copy(sorted, ltr.latencies)
sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] })
ltr.Results.MinLatency = sorted[0]
ltr.Results.MaxLatency = sorted[len(sorted)-1]
// Calculate average
var total time.Duration
for _, latency := range sorted {
total += latency
}
ltr.Results.AvgLatency = total / time.Duration(len(sorted))
// Calculate percentiles
p95Index := int(float64(len(sorted)) * 0.95)
p99Index := int(float64(len(sorted)) * 0.99)
if p95Index < len(sorted) {
ltr.Results.P95Latency = sorted[p95Index]
}
if p99Index < len(sorted) {
ltr.Results.P99Latency = sorted[p99Index]
}
ltr.Results.Latencies = sorted
// Calculate throughput and error rate
ltr.Results.Throughput = float64(ltr.Results.TotalRequests) / ltr.Results.TestDuration.Seconds()
if ltr.Results.TotalRequests > 0 {
ltr.Results.ErrorRate = float64(ltr.Results.FailedReqs) / float64(ltr.Results.TotalRequests) * 100
}
}
// runSpikeTest tests system behavior under sudden load spikes
func runSpikeTest(t *testing.T, baseURL string) {
t.Log("Running spike test")
config := LoadTestConfig{
Concurrency: 200,
Duration: 30 * time.Second,
RampUpTime: 1 * time.Second, // Very fast ramp-up
RequestsPerSec: 1000,
PayloadSize: 2048,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
}
runner := NewLoadTestRunner(baseURL, config)
results := runner.Run()
t.Logf("Spike test results:")
t.Logf(" Throughput: %.2f RPS", results.Throughput)
t.Logf(" Error rate: %.2f%%", results.ErrorRate)
t.Logf(" P99 latency: %v", results.P99Latency)
// Spike test should allow higher error rate but still maintain reasonable performance
if results.ErrorRate > 20.0 {
t.Errorf("Spike test error rate too high: %.2f%%", results.ErrorRate)
}
}
// runEnduranceTest tests system performance over extended periods
func runEnduranceTest(t *testing.T, baseURL string) {
t.Log("Running endurance test")
config := LoadTestConfig{
Concurrency: 25,
Duration: 10 * time.Minute, // Extended duration
RampUpTime: 30 * time.Second,
RequestsPerSec: 100,
PayloadSize: 4096,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
}
runner := NewLoadTestRunner(baseURL, config)
results := runner.Run()
t.Logf("Endurance test results:")
t.Logf(" Total requests: %d", results.TotalRequests)
t.Logf(" Throughput: %.2f RPS", results.Throughput)
t.Logf(" Error rate: %.2f%%", results.ErrorRate)
t.Logf(" Avg latency: %v", results.AvgLatency)
// Endurance test should maintain stable performance
if results.ErrorRate > 5.0 {
t.Errorf("Endurance test error rate too high: %.2f%%", results.ErrorRate)
}
}
// runStressTest tests system limits and breaking points
func runStressTest(t *testing.T, baseURL string) {
t.Log("Running stress test")
// Gradually increase load until system breaks
maxConcurrency := 500
for concurrency := 100; concurrency <= maxConcurrency; concurrency += 100 {
config := LoadTestConfig{
Concurrency: concurrency,
Duration: 60 * time.Second,
RampUpTime: 10 * time.Second,
RequestsPerSec: concurrency * 5,
PayloadSize: 8192,
Endpoint: "/api/v1/jobs",
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
}
runner := NewLoadTestRunner(baseURL, config)
results := runner.Run()
t.Logf("Stress test at concurrency %d:", concurrency)
t.Logf(" Throughput: %.2f RPS", results.Throughput)
t.Logf(" Error rate: %.2f%%", results.ErrorRate)
// Stop test if error rate becomes too high
if results.ErrorRate > 50.0 {
t.Logf("System breaking point reached at concurrency %d", concurrency)
break
}
}
}
// validateLoadTestResults checks if results meet performance criteria
func validateLoadTestResults(t *testing.T, scenarioName string, results *LoadTestResults) {
// Define performance thresholds based on scenario type
var maxErrorRate, maxP99Latency float64
var minThroughput float64
switch scenarioName {
case "LightLoad":
maxErrorRate = 1.0
maxP99Latency = 100.0 // 100ms
minThroughput = 40.0
case "MediumLoad":
maxErrorRate = 2.0
maxP99Latency = 200.0 // 200ms
minThroughput = 180.0
case "HeavyLoad":
maxErrorRate = 5.0
maxP99Latency = 500.0 // 500ms
minThroughput = 450.0
default:
return
}
if results.ErrorRate > maxErrorRate {
t.Errorf("%s error rate too high: %.2f%% (max: %.2f%%)", scenarioName, results.ErrorRate, maxErrorRate)
}
if float64(results.P99Latency.Nanoseconds())/1e6 > maxP99Latency {
t.Errorf("%s P99 latency too high: %v (max: %.0fms)", scenarioName, results.P99Latency, maxP99Latency)
}
if results.Throughput < minThroughput {
t.Errorf("%s throughput too low: %.2f RPS (min: %.2f RPS)", scenarioName, results.Throughput, minThroughput)
}
}
// Helper functions
func setupLoadTestRedis(t *testing.T) *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 7, // Use DB 7 for load tests
})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skipf("Redis not available for load tests: %v", err)
return nil
}
rdb.FlushDB(ctx)
t.Cleanup(func() {
rdb.FlushDB(ctx)
_ = rdb.Close()
})
return rdb
}
func setupLoadTestServer(_ *storage.DB, _ *redis.Client) *httptest.Server {
mux := http.NewServeMux()
// Simple API endpoints for load testing
mux.HandleFunc("/api/v1/jobs", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" {
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]string{"id": "test-job-id"})
} else {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode([]map[string]string{{"id": "test-job-id", "status": "pending"}})
}
})
mux.HandleFunc("/api/v1/jobs/", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "pending"})
})
server := httptest.NewUnstartedServer(mux)
// Optimize server configuration for better throughput
server.Config.ReadTimeout = 30 * time.Second
server.Config.WriteTimeout = 30 * time.Second
server.Config.IdleTimeout = 120 * time.Second
server.Start()
return server
}
func generateRandomData(size int) string {
data := make([]byte, size)
for i := range data {
data[i] = byte(i % 256)
}
return string(data)
}