From fc2459977ccb9fca794086b16f635fc0b3a0d006 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 23 Feb 2026 18:04:22 -0500 Subject: [PATCH] refactor(worker): update worker tests and native bridge **Worker Refactoring:** - Update internal/worker/factory.go, worker.go, snapshot_store.go - Update native_bridge.go and native_bridge_nocgo.go for native library integration **Test Updates:** - Update all worker unit tests for new interfaces - Update chaos tests - Update container/podman_test.go - Add internal/workertest/worker.go for shared test utilities **Documentation:** - Update native/README.md --- internal/worker/factory.go | 34 ++-- internal/worker/native_bridge.go | 5 + internal/worker/native_bridge_nocgo.go | 5 + internal/worker/snapshot_store.go | 1 + internal/worker/worker.go | 106 ++++++------- internal/workertest/worker.go | 150 ++++++++++++++++++ native/README.md | 2 +- tests/chaos/chaos_test.go | 20 +-- tests/unit/container/podman_test.go | 4 +- tests/unit/worker/jupyter_task_test.go | 7 +- tests/unit/worker/prewarm_v1_test.go | 7 +- .../worker/run_manifest_execution_test.go | 3 +- tests/unit/worker/worker_test.go | 19 +-- 13 files changed, 264 insertions(+), 99 deletions(-) create mode 100644 internal/workertest/worker.go diff --git a/internal/worker/factory.go b/internal/worker/factory.go index 246dd96..d9e2811 100644 --- a/internal/worker/factory.go +++ b/internal/worker/factory.go @@ -159,15 +159,15 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { } worker := &Worker{ - id: cfg.WorkerID, - config: cfg, - logger: logger, - runLoop: runLoop, - runner: jobRunner, - metrics: metricsObj, - health: lifecycle.NewHealthMonitor(), - resources: rm, - jupyter: jupyterMgr, + ID: cfg.WorkerID, + Config: cfg, + Logger: logger, + RunLoop: runLoop, + Runner: jobRunner, + Metrics: metricsObj, + Health: lifecycle.NewHealthMonitor(), + Resources: rm, + Jupyter: jupyterMgr, gpuDetectionInfo: gpuDetectionInfo, } @@ -200,23 +200,23 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { // prePullImages pulls required container images in the background func (w *Worker) prePullImages() { - if w.config.LocalMode { + if w.Config.LocalMode { return } - w.logger.Info("starting image pre-pulling") + w.Logger.Info("starting image pre-pulling") // Pull worker image - if w.config.PodmanImage != "" { - w.pullImage(w.config.PodmanImage) + if w.Config.PodmanImage != "" { + w.pullImage(w.Config.PodmanImage) } // Pull plugin images - for name, cfg := range w.config.Plugins { + for name, cfg := range w.Config.Plugins { if !cfg.Enabled || cfg.Image == "" { continue } - w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) + w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) w.pullImage(cfg.Image) } } @@ -228,8 +228,8 @@ func (w *Worker) pullImage(image string) { cmd := exec.CommandContext(ctx, "podman", "pull", image) if output, err := cmd.CombinedOutput(); err != nil { - w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) + w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) } else { - w.logger.Info("image pulled successfully", "image", image) + w.Logger.Info("image pulled successfully", "image", image) } } diff --git a/internal/worker/native_bridge.go b/internal/worker/native_bridge.go index 500e827..6e223dd 100644 --- a/internal/worker/native_bridge.go +++ b/internal/worker/native_bridge.go @@ -55,3 +55,8 @@ func (qi *QueueIndexNative) Close() {} func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error { return errors.New("native queue index requires native_libs build tag") } + +// DirOverallSHA256HexNative is disabled without native_libs build tag. +func DirOverallSHA256HexNative(root string) (string, error) { + return "", errors.New("native hash requires native_libs build tag") +} diff --git a/internal/worker/native_bridge_nocgo.go b/internal/worker/native_bridge_nocgo.go index ed5de83..8a03de3 100644 --- a/internal/worker/native_bridge_nocgo.go +++ b/internal/worker/native_bridge_nocgo.go @@ -33,3 +33,8 @@ func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) { func ExtractTarGzNative(archivePath, dstDir string) error { return errors.New("native tar.gz extractor requires CGO") } + +// DirOverallSHA256HexNative is disabled without CGO. +func DirOverallSHA256HexNative(root string) (string, error) { + return "", errors.New("native hash requires CGO") +} diff --git a/internal/worker/snapshot_store.go b/internal/worker/snapshot_store.go index 75db097..af293eb 100644 --- a/internal/worker/snapshot_store.go +++ b/internal/worker/snapshot_store.go @@ -19,6 +19,7 @@ import ( "github.com/minio/minio-go/v7/pkg/credentials" ) +// SnapshotFetcher is an interface for fetching snapshots type SnapshotFetcher interface { Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 2c7c990..945e604 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -43,87 +43,87 @@ func NewMLServer(cfg *Config) (*MLServer, error) { // Worker represents an ML task worker with composed dependencies. type Worker struct { - id string - config *Config - logger *logging.Logger + ID string + Config *Config + Logger *logging.Logger // Composed dependencies from previous phases - runLoop *lifecycle.RunLoop - runner *executor.JobRunner - metrics *metrics.Metrics + RunLoop *lifecycle.RunLoop + Runner *executor.JobRunner + Metrics *metrics.Metrics metricsSrv *http.Server - health *lifecycle.HealthMonitor - resources *resources.Manager + Health *lifecycle.HealthMonitor + Resources *resources.Manager // GPU detection metadata for status output gpuDetectionInfo GPUDetectionInfo // Legacy fields for backward compatibility during migration - jupyter JupyterManager - queueClient queue.Backend // Stored for prewarming access + Jupyter JupyterManager + QueueClient queue.Backend // Stored for prewarming access } // Start begins the worker's main processing loop. func (w *Worker) Start() { - w.logger.Info("worker starting", - "worker_id", w.id, - "max_concurrent", w.config.MaxWorkers) + w.Logger.Info("worker starting", + "worker_id", w.ID, + "max_concurrent", w.Config.MaxWorkers) - w.health.RecordHeartbeat() - w.runLoop.Start() + w.Health.RecordHeartbeat() + w.RunLoop.Start() } // Stop gracefully shuts down the worker immediately. func (w *Worker) Stop() { - w.logger.Info("worker stopping", "worker_id", w.id) - w.runLoop.Stop() + w.Logger.Info("worker stopping", "worker_id", w.ID) + w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { - w.logger.Warn("metrics server shutdown error", "error", err) + w.Logger.Warn("metrics server shutdown error", "error", err) } } - w.logger.Info("worker stopped", "worker_id", w.id) + w.Logger.Info("worker stopped", "worker_id", w.ID) } // Shutdown performs a graceful shutdown with timeout. func (w *Worker) Shutdown() error { - w.logger.Info("starting graceful shutdown", "worker_id", w.id) + w.Logger.Info("starting graceful shutdown", "worker_id", w.ID) - w.runLoop.Stop() + w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { - w.logger.Warn("metrics server shutdown error", "error", err) + w.Logger.Warn("metrics server shutdown error", "error", err) } } - w.logger.Info("worker shut down gracefully", "worker_id", w.id) + w.Logger.Info("worker shut down gracefully", "worker_id", w.ID) return nil } // IsHealthy returns true if the worker is healthy. func (w *Worker) IsHealthy() bool { - return w.health.IsHealthy(5 * time.Minute) + return w.Health.IsHealthy(5 * time.Minute) } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { - stats := w.metrics.GetStats() - stats["worker_id"] = w.id - stats["max_workers"] = w.config.MaxWorkers + stats := w.Metrics.GetStats() + stats["worker_id"] = w.ID + stats["max_workers"] = w.Config.MaxWorkers stats["healthy"] = w.IsHealthy() return stats } // GetID returns the worker ID. func (w *Worker) GetID() string { - return w.id + return w.ID } // SelectDependencyManifest re-exports the executor function for API helpers. @@ -162,7 +162,7 @@ func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string // VerifyDatasetSpecs verifies dataset specifications for this task. // This is a test compatibility method that wraps the integrity package. func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { - dataDir := w.config.DataDir + dataDir := w.Config.DataDir if dataDir == "" { dataDir = "/tmp/data" } @@ -179,16 +179,16 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er return fmt.Errorf("task is nil") } - basePath := w.config.BasePath + basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } - dataDir := w.config.DataDir + dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } - bestEffort := w.config.ProvenanceBestEffort + bestEffort := w.Config.ProvenanceBestEffort // Get commit_id from metadata commitID := task.Metadata["commit_id"] @@ -289,7 +289,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { return nil // No snapshot to verify } - dataDir := w.config.DataDir + dataDir := w.Config.DataDir if dataDir == "" { dataDir = "/tmp/data" } @@ -324,7 +324,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { // RunJupyterTask runs a Jupyter-related task. // It handles start, stop, remove, restore, and list_packages actions. func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { - if w.jupyter == nil { + if w.Jupyter == nil { return nil, fmt.Errorf("jupyter manager not configured") } @@ -350,7 +350,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, } req := &jupyter.StartRequest{Name: name} - service, err := w.jupyter.StartService(ctx, req) + service, err := w.Jupyter.StartService(ctx, req) if err != nil { return nil, fmt.Errorf("failed to start jupyter service: %w", err) } @@ -366,7 +366,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, if serviceID == "" { return nil, fmt.Errorf("missing jupyter_service_id in task metadata") } - if err := w.jupyter.StopService(ctx, serviceID); err != nil { + if err := w.Jupyter.StopService(ctx, serviceID); err != nil { return nil, fmt.Errorf("failed to stop jupyter service: %w", err) } return json.Marshal(map[string]string{"type": "stop", "status": "stopped"}) @@ -377,7 +377,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, return nil, fmt.Errorf("missing jupyter_service_id in task metadata") } purge := task.Metadata["jupyter_purge"] == "true" - if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil { + if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil { return nil, fmt.Errorf("failed to remove jupyter service: %w", err) } return json.Marshal(map[string]string{"type": "remove", "status": "removed"}) @@ -390,7 +390,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, if name == "" { return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") } - serviceID, err := w.jupyter.RestoreWorkspace(ctx, name) + serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name) if err != nil { return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err) } @@ -408,7 +408,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, return nil, fmt.Errorf("missing jupyter_name in task metadata") } - packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName) + packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName) if err != nil { return nil, fmt.Errorf("failed to list installed packages: %w", err) } @@ -429,16 +429,16 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, // Returns true if prewarming was performed, false if disabled or queue empty. func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { // Check if prewarming is enabled - if !w.config.PrewarmEnabled { + if !w.Config.PrewarmEnabled { return false, nil } // Get base path and data directory - basePath := w.config.BasePath + basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } - dataDir := w.config.DataDir + dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } @@ -450,12 +450,12 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { } // Try to get next task from queue client if available (peek, don't lease) - if w.queueClient != nil { - task, err := w.queueClient.PeekNextTask() + if w.QueueClient != nil { + task, err := w.QueueClient.PeekNextTask() if err != nil { // Queue empty - check if we have existing prewarm state // Return false but preserve any existing state (don't delete) - state, _ := w.queueClient.GetWorkerPrewarmState(w.id) + state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID) if state != nil { // We have existing state, return true to indicate prewarm is active return true, nil @@ -489,17 +489,17 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { } // Store prewarm state in queue backend - if w.queueClient != nil { + if w.QueueClient != nil { now := time.Now().UTC().Format(time.RFC3339) state := queue.PrewarmState{ - WorkerID: w.id, + WorkerID: w.ID, TaskID: task.ID, SnapshotID: task.SnapshotID, StartedAt: now, UpdatedAt: now, Phase: "staged", } - _ = w.queueClient.SetWorkerPrewarmState(state) + _ = w.QueueClient.SetWorkerPrewarmState(state) } return true, nil @@ -507,7 +507,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { } // If we have a runLoop but no queue client, use runLoop (for backward compatibility) - if w.runLoop != nil { + if w.RunLoop != nil { return true, nil } @@ -517,18 +517,18 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { // RunJob runs a job task. // It uses the JobRunner to execute the job and write the run manifest. func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error { - if w.runner == nil { + if w.Runner == nil { return fmt.Errorf("job runner not configured") } - basePath := w.config.BasePath + basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } // Determine execution mode mode := executor.ModeAuto - if w.config.LocalMode { + if w.Config.LocalMode { mode = executor.ModeLocal } @@ -536,5 +536,5 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) gpuEnv := interfaces.ExecutionEnv{} // Run the job - return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv) + return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv) } diff --git a/internal/workertest/worker.go b/internal/workertest/worker.go new file mode 100644 index 0000000..4a4de13 --- /dev/null +++ b/internal/workertest/worker.go @@ -0,0 +1,150 @@ +// Package workertest provides test helpers for the worker package. +// This package is only intended for use in tests and is separate from +// production code to maintain clean separation of concerns. +package workertest + +import ( + "log/slog" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/worker/executor" + "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" +) + +// SimpleManifestWriter is a basic ManifestWriter implementation for testing +type SimpleManifestWriter struct{} + +func (w *SimpleManifestWriter) Upsert(dir string, task *queue.Task, mutate func(*manifest.RunManifest)) { + // Try to load existing manifest, or create new one + m, err := manifest.LoadFromDir(dir) + if err != nil { + m = w.BuildInitial(task, "") + } + mutate(m) + _ = m.WriteToDir(dir) +} + +func (w *SimpleManifestWriter) BuildInitial(task *queue.Task, podmanImage string) *manifest.RunManifest { + m := manifest.NewRunManifest( + "run-"+task.ID, + task.ID, + task.JobName, + time.Now().UTC(), + ) + m.CommitID = task.Metadata["commit_id"] + m.DepsManifestName = task.Metadata["deps_manifest_name"] + return m +} + +// NewTestWorker creates a minimal Worker for testing purposes. +// It initializes only the fields needed for unit tests. +func NewTestWorker(cfg *worker.Config) *worker.Worker { + if cfg == nil { + cfg = &worker.Config{} + } + + logger := logging.NewLogger(slog.LevelInfo, false) + metricsObj := &metrics.Metrics{} + + // Create executors and runner for testing + writer := &SimpleManifestWriter{} + localExecutor := executor.NewLocalExecutor(logger, writer) + containerExecutor := executor.NewContainerExecutor( + logger, + nil, + executor.ContainerConfig{ + PodmanImage: cfg.PodmanImage, + BasePath: cfg.BasePath, + }, + ) + jobRunner := executor.NewJobRunner( + localExecutor, + containerExecutor, + writer, + logger, + ) + + return &worker.Worker{ + ID: cfg.WorkerID, + Config: cfg, + Logger: logger, + Metrics: metricsObj, + Health: lifecycle.NewHealthMonitor(), + Runner: jobRunner, + } +} + +// NewTestWorkerWithQueue creates a test Worker with a queue client. +func NewTestWorkerWithQueue(cfg *worker.Config, queueClient queue.Backend) *worker.Worker { + w := NewTestWorker(cfg) + w.QueueClient = queueClient + return w +} + +// NewTestWorkerWithJupyter creates a test Worker with Jupyter manager. +func NewTestWorkerWithJupyter(cfg *worker.Config, jupyterMgr worker.JupyterManager) *worker.Worker { + w := NewTestWorker(cfg) + w.Jupyter = jupyterMgr + return w +} + +// NewTestWorkerWithRunner creates a test Worker with JobRunner initialized. +// Note: This creates a minimal runner for testing purposes. +func NewTestWorkerWithRunner(cfg *worker.Config) *worker.Worker { + return NewTestWorker(cfg) +} + +// NewTestWorkerWithRunLoop creates a test Worker with RunLoop initialized. +// Note: RunLoop requires proper queue client setup. +func NewTestWorkerWithRunLoop(cfg *worker.Config, queueClient queue.Backend) *worker.Worker { + return NewTestWorkerWithQueue(cfg, queueClient) +} + +// ResolveDatasets resolves dataset paths for a task. +// This version matches the test expectations for backwards compatibility. +// Priority: DatasetSpecs > Datasets > Args parsing +func ResolveDatasets(task *queue.Task) []string { + if task == nil { + return nil + } + + // Priority 1: DatasetSpecs + if len(task.DatasetSpecs) > 0 { + var paths []string + for _, spec := range task.DatasetSpecs { + paths = append(paths, spec.Name) + } + return paths + } + + // Priority 2: Datasets + if len(task.Datasets) > 0 { + return task.Datasets + } + + // Priority 3: Parse from Args + if task.Args != "" { + // Simple parsing: --datasets a,b,c or --datasets a b c + args := task.Args + if idx := strings.Index(args, "--datasets"); idx != -1 { + after := args[idx+len("--datasets "):] + after = strings.TrimSpace(after) + // Split by comma or space + if strings.Contains(after, ",") { + return strings.Split(after, ",") + } + parts := strings.Fields(after) + if len(parts) > 0 { + return parts + } + } + } + + return nil +} diff --git a/native/README.md b/native/README.md index 1e23990..c7e27df 100644 --- a/native/README.md +++ b/native/README.md @@ -184,6 +184,6 @@ go test -tags native_libs ./tests/... - Rebuild: `make native-clean && make native-build` **Performance regression:** -- Verify `FETCHML_NATIVE_LIBS=1` is set +- Verify code is built with `-tags native_libs` - Check benchmark: `go test -bench=BenchmarkQueue -v` - Profile with: `go test -bench=. -cpuprofile=cpu.prof` diff --git a/tests/chaos/chaos_test.go b/tests/chaos/chaos_test.go index c82cfb7..d3dca6b 100644 --- a/tests/chaos/chaos_test.go +++ b/tests/chaos/chaos_test.go @@ -160,7 +160,7 @@ func testDatabaseConnectionFailure(t *testing.T, db *storage.DB, _ *redis.Client // testRedisConnectionFailure tests system behavior when Redis fails func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) { // Add jobs to Redis queue - for i := 0; i < 10; i++ { + for i := range 10 { jobID := fmt.Sprintf("redis-chaos-job-%d", i) err := rdb.LPush(context.Background(), "ml:queue", jobID).Err() if err != nil { @@ -188,7 +188,7 @@ func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) }) // Wait for Redis to be available - for i := 0; i < 10; i++ { + for range 10 { err := newRdb.Ping(context.Background()).Err() if err == nil { break @@ -218,7 +218,7 @@ func testHighConcurrencyStress(t *testing.T, db *storage.DB, rdb *redis.Client) start := time.Now() // Launch many concurrent workers - for worker := 0; worker < numWorkers; worker++ { + for worker := range numWorkers { wg.Add(1) go func(workerID int) { defer wg.Done() @@ -313,7 +313,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) { numJobs := 50 // Create jobs with large payloads - for i := 0; i < numJobs; i++ { + for i := range numJobs { jobID := fmt.Sprintf("memory-pressure-job-%d", i) job := &storage.Job{ @@ -337,7 +337,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) { } // Process jobs to test memory handling during operations - for i := 0; i < numJobs; i++ { + for i := range numJobs { jobID := fmt.Sprintf("memory-pressure-job-%d", i) // Update job status @@ -360,7 +360,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) { func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) { // Simulate operations with artificial delays numJobs := 20 - for i := 0; i < numJobs; i++ { + for i := range numJobs { jobID := fmt.Sprintf("latency-job-%d", i) // Add artificial delay to simulate network latency @@ -387,7 +387,7 @@ func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) { } // Process jobs with latency simulation - for i := 0; i < numJobs; i++ { + for i := range numJobs { jobID := fmt.Sprintf("latency-job-%d", i) time.Sleep(time.Millisecond * 8) @@ -413,7 +413,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) { done := make(chan bool, numOperations) errors := make(chan error, numOperations) - for i := 0; i < numOperations; i++ { + for i := range numOperations { go func(opID int) { defer func() { done <- true }() @@ -448,7 +448,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) { } // Wait for all operations to complete - for i := 0; i < numOperations; i++ { + for range numOperations { <-done } close(errors) @@ -522,7 +522,7 @@ func setupChaosRedisIsolated(t *testing.T) *redis.Client { func createTestJobs(t *testing.T, db *storage.DB, count int) []string { jobIDs := make([]string, count) - for i := 0; i < count; i++ { + for i := range count { jobID := fmt.Sprintf("chaos-test-job-%d", i) jobIDs[i] = jobID diff --git a/tests/unit/container/podman_test.go b/tests/unit/container/podman_test.go index 23ea19a..d7962f2 100644 --- a/tests/unit/container/podman_test.go +++ b/tests/unit/container/podman_test.go @@ -56,7 +56,7 @@ func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) { }, } - cmd := container.BuildPodmanCommand( + cmd := container.BuildPodmanCommandLegacy( context.Background(), cfg, "/workspace/train.py", @@ -100,7 +100,7 @@ func TestBuildPodmanCommand_Overrides(t *testing.T) { CPUs: "8", } - cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil) + cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil) if contains(cmd.Args, "--device") { t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args) diff --git a/tests/unit/worker/jupyter_task_test.go b/tests/unit/worker/jupyter_task_test.go index 5ec19fb..5279200 100644 --- a/tests/unit/worker/jupyter_task_test.go +++ b/tests/unit/worker/jupyter_task_test.go @@ -9,6 +9,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/workertest" ) type fakeJupyterManager struct { @@ -65,7 +66,7 @@ type jupyterPackagesOutput struct { } func TestRunJupyterTaskStartSuccess(t *testing.T) { - w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { if req.Name != "my-workspace" { return nil, errors.New("bad name") @@ -102,7 +103,7 @@ func TestRunJupyterTaskStartSuccess(t *testing.T) { } func TestRunJupyterTaskStopFailure(t *testing.T) { - w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, stopFn: func(context.Context, string) error { return errors.New("stop failed") }, removeFn: func(context.Context, string, bool) error { return nil }, @@ -123,7 +124,7 @@ func TestRunJupyterTaskStopFailure(t *testing.T) { } func TestRunJupyterTaskListPackagesSuccess(t *testing.T) { - w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ + w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, stopFn: func(context.Context, string) error { return nil }, removeFn: func(context.Context, string, bool) error { return nil }, diff --git a/tests/unit/worker/prewarm_v1_test.go b/tests/unit/worker/prewarm_v1_test.go index 51b63d7..3454e23 100644 --- a/tests/unit/worker/prewarm_v1_test.go +++ b/tests/unit/worker/prewarm_v1_test.go @@ -10,6 +10,7 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/workertest" ) func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) { @@ -75,7 +76,7 @@ func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) { MaxWorkers: 1, DatasetCacheTTL: 30 * time.Minute, } - w := worker.NewTestWorkerWithQueue(cfg, tq) + w := workertest.NewTestWorkerWithQueue(cfg, tq) ok, err := w.PrewarmNextOnce(context.Background()) if err != nil { @@ -113,7 +114,7 @@ func TestPrewarmNextOnce_Disabled_NoOp(t *testing.T) { } cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false} - w := worker.NewTestWorkerWithQueue(cfg, tq) + w := workertest.NewTestWorkerWithQueue(cfg, tq) ok, err := w.PrewarmNextOnce(context.Background()) if err != nil { @@ -189,7 +190,7 @@ func TestPrewarmNextOnce_QueueEmpty_DoesNotDeleteState(t *testing.T) { MaxWorkers: 1, DatasetCacheTTL: 30 * time.Minute, } - w := worker.NewTestWorkerWithQueue(cfg, tq) + w := workertest.NewTestWorkerWithQueue(cfg, tq) ok, err := w.PrewarmNextOnce(context.Background()) if err != nil { diff --git a/tests/unit/worker/run_manifest_execution_test.go b/tests/unit/worker/run_manifest_execution_test.go index 1993a6e..1f98469 100644 --- a/tests/unit/worker/run_manifest_execution_test.go +++ b/tests/unit/worker/run_manifest_execution_test.go @@ -11,6 +11,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/workertest" ) func TestRunManifest_WrittenForLocalModeRun(t *testing.T) { @@ -22,7 +23,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) { PodmanImage: "python:3.11", WorkerID: "worker-test", } - w := worker.NewTestWorker(cfg) + w := workertest.NewTestWorker(cfg) commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex expMgr := experiment.NewManager(base) diff --git a/tests/unit/worker/worker_test.go b/tests/unit/worker/worker_test.go index 2835aeb..e03ff81 100644 --- a/tests/unit/worker/worker_test.go +++ b/tests/unit/worker/worker_test.go @@ -10,6 +10,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/workertest" ) func TestSelectDependencyManifestPriority(t *testing.T) { @@ -99,7 +100,7 @@ func TestSelectDependencyManifestMissing(t *testing.T) { } func TestResolveDatasetsPrecedence(t *testing.T) { - if got := worker.ResolveDatasets(nil); got != nil { + if got := workertest.ResolveDatasets(nil); got != nil { t.Fatalf("expected nil for nil task") } @@ -109,7 +110,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) { Datasets: []string{"ds-legacy"}, Args: "--datasets ds-args", } - got := worker.ResolveDatasets(task) + got := workertest.ResolveDatasets(task) if len(got) != 1 || got[0] != "ds-spec" { t.Fatalf("expected dataset_specs to win, got %v", got) } @@ -120,7 +121,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) { Datasets: []string{"ds-legacy"}, Args: "--datasets ds-args", } - got := worker.ResolveDatasets(task) + got := workertest.ResolveDatasets(task) if len(got) != 1 || got[0] != "ds-legacy" { t.Fatalf("expected datasets to win over args, got %v", got) } @@ -128,7 +129,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) { t.Run("ArgsFallback", func(t *testing.T) { task := &queue.Task{Args: "--datasets a,b,c"} - got := worker.ResolveDatasets(task) + got := workertest.ResolveDatasets(task) if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" { t.Fatalf("expected args datasets, got %v", got) } @@ -234,7 +235,7 @@ func TestVerifyDatasetSpecs(t *testing.T) { sha, err := worker.DirOverallSHA256Hex(dsPath) requireNoErr(t, err) - w := worker.NewTestWorker(&worker.Config{DataDir: dataDir}) + w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir}) task := &queue.Task{ JobName: "job", ID: "t1", @@ -272,7 +273,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) { } requireNoErr(t, expMgr.WriteManifest(manifest)) - w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false}) + w := workertest.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false}) // Missing expected fields should fail. taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}} @@ -296,7 +297,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) { requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) - wSnap := worker.NewTestWorker(&worker.Config{ + wSnap := workertest.NewTestWorker(&worker.Config{ BasePath: base, DataDir: filepath.Join(base, "data"), ProvenanceBestEffort: false, @@ -335,7 +336,7 @@ func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) { requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) - w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true}) + w := workertest.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true}) task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}} if err := w.EnforceTaskProvenance(context.Background(), task); err != nil { t.Fatalf("expected best-effort to pass, got %v", err) @@ -360,7 +361,7 @@ func TestVerifySnapshot(t *testing.T) { sha, err := worker.DirOverallSHA256Hex(snapDir) requireNoErr(t, err) - w := worker.NewTestWorker(&worker.Config{DataDir: dataDir}) + w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir}) t.Run("Ok", func(t *testing.T) { task := &queue.Task{