refactor(queue): integrate scheduler backend and storage improvements

Update queue and storage systems for scheduler integration:
- Queue backend with scheduler coordination
- Filesystem queue with batch operations
- Deduplication with tenant-aware keys
- Storage layer with audit logging hooks
- Domain models (Task, Events, Errors) with scheduler fields
- Database layer with tenant isolation
- Dataset storage with integrity checks
This commit is contained in:
Jeremie Fraeys 2026-02-26 12:06:46 -05:00
parent 6b2c377680
commit 6866ba9366
No known key found for this signature in database
14 changed files with 211 additions and 141 deletions

View file

@ -81,16 +81,16 @@ func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClas
// FailureInfo contains complete failure context for the manifest
type FailureInfo struct {
Context map[string]string `json:"context,omitempty"`
Class FailureClass `json:"class"`
ExitCode int `json:"exit_code,omitempty"`
Signal string `json:"signal,omitempty"`
LogTail string `json:"log_tail,omitempty"`
Suggestion string `json:"suggestion,omitempty"`
AutoRetried bool `json:"auto_retried,omitempty"`
ClassifiedAt string `json:"classified_at,omitempty"`
ExitCode int `json:"exit_code,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
RetryCap int `json:"retry_cap,omitempty"`
ClassifiedAt string `json:"classified_at,omitempty"`
Context map[string]string `json:"context,omitempty"`
AutoRetried bool `json:"auto_retried,omitempty"`
}
// GetFailureSuggestion returns user guidance based on failure class

View file

@ -30,22 +30,11 @@ const (
// TaskEvent represents an event in a task's lifecycle.
// Events are stored in Redis Streams for append-only audit trails.
type TaskEvent struct {
// TaskID is the unique identifier of the task.
TaskID string `json:"task_id"`
// EventType indicates what happened (queued, started, completed, etc.).
EventType TaskEventType `json:"event_type"`
// Timestamp when the event occurred.
Timestamp time.Time `json:"timestamp"`
// Data contains event-specific data (JSON-encoded).
// For "started": {"worker_id": "worker-1", "image": "pytorch:latest"}
// For "failed": {"error": "OOM", "phase": "execution"}
Data json.RawMessage `json:"data,omitempty"`
// Who triggered this event (worker ID, user ID, or system).
Who string `json:"who"`
Timestamp time.Time `json:"timestamp"`
TaskID string `json:"task_id"`
EventType TaskEventType `json:"event_type"`
Who string `json:"who"`
Data json.RawMessage `json:"data,omitempty"`
}
// EventDataStarted contains data for the "started" event.
@ -64,8 +53,8 @@ type EventDataFailed struct {
// EventDataGPUAssigned contains data for the "gpu_assigned" event.
type EventDataGPUAssigned struct {
GPUDevices []string `json:"gpu_devices"`
GPUEnvVar string `json:"gpu_env_var,omitempty"`
GPUDevices []string `json:"gpu_devices"`
}
// NewTaskEvent creates a new task event with the current timestamp.

View file

@ -8,65 +8,48 @@ import (
// Task represents an ML experiment task
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
// SnapshotID references the experiment snapshot (code + deps) for this task.
// Currently stores an opaque identifier. Future: verify checksum/digest before execution
// to ensure reproducibility and detect tampering.
SnapshotID string `json:"snapshot_id,omitempty"`
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
// Datasets is kept for backward compatibility (legacy callers).
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// Resource requests (optional, 0 means unspecified)
CPU int `json:"cpu,omitempty"`
MemoryGB int `json:"memory_gb,omitempty"`
GPU int `json:"gpu,omitempty"`
GPUMemory string `json:"gpu_memory,omitempty"`
// User ownership and permissions
UserID string `json:"user_id"` // User who owns this task
Username string `json:"username"` // Username for display
CreatedBy string `json:"created_by"` // User who submitted the task
// Lease management for task resilience
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
// Retry management
RetryCount int `json:"retry_count"` // Number of retry attempts made
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
LastError string `json:"last_error,omitempty"` // Last error encountered
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
// Attempt tracking - complete history of all execution attempts
Attempts []Attempt `json:"attempts,omitempty"`
// Optional tracking configuration for this task
Tracking *TrackingConfig `json:"tracking,omitempty"`
CreatedAt time.Time `json:"created_at"`
Metadata map[string]string `json:"metadata,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
Tracking *TrackingConfig `json:"tracking,omitempty"`
NextRetry *time.Time `json:"next_retry,omitempty"`
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
Username string `json:"username"`
LeasedBy string `json:"leased_by,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
SnapshotID string `json:"snapshot_id,omitempty"`
Status string `json:"status"`
LastError string `json:"last_error,omitempty"`
ID string `json:"id"`
Args string `json:"args"`
WorkerID string `json:"worker_id,omitempty"`
JobName string `json:"job_name"`
GPUMemory string `json:"gpu_memory,omitempty"`
UserID string `json:"user_id"`
CreatedBy string `json:"created_by"`
Datasets []string `json:"datasets,omitempty"`
Attempts []Attempt `json:"attempts,omitempty"`
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
MemoryGB int `json:"memory_gb,omitempty"`
CPU int `json:"cpu,omitempty"`
GPU int `json:"gpu,omitempty"`
RetryCount int `json:"retry_count"`
MaxRetries int `json:"max_retries"`
Priority int64 `json:"priority"`
}
// Attempt represents a single execution attempt of a task
type Attempt struct {
Attempt int `json:"attempt"` // Attempt number (1-indexed)
StartedAt time.Time `json:"started_at"` // When attempt started
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
Status string `json:"status"` // running, completed, failed
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
ExitCode int `json:"exit_code,omitempty"` // Process exit code
Signal string `json:"signal,omitempty"` // Termination signal (if any)
Error string `json:"error,omitempty"` // Error message (if failed)
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
StartedAt time.Time `json:"started_at"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Status string `json:"status"`
FailureClass FailureClass `json:"failure_class,omitempty"`
Signal string `json:"signal,omitempty"`
Error string `json:"error,omitempty"`
LogTail string `json:"log_tail,omitempty"`
Attempt int `json:"attempt"`
ExitCode int `json:"exit_code,omitempty"`
}

