refactor: extract domain types and consolidate error system (Phases 1-2)

Phase 1: Extract Domain Types
=============================
- Create internal/domain/ package with canonical types:
  - domain/task.go: Task, Attempt structs
  - domain/tracking.go: TrackingConfig and MLflow/TensorBoard/Wandb configs
  - domain/dataset.go: DatasetSpec
  - domain/status.go: JobStatus constants
  - domain/errors.go: FailureClass system with classification functions
  - domain/doc.go: package documentation

- Update queue/task.go to re-export domain types (backward compatibility)
- Update TUI model/state.go to use domain types via type aliases
- Simplify TUI services: remove ~60 lines of conversion functions

Phase 2: Delete ErrorCategory System
====================================
- Remove deprecated ErrorCategory type and constants
- Remove TaskError struct and related functions
- Remove mapping functions: ClassifyError, IsRetryable, GetUserMessage, RetryDelay
- Update all queue implementations to use domain.FailureClass directly:
  - queue/metrics.go: RecordTaskFailure/Retry now take FailureClass
  - queue/queue.go: RetryTask uses domain.ClassifyFailure
  - queue/filesystem_queue.go: RetryTask and MoveToDeadLetterQueue updated
  - queue/sqlite_queue.go: RetryTask and MoveToDeadLetterQueue updated

Lines eliminated: ~190 lines of conversion and mapping code
Result: Single source of truth for domain types and error classification
This commit is contained in:
Jeremie Fraeys 2026-02-17 12:34:28 -05:00
parent e286fd7769
commit 6580917ba8
No known key found for this signature in database
16 changed files with 428 additions and 603 deletions

View file

