refactor(worker): update worker tests and native bridge

**Worker Refactoring:**
- Update internal/worker/factory.go, worker.go, snapshot_store.go
- Update native_bridge.go and native_bridge_nocgo.go for native library integration

**Test Updates:**
- Update all worker unit tests for new interfaces
- Update chaos tests
- Update container/podman_test.go
- Add internal/workertest/worker.go for shared test utilities

**Documentation:**
- Update native/README.md
This commit is contained in:
Jeremie Fraeys 2026-02-23 18:04:22 -05:00
parent 4b8df60e83
commit fc2459977c
No known key found for this signature in database
13 changed files with 264 additions and 99 deletions

View file

@ -159,15 +159,15 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
runLoop: runLoop,
runner: jobRunner,
metrics: metricsObj,
health: lifecycle.NewHealthMonitor(),
resources: rm,
jupyter: jupyterMgr,
ID: cfg.WorkerID,
Config: cfg,
Logger: logger,
RunLoop: runLoop,
Runner: jobRunner,
Metrics: metricsObj,
Health: lifecycle.NewHealthMonitor(),
Resources: rm,
Jupyter: jupyterMgr,
gpuDetectionInfo: gpuDetectionInfo,
}
@ -200,23 +200,23 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
// prePullImages pulls required container images in the background
func (w *Worker) prePullImages() {
if w.config.LocalMode {
if w.Config.LocalMode {
return
}
w.logger.Info("starting image pre-pulling")
w.Logger.Info("starting image pre-pulling")
// Pull worker image
if w.config.PodmanImage != "" {
w.pullImage(w.config.PodmanImage)
if w.Config.PodmanImage != "" {
w.pullImage(w.Config.PodmanImage)
}
// Pull plugin images
for name, cfg := range w.config.Plugins {
for name, cfg := range w.Config.Plugins {
if !cfg.Enabled || cfg.Image == "" {
continue
}
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
w.pullImage(cfg.Image)
}
}
@ -228,8 +228,8 @@ func (w *Worker) pullImage(image string) {
cmd := exec.CommandContext(ctx, "podman", "pull", image)
if output, err := cmd.CombinedOutput(); err != nil {
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
} else {
w.logger.Info("image pulled successfully", "image", image)
w.Logger.Info("image pulled successfully", "image", image)
}
}

View file

@ -55,3 +55,8 @@ func (qi *QueueIndexNative) Close() {}
func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error {
return errors.New("native queue index requires native_libs build tag")
}
// DirOverallSHA256HexNative is disabled without native_libs build tag.
func DirOverallSHA256HexNative(root string) (string, error) {
return "", errors.New("native hash requires native_libs build tag")
}

View file

@ -33,3 +33,8 @@ func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) {
func ExtractTarGzNative(archivePath, dstDir string) error {
return errors.New("native tar.gz extractor requires CGO")
}
// DirOverallSHA256HexNative is disabled without CGO.
func DirOverallSHA256HexNative(root string) (string, error) {
return "", errors.New("native hash requires CGO")
}

View file

@ -19,6 +19,7 @@ import (
"github.com/minio/minio-go/v7/pkg/credentials"
)
// SnapshotFetcher is an interface for fetching snapshots
type SnapshotFetcher interface {
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
}

View file

@ -43,87 +43,87 @@ func NewMLServer(cfg *Config) (*MLServer, error) {
// Worker represents an ML task worker with composed dependencies.
type Worker struct {
id string
config *Config
logger *logging.Logger
ID string
Config *Config
Logger *logging.Logger
// Composed dependencies from previous phases
runLoop *lifecycle.RunLoop
runner *executor.JobRunner
metrics *metrics.Metrics
RunLoop *lifecycle.RunLoop
Runner *executor.JobRunner
Metrics *metrics.Metrics
metricsSrv *http.Server
health *lifecycle.HealthMonitor
resources *resources.Manager
Health *lifecycle.HealthMonitor
Resources *resources.Manager
// GPU detection metadata for status output
gpuDetectionInfo GPUDetectionInfo
// Legacy fields for backward compatibility during migration
jupyter JupyterManager
queueClient queue.Backend // Stored for prewarming access
Jupyter JupyterManager
QueueClient queue.Backend // Stored for prewarming access
}
// Start begins the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker starting",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers)
w.Logger.Info("worker starting",
"worker_id", w.ID,
"max_concurrent", w.Config.MaxWorkers)
w.health.RecordHeartbeat()
w.runLoop.Start()
w.Health.RecordHeartbeat()
w.RunLoop.Start()
}
// Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() {
w.logger.Info("worker stopping", "worker_id", w.id)
w.runLoop.Stop()
w.Logger.Info("worker stopping", "worker_id", w.ID)
w.RunLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
w.Logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
w.Logger.Info("worker stopped", "worker_id", w.ID)
}
// Shutdown performs a graceful shutdown with timeout.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
w.Logger.Info("starting graceful shutdown", "worker_id", w.ID)
w.runLoop.Stop()
w.RunLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
w.Logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
w.Logger.Info("worker shut down gracefully", "worker_id", w.ID)
return nil
}
// IsHealthy returns true if the worker is healthy.
func (w *Worker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute)
return w.Health.IsHealthy(5 * time.Minute)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
stats := w.Metrics.GetStats()
stats["worker_id"] = w.ID
stats["max_workers"] = w.Config.MaxWorkers
stats["healthy"] = w.IsHealthy()
return stats
}
// GetID returns the worker ID.
func (w *Worker) GetID() string {
return w.id
return w.ID
}
// SelectDependencyManifest re-exports the executor function for API helpers.
@ -162,7 +162,7 @@ func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string
// VerifyDatasetSpecs verifies dataset specifications for this task.
// This is a test compatibility method that wraps the integrity package.
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
dataDir := w.config.DataDir
dataDir := w.Config.DataDir
if dataDir == "" {
dataDir = "/tmp/data"
}
@ -179,16 +179,16 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
return fmt.Errorf("task is nil")
}
basePath := w.config.BasePath
basePath := w.Config.BasePath
if basePath == "" {
basePath = "/tmp"
}
dataDir := w.config.DataDir
dataDir := w.Config.DataDir
if dataDir == "" {
dataDir = filepath.Join(basePath, "data")
}
bestEffort := w.config.ProvenanceBestEffort
bestEffort := w.Config.ProvenanceBestEffort
// Get commit_id from metadata
commitID := task.Metadata["commit_id"]
@ -289,7 +289,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return nil // No snapshot to verify
}
dataDir := w.config.DataDir
dataDir := w.Config.DataDir
if dataDir == "" {
dataDir = "/tmp/data"
}
@ -324,7 +324,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
// RunJupyterTask runs a Jupyter-related task.
// It handles start, stop, remove, restore, and list_packages actions.
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w.jupyter == nil {
if w.Jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
@ -350,7 +350,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
}
req := &jupyter.StartRequest{Name: name}
service, err := w.jupyter.StartService(ctx, req)
service, err := w.Jupyter.StartService(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
}
@ -366,7 +366,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
@ -377,7 +377,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
purge := task.Metadata["jupyter_purge"] == "true"
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
@ -390,7 +390,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name)
serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
}
@ -408,7 +408,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
return nil, fmt.Errorf("missing jupyter_name in task metadata")
}
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName)
packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
if err != nil {
return nil, fmt.Errorf("failed to list installed packages: %w", err)
}
@ -429,16 +429,16 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
// Returns true if prewarming was performed, false if disabled or queue empty.
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// Check if prewarming is enabled
if !w.config.PrewarmEnabled {
if !w.Config.PrewarmEnabled {
return false, nil
}
// Get base path and data directory
basePath := w.config.BasePath
basePath := w.Config.BasePath
if basePath == "" {
basePath = "/tmp"
}
dataDir := w.config.DataDir
dataDir := w.Config.DataDir
if dataDir == "" {
dataDir = filepath.Join(basePath, "data")
}
@ -450,12 +450,12 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
}
// Try to get next task from queue client if available (peek, don't lease)
if w.queueClient != nil {
task, err := w.queueClient.PeekNextTask()
if w.QueueClient != nil {
task, err := w.QueueClient.PeekNextTask()
if err != nil {
// Queue empty - check if we have existing prewarm state
// Return false but preserve any existing state (don't delete)
state, _ := w.queueClient.GetWorkerPrewarmState(w.id)
state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID)
if state != nil {
// We have existing state, return true to indicate prewarm is active
return true, nil
@ -489,17 +489,17 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
}
// Store prewarm state in queue backend
if w.queueClient != nil {
if w.QueueClient != nil {
now := time.Now().UTC().Format(time.RFC3339)
state := queue.PrewarmState{
WorkerID: w.id,
WorkerID: w.ID,
TaskID: task.ID,
SnapshotID: task.SnapshotID,
StartedAt: now,
UpdatedAt: now,
Phase: "staged",
}
_ = w.queueClient.SetWorkerPrewarmState(state)
_ = w.QueueClient.SetWorkerPrewarmState(state)
}
return true, nil
@ -507,7 +507,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
}
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
if w.runLoop != nil {
if w.RunLoop != nil {
return true, nil
}
@ -517,18 +517,18 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// RunJob runs a job task.
// It uses the JobRunner to execute the job and write the run manifest.
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
if w.runner == nil {
if w.Runner == nil {
return fmt.Errorf("job runner not configured")
}
basePath := w.config.BasePath
basePath := w.Config.BasePath
if basePath == "" {
basePath = "/tmp"
}
// Determine execution mode
mode := executor.ModeAuto
if w.config.LocalMode {
if w.Config.LocalMode {
mode = executor.ModeLocal
}
@ -536,5 +536,5 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
gpuEnv := interfaces.ExecutionEnv{}
// Run the job
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv)
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
}

View file

@ -0,0 +1,150 @@
// Package workertest provides test helpers for the worker package.
// This package is only intended for use in tests and is separate from
// production code to maintain clean separation of concerns.
package workertest
import (
"log/slog"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// SimpleManifestWriter is a basic ManifestWriter implementation for testing
type SimpleManifestWriter struct{}
func (w *SimpleManifestWriter) Upsert(dir string, task *queue.Task, mutate func(*manifest.RunManifest)) {
// Try to load existing manifest, or create new one
m, err := manifest.LoadFromDir(dir)
if err != nil {
m = w.BuildInitial(task, "")
}
mutate(m)
_ = m.WriteToDir(dir)
}
func (w *SimpleManifestWriter) BuildInitial(task *queue.Task, podmanImage string) *manifest.RunManifest {
m := manifest.NewRunManifest(
"run-"+task.ID,
task.ID,
task.JobName,
time.Now().UTC(),
)
m.CommitID = task.Metadata["commit_id"]
m.DepsManifestName = task.Metadata["deps_manifest_name"]
return m
}
// NewTestWorker creates a minimal Worker for testing purposes.
// It initializes only the fields needed for unit tests.
func NewTestWorker(cfg *worker.Config) *worker.Worker {
if cfg == nil {
cfg = &worker.Config{}
}
logger := logging.NewLogger(slog.LevelInfo, false)
metricsObj := &metrics.Metrics{}
// Create executors and runner for testing
writer := &SimpleManifestWriter{}
localExecutor := executor.NewLocalExecutor(logger, writer)
containerExecutor := executor.NewContainerExecutor(
logger,
nil,
executor.ContainerConfig{
PodmanImage: cfg.PodmanImage,
BasePath: cfg.BasePath,
},
)
jobRunner := executor.NewJobRunner(
localExecutor,
containerExecutor,
writer,
logger,
)
return &worker.Worker{
ID: cfg.WorkerID,
Config: cfg,
Logger: logger,
Metrics: metricsObj,
Health: lifecycle.NewHealthMonitor(),
Runner: jobRunner,
}
}
// NewTestWorkerWithQueue creates a test Worker with a queue client.
func NewTestWorkerWithQueue(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
w := NewTestWorker(cfg)
w.QueueClient = queueClient
return w
}
// NewTestWorkerWithJupyter creates a test Worker with Jupyter manager.
func NewTestWorkerWithJupyter(cfg *worker.Config, jupyterMgr worker.JupyterManager) *worker.Worker {
w := NewTestWorker(cfg)
w.Jupyter = jupyterMgr
return w
}
// NewTestWorkerWithRunner creates a test Worker with JobRunner initialized.
// Note: This creates a minimal runner for testing purposes.
func NewTestWorkerWithRunner(cfg *worker.Config) *worker.Worker {
return NewTestWorker(cfg)
}
// NewTestWorkerWithRunLoop creates a test Worker with RunLoop initialized.
// Note: RunLoop requires proper queue client setup.
func NewTestWorkerWithRunLoop(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
return NewTestWorkerWithQueue(cfg, queueClient)
}
// ResolveDatasets resolves dataset paths for a task.
// This version matches the test expectations for backwards compatibility.
// Priority: DatasetSpecs > Datasets > Args parsing
func ResolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
// Priority 1: DatasetSpecs
if len(task.DatasetSpecs) > 0 {
var paths []string
for _, spec := range task.DatasetSpecs {
paths = append(paths, spec.Name)
}
return paths
}
// Priority 2: Datasets
if len(task.Datasets) > 0 {
return task.Datasets
}
// Priority 3: Parse from Args
if task.Args != "" {
// Simple parsing: --datasets a,b,c or --datasets a b c
args := task.Args
if idx := strings.Index(args, "--datasets"); idx != -1 {
after := args[idx+len("--datasets "):]
after = strings.TrimSpace(after)
// Split by comma or space
if strings.Contains(after, ",") {
return strings.Split(after, ",")
}
parts := strings.Fields(after)
if len(parts) > 0 {
return parts
}
}
}
return nil
}

View file

@ -184,6 +184,6 @@ go test -tags native_libs ./tests/...
- Rebuild: `make native-clean && make native-build`
**Performance regression:**
- Verify `FETCHML_NATIVE_LIBS=1` is set
- Verify code is built with `-tags native_libs`
- Check benchmark: `go test -bench=BenchmarkQueue -v`
- Profile with: `go test -bench=. -cpuprofile=cpu.prof`

View file

@ -160,7 +160,7 @@ func testDatabaseConnectionFailure(t *testing.T, db *storage.DB, _ *redis.Client
// testRedisConnectionFailure tests system behavior when Redis fails
func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) {
// Add jobs to Redis queue
for i := 0; i < 10; i++ {
for i := range 10 {
jobID := fmt.Sprintf("redis-chaos-job-%d", i)
err := rdb.LPush(context.Background(), "ml:queue", jobID).Err()
if err != nil {
@ -188,7 +188,7 @@ func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client)
})
// Wait for Redis to be available
for i := 0; i < 10; i++ {
for range 10 {
err := newRdb.Ping(context.Background()).Err()
if err == nil {
break
@ -218,7 +218,7 @@ func testHighConcurrencyStress(t *testing.T, db *storage.DB, rdb *redis.Client)
start := time.Now()
// Launch many concurrent workers
for worker := 0; worker < numWorkers; worker++ {
for worker := range numWorkers {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
@ -313,7 +313,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
numJobs := 50
// Create jobs with large payloads
for i := 0; i < numJobs; i++ {
for i := range numJobs {
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
job := &storage.Job{
@ -337,7 +337,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
}
// Process jobs to test memory handling during operations
for i := 0; i < numJobs; i++ {
for i := range numJobs {
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
// Update job status
@ -360,7 +360,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
// Simulate operations with artificial delays
numJobs := 20
for i := 0; i < numJobs; i++ {
for i := range numJobs {
jobID := fmt.Sprintf("latency-job-%d", i)
// Add artificial delay to simulate network latency
@ -387,7 +387,7 @@ func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
}
// Process jobs with latency simulation
for i := 0; i < numJobs; i++ {
for i := range numJobs {
jobID := fmt.Sprintf("latency-job-%d", i)
time.Sleep(time.Millisecond * 8)
@ -413,7 +413,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
done := make(chan bool, numOperations)
errors := make(chan error, numOperations)
for i := 0; i < numOperations; i++ {
for i := range numOperations {
go func(opID int) {
defer func() { done <- true }()
@ -448,7 +448,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
}
// Wait for all operations to complete
for i := 0; i < numOperations; i++ {
for range numOperations {
<-done
}
close(errors)
@ -522,7 +522,7 @@ func setupChaosRedisIsolated(t *testing.T) *redis.Client {
func createTestJobs(t *testing.T, db *storage.DB, count int) []string {
jobIDs := make([]string, count)
for i := 0; i < count; i++ {
for i := range count {
jobID := fmt.Sprintf("chaos-test-job-%d", i)
jobIDs[i] = jobID

View file

@ -56,7 +56,7 @@ func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
},
}
cmd := container.BuildPodmanCommand(
cmd := container.BuildPodmanCommandLegacy(
context.Background(),
cfg,
"/workspace/train.py",
@ -100,7 +100,7 @@ func TestBuildPodmanCommand_Overrides(t *testing.T) {
CPUs: "8",
}
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil)
cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil)
if contains(cmd.Args, "--device") {
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)

View file

@ -9,6 +9,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
)
type fakeJupyterManager struct {
@ -65,7 +66,7 @@ type jupyterPackagesOutput struct {
}
func TestRunJupyterTaskStartSuccess(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
if req.Name != "my-workspace" {
return nil, errors.New("bad name")
@ -102,7 +103,7 @@ func TestRunJupyterTaskStartSuccess(t *testing.T) {
}
func TestRunJupyterTaskStopFailure(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
removeFn: func(context.Context, string, bool) error { return nil },
@ -123,7 +124,7 @@ func TestRunJupyterTaskStopFailure(t *testing.T) {
}
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
stopFn: func(context.Context, string) error { return nil },
removeFn: func(context.Context, string, bool) error { return nil },

View file

@ -10,6 +10,7 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
)
func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
@ -75,7 +76,7 @@ func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
MaxWorkers: 1,
DatasetCacheTTL: 30 * time.Minute,
}
w := worker.NewTestWorkerWithQueue(cfg, tq)
w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background())
if err != nil {
@ -113,7 +114,7 @@ func TestPrewarmNextOnce_Disabled_NoOp(t *testing.T) {
}
cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false}
w := worker.NewTestWorkerWithQueue(cfg, tq)
w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background())
if err != nil {
@ -189,7 +190,7 @@ func TestPrewarmNextOnce_QueueEmpty_DoesNotDeleteState(t *testing.T) {
MaxWorkers: 1,
DatasetCacheTTL: 30 * time.Minute,
}
w := worker.NewTestWorkerWithQueue(cfg, tq)
w := workertest.NewTestWorkerWithQueue(cfg, tq)
ok, err := w.PrewarmNextOnce(context.Background())
if err != nil {

View file

@ -11,6 +11,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
)
func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
@ -22,7 +23,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
PodmanImage: "python:3.11",
WorkerID: "worker-test",
}
w := worker.NewTestWorker(cfg)
w := workertest.NewTestWorker(cfg)
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
expMgr := experiment.NewManager(base)

View file

@ -10,6 +10,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/workertest"
)
func TestSelectDependencyManifestPriority(t *testing.T) {
@ -99,7 +100,7 @@ func TestSelectDependencyManifestMissing(t *testing.T) {
}
func TestResolveDatasetsPrecedence(t *testing.T) {
if got := worker.ResolveDatasets(nil); got != nil {
if got := workertest.ResolveDatasets(nil); got != nil {
t.Fatalf("expected nil for nil task")
}
@ -109,7 +110,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args",
}
got := worker.ResolveDatasets(task)
got := workertest.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-spec" {
t.Fatalf("expected dataset_specs to win, got %v", got)
}
@ -120,7 +121,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args",
}
got := worker.ResolveDatasets(task)
got := workertest.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-legacy" {
t.Fatalf("expected datasets to win over args, got %v", got)
}
@ -128,7 +129,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
t.Run("ArgsFallback", func(t *testing.T) {
task := &queue.Task{Args: "--datasets a,b,c"}
got := worker.ResolveDatasets(task)
got := workertest.ResolveDatasets(task)
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Fatalf("expected args datasets, got %v", got)
}
@ -234,7 +235,7 @@ func TestVerifyDatasetSpecs(t *testing.T) {
sha, err := worker.DirOverallSHA256Hex(dsPath)
requireNoErr(t, err)
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
task := &queue.Task{
JobName: "job",
ID: "t1",
@ -272,7 +273,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
}
requireNoErr(t, expMgr.WriteManifest(manifest))
w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
w := workertest.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
// Missing expected fields should fail.
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
@ -296,7 +297,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
wSnap := worker.NewTestWorker(&worker.Config{
wSnap := workertest.NewTestWorker(&worker.Config{
BasePath: base,
DataDir: filepath.Join(base, "data"),
ProvenanceBestEffort: false,
@ -335,7 +336,7 @@ func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
w := workertest.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
t.Fatalf("expected best-effort to pass, got %v", err)
@ -360,7 +361,7 @@ func TestVerifySnapshot(t *testing.T) {
sha, err := worker.DirOverallSHA256Hex(snapDir)
requireNoErr(t, err)
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
t.Run("Ok", func(t *testing.T) {
task := &queue.Task{