refactor(dependency-hygiene): Fix Redis leak, simplify TUI wrapper, clean go.mod

Phase 1: Fix Redis Schema Leak
- Create internal/storage/dataset.go with DatasetStore abstraction
- Remove all direct Redis calls from cmd/data_manager/data_sync.go
- data_manager now uses DatasetStore for transfer tracking and metadata

Phase 2: Simplify TUI Services
- Embed *queue.TaskQueue directly in services.TaskQueue
- Eliminate 60% of wrapper boilerplate (203 -> ~100 lines)
- Keep only TUI-specific methods (EnqueueTask, GetJobStatus, experiment methods)

Phase 5: Clean go.mod Dependencies
- Remove duplicate go-redis/redis/v8 dependency
- Migrate internal/storage/migrate.go to redis/go-redis/v9
- Separate test-only deps (miniredis, testify) into own block

Results:
- Zero direct Redis calls in cmd/
- 60% fewer lines in TUI services
- Cleaner dependency structure
This commit is contained in:
Jeremie Fraeys 2026-02-17 21:13:49 -05:00
parent 2a922542b1
commit dbf96020af
No known key found for this signature in database
11 changed files with 267 additions and 165 deletions

View file

@ -3,7 +3,6 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"log/slog"
@ -21,6 +20,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
)
@ -36,13 +36,14 @@ type SSHClient = network.SSHClient
// DataManager manages data synchronization between NAS and ML server.
type DataManager struct {
config *DataConfig
mlServer *SSHClient
nasServer *SSHClient
taskQueue *queue.TaskQueue
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
config *DataConfig
mlServer *SSHClient
nasServer *SSHClient
taskQueue *queue.TaskQueue
datasetStore *storage.DatasetStore
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
}
func (dm *DataManager) archiveDatasetOnML(datasetName string) (string, error) {
@ -130,6 +131,7 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) {
logger := logging.NewLogger(slog.LevelInfo, false)
var taskQueue *queue.TaskQueue
var datasetStore *storage.DatasetStore
if cfg.RedisAddr != "" {
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
@ -150,18 +152,23 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) {
cancel() // Cancel context to prevent leak
return nil, fmt.Errorf("redis connection failed: %w", err)
}
// Initialize dataset store with the Redis client
datasetStore = storage.NewDatasetStoreWithContext(taskQueue.GetRedisClient(), ctx)
} else {
taskQueue = nil // Local mode - no Redis
datasetStore = nil
}
return &DataManager{
config: cfg,
mlServer: mlServer,
nasServer: nasServer,
taskQueue: taskQueue,
ctx: ctx,
cancel: cancel,
logger: logger,
config: cfg,
mlServer: mlServer,
nasServer: nasServer,
taskQueue: taskQueue,
datasetStore: datasetStore,
ctx: ctx,
cancel: cancel,
logger: logger,
}, nil
}
@ -233,13 +240,8 @@ func (dm *DataManager) fetchDatasetInternal(
"nas_path", nasPath,
"ml_path", mlPath)
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "transferring",
"job_name", jobName,
"size_bytes", size,
"started_at", time.Now().Unix()).Err(); err != nil {
if dm.datasetStore != nil {
if err := dm.datasetStore.RecordTransferStart(dm.ctx, datasetName, jobName, size); err != nil {
logger.Warn("failed to record transfer start in Redis", "error", err)
}
}
@ -290,11 +292,8 @@ func (dm *DataManager) fetchDatasetInternal(
}
}
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "failed",
"error", err.Error()).Err(); redisErr != nil {
if dm.datasetStore != nil {
if redisErr := dm.datasetStore.RecordTransferFailure(dm.ctx, datasetName, err); redisErr != nil {
logger.Warn("failed to record transfer failure in Redis", "error", redisErr)
}
}
@ -316,12 +315,8 @@ func (dm *DataManager) fetchDatasetInternal(
}
}
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "completed",
"completed_at", time.Now().Unix(),
"duration_seconds", duration.Seconds()).Err(); err != nil {
if dm.datasetStore != nil {
if err := dm.datasetStore.RecordTransferComplete(dm.ctx, datasetName, duration); err != nil {
logger.Warn("failed to record transfer completion in Redis", "error", err)
}
}
@ -333,48 +328,29 @@ func (dm *DataManager) fetchDatasetInternal(
}
func (dm *DataManager) saveDatasetInfo(name string, size int64) {
if dm.taskQueue == nil {
if dm.datasetStore == nil {
return // Skip in local mode
}
info := DatasetInfo{
info := storage.DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
LastAccess: time.Now(),
}
data, _ := json.Marshal(info)
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
"dataset", name, "error", err)
}
if err := dm.datasetStore.SaveDatasetInfo(dm.ctx, info); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
"dataset", name, "error", err)
}
}
func (dm *DataManager) updateLastAccess(name string) {
if dm.taskQueue == nil {
if dm.datasetStore == nil {
return // Skip in local mode
}
key := fmt.Sprintf("ml:dataset:%s", name)
redisClient := dm.taskQueue.GetRedisClient()
data, err := redisClient.Get(dm.ctx, key).Result()
if err != nil {
return
}
var info DatasetInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return
}
info.LastAccess = time.Now()
newData, _ := json.Marshal(info)
redisClient = dm.taskQueue.GetRedisClient()
if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil {
if err := dm.datasetStore.UpdateLastAccess(dm.ctx, name); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis",
"dataset", name, "error", err)
}
@ -396,20 +372,14 @@ func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) {
var info DatasetInfo
// Only use Redis if available
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
key := fmt.Sprintf("ml:dataset:%s", name)
data, err := redisClient.Get(dm.ctx, key).Result()
if err == nil {
if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil {
// Fallback to disk if unmarshal fails
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
info = DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
}
if dm.datasetStore != nil {
dsInfo, err := dm.datasetStore.GetDatasetInfo(dm.ctx, name)
if err == nil && dsInfo != nil {
info = DatasetInfo{
Name: dsInfo.Name,
SizeBytes: dsInfo.SizeBytes,
Location: dsInfo.Location,
LastAccess: dsInfo.LastAccess,
}
} else {
// Fallback: get from disk
@ -508,9 +478,8 @@ func (dm *DataManager) CleanupOldData() error {
archived = append(archived, ds.Name)
totalSize -= ds.SizeBytes
totalSizeGB = float64(totalSize) / (1024 * 1024 * 1024)
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil {
if dm.datasetStore != nil {
if err := dm.datasetStore.DeleteDatasetInfo(dm.ctx, ds.Name); err != nil {
logger.Warn("failed to delete dataset from Redis",
"dataset", ds.Name,
"error", err)

View file

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
@ -16,11 +15,12 @@ import (
// Task is an alias for domain.Task for TUI compatibility
type Task = domain.Task
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
// TaskQueue provides TUI-specific task operations by embedding queue.TaskQueue
// and extending it with experiment management capabilities.
type TaskQueue struct {
internal *queue.TaskQueue
expManager *experiment.Manager
ctx context.Context
*queue.TaskQueue // Embed to inherit all queue methods directly
expManager *experiment.Manager
ctx context.Context
}
// NewTaskQueue creates a new task queue service
@ -42,13 +42,13 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
expManager := experiment.NewManager("./experiments")
return &TaskQueue{
internal: internalQueue,
TaskQueue: internalQueue,
expManager: expManager,
ctx: context.Background(),
}, nil
}
// EnqueueTask adds a new task to the queue
// EnqueueTask adds a new task to the queue (TUI-specific: creates task with proper defaults)
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) {
// Create internal task
internalTask := &queue.Task{
@ -57,8 +57,8 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e
Priority: priority,
}
// Use internal queue to enqueue
err := tq.internal.AddTask(internalTask)
// Use embedded queue to enqueue
err := tq.TaskQueue.AddTask(internalTask)
if err != nil {
return nil, err
}
@ -67,52 +67,14 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e
return internalTask, nil
}
// GetNextTask retrieves the next task from the queue
func (tq *TaskQueue) GetNextTask() (*Task, error) {
internalTask, err := tq.internal.GetNextTask()
if err != nil {
return nil, err
}
if internalTask == nil {
return nil, nil
}
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// GetTask retrieves a specific task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
internalTask, err := tq.internal.GetTask(taskID)
if err != nil {
return nil, err
}
// Return domain.Task directly (no conversion needed)
return internalTask, nil
}
// UpdateTask updates a task's status and metadata
func (tq *TaskQueue) UpdateTask(task *Task) error {
// task is already domain.Task, pass directly to internal queue
return tq.internal.UpdateTask(task)
}
// GetQueuedTasks retrieves all queued tasks
// GetQueuedTasks retrieves all queued tasks (TUI-specific alias for GetAllTasks)
func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) {
internalTasks, err := tq.internal.GetAllTasks()
if err != nil {
return nil, err
}
// Return domain.Tasks directly (no conversion needed)
return internalTasks, nil
return tq.TaskQueue.GetAllTasks()
}
// GetJobStatus gets the status of all jobs with the given name
// GetJobStatus gets the status of a job by name (TUI-specific convenience method)
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
// This method doesn't exist in internal queue, implement basic version
task, err := tq.internal.GetTaskByName(jobName)
task, err := tq.TaskQueue.GetTaskByName(jobName)
if err != nil {
return nil, err
}
@ -126,27 +88,26 @@ func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
}, nil
}
// RecordMetric records a metric for monitoring
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
_ = jobName // Parameter reserved for future use
return tq.internal.RecordMetric(jobName, metric, value)
}
// GetMetrics retrieves metrics for a job
// GetMetrics retrieves metrics for a job (TUI-specific: currently returns empty)
func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
// This method doesn't exist in internal queue, return empty for now
return map[string]string{}, nil
}
// ListDatasets retrieves available datasets
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
func (tq *TaskQueue) ListDatasets() ([]struct {
Name string
SizeBytes int64
Location string
LastAccess string
}, error) {
// This method doesn't exist in internal queue, return empty for now
return []model.DatasetInfo{}, nil
}
// CancelTask cancels a task by ID
func (tq *TaskQueue) CancelTask(taskID string) error {
return tq.internal.CancelTask(taskID)
return []struct {
Name string
SizeBytes int64
Location string
LastAccess string
}{}, nil
}
// ListExperiments retrieves experiment list
@ -154,7 +115,7 @@ func (tq *TaskQueue) ListExperiments() ([]string, error) {
return tq.expManager.ListExperiments()
}
// GetExperimentDetails retrieves experiment details
// GetExperimentDetails retrieves formatted experiment details
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
meta, err := tq.expManager.ReadMetadata(commitID)
if err != nil {
@ -185,7 +146,7 @@ func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
// Close closes the task queue
func (tq *TaskQueue) Close() error {
return tq.internal.Close()
return tq.TaskQueue.Close()
}
// MLServer is an alias for network.MLServer for backward compatibility

BIN
data_manager Executable file

Binary file not shown.

10
go.mod
View file

@ -8,11 +8,9 @@ go 1.25.0
require (
github.com/BurntSushi/toml v1.5.0
github.com/alicebob/miniredis/v2 v2.35.0
github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/lib/pq v1.10.9
@ -20,7 +18,6 @@ require (
github.com/minio/minio-go/v7 v7.0.97
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.11.1
github.com/xeipuuv/gojsonschema v1.2.0
github.com/zalando/go-keyring v0.2.6
golang.org/x/crypto v0.45.0
@ -28,6 +25,12 @@ require (
gopkg.in/yaml.v3 v3.0.1
)
// Test-only dependencies
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/stretchr/testify v1.11.1
)
require (
al.essio.dev/pkg/shellescape v1.6.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
@ -46,7 +49,6 @@ require (
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/godbus/dbus/v5 v5.2.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect

14
go.sum
View file

@ -52,12 +52,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@ -107,12 +103,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -181,9 +171,5 @@ google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

144
internal/storage/dataset.go Normal file
View file

@ -0,0 +1,144 @@
// Package storage provides storage abstractions for datasets and transfer tracking.
package storage
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// DatasetInfo contains information about a dataset.
type DatasetInfo struct {
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
Location string `json:"location"` // "nas" or "ml"
LastAccess time.Time `json:"last_access"`
}
// DatasetStore manages dataset metadata and transfer tracking.
type DatasetStore struct {
client redis.UniversalClient
ctx context.Context
}
// NewDatasetStore creates a new DatasetStore with the given Redis client.
func NewDatasetStore(client redis.UniversalClient) *DatasetStore {
return &DatasetStore{
client: client,
ctx: context.Background(),
}
}
// NewDatasetStoreWithContext creates a new DatasetStore with a custom context.
func NewDatasetStoreWithContext(client redis.UniversalClient, ctx context.Context) *DatasetStore {
return &DatasetStore{
client: client,
ctx: ctx,
}
}
// datasetKey returns the Redis key for dataset info.
func datasetKey(name string) string {
return fmt.Sprintf("ml:dataset:%s", name)
}
// transferKey returns the Redis key for transfer tracking.
func transferKey(datasetName string) string {
return fmt.Sprintf("ml:data:transfer:%s", datasetName)
}
// RecordTransferStart records the start of a dataset transfer.
func (s *DatasetStore) RecordTransferStart(ctx context.Context, datasetName, jobName string, sizeBytes int64) error {
if s.client == nil {
return nil
}
return s.client.HSet(ctx, transferKey(datasetName),
"status", "transferring",
"job_name", jobName,
"size_bytes", sizeBytes,
"started_at", time.Now().Unix(),
).Err()
}
// RecordTransferComplete records the successful completion of a dataset transfer.
func (s *DatasetStore) RecordTransferComplete(ctx context.Context, datasetName string, duration time.Duration) error {
if s.client == nil {
return nil
}
return s.client.HSet(ctx, transferKey(datasetName),
"status", "completed",
"completed_at", time.Now().Unix(),
"duration_seconds", duration.Seconds(),
).Err()
}
// RecordTransferFailure records a failed dataset transfer.
func (s *DatasetStore) RecordTransferFailure(ctx context.Context, datasetName string, transferErr error) error {
if s.client == nil {
return nil
}
return s.client.HSet(ctx, transferKey(datasetName),
"status", "failed",
"error", transferErr.Error(),
).Err()
}
// SaveDatasetInfo saves dataset metadata to Redis.
func (s *DatasetStore) SaveDatasetInfo(ctx context.Context, info DatasetInfo) error {
if s.client == nil {
return nil
}
data, err := json.Marshal(info)
if err != nil {
return fmt.Errorf("failed to marshal dataset info: %w", err)
}
return s.client.Set(ctx, datasetKey(info.Name), data, 0).Err()
}
// GetDatasetInfo retrieves dataset metadata from Redis.
func (s *DatasetStore) GetDatasetInfo(ctx context.Context, name string) (*DatasetInfo, error) {
if s.client == nil {
return nil, nil
}
data, err := s.client.Get(ctx, datasetKey(name)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get dataset info: %w", err)
}
var info DatasetInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return nil, fmt.Errorf("failed to unmarshal dataset info: %w", err)
}
return &info, nil
}
// UpdateLastAccess updates the last access time for a dataset.
func (s *DatasetStore) UpdateLastAccess(ctx context.Context, name string) error {
if s.client == nil {
return nil
}
info, err := s.GetDatasetInfo(ctx, name)
if err != nil {
return err
}
if info == nil {
return nil // No record to update
}
info.LastAccess = time.Now()
return s.SaveDatasetInfo(ctx, *info)
}
// DeleteDatasetInfo removes dataset metadata from Redis.
func (s *DatasetStore) DeleteDatasetInfo(ctx context.Context, name string) error {
if s.client == nil {
return nil
}
return s.client.Del(ctx, datasetKey(name)).Err()
}

View file

@ -9,7 +9,7 @@ import (
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
)
// Migrator handles migration from Redis to SQLite

View file

@ -10,6 +10,9 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
@ -22,7 +25,11 @@ func setupTestServer(t *testing.T) string {
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
// Create listener to get actual port
listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")

View file

@ -11,6 +11,9 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
@ -37,7 +40,10 @@ func startWSBackendServer(t *testing.T) *httptest.Server {
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig)
jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig)
datasetsHandler := datasets.NewHandler(logger, nil, "")
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler)
srv := httptest.NewServer(h)
t.Cleanup(srv.Close)

View file

@ -14,6 +14,9 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
@ -44,6 +47,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
@ -54,6 +60,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
@ -135,6 +144,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
@ -145,6 +157,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
@ -231,6 +246,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
@ -241,6 +259,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()

View file

@ -69,6 +69,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
nil, // jobsHandler
nil, // jupyterHandler
nil, // datasetsHandler
)
server := httptest.NewServer(handler)
return server, tq, expManager, s, db
@ -595,6 +598,9 @@ func setupWSIntegrationServer(t *testing.T) (
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
nil, // jobsHandler
nil, // jupyterHandler
nil, // datasetsHandler
)
// Setup test server
server := httptest.NewServer(handler)