@ -4,17 +4,12 @@ This guide covers security best practices for deploying Fetch ML in a homelab en
## Quick Setup
Run the secure setup script:
Secure setup requires manual configuration:
```bash
./scripts/setup-secure-homelab.sh
```
This will:
- Generate secure API keys
- Create TLS certificates
- Set up secure configuration
- Create environment files with proper permissions
1. **Generate API keys**: Use the instructions in [API Security](#api-security) below
2. **Create TLS certificates**: Use OpenSSL commands in [Troubleshooting](#troubleshooting)
3. **Configure security**: Copy and edit `configs/api/homelab-secure.yaml`
4. **Set permissions**: Ensure `.api-keys` and `.env.secure` have 600 permissions
## Security Features
@ -54,24 +49,30 @@ This will:
## Deployment Options
### Option 1: Docker Compose (Recommended)
```bash
# Generate secure setup
./scripts/setup-secure-homelab.sh
# Configure secure setup manually (see Quick Setup above)
# Copy and edit the secure configuration
cp configs/api/homelab-secure.yaml configs/api/my-secure.yaml
# Edit with your API keys, TLS settings, and IP whitelist
# Deploy with security overlay
docker-compose -f docker-compose.yml -f docker-compose.homelab-secure.yml up -d
```
### Option 2: Direct Deployment
```bash
# Generate secure setup
./scripts/setup-secure-homelab.sh
# Configure secure setup manually (see Quick Setup above)
# Copy and edit the secure configuration
cp configs/api/homelab-secure.yaml configs/api/my-secure.yaml
# Edit with your API keys, TLS settings, and IP whitelist
# Load environment variables
source .env.secure
# Start server
./api-server -config configs/api/homelab-secure.yaml
./api-server -config configs/api/my-secure.yaml
```
## Security Checklist

View file

@ -11,8 +11,25 @@ import (
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
"github.com/charmbracelet/lipgloss"
"github.com/jfraeys/fetch_ml/internal/domain"
)
// Re-export domain types for TUI use
// Task represents a task in the TUI
type Task = domain.Task
// TrackingConfig specifies experiment tracking tools
type TrackingConfig = domain.TrackingConfig
// MLflowTrackingConfig controls MLflow integration
type MLflowTrackingConfig = domain.MLflowTrackingConfig
// TensorBoardTrackingConfig controls TensorBoard integration
type TensorBoardTrackingConfig = domain.TensorBoardTrackingConfig
// WandbTrackingConfig controls Weights & Biases integration
type WandbTrackingConfig = domain.WandbTrackingConfig
// ViewMode represents the current view mode in the TUI
type ViewMode int
@ -69,50 +86,6 @@ func (j Job) Description() string {
// FilterValue returns the value used for filtering
func (j Job) FilterValue() string { return j.Name }
// Task represents a task in the TUI
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
Error string `json:"error,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Tracking *TrackingConfig `json:"tracking,omitempty"`
}
// TrackingConfig specifies experiment tracking tools
type TrackingConfig struct {
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
}
// MLflowTrackingConfig controls MLflow integration
type MLflowTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"`
TrackingURI string `json:"tracking_uri,omitempty"`
}
// TensorBoardTrackingConfig controls TensorBoard integration
type TensorBoardTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"`
}
// WandbTrackingConfig controls Weights & Biases integration
type WandbTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"`
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
}
// DatasetInfo represents dataset information in the TUI
type DatasetInfo struct {
Name string `json:"name"`

View file

@ -7,11 +7,15 @@ import (
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// Task is an alias for domain.Task for TUI compatibility
type Task = domain.Task
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
type TaskQueue struct {
internal *queue.TaskQueue
@ -45,7 +49,7 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
}
// EnqueueTask adds a new task to the queue
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) {
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
// Create internal task
internalTask := &queue.Task{
JobName: jobName,
@ -59,21 +63,12 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.T
return nil, err
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: "queued",
Priority: internalTask.Priority,
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
Tracking: convertTrackingToModel(internalTask.Tracking),
}, nil
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// GetNextTask retrieves the next task from the queue
func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
func (tq *TaskQueue) GetNextTask() (*Task, error) {
internalTask, err := tq.internal.GetNextTask()
if err != nil {
return nil, err
@ -82,79 +77,36 @@ func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
return nil, nil
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: internalTask.Status,
Priority: internalTask.Priority,
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
Tracking: convertTrackingToModel(internalTask.Tracking),
}, nil
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// GetTask retrieves a specific task by ID
func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) {
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
internalTask, err := tq.internal.GetTask(taskID)
if err != nil {
return nil, err
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: internalTask.Status,
Priority: internalTask.Priority,
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
Tracking: convertTrackingToModel(internalTask.Tracking),
}, nil
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// UpdateTask updates a task's status and metadata
func (tq *TaskQueue) UpdateTask(task *model.Task) error {
// Convert to internal task
internalTask := &queue.Task{
ID: task.ID,
JobName: task.JobName,
Args: task.Args,
Status: task.Status,
Priority: task.Priority,
CreatedAt: task.CreatedAt,
Metadata: task.Metadata,
Tracking: convertTrackingToInternal(task.Tracking),
}
return tq.internal.UpdateTask(internalTask)
func (tq *TaskQueue) UpdateTask(task *Task) error {
// task is already domain.Task, pass directly to internal queue
return tq.internal.UpdateTask(task)
}
// GetQueuedTasks retrieves all queued tasks
func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) {
func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) {
internalTasks, err := tq.internal.GetAllTasks()
if err != nil {
return nil, err
}
// Convert to TUI models
tasks := make([]*model.Task, len(internalTasks))
for i, task := range internalTasks {
tasks[i] = &model.Task{
ID: task.ID,
JobName: task.JobName,
Args: task.Args,
Status: task.Status,
Priority: task.Priority,
CreatedAt: task.CreatedAt,
Metadata: task.Metadata,
Tracking: convertTrackingToModel(task.Tracking),
}
}
return tasks, nil
// Return domain.Tasks directly (no conversion needed)
return internalTasks, nil
}
// GetJobStatus gets the status of all jobs with the given name
@ -257,63 +209,3 @@ func NewMLServer(cfg *config.Config) (*MLServer, error) {
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
return &MLServer{SSHClient: client, addr: addr}, nil
}
func convertTrackingToModel(t *queue.TrackingConfig) *model.TrackingConfig {
if t == nil {
return nil
}
out := &model.TrackingConfig{}
if t.MLflow != nil {
out.MLflow = &model.MLflowTrackingConfig{
Enabled: t.MLflow.Enabled,
Mode: t.MLflow.Mode,
TrackingURI: t.MLflow.TrackingURI,
}
}
if t.TensorBoard != nil {
out.TensorBoard = &model.TensorBoardTrackingConfig{
Enabled: t.TensorBoard.Enabled,
Mode: t.TensorBoard.Mode,
}
}
if t.Wandb != nil {
out.Wandb = &model.WandbTrackingConfig{
Enabled: t.Wandb.Enabled,
Mode: t.Wandb.Mode,
APIKey: t.Wandb.APIKey,
Project: t.Wandb.Project,
Entity: t.Wandb.Entity,
}
}
return out
}
func convertTrackingToInternal(t *model.TrackingConfig) *queue.TrackingConfig {
if t == nil {
return nil
}
out := &queue.TrackingConfig{}
if t.MLflow != nil {
out.MLflow = &queue.MLflowTrackingConfig{
Enabled: t.MLflow.Enabled,
Mode: t.MLflow.Mode,
TrackingURI: t.MLflow.TrackingURI,
}
}
if t.TensorBoard != nil {
out.TensorBoard = &queue.TensorBoardTrackingConfig{
Enabled: t.TensorBoard.Enabled,
Mode: t.TensorBoard.Mode,
}
}
if t.Wandb != nil {
out.Wandb = &queue.WandbTrackingConfig{
Enabled: t.Wandb.Enabled,
Mode: t.Wandb.Mode,
APIKey: t.Wandb.APIKey,
Project: t.Wandb.Project,
Entity: t.Wandb.Entity,
}
}
return out
}

View file

@ -67,10 +67,11 @@ curl -f http://localhost:3100/ready
**Setup**:
```bash
# Run production setup script
sudo ./scripts/setup-monitoring-prod.sh /data/monitoring ml-user ml-group
# Set up monitoring provisioning (Grafana datasources/providers)
python3 scripts/setup_monitoring.py
# Start services
# Set up systemd services for production monitoring
# See systemd unit files in deployments/systemd/
sudo systemctl start prometheus loki promtail grafana
sudo systemctl enable prometheus loki promtail grafana
```

View file

@ -0,0 +1,9 @@
package domain
// DatasetSpec describes a dataset input with optional provenance fields.
type DatasetSpec struct {
Name string `json:"name"`
Version string `json:"version,omitempty"`
Checksum string `json:"checksum,omitempty"`
URI string `json:"uri,omitempty"`
}

16
internal/domain/doc.go Normal file
View file

@ -0,0 +1,16 @@
// Package domain provides core domain types for fetch_ml.
//
// This package contains the fundamental data structures used across the entire
// application. It has zero dependencies on other internal packages - only
// standard library imports are allowed.
//
// The types in this package represent:
// - Tasks and job execution (Task, Attempt)
// - Dataset specifications (DatasetSpec)
// - Experiment tracking configuration (TrackingConfig)
// - Job status enumeration (JobStatus)
// - Failure classification (FailureClass)
//
// Schema changes to these types will cause compile errors in all dependent
// packages, ensuring consistency across the codebase.
package domain

150
internal/domain/errors.go Normal file
View file

@ -0,0 +1,150 @@
package domain
import (
"os"
"strings"
"syscall"
)
// FailureClass represents the classification of a task failure
// Used to determine appropriate retry policy and user guidance
type FailureClass string
const (
FailureInfrastructure FailureClass = "infrastructure" // OOM kill, SIGKILL, node failure
FailureCode FailureClass = "code" // non-zero exit, exception, assertion
FailureData FailureClass = "data" // hash mismatch, dataset unreachable
FailureResource FailureClass = "resource" // GPU OOM, disk full, timeout
FailureUnknown FailureClass = "unknown" // cannot classify
)
// ClassifyFailure determines the failure class from exit signals, codes, and log output
func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClass {
logLower := strings.ToLower(logTail)
// Killed by OS — infrastructure failure
if signal == syscall.SIGKILL {
return FailureInfrastructure
}
// CUDA OOM or GPU resource issues
if strings.Contains(logLower, "cuda out of memory") ||
strings.Contains(logLower, "cuda error") ||
strings.Contains(logLower, "gpu oom") {
return FailureResource
}
// General OOM (non-GPU) — infrastructure
if strings.Contains(logLower, "out of memory") ||
strings.Contains(logLower, "oom") ||
strings.Contains(logLower, "cannot allocate memory") {
return FailureInfrastructure
}
// Dataset hash check failed — data failure
if strings.Contains(logLower, "hash mismatch") ||
strings.Contains(logLower, "checksum failed") ||
strings.Contains(logLower, "dataset not found") ||
strings.Contains(logLower, "dataset unreachable") {
return FailureData
}
// Disk/resource exhaustion
if strings.Contains(logLower, "no space left") ||
strings.Contains(logLower, "disk full") ||
strings.Contains(logLower, "disk quota exceeded") {
return FailureResource
}
// Timeout — resource (time budget exceeded)
if strings.Contains(logLower, "timeout") ||
strings.Contains(logLower, "deadline exceeded") ||
strings.Contains(logLower, "context deadline") {
return FailureResource
}
// Network issues — infrastructure
if strings.Contains(logLower, "connection refused") ||
strings.Contains(logLower, "connection reset") ||
strings.Contains(logLower, "no route to host") ||
strings.Contains(logLower, "network unreachable") {
return FailureInfrastructure
}
// Non-zero exit without specific signal — code failure
if exitCode != 0 {
return FailureCode
}
return FailureUnknown
}
// FailureInfo contains complete failure context for the manifest
type FailureInfo struct {
Class FailureClass `json:"class"`
ExitCode int `json:"exit_code,omitempty"`
Signal string `json:"signal,omitempty"`
LogTail string `json:"log_tail,omitempty"`
Suggestion string `json:"suggestion,omitempty"`
AutoRetried bool `json:"auto_retried,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
RetryCap int `json:"retry_cap,omitempty"`
ClassifiedAt string `json:"classified_at,omitempty"`
Context map[string]string `json:"context,omitempty"`
}
// GetFailureSuggestion returns user guidance based on failure class
func GetFailureSuggestion(class FailureClass, logTail string) string {
switch class {
case FailureInfrastructure:
return "Infrastructure failure (node died, OOM kill). Auto-retry in progress."
case FailureCode:
return "Code error in training script. Fix before resubmitting."
case FailureData:
return "Data verification failed. Check dataset accessibility and hashes."
case FailureResource:
if strings.Contains(strings.ToLower(logTail), "cuda") {
return "GPU OOM. Increase --gpu-memory or use smaller batch size."
}
if strings.Contains(strings.ToLower(logTail), "disk") {
return "Disk full. Clean up storage or request more space."
}
return "Resource exhausted. Try with larger allocation or reduced load."
default:
return "Unknown failure. Review logs and contact support if persistent."
}
}
// ShouldAutoRetry determines if a failure class should auto-retry
// infrastructure: 3 retries transparent
// resource: 1 retry with backoff
// unknown: 1 retry (conservative - was retryable in old system)
// others: never auto-retry
func ShouldAutoRetry(class FailureClass, retryCount int) bool {
switch class {
case FailureInfrastructure:
return retryCount < 3
case FailureResource:
return retryCount < 1
case FailureUnknown:
// Unknown failures get 1 retry attempt (conservative, matches old behavior)
return retryCount < 1
default:
// code, data failures never auto-retry
return false
}
}
// RetryDelayForClass returns appropriate backoff for the failure class
func RetryDelayForClass(class FailureClass, retryCount int) int {
switch class {
case FailureInfrastructure:
// Immediate retry for infrastructure
return 0
case FailureResource:
// Short backoff for resource issues
return 30
default:
return 0
}
}

17
internal/domain/status.go Normal file
View file

@ -0,0 +1,17 @@
package domain
// JobStatus represents the status of a job
type JobStatus string
const (
StatusPending JobStatus = "pending"
StatusQueued JobStatus = "queued"
StatusRunning JobStatus = "running"
StatusCompleted JobStatus = "completed"
StatusFailed JobStatus = "failed"
)
// String returns the string representation of the status
func (s JobStatus) String() string {
return string(s)
}

71
internal/domain/task.go Normal file
View file

@ -0,0 +1,71 @@
// Package domain provides core domain types for fetch_ml.
// These types have zero internal dependencies and are used across all packages.
package domain
import (
"time"
)
// Task represents an ML experiment task
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
// TODO(phase1): SnapshotID is an opaque identifier only.
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
SnapshotID string `json:"snapshot_id,omitempty"`
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
// Datasets is kept for backward compatibility (legacy callers).
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// Resource requests (optional, 0 means unspecified)
CPU int `json:"cpu,omitempty"`
MemoryGB int `json:"memory_gb,omitempty"`
GPU int `json:"gpu,omitempty"`
GPUMemory string `json:"gpu_memory,omitempty"`
// User ownership and permissions
UserID string `json:"user_id"` // User who owns this task
Username string `json:"username"` // Username for display
CreatedBy string `json:"created_by"` // User who submitted the task
// Lease management for task resilience
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
// Retry management
RetryCount int `json:"retry_count"` // Number of retry attempts made
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
LastError string `json:"last_error,omitempty"` // Last error encountered
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
// Attempt tracking - complete history of all execution attempts
Attempts []Attempt `json:"attempts,omitempty"`
// Optional tracking configuration for this task
Tracking *TrackingConfig `json:"tracking,omitempty"`
}
// Attempt represents a single execution attempt of a task
type Attempt struct {
Attempt int `json:"attempt"` // Attempt number (1-indexed)
StartedAt time.Time `json:"started_at"` // When attempt started
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
Status string `json:"status"` // running, completed, failed
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
ExitCode int `json:"exit_code,omitempty"` // Process exit code
Signal string `json:"signal,omitempty"` // Termination signal (if any)
Error string `json:"error,omitempty"` // Error message (if failed)
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
}

View file

@ -0,0 +1,30 @@
package domain
// TrackingConfig specifies experiment tracking tools to enable for a task.
type TrackingConfig struct {
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
}
// MLflowTrackingConfig controls MLflow integration.
type MLflowTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
}
// TensorBoardTrackingConfig controls TensorBoard integration.
type TensorBoardTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
}
// WandbTrackingConfig controls Weights & Biases integration.
type WandbTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
}

View file

@ -2,284 +2,31 @@
package queue
import (
"errors"
"fmt"
"os"
"strings"
"syscall"
"github.com/jfraeys/fetch_ml/internal/domain"
)
// FailureClass represents the classification of a task failure
// Used to determine appropriate retry policy and user guidance
type FailureClass string
// Re-export from domain for backward compatibility
// Deprecated: Use internal/domain directly
type (
FailureClass = domain.FailureClass
FailureInfo = domain.FailureInfo
)
// Re-export functions from domain
// Deprecated: Use internal/domain directly
var (
ClassifyFailure = domain.ClassifyFailure
GetFailureSuggestion = domain.GetFailureSuggestion
ShouldAutoRetry = domain.ShouldAutoRetry
RetryDelayForClass = domain.RetryDelayForClass
)
// Re-export constants
// Deprecated: Use internal/domain directly
const (
FailureInfrastructure FailureClass = "infrastructure" // OOM kill, SIGKILL, node failure
FailureCode FailureClass = "code" // non-zero exit, exception, assertion
FailureData FailureClass = "data" // hash mismatch, dataset unreachable
FailureResource FailureClass = "resource" // GPU OOM, disk full, timeout
FailureUnknown FailureClass = "unknown" // cannot classify
FailureInfrastructure = domain.FailureInfrastructure
FailureCode = domain.FailureCode
FailureData = domain.FailureData
FailureResource = domain.FailureResource
FailureUnknown = domain.FailureUnknown
)
// ClassifyFailure determines the failure class from exit signals, codes, and log output
func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClass {
logLower := strings.ToLower(logTail)
// Killed by OS — infrastructure failure
if signal == syscall.SIGKILL {
return FailureInfrastructure
}
// CUDA OOM or GPU resource issues
if strings.Contains(logLower, "cuda out of memory") ||
strings.Contains(logLower, "cuda error") ||
strings.Contains(logLower, "gpu oom") {
return FailureResource
}
// General OOM (non-GPU) — infrastructure
if strings.Contains(logLower, "out of memory") ||
strings.Contains(logLower, "oom") ||
strings.Contains(logLower, "cannot allocate memory") {
return FailureInfrastructure
}
// Dataset hash check failed — data failure
if strings.Contains(logLower, "hash mismatch") ||
strings.Contains(logLower, "checksum failed") ||
strings.Contains(logLower, "dataset not found") ||
strings.Contains(logLower, "dataset unreachable") {
return FailureData
}
// Disk/resource exhaustion
if strings.Contains(logLower, "no space left") ||
strings.Contains(logLower, "disk full") ||
strings.Contains(logLower, "disk quota exceeded") {
return FailureResource
}
// Timeout — resource (time budget exceeded)
if strings.Contains(logLower, "timeout") ||
strings.Contains(logLower, "deadline exceeded") ||
strings.Contains(logLower, "context deadline") {
return FailureResource
}
// Network issues — infrastructure
if strings.Contains(logLower, "connection refused") ||
strings.Contains(logLower, "connection reset") ||
strings.Contains(logLower, "no route to host") ||
strings.Contains(logLower, "network unreachable") {
return FailureInfrastructure
}
// Non-zero exit without specific signal — code failure
if exitCode != 0 {
return FailureCode
}
return FailureUnknown
}
// FailureInfo contains complete failure context for the manifest
type FailureInfo struct {
Class FailureClass `json:"class"`
ExitCode int `json:"exit_code,omitempty"`
Signal string `json:"signal,omitempty"`
LogTail string `json:"log_tail,omitempty"`
Suggestion string `json:"suggestion,omitempty"`
AutoRetried bool `json:"auto_retried,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
RetryCap int `json:"retry_cap,omitempty"`
ClassifiedAt string `json:"classified_at,omitempty"`
Context map[string]string `json:"context,omitempty"`
}
// GetFailureSuggestion returns user guidance based on failure class
func GetFailureSuggestion(class FailureClass, logTail string) string {
switch class {
case FailureInfrastructure:
return "Infrastructure failure (node died, OOM kill). Auto-retry in progress."
case FailureCode:
return "Code error in training script. Fix before resubmitting."
case FailureData:
return "Data verification failed. Check dataset accessibility and hashes."
case FailureResource:
if strings.Contains(strings.ToLower(logTail), "cuda") {
return "GPU OOM. Increase --gpu-memory or use smaller batch size."
}
if strings.Contains(strings.ToLower(logTail), "disk") {
return "Disk full. Clean up storage or request more space."
}
return "Resource exhausted. Try with larger allocation or reduced load."
default:
return "Unknown failure. Review logs and contact support if persistent."
}
}
// ShouldAutoRetry determines if a failure class should auto-retry
// infrastructure: 3 retries transparent
// resource: 1 retry with backoff
// unknown: 1 retry (conservative - was retryable in old system)
// others: never auto-retry
func ShouldAutoRetry(class FailureClass, retryCount int) bool {
switch class {
case FailureInfrastructure:
return retryCount < 3
case FailureResource:
return retryCount < 1
case FailureUnknown:
// Unknown failures get 1 retry attempt (conservative, matches old behavior)
return retryCount < 1
default:
// code, data failures never auto-retry
return false
}
}
// RetryDelayForClass returns appropriate backoff for the failure class
func RetryDelayForClass(class FailureClass, retryCount int) int {
switch class {
case FailureInfrastructure:
// Immediate retry for infrastructure
return 0
case FailureResource:
// Short backoff for resource issues
return 30
default:
return 0
}
}
// ErrorCategory represents the type of error encountered (DEPRECATED: use FailureClass)
type ErrorCategory string
// Error categories for task classification and retry logic
const (
ErrorNetwork ErrorCategory = "network" // Network connectivity issues
ErrorResource ErrorCategory = "resource" // Resource exhaustion (OOM, disk full)
ErrorRateLimit ErrorCategory = "rate_limit" // Rate limiting or throttling
ErrorAuth ErrorCategory = "auth" // Authentication/authorization failures
ErrorValidation ErrorCategory = "validation" // Input validation errors
ErrorTimeout ErrorCategory = "timeout" // Operation timeout
ErrorPermanent ErrorCategory = "permanent" // Non-retryable errors
ErrorUnknown ErrorCategory = "unknown" // Unclassified errors
)
// TaskError wraps an error with category and context
type TaskError struct {
Category ErrorCategory
Message string
Cause error
Context map[string]string
}
func (e *TaskError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("[%s] %s: %v", e.Category, e.Message, e.Cause)
}
return fmt.Sprintf("[%s] %s", e.Category, e.Message)
}
func (e *TaskError) Unwrap() error {
return e.Cause
}
// NewTaskError creates a new categorized error
func NewTaskError(category ErrorCategory, message string, cause error) *TaskError {
return &TaskError{
Category: category,
Message: message,
Cause: cause,
Context: make(map[string]string),
}
}
// ClassifyError categorizes an error for retry logic (DEPRECATED: use classifyFailure)
// This function now delegates to the more accurate classifyFailure
func ClassifyError(err error) ErrorCategory {
if err == nil {
return ErrorUnknown
}
// Check if already classified as TaskError
var taskErr *TaskError
if errors.As(err, &taskErr) {
return taskErr.Category
}
// Delegate to new FailureClass classification
failureClass := ClassifyFailure(0, nil, err.Error())
// Map FailureClass back to ErrorCategory for backward compatibility
switch failureClass {
case FailureInfrastructure:
return ErrorNetwork
case FailureCode:
return ErrorPermanent
case FailureData:
return ErrorValidation
case FailureResource:
return ErrorResource
default:
return ErrorUnknown
}
}
// IsRetryable determines if an error category should be retried
// Now delegates to ShouldAutoRetry with FailureClass mapping
func IsRetryable(category ErrorCategory) bool {
// Map ErrorCategory to FailureClass
var failureClass FailureClass
switch category {
case ErrorNetwork:
failureClass = FailureInfrastructure
case ErrorResource, ErrorTimeout:
failureClass = FailureResource
case ErrorAuth, ErrorValidation, ErrorPermanent:
failureClass = FailureCode
default:
failureClass = FailureUnknown
}
return ShouldAutoRetry(failureClass, 0)
}
// GetUserMessage returns a user-friendly error message with suggestions
func GetUserMessage(category ErrorCategory, err error) string {
messages := map[ErrorCategory]string{
ErrorNetwork: "Network connectivity issue. Please check your network " +
"connection and try again.",
ErrorResource: "System resource exhausted. The system may be under heavy load. " +
"Try again later or contact support.",
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
ErrorValidation: "Invalid input. Please review your request and correct any errors.",
ErrorTimeout: "Operation timed out. The task may be too complex or the system is slow. " +
"Try again or simplify the request.",
ErrorPermanent: "A permanent error occurred. This task cannot be retried automatically.",
ErrorUnknown: "An unexpected error occurred. If this persists, please contact support.",
}
baseMsg := messages[category]
if err != nil {
return fmt.Sprintf("%s (Details: %v)", baseMsg, err)
}
return baseMsg
}
// RetryDelay calculates the retry delay based on error category and retry count
// Now delegates to RetryDelayForClass with FailureClass mapping
func RetryDelay(category ErrorCategory, retryCount int) int {
// Map ErrorCategory to FailureClass
var failureClass FailureClass
switch category {
case ErrorNetwork:
failureClass = FailureInfrastructure
case ErrorResource, ErrorTimeout:
failureClass = FailureResource
default:
failureClass = FailureUnknown
}
return RetryDelayForClass(failureClass, retryCount)
}

View file

@ -10,6 +10,8 @@ import (
"sort"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
)
type FilesystemQueue struct {
@ -177,13 +179,13 @@ func (q *FilesystemQueue) RetryTask(task *Task) error {
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
}
errorCategory := ErrorUnknown
failureClass := domain.FailureUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
failureClass = domain.ClassifyFailure(0, nil, task.Error)
}
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
RecordDLQAddition(string(failureClass))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
}
task.RetryCount++
@ -191,13 +193,13 @@ func (q *FilesystemQueue) RetryTask(task *Task) error {
task.LastError = task.Error
task.Error = ""
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
task.LeaseExpiry = nil
task.LeasedBy = ""
RecordTaskRetry(task.JobName, errorCategory)
RecordTaskRetry(task.JobName, failureClass)
return q.AddTask(task)
}
@ -207,7 +209,8 @@ func (q *FilesystemQueue) MoveToDeadLetterQueue(task *Task, reason string) error
}
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
RecordTaskFailure(task.JobName, failureClass)
return q.UpdateTask(task)
}

