319 lines
8.3 KiB
Go
319 lines
8.3 KiB
Go
// Package services provides TUI service implementations
|
|
package services
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
|
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/network"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
)
|
|
|
|
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
|
|
type TaskQueue struct {
|
|
internal *queue.TaskQueue
|
|
expManager *experiment.Manager
|
|
ctx context.Context
|
|
}
|
|
|
|
// NewTaskQueue creates a new task queue service
|
|
func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
|
// Create internal queue config
|
|
queueCfg := queue.Config{
|
|
RedisAddr: cfg.RedisAddr,
|
|
RedisPassword: cfg.RedisPassword,
|
|
RedisDB: cfg.RedisDB,
|
|
}
|
|
|
|
internalQueue, err := queue.NewTaskQueue(queueCfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create task queue: %w", err)
|
|
}
|
|
|
|
// Initialize experiment manager
|
|
// TODO: Get base path from config
|
|
expManager := experiment.NewManager("./experiments")
|
|
|
|
return &TaskQueue{
|
|
internal: internalQueue,
|
|
expManager: expManager,
|
|
ctx: context.Background(),
|
|
}, nil
|
|
}
|
|
|
|
// EnqueueTask adds a new task to the queue
|
|
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) {
|
|
// Create internal task
|
|
internalTask := &queue.Task{
|
|
JobName: jobName,
|
|
Args: args,
|
|
Priority: priority,
|
|
}
|
|
|
|
// Use internal queue to enqueue
|
|
err := tq.internal.AddTask(internalTask)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert to TUI model
|
|
return &model.Task{
|
|
ID: internalTask.ID,
|
|
JobName: internalTask.JobName,
|
|
Args: internalTask.Args,
|
|
Status: "queued",
|
|
Priority: internalTask.Priority,
|
|
CreatedAt: internalTask.CreatedAt,
|
|
Metadata: internalTask.Metadata,
|
|
Tracking: convertTrackingToModel(internalTask.Tracking),
|
|
}, nil
|
|
}
|
|
|
|
// GetNextTask retrieves the next task from the queue
|
|
func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
|
|
internalTask, err := tq.internal.GetNextTask()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if internalTask == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
// Convert to TUI model
|
|
return &model.Task{
|
|
ID: internalTask.ID,
|
|
JobName: internalTask.JobName,
|
|
Args: internalTask.Args,
|
|
Status: internalTask.Status,
|
|
Priority: internalTask.Priority,
|
|
CreatedAt: internalTask.CreatedAt,
|
|
Metadata: internalTask.Metadata,
|
|
Tracking: convertTrackingToModel(internalTask.Tracking),
|
|
}, nil
|
|
}
|
|
|
|
// GetTask retrieves a specific task by ID
|
|
func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) {
|
|
internalTask, err := tq.internal.GetTask(taskID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert to TUI model
|
|
return &model.Task{
|
|
ID: internalTask.ID,
|
|
JobName: internalTask.JobName,
|
|
Args: internalTask.Args,
|
|
Status: internalTask.Status,
|
|
Priority: internalTask.Priority,
|
|
CreatedAt: internalTask.CreatedAt,
|
|
Metadata: internalTask.Metadata,
|
|
Tracking: convertTrackingToModel(internalTask.Tracking),
|
|
}, nil
|
|
}
|
|
|
|
// UpdateTask updates a task's status and metadata
|
|
func (tq *TaskQueue) UpdateTask(task *model.Task) error {
|
|
// Convert to internal task
|
|
internalTask := &queue.Task{
|
|
ID: task.ID,
|
|
JobName: task.JobName,
|
|
Args: task.Args,
|
|
Status: task.Status,
|
|
Priority: task.Priority,
|
|
CreatedAt: task.CreatedAt,
|
|
Metadata: task.Metadata,
|
|
Tracking: convertTrackingToInternal(task.Tracking),
|
|
}
|
|
|
|
return tq.internal.UpdateTask(internalTask)
|
|
}
|
|
|
|
// GetQueuedTasks retrieves all queued tasks
|
|
func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) {
|
|
internalTasks, err := tq.internal.GetAllTasks()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert to TUI models
|
|
tasks := make([]*model.Task, len(internalTasks))
|
|
for i, task := range internalTasks {
|
|
tasks[i] = &model.Task{
|
|
ID: task.ID,
|
|
JobName: task.JobName,
|
|
Args: task.Args,
|
|
Status: task.Status,
|
|
Priority: task.Priority,
|
|
CreatedAt: task.CreatedAt,
|
|
Metadata: task.Metadata,
|
|
Tracking: convertTrackingToModel(task.Tracking),
|
|
}
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
// GetJobStatus gets the status of all jobs with the given name
|
|
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
|
|
// This method doesn't exist in internal queue, implement basic version
|
|
task, err := tq.internal.GetTaskByName(jobName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if task == nil {
|
|
return map[string]string{"status": "not_found"}, nil
|
|
}
|
|
|
|
return map[string]string{
|
|
"status": task.Status,
|
|
"task_id": task.ID,
|
|
}, nil
|
|
}
|
|
|
|
// RecordMetric records a metric for monitoring
|
|
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
|
|
_ = jobName // Parameter reserved for future use
|
|
return tq.internal.RecordMetric(jobName, metric, value)
|
|
}
|
|
|
|
// GetMetrics retrieves metrics for a job
|
|
func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
|
|
// This method doesn't exist in internal queue, return empty for now
|
|
return map[string]string{}, nil
|
|
}
|
|
|
|
// ListDatasets retrieves available datasets
|
|
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
|
|
// This method doesn't exist in internal queue, return empty for now
|
|
return []model.DatasetInfo{}, nil
|
|
}
|
|
|
|
// CancelTask cancels a task by ID
|
|
func (tq *TaskQueue) CancelTask(taskID string) error {
|
|
return tq.internal.CancelTask(taskID)
|
|
}
|
|
|
|
// ListExperiments retrieves experiment list
|
|
func (tq *TaskQueue) ListExperiments() ([]string, error) {
|
|
return tq.expManager.ListExperiments()
|
|
}
|
|
|
|
// GetExperimentDetails retrieves experiment details
|
|
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
|
|
meta, err := tq.expManager.ReadMetadata(commitID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
metrics, err := tq.expManager.GetMetrics(commitID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
output := fmt.Sprintf("Experiment: %s\n", meta.JobName)
|
|
output += fmt.Sprintf("Commit ID: %s\n", meta.CommitID)
|
|
output += fmt.Sprintf("User: %s\n", meta.User)
|
|
output += fmt.Sprintf("Timestamp: %d\n\n", meta.Timestamp)
|
|
output += "Metrics:\n"
|
|
|
|
if len(metrics) == 0 {
|
|
output += " No metrics logged.\n"
|
|
} else {
|
|
for _, m := range metrics {
|
|
output += fmt.Sprintf(" %s: %.4f (Step: %d)\n", m.Name, m.Value, m.Step)
|
|
}
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// Close closes the task queue
|
|
func (tq *TaskQueue) Close() error {
|
|
return tq.internal.Close()
|
|
}
|
|
|
|
// MLServer wraps network.SSHClient for backward compatibility
|
|
type MLServer struct {
|
|
*network.SSHClient
|
|
addr string
|
|
}
|
|
|
|
// NewMLServer creates a new ML server connection
|
|
func NewMLServer(cfg *config.Config) (*MLServer, error) {
|
|
// Local mode: skip SSH entirely
|
|
if cfg.Host == "" {
|
|
client, _ := network.NewSSHClient("", "", "", 0, "")
|
|
return &MLServer{SSHClient: client, addr: "localhost"}, nil
|
|
}
|
|
|
|
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
|
return &MLServer{SSHClient: client, addr: addr}, nil
|
|
}
|
|
|
|
func convertTrackingToModel(t *queue.TrackingConfig) *model.TrackingConfig {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
out := &model.TrackingConfig{}
|
|
if t.MLflow != nil {
|
|
out.MLflow = &model.MLflowTrackingConfig{
|
|
Enabled: t.MLflow.Enabled,
|
|
Mode: t.MLflow.Mode,
|
|
TrackingURI: t.MLflow.TrackingURI,
|
|
}
|
|
}
|
|
if t.TensorBoard != nil {
|
|
out.TensorBoard = &model.TensorBoardTrackingConfig{
|
|
Enabled: t.TensorBoard.Enabled,
|
|
Mode: t.TensorBoard.Mode,
|
|
}
|
|
}
|
|
if t.Wandb != nil {
|
|
out.Wandb = &model.WandbTrackingConfig{
|
|
Enabled: t.Wandb.Enabled,
|
|
Mode: t.Wandb.Mode,
|
|
APIKey: t.Wandb.APIKey,
|
|
Project: t.Wandb.Project,
|
|
Entity: t.Wandb.Entity,
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func convertTrackingToInternal(t *model.TrackingConfig) *queue.TrackingConfig {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
out := &queue.TrackingConfig{}
|
|
if t.MLflow != nil {
|
|
out.MLflow = &queue.MLflowTrackingConfig{
|
|
Enabled: t.MLflow.Enabled,
|
|
Mode: t.MLflow.Mode,
|
|
TrackingURI: t.MLflow.TrackingURI,
|
|
}
|
|
}
|
|
if t.TensorBoard != nil {
|
|
out.TensorBoard = &queue.TensorBoardTrackingConfig{
|
|
Enabled: t.TensorBoard.Enabled,
|
|
Mode: t.TensorBoard.Mode,
|
|
}
|
|
}
|
|
if t.Wandb != nil {
|
|
out.Wandb = &queue.WandbTrackingConfig{
|
|
Enabled: t.Wandb.Enabled,
|
|
Mode: t.Wandb.Mode,
|
|
APIKey: t.Wandb.APIKey,
|
|
Project: t.Wandb.Project,
|
|
Entity: t.Wandb.Entity,
|
|
}
|
|
}
|
|
return out
|
|
}
|