// Package tests provides test utilities and fixtures. package tests import ( "context" "encoding/json" "fmt" "io" "os" "os/exec" "path/filepath" "testing" "time" "github.com/google/uuid" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/redis/go-redis/v9" "gopkg.in/yaml.v3" ) // TestSchema is the shared database schema for testing const TestSchema = ` CREATE TABLE IF NOT EXISTS jobs ( id TEXT PRIMARY KEY, job_name TEXT NOT NULL, args TEXT, status TEXT NOT NULL DEFAULT 'pending', priority INTEGER DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, started_at DATETIME, ended_at DATETIME, worker_id TEXT, error TEXT, datasets TEXT, metadata TEXT, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS workers ( id TEXT PRIMARY KEY, hostname TEXT NOT NULL, last_heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP, status TEXT NOT NULL DEFAULT 'active', current_jobs INTEGER DEFAULT 0, max_jobs INTEGER DEFAULT 1, metadata TEXT ); CREATE TABLE IF NOT EXISTS job_metrics ( job_id TEXT NOT NULL, metric_name TEXT NOT NULL, metric_value TEXT NOT NULL, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (job_id, metric_name), FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS system_metrics ( metric_name TEXT, metric_value TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (metric_name, timestamp) ); ` // Config holds test configuration type Config struct { RedisAddr string `yaml:"redis_addr"` RedisPassword string `yaml:"redis_password"` RedisDB int `yaml:"redis_db"` } // Task struct for testing type Task struct { ID string `json:"id"` JobName string `json:"job_name"` Args string `json:"args"` Status string `json:"status"` Priority int64 `json:"priority"` CreatedAt time.Time `json:"created_at"` StartedAt *time.Time `json:"started_at,omitempty"` EndedAt *time.Time `json:"ended_at,omitempty"` WorkerID string `json:"worker_id,omitempty"` Error string `json:"error,omitempty"` } // TaskQueue for testing type TaskQueue struct { client *redis.Client ctx context.Context } const ( taskQueueKey = "ml:queue" taskPrefix = "ml:task:" taskStatusPrefix = "ml:status:" jobMetricsPrefix = "ml:metrics:" ) // NewTaskQueue creates a new task queue for testing func NewTaskQueue(cfg *Config) (*TaskQueue, error) { rdb := redis.NewClient(&redis.Options{ Addr: cfg.RedisAddr, Password: cfg.RedisPassword, DB: cfg.RedisDB, }) ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { return nil, fmt.Errorf("redis connection failed: %w", err) } return &TaskQueue{client: rdb, ctx: ctx}, nil } // EnsureRedis ensures a Redis instance is running on localhost:6379. // If none is found, it starts a temporary instance and returns a cleanup function. func EnsureRedis(t *testing.T) (cleanup func()) { const redisAddr = "localhost:6379" // Try to connect first rdb := redis.NewClient(&redis.Options{Addr: redisAddr}) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := rdb.Ping(ctx).Err(); err == nil { // Redis is already running return func() {} } // Start temporary Redis t.Logf("Starting temporary Redis on %s", redisAddr) cmd := exec.CommandContext( context.Background(), "redis-server", "--daemonize", "yes", "--port", "6379", ) if out, err := cmd.CombinedOutput(); err != nil { t.Fatalf("Failed to start temporary Redis: %v; output: %s", err, string(out)) } // Give it a moment to start time.Sleep(1 * time.Second) // Verify it started if err := rdb.Ping(context.Background()).Err(); err != nil { t.Fatalf("Temporary Redis failed to start: %v", err) } // Return cleanup function return func() { shutdown := exec.CommandContext(context.Background(), "redis-cli", "-p", "6379", "shutdown") _ = shutdown.Run() // ignore errors } } // EnqueueTask adds a task to the queue func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) { task := &Task{ ID: uuid.New().String(), JobName: jobName, Args: args, Status: "queued", Priority: priority, CreatedAt: time.Now(), } taskData, err := json.Marshal(task) if err != nil { return nil, err } pipe := tq.client.Pipeline() pipe.Set(tq.ctx, taskPrefix+task.ID, taskData, 0) pipe.ZAdd(tq.ctx, taskQueueKey, redis.Z{Score: float64(priority), Member: task.ID}) pipe.HSet(tq.ctx, taskStatusPrefix+task.JobName, "status", "queued", "task_id", task.ID) if _, err := pipe.Exec(tq.ctx); err != nil { return nil, err } return task, nil } // GetNextTask retrieves the next highest priority task func (tq *TaskQueue) GetNextTask() (*Task, error) { result, err := tq.client.ZPopMax(tq.ctx, taskQueueKey, 1).Result() if err != nil { return nil, err } if len(result) == 0 { return nil, nil } taskID := result[0].Member.(string) return tq.GetTask(taskID) } // GetTask retrieves a task by ID func (tq *TaskQueue) GetTask(taskID string) (*Task, error) { data, err := tq.client.Get(tq.ctx, taskPrefix+taskID).Result() if err != nil { return nil, err } var task Task if err := json.Unmarshal([]byte(data), &task); err != nil { return nil, err } return &task, nil } // UpdateTask updates a task's status and metadata func (tq *TaskQueue) UpdateTask(task *Task) error { taskData, err := json.Marshal(task) if err != nil { return err } pipe := tq.client.Pipeline() pipe.Set(tq.ctx, taskPrefix+task.ID, taskData, 0) pipe.HSet( tq.ctx, taskStatusPrefix+task.JobName, "status", task.Status, "updated_at", time.Now().Format(time.RFC3339), ) _, err = pipe.Exec(tq.ctx) return err } // CancelTask cancels a task func (tq *TaskQueue) CancelTask(taskID string) error { task, err := tq.GetTask(taskID) if err != nil { return err } task.Status = "cancelled" now := time.Now() task.EndedAt = &now pipe := tq.client.Pipeline() pipe.ZRem(tq.ctx, taskQueueKey, taskID) if err := tq.UpdateTask(task); err != nil { return err } _, err = pipe.Exec(tq.ctx) return err } // GetJobStatus retrieves the status of a job func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) { return tq.client.HGetAll(tq.ctx, taskStatusPrefix+jobName).Result() } // RecordMetric records a metric for a job func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error { key := jobMetricsPrefix + jobName return tq.client.HSet(tq.ctx, key, metric, value, "timestamp", time.Now().Unix()).Err() } // GetMetrics retrieves all metrics for a job func (tq *TaskQueue) GetMetrics(jobName string) (map[string]string, error) { return tq.client.HGetAll(tq.ctx, jobMetricsPrefix+jobName).Result() } // Close closes the task queue func (tq *TaskQueue) Close() error { return tq.client.Close() } // ManageScript provides utilities for manage.sh operations type ManageScript struct { path string } // NewManageScript creates a new manage script utility func NewManageScript(path string) *ManageScript { return &ManageScript{path: path} } // Status gets the status of services func (ms *ManageScript) Status() (string, error) { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "status") output, err := cmd.CombinedOutput() return string(output), err } // Start starts the services func (ms *ManageScript) Start() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "start") return cmd.Run() } // Stop stops the services func (ms *ManageScript) Stop() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "stop") return cmd.Run() } // Cleanup cleans up any artifacts created by services func (ms *ManageScript) Cleanup() error { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "cleanup") return cmd.Run() } // StopAndCleanup ensures cleanup when called with defer func (ms *ManageScript) StopAndCleanup() { _ = ms.Stop() _ = ms.Cleanup() } // Health checks the health of services func (ms *ManageScript) Health() (string, error) { //nolint:gosec // G204: Subprocess launched with potential tainted input - this is a test utility cmd := exec.CommandContext(context.Background(), ms.path, "health") output, err := cmd.CombinedOutput() return string(output), err } // RedisHelper provides utilities for Redis operations type RedisHelper struct { client *redis.Client ctx context.Context } // NewRedisHelper creates a new Redis helper func NewRedisHelper(addr string, db int) (*RedisHelper, error) { rdb := redis.NewClient(&redis.Options{ Addr: addr, Password: "", DB: db, }) ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { return nil, fmt.Errorf("redis connection failed: %w", err) } return &RedisHelper{client: rdb, ctx: ctx}, nil } // Close closes the Redis connection func (rh *RedisHelper) Close() error { return rh.client.Close() } // FlushDB flushes the Redis database func (rh *RedisHelper) FlushDB() error { return rh.client.FlushDB(rh.ctx).Err() } // GetClient returns the underlying Redis client func (rh *RedisHelper) GetClient() *redis.Client { return rh.client } // ExamplesDir provides utilities for working with example projects type ExamplesDir struct { path string } // NewExamplesDir creates a new examples directory utility func NewExamplesDir(basePath string) *ExamplesDir { return &ExamplesDir{path: basePath} } // GetPath returns the path to an example project func (ed *ExamplesDir) GetPath(projectName string) string { return filepath.Join(ed.path, projectName) } // ListProjects returns a list of all example projects func (ed *ExamplesDir) ListProjects() ([]string, error) { entries, err := os.ReadDir(ed.path) if err != nil { return nil, err } var projects []string for _, entry := range entries { if entry.IsDir() { projects = append(projects, entry.Name()) } } return projects, nil } // CopyProject copies an example project to a destination func (ed *ExamplesDir) CopyProject(projectName, dest string) error { src := ed.GetPath(projectName) return CopyDir(src, dest) } // MLServer minimal implementation for testing type MLServer struct { client any // In real implementation this would be *ssh.Client } // NewMLServer creates a new MLServer instance for testing func NewMLServer() *MLServer { return &MLServer{ client: nil, // Local mode by default } } // Exec runs a command either locally or via SSH (stubbed for tests) func (s *MLServer) Exec(cmd string) (string, error) { if s.client == nil { // Local mode out, err := exec.CommandContext(context.Background(), "sh", "-c", cmd).CombinedOutput() return string(out), err } // SSH mode would be implemented here return "", fmt.Errorf("SSH mode not implemented in tests") } // Close closes the ML server connection func (s *MLServer) Close() error { return nil } // LoadConfig loads configuration for testing func LoadConfig(path string) (*Config, error) { data, err := fileutil.SecureFileRead(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } if cfg.RedisAddr == "" { cfg.RedisAddr = "localhost:6379" } if cfg.RedisDB == 0 { cfg.RedisDB = 0 } return &cfg, nil } // CopyDir copies a directory recursively func CopyDir(src, dst string) error { srcInfo, err := os.Stat(src) if err != nil { return err } // Create the destination directory with the same permissions as source if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil { return err } entries, err := os.ReadDir(src) if err != nil { return err } for _, entry := range entries { srcPath := filepath.Join(src, entry.Name()) dstPath := filepath.Join(dst, entry.Name()) if entry.IsDir() { if err := CopyDir(srcPath, dstPath); err != nil { return err } } else { if err := copyFile(srcPath, dstPath); err != nil { return err } } } return nil } func copyFile(src, dst string) error { //nolint:gosec // G304: Potential file inclusion via variable - this is a test utility srcFile, err := os.Open(src) if err != nil { return err } defer func() { _ = srcFile.Close() }() srcInfo, err := srcFile.Stat() if err != nil { return err } //nolint:gosec // G304: Potential file inclusion via variable - this is a test utility dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode()) if err != nil { return err } defer func() { _ = dstFile.Close() }() _, err = io.Copy(dstFile, srcFile) return err } // CreateMLProject creates an ML project from a template func CreateMLProject(t *testing.T, testDir, projectName string, template MLProjectTemplate) { experimentDir := filepath.Join(testDir, projectName) if err := os.MkdirAll(experimentDir, 0750); err != nil { t.Fatalf("Failed to create experiment directory: %v", err) } // Create training script trainScript := filepath.Join(experimentDir, "train.py") if err := os.WriteFile(trainScript, []byte(template.TrainScript), 0600); err != nil { t.Fatalf("Failed to create train.py: %v", err) } // Create requirements.txt requirementsFile := filepath.Join(experimentDir, "requirements.txt") if err := os.WriteFile(requirementsFile, []byte(template.Requirements), 0600); err != nil { t.Fatalf("Failed to create requirements.txt: %v", err) } // Verify project structure if _, err := os.Stat(trainScript); os.IsNotExist(err) { t.Errorf("%s train.py should exist", template.Name) } if _, err := os.Stat(requirementsFile); os.IsNotExist(err) { t.Errorf("%s requirements.txt should exist", template.Name) } }