401 lines
11 KiB
Go
401 lines
11 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
|
)
|
|
|
|
const (
|
|
redisAddr = "localhost:6379"
|
|
redisDB = 0
|
|
)
|
|
|
|
// TestWorkerLocalMode tests worker functionality with zero-install workflow
|
|
func TestWorkerLocalMode(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Setup test environment
|
|
testDir := t.TempDir()
|
|
|
|
// Create job directory structure
|
|
jobBaseDir := filepath.Join(testDir, "jobs")
|
|
pendingDir := filepath.Join(jobBaseDir, "pending")
|
|
runningDir := filepath.Join(jobBaseDir, "running")
|
|
finishedDir := filepath.Join(jobBaseDir, "finished")
|
|
|
|
for _, dir := range []string{pendingDir, runningDir, finishedDir} {
|
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
|
t.Fatalf("Failed to create directory %s: %v", dir, err)
|
|
}
|
|
}
|
|
|
|
// Create standard ML experiment
|
|
jobDir := filepath.Join(pendingDir, "worker_test")
|
|
if err := os.MkdirAll(jobDir, 0750); err != nil {
|
|
t.Fatalf("Failed to create job directory: %v", err)
|
|
}
|
|
|
|
// Create standard ML project files
|
|
trainScript := filepath.Join(jobDir, "train.py")
|
|
requirementsFile := filepath.Join(jobDir, "requirements.txt")
|
|
|
|
// Create train.py (zero-install style)
|
|
trainCode := `#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train ML model")
|
|
parser.add_argument("--epochs", type=int, default=2, help="Number of epochs")
|
|
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Training: {args.epochs} epochs, lr={args.lr}")
|
|
|
|
# Simulate training
|
|
for epoch in range(args.epochs):
|
|
loss = 1.0 - (epoch * 0.3)
|
|
accuracy = 0.3 + (epoch * 0.3)
|
|
logger.info(f"Epoch {epoch + 1}: loss={loss:.2f}, acc={accuracy:.2f}")
|
|
time.sleep(0.01)
|
|
|
|
# Save results
|
|
results = {
|
|
"final_accuracy": accuracy,
|
|
"final_loss": loss,
|
|
"epochs": args.epochs
|
|
}
|
|
|
|
with open(output_dir / "results.json", "w") as f:
|
|
json.dump(results, f)
|
|
|
|
logger.info("Training complete!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
`
|
|
|
|
//nolint:gosec // G306: Script needs execute permissions
|
|
if err := os.WriteFile(trainScript, []byte(trainCode), 0750); err != nil {
|
|
t.Fatalf("Failed to create train.py: %v", err)
|
|
}
|
|
|
|
// Create requirements.txt
|
|
requirements := `torch>=1.9.0
|
|
numpy>=1.21.0
|
|
`
|
|
|
|
if err := os.WriteFile(requirementsFile, []byte(requirements), 0600); err != nil {
|
|
t.Fatalf("Failed to create requirements.txt: %v", err)
|
|
}
|
|
|
|
// Test local execution
|
|
server := tests.NewMLServer()
|
|
|
|
t.Run("LocalCommandExecution", func(t *testing.T) {
|
|
// Test basic command execution
|
|
output, err := server.Exec("echo 'test'")
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute command: %v", err)
|
|
}
|
|
|
|
expected := "test\n"
|
|
if output != expected {
|
|
t.Errorf("Expected output '%s', got '%s'", expected, output)
|
|
}
|
|
})
|
|
|
|
t.Run("JobDirectoryOperations", func(t *testing.T) {
|
|
// Test directory listing
|
|
output, err := server.Exec(fmt.Sprintf("ls -1 %s", pendingDir))
|
|
if err != nil {
|
|
t.Fatalf("Failed to list directory: %v", err)
|
|
}
|
|
|
|
if output != "worker_test\n" {
|
|
t.Errorf("Expected 'worker_test', got '%s'", output)
|
|
}
|
|
|
|
// Test file operations
|
|
output, err = server.Exec(fmt.Sprintf("test -f %s && echo 'exists'", trainScript))
|
|
if err != nil {
|
|
t.Fatalf("Failed to test file existence: %v", err)
|
|
}
|
|
|
|
if output != "exists\n" {
|
|
t.Errorf("Expected 'exists', got '%s'", output)
|
|
}
|
|
})
|
|
|
|
t.Run("ZeroInstallJobExecution", func(t *testing.T) {
|
|
// Create output directory
|
|
outputDir := filepath.Join(jobDir, "results")
|
|
if err := os.MkdirAll(outputDir, 0750); err != nil {
|
|
t.Fatalf("Failed to create output directory: %v", err)
|
|
}
|
|
|
|
// Execute job (zero-install style - direct Python execution)
|
|
cmd := fmt.Sprintf("cd %s && python3 train.py --epochs 1 --output_dir %s",
|
|
jobDir, outputDir)
|
|
|
|
output, err := server.Exec(cmd)
|
|
if err != nil {
|
|
t.Logf("Command execution failed (expected in test environment): %v", err)
|
|
t.Logf("Output: %s", output)
|
|
}
|
|
|
|
// Create expected results manually for testing
|
|
resultsFile := filepath.Join(outputDir, "results.json")
|
|
resultsJSON := `{
|
|
"final_accuracy": 0.6,
|
|
"final_loss": 0.7,
|
|
"epochs": 1
|
|
}`
|
|
|
|
if err := os.WriteFile(resultsFile, []byte(resultsJSON), 0600); err != nil {
|
|
t.Fatalf("Failed to create results file: %v", err)
|
|
}
|
|
|
|
// Check if results file was created
|
|
if _, err := os.Stat(resultsFile); os.IsNotExist(err) {
|
|
t.Errorf("Results file should exist: %s", resultsFile)
|
|
}
|
|
|
|
// Verify job completed successfully
|
|
if output == "" {
|
|
t.Log("No output from job execution (expected in test environment)")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestWorkerConfiguration tests worker configuration loading
|
|
func TestWorkerConfiguration(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
t.Run("DefaultConfig", func(t *testing.T) {
|
|
cfg := &tests.Config{}
|
|
|
|
// Test defaults
|
|
if cfg.RedisAddr == "" {
|
|
cfg.RedisAddr = redisAddr
|
|
}
|
|
if cfg.RedisDB == 0 {
|
|
cfg.RedisDB = redisDB
|
|
}
|
|
|
|
if cfg.RedisAddr != redisAddr {
|
|
t.Errorf("Expected default Redis address '%s', got '%s'", redisAddr, cfg.RedisAddr)
|
|
}
|
|
|
|
if cfg.RedisDB != redisDB {
|
|
t.Errorf("Expected default Redis DB %d, got %d", redisDB, cfg.RedisDB)
|
|
}
|
|
})
|
|
|
|
t.Run("ConfigFileLoading", func(t *testing.T) {
|
|
// Create temporary config file
|
|
configContent := `
|
|
redis_addr: "localhost:6379"
|
|
redis_password: ""
|
|
redis_db: 1
|
|
`
|
|
|
|
configFile := filepath.Join(t.TempDir(), "test_config.yaml")
|
|
if err := os.WriteFile(configFile, []byte(configContent), 0600); err != nil {
|
|
t.Fatalf("Failed to create config file: %v", err)
|
|
}
|
|
|
|
// Load config
|
|
cfg, err := tests.LoadConfig(configFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
if cfg.RedisAddr != redisAddr {
|
|
t.Errorf("Expected Redis address '%s', got '%s'", redisAddr, cfg.RedisAddr)
|
|
}
|
|
|
|
if cfg.RedisDB != 1 {
|
|
t.Errorf("Expected Redis DB 1, got %d", cfg.RedisDB)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestWorkerTaskProcessing tests worker task processing capabilities
|
|
func TestWorkerTaskProcessing(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
ctx := context.Background()
|
|
redisDBTaskProcessing := 16
|
|
|
|
// Setup test Redis using fixtures
|
|
redisHelper, err := tests.NewRedisHelper(redisAddr, redisDBTaskProcessing)
|
|
if err != nil {
|
|
t.Skipf("Redis not available, skipping test: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = redisHelper.FlushDB()
|
|
_ = redisHelper.Close()
|
|
}()
|
|
|
|
if pingErr := redisHelper.GetClient().Ping(ctx).Err(); pingErr != nil {
|
|
t.Skipf("Redis not available, skipping test: %v", pingErr)
|
|
}
|
|
|
|
// Create task queue
|
|
taskQueue, err := tests.NewTaskQueue(&tests.Config{
|
|
RedisAddr: redisAddr,
|
|
RedisDB: redisDBTaskProcessing,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create task queue: %v", err)
|
|
}
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
t.Run("TaskLifecycle", func(t *testing.T) {
|
|
// Create a task
|
|
task, err := taskQueue.EnqueueTask("lifecycle_test", "--epochs 2", 5)
|
|
if err != nil {
|
|
t.Fatalf("Failed to enqueue task: %v", err)
|
|
}
|
|
|
|
// Verify initial state
|
|
if task.Status != "queued" {
|
|
t.Errorf("Expected status 'queued', got '%s'", task.Status)
|
|
}
|
|
|
|
// Get task from queue
|
|
nextTask, nextErr := taskQueue.GetNextTask()
|
|
if nextErr != nil {
|
|
t.Fatalf("Failed to get next task: %v", nextErr)
|
|
}
|
|
|
|
if nextTask.ID != task.ID {
|
|
t.Errorf("Expected task ID %s, got %s", task.ID, nextTask.ID)
|
|
}
|
|
|
|
// Update to running
|
|
now := time.Now()
|
|
nextTask.Status = "running"
|
|
nextTask.StartedAt = &now
|
|
nextTask.WorkerID = "test-worker"
|
|
|
|
if updateErr := taskQueue.UpdateTask(nextTask); updateErr != nil {
|
|
t.Fatalf("Failed to update task: %v", updateErr)
|
|
}
|
|
|
|
// Verify running state
|
|
retrievedTask, err := taskQueue.GetTask(nextTask.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve task: %v", err)
|
|
}
|
|
|
|
if retrievedTask.Status != "running" {
|
|
t.Errorf("Expected status 'running', got '%s'", retrievedTask.Status)
|
|
}
|
|
|
|
if retrievedTask.WorkerID != "test-worker" {
|
|
t.Errorf("Expected worker ID 'test-worker', got '%s'", retrievedTask.WorkerID)
|
|
}
|
|
|
|
// Update to completed
|
|
endTime := time.Now()
|
|
retrievedTask.Status = statusCompleted
|
|
retrievedTask.EndedAt = &endTime
|
|
|
|
if updateErr := taskQueue.UpdateTask(retrievedTask); updateErr != nil {
|
|
t.Fatalf("Failed to update task to completed: %v", updateErr)
|
|
}
|
|
|
|
// Verify completed state
|
|
finalTask, err := taskQueue.GetTask(retrievedTask.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve final task: %v", err)
|
|
}
|
|
|
|
if finalTask.Status != "completed" {
|
|
t.Errorf("Expected status 'completed', got '%s'", finalTask.Status)
|
|
}
|
|
|
|
if finalTask.StartedAt == nil {
|
|
t.Error("StartedAt should not be nil")
|
|
}
|
|
|
|
if finalTask.EndedAt == nil {
|
|
t.Error("EndedAt should not be nil")
|
|
}
|
|
|
|
// Verify task duration
|
|
duration := finalTask.EndedAt.Sub(*finalTask.StartedAt)
|
|
if duration < 0 {
|
|
t.Error("Duration should be positive")
|
|
}
|
|
})
|
|
|
|
t.Run("TaskMetrics", func(t *testing.T) {
|
|
// Create a task for metrics testing
|
|
_, enqueueErr := taskQueue.EnqueueTask("metrics_test", "--lr 0.01", 5)
|
|
if enqueueErr != nil {
|
|
t.Fatalf("Failed to enqueue task: %v", enqueueErr)
|
|
}
|
|
|
|
// Process the task
|
|
nextTask, nextErr := taskQueue.GetNextTask()
|
|
if nextErr != nil {
|
|
t.Fatalf("Failed to get next task: %v", nextErr)
|
|
}
|
|
|
|
// Simulate task completion with metrics
|
|
now := time.Now()
|
|
nextTask.Status = "completed"
|
|
nextTask.StartedAt = &now
|
|
endTime := now.Add(5 * time.Second) // Simulate 5 second execution
|
|
nextTask.EndedAt = &endTime
|
|
|
|
if updateErr := taskQueue.UpdateTask(nextTask); updateErr != nil {
|
|
t.Fatalf("Failed to update task: %v", updateErr)
|
|
}
|
|
|
|
// Record metrics
|
|
duration := nextTask.EndedAt.Sub(*nextTask.StartedAt).Seconds()
|
|
if metricErr := taskQueue.RecordMetric(nextTask.JobName, "execution_time", duration); metricErr != nil {
|
|
t.Fatalf("Failed to record execution time: %v", metricErr)
|
|
}
|
|
|
|
if metricErr := taskQueue.RecordMetric(nextTask.JobName, "accuracy", 0.95); metricErr != nil {
|
|
t.Fatalf("Failed to record accuracy: %v", metricErr)
|
|
}
|
|
|
|
// Verify metrics
|
|
metrics, metricsErr := taskQueue.GetMetrics(nextTask.JobName)
|
|
if metricsErr != nil {
|
|
t.Fatalf("Failed to get metrics: %v", metricsErr)
|
|
}
|
|
|
|
if metrics["execution_time"] != "5" {
|
|
t.Errorf("Expected execution time '5', got '%s'", metrics["execution_time"])
|
|
}
|
|
|
|
if metrics["accuracy"] != "0.95" {
|
|
t.Errorf("Expected accuracy '0.95', got '%s'", metrics["accuracy"])
|
|
}
|
|
})
|
|
}
|