refactor(worker): update worker tests and native bridge
**Worker Refactoring:** - Update internal/worker/factory.go, worker.go, snapshot_store.go - Update native_bridge.go and native_bridge_nocgo.go for native library integration **Test Updates:** - Update all worker unit tests for new interfaces - Update chaos tests - Update container/podman_test.go - Add internal/workertest/worker.go for shared test utilities **Documentation:** - Update native/README.md
This commit is contained in:
parent
4b8df60e83
commit
fc2459977c
13 changed files with 264 additions and 99 deletions
|
|
@ -159,15 +159,15 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
runLoop: runLoop,
|
||||
runner: jobRunner,
|
||||
metrics: metricsObj,
|
||||
health: lifecycle.NewHealthMonitor(),
|
||||
resources: rm,
|
||||
jupyter: jupyterMgr,
|
||||
ID: cfg.WorkerID,
|
||||
Config: cfg,
|
||||
Logger: logger,
|
||||
RunLoop: runLoop,
|
||||
Runner: jobRunner,
|
||||
Metrics: metricsObj,
|
||||
Health: lifecycle.NewHealthMonitor(),
|
||||
Resources: rm,
|
||||
Jupyter: jupyterMgr,
|
||||
gpuDetectionInfo: gpuDetectionInfo,
|
||||
}
|
||||
|
||||
|
|
@ -200,23 +200,23 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
|
||||
// prePullImages pulls required container images in the background
|
||||
func (w *Worker) prePullImages() {
|
||||
if w.config.LocalMode {
|
||||
if w.Config.LocalMode {
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("starting image pre-pulling")
|
||||
w.Logger.Info("starting image pre-pulling")
|
||||
|
||||
// Pull worker image
|
||||
if w.config.PodmanImage != "" {
|
||||
w.pullImage(w.config.PodmanImage)
|
||||
if w.Config.PodmanImage != "" {
|
||||
w.pullImage(w.Config.PodmanImage)
|
||||
}
|
||||
|
||||
// Pull plugin images
|
||||
for name, cfg := range w.config.Plugins {
|
||||
for name, cfg := range w.Config.Plugins {
|
||||
if !cfg.Enabled || cfg.Image == "" {
|
||||
continue
|
||||
}
|
||||
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||
w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||
w.pullImage(cfg.Image)
|
||||
}
|
||||
}
|
||||
|
|
@ -228,8 +228,8 @@ func (w *Worker) pullImage(image string) {
|
|||
|
||||
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||
w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||
} else {
|
||||
w.logger.Info("image pulled successfully", "image", image)
|
||||
w.Logger.Info("image pulled successfully", "image", image)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,3 +55,8 @@ func (qi *QueueIndexNative) Close() {}
|
|||
func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error {
|
||||
return errors.New("native queue index requires native_libs build tag")
|
||||
}
|
||||
|
||||
// DirOverallSHA256HexNative is disabled without native_libs build tag.
|
||||
func DirOverallSHA256HexNative(root string) (string, error) {
|
||||
return "", errors.New("native hash requires native_libs build tag")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,3 +33,8 @@ func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) {
|
|||
func ExtractTarGzNative(archivePath, dstDir string) error {
|
||||
return errors.New("native tar.gz extractor requires CGO")
|
||||
}
|
||||
|
||||
// DirOverallSHA256HexNative is disabled without CGO.
|
||||
func DirOverallSHA256HexNative(root string) (string, error) {
|
||||
return "", errors.New("native hash requires CGO")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
// SnapshotFetcher is an interface for fetching snapshots
|
||||
type SnapshotFetcher interface {
|
||||
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,87 +43,87 @@ func NewMLServer(cfg *Config) (*MLServer, error) {
|
|||
|
||||
// Worker represents an ML task worker with composed dependencies.
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
logger *logging.Logger
|
||||
ID string
|
||||
Config *Config
|
||||
Logger *logging.Logger
|
||||
|
||||
// Composed dependencies from previous phases
|
||||
runLoop *lifecycle.RunLoop
|
||||
runner *executor.JobRunner
|
||||
metrics *metrics.Metrics
|
||||
RunLoop *lifecycle.RunLoop
|
||||
Runner *executor.JobRunner
|
||||
Metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
health *lifecycle.HealthMonitor
|
||||
resources *resources.Manager
|
||||
Health *lifecycle.HealthMonitor
|
||||
Resources *resources.Manager
|
||||
|
||||
// GPU detection metadata for status output
|
||||
gpuDetectionInfo GPUDetectionInfo
|
||||
|
||||
// Legacy fields for backward compatibility during migration
|
||||
jupyter JupyterManager
|
||||
queueClient queue.Backend // Stored for prewarming access
|
||||
Jupyter JupyterManager
|
||||
QueueClient queue.Backend // Stored for prewarming access
|
||||
}
|
||||
|
||||
// Start begins the worker's main processing loop.
|
||||
func (w *Worker) Start() {
|
||||
w.logger.Info("worker starting",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers)
|
||||
w.Logger.Info("worker starting",
|
||||
"worker_id", w.ID,
|
||||
"max_concurrent", w.Config.MaxWorkers)
|
||||
|
||||
w.health.RecordHeartbeat()
|
||||
w.runLoop.Start()
|
||||
w.Health.RecordHeartbeat()
|
||||
w.RunLoop.Start()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the worker immediately.
|
||||
func (w *Worker) Stop() {
|
||||
w.logger.Info("worker stopping", "worker_id", w.id)
|
||||
w.runLoop.Stop()
|
||||
w.Logger.Info("worker stopping", "worker_id", w.ID)
|
||||
w.RunLoop.Stop()
|
||||
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
||||
w.Logger.Warn("metrics server shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
||||
w.Logger.Info("worker stopped", "worker_id", w.ID)
|
||||
}
|
||||
|
||||
// Shutdown performs a graceful shutdown with timeout.
|
||||
func (w *Worker) Shutdown() error {
|
||||
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
|
||||
w.Logger.Info("starting graceful shutdown", "worker_id", w.ID)
|
||||
|
||||
w.runLoop.Stop()
|
||||
w.RunLoop.Stop()
|
||||
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
||||
w.Logger.Warn("metrics server shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
|
||||
w.Logger.Info("worker shut down gracefully", "worker_id", w.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the worker is healthy.
|
||||
func (w *Worker) IsHealthy() bool {
|
||||
return w.health.IsHealthy(5 * time.Minute)
|
||||
return w.Health.IsHealthy(5 * time.Minute)
|
||||
}
|
||||
|
||||
// GetMetrics returns current worker metrics.
|
||||
func (w *Worker) GetMetrics() map[string]any {
|
||||
stats := w.metrics.GetStats()
|
||||
stats["worker_id"] = w.id
|
||||
stats["max_workers"] = w.config.MaxWorkers
|
||||
stats := w.Metrics.GetStats()
|
||||
stats["worker_id"] = w.ID
|
||||
stats["max_workers"] = w.Config.MaxWorkers
|
||||
stats["healthy"] = w.IsHealthy()
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetID returns the worker ID.
|
||||
func (w *Worker) GetID() string {
|
||||
return w.id
|
||||
return w.ID
|
||||
}
|
||||
|
||||
// SelectDependencyManifest re-exports the executor function for API helpers.
|
||||
|
|
@ -162,7 +162,7 @@ func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string
|
|||
// VerifyDatasetSpecs verifies dataset specifications for this task.
|
||||
// This is a test compatibility method that wraps the integrity package.
|
||||
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||
dataDir := w.config.DataDir
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = "/tmp/data"
|
||||
}
|
||||
|
|
@ -179,16 +179,16 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
|
|||
return fmt.Errorf("task is nil")
|
||||
}
|
||||
|
||||
basePath := w.config.BasePath
|
||||
basePath := w.Config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/tmp"
|
||||
}
|
||||
dataDir := w.config.DataDir
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Join(basePath, "data")
|
||||
}
|
||||
|
||||
bestEffort := w.config.ProvenanceBestEffort
|
||||
bestEffort := w.Config.ProvenanceBestEffort
|
||||
|
||||
// Get commit_id from metadata
|
||||
commitID := task.Metadata["commit_id"]
|
||||
|
|
@ -289,7 +289,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|||
return nil // No snapshot to verify
|
||||
}
|
||||
|
||||
dataDir := w.config.DataDir
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = "/tmp/data"
|
||||
}
|
||||
|
|
@ -324,7 +324,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|||
// RunJupyterTask runs a Jupyter-related task.
|
||||
// It handles start, stop, remove, restore, and list_packages actions.
|
||||
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
if w.jupyter == nil {
|
||||
if w.Jupyter == nil {
|
||||
return nil, fmt.Errorf("jupyter manager not configured")
|
||||
}
|
||||
|
||||
|
|
@ -350,7 +350,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
}
|
||||
|
||||
req := &jupyter.StartRequest{Name: name}
|
||||
service, err := w.jupyter.StartService(ctx, req)
|
||||
service, err := w.Jupyter.StartService(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
||||
}
|
||||
|
|
@ -366,7 +366,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
||||
if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
|
||||
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
||||
|
|
@ -377,7 +377,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
purge := task.Metadata["jupyter_purge"] == "true"
|
||||
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
||||
|
|
@ -390,7 +390,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||
}
|
||||
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name)
|
||||
serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
||||
}
|
||||
|
|
@ -408,7 +408,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
||||
}
|
||||
|
||||
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName)
|
||||
packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
||||
}
|
||||
|
|
@ -429,16 +429,16 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
|||
// Returns true if prewarming was performed, false if disabled or queue empty.
|
||||
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||
// Check if prewarming is enabled
|
||||
if !w.config.PrewarmEnabled {
|
||||
if !w.Config.PrewarmEnabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get base path and data directory
|
||||
basePath := w.config.BasePath
|
||||
basePath := w.Config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/tmp"
|
||||
}
|
||||
dataDir := w.config.DataDir
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Join(basePath, "data")
|
||||
}
|
||||
|
|
@ -450,12 +450,12 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|||
}
|
||||
|
||||
// Try to get next task from queue client if available (peek, don't lease)
|
||||
if w.queueClient != nil {
|
||||
task, err := w.queueClient.PeekNextTask()
|
||||
if w.QueueClient != nil {
|
||||
task, err := w.QueueClient.PeekNextTask()
|
||||
if err != nil {
|
||||
// Queue empty - check if we have existing prewarm state
|
||||
// Return false but preserve any existing state (don't delete)
|
||||
state, _ := w.queueClient.GetWorkerPrewarmState(w.id)
|
||||
state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID)
|
||||
if state != nil {
|
||||
// We have existing state, return true to indicate prewarm is active
|
||||
return true, nil
|
||||
|
|
@ -489,17 +489,17 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|||
}
|
||||
|
||||
// Store prewarm state in queue backend
|
||||
if w.queueClient != nil {
|
||||
if w.QueueClient != nil {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
state := queue.PrewarmState{
|
||||
WorkerID: w.id,
|
||||
WorkerID: w.ID,
|
||||
TaskID: task.ID,
|
||||
SnapshotID: task.SnapshotID,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now,
|
||||
Phase: "staged",
|
||||
}
|
||||
_ = w.queueClient.SetWorkerPrewarmState(state)
|
||||
_ = w.QueueClient.SetWorkerPrewarmState(state)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
|
@ -507,7 +507,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|||
}
|
||||
|
||||
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
|
||||
if w.runLoop != nil {
|
||||
if w.RunLoop != nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
|
@ -517,18 +517,18 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|||
// RunJob runs a job task.
|
||||
// It uses the JobRunner to execute the job and write the run manifest.
|
||||
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
|
||||
if w.runner == nil {
|
||||
if w.Runner == nil {
|
||||
return fmt.Errorf("job runner not configured")
|
||||
}
|
||||
|
||||
basePath := w.config.BasePath
|
||||
basePath := w.Config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/tmp"
|
||||
}
|
||||
|
||||
// Determine execution mode
|
||||
mode := executor.ModeAuto
|
||||
if w.config.LocalMode {
|
||||
if w.Config.LocalMode {
|
||||
mode = executor.ModeLocal
|
||||
}
|
||||
|
||||
|
|
@ -536,5 +536,5 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
|
|||
gpuEnv := interfaces.ExecutionEnv{}
|
||||
|
||||
// Run the job
|
||||
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv)
|
||||
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
|
||||
}
|
||||
|
|
|
|||
150
internal/workertest/worker.go
Normal file
150
internal/workertest/worker.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
// Package workertest provides test helpers for the worker package.
|
||||
// This package is only intended for use in tests and is separate from
|
||||
// production code to maintain clean separation of concerns.
|
||||
package workertest
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||
)
|
||||
|
||||
// SimpleManifestWriter is a basic ManifestWriter implementation for testing
|
||||
type SimpleManifestWriter struct{}
|
||||
|
||||
func (w *SimpleManifestWriter) Upsert(dir string, task *queue.Task, mutate func(*manifest.RunManifest)) {
|
||||
// Try to load existing manifest, or create new one
|
||||
m, err := manifest.LoadFromDir(dir)
|
||||
if err != nil {
|
||||
m = w.BuildInitial(task, "")
|
||||
}
|
||||
mutate(m)
|
||||
_ = m.WriteToDir(dir)
|
||||
}
|
||||
|
||||
func (w *SimpleManifestWriter) BuildInitial(task *queue.Task, podmanImage string) *manifest.RunManifest {
|
||||
m := manifest.NewRunManifest(
|
||||
"run-"+task.ID,
|
||||
task.ID,
|
||||
task.JobName,
|
||||
time.Now().UTC(),
|
||||
)
|
||||
m.CommitID = task.Metadata["commit_id"]
|
||||
m.DepsManifestName = task.Metadata["deps_manifest_name"]
|
||||
return m
|
||||
}
|
||||
|
||||
// NewTestWorker creates a minimal Worker for testing purposes.
|
||||
// It initializes only the fields needed for unit tests.
|
||||
func NewTestWorker(cfg *worker.Config) *worker.Worker {
|
||||
if cfg == nil {
|
||||
cfg = &worker.Config{}
|
||||
}
|
||||
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
metricsObj := &metrics.Metrics{}
|
||||
|
||||
// Create executors and runner for testing
|
||||
writer := &SimpleManifestWriter{}
|
||||
localExecutor := executor.NewLocalExecutor(logger, writer)
|
||||
containerExecutor := executor.NewContainerExecutor(
|
||||
logger,
|
||||
nil,
|
||||
executor.ContainerConfig{
|
||||
PodmanImage: cfg.PodmanImage,
|
||||
BasePath: cfg.BasePath,
|
||||
},
|
||||
)
|
||||
jobRunner := executor.NewJobRunner(
|
||||
localExecutor,
|
||||
containerExecutor,
|
||||
writer,
|
||||
logger,
|
||||
)
|
||||
|
||||
return &worker.Worker{
|
||||
ID: cfg.WorkerID,
|
||||
Config: cfg,
|
||||
Logger: logger,
|
||||
Metrics: metricsObj,
|
||||
Health: lifecycle.NewHealthMonitor(),
|
||||
Runner: jobRunner,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestWorkerWithQueue creates a test Worker with a queue client.
|
||||
func NewTestWorkerWithQueue(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
|
||||
w := NewTestWorker(cfg)
|
||||
w.QueueClient = queueClient
|
||||
return w
|
||||
}
|
||||
|
||||
// NewTestWorkerWithJupyter creates a test Worker with Jupyter manager.
|
||||
func NewTestWorkerWithJupyter(cfg *worker.Config, jupyterMgr worker.JupyterManager) *worker.Worker {
|
||||
w := NewTestWorker(cfg)
|
||||
w.Jupyter = jupyterMgr
|
||||
return w
|
||||
}
|
||||
|
||||
// NewTestWorkerWithRunner creates a test Worker with JobRunner initialized.
|
||||
// Note: This creates a minimal runner for testing purposes.
|
||||
func NewTestWorkerWithRunner(cfg *worker.Config) *worker.Worker {
|
||||
return NewTestWorker(cfg)
|
||||
}
|
||||
|
||||
// NewTestWorkerWithRunLoop creates a test Worker with RunLoop initialized.
|
||||
// Note: RunLoop requires proper queue client setup.
|
||||
func NewTestWorkerWithRunLoop(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
|
||||
return NewTestWorkerWithQueue(cfg, queueClient)
|
||||
}
|
||||
|
||||
// ResolveDatasets resolves dataset paths for a task.
|
||||
// This version matches the test expectations for backwards compatibility.
|
||||
// Priority: DatasetSpecs > Datasets > Args parsing
|
||||
func ResolveDatasets(task *queue.Task) []string {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: DatasetSpecs
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
var paths []string
|
||||
for _, spec := range task.DatasetSpecs {
|
||||
paths = append(paths, spec.Name)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
// Priority 2: Datasets
|
||||
if len(task.Datasets) > 0 {
|
||||
return task.Datasets
|
||||
}
|
||||
|
||||
// Priority 3: Parse from Args
|
||||
if task.Args != "" {
|
||||
// Simple parsing: --datasets a,b,c or --datasets a b c
|
||||
args := task.Args
|
||||
if idx := strings.Index(args, "--datasets"); idx != -1 {
|
||||
after := args[idx+len("--datasets "):]
|
||||
after = strings.TrimSpace(after)
|
||||
// Split by comma or space
|
||||
if strings.Contains(after, ",") {
|
||||
return strings.Split(after, ",")
|
||||
}
|
||||
parts := strings.Fields(after)
|
||||
if len(parts) > 0 {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -184,6 +184,6 @@ go test -tags native_libs ./tests/...
|
|||
- Rebuild: `make native-clean && make native-build`
|
||||
|
||||
**Performance regression:**
|
||||
- Verify `FETCHML_NATIVE_LIBS=1` is set
|
||||
- Verify code is built with `-tags native_libs`
|
||||
- Check benchmark: `go test -bench=BenchmarkQueue -v`
|
||||
- Profile with: `go test -bench=. -cpuprofile=cpu.prof`
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ func testDatabaseConnectionFailure(t *testing.T, db *storage.DB, _ *redis.Client
|
|||
// testRedisConnectionFailure tests system behavior when Redis fails
|
||||
func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) {
|
||||
// Add jobs to Redis queue
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
jobID := fmt.Sprintf("redis-chaos-job-%d", i)
|
||||
err := rdb.LPush(context.Background(), "ml:queue", jobID).Err()
|
||||
if err != nil {
|
||||
|
|
@ -188,7 +188,7 @@ func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client)
|
|||
})
|
||||
|
||||
// Wait for Redis to be available
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
err := newRdb.Ping(context.Background()).Err()
|
||||
if err == nil {
|
||||
break
|
||||
|
|
@ -218,7 +218,7 @@ func testHighConcurrencyStress(t *testing.T, db *storage.DB, rdb *redis.Client)
|
|||
start := time.Now()
|
||||
|
||||
// Launch many concurrent workers
|
||||
for worker := 0; worker < numWorkers; worker++ {
|
||||
for worker := range numWorkers {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
|
@ -313,7 +313,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
numJobs := 50
|
||||
|
||||
// Create jobs with large payloads
|
||||
for i := 0; i < numJobs; i++ {
|
||||
for i := range numJobs {
|
||||
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
||||
|
||||
job := &storage.Job{
|
||||
|
|
@ -337,7 +337,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
}
|
||||
|
||||
// Process jobs to test memory handling during operations
|
||||
for i := 0; i < numJobs; i++ {
|
||||
for i := range numJobs {
|
||||
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
||||
|
||||
// Update job status
|
||||
|
|
@ -360,7 +360,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||
// Simulate operations with artificial delays
|
||||
numJobs := 20
|
||||
for i := 0; i < numJobs; i++ {
|
||||
for i := range numJobs {
|
||||
jobID := fmt.Sprintf("latency-job-%d", i)
|
||||
|
||||
// Add artificial delay to simulate network latency
|
||||
|
|
@ -387,7 +387,7 @@ func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
}
|
||||
|
||||
// Process jobs with latency simulation
|
||||
for i := 0; i < numJobs; i++ {
|
||||
for i := range numJobs {
|
||||
jobID := fmt.Sprintf("latency-job-%d", i)
|
||||
|
||||
time.Sleep(time.Millisecond * 8)
|
||||
|
|
@ -413,7 +413,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
done := make(chan bool, numOperations)
|
||||
errors := make(chan error, numOperations)
|
||||
|
||||
for i := 0; i < numOperations; i++ {
|
||||
for i := range numOperations {
|
||||
go func(opID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
|
|
@ -448,7 +448,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
|||
}
|
||||
|
||||
// Wait for all operations to complete
|
||||
for i := 0; i < numOperations; i++ {
|
||||
for range numOperations {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
|
@ -522,7 +522,7 @@ func setupChaosRedisIsolated(t *testing.T) *redis.Client {
|
|||
|
||||
func createTestJobs(t *testing.T, db *storage.DB, count int) []string {
|
||||
jobIDs := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
for i := range count {
|
||||
jobID := fmt.Sprintf("chaos-test-job-%d", i)
|
||||
jobIDs[i] = jobID
|
||||
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
cmd := container.BuildPodmanCommand(
|
||||
cmd := container.BuildPodmanCommandLegacy(
|
||||
context.Background(),
|
||||
cfg,
|
||||
"/workspace/train.py",
|
||||
|
|
@ -100,7 +100,7 @@ func TestBuildPodmanCommand_Overrides(t *testing.T) {
|
|||
CPUs: "8",
|
||||
}
|
||||
|
||||
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil)
|
||||
cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil)
|
||||
|
||||
if contains(cmd.Args, "--device") {
|
||||
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||
)
|
||||
|
||||
type fakeJupyterManager struct {
|
||||
|
|
@ -65,7 +66,7 @@ type jupyterPackagesOutput struct {
|
|||
}
|
||||
|
||||
func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
||||
if req.Name != "my-workspace" {
|
||||
return nil, errors.New("bad name")
|
||||
|
|
@ -102,7 +103,7 @@ func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRunJupyterTaskStopFailure(t *testing.T) {
|
||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
|
||||
removeFn: func(context.Context, string, bool) error { return nil },
|
||||
|
|
@ -123,7 +124,7 @@ func TestRunJupyterTaskStopFailure(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
|
||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||
stopFn: func(context.Context, string) error { return nil },
|
||||
removeFn: func(context.Context, string, bool) error { return nil },
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||
)
|
||||
|
||||
func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
|
||||
|
|
@ -75,7 +76,7 @@ func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
|
|||
MaxWorkers: 1,
|
||||
DatasetCacheTTL: 30 * time.Minute,
|
||||
}
|
||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
||||
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||
|
||||
ok, err := w.PrewarmNextOnce(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -113,7 +114,7 @@ func TestPrewarmNextOnce_Disabled_NoOp(t *testing.T) {
|
|||
}
|
||||
|
||||
cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false}
|
||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
||||
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||
|
||||
ok, err := w.PrewarmNextOnce(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -189,7 +190,7 @@ func TestPrewarmNextOnce_QueueEmpty_DoesNotDeleteState(t *testing.T) {
|
|||
MaxWorkers: 1,
|
||||
DatasetCacheTTL: 30 * time.Minute,
|
||||
}
|
||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
||||
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||
|
||||
ok, err := w.PrewarmNextOnce(context.Background())
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||
)
|
||||
|
||||
func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
||||
|
|
@ -22,7 +23,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
|||
PodmanImage: "python:3.11",
|
||||
WorkerID: "worker-test",
|
||||
}
|
||||
w := worker.NewTestWorker(cfg)
|
||||
w := workertest.NewTestWorker(cfg)
|
||||
|
||||
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
||||
expMgr := experiment.NewManager(base)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||
)
|
||||
|
||||
func TestSelectDependencyManifestPriority(t *testing.T) {
|
||||
|
|
@ -99,7 +100,7 @@ func TestSelectDependencyManifestMissing(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestResolveDatasetsPrecedence(t *testing.T) {
|
||||
if got := worker.ResolveDatasets(nil); got != nil {
|
||||
if got := workertest.ResolveDatasets(nil); got != nil {
|
||||
t.Fatalf("expected nil for nil task")
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +110,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
|||
Datasets: []string{"ds-legacy"},
|
||||
Args: "--datasets ds-args",
|
||||
}
|
||||
got := worker.ResolveDatasets(task)
|
||||
got := workertest.ResolveDatasets(task)
|
||||
if len(got) != 1 || got[0] != "ds-spec" {
|
||||
t.Fatalf("expected dataset_specs to win, got %v", got)
|
||||
}
|
||||
|
|
@ -120,7 +121,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
|||
Datasets: []string{"ds-legacy"},
|
||||
Args: "--datasets ds-args",
|
||||
}
|
||||
got := worker.ResolveDatasets(task)
|
||||
got := workertest.ResolveDatasets(task)
|
||||
if len(got) != 1 || got[0] != "ds-legacy" {
|
||||
t.Fatalf("expected datasets to win over args, got %v", got)
|
||||
}
|
||||
|
|
@ -128,7 +129,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
|||
|
||||
t.Run("ArgsFallback", func(t *testing.T) {
|
||||
task := &queue.Task{Args: "--datasets a,b,c"}
|
||||
got := worker.ResolveDatasets(task)
|
||||
got := workertest.ResolveDatasets(task)
|
||||
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
|
||||
t.Fatalf("expected args datasets, got %v", got)
|
||||
}
|
||||
|
|
@ -234,7 +235,7 @@ func TestVerifyDatasetSpecs(t *testing.T) {
|
|||
sha, err := worker.DirOverallSHA256Hex(dsPath)
|
||||
requireNoErr(t, err)
|
||||
|
||||
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||
task := &queue.Task{
|
||||
JobName: "job",
|
||||
ID: "t1",
|
||||
|
|
@ -272,7 +273,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
|
|||
}
|
||||
requireNoErr(t, expMgr.WriteManifest(manifest))
|
||||
|
||||
w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
|
||||
w := workertest.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
|
||||
|
||||
// Missing expected fields should fail.
|
||||
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
|
||||
|
|
@ -296,7 +297,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
|
|||
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
||||
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
||||
|
||||
wSnap := worker.NewTestWorker(&worker.Config{
|
||||
wSnap := workertest.NewTestWorker(&worker.Config{
|
||||
BasePath: base,
|
||||
DataDir: filepath.Join(base, "data"),
|
||||
ProvenanceBestEffort: false,
|
||||
|
|
@ -335,7 +336,7 @@ func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
|
|||
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
||||
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
||||
|
||||
w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
|
||||
w := workertest.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
|
||||
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
|
||||
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
|
||||
t.Fatalf("expected best-effort to pass, got %v", err)
|
||||
|
|
@ -360,7 +361,7 @@ func TestVerifySnapshot(t *testing.T) {
|
|||
sha, err := worker.DirOverallSHA256Hex(snapDir)
|
||||
requireNoErr(t, err)
|
||||
|
||||
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||
|
||||
t.Run("Ok", func(t *testing.T) {
|
||||
task := &queue.Task{
|
||||
|
|
|
|||
Loading…
Reference in a new issue