refactor: extract domain types and consolidate error system (Phases 1-2)
Phase 1: Extract Domain Types ============================= - Create internal/domain/ package with canonical types: - domain/task.go: Task, Attempt structs - domain/tracking.go: TrackingConfig and MLflow/TensorBoard/Wandb configs - domain/dataset.go: DatasetSpec - domain/status.go: JobStatus constants - domain/errors.go: FailureClass system with classification functions - domain/doc.go: package documentation - Update queue/task.go to re-export domain types (backward compatibility) - Update TUI model/state.go to use domain types via type aliases - Simplify TUI services: remove ~60 lines of conversion functions Phase 2: Delete ErrorCategory System ==================================== - Remove deprecated ErrorCategory type and constants - Remove TaskError struct and related functions - Remove mapping functions: ClassifyError, IsRetryable, GetUserMessage, RetryDelay - Update all queue implementations to use domain.FailureClass directly: - queue/metrics.go: RecordTaskFailure/Retry now take FailureClass - queue/queue.go: RetryTask uses domain.ClassifyFailure - queue/filesystem_queue.go: RetryTask and MoveToDeadLetterQueue updated - queue/sqlite_queue.go: RetryTask and MoveToDeadLetterQueue updated Lines eliminated: ~190 lines of conversion and mapping code Result: Single source of truth for domain types and error classification
This commit is contained in:
parent
e286fd7769
commit
6580917ba8
16 changed files with 428 additions and 603 deletions
31
SECURITY.md
31
SECURITY.md
|
|
@ -4,17 +4,12 @@ This guide covers security best practices for deploying Fetch ML in a homelab en
|
|||
|
||||
## Quick Setup
|
||||
|
||||
Run the secure setup script:
|
||||
Secure setup requires manual configuration:
|
||||
|
||||
```bash
|
||||
./scripts/setup-secure-homelab.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Generate secure API keys
|
||||
- Create TLS certificates
|
||||
- Set up secure configuration
|
||||
- Create environment files with proper permissions
|
||||
1. **Generate API keys**: Use the instructions in [API Security](#api-security) below
|
||||
2. **Create TLS certificates**: Use OpenSSL commands in [Troubleshooting](#troubleshooting)
|
||||
3. **Configure security**: Copy and edit `configs/api/homelab-secure.yaml`
|
||||
4. **Set permissions**: Ensure `.api-keys` and `.env.secure` have 600 permissions
|
||||
|
||||
## Security Features
|
||||
|
||||
|
|
@ -54,24 +49,30 @@ This will:
|
|||
## Deployment Options
|
||||
|
||||
### Option 1: Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
# Generate secure setup
|
||||
./scripts/setup-secure-homelab.sh
|
||||
# Configure secure setup manually (see Quick Setup above)
|
||||
# Copy and edit the secure configuration
|
||||
cp configs/api/homelab-secure.yaml configs/api/my-secure.yaml
|
||||
# Edit with your API keys, TLS settings, and IP whitelist
|
||||
|
||||
# Deploy with security overlay
|
||||
docker-compose -f docker-compose.yml -f docker-compose.homelab-secure.yml up -d
|
||||
```
|
||||
|
||||
### Option 2: Direct Deployment
|
||||
|
||||
```bash
|
||||
# Generate secure setup
|
||||
./scripts/setup-secure-homelab.sh
|
||||
# Configure secure setup manually (see Quick Setup above)
|
||||
# Copy and edit the secure configuration
|
||||
cp configs/api/homelab-secure.yaml configs/api/my-secure.yaml
|
||||
# Edit with your API keys, TLS settings, and IP whitelist
|
||||
|
||||
# Load environment variables
|
||||
source .env.secure
|
||||
|
||||
# Start server
|
||||
./api-server -config configs/api/homelab-secure.yaml
|
||||
./api-server -config configs/api/my-secure.yaml
|
||||
```
|
||||
|
||||
## Security Checklist
|
||||
|
|
|
|||
|
|
@ -11,8 +11,25 @@ import (
|
|||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
// Re-export domain types for TUI use
|
||||
// Task represents a task in the TUI
|
||||
type Task = domain.Task
|
||||
|
||||
// TrackingConfig specifies experiment tracking tools
|
||||
type TrackingConfig = domain.TrackingConfig
|
||||
|
||||
// MLflowTrackingConfig controls MLflow integration
|
||||
type MLflowTrackingConfig = domain.MLflowTrackingConfig
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration
|
||||
type TensorBoardTrackingConfig = domain.TensorBoardTrackingConfig
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration
|
||||
type WandbTrackingConfig = domain.WandbTrackingConfig
|
||||
|
||||
// ViewMode represents the current view mode in the TUI
|
||||
type ViewMode int
|
||||
|
||||
|
|
@ -69,50 +86,6 @@ func (j Job) Description() string {
|
|||
// FilterValue returns the value used for filtering
|
||||
func (j Job) FilterValue() string { return j.Name }
|
||||
|
||||
// Task represents a task in the TUI
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
}
|
||||
|
||||
// TrackingConfig specifies experiment tracking tools
|
||||
type TrackingConfig struct {
|
||||
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
|
||||
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
|
||||
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
|
||||
}
|
||||
|
||||
// MLflowTrackingConfig controls MLflow integration
|
||||
type MLflowTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
TrackingURI string `json:"tracking_uri,omitempty"`
|
||||
}
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration
|
||||
type TensorBoardTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
}
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration
|
||||
type WandbTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Entity string `json:"entity,omitempty"`
|
||||
}
|
||||
|
||||
// DatasetInfo represents dataset information in the TUI
|
||||
type DatasetInfo struct {
|
||||
Name string `json:"name"`
|
||||
|
|
|
|||
|
|
@ -7,11 +7,15 @@ import (
|
|||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// Task is an alias for domain.Task for TUI compatibility
|
||||
type Task = domain.Task
|
||||
|
||||
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
|
||||
type TaskQueue struct {
|
||||
internal *queue.TaskQueue
|
||||
|
|
@ -45,7 +49,7 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
|||
}
|
||||
|
||||
// EnqueueTask adds a new task to the queue
|
||||
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) {
|
||||
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
|
||||
// Create internal task
|
||||
internalTask := &queue.Task{
|
||||
JobName: jobName,
|
||||
|
|
@ -59,21 +63,12 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.T
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: "queued",
|
||||
Priority: internalTask.Priority,
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
Tracking: convertTrackingToModel(internalTask.Tracking),
|
||||
}, nil
|
||||
// Return domain.Task directly (no conversion needed)
|
||||
return internalTask, nil
|
||||
}
|
||||
|
||||
// GetNextTask retrieves the next task from the queue
|
||||
func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
|
||||
func (tq *TaskQueue) GetNextTask() (*Task, error) {
|
||||
internalTask, err := tq.internal.GetNextTask()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -82,79 +77,36 @@ func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: internalTask.Status,
|
||||
Priority: internalTask.Priority,
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
Tracking: convertTrackingToModel(internalTask.Tracking),
|
||||
}, nil
|
||||
// Return domain.Task directly (no conversion needed)
|
||||
return internalTask, nil
|
||||
}
|
||||
|
||||
// GetTask retrieves a specific task by ID
|
||||
func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) {
|
||||
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
|
||||
internalTask, err := tq.internal.GetTask(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: internalTask.Status,
|
||||
Priority: internalTask.Priority,
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
Tracking: convertTrackingToModel(internalTask.Tracking),
|
||||
}, nil
|
||||
// Return domain.Task directly (no conversion needed)
|
||||
return internalTask, nil
|
||||
}
|
||||
|
||||
// UpdateTask updates a task's status and metadata
|
||||
func (tq *TaskQueue) UpdateTask(task *model.Task) error {
|
||||
// Convert to internal task
|
||||
internalTask := &queue.Task{
|
||||
ID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Args: task.Args,
|
||||
Status: task.Status,
|
||||
Priority: task.Priority,
|
||||
CreatedAt: task.CreatedAt,
|
||||
Metadata: task.Metadata,
|
||||
Tracking: convertTrackingToInternal(task.Tracking),
|
||||
}
|
||||
|
||||
return tq.internal.UpdateTask(internalTask)
|
||||
func (tq *TaskQueue) UpdateTask(task *Task) error {
|
||||
// task is already domain.Task, pass directly to internal queue
|
||||
return tq.internal.UpdateTask(task)
|
||||
}
|
||||
|
||||
// GetQueuedTasks retrieves all queued tasks
|
||||
func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) {
|
||||
func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) {
|
||||
internalTasks, err := tq.internal.GetAllTasks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI models
|
||||
tasks := make([]*model.Task, len(internalTasks))
|
||||
for i, task := range internalTasks {
|
||||
tasks[i] = &model.Task{
|
||||
ID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Args: task.Args,
|
||||
Status: task.Status,
|
||||
Priority: task.Priority,
|
||||
CreatedAt: task.CreatedAt,
|
||||
Metadata: task.Metadata,
|
||||
Tracking: convertTrackingToModel(task.Tracking),
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
// Return domain.Tasks directly (no conversion needed)
|
||||
return internalTasks, nil
|
||||
}
|
||||
|
||||
// GetJobStatus gets the status of all jobs with the given name
|
||||
|
|
@ -257,63 +209,3 @@ func NewMLServer(cfg *config.Config) (*MLServer, error) {
|
|||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
return &MLServer{SSHClient: client, addr: addr}, nil
|
||||
}
|
||||
|
||||
func convertTrackingToModel(t *queue.TrackingConfig) *model.TrackingConfig {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
out := &model.TrackingConfig{}
|
||||
if t.MLflow != nil {
|
||||
out.MLflow = &model.MLflowTrackingConfig{
|
||||
Enabled: t.MLflow.Enabled,
|
||||
Mode: t.MLflow.Mode,
|
||||
TrackingURI: t.MLflow.TrackingURI,
|
||||
}
|
||||
}
|
||||
if t.TensorBoard != nil {
|
||||
out.TensorBoard = &model.TensorBoardTrackingConfig{
|
||||
Enabled: t.TensorBoard.Enabled,
|
||||
Mode: t.TensorBoard.Mode,
|
||||
}
|
||||
}
|
||||
if t.Wandb != nil {
|
||||
out.Wandb = &model.WandbTrackingConfig{
|
||||
Enabled: t.Wandb.Enabled,
|
||||
Mode: t.Wandb.Mode,
|
||||
APIKey: t.Wandb.APIKey,
|
||||
Project: t.Wandb.Project,
|
||||
Entity: t.Wandb.Entity,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func convertTrackingToInternal(t *model.TrackingConfig) *queue.TrackingConfig {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
out := &queue.TrackingConfig{}
|
||||
if t.MLflow != nil {
|
||||
out.MLflow = &queue.MLflowTrackingConfig{
|
||||
Enabled: t.MLflow.Enabled,
|
||||
Mode: t.MLflow.Mode,
|
||||
TrackingURI: t.MLflow.TrackingURI,
|
||||
}
|
||||
}
|
||||
if t.TensorBoard != nil {
|
||||
out.TensorBoard = &queue.TensorBoardTrackingConfig{
|
||||
Enabled: t.TensorBoard.Enabled,
|
||||
Mode: t.TensorBoard.Mode,
|
||||
}
|
||||
}
|
||||
if t.Wandb != nil {
|
||||
out.Wandb = &queue.WandbTrackingConfig{
|
||||
Enabled: t.Wandb.Enabled,
|
||||
Mode: t.Wandb.Mode,
|
||||
APIKey: t.Wandb.APIKey,
|
||||
Project: t.Wandb.Project,
|
||||
Entity: t.Wandb.Entity,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,10 +67,11 @@ curl -f http://localhost:3100/ready
|
|||
|
||||
**Setup**:
|
||||
```bash
|
||||
# Run production setup script
|
||||
sudo ./scripts/setup-monitoring-prod.sh /data/monitoring ml-user ml-group
|
||||
# Set up monitoring provisioning (Grafana datasources/providers)
|
||||
python3 scripts/setup_monitoring.py
|
||||
|
||||
# Start services
|
||||
# Set up systemd services for production monitoring
|
||||
# See systemd unit files in deployments/systemd/
|
||||
sudo systemctl start prometheus loki promtail grafana
|
||||
sudo systemctl enable prometheus loki promtail grafana
|
||||
```
|
||||
|
|
|
|||
9
internal/domain/dataset.go
Normal file
9
internal/domain/dataset.go
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
package domain
|
||||
|
||||
// DatasetSpec describes a dataset input with optional provenance fields.
|
||||
type DatasetSpec struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Checksum string `json:"checksum,omitempty"`
|
||||
URI string `json:"uri,omitempty"`
|
||||
}
|
||||
16
internal/domain/doc.go
Normal file
16
internal/domain/doc.go
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Package domain provides core domain types for fetch_ml.
|
||||
//
|
||||
// This package contains the fundamental data structures used across the entire
|
||||
// application. It has zero dependencies on other internal packages - only
|
||||
// standard library imports are allowed.
|
||||
//
|
||||
// The types in this package represent:
|
||||
// - Tasks and job execution (Task, Attempt)
|
||||
// - Dataset specifications (DatasetSpec)
|
||||
// - Experiment tracking configuration (TrackingConfig)
|
||||
// - Job status enumeration (JobStatus)
|
||||
// - Failure classification (FailureClass)
|
||||
//
|
||||
// Schema changes to these types will cause compile errors in all dependent
|
||||
// packages, ensuring consistency across the codebase.
|
||||
package domain
|
||||
150
internal/domain/errors.go
Normal file
150
internal/domain/errors.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// FailureClass represents the classification of a task failure
|
||||
// Used to determine appropriate retry policy and user guidance
|
||||
type FailureClass string
|
||||
|
||||
const (
|
||||
FailureInfrastructure FailureClass = "infrastructure" // OOM kill, SIGKILL, node failure
|
||||
FailureCode FailureClass = "code" // non-zero exit, exception, assertion
|
||||
FailureData FailureClass = "data" // hash mismatch, dataset unreachable
|
||||
FailureResource FailureClass = "resource" // GPU OOM, disk full, timeout
|
||||
FailureUnknown FailureClass = "unknown" // cannot classify
|
||||
)
|
||||
|
||||
// ClassifyFailure determines the failure class from exit signals, codes, and log output
|
||||
func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClass {
|
||||
logLower := strings.ToLower(logTail)
|
||||
|
||||
// Killed by OS — infrastructure failure
|
||||
if signal == syscall.SIGKILL {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// CUDA OOM or GPU resource issues
|
||||
if strings.Contains(logLower, "cuda out of memory") ||
|
||||
strings.Contains(logLower, "cuda error") ||
|
||||
strings.Contains(logLower, "gpu oom") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// General OOM (non-GPU) — infrastructure
|
||||
if strings.Contains(logLower, "out of memory") ||
|
||||
strings.Contains(logLower, "oom") ||
|
||||
strings.Contains(logLower, "cannot allocate memory") {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// Dataset hash check failed — data failure
|
||||
if strings.Contains(logLower, "hash mismatch") ||
|
||||
strings.Contains(logLower, "checksum failed") ||
|
||||
strings.Contains(logLower, "dataset not found") ||
|
||||
strings.Contains(logLower, "dataset unreachable") {
|
||||
return FailureData
|
||||
}
|
||||
|
||||
// Disk/resource exhaustion
|
||||
if strings.Contains(logLower, "no space left") ||
|
||||
strings.Contains(logLower, "disk full") ||
|
||||
strings.Contains(logLower, "disk quota exceeded") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// Timeout — resource (time budget exceeded)
|
||||
if strings.Contains(logLower, "timeout") ||
|
||||
strings.Contains(logLower, "deadline exceeded") ||
|
||||
strings.Contains(logLower, "context deadline") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// Network issues — infrastructure
|
||||
if strings.Contains(logLower, "connection refused") ||
|
||||
strings.Contains(logLower, "connection reset") ||
|
||||
strings.Contains(logLower, "no route to host") ||
|
||||
strings.Contains(logLower, "network unreachable") {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// Non-zero exit without specific signal — code failure
|
||||
if exitCode != 0 {
|
||||
return FailureCode
|
||||
}
|
||||
|
||||
return FailureUnknown
|
||||
}
|
||||
|
||||
// FailureInfo contains complete failure context for the manifest
|
||||
type FailureInfo struct {
|
||||
Class FailureClass `json:"class"`
|
||||
ExitCode int `json:"exit_code,omitempty"`
|
||||
Signal string `json:"signal,omitempty"`
|
||||
LogTail string `json:"log_tail,omitempty"`
|
||||
Suggestion string `json:"suggestion,omitempty"`
|
||||
AutoRetried bool `json:"auto_retried,omitempty"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
RetryCap int `json:"retry_cap,omitempty"`
|
||||
ClassifiedAt string `json:"classified_at,omitempty"`
|
||||
Context map[string]string `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
// GetFailureSuggestion returns user guidance based on failure class
|
||||
func GetFailureSuggestion(class FailureClass, logTail string) string {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
return "Infrastructure failure (node died, OOM kill). Auto-retry in progress."
|
||||
case FailureCode:
|
||||
return "Code error in training script. Fix before resubmitting."
|
||||
case FailureData:
|
||||
return "Data verification failed. Check dataset accessibility and hashes."
|
||||
case FailureResource:
|
||||
if strings.Contains(strings.ToLower(logTail), "cuda") {
|
||||
return "GPU OOM. Increase --gpu-memory or use smaller batch size."
|
||||
}
|
||||
if strings.Contains(strings.ToLower(logTail), "disk") {
|
||||
return "Disk full. Clean up storage or request more space."
|
||||
}
|
||||
return "Resource exhausted. Try with larger allocation or reduced load."
|
||||
default:
|
||||
return "Unknown failure. Review logs and contact support if persistent."
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldAutoRetry determines if a failure class should auto-retry
|
||||
// infrastructure: 3 retries transparent
|
||||
// resource: 1 retry with backoff
|
||||
// unknown: 1 retry (conservative - was retryable in old system)
|
||||
// others: never auto-retry
|
||||
func ShouldAutoRetry(class FailureClass, retryCount int) bool {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
return retryCount < 3
|
||||
case FailureResource:
|
||||
return retryCount < 1
|
||||
case FailureUnknown:
|
||||
// Unknown failures get 1 retry attempt (conservative, matches old behavior)
|
||||
return retryCount < 1
|
||||
default:
|
||||
// code, data failures never auto-retry
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RetryDelayForClass returns appropriate backoff for the failure class
|
||||
func RetryDelayForClass(class FailureClass, retryCount int) int {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
// Immediate retry for infrastructure
|
||||
return 0
|
||||
case FailureResource:
|
||||
// Short backoff for resource issues
|
||||
return 30
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
17
internal/domain/status.go
Normal file
17
internal/domain/status.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package domain
|
||||
|
||||
// JobStatus represents the status of a job
|
||||
type JobStatus string
|
||||
|
||||
const (
|
||||
StatusPending JobStatus = "pending"
|
||||
StatusQueued JobStatus = "queued"
|
||||
StatusRunning JobStatus = "running"
|
||||
StatusCompleted JobStatus = "completed"
|
||||
StatusFailed JobStatus = "failed"
|
||||
)
|
||||
|
||||
// String returns the string representation of the status
|
||||
func (s JobStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
71
internal/domain/task.go
Normal file
71
internal/domain/task.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
// Package domain provides core domain types for fetch_ml.
|
||||
// These types have zero internal dependencies and are used across all packages.
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task represents an ML experiment task
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
// TODO(phase1): SnapshotID is an opaque identifier only.
|
||||
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
// Datasets is kept for backward compatibility (legacy callers).
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// Resource requests (optional, 0 means unspecified)
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
MemoryGB int `json:"memory_gb,omitempty"`
|
||||
GPU int `json:"gpu,omitempty"`
|
||||
GPUMemory string `json:"gpu_memory,omitempty"`
|
||||
|
||||
// User ownership and permissions
|
||||
UserID string `json:"user_id"` // User who owns this task
|
||||
Username string `json:"username"` // Username for display
|
||||
CreatedBy string `json:"created_by"` // User who submitted the task
|
||||
|
||||
// Lease management for task resilience
|
||||
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
|
||||
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
|
||||
|
||||
// Retry management
|
||||
RetryCount int `json:"retry_count"` // Number of retry attempts made
|
||||
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
|
||||
LastError string `json:"last_error,omitempty"` // Last error encountered
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
|
||||
|
||||
// Attempt tracking - complete history of all execution attempts
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
|
||||
// Optional tracking configuration for this task
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
}
|
||||
|
||||
// Attempt represents a single execution attempt of a task
|
||||
type Attempt struct {
|
||||
Attempt int `json:"attempt"` // Attempt number (1-indexed)
|
||||
StartedAt time.Time `json:"started_at"` // When attempt started
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
|
||||
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
|
||||
Status string `json:"status"` // running, completed, failed
|
||||
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
|
||||
ExitCode int `json:"exit_code,omitempty"` // Process exit code
|
||||
Signal string `json:"signal,omitempty"` // Termination signal (if any)
|
||||
Error string `json:"error,omitempty"` // Error message (if failed)
|
||||
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
|
||||
}
|
||||
30
internal/domain/tracking.go
Normal file
30
internal/domain/tracking.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package domain
|
||||
|
||||
// TrackingConfig specifies experiment tracking tools to enable for a task.
|
||||
type TrackingConfig struct {
|
||||
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
|
||||
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
|
||||
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
|
||||
}
|
||||
|
||||
// MLflowTrackingConfig controls MLflow integration.
|
||||
type MLflowTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
|
||||
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
|
||||
}
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration.
|
||||
type TensorBoardTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
|
||||
}
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration.
|
||||
type WandbTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Entity string `json:"entity,omitempty"`
|
||||
}
|
||||
|
|
@ -2,284 +2,31 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
// FailureClass represents the classification of a task failure
|
||||
// Used to determine appropriate retry policy and user guidance
|
||||
type FailureClass string
|
||||
// Re-export from domain for backward compatibility
|
||||
// Deprecated: Use internal/domain directly
|
||||
type (
|
||||
FailureClass = domain.FailureClass
|
||||
FailureInfo = domain.FailureInfo
|
||||
)
|
||||
|
||||
// Re-export functions from domain
|
||||
// Deprecated: Use internal/domain directly
|
||||
var (
|
||||
ClassifyFailure = domain.ClassifyFailure
|
||||
GetFailureSuggestion = domain.GetFailureSuggestion
|
||||
ShouldAutoRetry = domain.ShouldAutoRetry
|
||||
RetryDelayForClass = domain.RetryDelayForClass
|
||||
)
|
||||
|
||||
// Re-export constants
|
||||
// Deprecated: Use internal/domain directly
|
||||
const (
|
||||
FailureInfrastructure FailureClass = "infrastructure" // OOM kill, SIGKILL, node failure
|
||||
FailureCode FailureClass = "code" // non-zero exit, exception, assertion
|
||||
FailureData FailureClass = "data" // hash mismatch, dataset unreachable
|
||||
FailureResource FailureClass = "resource" // GPU OOM, disk full, timeout
|
||||
FailureUnknown FailureClass = "unknown" // cannot classify
|
||||
FailureInfrastructure = domain.FailureInfrastructure
|
||||
FailureCode = domain.FailureCode
|
||||
FailureData = domain.FailureData
|
||||
FailureResource = domain.FailureResource
|
||||
FailureUnknown = domain.FailureUnknown
|
||||
)
|
||||
|
||||
// ClassifyFailure determines the failure class from exit signals, codes, and log output
|
||||
func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClass {
|
||||
logLower := strings.ToLower(logTail)
|
||||
|
||||
// Killed by OS — infrastructure failure
|
||||
if signal == syscall.SIGKILL {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// CUDA OOM or GPU resource issues
|
||||
if strings.Contains(logLower, "cuda out of memory") ||
|
||||
strings.Contains(logLower, "cuda error") ||
|
||||
strings.Contains(logLower, "gpu oom") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// General OOM (non-GPU) — infrastructure
|
||||
if strings.Contains(logLower, "out of memory") ||
|
||||
strings.Contains(logLower, "oom") ||
|
||||
strings.Contains(logLower, "cannot allocate memory") {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// Dataset hash check failed — data failure
|
||||
if strings.Contains(logLower, "hash mismatch") ||
|
||||
strings.Contains(logLower, "checksum failed") ||
|
||||
strings.Contains(logLower, "dataset not found") ||
|
||||
strings.Contains(logLower, "dataset unreachable") {
|
||||
return FailureData
|
||||
}
|
||||
|
||||
// Disk/resource exhaustion
|
||||
if strings.Contains(logLower, "no space left") ||
|
||||
strings.Contains(logLower, "disk full") ||
|
||||
strings.Contains(logLower, "disk quota exceeded") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// Timeout — resource (time budget exceeded)
|
||||
if strings.Contains(logLower, "timeout") ||
|
||||
strings.Contains(logLower, "deadline exceeded") ||
|
||||
strings.Contains(logLower, "context deadline") {
|
||||
return FailureResource
|
||||
}
|
||||
|
||||
// Network issues — infrastructure
|
||||
if strings.Contains(logLower, "connection refused") ||
|
||||
strings.Contains(logLower, "connection reset") ||
|
||||
strings.Contains(logLower, "no route to host") ||
|
||||
strings.Contains(logLower, "network unreachable") {
|
||||
return FailureInfrastructure
|
||||
}
|
||||
|
||||
// Non-zero exit without specific signal — code failure
|
||||
if exitCode != 0 {
|
||||
return FailureCode
|
||||
}
|
||||
|
||||
return FailureUnknown
|
||||
}
|
||||
|
||||
// FailureInfo contains complete failure context for the manifest
|
||||
type FailureInfo struct {
|
||||
Class FailureClass `json:"class"`
|
||||
ExitCode int `json:"exit_code,omitempty"`
|
||||
Signal string `json:"signal,omitempty"`
|
||||
LogTail string `json:"log_tail,omitempty"`
|
||||
Suggestion string `json:"suggestion,omitempty"`
|
||||
AutoRetried bool `json:"auto_retried,omitempty"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
RetryCap int `json:"retry_cap,omitempty"`
|
||||
ClassifiedAt string `json:"classified_at,omitempty"`
|
||||
Context map[string]string `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
// GetFailureSuggestion returns user guidance based on failure class
|
||||
func GetFailureSuggestion(class FailureClass, logTail string) string {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
return "Infrastructure failure (node died, OOM kill). Auto-retry in progress."
|
||||
case FailureCode:
|
||||
return "Code error in training script. Fix before resubmitting."
|
||||
case FailureData:
|
||||
return "Data verification failed. Check dataset accessibility and hashes."
|
||||
case FailureResource:
|
||||
if strings.Contains(strings.ToLower(logTail), "cuda") {
|
||||
return "GPU OOM. Increase --gpu-memory or use smaller batch size."
|
||||
}
|
||||
if strings.Contains(strings.ToLower(logTail), "disk") {
|
||||
return "Disk full. Clean up storage or request more space."
|
||||
}
|
||||
return "Resource exhausted. Try with larger allocation or reduced load."
|
||||
default:
|
||||
return "Unknown failure. Review logs and contact support if persistent."
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldAutoRetry determines if a failure class should auto-retry
|
||||
// infrastructure: 3 retries transparent
|
||||
// resource: 1 retry with backoff
|
||||
// unknown: 1 retry (conservative - was retryable in old system)
|
||||
// others: never auto-retry
|
||||
func ShouldAutoRetry(class FailureClass, retryCount int) bool {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
return retryCount < 3
|
||||
case FailureResource:
|
||||
return retryCount < 1
|
||||
case FailureUnknown:
|
||||
// Unknown failures get 1 retry attempt (conservative, matches old behavior)
|
||||
return retryCount < 1
|
||||
default:
|
||||
// code, data failures never auto-retry
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RetryDelayForClass returns appropriate backoff for the failure class
|
||||
func RetryDelayForClass(class FailureClass, retryCount int) int {
|
||||
switch class {
|
||||
case FailureInfrastructure:
|
||||
// Immediate retry for infrastructure
|
||||
return 0
|
||||
case FailureResource:
|
||||
// Short backoff for resource issues
|
||||
return 30
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorCategory represents the type of error encountered (DEPRECATED: use FailureClass)
|
||||
type ErrorCategory string
|
||||
|
||||
// Error categories for task classification and retry logic
|
||||
const (
|
||||
ErrorNetwork ErrorCategory = "network" // Network connectivity issues
|
||||
ErrorResource ErrorCategory = "resource" // Resource exhaustion (OOM, disk full)
|
||||
ErrorRateLimit ErrorCategory = "rate_limit" // Rate limiting or throttling
|
||||
ErrorAuth ErrorCategory = "auth" // Authentication/authorization failures
|
||||
ErrorValidation ErrorCategory = "validation" // Input validation errors
|
||||
ErrorTimeout ErrorCategory = "timeout" // Operation timeout
|
||||
ErrorPermanent ErrorCategory = "permanent" // Non-retryable errors
|
||||
ErrorUnknown ErrorCategory = "unknown" // Unclassified errors
|
||||
)
|
||||
|
||||
// TaskError wraps an error with category and context
|
||||
type TaskError struct {
|
||||
Category ErrorCategory
|
||||
Message string
|
||||
Cause error
|
||||
Context map[string]string
|
||||
}
|
||||
|
||||
func (e *TaskError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("[%s] %s: %v", e.Category, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", e.Category, e.Message)
|
||||
}
|
||||
|
||||
func (e *TaskError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewTaskError creates a new categorized error
|
||||
func NewTaskError(category ErrorCategory, message string, cause error) *TaskError {
|
||||
return &TaskError{
|
||||
Category: category,
|
||||
Message: message,
|
||||
Cause: cause,
|
||||
Context: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// ClassifyError categorizes an error for retry logic (DEPRECATED: use classifyFailure)
|
||||
// This function now delegates to the more accurate classifyFailure
|
||||
func ClassifyError(err error) ErrorCategory {
|
||||
if err == nil {
|
||||
return ErrorUnknown
|
||||
}
|
||||
|
||||
// Check if already classified as TaskError
|
||||
var taskErr *TaskError
|
||||
if errors.As(err, &taskErr) {
|
||||
return taskErr.Category
|
||||
}
|
||||
|
||||
// Delegate to new FailureClass classification
|
||||
failureClass := ClassifyFailure(0, nil, err.Error())
|
||||
|
||||
// Map FailureClass back to ErrorCategory for backward compatibility
|
||||
switch failureClass {
|
||||
case FailureInfrastructure:
|
||||
return ErrorNetwork
|
||||
case FailureCode:
|
||||
return ErrorPermanent
|
||||
case FailureData:
|
||||
return ErrorValidation
|
||||
case FailureResource:
|
||||
return ErrorResource
|
||||
default:
|
||||
return ErrorUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// IsRetryable determines if an error category should be retried
|
||||
// Now delegates to ShouldAutoRetry with FailureClass mapping
|
||||
func IsRetryable(category ErrorCategory) bool {
|
||||
// Map ErrorCategory to FailureClass
|
||||
var failureClass FailureClass
|
||||
switch category {
|
||||
case ErrorNetwork:
|
||||
failureClass = FailureInfrastructure
|
||||
case ErrorResource, ErrorTimeout:
|
||||
failureClass = FailureResource
|
||||
case ErrorAuth, ErrorValidation, ErrorPermanent:
|
||||
failureClass = FailureCode
|
||||
default:
|
||||
failureClass = FailureUnknown
|
||||
}
|
||||
return ShouldAutoRetry(failureClass, 0)
|
||||
}
|
||||
|
||||
// GetUserMessage returns a user-friendly error message with suggestions
|
||||
func GetUserMessage(category ErrorCategory, err error) string {
|
||||
messages := map[ErrorCategory]string{
|
||||
ErrorNetwork: "Network connectivity issue. Please check your network " +
|
||||
"connection and try again.",
|
||||
ErrorResource: "System resource exhausted. The system may be under heavy load. " +
|
||||
"Try again later or contact support.",
|
||||
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
|
||||
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
|
||||
ErrorValidation: "Invalid input. Please review your request and correct any errors.",
|
||||
ErrorTimeout: "Operation timed out. The task may be too complex or the system is slow. " +
|
||||
"Try again or simplify the request.",
|
||||
ErrorPermanent: "A permanent error occurred. This task cannot be retried automatically.",
|
||||
ErrorUnknown: "An unexpected error occurred. If this persists, please contact support.",
|
||||
}
|
||||
|
||||
baseMsg := messages[category]
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s (Details: %v)", baseMsg, err)
|
||||
}
|
||||
return baseMsg
|
||||
}
|
||||
|
||||
// RetryDelay calculates the retry delay based on error category and retry count
|
||||
// Now delegates to RetryDelayForClass with FailureClass mapping
|
||||
func RetryDelay(category ErrorCategory, retryCount int) int {
|
||||
// Map ErrorCategory to FailureClass
|
||||
var failureClass FailureClass
|
||||
switch category {
|
||||
case ErrorNetwork:
|
||||
failureClass = FailureInfrastructure
|
||||
case ErrorResource, ErrorTimeout:
|
||||
failureClass = FailureResource
|
||||
default:
|
||||
failureClass = FailureUnknown
|
||||
}
|
||||
return RetryDelayForClass(failureClass, retryCount)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
type FilesystemQueue struct {
|
||||
|
|
@ -177,13 +179,13 @@ func (q *FilesystemQueue) RetryTask(task *Task) error {
|
|||
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
|
||||
}
|
||||
|
||||
errorCategory := ErrorUnknown
|
||||
failureClass := domain.FailureUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
failureClass = domain.ClassifyFailure(0, nil, task.Error)
|
||||
}
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
|
||||
RecordDLQAddition(string(failureClass))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
|
|
@ -191,13 +193,13 @@ func (q *FilesystemQueue) RetryTask(task *Task) error {
|
|||
task.LastError = task.Error
|
||||
task.Error = ""
|
||||
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
|
||||
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
RecordTaskRetry(task.JobName, failureClass)
|
||||
return q.AddTask(task)
|
||||
}
|
||||
|
||||
|
|
@ -207,7 +209,8 @@ func (q *FilesystemQueue) MoveToDeadLetterQueue(task *Task, reason string) error
|
|||
}
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
|
||||
RecordTaskFailure(task.JobName, failureClass)
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
|
@ -31,17 +32,17 @@ var (
|
|||
Help: "Total number of completed tasks",
|
||||
}, []string{"job_name", "status"})
|
||||
|
||||
// TaskFailures tracks failed tasks by error category.
|
||||
// TaskFailures tracks failed tasks by failure class.
|
||||
TaskFailures = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_task_failures_total",
|
||||
Help: "Total number of failed tasks by error category",
|
||||
}, []string{"job_name", "error_category"})
|
||||
Help: "Total number of failed tasks by failure class",
|
||||
}, []string{"job_name", "failure_class"})
|
||||
|
||||
// TaskRetries tracks the total number of task retries.
|
||||
TaskRetries = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_task_retries_total",
|
||||
Help: "Total number of task retries",
|
||||
}, []string{"job_name", "error_category"})
|
||||
}, []string{"job_name", "failure_class"})
|
||||
|
||||
// LeaseExpirations tracks expired leases that were reclaimed.
|
||||
LeaseExpirations = promauto.NewCounter(prometheus.CounterOpts{
|
||||
|
|
@ -92,14 +93,14 @@ func RecordTaskEnd(jobName, workerID, status string, durationSeconds float64) {
|
|||
TasksCompleted.WithLabelValues(jobName, status).Inc()
|
||||
}
|
||||
|
||||
// RecordTaskFailure records a task failure with error category
|
||||
func RecordTaskFailure(jobName string, errorCategory ErrorCategory) {
|
||||
TaskFailures.WithLabelValues(jobName, string(errorCategory)).Inc()
|
||||
// RecordTaskFailure records a task failure with failure class
|
||||
func RecordTaskFailure(jobName string, failureClass domain.FailureClass) {
|
||||
TaskFailures.WithLabelValues(jobName, string(failureClass)).Inc()
|
||||
}
|
||||
|
||||
// RecordTaskRetry records a task retry
|
||||
func RecordTaskRetry(jobName string, errorCategory ErrorCategory) {
|
||||
TaskRetries.WithLabelValues(jobName, string(errorCategory)).Inc()
|
||||
func RecordTaskRetry(jobName string, failureClass domain.FailureClass) {
|
||||
TaskRetries.WithLabelValues(jobName, string(failureClass)).Inc()
|
||||
}
|
||||
|
||||
// RecordLeaseExpiration records a lease expiration
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
|
|
@ -439,7 +440,7 @@ func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
|
|||
return tq.UpdateTask(task)
|
||||
}
|
||||
|
||||
// RetryTask re-queues a failed task with smart backoff based on error category
|
||||
// RetryTask re-queues a failed task with smart backoff based on failure class
|
||||
func (tq *TaskQueue) RetryTask(task *Task) error {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Move to dead letter queue
|
||||
|
|
@ -448,15 +449,15 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
|
|||
}
|
||||
|
||||
// Classify the error if it exists
|
||||
errorCategory := ErrorUnknown
|
||||
failureClass := domain.FailureUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
failureClass = domain.ClassifyFailure(0, nil, task.Error)
|
||||
}
|
||||
|
||||
// Check if error is retryable
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
|
||||
RecordDLQAddition(string(failureClass))
|
||||
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
|
|
@ -464,8 +465,8 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
|
|||
task.LastError = task.Error // Preserve last error
|
||||
task.Error = "" // Clear current error
|
||||
|
||||
// Calculate smart backoff based on error category
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
// Calculate smart backoff based on failure class
|
||||
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
|
||||
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
|
||||
|
|
@ -474,7 +475,7 @@ func (tq *TaskQueue) RetryTask(task *Task) error {
|
|||
task.LeasedBy = ""
|
||||
|
||||
// Record retry metrics
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
RecordTaskRetry(task.JobName, failureClass)
|
||||
|
||||
// Re-queue with same priority
|
||||
return tq.AddTask(task)
|
||||
|
|
@ -493,8 +494,9 @@ func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
|||
// Store in dead letter queue with timestamp
|
||||
key := "task:dlq:" + task.ID
|
||||
|
||||
// Record metrics
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
// Record metrics using domain.FailureClass
|
||||
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
|
||||
RecordTaskFailure(task.JobName, failureClass)
|
||||
|
||||
pipe := tq.client.Pipeline()
|
||||
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
|
|
@ -495,13 +496,13 @@ func (q *SQLiteQueue) RetryTask(task *Task) error {
|
|||
return q.MoveToDeadLetterQueue(task, "max retries exceeded")
|
||||
}
|
||||
|
||||
errorCategory := ErrorUnknown
|
||||
failureClass := domain.FailureUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
failureClass = domain.ClassifyFailure(0, nil, task.Error)
|
||||
}
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
if !domain.ShouldAutoRetry(failureClass, task.RetryCount) {
|
||||
RecordDLQAddition(string(failureClass))
|
||||
return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", failureClass))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
|
|
@ -509,13 +510,13 @@ func (q *SQLiteQueue) RetryTask(task *Task) error {
|
|||
task.LastError = task.Error
|
||||
task.Error = ""
|
||||
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
backoffSeconds := domain.RetryDelayForClass(failureClass, task.RetryCount)
|
||||
nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
RecordTaskRetry(task.JobName, failureClass)
|
||||
return q.AddTask(task)
|
||||
}
|
||||
|
||||
|
|
@ -523,7 +524,8 @@ func (q *SQLiteQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
|||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
||||
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
failureClass := domain.ClassifyFailure(0, nil, task.LastError)
|
||||
RecordTaskFailure(task.JobName, failureClass)
|
||||
return q.UpdateTask(task)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,111 +1,21 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
// DatasetSpec describes a dataset input with optional provenance fields.
|
||||
type DatasetSpec struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Checksum string `json:"checksum,omitempty"`
|
||||
URI string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// Task represents an ML experiment task
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
// TODO(phase1): SnapshotID is an opaque identifier only.
|
||||
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
// Datasets is kept for backward compatibility (legacy callers).
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// Resource requests (optional, 0 means unspecified)
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
MemoryGB int `json:"memory_gb,omitempty"`
|
||||
GPU int `json:"gpu,omitempty"`
|
||||
GPUMemory string `json:"gpu_memory,omitempty"`
|
||||
|
||||
// User ownership and permissions
|
||||
UserID string `json:"user_id"` // User who owns this task
|
||||
Username string `json:"username"` // Username for display
|
||||
CreatedBy string `json:"created_by"` // User who submitted the task
|
||||
|
||||
// Lease management for task resilience
|
||||
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
|
||||
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
|
||||
|
||||
// Retry management
|
||||
RetryCount int `json:"retry_count"` // Number of retry attempts made
|
||||
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
|
||||
LastError string `json:"last_error,omitempty"` // Last error encountered
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
|
||||
|
||||
// Attempt tracking - complete history of all execution attempts
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
|
||||
// Optional tracking configuration for this task
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
}
|
||||
|
||||
// Attempt represents a single execution attempt of a task
|
||||
type Attempt struct {
|
||||
Attempt int `json:"attempt"` // Attempt number (1-indexed)
|
||||
StartedAt time.Time `json:"started_at"` // When attempt started
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
|
||||
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
|
||||
Status string `json:"status"` // running, completed, failed
|
||||
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
|
||||
ExitCode int `json:"exit_code,omitempty"` // Process exit code
|
||||
Signal string `json:"signal,omitempty"` // Termination signal (if any)
|
||||
Error string `json:"error,omitempty"` // Error message (if failed)
|
||||
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
|
||||
}
|
||||
|
||||
// TrackingConfig specifies experiment tracking tools to enable for a task.
|
||||
type TrackingConfig struct {
|
||||
MLflow *MLflowTrackingConfig `json:"mlflow,omitempty"`
|
||||
TensorBoard *TensorBoardTrackingConfig `json:"tensorboard,omitempty"`
|
||||
Wandb *WandbTrackingConfig `json:"wandb,omitempty"`
|
||||
}
|
||||
|
||||
// MLflowTrackingConfig controls MLflow integration.
|
||||
type MLflowTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
|
||||
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
|
||||
}
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration.
|
||||
type TensorBoardTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
|
||||
}
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration.
|
||||
type WandbTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Entity string `json:"entity,omitempty"`
|
||||
}
|
||||
// Re-export domain types for backward compatibility
|
||||
// Deprecated: Use internal/domain directly
|
||||
type (
|
||||
Task = domain.Task
|
||||
Attempt = domain.Attempt
|
||||
DatasetSpec = domain.DatasetSpec
|
||||
TrackingConfig = domain.TrackingConfig
|
||||
MLflowTrackingConfig = domain.MLflowTrackingConfig
|
||||
TensorBoardTrackingConfig = domain.TensorBoardTrackingConfig
|
||||
WandbTrackingConfig = domain.WandbTrackingConfig
|
||||
)
|
||||
|
||||
// Redis key constants
|
||||
var (
|
||||
|
|
|
|||
Loading…
Reference in a new issue