refactor(worker): update worker tests and native bridge

**Worker Refactoring:**
- Update internal/worker/factory.go, worker.go, snapshot_store.go
- Update native_bridge.go and native_bridge_nocgo.go for native library integration

**Test Updates:**
- Update all worker unit tests for new interfaces
- Update chaos tests
- Update container/podman_test.go
- Add internal/workertest/worker.go for shared test utilities

**Documentation:**
- Update native/README.md
This commit is contained in:
Jeremie Fraeys 2026-02-23 18:04:22 -05:00
parent 4b8df60e83
commit fc2459977c
No known key found for this signature in database
13 changed files with 264 additions and 99 deletions

View file

@ -159,15 +159,15 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
} }
worker := &Worker{ worker := &Worker{
id: cfg.WorkerID, ID: cfg.WorkerID,
config: cfg, Config: cfg,
logger: logger, Logger: logger,
runLoop: runLoop, RunLoop: runLoop,
runner: jobRunner, Runner: jobRunner,
metrics: metricsObj, Metrics: metricsObj,
health: lifecycle.NewHealthMonitor(), Health: lifecycle.NewHealthMonitor(),
resources: rm, Resources: rm,
jupyter: jupyterMgr, Jupyter: jupyterMgr,
gpuDetectionInfo: gpuDetectionInfo, gpuDetectionInfo: gpuDetectionInfo,
} }
@ -200,23 +200,23 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
// prePullImages pulls required container images in the background // prePullImages pulls required container images in the background
func (w *Worker) prePullImages() { func (w *Worker) prePullImages() {
if w.config.LocalMode { if w.Config.LocalMode {
return return
} }
w.logger.Info("starting image pre-pulling") w.Logger.Info("starting image pre-pulling")
// Pull worker image // Pull worker image
if w.config.PodmanImage != "" { if w.Config.PodmanImage != "" {
w.pullImage(w.config.PodmanImage) w.pullImage(w.Config.PodmanImage)
} }
// Pull plugin images // Pull plugin images
for name, cfg := range w.config.Plugins { for name, cfg := range w.Config.Plugins {
if !cfg.Enabled || cfg.Image == "" { if !cfg.Enabled || cfg.Image == "" {
continue continue
} }
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.pullImage(cfg.Image) w.pullImage(cfg.Image)
} }
} }
@ -228,8 +228,8 @@ func (w *Worker) pullImage(image string) {
cmd := exec.CommandContext(ctx, "podman", "pull", image) cmd := exec.CommandContext(ctx, "podman", "pull", image)
if output, err := cmd.CombinedOutput(); err != nil { if output, err := cmd.CombinedOutput(); err != nil {
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
} else { } else {
w.logger.Info("image pulled successfully", "image", image) w.Logger.Info("image pulled successfully", "image", image)
} }
} }

View file

@ -55,3 +55,8 @@ func (qi *QueueIndexNative) Close() {}
func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error { func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error {
return errors.New("native queue index requires native_libs build tag") return errors.New("native queue index requires native_libs build tag")
} }
// DirOverallSHA256HexNative is disabled without native_libs build tag.
func DirOverallSHA256HexNative(root string) (string, error) {
return "", errors.New("native hash requires native_libs build tag")
}

View file

@ -33,3 +33,8 @@ func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) {
func ExtractTarGzNative(archivePath, dstDir string) error { func ExtractTarGzNative(archivePath, dstDir string) error {
return errors.New("native tar.gz extractor requires CGO") return errors.New("native tar.gz extractor requires CGO")
} }
// DirOverallSHA256HexNative is disabled without CGO.
func DirOverallSHA256HexNative(root string) (string, error) {
return "", errors.New("native hash requires CGO")
}

View file

@ -19,6 +19,7 @@ import (
"github.com/minio/minio-go/v7/pkg/credentials" "github.com/minio/minio-go/v7/pkg/credentials"
) )
// SnapshotFetcher is an interface for fetching snapshots
type SnapshotFetcher interface { type SnapshotFetcher interface {
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
} }

