fetch_ml/cmd/tui/internal/services/services.go
Jeremie Fraeys 6580917ba8
refactor: extract domain types and consolidate error system (Phases 1-2)
Phase 1: Extract Domain Types
=============================
- Create internal/domain/ package with canonical types:
  - domain/task.go: Task, Attempt structs
  - domain/tracking.go: TrackingConfig and MLflow/TensorBoard/Wandb configs
  - domain/dataset.go: DatasetSpec
  - domain/status.go: JobStatus constants
  - domain/errors.go: FailureClass system with classification functions
  - domain/doc.go: package documentation

- Update queue/task.go to re-export domain types (backward compatibility)
- Update TUI model/state.go to use domain types via type aliases
- Simplify TUI services: remove ~60 lines of conversion functions

Phase 2: Delete ErrorCategory System
====================================
- Remove deprecated ErrorCategory type and constants
- Remove TaskError struct and related functions
- Remove mapping functions: ClassifyError, IsRetryable, GetUserMessage, RetryDelay
- Update all queue implementations to use domain.FailureClass directly:
  - queue/metrics.go: RecordTaskFailure/Retry now take FailureClass
  - queue/queue.go: RetryTask uses domain.ClassifyFailure
  - queue/filesystem_queue.go: RetryTask and MoveToDeadLetterQueue updated
  - queue/sqlite_queue.go: RetryTask and MoveToDeadLetterQueue updated

Lines eliminated: ~190 lines of conversion and mapping code
Result: Single source of truth for domain types and error classification
2026-02-17 12:34:28 -05:00

211 lines
5.7 KiB
Go

// Package services provides TUI service implementations
package services
import (
"context"
"fmt"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// Task is an alias for domain.Task for TUI compatibility
type Task = domain.Task
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
type TaskQueue struct {
internal *queue.TaskQueue
expManager *experiment.Manager
ctx context.Context
}
// NewTaskQueue creates a new task queue service
func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
// Create internal queue config
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
}
internalQueue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
return nil, fmt.Errorf("failed to create task queue: %w", err)
}
// Initialize experiment manager
// TODO: Get base path from config
expManager := experiment.NewManager("./experiments")
return &TaskQueue{
internal: internalQueue,
expManager: expManager,
ctx: context.Background(),
}, nil
}
// EnqueueTask adds a new task to the queue
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
// Create internal task
internalTask := &queue.Task{
JobName: jobName,
Args: args,
Priority: priority,
}
// Use internal queue to enqueue
err := tq.internal.AddTask(internalTask)
if err != nil {
return nil, err
}
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// GetNextTask retrieves the next task from the queue
func (tq *TaskQueue) GetNextTask() (*Task, error) {
internalTask, err := tq.internal.GetNextTask()
if err != nil {
return nil, err
}
if internalTask == nil {
return nil, nil
}
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// GetTask retrieves a specific task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
internalTask, err := tq.internal.GetTask(taskID)
if err != nil {
return nil, err
}
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// UpdateTask updates a task's status and metadata
func (tq *TaskQueue) UpdateTask(task *Task) error {
// task is already domain.Task, pass directly to internal queue
return tq.internal.UpdateTask(task)
}
// GetQueuedTasks retrieves all queued tasks
func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) {
internalTasks, err := tq.internal.GetAllTasks()
if err != nil {
return nil, err
}
// Return domain.Tasks directly (no conversion needed)
return internalTasks, nil
}
// GetJobStatus gets the status of all jobs with the given name
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
// This method doesn't exist in internal queue, implement basic version
task, err := tq.internal.GetTaskByName(jobName)
if err != nil {
return nil, err
}
if task == nil {
return map[string]string{"status": "not_found"}, nil
}
return map[string]string{
"status": task.Status,
"task_id": task.ID,
}, nil
}
// RecordMetric records a metric for monitoring
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
_ = jobName // Parameter reserved for future use
return tq.internal.RecordMetric(jobName, metric, value)
}
// GetMetrics retrieves metrics for a job
func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
// This method doesn't exist in internal queue, return empty for now
return map[string]string{}, nil
}
// ListDatasets retrieves available datasets
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
// This method doesn't exist in internal queue, return empty for now
return []model.DatasetInfo{}, nil
}
// CancelTask cancels a task by ID
func (tq *TaskQueue) CancelTask(taskID string) error {
return tq.internal.CancelTask(taskID)
}
// ListExperiments retrieves experiment list
func (tq *TaskQueue) ListExperiments() ([]string, error) {
return tq.expManager.ListExperiments()
}
// GetExperimentDetails retrieves experiment details
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
meta, err := tq.expManager.ReadMetadata(commitID)
if err != nil {
return "", err
}
metrics, err := tq.expManager.GetMetrics(commitID)
if err != nil {
return "", err
}
output := fmt.Sprintf("Experiment: %s\n", meta.JobName)
output += fmt.Sprintf("Commit ID: %s\n", meta.CommitID)
output += fmt.Sprintf("User: %s\n", meta.User)
output += fmt.Sprintf("Timestamp: %d\n\n", meta.Timestamp)
output += "Metrics:\n"
if len(metrics) == 0 {
output += " No metrics logged.\n"
} else {
for _, m := range metrics {
output += fmt.Sprintf(" %s: %.4f (Step: %d)\n", m.Name, m.Value, m.Step)
}
}
return output, nil
}
// Close closes the task queue
func (tq *TaskQueue) Close() error {
return tq.internal.Close()
}
// MLServer wraps network.SSHClient for backward compatibility
type MLServer struct {
*network.SSHClient
addr string
}
// NewMLServer creates a new ML server connection
func NewMLServer(cfg *config.Config) (*MLServer, error) {
// Local mode: skip SSH entirely
if cfg.Host == "" {
client, _ := network.NewSSHClient("", "", "", 0, "")
return &MLServer{SSHClient: client, addr: "localhost"}, nil
}
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
return &MLServer{SSHClient: client, addr: addr}, nil
}