View file

@ -1,6 +1,7 @@
package queue
import (
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
@ -31,17 +32,17 @@ var (
Help: "Total number of completed tasks",
}, []string{"job_name", "status"})
// TaskFailures tracks failed tasks by error category.
// TaskFailures tracks failed tasks by failure class.
TaskFailures = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_task_failures_total",
Help: "Total number of failed tasks by error category",
}, []string{"job_name", "error_category"})
Help: "Total number of failed tasks by failure class",
}, []string{"job_name", "failure_class"})
// TaskRetries tracks the total number of task retries.
TaskRetries = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_task_retries_total",
Help: "Total number of task retries",
}, []string{"job_name", "error_category"})
}, []string{"job_name", "failure_class"})
// LeaseExpirations tracks expired leases that were reclaimed.
LeaseExpirations = promauto.NewCounter(prometheus.CounterOpts{
@ -92,14 +93,14 @@ func RecordTaskEnd(jobName, workerID, status string, durationSeconds float64) {
TasksCompleted.WithLabelValues(jobName, status).Inc()
}
// RecordTaskFailure records a task failure with error category
func RecordTaskFailure(jobName string, errorCategory ErrorCategory) {
TaskFailures.WithLabelValues(jobName, string(errorCategory)).Inc()
// RecordTaskFailure records a task failure with failure class
func RecordTaskFailure(jobName string, failureClass domain.FailureClass) {
TaskFailures.WithLabelValues(jobName, string(failureClass)).Inc()
}
// RecordTaskRetry records a task retry
func RecordTaskRetry(jobName string, errorCategory ErrorCategory) {
TaskRetries.WithLabelValues(jobName, string(errorCategory)).Inc()
func RecordTaskRetry(jobName string, failureClass domain.FailureClass) {
TaskRetries.WithLabelValues(jobName, string(failureClass)).Inc()
}
// RecordLeaseExpiration records a lease expiration

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/redis/go-redis/v9"
)
@ -439,7 +440,7 @@ func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
return tq.UpdateTask(task)
}
// RetryTask re-queues a failed task with smart backoff based on error category
// RetryTask re-queues a failed task with smart backoff based on failure class
func (tq *TaskQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
// Move to dead letter queue
@ -448,15 +449,15 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
}
// Classify the error if it exists
errorCategory := ErrorUnknown
failureClass := domain.FailureUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
failureClass = domain.ClassifyFailure(0, nil, task.Error)
}
// Check if error is retryable
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
RecordDLQAddition(string(failureClass))
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
}
task.RetryCount++
@ -464,8 +465,8 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
task.LastError = task.Error // Preserve last error
task.Error = "" // Clear current error
// Calculate smart backoff based on error category
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
// Calculate smart backoff based on failure class
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
@ -474,7 +475,7 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
task.LeasedBy = ""
// Record retry metrics
RecordTaskRetry(task.JobName, errorCategory)
RecordTaskRetry(task.JobName, failureClass)
// Re-queue with same priority
return tq.AddTask(task)
@ -493,8 +494,9 @@ func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
// Store in dead letter queue with timestamp
key := "task:dlq:" + task.ID
// Record metrics
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
// Record metrics using domain.FailureClass
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
RecordTaskFailure(task.JobName, failureClass)
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)

