package benchmarks import ( "context" "fmt" "path/filepath" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/storage" fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" "github.com/redis/go-redis/v9" ) // BenchmarkMLExperimentExecution measures ML experiment performance func BenchmarkMLExperimentExecution(b *testing.B) { // Setup test environment tempDir := b.TempDir() rdb := setupBenchmarkRedis(b) if rdb == nil { b.Skip("Redis not available") } defer func() { _ = rdb.Close() }() db, err := storage.NewDBFromPath(filepath.Join(tempDir, "test.db")) if err != nil { b.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema err = db.Initialize(getMLSchema()) if err != nil { b.Fatalf("Failed to initialize database: %v", err) } b.ResetTimer() b.ReportAllocs() // Benchmark different experiment types b.Run("SmallExperiment", func(b *testing.B) { benchmarkMLExperiment(b, db, rdb, "small", 100, 1024) // 100 jobs, 1KB each }) b.Run("MediumExperiment", func(b *testing.B) { benchmarkMLExperiment(b, db, rdb, "medium", 50, 10240) // 50 jobs, 10KB each }) b.Run("LargeExperiment", func(b *testing.B) { benchmarkMLExperiment(b, db, rdb, "large", 10, 102400) // 10 jobs, 100KB each }) b.Run("ConcurrentExperiments", func(b *testing.B) { benchmarkConcurrentExperiments(b, db, rdb) }) b.Run("ExperimentMetrics", func(b *testing.B) { benchmarkExperimentMetrics(b, db, rdb) }) } // benchmarkMLExperiment tests ML experiment execution performance func benchmarkMLExperiment(b *testing.B, db *storage.DB, rdb *redis.Client, expType string, numJobs, payloadSize int) { m := &metrics.Metrics{} for i := 0; i < b.N; i++ { expID := fmt.Sprintf("exp-%s-%d-%d", expType, i, time.Now().UnixNano()) // Create experiment start := time.Now() err := createMLExperiment(db, rdb, expID, numJobs, payloadSize) if err != nil { b.Fatalf("Failed to create experiment: %v", err) } m.RecordTaskStart() // Simulate experiment execution err = executeMLExperiment(db, rdb, expID, numJobs) if err != nil { b.Fatalf("Failed to execute experiment: %v", err) } m.RecordTaskCompletion() executionTime := time.Since(start) m.RecordDataTransfer(int64(numJobs*payloadSize), 0) _ = executionTime // Use executionTime to avoid unused variable warning // Record experiment metrics for j := range 5 { metricName := fmt.Sprintf("metric_%d_%d", j, i) err := db.RecordJobMetric(expID, metricName, fmt.Sprintf("%.2f", float64(j)*1.5)) if err != nil { b.Errorf("Failed to record metric %s: %v", metricName, err) } } } } // benchmarkConcurrentExperiments tests concurrent experiment execution func benchmarkConcurrentExperiments(b *testing.B, db *storage.DB, rdb *redis.Client) { numExperiments := 5 jobsPerExperiment := 20 payloadSize := 5120 // 5KB b.ResetTimer() // Create experiments concurrently for i := 0; b.Loop(); i++ { done := make(chan bool, numExperiments) for exp := range numExperiments { go func(expID int) { defer func() { done <- true }() expName := fmt.Sprintf("concurrent-exp-%d-%d-%d", i, expID, time.Now().UnixNano()) err := createMLExperiment(db, rdb, expName, jobsPerExperiment, payloadSize) if err != nil { b.Errorf("Failed to create experiment %d: %v", expID, err) return } err = executeMLExperiment(db, rdb, expName, jobsPerExperiment) if err != nil { b.Errorf("Failed to execute experiment %d: %v", expID, err) } }(exp) } // Wait for all experiments to complete for j := 0; j < numExperiments; j++ { <-done } } } // benchmarkExperimentMetrics tests metrics recording performance func benchmarkExperimentMetrics(b *testing.B, db *storage.DB, _ *redis.Client) { metricsPerJob := 10 numJobs := 100 // Create test jobs jobIDs := make([]string, numJobs) for i := range jobIDs { jobIDs[i] = fmt.Sprintf("metrics-job-%d-%d", i, time.Now().UnixNano()) job := &storage.Job{ ID: jobIDs[i], JobName: fmt.Sprintf("Metrics Job %d", i), Status: "completed", Priority: 0, } err := db.CreateJob(job) if err != nil { b.Fatalf("Failed to create job %d: %v", i, err) } } b.ResetTimer() b.ReportAllocs() // Record metrics for all jobs for i := 0; b.Loop(); i++ { for _, jobID := range jobIDs { for j := range metricsPerJob { metricName := fmt.Sprintf("metric_%d_%d", j, i) metricValue := fmt.Sprintf("%.6f", float64(i*j)*0.001) err := db.RecordJobMetric(jobID, metricName, metricValue) if err != nil { b.Errorf("Failed to record metric %s for job %s: %v", metricName, jobID, err) } } } } } // BenchmarkDatasetOperations tests dataset-related performance func BenchmarkDatasetOperations(b *testing.B) { tempDir := b.TempDir() db, err := storage.NewDBFromPath(filepath.Join(tempDir, "test.db")) if err != nil { b.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() err = db.Initialize(getMLSchema()) if err != nil { b.Fatalf("Failed to initialize database: %v", err) } b.ResetTimer() b.ReportAllocs() b.Run("DatasetCreation", func(b *testing.B) { benchmarkDatasetCreation(b, db) }) b.Run("DatasetRetrieval", func(b *testing.B) { benchmarkDatasetRetrieval(b, db) }) b.Run("DatasetUpdate", func(b *testing.B) { benchmarkDatasetUpdate(b, db) }) } func benchmarkDatasetCreation(b *testing.B, db *storage.DB) { for i := 0; b.Loop(); i++ { datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano()) // Create a job first for foreign key constraint job := &storage.Job{ ID: datasetID, JobName: fmt.Sprintf("Dataset %d", i), Status: "completed", Priority: 0, } err := db.CreateJob(job) if err != nil { b.Errorf("Failed to create dataset job %d: %v", i, err) continue } // Simulate dataset creation with metadata err = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024*(i+1))) if err != nil { b.Errorf("Failed to create dataset %d: %v", i, err) } err = db.RecordJobMetric(datasetID, "dataset_type", "training") if err != nil { b.Errorf("Failed to set dataset type %d: %v", i, err) } err = db.RecordJobMetric(datasetID, "created_at", time.Now().Format(time.RFC3339)) if err != nil { b.Errorf("Failed to set dataset timestamp %d: %v", i, err) } } } func benchmarkDatasetRetrieval(b *testing.B, db *storage.DB) { // Pre-create datasets numDatasets := 100 for i := range numDatasets { datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano()) // Create a job first job := &storage.Job{ ID: datasetID, JobName: fmt.Sprintf("Dataset %d", i), Status: "completed", Priority: 0, } _ = db.CreateJob(job) _ = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024*(i+1))) _ = db.RecordJobMetric(datasetID, "dataset_type", "training") } for i := 0; b.Loop(); i++ { datasetID := fmt.Sprintf("dataset-%d", i%numDatasets) // Simulate dataset metadata retrieval // In a real implementation, this would query the database // For benchmarking, we'll simulate the lookup cost _ = datasetID } } func benchmarkDatasetUpdate(b *testing.B, db *storage.DB) { // Pre-create datasets numDatasets := 50 datasetIDs := make([]string, numDatasets) for i := range numDatasets { datasetID := fmt.Sprintf("dataset-%d-%d", i, time.Now().UnixNano()) datasetIDs[i] = datasetID // Create a job first job := &storage.Job{ ID: datasetID, JobName: fmt.Sprintf("Dataset %d", i), Status: "completed", Priority: 0, } _ = db.CreateJob(job) _ = db.RecordJobMetric(datasetID, "dataset_size", fmt.Sprintf("%d", 1024)) } for i := 0; b.Loop(); i++ { datasetID := datasetIDs[i%numDatasets] // Update dataset metadata err := db.RecordJobMetric(datasetID, fmt.Sprintf("dataset_size_%d", i), fmt.Sprintf("%d", 2048)) if err != nil { b.Errorf("Failed to update dataset %d: %v", i, err) } err = db.RecordJobMetric(datasetID, fmt.Sprintf("last_modified_%d", i), time.Now().Format(time.RFC3339)) if err != nil { b.Errorf("Failed to update timestamp %d: %v", i, err) } } } // Helper functions func setupBenchmarkRedis(b *testing.B) *redis.Client { // Start in-memory Redis server s, err := miniredis.Run() if err != nil { b.Fatalf("failed to start miniredis: %v", err) } rdb := redis.NewClient(&redis.Options{ Addr: s.Addr(), }) ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { b.Fatalf("miniredis ping failed: %v", err) } b.Cleanup(func() { rdb.Close() s.Close() }) return rdb } func createMLExperiment(db *storage.DB, rdb *redis.Client, expID string, numJobs, payloadSize int) error { ctx := context.Background() // Create experiment metadata expJob := &storage.Job{ ID: expID, JobName: fmt.Sprintf("ML Experiment %s", expID), Status: "running", Priority: 1, Args: fmt.Sprintf(`{"experiment_id": "%s", "num_jobs": %d}`, expID, numJobs), } err := db.CreateJob(expJob) if err != nil { return fmt.Errorf("failed to create experiment job: %w", err) } // Create individual jobs for the experiment for i := range numJobs { jobID := fmt.Sprintf("%s-job-%d", expID, i) payload := generateMLPayload(payloadSize, i) job := &storage.Job{ ID: jobID, JobName: fmt.Sprintf("ML Job %s-%d", expID, i), Status: "pending", Priority: 1, Args: payload, } err = db.CreateJob(job) if err != nil { return fmt.Errorf("failed to create job %d: %w", i, err) } // Queue job in Redis err = rdb.LPush(ctx, "ml:queue", jobID).Err() if err != nil { return fmt.Errorf("failed to queue job %d: %w", i, err) } } return nil } func executeMLExperiment(db *storage.DB, rdb *redis.Client, expID string, numJobs int) error { ctx := context.Background() // Process all jobs in the experiment for i := range numJobs { jobID := fmt.Sprintf("%s-job-%d", expID, i) // Update job status to running err := db.UpdateJobStatus(jobID, "running", fmt.Sprintf("worker-%d", i%5), "") if err != nil { return fmt.Errorf("failed to start job %d: %w", i, err) } // Simulate processing time (in real scenario, this would be ML computation) time.Sleep(time.Microsecond * time.Duration(10+i%100)) // Update job status to completed err = db.UpdateJobStatus(jobID, "completed", fmt.Sprintf("worker-%d", i%5), "") if err != nil { return fmt.Errorf("failed to complete job %d: %w", i, err) } // Record job metrics err = db.RecordJobMetric(jobID, "processing_time", fmt.Sprintf("%.3f", float64(10+i%100)*0.001)) if err != nil { return fmt.Errorf("failed to record processing time for job %d: %w", i, err) } err = db.RecordJobMetric(jobID, "memory_usage", fmt.Sprintf("%d", 1024*(i+1))) if err != nil { return fmt.Errorf("failed to record memory usage for job %d: %w", i, err) } // Pop from queue _, err = rdb.LPop(ctx, "ml:queue").Result() if err != nil { return fmt.Errorf("failed to pop job %d: %w", i, err) } } // Update experiment status err := db.UpdateJobStatus(expID, "completed", "coordinator", "") if err != nil { return fmt.Errorf("failed to complete experiment: %w", err) } return nil } func generateMLPayload(size int, seed int) string { data := make([]byte, size) for i := range data { data[i] = byte((i + seed) % 256) } return fmt.Sprintf(`{ "model": "test-model", "data": "%s", "parameters": { "learning_rate": 0.001, "batch_size": 32, "epochs": 10 }, "seed": %d }`, string(data[:minInt(len(data), 100)]), seed) // Truncate data for JSON safety } func getMLSchema() string { return fixtures.TestSchema } func minInt(a, b int) int { if a < b { return a } return b }