View file

@ -9,22 +9,22 @@ type TrackingConfig struct {
// MLflowTrackingConfig controls MLflow integration.
type MLflowTrackingConfig struct {
Mode string `json:"mode,omitempty"`
TrackingURI string `json:"tracking_uri,omitempty"`
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
}
// TensorBoardTrackingConfig controls TensorBoard integration.
type TensorBoardTrackingConfig struct {
Mode string `json:"mode,omitempty"`
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
}
// WandbTrackingConfig controls Weights & Biases integration.
type WandbTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
Mode string `json:"mode,omitempty"`
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
Enabled bool `json:"enabled"`
}

View file

@ -5,6 +5,8 @@ import (
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
var ErrInvalidQueueBackend = errors.New("invalid queue backend")
@ -48,24 +50,38 @@ type Backend interface {
type QueueBackend string
const (
QueueBackendRedis QueueBackend = "redis"
QueueBackendSQLite QueueBackend = "sqlite"
QueueBackendFS QueueBackend = "filesystem"
QueueBackendNative QueueBackend = "native" // Native C++ queue_index (requires -tags native_libs)
QueueBackendRedis QueueBackend = "redis"
QueueBackendSQLite QueueBackend = "sqlite"
QueueBackendFS QueueBackend = "filesystem"
QueueBackendNative QueueBackend = "native" // Native C++ queue_index (requires -tags native_libs)
QueueBackendScheduler QueueBackend = "scheduler" // Distributed mode via WebSocket
)
type SchedulerConfig struct {
Address string // Scheduler address (e.g., "192.168.1.10:7777")
Cert string // Path to scheduler's TLS certificate
Token string // Worker authentication token
}
type BackendConfig struct {
Mode string // "standalone" | "distributed"
Backend QueueBackend
RedisAddr string
RedisPassword string
RedisDB int
SQLitePath string
FilesystemPath string
FallbackToFilesystem bool
RedisDB int
MetricsFlushInterval time.Duration
FallbackToFilesystem bool
Scheduler SchedulerConfig // Config for distributed mode
}
func NewBackend(cfg BackendConfig) (Backend, error) {
// Distributed mode: use SchedulerBackend
if cfg.Mode == "distributed" {
return NewSchedulerBackendFromConfig(cfg.Scheduler)
}
mkFallback := func(err error) (Backend, error) {
if !cfg.FallbackToFilesystem {
return nil, err
@ -112,3 +128,15 @@ func NewBackend(cfg BackendConfig) (Backend, error) {
return nil, ErrInvalidQueueBackend
}
}
// NewSchedulerBackendFromConfig creates a SchedulerBackend from config
func NewSchedulerBackendFromConfig(cfg SchedulerConfig) (Backend, error) {
if cfg.Address == "" {
return nil, fmt.Errorf("scheduler address is required for distributed mode")
}
conn := scheduler.NewSchedulerConn(cfg.Address, cfg.Cert, cfg.Token, "", scheduler.WorkerCapabilities{})
if err := conn.Connect(); err != nil {
return nil, fmt.Errorf("connect to scheduler: %w", err)
}
return NewSchedulerBackend(conn), nil
}

View file

@ -11,9 +11,9 @@ var ErrAlreadyQueued = fmt.Errorf("job already queued with this commit")
// CommitDedup tracks recently queued commits to prevent duplicate submissions
type CommitDedup struct {
mu sync.RWMutex
commits map[string]time.Time // key: "job_name:commit_id" -> queued_at
commits map[string]time.Time
ttl time.Duration
mu sync.RWMutex
}
// NewCommitDedup creates a new commit deduplication tracker

View file

@ -8,28 +8,77 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
// validTaskID is an allowlist regex for task IDs.
// Only alphanumeric, underscore, and hyphen allowed. Max 128 chars.
// This prevents path traversal attacks (null bytes, slashes, backslashes, etc.)
var validTaskID = regexp.MustCompile(`^[a-zA-Z0-9_\-]{1,128}$`)
// validateTaskID checks if a task ID is valid according to the allowlist.
func validateTaskID(id string) error {
if id == "" {
return errors.New("task ID is required")
}
if !validTaskID.MatchString(id) {
return fmt.Errorf("invalid task ID %q: must match %s", id, validTaskID.String())
}
return nil
}
// writeTaskFile writes task data with O_NOFOLLOW, fsync, and cleanup on error.
// This prevents symlink attacks and ensures crash safety.
func writeTaskFile(path string, data []byte) error {
// Use O_NOFOLLOW to prevent following symlinks (TOCTOU protection)
f, err := fileutil.OpenFileNoFollow(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
if err != nil {
return fmt.Errorf("open (symlink rejected): %w", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(path) // remove partial write
return fmt.Errorf("write: %w", err)
}
// CRITICAL: fsync ensures data is flushed to disk before returning
if err := f.Sync(); err != nil {
f.Close()
os.Remove(path) // remove unsynced file
return fmt.Errorf("fsync: %w", err)
}
// Close can fail on some filesystems (NFS, network-backed volumes)
if err := f.Close(); err != nil {
os.Remove(path) // remove file if close failed
return fmt.Errorf("close: %w", err)
}
return nil
}
// Queue implements a filesystem-based task queue
type Queue struct {
root string
ctx context.Context
cancel context.CancelFunc
root string
}
type queueIndex struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at"`
Tasks []queueIndexTask `json:"tasks"`
Version int `json:"version"`
}
type queueIndexTask struct {
ID string `json:"id"`
Priority int64 `json:"priority"`
CreatedAt string `json:"created_at"`
Priority int64 `json:"priority"`
}
// NewQueue creates a new filesystem queue instance
@ -65,29 +114,42 @@ func (q *Queue) AddTask(task *domain.Task) error {
if task == nil {
return errors.New("task is nil")
}
if task.ID == "" {
return errors.New("task ID is required")
if err := validateTaskID(task.ID); err != nil {
return err
}
pendingDir := filepath.Join(q.root, "pending", "entries")
taskFile := filepath.Join(pendingDir, task.ID+".json")
// SECURITY: Verify resolved path is still inside pendingDir (symlink/traversal check)
resolvedDir, err := filepath.EvalSymlinks(pendingDir)
if err != nil {
return fmt.Errorf("resolve pending dir: %w", err)
}
resolvedFile, err := filepath.EvalSymlinks(filepath.Dir(taskFile))
if err != nil {
return fmt.Errorf("resolve task dir: %w", err)
}
if !strings.HasPrefix(resolvedFile+string(filepath.Separator), resolvedDir+string(filepath.Separator)) {
return fmt.Errorf("task path %q escapes queue root", taskFile)
}
data, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
if err := os.WriteFile(taskFile, data, 0640); err != nil {
// SECURITY: Write with fsync + cleanup on error + O_NOFOLLOW to prevent symlink attacks
if err := writeTaskFile(taskFile, data); err != nil {
return fmt.Errorf("failed to write task file: %w", err)
}
return nil
}
// GetTask retrieves a task by ID
func (q *Queue) GetTask(id string) (*domain.Task, error) {
if id == "" {
return nil, errors.New("task ID is required")
if err := validateTaskID(id); err != nil {
return nil, err
}
// Search in all directories
@ -140,6 +202,9 @@ func (q *Queue) ListTasks() ([]*domain.Task, error) {
// CancelTask cancels a task
func (q *Queue) CancelTask(id string) error {
if err := validateTaskID(id); err != nil {
return err
}
// Remove from pending if exists
pendingFile := filepath.Join(q.root, "pending", "entries", id+".json")
if _, err := os.Stat(pendingFile); err == nil {
@ -150,8 +215,11 @@ func (q *Queue) CancelTask(id string) error {
// UpdateTask updates a task
func (q *Queue) UpdateTask(task *domain.Task) error {
if task == nil || task.ID == "" {
return errors.New("task is nil or missing ID")
if task == nil {
return errors.New("task is nil")
}
if err := validateTaskID(task.ID); err != nil {
return err
}
// Find current location
@ -173,7 +241,11 @@ func (q *Queue) UpdateTask(task *domain.Task) error {
return fmt.Errorf("failed to marshal task: %w", err)
}
return os.WriteFile(currentFile, data, 0640)
// SECURITY: Write with O_NOFOLLOW + fsync + cleanup on error
if err := writeTaskFile(currentFile, data); err != nil {
return fmt.Errorf("failed to write task file: %w", err)
}
return nil
}
// rebuildIndex rebuilds the queue index

View file

@ -13,24 +13,25 @@ import (
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
type FilesystemQueue struct {
root string
ctx context.Context
cancel context.CancelFunc
root string
}
type filesystemQueueIndex struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at"`
Tasks []filesystemQueueIndexTask `json:"tasks"`
Version int `json:"version"`
}
type filesystemQueueIndexTask struct {
ID string `json:"id"`
Priority int64 `json:"priority"`
CreatedAt string `json:"created_at"`
Priority int64 `json:"priority"`
}
func NewFilesystemQueue(root string) (*FilesystemQueue, error) {
@ -572,9 +573,6 @@ func writeFileAtomic(path string, data []byte, perm os.FileMode) error {
if err := os.MkdirAll(dir, 0750); err != nil {
return err
}
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, perm); err != nil {
return err
}
return os.Rename(tmp, path)
// SECURITY: Use WriteFileSafe for atomic write with fsync
return fileutil.WriteFileSafe(path, data, perm)
}

View file

@ -20,13 +20,13 @@ const (
// TaskQueue manages ML experiment tasks via Redis
type TaskQueue struct {
client *redis.Client
ctx context.Context
client *redis.Client
cancel context.CancelFunc
metricsCh chan metricEvent
metricsDone chan struct{}
dedup *CommitDedup
flushEvery time.Duration
dedup *CommitDedup // Tracks recently queued commits
}
type metricEvent struct {

View file

@ -12,10 +12,10 @@ import (
// DatasetInfo contains information about a dataset.
type DatasetInfo struct {
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
Location string `json:"location"` // "nas" or "ml"
LastAccess time.Time `json:"last_access"`
Name string `json:"name"`
Location string `json:"location"`
SizeBytes int64 `json:"size_bytes"`
}
// DatasetStore manages dataset metadata and transfer tracking.

View file

@ -35,10 +35,10 @@ type DBConfig struct {
Type string
Connection string
Host string
Port int
Username string
Password string
Database string
Port int
}
// DB wraps a database connection with type information.

View file

@ -11,17 +11,18 @@ import (
)
type Experiment struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Status string `json:"status"`
UserID string `json:"user_id,omitempty"`
WorkspaceID string `json:"workspace_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ExperimentEnvironment struct {
CreatedAt time.Time `json:"created_at"`
PythonVersion string `json:"python_version"`
CUDAVersion string `json:"cuda_version,omitempty"`
SystemOS string `json:"system_os"`
@ -30,16 +31,15 @@ type ExperimentEnvironment struct {
RequirementsHash string `json:"requirements_hash"`
CondaEnvHash string `json:"conda_env_hash,omitempty"`
Dependencies json.RawMessage `json:"dependencies,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ExperimentGitInfo struct {
CreatedAt time.Time `json:"created_at"`
CommitSHA string `json:"commit_sha"`
Branch string `json:"branch"`
RemoteURL string `json:"remote_url"`
IsDirty bool `json:"is_dirty"`
DiffPatch string `json:"diff_patch,omitempty"`
CreatedAt time.Time `json:"created_at"`
IsDirty bool `json:"is_dirty"`
}
type ExperimentSeeds struct {
@ -51,17 +51,17 @@ type ExperimentSeeds struct {
}
type Dataset struct {
Name string `json:"name"`
URL string `json:"url"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Name string `json:"name"`
URL string `json:"url"`
}
type ExperimentWithMetadata struct {
Experiment Experiment `json:"experiment"`
Environment *ExperimentEnvironment `json:"environment,omitempty"`
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
Experiment Experiment `json:"experiment"`
}
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {

View file

@ -10,30 +10,30 @@ import (
// Job represents a machine learning job in the system.
type Job struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
Metadata map[string]string `json:"metadata,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
Status string `json:"status"`
ID string `json:"id"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Args string `json:"args"`
JobName string `json:"job_name"`
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
Priority int64 `json:"priority"`
}
// Worker represents a worker node in the system.
type Worker struct {
LastHeartbeat time.Time `json:"last_heartbeat"`
Metadata map[string]string `json:"metadata,omitempty"`
ID string `json:"id"`
Hostname string `json:"hostname"`
LastHeartbeat time.Time `json:"last_heartbeat"`
Status string `json:"status"`
CurrentJobs int `json:"current_jobs"`
MaxJobs int `json:"max_jobs"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// CreateJob inserts a new job into the database.

View file

@ -8,23 +8,23 @@ import (
// Metric represents a recorded metric from WebSocket connections
type Metric struct {
ID int64 `json:"id"`
Name string `json:"name"`
Value float64 `json:"value"`
User string `json:"user,omitempty"`
RecordedAt time.Time `json:"recorded_at"`
Name string `json:"name"`
User string `json:"user,omitempty"`
ID int64 `json:"id"`
Value float64 `json:"value"`
}
// MetricSummary represents aggregated metric statistics
type MetricSummary struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Name string `json:"name"`
Count int64 `json:"count"`
Avg float64 `json:"avg"`
Min float64 `json:"min"`
Max float64 `json:"max"`
Sum float64 `json:"sum"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
// RecordMetric records a metric to the database