View file

@ -8,6 +8,7 @@ import (
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/domain"
_ "github.com/mattn/go-sqlite3"
)
@ -495,13 +496,13 @@ func (q *SQLiteQueue) RetryTask(task *Task) error {
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
}
errorCategory := ErrorUnknown
failureClass := domain.FailureUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
failureClass = domain.ClassifyFailure(0, nil, task.Error)
}
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
RecordDLQAddition(string(failureClass))
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
}
task.RetryCount++
@ -509,13 +510,13 @@ func (q *SQLiteQueue) RetryTask(task *Task) error {
task.LastError = task.Error
task.Error = ""
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
task.LeaseExpiry = nil
task.LeasedBy = ""
RecordTaskRetry(task.JobName, errorCategory)
RecordTaskRetry(task.JobName, failureClass)
return q.AddTask(task)
}
@ -523,7 +524,8 @@ func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
RecordTaskFailure(task.JobName, failureClass)
return q.UpdateTask(task)
}

View file

@ -1,111 +1,21 @@
package queue
import (
"time"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/domain"
)
// DatasetSpec describes a dataset input with optional provenance fields.
type DatasetSpec struct {
Name string `json:"name"`
Version string `json:"version,omitempty"`
Checksum string `json:"checksum,omitempty"`
URI string `json:"uri,omitempty"`
}
// Task represents an ML experiment task
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Output string `json:"output,omitempty"`
// TODO(phase1): SnapshotID is an opaque identifier only.
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
SnapshotID string `json:"snapshot_id,omitempty"`
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
// Datasets is kept for backward compatibility (legacy callers).
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// Resource requests (optional, 0 means unspecified)
CPU int `json:"cpu,omitempty"`
MemoryGB int `json:"memory_gb,omitempty"`
GPU int `json:"gpu,omitempty"`
GPUMemory string `json:"gpu_memory,omitempty"`
// User ownership and permissions
UserID string `json:"user_id"` // User who owns this task
Username string `json:"username"` // Username for display
CreatedBy string `json:"created_by"` // User who submitted the task
// Lease management for task resilience
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
// Retry management
RetryCount int `json:"retry_count"` // Number of retry attempts made
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
LastError string `json:"last_error,omitempty"` // Last error encountered
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
// Attempt tracking - complete history of all execution attempts
Attempts []Attempt `json:"attempts,omitempty"`
// Optional tracking configuration for this task
Tracking *TrackingConfig `json:"tracking,omitempty"`
}
// Attempt represents a single execution attempt of a task
type Attempt struct {
Attempt int `json:"attempt"` // Attempt number (1-indexed)
StartedAt time.Time `json:"started_at"` // When attempt started
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
Status string `json:"status"` // running, completed, failed
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
ExitCode int `json:"exit_code,omitempty"` // Process exit code
Signal string `json:"signal,omitempty"` // Termination signal (if any)
Error string `json:"error,omitempty"` // Error message (if failed)
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
}
// TrackingConfig specifies experiment tracking tools to enable for a task.
type TrackingConfig struct {
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
}
// MLflowTrackingConfig controls MLflow integration.
type MLflowTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
}
// TensorBoardTrackingConfig controls TensorBoard integration.
type TensorBoardTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
}
// WandbTrackingConfig controls Weights & Biases integration.
type WandbTrackingConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
}
// Re-export domain types for backward compatibility
// Deprecated: Use internal/domain directly
type (
Task = domain.Task
Attempt = domain.Attempt
DatasetSpec = domain.DatasetSpec
TrackingConfig = domain.TrackingConfig
MLflowTrackingConfig = domain.MLflowTrackingConfig
TensorBoardTrackingConfig = domain.TensorBoardTrackingConfig
WandbTrackingConfig = domain.WandbTrackingConfig
)
// Redis key constants
var (