fetch_ml/tests/fixtures/test_utils.go
Jeremie Fraeys 5f53104fcd
test: modernize test suite for streamlined infrastructure
- Update E2E tests for consolidated docker-compose.test.yml
- Remove references to obsolete logs-debug.yml
- Enhance test fixtures and utilities
- Improve integration test coverage for KMS, queue, scheduler
- Update unit tests for config constants and worker execution
- Modernize cleanup-status.sh with new Makefile targets
2026-03-04 13:24:24 -05:00

565 lines
14 KiB
Go

// Package tests provides test utilities and fixtures.
package tests
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/redis/go-redis/v9"
"gopkg.in/yaml.v3"
)
// TestSchema is the shared database schema for testing
const TestSchema = `
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
job_name TEXT NOT NULL,
args TEXT,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
ended_at DATETIME,
worker_id TEXT,
error TEXT,
datasets TEXT,
metadata TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS workers (
id TEXT PRIMARY KEY,
hostname TEXT NOT NULL,
last_heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP,
status TEXT NOT NULL DEFAULT 'active',
current_jobs INTEGER DEFAULT 0,
max_jobs INTEGER DEFAULT 1,
metadata TEXT
);
CREATE TABLE IF NOT EXISTS job_metrics (
job_id TEXT NOT NULL,
metric_name TEXT NOT NULL,
metric_value TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (job_id, metric_name),
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS system_metrics (
metric_name TEXT,
metric_value TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (metric_name, timestamp)
);
`
// Config holds test configuration
type Config struct {
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
}
// Task struct for testing
type Task struct {
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Priority int64 `json:"priority"`
}
// TaskQueue for testing
type TaskQueue struct {
client *redis.Client
ctx context.Context
}
const (
taskQueueKey = "ml:queue"
taskPrefix = "ml:task:"
taskStatusPrefix = "ml:status:"
jobMetricsPrefix = "ml:metrics:"
)
// NewTaskQueue creates a new task queue for testing
func NewTaskQueue(cfg *Config) (*TaskQueue, error) {
rdb := redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("redis connection failed: %w", err)
}
return &TaskQueue{client: rdb, ctx: ctx}, nil
}
// EnsureRedis ensures a Redis instance is running on localhost:6379.
// If none is found, it starts a temporary instance and returns a cleanup function.
func EnsureRedis(t *testing.T) (cleanup func()) {
const redisAddr = "localhost:6379"
// Try to connect first
rdb := redis.NewClient(&redis.Options{Addr: redisAddr})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := rdb.Ping(ctx).Err(); err == nil {
// Redis is already running
return func() {}
}
// Start temporary Redis
t.Logf("Starting temporary Redis on %s", redisAddr)
// Check if redis-server is available
if _, err := exec.LookPath("redis-server"); err != nil {
t.Skip("Skipping: redis-server not available in PATH")
}
cmd := exec.CommandContext(
context.Background(),
"redis-server",
"--daemonize",
"yes",
"--port",
"6379",
)
if out, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("Failed to start temporary Redis: %v; output: %s", err, string(out))
}
// Give it a moment to start
time.Sleep(1 * time.Second)
// Verify it started
if err := rdb.Ping(context.Background()).Err(); err != nil {
t.Fatalf("Temporary Redis failed to start: %v", err)
}
// Return cleanup function
return func() {
shutdown := exec.CommandContext(context.Background(), "redis-cli", "-p", "6379", "shutdown")
_ = shutdown.Run() // ignore errors
}
}
// EnqueueTask adds a task to the queue
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
task := &Task{
ID: uuid.New().String(),
JobName: jobName,
Args: args,
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
}
taskData, err := json.Marshal(task)
if err != nil {
return nil, err
}
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, taskPrefix+task.ID, taskData, 0)
pipe.ZAdd(tq.ctx, taskQueueKey, redis.Z{Score: float64(priority), Member: task.ID})
pipe.HSet(tq.ctx, taskStatusPrefix+task.JobName, "status", "queued", "task_id", task.ID)
if _, err := pipe.Exec(tq.ctx); err != nil {
return nil, err
}
return task, nil
}
// GetNextTask retrieves the next highest priority task
func (tq *TaskQueue) GetNextTask() (*Task, error) {
result, err := tq.client.ZPopMax(tq.ctx, taskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
return tq.GetTask(taskID)
}
// GetTask retrieves a task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
data, err := tq.client.Get(tq.ctx, taskPrefix+taskID).Result()
if err != nil {
return nil, err
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
return nil, err
}
return &task, nil
}
// UpdateTask updates a task's status and metadata
func (tq *TaskQueue) UpdateTask(task *Task) error {
taskData, err := json.Marshal(task)
if err != nil {
return err
}
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, taskPrefix+task.ID, taskData, 0)
pipe.HSet(
tq.ctx,
taskStatusPrefix+task.JobName,
"status",
task.Status,
"updated_at",
time.Now().Format(time.RFC3339),
)
_, err = pipe.Exec(tq.ctx)
return err
}
// CancelTask cancels a task
func (tq *TaskQueue) CancelTask(taskID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
task.Status = "cancelled"
now := time.Now()
task.EndedAt = &now
pipe := tq.client.Pipeline()
pipe.ZRem(tq.ctx, taskQueueKey, taskID)
if err := tq.UpdateTask(task); err != nil {
return err
}
_, err = pipe.Exec(tq.ctx)
return err
}
// GetJobStatus retrieves the status of a job
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
return tq.client.HGetAll(tq.ctx, taskStatusPrefix+jobName).Result()
}
// RecordMetric records a metric for a job
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
key := jobMetricsPrefix + jobName
return tq.client.HSet(tq.ctx, key, metric, value, "timestamp", time.Now().Unix()).Err()
}
// GetMetrics retrieves all metrics for a job
func (tq *TaskQueue) GetMetrics(jobName string) (map[string]string, error) {
return tq.client.HGetAll(tq.ctx, jobMetricsPrefix+jobName).Result()
}
// Close closes the task queue
func (tq *TaskQueue) Close() error {
return tq.client.Close()
}
// ManageScript provides utilities for manage.sh operations
type ManageScript struct {
path string
dir string
}
// NewManageScript creates a new manage script utility
func NewManageScript(path string) *ManageScript {
return &ManageScript{path: path}
}
// NewManageScriptWithDir creates a new manage script utility with a specific working directory
func NewManageScriptWithDir(path, dir string) *ManageScript {
return &ManageScript{path: path, dir: dir}
}
func (ms *ManageScript) setDir(cmd *exec.Cmd) {
if ms.dir != "" {
cmd.Dir = ms.dir
}
}
// Status gets the status of services
func (ms *ManageScript) Status() (string, error) {
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
cmd := exec.CommandContext(context.Background(), ms.path, "status")
ms.setDir(cmd)
output, err := cmd.CombinedOutput()
return string(output), err
}
// Start starts the services
func (ms *ManageScript) Start() error {
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
cmd := exec.CommandContext(context.Background(), ms.path, "start")
ms.setDir(cmd)
return cmd.Run()
}
// Stop stops the services
func (ms *ManageScript) Stop() error {
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
cmd := exec.CommandContext(context.Background(), ms.path, "stop")
ms.setDir(cmd)
return cmd.Run()
}
// Cleanup cleans up any artifacts created by services
func (ms *ManageScript) Cleanup() error {
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
cmd := exec.CommandContext(context.Background(), ms.path, "cleanup")
ms.setDir(cmd)
return cmd.Run()
}
// StopAndCleanup ensures cleanup when called with defer
func (ms *ManageScript) StopAndCleanup() {
_ = ms.Stop()
_ = ms.Cleanup()
}
// Health checks the health of services
func (ms *ManageScript) Health() (string, error) {
//nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility
cmd := exec.CommandContext(context.Background(), ms.path, "health")
ms.setDir(cmd)
output, err := cmd.CombinedOutput()
return string(output), err
}
// RedisHelper provides utilities for Redis operations
type RedisHelper struct {
client *redis.Client
ctx context.Context
}
// NewRedisHelper creates a new Redis helper
func NewRedisHelper(addr string, db int) (*RedisHelper, error) {
rdb := redis.NewClient(&redis.Options{
Addr: addr,
Password: "",
DB: db,
})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("redis connection failed: %w", err)
}
return &RedisHelper{client: rdb, ctx: ctx}, nil
}
// Close closes the Redis connection
func (rh *RedisHelper) Close() error {
return rh.client.Close()
}
// FlushDB flushes the Redis database
func (rh *RedisHelper) FlushDB() error {
return rh.client.FlushDB(rh.ctx).Err()
}
// GetClient returns the underlying Redis client
func (rh *RedisHelper) GetClient() *redis.Client {
return rh.client
}
// ExamplesDir provides utilities for working with example projects
type ExamplesDir struct {
path string
}
// NewExamplesDir creates a new examples directory utility
func NewExamplesDir(basePath string) *ExamplesDir {
return &ExamplesDir{path: basePath}
}
// GetPath returns the path to an example project
func (ed *ExamplesDir) GetPath(projectName string) string {
return filepath.Join(ed.path, projectName)
}
// ListProjects returns a list of all example projects
func (ed *ExamplesDir) ListProjects() ([]string, error) {
entries, err := os.ReadDir(ed.path)
if err != nil {
return nil, err
}
var projects []string
for _, entry := range entries {
if entry.IsDir() {
projects = append(projects, entry.Name())
}
}
return projects, nil
}
// CopyProject copies an example project to a destination
func (ed *ExamplesDir) CopyProject(projectName, dest string) error {
src := ed.GetPath(projectName)
return CopyDir(src, dest)
}
// MLServer minimal implementation for testing
type MLServer struct {
client any // In real implementation this would be *ssh.Client
}
// NewMLServer creates a new MLServer instance for testing
func NewMLServer() *MLServer {
return &MLServer{
client: nil, // Local mode by default
}
}
// Exec runs a command either locally or via SSH (stubbed for tests)
func (s *MLServer) Exec(cmd string) (string, error) {
if s.client == nil {
// Local mode
out, err := exec.CommandContext(context.Background(), "sh", "-c", cmd).CombinedOutput()
return string(out), err
}
// SSH mode would be implemented here
return "", fmt.Errorf("SSH mode not implemented in tests")
}
// Close closes the ML server connection
func (s *MLServer) Close() error {
return nil
}
// LoadConfig loads configuration for testing
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = "localhost:6379"
}
if cfg.RedisDB == 0 {
cfg.RedisDB = 0
}
return &cfg, nil
}
// CopyDir copies a directory recursively
func CopyDir(src, dst string) error {
srcInfo, err := os.Stat(src)
if err != nil {
return err
}
// Create the destination directory with the same permissions as source
if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil {
return err
}
entries, err := os.ReadDir(src)
if err != nil {
return err
}
for _, entry := range entries {
srcPath := filepath.Join(src, entry.Name())
dstPath := filepath.Join(dst, entry.Name())
if entry.IsDir() {
if err := CopyDir(srcPath, dstPath); err != nil {
return err
}
} else {
if err := copyFile(srcPath, dstPath); err != nil {
return err
}
}
}
return nil
}
func copyFile(src, dst string) error {
//nolint:gosec // G304: Potential file inclusion via variable - this is a test utility
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer func() { _ = srcFile.Close() }()
srcInfo, err := srcFile.Stat()
if err != nil {
return err
}
//nolint:gosec // G304: Potential file inclusion via variable - this is a test utility
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode())
if err != nil {
return err
}
defer func() { _ = dstFile.Close() }()
_, err = io.Copy(dstFile, srcFile)
return err
}
// CreateMLProject creates an ML project from a template
func CreateMLProject(t *testing.T, testDir, projectName string, template MLProjectTemplate) {
experimentDir := filepath.Join(testDir, projectName)
if err := os.MkdirAll(experimentDir, 0750); err != nil {
t.Fatalf("Failed to create experiment directory: %v", err)
}
// Create training script
trainScript := filepath.Join(experimentDir, "train.py")
if err := os.WriteFile(trainScript, []byte(template.Entrypoint), 0600); err != nil {
t.Fatalf("Failed to create train.py: %v", err)
}
// Create requirements.txt
requirementsFile := filepath.Join(experimentDir, "requirements.txt")
if err := os.WriteFile(requirementsFile, []byte(template.Requirements), 0600); err != nil {
t.Fatalf("Failed to create requirements.txt: %v", err)
}
// Verify project structure
if _, err := os.Stat(trainScript); os.IsNotExist(err) {
t.Errorf("%s train.py should exist", template.Name)
}
if _, err := os.Stat(requirementsFile); os.IsNotExist(err) {
t.Errorf("%s requirements.txt should exist", template.Name)
}
}