refactor(worker): update worker tests and native bridge
**Worker Refactoring:** - Update internal/worker/factory.go, worker.go, snapshot_store.go - Update native_bridge.go and native_bridge_nocgo.go for native library integration **Test Updates:** - Update all worker unit tests for new interfaces - Update chaos tests - Update container/podman_test.go - Add internal/workertest/worker.go for shared test utilities **Documentation:** - Update native/README.md
This commit is contained in:
parent
4b8df60e83
commit
fc2459977c
13 changed files with 264 additions and 99 deletions
|
|
@ -159,15 +159,15 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
worker := &Worker{
|
worker := &Worker{
|
||||||
id: cfg.WorkerID,
|
ID: cfg.WorkerID,
|
||||||
config: cfg,
|
Config: cfg,
|
||||||
logger: logger,
|
Logger: logger,
|
||||||
runLoop: runLoop,
|
RunLoop: runLoop,
|
||||||
runner: jobRunner,
|
Runner: jobRunner,
|
||||||
metrics: metricsObj,
|
Metrics: metricsObj,
|
||||||
health: lifecycle.NewHealthMonitor(),
|
Health: lifecycle.NewHealthMonitor(),
|
||||||
resources: rm,
|
Resources: rm,
|
||||||
jupyter: jupyterMgr,
|
Jupyter: jupyterMgr,
|
||||||
gpuDetectionInfo: gpuDetectionInfo,
|
gpuDetectionInfo: gpuDetectionInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -200,23 +200,23 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||||
|
|
||||||
// prePullImages pulls required container images in the background
|
// prePullImages pulls required container images in the background
|
||||||
func (w *Worker) prePullImages() {
|
func (w *Worker) prePullImages() {
|
||||||
if w.config.LocalMode {
|
if w.Config.LocalMode {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.logger.Info("starting image pre-pulling")
|
w.Logger.Info("starting image pre-pulling")
|
||||||
|
|
||||||
// Pull worker image
|
// Pull worker image
|
||||||
if w.config.PodmanImage != "" {
|
if w.Config.PodmanImage != "" {
|
||||||
w.pullImage(w.config.PodmanImage)
|
w.pullImage(w.Config.PodmanImage)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pull plugin images
|
// Pull plugin images
|
||||||
for name, cfg := range w.config.Plugins {
|
for name, cfg := range w.Config.Plugins {
|
||||||
if !cfg.Enabled || cfg.Image == "" {
|
if !cfg.Enabled || cfg.Image == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image)
|
||||||
w.pullImage(cfg.Image)
|
w.pullImage(cfg.Image)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -228,8 +228,8 @@ func (w *Worker) pullImage(image string) {
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
cmd := exec.CommandContext(ctx, "podman", "pull", image)
|
||||||
if output, err := cmd.CombinedOutput(); err != nil {
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output))
|
||||||
} else {
|
} else {
|
||||||
w.logger.Info("image pulled successfully", "image", image)
|
w.Logger.Info("image pulled successfully", "image", image)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,3 +55,8 @@ func (qi *QueueIndexNative) Close() {}
|
||||||
func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error {
|
func (qi *QueueIndexNative) AddTasks(tasks []*queue.Task) error {
|
||||||
return errors.New("native queue index requires native_libs build tag")
|
return errors.New("native queue index requires native_libs build tag")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DirOverallSHA256HexNative is disabled without native_libs build tag.
|
||||||
|
func DirOverallSHA256HexNative(root string) (string, error) {
|
||||||
|
return "", errors.New("native hash requires native_libs build tag")
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,3 +33,8 @@ func ScanArtifactsNative(runDir string) (*manifest.Artifacts, error) {
|
||||||
func ExtractTarGzNative(archivePath, dstDir string) error {
|
func ExtractTarGzNative(archivePath, dstDir string) error {
|
||||||
return errors.New("native tar.gz extractor requires CGO")
|
return errors.New("native tar.gz extractor requires CGO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DirOverallSHA256HexNative is disabled without CGO.
|
||||||
|
func DirOverallSHA256HexNative(root string) (string, error) {
|
||||||
|
return "", errors.New("native hash requires CGO")
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SnapshotFetcher is an interface for fetching snapshots
|
||||||
type SnapshotFetcher interface {
|
type SnapshotFetcher interface {
|
||||||
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
|
Get(ctx context.Context, bucket, key string) (io.ReadCloser, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,87 +43,87 @@ func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||||
|
|
||||||
// Worker represents an ML task worker with composed dependencies.
|
// Worker represents an ML task worker with composed dependencies.
|
||||||
type Worker struct {
|
type Worker struct {
|
||||||
id string
|
ID string
|
||||||
config *Config
|
Config *Config
|
||||||
logger *logging.Logger
|
Logger *logging.Logger
|
||||||
|
|
||||||
// Composed dependencies from previous phases
|
// Composed dependencies from previous phases
|
||||||
runLoop *lifecycle.RunLoop
|
RunLoop *lifecycle.RunLoop
|
||||||
runner *executor.JobRunner
|
Runner *executor.JobRunner
|
||||||
metrics *metrics.Metrics
|
Metrics *metrics.Metrics
|
||||||
metricsSrv *http.Server
|
metricsSrv *http.Server
|
||||||
health *lifecycle.HealthMonitor
|
Health *lifecycle.HealthMonitor
|
||||||
resources *resources.Manager
|
Resources *resources.Manager
|
||||||
|
|
||||||
// GPU detection metadata for status output
|
// GPU detection metadata for status output
|
||||||
gpuDetectionInfo GPUDetectionInfo
|
gpuDetectionInfo GPUDetectionInfo
|
||||||
|
|
||||||
// Legacy fields for backward compatibility during migration
|
// Legacy fields for backward compatibility during migration
|
||||||
jupyter JupyterManager
|
Jupyter JupyterManager
|
||||||
queueClient queue.Backend // Stored for prewarming access
|
QueueClient queue.Backend // Stored for prewarming access
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins the worker's main processing loop.
|
// Start begins the worker's main processing loop.
|
||||||
func (w *Worker) Start() {
|
func (w *Worker) Start() {
|
||||||
w.logger.Info("worker starting",
|
w.Logger.Info("worker starting",
|
||||||
"worker_id", w.id,
|
"worker_id", w.ID,
|
||||||
"max_concurrent", w.config.MaxWorkers)
|
"max_concurrent", w.Config.MaxWorkers)
|
||||||
|
|
||||||
w.health.RecordHeartbeat()
|
w.Health.RecordHeartbeat()
|
||||||
w.runLoop.Start()
|
w.RunLoop.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop gracefully shuts down the worker immediately.
|
// Stop gracefully shuts down the worker immediately.
|
||||||
func (w *Worker) Stop() {
|
func (w *Worker) Stop() {
|
||||||
w.logger.Info("worker stopping", "worker_id", w.id)
|
w.Logger.Info("worker stopping", "worker_id", w.ID)
|
||||||
w.runLoop.Stop()
|
w.RunLoop.Stop()
|
||||||
|
|
||||||
if w.metricsSrv != nil {
|
if w.metricsSrv != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
w.Logger.Warn("metrics server shutdown error", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
w.Logger.Info("worker stopped", "worker_id", w.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown performs a graceful shutdown with timeout.
|
// Shutdown performs a graceful shutdown with timeout.
|
||||||
func (w *Worker) Shutdown() error {
|
func (w *Worker) Shutdown() error {
|
||||||
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
|
w.Logger.Info("starting graceful shutdown", "worker_id", w.ID)
|
||||||
|
|
||||||
w.runLoop.Stop()
|
w.RunLoop.Stop()
|
||||||
|
|
||||||
if w.metricsSrv != nil {
|
if w.metricsSrv != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||||
w.logger.Warn("metrics server shutdown error", "error", err)
|
w.Logger.Warn("metrics server shutdown error", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
|
w.Logger.Info("worker shut down gracefully", "worker_id", w.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsHealthy returns true if the worker is healthy.
|
// IsHealthy returns true if the worker is healthy.
|
||||||
func (w *Worker) IsHealthy() bool {
|
func (w *Worker) IsHealthy() bool {
|
||||||
return w.health.IsHealthy(5 * time.Minute)
|
return w.Health.IsHealthy(5 * time.Minute)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMetrics returns current worker metrics.
|
// GetMetrics returns current worker metrics.
|
||||||
func (w *Worker) GetMetrics() map[string]any {
|
func (w *Worker) GetMetrics() map[string]any {
|
||||||
stats := w.metrics.GetStats()
|
stats := w.Metrics.GetStats()
|
||||||
stats["worker_id"] = w.id
|
stats["worker_id"] = w.ID
|
||||||
stats["max_workers"] = w.config.MaxWorkers
|
stats["max_workers"] = w.Config.MaxWorkers
|
||||||
stats["healthy"] = w.IsHealthy()
|
stats["healthy"] = w.IsHealthy()
|
||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the worker ID.
|
// GetID returns the worker ID.
|
||||||
func (w *Worker) GetID() string {
|
func (w *Worker) GetID() string {
|
||||||
return w.id
|
return w.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectDependencyManifest re-exports the executor function for API helpers.
|
// SelectDependencyManifest re-exports the executor function for API helpers.
|
||||||
|
|
@ -162,7 +162,7 @@ func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string
|
||||||
// VerifyDatasetSpecs verifies dataset specifications for this task.
|
// VerifyDatasetSpecs verifies dataset specifications for this task.
|
||||||
// This is a test compatibility method that wraps the integrity package.
|
// This is a test compatibility method that wraps the integrity package.
|
||||||
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
|
||||||
dataDir := w.config.DataDir
|
dataDir := w.Config.DataDir
|
||||||
if dataDir == "" {
|
if dataDir == "" {
|
||||||
dataDir = "/tmp/data"
|
dataDir = "/tmp/data"
|
||||||
}
|
}
|
||||||
|
|
@ -179,16 +179,16 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
|
||||||
return fmt.Errorf("task is nil")
|
return fmt.Errorf("task is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
basePath := w.config.BasePath
|
basePath := w.Config.BasePath
|
||||||
if basePath == "" {
|
if basePath == "" {
|
||||||
basePath = "/tmp"
|
basePath = "/tmp"
|
||||||
}
|
}
|
||||||
dataDir := w.config.DataDir
|
dataDir := w.Config.DataDir
|
||||||
if dataDir == "" {
|
if dataDir == "" {
|
||||||
dataDir = filepath.Join(basePath, "data")
|
dataDir = filepath.Join(basePath, "data")
|
||||||
}
|
}
|
||||||
|
|
||||||
bestEffort := w.config.ProvenanceBestEffort
|
bestEffort := w.Config.ProvenanceBestEffort
|
||||||
|
|
||||||
// Get commit_id from metadata
|
// Get commit_id from metadata
|
||||||
commitID := task.Metadata["commit_id"]
|
commitID := task.Metadata["commit_id"]
|
||||||
|
|
@ -289,7 +289,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||||
return nil // No snapshot to verify
|
return nil // No snapshot to verify
|
||||||
}
|
}
|
||||||
|
|
||||||
dataDir := w.config.DataDir
|
dataDir := w.Config.DataDir
|
||||||
if dataDir == "" {
|
if dataDir == "" {
|
||||||
dataDir = "/tmp/data"
|
dataDir = "/tmp/data"
|
||||||
}
|
}
|
||||||
|
|
@ -324,7 +324,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
||||||
// RunJupyterTask runs a Jupyter-related task.
|
// RunJupyterTask runs a Jupyter-related task.
|
||||||
// It handles start, stop, remove, restore, and list_packages actions.
|
// It handles start, stop, remove, restore, and list_packages actions.
|
||||||
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||||
if w.jupyter == nil {
|
if w.Jupyter == nil {
|
||||||
return nil, fmt.Errorf("jupyter manager not configured")
|
return nil, fmt.Errorf("jupyter manager not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -350,7 +350,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &jupyter.StartRequest{Name: name}
|
req := &jupyter.StartRequest{Name: name}
|
||||||
service, err := w.jupyter.StartService(ctx, req)
|
service, err := w.Jupyter.StartService(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -366,7 +366,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
if serviceID == "" {
|
if serviceID == "" {
|
||||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||||
}
|
}
|
||||||
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
|
if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
||||||
}
|
}
|
||||||
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
||||||
|
|
@ -377,7 +377,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||||
}
|
}
|
||||||
purge := task.Metadata["jupyter_purge"] == "true"
|
purge := task.Metadata["jupyter_purge"] == "true"
|
||||||
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||||
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
||||||
}
|
}
|
||||||
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
||||||
|
|
@ -390,7 +390,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||||
}
|
}
|
||||||
serviceID, err := w.jupyter.RestoreWorkspace(ctx, name)
|
serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -408,7 +408,7 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
||||||
}
|
}
|
||||||
|
|
||||||
packages, err := w.jupyter.ListInstalledPackages(ctx, serviceName)
|
packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -429,16 +429,16 @@ func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte,
|
||||||
// Returns true if prewarming was performed, false if disabled or queue empty.
|
// Returns true if prewarming was performed, false if disabled or queue empty.
|
||||||
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||||
// Check if prewarming is enabled
|
// Check if prewarming is enabled
|
||||||
if !w.config.PrewarmEnabled {
|
if !w.Config.PrewarmEnabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get base path and data directory
|
// Get base path and data directory
|
||||||
basePath := w.config.BasePath
|
basePath := w.Config.BasePath
|
||||||
if basePath == "" {
|
if basePath == "" {
|
||||||
basePath = "/tmp"
|
basePath = "/tmp"
|
||||||
}
|
}
|
||||||
dataDir := w.config.DataDir
|
dataDir := w.Config.DataDir
|
||||||
if dataDir == "" {
|
if dataDir == "" {
|
||||||
dataDir = filepath.Join(basePath, "data")
|
dataDir = filepath.Join(basePath, "data")
|
||||||
}
|
}
|
||||||
|
|
@ -450,12 +450,12 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to get next task from queue client if available (peek, don't lease)
|
// Try to get next task from queue client if available (peek, don't lease)
|
||||||
if w.queueClient != nil {
|
if w.QueueClient != nil {
|
||||||
task, err := w.queueClient.PeekNextTask()
|
task, err := w.QueueClient.PeekNextTask()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Queue empty - check if we have existing prewarm state
|
// Queue empty - check if we have existing prewarm state
|
||||||
// Return false but preserve any existing state (don't delete)
|
// Return false but preserve any existing state (don't delete)
|
||||||
state, _ := w.queueClient.GetWorkerPrewarmState(w.id)
|
state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID)
|
||||||
if state != nil {
|
if state != nil {
|
||||||
// We have existing state, return true to indicate prewarm is active
|
// We have existing state, return true to indicate prewarm is active
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|
@ -489,17 +489,17 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store prewarm state in queue backend
|
// Store prewarm state in queue backend
|
||||||
if w.queueClient != nil {
|
if w.QueueClient != nil {
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
state := queue.PrewarmState{
|
state := queue.PrewarmState{
|
||||||
WorkerID: w.id,
|
WorkerID: w.ID,
|
||||||
TaskID: task.ID,
|
TaskID: task.ID,
|
||||||
SnapshotID: task.SnapshotID,
|
SnapshotID: task.SnapshotID,
|
||||||
StartedAt: now,
|
StartedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
Phase: "staged",
|
Phase: "staged",
|
||||||
}
|
}
|
||||||
_ = w.queueClient.SetWorkerPrewarmState(state)
|
_ = w.QueueClient.SetWorkerPrewarmState(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|
@ -507,7 +507,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
|
// If we have a runLoop but no queue client, use runLoop (for backward compatibility)
|
||||||
if w.runLoop != nil {
|
if w.RunLoop != nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -517,18 +517,18 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
||||||
// RunJob runs a job task.
|
// RunJob runs a job task.
|
||||||
// It uses the JobRunner to execute the job and write the run manifest.
|
// It uses the JobRunner to execute the job and write the run manifest.
|
||||||
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
|
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error {
|
||||||
if w.runner == nil {
|
if w.Runner == nil {
|
||||||
return fmt.Errorf("job runner not configured")
|
return fmt.Errorf("job runner not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
basePath := w.config.BasePath
|
basePath := w.Config.BasePath
|
||||||
if basePath == "" {
|
if basePath == "" {
|
||||||
basePath = "/tmp"
|
basePath = "/tmp"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine execution mode
|
// Determine execution mode
|
||||||
mode := executor.ModeAuto
|
mode := executor.ModeAuto
|
||||||
if w.config.LocalMode {
|
if w.Config.LocalMode {
|
||||||
mode = executor.ModeLocal
|
mode = executor.ModeLocal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -536,5 +536,5 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
|
||||||
gpuEnv := interfaces.ExecutionEnv{}
|
gpuEnv := interfaces.ExecutionEnv{}
|
||||||
|
|
||||||
// Run the job
|
// Run the job
|
||||||
return w.runner.Run(ctx, task, basePath, mode, w.config.LocalMode, gpuEnv)
|
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
150
internal/workertest/worker.go
Normal file
150
internal/workertest/worker.go
Normal file
|
|
@ -0,0 +1,150 @@
|
||||||
|
// Package workertest provides test helpers for the worker package.
|
||||||
|
// This package is only intended for use in tests and is separate from
|
||||||
|
// production code to maintain clean separation of concerns.
|
||||||
|
package workertest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SimpleManifestWriter is a basic ManifestWriter implementation for testing
|
||||||
|
type SimpleManifestWriter struct{}
|
||||||
|
|
||||||
|
func (w *SimpleManifestWriter) Upsert(dir string, task *queue.Task, mutate func(*manifest.RunManifest)) {
|
||||||
|
// Try to load existing manifest, or create new one
|
||||||
|
m, err := manifest.LoadFromDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
m = w.BuildInitial(task, "")
|
||||||
|
}
|
||||||
|
mutate(m)
|
||||||
|
_ = m.WriteToDir(dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *SimpleManifestWriter) BuildInitial(task *queue.Task, podmanImage string) *manifest.RunManifest {
|
||||||
|
m := manifest.NewRunManifest(
|
||||||
|
"run-"+task.ID,
|
||||||
|
task.ID,
|
||||||
|
task.JobName,
|
||||||
|
time.Now().UTC(),
|
||||||
|
)
|
||||||
|
m.CommitID = task.Metadata["commit_id"]
|
||||||
|
m.DepsManifestName = task.Metadata["deps_manifest_name"]
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestWorker creates a minimal Worker for testing purposes.
|
||||||
|
// It initializes only the fields needed for unit tests.
|
||||||
|
func NewTestWorker(cfg *worker.Config) *worker.Worker {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &worker.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||||
|
metricsObj := &metrics.Metrics{}
|
||||||
|
|
||||||
|
// Create executors and runner for testing
|
||||||
|
writer := &SimpleManifestWriter{}
|
||||||
|
localExecutor := executor.NewLocalExecutor(logger, writer)
|
||||||
|
containerExecutor := executor.NewContainerExecutor(
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
executor.ContainerConfig{
|
||||||
|
PodmanImage: cfg.PodmanImage,
|
||||||
|
BasePath: cfg.BasePath,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jobRunner := executor.NewJobRunner(
|
||||||
|
localExecutor,
|
||||||
|
containerExecutor,
|
||||||
|
writer,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &worker.Worker{
|
||||||
|
ID: cfg.WorkerID,
|
||||||
|
Config: cfg,
|
||||||
|
Logger: logger,
|
||||||
|
Metrics: metricsObj,
|
||||||
|
Health: lifecycle.NewHealthMonitor(),
|
||||||
|
Runner: jobRunner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestWorkerWithQueue creates a test Worker with a queue client.
|
||||||
|
func NewTestWorkerWithQueue(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
|
||||||
|
w := NewTestWorker(cfg)
|
||||||
|
w.QueueClient = queueClient
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestWorkerWithJupyter creates a test Worker with Jupyter manager.
|
||||||
|
func NewTestWorkerWithJupyter(cfg *worker.Config, jupyterMgr worker.JupyterManager) *worker.Worker {
|
||||||
|
w := NewTestWorker(cfg)
|
||||||
|
w.Jupyter = jupyterMgr
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestWorkerWithRunner creates a test Worker with JobRunner initialized.
|
||||||
|
// Note: This creates a minimal runner for testing purposes.
|
||||||
|
func NewTestWorkerWithRunner(cfg *worker.Config) *worker.Worker {
|
||||||
|
return NewTestWorker(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestWorkerWithRunLoop creates a test Worker with RunLoop initialized.
|
||||||
|
// Note: RunLoop requires proper queue client setup.
|
||||||
|
func NewTestWorkerWithRunLoop(cfg *worker.Config, queueClient queue.Backend) *worker.Worker {
|
||||||
|
return NewTestWorkerWithQueue(cfg, queueClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveDatasets resolves dataset paths for a task.
|
||||||
|
// This version matches the test expectations for backwards compatibility.
|
||||||
|
// Priority: DatasetSpecs > Datasets > Args parsing
|
||||||
|
func ResolveDatasets(task *queue.Task) []string {
|
||||||
|
if task == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 1: DatasetSpecs
|
||||||
|
if len(task.DatasetSpecs) > 0 {
|
||||||
|
var paths []string
|
||||||
|
for _, spec := range task.DatasetSpecs {
|
||||||
|
paths = append(paths, spec.Name)
|
||||||
|
}
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: Datasets
|
||||||
|
if len(task.Datasets) > 0 {
|
||||||
|
return task.Datasets
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 3: Parse from Args
|
||||||
|
if task.Args != "" {
|
||||||
|
// Simple parsing: --datasets a,b,c or --datasets a b c
|
||||||
|
args := task.Args
|
||||||
|
if idx := strings.Index(args, "--datasets"); idx != -1 {
|
||||||
|
after := args[idx+len("--datasets "):]
|
||||||
|
after = strings.TrimSpace(after)
|
||||||
|
// Split by comma or space
|
||||||
|
if strings.Contains(after, ",") {
|
||||||
|
return strings.Split(after, ",")
|
||||||
|
}
|
||||||
|
parts := strings.Fields(after)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -184,6 +184,6 @@ go test -tags native_libs ./tests/...
|
||||||
- Rebuild: `make native-clean && make native-build`
|
- Rebuild: `make native-clean && make native-build`
|
||||||
|
|
||||||
**Performance regression:**
|
**Performance regression:**
|
||||||
- Verify `FETCHML_NATIVE_LIBS=1` is set
|
- Verify code is built with `-tags native_libs`
|
||||||
- Check benchmark: `go test -bench=BenchmarkQueue -v`
|
- Check benchmark: `go test -bench=BenchmarkQueue -v`
|
||||||
- Profile with: `go test -bench=. -cpuprofile=cpu.prof`
|
- Profile with: `go test -bench=. -cpuprofile=cpu.prof`
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ func testDatabaseConnectionFailure(t *testing.T, db *storage.DB, _ *redis.Client
|
||||||
// testRedisConnectionFailure tests system behavior when Redis fails
|
// testRedisConnectionFailure tests system behavior when Redis fails
|
||||||
func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) {
|
func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client) {
|
||||||
// Add jobs to Redis queue
|
// Add jobs to Redis queue
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
jobID := fmt.Sprintf("redis-chaos-job-%d", i)
|
jobID := fmt.Sprintf("redis-chaos-job-%d", i)
|
||||||
err := rdb.LPush(context.Background(), "ml:queue", jobID).Err()
|
err := rdb.LPush(context.Background(), "ml:queue", jobID).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -188,7 +188,7 @@ func testRedisConnectionFailure(t *testing.T, _ *storage.DB, rdb *redis.Client)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wait for Redis to be available
|
// Wait for Redis to be available
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
err := newRdb.Ping(context.Background()).Err()
|
err := newRdb.Ping(context.Background()).Err()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
|
|
@ -218,7 +218,7 @@ func testHighConcurrencyStress(t *testing.T, db *storage.DB, rdb *redis.Client)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// Launch many concurrent workers
|
// Launch many concurrent workers
|
||||||
for worker := 0; worker < numWorkers; worker++ {
|
for worker := range numWorkers {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(workerID int) {
|
go func(workerID int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
@ -313,7 +313,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
numJobs := 50
|
numJobs := 50
|
||||||
|
|
||||||
// Create jobs with large payloads
|
// Create jobs with large payloads
|
||||||
for i := 0; i < numJobs; i++ {
|
for i := range numJobs {
|
||||||
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
||||||
|
|
||||||
job := &storage.Job{
|
job := &storage.Job{
|
||||||
|
|
@ -337,7 +337,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process jobs to test memory handling during operations
|
// Process jobs to test memory handling during operations
|
||||||
for i := 0; i < numJobs; i++ {
|
for i := range numJobs {
|
||||||
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
jobID := fmt.Sprintf("memory-pressure-job-%d", i)
|
||||||
|
|
||||||
// Update job status
|
// Update job status
|
||||||
|
|
@ -360,7 +360,7 @@ func testMemoryPressure(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
// Simulate operations with artificial delays
|
// Simulate operations with artificial delays
|
||||||
numJobs := 20
|
numJobs := 20
|
||||||
for i := 0; i < numJobs; i++ {
|
for i := range numJobs {
|
||||||
jobID := fmt.Sprintf("latency-job-%d", i)
|
jobID := fmt.Sprintf("latency-job-%d", i)
|
||||||
|
|
||||||
// Add artificial delay to simulate network latency
|
// Add artificial delay to simulate network latency
|
||||||
|
|
@ -387,7 +387,7 @@ func testNetworkLatency(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process jobs with latency simulation
|
// Process jobs with latency simulation
|
||||||
for i := 0; i < numJobs; i++ {
|
for i := range numJobs {
|
||||||
jobID := fmt.Sprintf("latency-job-%d", i)
|
jobID := fmt.Sprintf("latency-job-%d", i)
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 8)
|
time.Sleep(time.Millisecond * 8)
|
||||||
|
|
@ -413,7 +413,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
done := make(chan bool, numOperations)
|
done := make(chan bool, numOperations)
|
||||||
errors := make(chan error, numOperations)
|
errors := make(chan error, numOperations)
|
||||||
|
|
||||||
for i := 0; i < numOperations; i++ {
|
for i := range numOperations {
|
||||||
go func(opID int) {
|
go func(opID int) {
|
||||||
defer func() { done <- true }()
|
defer func() { done <- true }()
|
||||||
|
|
||||||
|
|
@ -448,7 +448,7 @@ func testResourceExhaustion(t *testing.T, db *storage.DB, rdb *redis.Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all operations to complete
|
// Wait for all operations to complete
|
||||||
for i := 0; i < numOperations; i++ {
|
for range numOperations {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
close(errors)
|
close(errors)
|
||||||
|
|
@ -522,7 +522,7 @@ func setupChaosRedisIsolated(t *testing.T) *redis.Client {
|
||||||
|
|
||||||
func createTestJobs(t *testing.T, db *storage.DB, count int) []string {
|
func createTestJobs(t *testing.T, db *storage.DB, count int) []string {
|
||||||
jobIDs := make([]string, count)
|
jobIDs := make([]string, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
jobID := fmt.Sprintf("chaos-test-job-%d", i)
|
jobID := fmt.Sprintf("chaos-test-job-%d", i)
|
||||||
jobIDs[i] = jobID
|
jobIDs[i] = jobID
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := container.BuildPodmanCommand(
|
cmd := container.BuildPodmanCommandLegacy(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
cfg,
|
cfg,
|
||||||
"/workspace/train.py",
|
"/workspace/train.py",
|
||||||
|
|
@ -100,7 +100,7 @@ func TestBuildPodmanCommand_Overrides(t *testing.T) {
|
||||||
CPUs: "8",
|
CPUs: "8",
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil)
|
cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil)
|
||||||
|
|
||||||
if contains(cmd.Args, "--device") {
|
if contains(cmd.Args, "--device") {
|
||||||
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)
|
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fakeJupyterManager struct {
|
type fakeJupyterManager struct {
|
||||||
|
|
@ -65,7 +66,7 @@ type jupyterPackagesOutput struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
||||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||||
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
startFn: func(_ context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) {
|
||||||
if req.Name != "my-workspace" {
|
if req.Name != "my-workspace" {
|
||||||
return nil, errors.New("bad name")
|
return nil, errors.New("bad name")
|
||||||
|
|
@ -102,7 +103,7 @@ func TestRunJupyterTaskStartSuccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunJupyterTaskStopFailure(t *testing.T) {
|
func TestRunJupyterTaskStopFailure(t *testing.T) {
|
||||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||||
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
|
stopFn: func(context.Context, string) error { return errors.New("stop failed") },
|
||||||
removeFn: func(context.Context, string, bool) error { return nil },
|
removeFn: func(context.Context, string, bool) error { return nil },
|
||||||
|
|
@ -123,7 +124,7 @@ func TestRunJupyterTaskStopFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
|
func TestRunJupyterTaskListPackagesSuccess(t *testing.T) {
|
||||||
w := worker.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
w := workertest.NewTestWorkerWithJupyter(&worker.Config{}, &fakeJupyterManager{
|
||||||
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
startFn: func(context.Context, *jupyter.StartRequest) (*jupyter.JupyterService, error) { return nil, nil },
|
||||||
stopFn: func(context.Context, string) error { return nil },
|
stopFn: func(context.Context, string) error { return nil },
|
||||||
removeFn: func(context.Context, string, bool) error { return nil },
|
removeFn: func(context.Context, string, bool) error { return nil },
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/alicebob/miniredis/v2"
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
|
func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
|
||||||
|
|
@ -75,7 +76,7 @@ func TestPrewarmNextOnce_Snapshot_WritesPrewarmDir(t *testing.T) {
|
||||||
MaxWorkers: 1,
|
MaxWorkers: 1,
|
||||||
DatasetCacheTTL: 30 * time.Minute,
|
DatasetCacheTTL: 30 * time.Minute,
|
||||||
}
|
}
|
||||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||||
|
|
||||||
ok, err := w.PrewarmNextOnce(context.Background())
|
ok, err := w.PrewarmNextOnce(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -113,7 +114,7 @@ func TestPrewarmNextOnce_Disabled_NoOp(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false}
|
cfg := &worker.Config{WorkerID: "worker-1", BasePath: base, DataDir: dataDir, PrewarmEnabled: false}
|
||||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||||
|
|
||||||
ok, err := w.PrewarmNextOnce(context.Background())
|
ok, err := w.PrewarmNextOnce(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -189,7 +190,7 @@ func TestPrewarmNextOnce_QueueEmpty_DoesNotDeleteState(t *testing.T) {
|
||||||
MaxWorkers: 1,
|
MaxWorkers: 1,
|
||||||
DatasetCacheTTL: 30 * time.Minute,
|
DatasetCacheTTL: 30 * time.Minute,
|
||||||
}
|
}
|
||||||
w := worker.NewTestWorkerWithQueue(cfg, tq)
|
w := workertest.NewTestWorkerWithQueue(cfg, tq)
|
||||||
|
|
||||||
ok, err := w.PrewarmNextOnce(context.Background())
|
ok, err := w.PrewarmNextOnce(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
||||||
|
|
@ -22,7 +23,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
||||||
PodmanImage: "python:3.11",
|
PodmanImage: "python:3.11",
|
||||||
WorkerID: "worker-test",
|
WorkerID: "worker-test",
|
||||||
}
|
}
|
||||||
w := worker.NewTestWorker(cfg)
|
w := workertest.NewTestWorker(cfg)
|
||||||
|
|
||||||
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
||||||
expMgr := experiment.NewManager(base)
|
expMgr := experiment.NewManager(base)
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/workertest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSelectDependencyManifestPriority(t *testing.T) {
|
func TestSelectDependencyManifestPriority(t *testing.T) {
|
||||||
|
|
@ -99,7 +100,7 @@ func TestSelectDependencyManifestMissing(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveDatasetsPrecedence(t *testing.T) {
|
func TestResolveDatasetsPrecedence(t *testing.T) {
|
||||||
if got := worker.ResolveDatasets(nil); got != nil {
|
if got := workertest.ResolveDatasets(nil); got != nil {
|
||||||
t.Fatalf("expected nil for nil task")
|
t.Fatalf("expected nil for nil task")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -109,7 +110,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
||||||
Datasets: []string{"ds-legacy"},
|
Datasets: []string{"ds-legacy"},
|
||||||
Args: "--datasets ds-args",
|
Args: "--datasets ds-args",
|
||||||
}
|
}
|
||||||
got := worker.ResolveDatasets(task)
|
got := workertest.ResolveDatasets(task)
|
||||||
if len(got) != 1 || got[0] != "ds-spec" {
|
if len(got) != 1 || got[0] != "ds-spec" {
|
||||||
t.Fatalf("expected dataset_specs to win, got %v", got)
|
t.Fatalf("expected dataset_specs to win, got %v", got)
|
||||||
}
|
}
|
||||||
|
|
@ -120,7 +121,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
||||||
Datasets: []string{"ds-legacy"},
|
Datasets: []string{"ds-legacy"},
|
||||||
Args: "--datasets ds-args",
|
Args: "--datasets ds-args",
|
||||||
}
|
}
|
||||||
got := worker.ResolveDatasets(task)
|
got := workertest.ResolveDatasets(task)
|
||||||
if len(got) != 1 || got[0] != "ds-legacy" {
|
if len(got) != 1 || got[0] != "ds-legacy" {
|
||||||
t.Fatalf("expected datasets to win over args, got %v", got)
|
t.Fatalf("expected datasets to win over args, got %v", got)
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +129,7 @@ func TestResolveDatasetsPrecedence(t *testing.T) {
|
||||||
|
|
||||||
t.Run("ArgsFallback", func(t *testing.T) {
|
t.Run("ArgsFallback", func(t *testing.T) {
|
||||||
task := &queue.Task{Args: "--datasets a,b,c"}
|
task := &queue.Task{Args: "--datasets a,b,c"}
|
||||||
got := worker.ResolveDatasets(task)
|
got := workertest.ResolveDatasets(task)
|
||||||
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
|
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
|
||||||
t.Fatalf("expected args datasets, got %v", got)
|
t.Fatalf("expected args datasets, got %v", got)
|
||||||
}
|
}
|
||||||
|
|
@ -234,7 +235,7 @@ func TestVerifyDatasetSpecs(t *testing.T) {
|
||||||
sha, err := worker.DirOverallSHA256Hex(dsPath)
|
sha, err := worker.DirOverallSHA256Hex(dsPath)
|
||||||
requireNoErr(t, err)
|
requireNoErr(t, err)
|
||||||
|
|
||||||
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||||
task := &queue.Task{
|
task := &queue.Task{
|
||||||
JobName: "job",
|
JobName: "job",
|
||||||
ID: "t1",
|
ID: "t1",
|
||||||
|
|
@ -272,7 +273,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
|
||||||
}
|
}
|
||||||
requireNoErr(t, expMgr.WriteManifest(manifest))
|
requireNoErr(t, expMgr.WriteManifest(manifest))
|
||||||
|
|
||||||
w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
|
w := workertest.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
|
||||||
|
|
||||||
// Missing expected fields should fail.
|
// Missing expected fields should fail.
|
||||||
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
|
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
|
||||||
|
|
@ -296,7 +297,7 @@ func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
|
||||||
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
||||||
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
||||||
|
|
||||||
wSnap := worker.NewTestWorker(&worker.Config{
|
wSnap := workertest.NewTestWorker(&worker.Config{
|
||||||
BasePath: base,
|
BasePath: base,
|
||||||
DataDir: filepath.Join(base, "data"),
|
DataDir: filepath.Join(base, "data"),
|
||||||
ProvenanceBestEffort: false,
|
ProvenanceBestEffort: false,
|
||||||
|
|
@ -335,7 +336,7 @@ func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
|
||||||
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
||||||
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
||||||
|
|
||||||
w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
|
w := workertest.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
|
||||||
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
|
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
|
||||||
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
|
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
|
||||||
t.Fatalf("expected best-effort to pass, got %v", err)
|
t.Fatalf("expected best-effort to pass, got %v", err)
|
||||||
|
|
@ -360,7 +361,7 @@ func TestVerifySnapshot(t *testing.T) {
|
||||||
sha, err := worker.DirOverallSHA256Hex(snapDir)
|
sha, err := worker.DirOverallSHA256Hex(snapDir)
|
||||||
requireNoErr(t, err)
|
requireNoErr(t, err)
|
||||||
|
|
||||||
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
w := workertest.NewTestWorker(&worker.Config{DataDir: dataDir})
|
||||||
|
|
||||||
t.Run("Ok", func(t *testing.T) {
|
t.Run("Ok", func(t *testing.T) {
|
||||||
task := &queue.Task{
|
task := &queue.Task{
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue