diff --git a/cmd/data_manager/data_sync.go b/cmd/data_manager/data_sync.go index fadbd26..24c4cdf 100644 --- a/cmd/data_manager/data_sync.go +++ b/cmd/data_manager/data_sync.go @@ -3,7 +3,6 @@ package main import ( "context" - "encoding/json" "fmt" "log" "log/slog" @@ -21,6 +20,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" ) @@ -36,13 +36,14 @@ type SSHClient = network.SSHClient // DataManager manages data synchronization between NAS and ML server. type DataManager struct { - config *DataConfig - mlServer *SSHClient - nasServer *SSHClient - taskQueue *queue.TaskQueue - ctx context.Context - cancel context.CancelFunc - logger *logging.Logger + config *DataConfig + mlServer *SSHClient + nasServer *SSHClient + taskQueue *queue.TaskQueue + datasetStore *storage.DatasetStore + ctx context.Context + cancel context.CancelFunc + logger *logging.Logger } func (dm *DataManager) archiveDatasetOnML(datasetName string) (string, error) { @@ -130,6 +131,7 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) { logger := logging.NewLogger(slog.LevelInfo, false) var taskQueue *queue.TaskQueue + var datasetStore *storage.DatasetStore if cfg.RedisAddr != "" { queueCfg := queue.Config{ RedisAddr: cfg.RedisAddr, @@ -150,18 +152,23 @@ func NewDataManager(cfg *DataConfig, _ string) (*DataManager, error) { cancel() // Cancel context to prevent leak return nil, fmt.Errorf("redis connection failed: %w", err) } + + // Initialize dataset store with the Redis client + datasetStore = storage.NewDatasetStoreWithContext(taskQueue.GetRedisClient(), ctx) } else { taskQueue = nil // Local mode - no Redis + datasetStore = nil } return &DataManager{ - config: cfg, - mlServer: mlServer, - nasServer: nasServer, - taskQueue: taskQueue, - ctx: ctx, - cancel: cancel, - logger: logger, + config: cfg, + mlServer: mlServer, + nasServer: nasServer, + taskQueue: taskQueue, + datasetStore: datasetStore, + ctx: ctx, + cancel: cancel, + logger: logger, }, nil } @@ -233,13 +240,8 @@ func (dm *DataManager) fetchDatasetInternal( "nas_path", nasPath, "ml_path", mlPath) - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), - "status", "transferring", - "job_name", jobName, - "size_bytes", size, - "started_at", time.Now().Unix()).Err(); err != nil { + if dm.datasetStore != nil { + if err := dm.datasetStore.RecordTransferStart(dm.ctx, datasetName, jobName, size); err != nil { logger.Warn("failed to record transfer start in Redis", "error", err) } } @@ -290,11 +292,8 @@ func (dm *DataManager) fetchDatasetInternal( } } - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), - "status", "failed", - "error", err.Error()).Err(); redisErr != nil { + if dm.datasetStore != nil { + if redisErr := dm.datasetStore.RecordTransferFailure(dm.ctx, datasetName, err); redisErr != nil { logger.Warn("failed to record transfer failure in Redis", "error", redisErr) } } @@ -316,12 +315,8 @@ func (dm *DataManager) fetchDatasetInternal( } } - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), - "status", "completed", - "completed_at", time.Now().Unix(), - "duration_seconds", duration.Seconds()).Err(); err != nil { + if dm.datasetStore != nil { + if err := dm.datasetStore.RecordTransferComplete(dm.ctx, datasetName, duration); err != nil { logger.Warn("failed to record transfer completion in Redis", "error", err) } } @@ -333,48 +328,29 @@ func (dm *DataManager) fetchDatasetInternal( } func (dm *DataManager) saveDatasetInfo(name string, size int64) { - if dm.taskQueue == nil { + if dm.datasetStore == nil { return // Skip in local mode } - info := DatasetInfo{ + info := storage.DatasetInfo{ Name: name, SizeBytes: size, Location: "ml", LastAccess: time.Now(), } - data, _ := json.Marshal(info) - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil { - dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis", - "dataset", name, "error", err) - } + if err := dm.datasetStore.SaveDatasetInfo(dm.ctx, info); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis", + "dataset", name, "error", err) } } func (dm *DataManager) updateLastAccess(name string) { - if dm.taskQueue == nil { + if dm.datasetStore == nil { return // Skip in local mode } - key := fmt.Sprintf("ml:dataset:%s", name) - redisClient := dm.taskQueue.GetRedisClient() - data, err := redisClient.Get(dm.ctx, key).Result() - if err != nil { - return - } - - var info DatasetInfo - if err := json.Unmarshal([]byte(data), &info); err != nil { - return - } - - info.LastAccess = time.Now() - newData, _ := json.Marshal(info) - redisClient = dm.taskQueue.GetRedisClient() - if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil { + if err := dm.datasetStore.UpdateLastAccess(dm.ctx, name); err != nil { dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis", "dataset", name, "error", err) } @@ -396,20 +372,14 @@ func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) { var info DatasetInfo // Only use Redis if available - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - key := fmt.Sprintf("ml:dataset:%s", name) - data, err := redisClient.Get(dm.ctx, key).Result() - - if err == nil { - if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil { - // Fallback to disk if unmarshal fails - size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name)) - info = DatasetInfo{ - Name: name, - SizeBytes: size, - Location: "ml", - } + if dm.datasetStore != nil { + dsInfo, err := dm.datasetStore.GetDatasetInfo(dm.ctx, name) + if err == nil && dsInfo != nil { + info = DatasetInfo{ + Name: dsInfo.Name, + SizeBytes: dsInfo.SizeBytes, + Location: dsInfo.Location, + LastAccess: dsInfo.LastAccess, } } else { // Fallback: get from disk @@ -508,9 +478,8 @@ func (dm *DataManager) CleanupOldData() error { archived = append(archived, ds.Name) totalSize -= ds.SizeBytes totalSizeGB = float64(totalSize) / (1024 * 1024 * 1024) - if dm.taskQueue != nil { - redisClient := dm.taskQueue.GetRedisClient() - if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil { + if dm.datasetStore != nil { + if err := dm.datasetStore.DeleteDatasetInfo(dm.ctx, ds.Name); err != nil { logger.Warn("failed to delete dataset from Redis", "dataset", ds.Name, "error", err) diff --git a/cmd/tui/internal/services/services.go b/cmd/tui/internal/services/services.go index 7a114c7..cf51a47 100644 --- a/cmd/tui/internal/services/services.go +++ b/cmd/tui/internal/services/services.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/jfraeys/fetch_ml/cmd/tui/internal/config" - "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" "github.com/jfraeys/fetch_ml/internal/domain" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/network" @@ -16,11 +15,12 @@ import ( // Task is an alias for domain.Task for TUI compatibility type Task = domain.Task -// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility +// TaskQueue provides TUI-specific task operations by embedding queue.TaskQueue +// and extending it with experiment management capabilities. type TaskQueue struct { - internal *queue.TaskQueue - expManager *experiment.Manager - ctx context.Context + *queue.TaskQueue // Embed to inherit all queue methods directly + expManager *experiment.Manager + ctx context.Context } // NewTaskQueue creates a new task queue service @@ -42,13 +42,13 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) { expManager := experiment.NewManager("./experiments") return &TaskQueue{ - internal: internalQueue, + TaskQueue: internalQueue, expManager: expManager, ctx: context.Background(), }, nil } -// EnqueueTask adds a new task to the queue +// EnqueueTask adds a new task to the queue (TUI-specific: creates task with proper defaults) func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, error) { // Create internal task internalTask := &queue.Task{ @@ -57,8 +57,8 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e Priority: priority, } - // Use internal queue to enqueue - err := tq.internal.AddTask(internalTask) + // Use embedded queue to enqueue + err := tq.TaskQueue.AddTask(internalTask) if err != nil { return nil, err } @@ -67,52 +67,14 @@ func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*Task, e return internalTask, nil } -// GetNextTask retrieves the next task from the queue -func (tq *TaskQueue) GetNextTask() (*Task, error) { - internalTask, err := tq.internal.GetNextTask() - if err != nil { - return nil, err - } - if internalTask == nil { - return nil, nil - } - - // Return domain.Task directly (no conversion needed) - return internalTask, nil -} - -// GetTask retrieves a specific task by ID -func (tq *TaskQueue) GetTask(taskID string) (*Task, error) { - internalTask, err := tq.internal.GetTask(taskID) - if err != nil { - return nil, err - } - - // Return domain.Task directly (no conversion needed) - return internalTask, nil -} - -// UpdateTask updates a task's status and metadata -func (tq *TaskQueue) UpdateTask(task *Task) error { - // task is already domain.Task, pass directly to internal queue - return tq.internal.UpdateTask(task) -} - -// GetQueuedTasks retrieves all queued tasks +// GetQueuedTasks retrieves all queued tasks (TUI-specific alias for GetAllTasks) func (tq *TaskQueue) GetQueuedTasks() ([]*Task, error) { - internalTasks, err := tq.internal.GetAllTasks() - if err != nil { - return nil, err - } - - // Return domain.Tasks directly (no conversion needed) - return internalTasks, nil + return tq.TaskQueue.GetAllTasks() } -// GetJobStatus gets the status of all jobs with the given name +// GetJobStatus gets the status of a job by name (TUI-specific convenience method) func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) { - // This method doesn't exist in internal queue, implement basic version - task, err := tq.internal.GetTaskByName(jobName) + task, err := tq.TaskQueue.GetTaskByName(jobName) if err != nil { return nil, err } @@ -126,27 +88,26 @@ func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) { }, nil } -// RecordMetric records a metric for monitoring -func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error { - _ = jobName // Parameter reserved for future use - return tq.internal.RecordMetric(jobName, metric, value) -} - -// GetMetrics retrieves metrics for a job +// GetMetrics retrieves metrics for a job (TUI-specific: currently returns empty) func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) { // This method doesn't exist in internal queue, return empty for now return map[string]string{}, nil } -// ListDatasets retrieves available datasets -func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) { +// ListDatasets retrieves available datasets (TUI-specific: currently returns empty) +func (tq *TaskQueue) ListDatasets() ([]struct { + Name string + SizeBytes int64 + Location string + LastAccess string +}, error) { // This method doesn't exist in internal queue, return empty for now - return []model.DatasetInfo{}, nil -} - -// CancelTask cancels a task by ID -func (tq *TaskQueue) CancelTask(taskID string) error { - return tq.internal.CancelTask(taskID) + return []struct { + Name string + SizeBytes int64 + Location string + LastAccess string + }{}, nil } // ListExperiments retrieves experiment list @@ -154,7 +115,7 @@ func (tq *TaskQueue) ListExperiments() ([]string, error) { return tq.expManager.ListExperiments() } -// GetExperimentDetails retrieves experiment details +// GetExperimentDetails retrieves formatted experiment details func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) { meta, err := tq.expManager.ReadMetadata(commitID) if err != nil { @@ -185,7 +146,7 @@ func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) { // Close closes the task queue func (tq *TaskQueue) Close() error { - return tq.internal.Close() + return tq.TaskQueue.Close() } // MLServer is an alias for network.MLServer for backward compatibility diff --git a/data_manager b/data_manager new file mode 100755 index 0000000..8915a34 Binary files /dev/null and b/data_manager differ diff --git a/go.mod b/go.mod index 81bd5ef..6a6803c 100644 --- a/go.mod +++ b/go.mod @@ -8,11 +8,9 @@ go 1.25.0 require ( github.com/BurntSushi/toml v1.5.0 - github.com/alicebob/miniredis/v2 v2.35.0 github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 - github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/lib/pq v1.10.9 @@ -20,7 +18,6 @@ require ( github.com/minio/minio-go/v7 v7.0.97 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.17.2 - github.com/stretchr/testify v1.11.1 github.com/xeipuuv/gojsonschema v1.2.0 github.com/zalando/go-keyring v0.2.6 golang.org/x/crypto v0.45.0 @@ -28,6 +25,12 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +// Test-only dependencies +require ( + github.com/alicebob/miniredis/v2 v2.35.0 + github.com/stretchr/testify v1.11.1 +) + require ( al.essio.dev/pkg/shellescape v1.6.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect @@ -46,7 +49,6 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/godbus/dbus/v5 v5.2.0 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index 8b9cfdf..147c82f 100644 --- a/go.sum +++ b/go.sum @@ -52,12 +52,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= -github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= -github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= -github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= -github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -107,12 +103,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -181,9 +171,5 @@ google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/storage/dataset.go b/internal/storage/dataset.go new file mode 100644 index 0000000..033417a --- /dev/null +++ b/internal/storage/dataset.go @@ -0,0 +1,144 @@ +// Package storage provides storage abstractions for datasets and transfer tracking. +package storage + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// DatasetInfo contains information about a dataset. +type DatasetInfo struct { + Name string `json:"name"` + SizeBytes int64 `json:"size_bytes"` + Location string `json:"location"` // "nas" or "ml" + LastAccess time.Time `json:"last_access"` +} + +// DatasetStore manages dataset metadata and transfer tracking. +type DatasetStore struct { + client redis.UniversalClient + ctx context.Context +} + +// NewDatasetStore creates a new DatasetStore with the given Redis client. +func NewDatasetStore(client redis.UniversalClient) *DatasetStore { + return &DatasetStore{ + client: client, + ctx: context.Background(), + } +} + +// NewDatasetStoreWithContext creates a new DatasetStore with a custom context. +func NewDatasetStoreWithContext(client redis.UniversalClient, ctx context.Context) *DatasetStore { + return &DatasetStore{ + client: client, + ctx: ctx, + } +} + +// datasetKey returns the Redis key for dataset info. +func datasetKey(name string) string { + return fmt.Sprintf("ml:dataset:%s", name) +} + +// transferKey returns the Redis key for transfer tracking. +func transferKey(datasetName string) string { + return fmt.Sprintf("ml:data:transfer:%s", datasetName) +} + +// RecordTransferStart records the start of a dataset transfer. +func (s *DatasetStore) RecordTransferStart(ctx context.Context, datasetName, jobName string, sizeBytes int64) error { + if s.client == nil { + return nil + } + return s.client.HSet(ctx, transferKey(datasetName), + "status", "transferring", + "job_name", jobName, + "size_bytes", sizeBytes, + "started_at", time.Now().Unix(), + ).Err() +} + +// RecordTransferComplete records the successful completion of a dataset transfer. +func (s *DatasetStore) RecordTransferComplete(ctx context.Context, datasetName string, duration time.Duration) error { + if s.client == nil { + return nil + } + return s.client.HSet(ctx, transferKey(datasetName), + "status", "completed", + "completed_at", time.Now().Unix(), + "duration_seconds", duration.Seconds(), + ).Err() +} + +// RecordTransferFailure records a failed dataset transfer. +func (s *DatasetStore) RecordTransferFailure(ctx context.Context, datasetName string, transferErr error) error { + if s.client == nil { + return nil + } + return s.client.HSet(ctx, transferKey(datasetName), + "status", "failed", + "error", transferErr.Error(), + ).Err() +} + +// SaveDatasetInfo saves dataset metadata to Redis. +func (s *DatasetStore) SaveDatasetInfo(ctx context.Context, info DatasetInfo) error { + if s.client == nil { + return nil + } + data, err := json.Marshal(info) + if err != nil { + return fmt.Errorf("failed to marshal dataset info: %w", err) + } + return s.client.Set(ctx, datasetKey(info.Name), data, 0).Err() +} + +// GetDatasetInfo retrieves dataset metadata from Redis. +func (s *DatasetStore) GetDatasetInfo(ctx context.Context, name string) (*DatasetInfo, error) { + if s.client == nil { + return nil, nil + } + data, err := s.client.Get(ctx, datasetKey(name)).Result() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get dataset info: %w", err) + } + + var info DatasetInfo + if err := json.Unmarshal([]byte(data), &info); err != nil { + return nil, fmt.Errorf("failed to unmarshal dataset info: %w", err) + } + return &info, nil +} + +// UpdateLastAccess updates the last access time for a dataset. +func (s *DatasetStore) UpdateLastAccess(ctx context.Context, name string) error { + if s.client == nil { + return nil + } + info, err := s.GetDatasetInfo(ctx, name) + if err != nil { + return err + } + if info == nil { + return nil // No record to update + } + + info.LastAccess = time.Now() + return s.SaveDatasetInfo(ctx, *info) +} + +// DeleteDatasetInfo removes dataset metadata from Redis. +func (s *DatasetStore) DeleteDatasetInfo(ctx context.Context, name string) error { + if s.client == nil { + return nil + } + return s.client.Del(ctx, datasetKey(name)).Err() +} diff --git a/internal/storage/migrate.go b/internal/storage/migrate.go index fa771fc..2fa5d98 100644 --- a/internal/storage/migrate.go +++ b/internal/storage/migrate.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) // Migrator handles migration from Redis to SQLite diff --git a/tests/e2e/websocket_e2e_test.go b/tests/e2e/websocket_e2e_test.go index 672b6d1..7a0bf1a 100644 --- a/tests/e2e/websocket_e2e_test.go +++ b/tests/e2e/websocket_e2e_test.go @@ -10,6 +10,9 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/jobs" + jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" @@ -22,7 +25,11 @@ func setupTestServer(t *testing.T) string { authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) + datasetsHandler := datasets.NewHandler(logger, nil, "") + + wsHandler := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) // Create listener to get actual port listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") diff --git a/tests/e2e/wss_reverse_proxy_e2e_test.go b/tests/e2e/wss_reverse_proxy_e2e_test.go index bc474b2..3dbc53a 100644 --- a/tests/e2e/wss_reverse_proxy_e2e_test.go +++ b/tests/e2e/wss_reverse_proxy_e2e_test.go @@ -11,6 +11,9 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/jobs" + jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" @@ -37,7 +40,10 @@ func startWSBackendServer(t *testing.T) *httptest.Server { logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) - h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil) + jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig) + jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) + datasetsHandler := datasets.NewHandler(logger, nil, "") + h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) srv := httptest.NewServer(h) t.Cleanup(srv.Close) diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index 8179b52..a9ba9ab 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -14,6 +14,9 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api" + "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/jobs" + jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" @@ -44,6 +47,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) + datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( authCfg, logger, @@ -54,6 +60,9 @@ func TestWebSocketQueueEndToEnd(t *testing.T) { nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger + jobsHandler, + jupyterHandler, + datasetsHandler, ) server := httptest.NewServer(wsHandler) defer server.Close() @@ -135,6 +144,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) + datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( authCfg, logger, @@ -145,6 +157,9 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) { nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger + jobsHandler, + jupyterHandler, + datasetsHandler, ) server := httptest.NewServer(wsHandler) defer server.Close() @@ -231,6 +246,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} + jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg) + jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg) + datasetsHandler := datasets.NewHandler(logger, nil, "") wsHandler := wspkg.NewHandler( authCfg, logger, @@ -241,6 +259,9 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger + jobsHandler, + jupyterHandler, + datasetsHandler, ) server := httptest.NewServer(wsHandler) defer server.Close() diff --git a/tests/integration/ws_handler_integration_test.go b/tests/integration/ws_handler_integration_test.go index e922f56..a798d3b 100644 --- a/tests/integration/ws_handler_integration_test.go +++ b/tests/integration/ws_handler_integration_test.go @@ -69,6 +69,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) ( nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger + nil, // jobsHandler + nil, // jupyterHandler + nil, // datasetsHandler ) server := httptest.NewServer(handler) return server, tq, expManager, s, db @@ -595,6 +598,9 @@ func setupWSIntegrationServer(t *testing.T) ( nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger + nil, // jobsHandler + nil, // jupyterHandler + nil, // datasetsHandler ) // Setup test server server := httptest.NewServer(handler)