View file

@ -43,87 +43,87 @@ func NewMLServer(cfg *Config) (*MLServer, error) {
// Worker represents an ML task worker with composed dependencies. // Worker represents an ML task worker with composed dependencies.
type Worker struct { type Worker struct {
id string ID string
config *Config Config *Config
logger *logging.Logger Logger *logging.Logger
// Composed dependencies from previous phases // Composed dependencies from previous phases
runLoop *lifecycle.RunLoop RunLoop *lifecycle.RunLoop
runner *executor.JobRunner Runner *executor.JobRunner
metrics *metrics.Metrics Metrics *metrics.Metrics
metricsSrv *http.Server metricsSrv *http.Server
health *lifecycle.HealthMonitor Health *lifecycle.HealthMonitor
resources *resources.Manager Resources *resources.Manager
// GPU detection metadata for status output // GPU detection metadata for status output
gpuDetectionInfo GPUDetectionInfo gpuDetectionInfo GPUDetectionInfo
// Legacy fields for backward compatibility during migration // Legacy fields for backward compatibility during migration
jupyter JupyterManager Jupyter JupyterManager
queueClient queue.Backend // Stored for prewarming access QueueClient queue.Backend // Stored for prewarming access
} }
// Start begins the worker's main processing loop. // Start begins the worker's main processing loop.
func (w *Worker) Start() { func (w *Worker) Start() {
w.logger.Info("worker starting", w.Logger.Info("worker starting",
"worker_id", w.id, "worker_id", w.ID,
"max_concurrent", w.config.MaxWorkers) "max_concurrent", w.Config.MaxWorkers)
w.health.RecordHeartbeat() w.Health.RecordHeartbeat()
w.runLoop.Start() w.RunLoop.Start()
} }
// Stop gracefully shuts down the worker immediately. // Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() { func (w *Worker) Stop() {
w.logger.Info("worker stopping", "worker_id", w.id) w.Logger.Info("worker stopping", "worker_id", w.ID)
w.runLoop.Stop() w.RunLoop.Stop()
if w.metricsSrv != nil { if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil { if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err) w.Logger.Warn("metrics server shutdown error", "error", err)
} }
} }
w.logger.Info("worker stopped", "worker_id", w.id) w.Logger.Info("worker stopped", "worker_id", w.ID)
} }
// Shutdown performs a graceful shutdown with timeout. // Shutdown performs a graceful shutdown with timeout.
func (w *Worker) Shutdown() error { func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "worker_id", w.id) w.Logger.Info("starting graceful shutdown", "worker_id", w.ID)
w.runLoop.Stop() w.RunLoop.Stop()
if w.metricsSrv != nil { if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil { if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err) w.Logger.Warn("metrics server shutdown error", "error", err)
} }
} }
w.logger.Info("worker shut down gracefully", "worker_id", w.id) w.Logger.Info("worker shut down gracefully", "worker_id", w.ID)
return nil return nil
} }
// IsHealthy returns true if the worker is healthy. // IsHealthy returns true if the worker is healthy.
func (w *Worker) IsHealthy() bool { func (w *Worker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute) return w.Health.IsHealthy(5 * time.Minute)
} }
// GetMetrics returns current worker metrics. // GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any { func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats() stats := w.Metrics.GetStats()
stats["worker_id"] = w.id stats["worker_id"] = w.ID
stats["max_workers"] = w.config.MaxWorkers stats["max_workers"] = w.Config.MaxWorkers
stats["healthy"] = w.IsHealthy() stats["healthy"] = w.IsHealthy()
return stats return stats
} }
// GetID returns the worker ID. // GetID returns the worker ID.
func (w *Worker) GetID() string { func (w *Worker) GetID() string {
return w.id return w.ID
} }
// SelectDependencyManifest re-exports the executor function for API helpers. // SelectDependencyManifest re-exports the executor function for API helpers.
@ -162,7 +162,7 @@ func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string
// VerifyDatasetSpecs verifies dataset specifications for this task. // VerifyDatasetSpecs verifies dataset specifications for this task.
// This is a test compatibility method that wraps the integrity package. // This is a test compatibility method that wraps the integrity package.
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
dataDir := w.config.DataDir dataDir := w.Config.DataDir
if dataDir == "" { if dataDir == "" {
dataDir = "/tmp/data" dataDir = "/tmp/data"
} }
@ -179,16 +179,16 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
return fmt.Errorf("task is nil") return fmt.Errorf("task is nil")
} }
basePath := w.config.BasePath basePath := w.Config.BasePath
if basePath == "" { if basePath == "" {
basePath = "/tmp" basePath = "/tmp"
} }
dataDir := w.config.DataDir dataDir := w.Config.DataDir
if dataDir == "" { if dataDir == "" {
dataDir = filepath.Join(basePath, "data") dataDir = filepath.Join(basePath, "data")
} }
bestEffort := w.config.ProvenanceBestEffort bestEffort := w.Config.ProvenanceBestEffort
// Get commit_id from metadata // Get commit_id from metadata
commitID := task.Metadata["commit_id"] commitID := task.Metadata["commit_id"]
@ -289,7 +289,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return nil // No snapshot to verify return nil // No snapshot to verify
} }
dataDir := w.config.DataDir dataDir := w.Config.DataDir
if dataDir == "" { if dataDir == "" {
dataDir = "/tmp/data" dataDir = "/tmp/data"
} }
@ -324,7 +324,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
// RunJupyterTask runs a Jupyter-related task. // RunJupyterTask runs a Jupyter-related task.
// It handles start, stop, remove, restore, and list_packages actions. // It handles start, stop, remove, restore, and list_packages actions.
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w.jupyter == nil { if w.Jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured") return nil, fmt.Errorf("jupyter manager not configured")
} }
@ -350,7 +350,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
} }
req := &jupyter.StartRequest{Name: name} req := &jupyter.StartRequest{Name: name}
service, err := w.jupyter.StartService(ctx, req) service, err := w.Jupyter.StartService(ctx, req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to start jupyter service: %w", err) return nil, fmt.Errorf("failed to start jupyter service: %w", err)
} }
@ -366,7 +366,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
if serviceID == "" { if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata") return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
} }
if err := w.jupyter.StopService(ctx, serviceID); err != nil { if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
return nil, fmt.Errorf("failed to stop jupyter service: %w", err) return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
} }
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"}) return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
@ -377,7 +377,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
return nil, fmt.Errorf("missing jupyter_service_id in task metadata") return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
} }
purge := task.Metadata["jupyter_purge"] == "true" purge := task.Metadata["jupyter_purge"] == "true"
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil { if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, fmt.Errorf("failed to remove jupyter service: %w", err) return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
} }
return json.Marshal(map[string]string{"type": "remove", "status": "removed"}) return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
@ -390,7 +390,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
if name == "" { if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
} }
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name) serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err) return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
} }
@ -408,7 +408,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
return nil, fmt.Errorf("missing jupyter_name in task metadata") return nil, fmt.Errorf("missing jupyter_name in task metadata")
} }
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName) packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list installed packages: %w", err) return nil, fmt.Errorf("failed to list installed packages: %w", err)
} }
@ -429,16 +429,16 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
// Returns true if prewarming was performed, false if disabled or queue empty. // Returns true if prewarming was performed, false if disabled or queue empty.
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// Check if prewarming is enabled // Check if prewarming is enabled
if !w.config.PrewarmEnabled { if !w.Config.PrewarmEnabled {
return false, nil return false, nil
} }
// Get base path and data directory // Get base path and data directory
basePath := w.config.BasePath basePath := w.Config.BasePath
if basePath == "" { if basePath == "" {
basePath = "/tmp" basePath = "/tmp"
} }
dataDir := w.config.DataDir dataDir := w.Config.DataDir
if dataDir == "" { if dataDir == "" {
dataDir = filepath.Join(basePath, "data") dataDir = filepath.Join(basePath, "data")
} }
@ -450,12 +450,12 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
} }
// Try to get next task from queue client if available (peek, don't lease) // Try to get next task from queue client if available (peek, don't lease)
if w.queueClient != nil { if w.QueueClient != nil {
task, err := w.queueClient.PeekNextTask() task, err := w.QueueClient.PeekNextTask()
if err != nil { if err != nil {
// Queue empty - check if we have existing prewarm state // Queue empty - check if we have existing prewarm state
// Return false but preserve any existing state (don't delete) // Return false but preserve any existing state (don't delete)
state, _ := w.queueClient.GetWorkerPrewarmState(w.id) state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID)
if state != nil { if state != nil {
// We have existing state, return true to indicate prewarm is active // We have existing state, return true to indicate prewarm is active
return true, nil return true, nil
@ -489,17 +489,17 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
} }
// Store prewarm state in queue backend // Store prewarm state in queue backend
if w.queueClient != nil { if w.QueueClient != nil {
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
state := queue.PrewarmState{ state := queue.PrewarmState{
WorkerID: w.id, WorkerID: w.ID,
TaskID: task.ID, TaskID: task.ID,
SnapshotID: task.SnapshotID, SnapshotID: task.SnapshotID,
StartedAt: now, StartedAt: now,
UpdatedAt: now, UpdatedAt: now,
Phase: "staged", Phase: "staged",
} }
_ = w.queueClient.SetWorkerPrewarmState(state) _ = w.QueueClient.SetWorkerPrewarmState(state)
} }
return true, nil return true, nil
@ -507,7 +507,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
} }
// If we have a runLoop but no queue client, use runLoop (for backward compatibility) // If we have a runLoop but no queue client, use runLoop (for backward compatibility)
if w.runLoop != nil { if w.RunLoop != nil {
return true, nil return true, nil
} }
@ -517,18 +517,18 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// RunJob runs a job task. // RunJob runs a job task.
// It uses the JobRunner to execute the job and write the run manifest. // It uses the JobRunner to execute the job and write the run manifest.
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error { func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
if w.runner == nil { if w.Runner == nil {
return fmt.Errorf("job runner not configured") return fmt.Errorf("job runner not configured")
} }
basePath := w.config.BasePath basePath := w.Config.BasePath
if basePath == "" { if basePath == "" {
basePath = "/tmp" basePath = "/tmp"
} }
// Determine execution mode // Determine execution mode
mode := executor.ModeAuto mode := executor.ModeAuto
if w.config.LocalMode { if w.Config.LocalMode {
mode = executor.ModeLocal mode = executor.ModeLocal
} }
@ -536,5 +536,5 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
gpuEnv := interfaces.ExecutionEnv{} gpuEnv := interfaces.ExecutionEnv{}
// Run the job // Run the job
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv) return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
} }

