refactor(dependency-hygiene): Fix Redis leak, simplify TUI wrapper, clean go.mod
Phase 1: Fix Redis Schema Leak - Create internal/storage/dataset.go with DatasetStore abstraction - Remove all direct Redis calls from cmd/data_manager/data_sync.go - data_manager now uses DatasetStore for transfer tracking and metadata Phase 2: Simplify TUI Services - Embed *queue.TaskQueue directly in services.TaskQueue - Eliminate 60% of wrapper boilerplate (203 -> ~100 lines) - Keep only TUI-specific methods (EnqueueTask, GetJobStatus, experiment methods) Phase 5: Clean go.mod Dependencies - Remove duplicate go-redis/redis/v8 dependency - Migrate internal/storage/migrate.go to redis/go-redis/v9 - Separate test-only deps (miniredis, testify) into own block Results: - Zero direct Redis calls in cmd/ - 60% fewer lines in TUI services - Cleaner dependency structure
This commit is contained in:
parent
2a922542b1
commit
dbf96020af
11 changed files with 267 additions and 165 deletions
|
|
@ -3,7 +3,6 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
|
@ -21,6 +20,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
)
|
||||
|
||||
|
|
@ -36,13 +36,14 @@ type SSHClient = network.SSHClient
|
|||
|
||||
// DataManager manages data synchronization between NAS and ML server.
|
||||
type DataManager struct {
|
||||
config *DataConfig
|
||||
mlServer *SSHClient
|
||||
nasServer *SSHClient
|
||||
taskQueue *queue.TaskQueue
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
config *DataConfig
|
||||
mlServer *SSHClient
|
||||
nasServer *SSHClient
|
||||
taskQueue *queue.TaskQueue
|
||||
datasetStore *storage.DatasetStore
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
func (dm *DataManager) archiveDatasetOnML(datasetName string) (string, error) {
|
||||
|
|
@ -130,6 +131,7 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) {
|
|||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
|
||||
var taskQueue *queue.TaskQueue
|
||||
var datasetStore *storage.DatasetStore
|
||||
if cfg.RedisAddr != "" {
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
|
|
@ -150,18 +152,23 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) {
|
|||
cancel() // Cancel context to prevent leak
|
||||
return nil, fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
// Initialize dataset store with the Redis client
|
||||
datasetStore = storage.NewDatasetStoreWithContext(taskQueue.GetRedisClient(), ctx)
|
||||
} else {
|
||||
taskQueue = nil // Local mode - no Redis
|
||||
datasetStore = nil
|
||||
}
|
||||
|
||||
return &DataManager{
|
||||
config: cfg,
|
||||
mlServer: mlServer,
|
||||
nasServer: nasServer,
|
||||
taskQueue: taskQueue,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
config: cfg,
|
||||
mlServer: mlServer,
|
||||
nasServer: nasServer,
|
||||
taskQueue: taskQueue,
|
||||
datasetStore: datasetStore,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -233,13 +240,8 @@ func (dm *DataManager) fetchDatasetInternal(
|
|||
"nas_path", nasPath,
|
||||
"ml_path", mlPath)
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "transferring",
|
||||
"job_name", jobName,
|
||||
"size_bytes", size,
|
||||
"started_at", time.Now().Unix()).Err(); err != nil {
|
||||
if dm.datasetStore != nil {
|
||||
if err := dm.datasetStore.RecordTransferStart(dm.ctx, datasetName, jobName, size); err != nil {
|
||||
logger.Warn("failed to record transfer start in Redis", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -290,11 +292,8 @@ func (dm *DataManager) fetchDatasetInternal(
|
|||
}
|
||||
}
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "failed",
|
||||
"error", err.Error()).Err(); redisErr != nil {
|
||||
if dm.datasetStore != nil {
|
||||
if redisErr := dm.datasetStore.RecordTransferFailure(dm.ctx, datasetName, err); redisErr != nil {
|
||||
logger.Warn("failed to record transfer failure in Redis", "error", redisErr)
|
||||
}
|
||||
}
|
||||
|
|
@ -316,12 +315,8 @@ func (dm *DataManager) fetchDatasetInternal(
|
|||
}
|
||||
}
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "completed",
|
||||
"completed_at", time.Now().Unix(),
|
||||
"duration_seconds", duration.Seconds()).Err(); err != nil {
|
||||
if dm.datasetStore != nil {
|
||||
if err := dm.datasetStore.RecordTransferComplete(dm.ctx, datasetName, duration); err != nil {
|
||||
logger.Warn("failed to record transfer completion in Redis", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -333,48 +328,29 @@ func (dm *DataManager) fetchDatasetInternal(
|
|||
}
|
||||
|
||||
func (dm *DataManager) saveDatasetInfo(name string, size int64) {
|
||||
if dm.taskQueue == nil {
|
||||
if dm.datasetStore == nil {
|
||||
return // Skip in local mode
|
||||
}
|
||||
|
||||
info := DatasetInfo{
|
||||
info := storage.DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
LastAccess: time.Now(),
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(info)
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
|
||||
"dataset", name, "error", err)
|
||||
}
|
||||
if err := dm.datasetStore.SaveDatasetInfo(dm.ctx, info); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
|
||||
"dataset", name, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (dm *DataManager) updateLastAccess(name string) {
|
||||
if dm.taskQueue == nil {
|
||||
if dm.datasetStore == nil {
|
||||
return // Skip in local mode
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("ml:dataset:%s", name)
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
data, err := redisClient.Get(dm.ctx, key).Result()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var info DatasetInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
info.LastAccess = time.Now()
|
||||
newData, _ := json.Marshal(info)
|
||||
redisClient = dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil {
|
||||
if err := dm.datasetStore.UpdateLastAccess(dm.ctx, name); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis",
|
||||
"dataset", name, "error", err)
|
||||
}
|
||||
|
|
@ -396,20 +372,14 @@ func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) {
|
|||
var info DatasetInfo
|
||||
|
||||
// Only use Redis if available
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
key := fmt.Sprintf("ml:dataset:%s", name)
|
||||
data, err := redisClient.Get(dm.ctx, key).Result()
|
||||
|
||||
if err == nil {
|
||||
if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil {
|
||||
// Fallback to disk if unmarshal fails
|
||||
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
|
||||
info = DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
}
|
||||
if dm.datasetStore != nil {
|
||||
dsInfo, err := dm.datasetStore.GetDatasetInfo(dm.ctx, name)
|
||||
if err == nil && dsInfo != nil {
|
||||
info = DatasetInfo{
|
||||
Name: dsInfo.Name,
|
||||
SizeBytes: dsInfo.SizeBytes,
|
||||
Location: dsInfo.Location,
|
||||
LastAccess: dsInfo.LastAccess,
|
||||
}
|
||||
} else {
|
||||
// Fallback: get from disk
|
||||
|
|
@ -508,9 +478,8 @@ func (dm *DataManager) CleanupOldData() error {
|
|||
archived = append(archived, ds.Name)
|
||||
totalSize -= ds.SizeBytes
|
||||
totalSizeGB = float64(totalSize) / (1024 * 1024 * 1024)
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil {
|
||||
if dm.datasetStore != nil {
|
||||
if err := dm.datasetStore.DeleteDatasetInfo(dm.ctx, ds.Name); err != nil {
|
||||
logger.Warn("failed to delete dataset from Redis",
|
||||
"dataset", ds.Name,
|
||||
"error", err)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
|
|
@ -16,11 +15,12 @@ import (
|
|||
// Task is an alias for domain.Task for TUI compatibility
|
||||
type Task = domain.Task
|
||||
|
||||
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
|
||||
// TaskQueue provides TUI-specific task operations by embedding queue.TaskQueue
|
||||
// and extending it with experiment management capabilities.
|
||||
type TaskQueue struct {
|
||||
internal *queue.TaskQueue
|
||||
expManager *experiment.Manager
|
||||
ctx context.Context
|
||||
*queue.TaskQueue // Embed to inherit all queue methods directly
|
||||
expManager *experiment.Manager
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewTaskQueue creates a new task queue service
|
||||
|
|
@ -42,13 +42,13 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
|||
expManager := experiment.NewManager("./experiments")
|
||||
|
||||
return &TaskQueue{
|
||||
internal: internalQueue,
|
||||
TaskQueue: internalQueue,
|
||||
expManager: expManager,
|
||||
ctx: context.Background(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnqueueTask adds a new task to the queue
|
||||
// EnqueueTask adds a new task to the queue (TUI-specific: creates task with proper defaults)
|
||||
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
|
||||
// Create internal task
|
||||
internalTask := &queue.Task{
|
||||
|
|
@ -57,8 +57,8 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e
|
|||
Priority: priority,
|
||||
}
|
||||
|
||||
// Use internal queue to enqueue
|
||||
err := tq.internal.AddTask(internalTask)
|
||||
// Use embedded queue to enqueue
|
||||
err := tq.TaskQueue.AddTask(internalTask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -67,52 +67,14 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e
|
|||
return internalTask, nil
|
||||
}
|
||||
|
||||
// GetNextTask retrieves the next task from the queue
|
||||
func (tq *TaskQueue) GetNextTask() (*Task, error) {
|
||||
internalTask, err := tq.internal.GetNextTask()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if internalTask == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return domain.Task directly (no conversion needed)
|
||||
return internalTask, nil
|
||||
}
|
||||
|
||||
// GetTask retrieves a specific task by ID
|
||||
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
|
||||
internalTask, err := tq.internal.GetTask(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return domain.Task directly (no conversion needed)
|
||||
return internalTask, nil
|
||||
}
|
||||
|
||||
// UpdateTask updates a task's status and metadata
|
||||
func (tq *TaskQueue) UpdateTask(task *Task) error {
|
||||
// task is already domain.Task, pass directly to internal queue
|
||||
return tq.internal.UpdateTask(task)
|
||||
}
|
||||
|
||||
// GetQueuedTasks retrieves all queued tasks
|
||||
// GetQueuedTasks retrieves all queued tasks (TUI-specific alias for GetAllTasks)
|
||||
func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) {
|
||||
internalTasks, err := tq.internal.GetAllTasks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return domain.Tasks directly (no conversion needed)
|
||||
return internalTasks, nil
|
||||
return tq.TaskQueue.GetAllTasks()
|
||||
}
|
||||
|
||||
// GetJobStatus gets the status of all jobs with the given name
|
||||
// GetJobStatus gets the status of a job by name (TUI-specific convenience method)
|
||||
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
|
||||
// This method doesn't exist in internal queue, implement basic version
|
||||
task, err := tq.internal.GetTaskByName(jobName)
|
||||
task, err := tq.TaskQueue.GetTaskByName(jobName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -126,27 +88,26 @@ func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// RecordMetric records a metric for monitoring
|
||||
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
|
||||
_ = jobName // Parameter reserved for future use
|
||||
return tq.internal.RecordMetric(jobName, metric, value)
|
||||
}
|
||||
|
||||
// GetMetrics retrieves metrics for a job
|
||||
// GetMetrics retrieves metrics for a job (TUI-specific: currently returns empty)
|
||||
func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
// ListDatasets retrieves available datasets
|
||||
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
|
||||
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
|
||||
func (tq *TaskQueue) ListDatasets() ([]struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return []model.DatasetInfo{}, nil
|
||||
}
|
||||
|
||||
// CancelTask cancels a task by ID
|
||||
func (tq *TaskQueue) CancelTask(taskID string) error {
|
||||
return tq.internal.CancelTask(taskID)
|
||||
return []struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}{}, nil
|
||||
}
|
||||
|
||||
// ListExperiments retrieves experiment list
|
||||
|
|
@ -154,7 +115,7 @@ func (tq *TaskQueue) ListExperiments() ([]string, error) {
|
|||
return tq.expManager.ListExperiments()
|
||||
}
|
||||
|
||||
// GetExperimentDetails retrieves experiment details
|
||||
// GetExperimentDetails retrieves formatted experiment details
|
||||
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
|
||||
meta, err := tq.expManager.ReadMetadata(commitID)
|
||||
if err != nil {
|
||||
|
|
@ -185,7 +146,7 @@ func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
|
|||
|
||||
// Close closes the task queue
|
||||
func (tq *TaskQueue) Close() error {
|
||||
return tq.internal.Close()
|
||||
return tq.TaskQueue.Close()
|
||||
}
|
||||
|
||||
// MLServer is an alias for network.MLServer for backward compatibility
|
||||
|
|
|
|||
BIN
data_manager
Executable file
BIN
data_manager
Executable file
Binary file not shown.
10
go.mod
10
go.mod
|
|
@ -8,11 +8,9 @@ go 1.25.0
|
|||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/charmbracelet/bubbles v0.21.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/lib/pq v1.10.9
|
||||
|
|
@ -20,7 +18,6 @@ require (
|
|||
github.com/minio/minio-go/v7 v7.0.97
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/xeipuuv/gojsonschema v1.2.0
|
||||
github.com/zalando/go-keyring v0.2.6
|
||||
golang.org/x/crypto v0.45.0
|
||||
|
|
@ -28,6 +25,12 @@ require (
|
|||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
// Test-only dependencies
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
al.essio.dev/pkg/shellescape v1.6.0 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
|
|
@ -46,7 +49,6 @@ require (
|
|||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/godbus/dbus/v5 v5.2.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
|
|
|
|||
14
go.sum
14
go.sum
|
|
@ -52,12 +52,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
|
|
@ -107,12 +103,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
|||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
|
||||
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
|
@ -181,9 +171,5 @@ google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j
|
|||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
|||
144
internal/storage/dataset.go
Normal file
144
internal/storage/dataset.go
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
// Package storage provides storage abstractions for datasets and transfer tracking.
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// DatasetInfo contains information about a dataset.
|
||||
type DatasetInfo struct {
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
Location string `json:"location"` // "nas" or "ml"
|
||||
LastAccess time.Time `json:"last_access"`
|
||||
}
|
||||
|
||||
// DatasetStore manages dataset metadata and transfer tracking.
|
||||
type DatasetStore struct {
|
||||
client redis.UniversalClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewDatasetStore creates a new DatasetStore with the given Redis client.
|
||||
func NewDatasetStore(client redis.UniversalClient) *DatasetStore {
|
||||
return &DatasetStore{
|
||||
client: client,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewDatasetStoreWithContext creates a new DatasetStore with a custom context.
|
||||
func NewDatasetStoreWithContext(client redis.UniversalClient, ctx context.Context) *DatasetStore {
|
||||
return &DatasetStore{
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// datasetKey returns the Redis key for dataset info.
|
||||
func datasetKey(name string) string {
|
||||
return fmt.Sprintf("ml:dataset:%s", name)
|
||||
}
|
||||
|
||||
// transferKey returns the Redis key for transfer tracking.
|
||||
func transferKey(datasetName string) string {
|
||||
return fmt.Sprintf("ml:data:transfer:%s", datasetName)
|
||||
}
|
||||
|
||||
// RecordTransferStart records the start of a dataset transfer.
|
||||
func (s *DatasetStore) RecordTransferStart(ctx context.Context, datasetName, jobName string, sizeBytes int64) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
return s.client.HSet(ctx, transferKey(datasetName),
|
||||
"status", "transferring",
|
||||
"job_name", jobName,
|
||||
"size_bytes", sizeBytes,
|
||||
"started_at", time.Now().Unix(),
|
||||
).Err()
|
||||
}
|
||||
|
||||
// RecordTransferComplete records the successful completion of a dataset transfer.
|
||||
func (s *DatasetStore) RecordTransferComplete(ctx context.Context, datasetName string, duration time.Duration) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
return s.client.HSet(ctx, transferKey(datasetName),
|
||||
"status", "completed",
|
||||
"completed_at", time.Now().Unix(),
|
||||
"duration_seconds", duration.Seconds(),
|
||||
).Err()
|
||||
}
|
||||
|
||||
// RecordTransferFailure records a failed dataset transfer.
|
||||
func (s *DatasetStore) RecordTransferFailure(ctx context.Context, datasetName string, transferErr error) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
return s.client.HSet(ctx, transferKey(datasetName),
|
||||
"status", "failed",
|
||||
"error", transferErr.Error(),
|
||||
).Err()
|
||||
}
|
||||
|
||||
// SaveDatasetInfo saves dataset metadata to Redis.
|
||||
func (s *DatasetStore) SaveDatasetInfo(ctx context.Context, info DatasetInfo) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal dataset info: %w", err)
|
||||
}
|
||||
return s.client.Set(ctx, datasetKey(info.Name), data, 0).Err()
|
||||
}
|
||||
|
||||
// GetDatasetInfo retrieves dataset metadata from Redis.
|
||||
func (s *DatasetStore) GetDatasetInfo(ctx context.Context, name string) (*DatasetInfo, error) {
|
||||
if s.client == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := s.client.Get(ctx, datasetKey(name)).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset info: %w", err)
|
||||
}
|
||||
|
||||
var info DatasetInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal dataset info: %w", err)
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// UpdateLastAccess updates the last access time for a dataset.
|
||||
func (s *DatasetStore) UpdateLastAccess(ctx context.Context, name string) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
info, err := s.GetDatasetInfo(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info == nil {
|
||||
return nil // No record to update
|
||||
}
|
||||
|
||||
info.LastAccess = time.Now()
|
||||
return s.SaveDatasetInfo(ctx, *info)
|
||||
}
|
||||
|
||||
// DeleteDatasetInfo removes dataset metadata from Redis.
|
||||
func (s *DatasetStore) DeleteDatasetInfo(ctx context.Context, name string) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(ctx, datasetKey(name)).Err()
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Migrator handles migration from Redis to SQLite
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
|
|
@ -22,7 +25,11 @@ func setupTestServer(t *testing.T) string {
|
|||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
|
||||
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
|
||||
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
|
||||
// Create listener to get actual port
|
||||
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
|
|
@ -37,7 +40,10 @@ func startWSBackendServer(t *testing.T) *httptest.Server {
|
|||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
authConfig := &auth.Config{Enabled: false}
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
|
||||
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
|
||||
|
||||
srv := httptest.NewServer(h)
|
||||
t.Cleanup(srv.Close)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ import (
|
|||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
|
|
@ -44,6 +47,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
authCfg,
|
||||
logger,
|
||||
|
|
@ -54,6 +60,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
@ -135,6 +144,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
authCfg,
|
||||
logger,
|
||||
|
|
@ -145,6 +157,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
@ -231,6 +246,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
|
||||
logger := logging.NewLogger(0, false)
|
||||
authCfg := &auth.Config{Enabled: false}
|
||||
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
||||
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
||||
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
||||
wsHandler := wspkg.NewHandler(
|
||||
authCfg,
|
||||
logger,
|
||||
|
|
@ -241,6 +259,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
|
|||
|
|
@ -69,6 +69,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
|
|||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
nil, // jobsHandler
|
||||
nil, // jupyterHandler
|
||||
nil, // datasetsHandler
|
||||
)
|
||||
server := httptest.NewServer(handler)
|
||||
return server, tq, expManager, s, db
|
||||
|
|
@ -595,6 +598,9 @@ func setupWSIntegrationServer(t *testing.T) (
|
|||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
nil, // jobsHandler
|
||||
nil, // jupyterHandler
|
||||
nil, // datasetsHandler
|
||||
)
|
||||
// Setup test server
|
||||
server := httptest.NewServer(handler)
|
||||
|
|
|
|||
Loading…
Reference in a new issue