fetch_ml/tests/benchmarks/ml_experiment_benchmark_test.go
Jeremie Fraeys be67cb77d3
test(benchmarks): update benchmark tests with job cleanup and improvements
**Payload Performance Test:**
- Add job cleanup after each iteration using DeleteJob()
- Ensure isolated memory measurements between test runs

**All Benchmark Tests:**
- General improvements and maintenance updates
2026-02-23 18:03:54 -05:00

458 lines
12 KiB
Go

package benchmarks
import (
"context"
"fmt"
"path/filepath"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/storage"
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
"github.com/redis/go-redis/v9"
)
// BenchmarkMLExperimentExecution measures ML experiment performance
func BenchmarkMLExperimentExecution(b *testing.B) {
// Setup test environment
tempDir := b.TempDir()
rdb := setupBenchmarkRedis(b)
if rdb == nil {
b.Skip("Redis not available")
}
defer func() { _ = rdb.Close() }()
db, err := storage.NewDBFromPath(filepath.Join(tempDir, "test.db"))
if err != nil {
b.Fatalf("Failed to create database: %v", err)
}
defer func() { _ = db.Close() }()
// Initialize database schema
err = db.Initialize(getMLSchema())
if err != nil {
b.Fatalf("Failed to initialize database: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark different experiment types
b.Run("SmallExperiment", func(b *testing.B) {
benchmarkMLExperiment(b, db, rdb, "small", 100, 1024) // 100 jobs, 1KB each
})
b.Run("MediumExperiment", func(b *testing.B) {
benchmarkMLExperiment(b, db, rdb, "medium", 50, 10240) // 50 jobs, 10KB each
})
b.Run("LargeExperiment", func(b *testing.B) {
benchmarkMLExperiment(b, db, rdb, "large", 10, 102400) // 10 jobs, 100KB each
})
b.Run("ConcurrentExperiments", func(b *testing.B) {
benchmarkConcurrentExperiments(b, db, rdb)
})
b.Run("ExperimentMetrics", func(b *testing.B) {
benchmarkExperimentMetrics(b, db, rdb)
})
}
// benchmarkMLExperiment tests ML experiment execution performance
func benchmarkMLExperiment(b *testing.B, db *storage.DB, rdb *redis.Client, expType string, numJobs, payloadSize int) {
m := &metrics.Metrics{}
for i := 0; i < b.N; i++ {
expID := fmt.Sprintf("exp-%s-%d-%d", expType, i, time.Now().UnixNano())
// Create experiment
start := time.Now()
err := createMLExperiment(db, rdb, expID, numJobs, payloadSize)
if err != nil {
b.Fatalf("Failed to create experiment: %v", err)
}
m.RecordTaskStart()
// Simulate experiment execution
err = executeMLExperiment(db, rdb, expID, numJobs)
if err != nil {
b.Fatalf("Failed to execute experiment: %v", err)
}
m.RecordTaskCompletion()
executionTime := time.Since(start)
m.RecordDataTransfer(int64(numJobs*payloadSize), 0)
_ = executionTime // Use executionTime to avoid unused variable warning
// Record experiment metrics
for j := range 5 {
metricName := fmt.Sprintf("metric_%d_%d", j, i)
err := db.RecordJobMetric(expID, metricName, fmt.Sprintf("%.2f", float64(j)*1.5))
if err != nil {
b.Errorf("Failed to record metric %s: %v", metricName, err)
}
}
}
}
// benchmarkConcurrentExperiments tests concurrent experiment execution
func benchmarkConcurrentExperiments(b *testing.B, db *storage.DB, rdb *redis.Client) {
numExperiments := 5
jobsPerExperiment := 20
payloadSize := 5120 // 5KB
b.ResetTimer()
// Create experiments concurrently
for i := 0; b.Loop(); i++ {
done := make(chan bool, numExperiments)
for exp := range numExperiments {
go func(expID int) {
defer func() { done <- true }()
expName := fmt.Sprintf("concurrent-exp-%d-%d-%d", i, expID, time.Now().UnixNano())
err := createMLExperiment(db, rdb, expName, jobsPerExperiment, payloadSize)
if err != nil {
b.Errorf("Failed to create experiment %d: %v", expID, err)
return
}
err = executeMLExperiment(db, rdb, expName, jobsPerExperiment)
if err != nil {
b.Errorf("Failed to execute experiment %d: %v", expID, err)
}
}(exp)
}
// Wait for all experiments to complete
for j := 0; j < numExperiments; j++ {
<-done
}
}
}
// benchmarkExperimentMetrics tests metrics recording performance
func benchmarkExperimentMetrics(b *testing.B, db *storage.DB, _ *redis.Client) {
metricsPerJob := 10
numJobs := 100
// Create test jobs
jobIDs := make([]string, numJobs)
for i := range jobIDs {
jobIDs[i] = fmt.Sprintf("metrics-job-%d-%d", i, time.Now().UnixNano())
job := &storage.Job{
ID: jobIDs[i],
JobName: fmt.Sprintf("Metrics Job %d", i),
Status: "completed",
Priority: 0,
}
err := db.CreateJob(job)
if err != nil {
b.Fatalf("Failed to create job %d: %v", i, err)
}
}
b.ResetTimer()
b.ReportAllocs()
// Record metrics for all jobs
for i := 0; b.Loop(); i++ {
for _, jobID := range jobIDs {
for j := range metricsPerJob {
metricName := fmt.Sprintf("metric_%d_%d", j, i)
metricValue := fmt.Sprintf("%.6f", float64(i*j)*0.001)
err := db.RecordJobMetric(jobID, metricName, metricValue)
if err != nil {
b.Errorf("Failed to record metric %s for job %s: %v", metricName, jobID, err)
}
}
}
}
}
// BenchmarkDatasetOperations tests dataset-related performance
func BenchmarkDatasetOperations(b *testing.B) {
tempDir := b.TempDir()
db, err := storage.NewDBFromPath(filepath.Join(tempDir, "test.db"))
if err != nil {
b.Fatalf("Failed to create database: %v", err)
}
defer func() { _ = db.Close() }()
err = db.Initialize(getMLSchema())
if err != nil {
b.Fatalf("Failed to initialize database: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
b.Run("DatasetCreation", func(b *testing.B) {
benchmarkDatasetCreation(b, db)
})
b.Run("DatasetRetrieval", func(b *testing.B) {
benchmarkDatasetRetrieval(b, db)
})
b.Run("DatasetUpdate", func(b *testing.B) {
benchmarkDatasetUpdate(b, db)
})
}
func benchmarkDatasetCreation(b *testing.B, db *storage.DB) {
for i := 0; b.Loop(); i++ {
datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano())
// Create a job first for foreign key constraint
job := &storage.Job{
ID: datasetID,
JobName: fmt.Sprintf("Dataset %d", i),
Status: "completed",
Priority: 0,
}
err := db.CreateJob(job)
if err != nil {
b.Errorf("Failed to create dataset job %d: %v", i, err)
continue
}
// Simulate dataset creation with metadata
err = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024*(i+1)))
if err != nil {
b.Errorf("Failed to create dataset %d: %v", i, err)
}
err = db.RecordJobMetric(datasetID, "dataset_type", "training")
if err != nil {
b.Errorf("Failed to set dataset type %d: %v", i, err)
}
err = db.RecordJobMetric(datasetID, "created_at", time.Now().Format(time.RFC3339))
if err != nil {
b.Errorf("Failed to set dataset timestamp %d: %v", i, err)
}
}
}
func benchmarkDatasetRetrieval(b *testing.B, db *storage.DB) {
// Pre-create datasets
numDatasets := 100
for i := range numDatasets {
datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano())
// Create a job first
job := &storage.Job{
ID: datasetID,
JobName: fmt.Sprintf("Dataset %d", i),
Status: "completed",
Priority: 0,
}
_ = db.CreateJob(job)
_ = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024*(i+1)))
_ = db.RecordJobMetric(datasetID, "dataset_type", "training")
}
for i := 0; b.Loop(); i++ {
datasetID := fmt.Sprintf("dataset-%d", i%numDatasets)
// Simulate dataset metadata retrieval
// In a real implementation, this would query the database
// For benchmarking, we'll simulate the lookup cost
_ = datasetID
}
}
func benchmarkDatasetUpdate(b *testing.B, db *storage.DB) {
// Pre-create datasets
numDatasets := 50
datasetIDs := make([]string, numDatasets)
for i := range numDatasets {
datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano())
datasetIDs[i] = datasetID
// Create a job first
job := &storage.Job{
ID: datasetID,
JobName: fmt.Sprintf("Dataset %d", i),
Status: "completed",
Priority: 0,
}
_ = db.CreateJob(job)
_ = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024))
}
for i := 0; b.Loop(); i++ {
datasetID := datasetIDs[i%numDatasets]
// Update dataset metadata
err := db.RecordJobMetric(datasetID, fmt.Sprintf("dataset_size_%d", i), fmt.Sprintf("%d", 2048))
if err != nil {
b.Errorf("Failed to update dataset %d: %v", i, err)
}
err = db.RecordJobMetric(datasetID, fmt.Sprintf("last_modified_%d", i), time.Now().Format(time.RFC3339))
if err != nil {
b.Errorf("Failed to update timestamp %d: %v", i, err)
}
}
}
// Helper functions
func setupBenchmarkRedis(b *testing.B) *redis.Client {
// Start in-memory Redis server
s, err := miniredis.Run()
if err != nil {
b.Fatalf("failed to start miniredis: %v", err)
}
rdb := redis.NewClient(&redis.Options{
Addr: s.Addr(),
})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
b.Fatalf("miniredis ping failed: %v", err)
}
b.Cleanup(func() {
rdb.Close()
s.Close()
})
return rdb
}
func createMLExperiment(db *storage.DB, rdb *redis.Client, expID string, numJobs, payloadSize int) error {
ctx := context.Background()
// Create experiment metadata
expJob := &storage.Job{
ID: expID,
JobName: fmt.Sprintf("ML Experiment %s", expID),
Status: "running",
Priority: 1,
Args: fmt.Sprintf(`{"experiment_id": "%s", "num_jobs": %d}`, expID, numJobs),
}
err := db.CreateJob(expJob)
if err != nil {
return fmt.Errorf("failed to create experiment job: %w", err)
}
// Create individual jobs for the experiment
for i := range numJobs {
jobID := fmt.Sprintf("%s-job-%d", expID, i)
payload := generateMLPayload(payloadSize, i)
job := &storage.Job{
ID: jobID,
JobName: fmt.Sprintf("ML Job %s-%d", expID, i),
Status: "pending",
Priority: 1,
Args: payload,
}
err = db.CreateJob(job)
if err != nil {
return fmt.Errorf("failed to create job %d: %w", i, err)
}
// Queue job in Redis
err = rdb.LPush(ctx, "ml:queue", jobID).Err()
if err != nil {
return fmt.Errorf("failed to queue job %d: %w", i, err)
}
}
return nil
}
func executeMLExperiment(db *storage.DB, rdb *redis.Client, expID string, numJobs int) error {
ctx := context.Background()
// Process all jobs in the experiment
for i := range numJobs {
jobID := fmt.Sprintf("%s-job-%d", expID, i)
// Update job status to running
err := db.UpdateJobStatus(jobID, "running", fmt.Sprintf("worker-%d", i%5), "")
if err != nil {
return fmt.Errorf("failed to start job %d: %w", i, err)
}
// Simulate processing time (in real scenario, this would be ML computation)
time.Sleep(time.Microsecond * time.Duration(10+i%100))
// Update job status to completed
err = db.UpdateJobStatus(jobID, "completed", fmt.Sprintf("worker-%d", i%5), "")
if err != nil {
return fmt.Errorf("failed to complete job %d: %w", i, err)
}
// Record job metrics
err = db.RecordJobMetric(jobID, "processing_time", fmt.Sprintf("%.3f", float64(10+i%100)*0.001))
if err != nil {
return fmt.Errorf("failed to record processing time for job %d: %w", i, err)
}
err = db.RecordJobMetric(jobID, "memory_usage", fmt.Sprintf("%d", 1024*(i+1)))
if err != nil {
return fmt.Errorf("failed to record memory usage for job %d: %w", i, err)
}
// Pop from queue
_, err = rdb.LPop(ctx, "ml:queue").Result()
if err != nil {
return fmt.Errorf("failed to pop job %d: %w", i, err)
}
}
// Update experiment status
err := db.UpdateJobStatus(expID, "completed", "coordinator", "")
if err != nil {
return fmt.Errorf("failed to complete experiment: %w", err)
}
return nil
}
func generateMLPayload(size int, seed int) string {
data := make([]byte, size)
for i := range data {
data[i] = byte((i + seed) % 256)
}
return fmt.Sprintf(`{
"model": "test-model",
"data": "%s",
"parameters": {
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 10
},
"seed": %d
}`, string(data[:minInt(len(data), 100)]), seed) // Truncate data for JSON safety
}
func getMLSchema() string {
return fixtures.TestSchema
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}