View file

@ -0,0 +1,150 @@
// Package workertest provides test helpers for the worker package.
// This package is only intended for use in tests and is separate from
// production code to maintain clean separation of concerns.
package workertest
import (
"log/slog"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// SimpleManifestWriter is a basic ManifestWriter implementation for testing
type SimpleManifestWriter struct{}
func (w *SimpleManifestWriter) Upsert(dir string, task *queue.Task, mutate func(*manifest.RunManifest)) {
// Try to load existing manifest, or create new one
m, err := manifest.LoadFromDir(dir)
if err != nil {
m = w.BuildInitial(task, "")
}
mutate(m)
_ = m.WriteToDir(dir)
}
func (w *SimpleManifestWriter) BuildInitial(task *queue.Task, podmanImage string) *manifest.RunManifest {
m := manifest.NewRunManifest(
"run-"+task.ID,
task.ID,
task.JobName,
time.Now().UTC(),
)
m.CommitID = task.Metadata["commit_id"]
m.DepsManifestName = task.Metadata["deps_manifest_name"]
return m
}
// NewTestWorker creates a minimal Worker for testing purposes.
// It initializes only the fields needed for unit tests.
func NewTestWorker(cfg *worker.Config) *worker.Worker {
if cfg == nil {
cfg = &worker.Config{}
}
logger := logging.NewLogger(slog.LevelInfo, false)
metricsObj := &metrics.Metrics{}
// Create executors and runner for testing
writer := &SimpleManifestWriter{}
localExecutor := executor.NewLocalExecutor(logger, writer)
containerExecutor := executor.NewContainerExecutor(
logger,
nil,
executor.ContainerConfig{
PodmanImage: cfg.PodmanImage,
BasePath: cfg.BasePath,
},
)
jobRunner := executor.NewJobRunner(
localExecutor,
containerExecutor,
writer,
logger,
)
return &worker.Worker{
ID: cfg.WorkerID,
Config: cfg,
Logger: logger,
Metrics: metricsObj,
Health: lifecycle.NewHealthMonitor(),
Runner: jobRunner,
}
}
// NewTestWorkerWithQueue creates a test Worker with a queue client.
func NewTestWorkerWithQueue(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
w := NewTestWorker(cfg)
w.QueueClient = queueClient
return w
}
// NewTestWorkerWithJupyter creates a test Worker with Jupyter manager.
func NewTestWorkerWithJupyter(cfg *worker.Config, jupyterMgr worker.JupyterManager) *worker.Worker {
w := NewTestWorker(cfg)
w.Jupyter = jupyterMgr
return w
}
// NewTestWorkerWithRunner creates a test Worker with JobRunner initialized.
// Note: This creates a minimal runner for testing purposes.
func NewTestWorkerWithRunner(cfg *worker.Config) *worker.Worker {
return NewTestWorker(cfg)
}
// NewTestWorkerWithRunLoop creates a test Worker with RunLoop initialized.
// Note: RunLoop requires proper queue client setup.
func NewTestWorkerWithRunLoop(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
return NewTestWorkerWithQueue(cfg, queueClient)
}
// ResolveDatasets resolves dataset paths for a task.
// This version matches the test expectations for backwards compatibility.
// Priority: DatasetSpecs > Datasets > Args parsing
func ResolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
// Priority 1: DatasetSpecs
if len(task.DatasetSpecs) > 0 {
var paths []string
for _, spec := range task.DatasetSpecs {
paths = append(paths, spec.Name)
}
return paths
}
// Priority 2: Datasets
if len(task.Datasets) > 0 {
return task.Datasets
}
// Priority 3: Parse from Args
if task.Args != "" {
// Simple parsing: --datasets a,b,c or --datasets a b c
args := task.Args
if idx := strings.Index(args, "--datasets"); idx != -1 {
after := args[idx+len("--datasets "):]
after = strings.TrimSpace(after)
// Split by comma or space
if strings.Contains(after, ",") {
return strings.Split(after, ",")
}
parts := strings.Fields(after)
if len(parts) > 0 {
return parts
}
}
}
return nil
}

