- Add end-to-end tests for complete workflow validation - Include integration tests for API and database interactions - Add unit tests for all major components and utilities - Include performance tests for payload handling - Add CLI API integration tests - Include Podman container integration tests - Add WebSocket and queue execution tests - Include shell script tests for setup validation Provides comprehensive test coverage ensuring platform reliability and functionality across all components and interactions.
292 lines
8.2 KiB
Go
292 lines
8.2 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
|
)
|
|
|
|
// TestIntegrationE2E tests the complete end-to-end workflow
|
|
func TestIntegrationE2E(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
testDir := t.TempDir()
|
|
ctx := context.Background()
|
|
|
|
// Create test job directory structure
|
|
jobBaseDir := filepath.Join(testDir, "jobs")
|
|
pendingDir := filepath.Join(jobBaseDir, "pending")
|
|
runningDir := filepath.Join(jobBaseDir, "running")
|
|
finishedDir := filepath.Join(jobBaseDir, "finished")
|
|
|
|
for _, dir := range []string{pendingDir, runningDir, finishedDir} {
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
t.Fatalf("Failed to create directory %s: %v", dir, err)
|
|
}
|
|
}
|
|
|
|
// Create standard ML experiment (zero-install style)
|
|
jobDir := filepath.Join(pendingDir, "test_job")
|
|
if err := os.MkdirAll(jobDir, 0755); err != nil {
|
|
t.Fatalf("Failed to create job directory: %v", err)
|
|
}
|
|
|
|
// Create standard ML project files
|
|
trainScript := filepath.Join(jobDir, "train.py")
|
|
requirementsFile := filepath.Join(jobDir, "requirements.txt")
|
|
readmeFile := filepath.Join(jobDir, "README.md")
|
|
|
|
// Create train.py (standard ML script)
|
|
trainCode := `#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train ML model")
|
|
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
|
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
|
|
parser.add_argument("--data_dir", type=str, help="Data directory")
|
|
parser.add_argument("--datasets", type=str, help="Comma-separated datasets")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Starting training: {args.epochs} epochs, lr={args.lr}, batch_size={args.batch_size}")
|
|
|
|
if args.datasets:
|
|
logger.info(f"Using datasets: {args.datasets}")
|
|
|
|
# Simulate training
|
|
for epoch in range(args.epochs):
|
|
loss = 1.0 - (epoch * 0.08)
|
|
accuracy = 0.4 + (epoch * 0.055)
|
|
|
|
logger.info(f"Epoch {epoch + 1}/{args.epochs}: loss={loss:.4f}, accuracy={accuracy:.4f}")
|
|
time.sleep(0.01) # Minimal delay for testing
|
|
|
|
# Save results
|
|
results = {
|
|
"model_type": "test_model",
|
|
"epochs_trained": args.epochs,
|
|
"learning_rate": args.lr,
|
|
"batch_size": args.batch_size,
|
|
"final_accuracy": accuracy,
|
|
"final_loss": loss,
|
|
"datasets": args.datasets.split(",") if args.datasets else []
|
|
}
|
|
|
|
results_file = output_dir / "results.json"
|
|
with open(results_file, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
logger.info(f"Training completed! Results saved to {results_file}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
`
|
|
|
|
if err := os.WriteFile(trainScript, []byte(trainCode), 0755); err != nil {
|
|
t.Fatalf("Failed to create train.py: %v", err)
|
|
}
|
|
|
|
// Create requirements.txt
|
|
requirements := `torch>=1.9.0
|
|
numpy>=1.21.0
|
|
scikit-learn>=1.0.0
|
|
`
|
|
|
|
if err := os.WriteFile(requirementsFile, []byte(requirements), 0644); err != nil {
|
|
t.Fatalf("Failed to create requirements.txt: %v", err)
|
|
}
|
|
|
|
// Create README.md
|
|
readme := `# Test Experiment
|
|
|
|
This is a test experiment for integration testing.
|
|
|
|
## Usage
|
|
python train.py --epochs 2 --lr 0.01 --output_dir ./results
|
|
`
|
|
|
|
if err := os.WriteFile(readmeFile, []byte(readme), 0644); err != nil {
|
|
t.Fatalf("Failed to create README.md: %v", err)
|
|
}
|
|
|
|
// Setup test Redis using fixtures
|
|
redisHelper, err := tests.NewRedisHelper("localhost:6379", 15)
|
|
if err != nil {
|
|
t.Skipf("Redis not available, skipping integration test: %v", err)
|
|
}
|
|
defer redisHelper.Close()
|
|
|
|
// Test Redis connection
|
|
if err := redisHelper.GetClient().Ping(ctx).Err(); err != nil {
|
|
t.Skipf("Redis not available, skipping integration test: %v", err)
|
|
}
|
|
|
|
// Create task queue
|
|
taskQueue, err := tests.NewTaskQueue(&tests.Config{
|
|
RedisAddr: "localhost:6379",
|
|
RedisDB: 15,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create task queue: %v", err)
|
|
}
|
|
defer taskQueue.Close()
|
|
|
|
// Create ML server (local mode)
|
|
mlServer := tests.NewMLServer()
|
|
|
|
// Test 1: Enqueue task (as would happen from TUI)
|
|
task, err := taskQueue.EnqueueTask("test_job", "--epochs 2 --lr 0.01", 5)
|
|
if err != nil {
|
|
t.Fatalf("Failed to enqueue task: %v", err)
|
|
}
|
|
|
|
if task.ID == "" {
|
|
t.Fatal("Task ID should not be empty")
|
|
}
|
|
|
|
if task.JobName != "test_job" {
|
|
t.Errorf("Expected job name 'test_job', got '%s'", task.JobName)
|
|
}
|
|
|
|
if task.Status != "queued" {
|
|
t.Errorf("Expected status 'queued', got '%s'", task.Status)
|
|
}
|
|
|
|
// Test 2: Get next task (as worker would)
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get next task: %v", err)
|
|
}
|
|
|
|
if nextTask == nil {
|
|
t.Fatal("Should have retrieved a task")
|
|
}
|
|
|
|
if nextTask.ID != task.ID {
|
|
t.Errorf("Expected task ID %s, got %s", task.ID, nextTask.ID)
|
|
}
|
|
|
|
// Test 3: Update task status to running
|
|
now := time.Now()
|
|
nextTask.Status = "running"
|
|
nextTask.StartedAt = &now
|
|
|
|
if err := taskQueue.UpdateTask(nextTask); err != nil {
|
|
t.Fatalf("Failed to update task: %v", err)
|
|
}
|
|
|
|
// Test 4: Execute job (zero-install style)
|
|
if err := executeZeroInstallJob(mlServer, nextTask, jobBaseDir, trainScript); err != nil {
|
|
t.Fatalf("Failed to execute job: %v", err)
|
|
}
|
|
|
|
// Test 5: Update task status to completed
|
|
endTime := time.Now()
|
|
nextTask.Status = "completed"
|
|
nextTask.EndedAt = &endTime
|
|
|
|
if err := taskQueue.UpdateTask(nextTask); err != nil {
|
|
t.Fatalf("Failed to update final task status: %v", err)
|
|
}
|
|
|
|
// Test 6: Verify results
|
|
retrievedTask, err := taskQueue.GetTask(nextTask.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve completed task: %v", err)
|
|
}
|
|
|
|
if retrievedTask.Status != "completed" {
|
|
t.Errorf("Expected status 'completed', got '%s'", retrievedTask.Status)
|
|
}
|
|
|
|
if retrievedTask.StartedAt == nil {
|
|
t.Error("StartedAt should not be nil")
|
|
}
|
|
|
|
if retrievedTask.EndedAt == nil {
|
|
t.Error("EndedAt should not be nil")
|
|
}
|
|
|
|
// Test 7: Check job status
|
|
jobStatus, err := taskQueue.GetJobStatus("test_job")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get job status: %v", err)
|
|
}
|
|
|
|
if jobStatus["status"] != "completed" {
|
|
t.Errorf("Expected job status 'completed', got '%s'", jobStatus["status"])
|
|
}
|
|
|
|
// Test 8: Record and check metrics
|
|
if err := taskQueue.RecordMetric("test_job", "accuracy", 0.95); err != nil {
|
|
t.Fatalf("Failed to record metric: %v", err)
|
|
}
|
|
|
|
metrics, err := taskQueue.GetMetrics("test_job")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get metrics: %v", err)
|
|
}
|
|
|
|
if metrics["accuracy"] != "0.95" {
|
|
t.Errorf("Expected accuracy '0.95', got '%s'", metrics["accuracy"])
|
|
}
|
|
|
|
t.Log("End-to-end test completed successfully")
|
|
}
|
|
|
|
// executeZeroInstallJob simulates zero-install job execution
|
|
func executeZeroInstallJob(server *tests.MLServer, task *tests.Task, baseDir, trainScript string) error {
|
|
// Move job to running directory
|
|
pendingPath := filepath.Join(baseDir, "pending", task.JobName)
|
|
runningPath := filepath.Join(baseDir, "running", task.JobName)
|
|
|
|
if err := os.Rename(pendingPath, runningPath); err != nil {
|
|
return fmt.Errorf("failed to move job to running: %w", err)
|
|
}
|
|
|
|
// Execute the job (zero-install style - direct Python execution)
|
|
outputDir := filepath.Join(runningPath, "results")
|
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create output directory: %w", err)
|
|
}
|
|
|
|
cmd := fmt.Sprintf("cd %s && python3 %s --output_dir %s %s",
|
|
runningPath,
|
|
filepath.Base(trainScript),
|
|
outputDir,
|
|
task.Args,
|
|
)
|
|
|
|
output, err := server.Exec(cmd)
|
|
if err != nil {
|
|
return fmt.Errorf("job execution failed: %w, output: %s", err, output)
|
|
}
|
|
|
|
// Move to finished directory
|
|
finishedPath := filepath.Join(baseDir, "finished", task.JobName)
|
|
if err := os.Rename(runningPath, finishedPath); err != nil {
|
|
return fmt.Errorf("failed to move job to finished: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|