View file

@ -184,6 +184,6 @@ go test -tags native_libs ./tests/...
- Rebuild: `make native-clean && make native-build` - Rebuild: `make native-clean && make native-build`
**Performance regression:** **Performance regression:**
- Verify `FETCHML_NATIVE_LIBS=1` is set - Verify code is built with `-tags native_libs`
- Check benchmark: `go test -bench=BenchmarkQueue -v` - Check benchmark: `go test -bench=BenchmarkQueue -v`
- Profile with: `go test -bench=. -cpuprofile=cpu.prof` - Profile with: `go test -bench=. -cpuprofile=cpu.prof`

View file

@ -160,7 +160,7 @@ func testDatabaseConnectionFailure(t *testing.T, db *storage.DB, _ *redis.Client
// testRedisConnectionFailure tests system behavior when Redis fails // testRedisConnectionFailure tests system behavior when Redis fails
func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) { func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) {
// Add jobs to Redis queue // Add jobs to Redis queue
for i := 0; i < 10; i++ { for i := range 10 {
jobID := fmt.Sprintf("redis-chaos-job-%d", i) jobID := fmt.Sprintf("redis-chaos-job-%d", i)
err := rdb.LPush(context.Background(), "ml:queue", jobID).Err() err := rdb.LPush(context.Background(), "ml:queue", jobID).Err()
if err != nil { if err != nil {
@ -188,7 +188,7 @@ func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client)
}) })
// Wait for Redis to be available // Wait for Redis to be available
for i := 0; i < 10; i++ { for range 10 {
err := newRdb.Ping(context.Background()).Err() err := newRdb.Ping(context.Background()).Err()
if err == nil { if err == nil {
break break
@ -218,7 +218,7 @@ func testHighConcurrencyStress(t *testing.T, db *storage.DB, rdb *redis.Client)
start := time.Now() start := time.Now()
// Launch many concurrent workers // Launch many concurrent workers
for worker := 0; worker < numWorkers; worker++ { for worker := range numWorkers {
wg.Add(1) wg.Add(1)
go func(workerID int) { go func(workerID int) {
defer wg.Done() defer wg.Done()
@ -313,7 +313,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
numJobs := 50 numJobs := 50
// Create jobs with large payloads // Create jobs with large payloads
for i := 0; i < numJobs; i++ { for i := range numJobs {
jobID := fmt.Sprintf("memory-pressure-job-%d", i) jobID := fmt.Sprintf("memory-pressure-job-%d", i)
job := &storage.Job{ job := &storage.Job{
@ -337,7 +337,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
} }
// Process jobs to test memory handling during operations // Process jobs to test memory handling during operations
for i := 0; i < numJobs; i++ { for i := range numJobs {
jobID := fmt.Sprintf("memory-pressure-job-%d", i) jobID := fmt.Sprintf("memory-pressure-job-%d", i)
// Update job status // Update job status
@ -360,7 +360,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) { func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
// Simulate operations with artificial delays // Simulate operations with artificial delays
numJobs := 20 numJobs := 20
for i := 0; i < numJobs; i++ { for i := range numJobs {
jobID := fmt.Sprintf("latency-job-%d", i) jobID := fmt.Sprintf("latency-job-%d", i)
// Add artificial delay to simulate network latency // Add artificial delay to simulate network latency
@ -387,7 +387,7 @@ func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
} }
// Process jobs with latency simulation // Process jobs with latency simulation
for i := 0; i < numJobs; i++ { for i := range numJobs {
jobID := fmt.Sprintf("latency-job-%d", i) jobID := fmt.Sprintf("latency-job-%d", i)
time.Sleep(time.Millisecond * 8) time.Sleep(time.Millisecond * 8)
@ -413,7 +413,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
done := make(chan bool, numOperations) done := make(chan bool, numOperations)
errors := make(chan error, numOperations) errors := make(chan error, numOperations)
for i := 0; i < numOperations; i++ { for i := range numOperations {
go func(opID int) { go func(opID int) {
defer func() { done <- true }() defer func() { done <- true }()
@ -448,7 +448,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
} }
// Wait for all operations to complete // Wait for all operations to complete
for i := 0; i < numOperations; i++ { for range numOperations {
<-done <-done
} }
close(errors) close(errors)
@ -522,7 +522,7 @@ func setupChaosRedisIsolated(t *testing.T) *redis.Client {
func createTestJobs(t *testing.T, db *storage.DB, count int) []string { func createTestJobs(t *testing.T, db *storage.DB, count int) []string {
jobIDs := make([]string, count) jobIDs := make([]string, count)
for i := 0; i < count; i++ { for i := range count {
jobID := fmt.Sprintf("chaos-test-job-%d", i) jobID := fmt.Sprintf("chaos-test-job-%d", i)
jobIDs[i] = jobID jobIDs[i] = jobID

View file

@ -56,7 +56,7 @@ func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
}, },
} }
cmd := container.BuildPodmanCommand( cmd := container.BuildPodmanCommandLegacy(
context.Background(), context.Background(),
cfg, cfg,
"/workspace/train.py", "/workspace/train.py",
@ -100,7 +100,7 @@ func TestBuildPodmanCommand_Overrides(t *testing.T) {
CPUs: "8", CPUs: "8",
} }
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil) cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil)
if contains(cmd.Args, "--device") { if contains(cmd.Args, "--device") {
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args) t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)

View file

@ -9,6 +9,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker" "github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
) )
type fakeJupyterManager struct { type fakeJupyterManager struct {
@ -65,7 +66,7 @@ type jupyterPackagesOutput struct {
} }
func TestRunJupyterTaskStartSuccess(t *testing.T) { func TestRunJupyterTaskStartSuccess(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) { startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
if req.Name != "my-workspace" { if req.Name != "my-workspace" {
return nil, errors.New("bad name") return nil, errors.New("bad name")
@ -102,7 +103,7 @@ func TestRunJupyterTaskStartSuccess(t *testing.T) {
} }
func TestRunJupyterTaskStopFailure(t *testing.T) { func TestRunJupyterTaskStopFailure(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return errors.New("stop failed") }, stopFn: func(context.Context, string) error { return errors.New("stop failed") },
removeFn: func(context.Context, string, bool) error { return nil }, removeFn: func(context.Context, string, bool) error { return nil },
@ -123,7 +124,7 @@ func TestRunJupyterTaskStopFailure(t *testing.T) {
} }
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) { func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{ w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil }, startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return nil }, stopFn: func(context.Context, string) error { return nil },
removeFn: func(context.Context, string, bool) error { return nil }, removeFn: func(context.Context, string, bool) error { return nil },

View file

@ -10,6 +10,7 @@ import (
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker" "github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
) )
func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) { func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
@ -75,7 +76,7 @@ func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
MaxWorkers: 1, MaxWorkers: 1,
DatasetCacheTTL: 30 * time.Minute, DatasetCacheTTL: 30 * time.Minute,
} }
w := worker.NewTestWorkerWithQueue(cfg, tq) w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background()) ok, err := w.PrewarmNextOnce(context.Background())
if err != nil { if err != nil {
@ -113,7 +114,7 @@ func TestPrewarmNextOnce_Disabled_NoOp(t *testing.T) {
} }
cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false} cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false}
w := worker.NewTestWorkerWithQueue(cfg, tq) w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background()) ok, err := w.PrewarmNextOnce(context.Background())
if err != nil { if err != nil {
@ -189,7 +190,7 @@ func TestPrewarmNextOnce_QueueEmpty_DoesNotDeleteState(t *testing.T) {
MaxWorkers: 1, MaxWorkers: 1,
DatasetCacheTTL: 30 * time.Minute, DatasetCacheTTL: 30 * time.Minute,
} }
w := worker.NewTestWorkerWithQueue(cfg, tq) w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background()) ok, err := w.PrewarmNextOnce(context.Background())
if err != nil { if err != nil {

View file

@ -11,6 +11,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker" "github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
) )
func TestRunManifest_WrittenForLocalModeRun(t *testing.T) { func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
@ -22,7 +23,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
PodmanImage: "python:3.11", PodmanImage: "python:3.11",
WorkerID: "worker-test", WorkerID: "worker-test",
} }
w := worker.NewTestWorker(cfg) w := workertest.NewTestWorker(cfg)
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
expMgr := experiment.NewManager(base) expMgr := experiment.NewManager(base)

View file

@ -10,6 +10,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker" "github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
) )
func TestSelectDependencyManifestPriority(t *testing.T) { func TestSelectDependencyManifestPriority(t *testing.T) {
@ -99,7 +100,7 @@ func TestSelectDependencyManifestMissing(t *testing.T) {
} }
func TestResolveDatasetsPrecedence(t *testing.T) { func TestResolveDatasetsPrecedence(t *testing.T) {
if got := worker.ResolveDatasets(nil); got != nil { if got := workertest.ResolveDatasets(nil); got != nil {
t.Fatalf("expected nil for nil task") t.Fatalf("expected nil for nil task")
} }
@ -109,7 +110,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
Datasets: []string{"ds-legacy"}, Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args", Args: "--datasets ds-args",
} }
got := worker.ResolveDatasets(task) got := workertest.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-spec" { if len(got) != 1 || got[0] != "ds-spec" {
t.Fatalf("expected dataset_specs to win, got %v", got) t.Fatalf("expected dataset_specs to win, got %v", got)
} }
@ -120,7 +121,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
Datasets: []string{"ds-legacy"}, Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args", Args: "--datasets ds-args",
} }
got := worker.ResolveDatasets(task) got := workertest.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-legacy" { if len(got) != 1 || got[0] != "ds-legacy" {
t.Fatalf("expected datasets to win over args, got %v", got) t.Fatalf("expected datasets to win over args, got %v", got)
} }
@ -128,7 +129,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
t.Run("ArgsFallback", func(t *testing.T) { t.Run("ArgsFallback", func(t *testing.T) {
task := &queue.Task{Args: "--datasets a,b,c"} task := &queue.Task{Args: "--datasets a,b,c"}
got := worker.ResolveDatasets(task) got := workertest.ResolveDatasets(task)
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" { if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Fatalf("expected args datasets, got %v", got) t.Fatalf("expected args datasets, got %v", got)
} }
@ -234,7 +235,7 @@ func TestVerifyDatasetSpecs(t *testing.T) {
sha, err := worker.DirOverallSHA256Hex(dsPath) sha, err := worker.DirOverallSHA256Hex(dsPath)
requireNoErr(t, err) requireNoErr(t, err)
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir}) w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
task := &queue.Task{ task := &queue.Task{
JobName: "job", JobName: "job",
ID: "t1", ID: "t1",
@ -272,7 +273,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
} }
requireNoErr(t, expMgr.WriteManifest(manifest)) requireNoErr(t, expMgr.WriteManifest(manifest))
w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false}) w := workertest.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
// Missing expected fields should fail. // Missing expected fields should fail.
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}} taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
@ -296,7 +297,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
wSnap := worker.NewTestWorker(&worker.Config{ wSnap := workertest.NewTestWorker(&worker.Config{
BasePath: base, BasePath: base,
DataDir: filepath.Join(base, "data"), DataDir: filepath.Join(base, "data"),
ProvenanceBestEffort: false, ProvenanceBestEffort: false,
@ -335,7 +336,7 @@ func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true}) w := workertest.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}} task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil { if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
t.Fatalf("expected best-effort to pass, got %v", err) t.Fatalf("expected best-effort to pass, got %v", err)
@ -360,7 +361,7 @@ func TestVerifySnapshot(t *testing.T) {
sha, err := worker.DirOverallSHA256Hex(snapDir) sha, err := worker.DirOverallSHA256Hex(snapDir)
requireNoErr(t, err) requireNoErr(t, err)
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir}) w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
t.Run("Ok", func(t *testing.T) { t.Run("Ok", func(t *testing.T) {
task := &queue.Task{ task := &queue.Task{