refactor(queue): integrate scheduler backend and storage improvements
Update queue and storage systems for scheduler integration: - Queue backend with scheduler coordination - Filesystem queue with batch operations - Deduplication with tenant-aware keys - Storage layer with audit logging hooks - Domain models (Task, Events, Errors) with scheduler fields - Database layer with tenant isolation - Dataset storage with integrity checks
This commit is contained in:
parent
6b2c377680
commit
6866ba9366
14 changed files with 211 additions and 141 deletions
|
|
@ -81,16 +81,16 @@ func ClassifyFailure(exitCode int, signal os.Signal, logTail string) FailureClas
|
|||
|
||||
// FailureInfo contains complete failure context for the manifest
|
||||
type FailureInfo struct {
|
||||
Context map[string]string `json:"context,omitempty"`
|
||||
Class FailureClass `json:"class"`
|
||||
ExitCode int `json:"exit_code,omitempty"`
|
||||
Signal string `json:"signal,omitempty"`
|
||||
LogTail string `json:"log_tail,omitempty"`
|
||||
Suggestion string `json:"suggestion,omitempty"`
|
||||
AutoRetried bool `json:"auto_retried,omitempty"`
|
||||
ClassifiedAt string `json:"classified_at,omitempty"`
|
||||
ExitCode int `json:"exit_code,omitempty"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
RetryCap int `json:"retry_cap,omitempty"`
|
||||
ClassifiedAt string `json:"classified_at,omitempty"`
|
||||
Context map[string]string `json:"context,omitempty"`
|
||||
AutoRetried bool `json:"auto_retried,omitempty"`
|
||||
}
|
||||
|
||||
// GetFailureSuggestion returns user guidance based on failure class
|
||||
|
|
|
|||
|
|
@ -30,22 +30,11 @@ const (
|
|||
// TaskEvent represents an event in a task's lifecycle.
|
||||
// Events are stored in Redis Streams for append-only audit trails.
|
||||
type TaskEvent struct {
|
||||
// TaskID is the unique identifier of the task.
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
// EventType indicates what happened (queued, started, completed, etc.).
|
||||
EventType TaskEventType `json:"event_type"`
|
||||
|
||||
// Timestamp when the event occurred.
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
// Data contains event-specific data (JSON-encoded).
|
||||
// For "started": {"worker_id": "worker-1", "image": "pytorch:latest"}
|
||||
// For "failed": {"error": "OOM", "phase": "execution"}
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
|
||||
// Who triggered this event (worker ID, user ID, or system).
|
||||
Who string `json:"who"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
TaskID string `json:"task_id"`
|
||||
EventType TaskEventType `json:"event_type"`
|
||||
Who string `json:"who"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// EventDataStarted contains data for the "started" event.
|
||||
|
|
@ -64,8 +53,8 @@ type EventDataFailed struct {
|
|||
|
||||
// EventDataGPUAssigned contains data for the "gpu_assigned" event.
|
||||
type EventDataGPUAssigned struct {
|
||||
GPUDevices []string `json:"gpu_devices"`
|
||||
GPUEnvVar string `json:"gpu_env_var,omitempty"`
|
||||
GPUDevices []string `json:"gpu_devices"`
|
||||
}
|
||||
|
||||
// NewTaskEvent creates a new task event with the current timestamp.
|
||||
|
|
|
|||
|
|
@ -8,65 +8,48 @@ import (
|
|||
|
||||
// Task represents an ML experiment task
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
// SnapshotID references the experiment snapshot (code + deps) for this task.
|
||||
// Currently stores an opaque identifier. Future: verify checksum/digest before execution
|
||||
// to ensure reproducibility and detect tampering.
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
// Datasets is kept for backward compatibility (legacy callers).
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// Resource requests (optional, 0 means unspecified)
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
MemoryGB int `json:"memory_gb,omitempty"`
|
||||
GPU int `json:"gpu,omitempty"`
|
||||
GPUMemory string `json:"gpu_memory,omitempty"`
|
||||
|
||||
// User ownership and permissions
|
||||
UserID string `json:"user_id"` // User who owns this task
|
||||
Username string `json:"username"` // Username for display
|
||||
CreatedBy string `json:"created_by"` // User who submitted the task
|
||||
|
||||
// Lease management for task resilience
|
||||
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
|
||||
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
|
||||
|
||||
// Retry management
|
||||
RetryCount int `json:"retry_count"` // Number of retry attempts made
|
||||
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
|
||||
LastError string `json:"last_error,omitempty"` // Last error encountered
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
|
||||
|
||||
// Attempt tracking - complete history of all execution attempts
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
|
||||
// Optional tracking configuration for this task
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
Tracking *TrackingConfig `json:"tracking,omitempty"`
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"`
|
||||
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
Username string `json:"username"`
|
||||
LeasedBy string `json:"leased_by,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Args string `json:"args"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
JobName string `json:"job_name"`
|
||||
GPUMemory string `json:"gpu_memory,omitempty"`
|
||||
UserID string `json:"user_id"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
MemoryGB int `json:"memory_gb,omitempty"`
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
GPU int `json:"gpu,omitempty"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
Priority int64 `json:"priority"`
|
||||
}
|
||||
|
||||
// Attempt represents a single execution attempt of a task
|
||||
type Attempt struct {
|
||||
Attempt int `json:"attempt"` // Attempt number (1-indexed)
|
||||
StartedAt time.Time `json:"started_at"` // When attempt started
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"` // When attempt ended (if completed)
|
||||
WorkerID string `json:"worker_id,omitempty"` // Which worker ran this attempt
|
||||
Status string `json:"status"` // running, completed, failed
|
||||
FailureClass FailureClass `json:"failure_class,omitempty"` // Failure classification (if failed)
|
||||
ExitCode int `json:"exit_code,omitempty"` // Process exit code
|
||||
Signal string `json:"signal,omitempty"` // Termination signal (if any)
|
||||
Error string `json:"error,omitempty"` // Error message (if failed)
|
||||
LogTail string `json:"log_tail,omitempty"` // Last N lines of log output
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
FailureClass FailureClass `json:"failure_class,omitempty"`
|
||||
Signal string `json:"signal,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LogTail string `json:"log_tail,omitempty"`
|
||||
Attempt int `json:"attempt"`
|
||||
ExitCode int `json:"exit_code,omitempty"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,22 +9,22 @@ type TrackingConfig struct {
|
|||
|
||||
// MLflowTrackingConfig controls MLflow integration.
|
||||
type MLflowTrackingConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
TrackingURI string `json:"tracking_uri,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "remote" | "disabled"
|
||||
TrackingURI string `json:"tracking_uri,omitempty"` // Explicit tracking URI for remote mode
|
||||
}
|
||||
|
||||
// TensorBoardTrackingConfig controls TensorBoard integration.
|
||||
type TensorBoardTrackingConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "sidecar" | "disabled"
|
||||
}
|
||||
|
||||
// WandbTrackingConfig controls Weights & Biases integration.
|
||||
type WandbTrackingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"` // "remote" | "disabled"
|
||||
Mode string `json:"mode,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Entity string `json:"entity,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
)
|
||||
|
||||
var ErrInvalidQueueBackend = errors.New("invalid queue backend")
|
||||
|
|
@ -48,24 +50,38 @@ type Backend interface {
|
|||
type QueueBackend string
|
||||
|
||||
const (
|
||||
QueueBackendRedis QueueBackend = "redis"
|
||||
QueueBackendSQLite QueueBackend = "sqlite"
|
||||
QueueBackendFS QueueBackend = "filesystem"
|
||||
QueueBackendNative QueueBackend = "native" // Native C++ queue_index (requires -tags native_libs)
|
||||
QueueBackendRedis QueueBackend = "redis"
|
||||
QueueBackendSQLite QueueBackend = "sqlite"
|
||||
QueueBackendFS QueueBackend = "filesystem"
|
||||
QueueBackendNative QueueBackend = "native" // Native C++ queue_index (requires -tags native_libs)
|
||||
QueueBackendScheduler QueueBackend = "scheduler" // Distributed mode via WebSocket
|
||||
)
|
||||
|
||||
type SchedulerConfig struct {
|
||||
Address string // Scheduler address (e.g., "192.168.1.10:7777")
|
||||
Cert string // Path to scheduler's TLS certificate
|
||||
Token string // Worker authentication token
|
||||
}
|
||||
|
||||
type BackendConfig struct {
|
||||
Mode string // "standalone" | "distributed"
|
||||
Backend QueueBackend
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
SQLitePath string
|
||||
FilesystemPath string
|
||||
FallbackToFilesystem bool
|
||||
RedisDB int
|
||||
MetricsFlushInterval time.Duration
|
||||
FallbackToFilesystem bool
|
||||
Scheduler SchedulerConfig // Config for distributed mode
|
||||
}
|
||||
|
||||
func NewBackend(cfg BackendConfig) (Backend, error) {
|
||||
// Distributed mode: use SchedulerBackend
|
||||
if cfg.Mode == "distributed" {
|
||||
return NewSchedulerBackendFromConfig(cfg.Scheduler)
|
||||
}
|
||||
|
||||
mkFallback := func(err error) (Backend, error) {
|
||||
if !cfg.FallbackToFilesystem {
|
||||
return nil, err
|
||||
|
|
@ -112,3 +128,15 @@ func NewBackend(cfg BackendConfig) (Backend, error) {
|
|||
return nil, ErrInvalidQueueBackend
|
||||
}
|
||||
}
|
||||
|
||||
// NewSchedulerBackendFromConfig creates a SchedulerBackend from config
|
||||
func NewSchedulerBackendFromConfig(cfg SchedulerConfig) (Backend, error) {
|
||||
if cfg.Address == "" {
|
||||
return nil, fmt.Errorf("scheduler address is required for distributed mode")
|
||||
}
|
||||
conn := scheduler.NewSchedulerConn(cfg.Address, cfg.Cert, cfg.Token, "", scheduler.WorkerCapabilities{})
|
||||
if err := conn.Connect(); err != nil {
|
||||
return nil, fmt.Errorf("connect to scheduler: %w", err)
|
||||
}
|
||||
return NewSchedulerBackend(conn), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ var ErrAlreadyQueued = fmt.Errorf("job already queued with this commit")
|
|||
|
||||
// CommitDedup tracks recently queued commits to prevent duplicate submissions
|
||||
type CommitDedup struct {
|
||||
mu sync.RWMutex
|
||||
commits map[string]time.Time // key: "job_name:commit_id" -> queued_at
|
||||
commits map[string]time.Time
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCommitDedup creates a new commit deduplication tracker
|
||||
|
|
|
|||
|
|
@ -8,28 +8,77 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
)
|
||||
|
||||
// validTaskID is an allowlist regex for task IDs.
|
||||
// Only alphanumeric, underscore, and hyphen allowed. Max 128 chars.
|
||||
// This prevents path traversal attacks (null bytes, slashes, backslashes, etc.)
|
||||
var validTaskID = regexp.MustCompile(`^[a-zA-Z0-9_\-]{1,128}$`)
|
||||
|
||||
// validateTaskID checks if a task ID is valid according to the allowlist.
|
||||
func validateTaskID(id string) error {
|
||||
if id == "" {
|
||||
return errors.New("task ID is required")
|
||||
}
|
||||
if !validTaskID.MatchString(id) {
|
||||
return fmt.Errorf("invalid task ID %q: must match %s", id, validTaskID.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeTaskFile writes task data with O_NOFOLLOW, fsync, and cleanup on error.
|
||||
// This prevents symlink attacks and ensures crash safety.
|
||||
func writeTaskFile(path string, data []byte) error {
|
||||
// Use O_NOFOLLOW to prevent following symlinks (TOCTOU protection)
|
||||
f, err := fileutil.OpenFileNoFollow(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open (symlink rejected): %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(data); err != nil {
|
||||
f.Close()
|
||||
os.Remove(path) // remove partial write
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
// CRITICAL: fsync ensures data is flushed to disk before returning
|
||||
if err := f.Sync(); err != nil {
|
||||
f.Close()
|
||||
os.Remove(path) // remove unsynced file
|
||||
return fmt.Errorf("fsync: %w", err)
|
||||
}
|
||||
|
||||
// Close can fail on some filesystems (NFS, network-backed volumes)
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(path) // remove file if close failed
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Queue implements a filesystem-based task queue
|
||||
type Queue struct {
|
||||
root string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
root string
|
||||
}
|
||||
|
||||
type queueIndex struct {
|
||||
Version int `json:"version"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Tasks []queueIndexTask `json:"tasks"`
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
type queueIndexTask struct {
|
||||
ID string `json:"id"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Priority int64 `json:"priority"`
|
||||
}
|
||||
|
||||
// NewQueue creates a new filesystem queue instance
|
||||
|
|
@ -65,29 +114,42 @@ func (q *Queue) AddTask(task *domain.Task) error {
|
|||
if task == nil {
|
||||
return errors.New("task is nil")
|
||||
}
|
||||
if task.ID == "" {
|
||||
return errors.New("task ID is required")
|
||||
if err := validateTaskID(task.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pendingDir := filepath.Join(q.root, "pending", "entries")
|
||||
taskFile := filepath.Join(pendingDir, task.ID+".json")
|
||||
|
||||
// SECURITY: Verify resolved path is still inside pendingDir (symlink/traversal check)
|
||||
resolvedDir, err := filepath.EvalSymlinks(pendingDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve pending dir: %w", err)
|
||||
}
|
||||
resolvedFile, err := filepath.EvalSymlinks(filepath.Dir(taskFile))
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve task dir: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(resolvedFile+string(filepath.Separator), resolvedDir+string(filepath.Separator)) {
|
||||
return fmt.Errorf("task path %q escapes queue root", taskFile)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(taskFile, data, 0640); err != nil {
|
||||
// SECURITY: Write with fsync + cleanup on error + O_NOFOLLOW to prevent symlink attacks
|
||||
if err := writeTaskFile(taskFile, data); err != nil {
|
||||
return fmt.Errorf("failed to write task file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID
|
||||
func (q *Queue) GetTask(id string) (*domain.Task, error) {
|
||||
if id == "" {
|
||||
return nil, errors.New("task ID is required")
|
||||
if err := validateTaskID(id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Search in all directories
|
||||
|
|
@ -140,6 +202,9 @@ func (q *Queue) ListTasks() ([]*domain.Task, error) {
|
|||
|
||||
// CancelTask cancels a task
|
||||
func (q *Queue) CancelTask(id string) error {
|
||||
if err := validateTaskID(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Remove from pending if exists
|
||||
pendingFile := filepath.Join(q.root, "pending", "entries", id+".json")
|
||||
if _, err := os.Stat(pendingFile); err == nil {
|
||||
|
|
@ -150,8 +215,11 @@ func (q *Queue) CancelTask(id string) error {
|
|||
|
||||
// UpdateTask updates a task
|
||||
func (q *Queue) UpdateTask(task *domain.Task) error {
|
||||
if task == nil || task.ID == "" {
|
||||
return errors.New("task is nil or missing ID")
|
||||
if task == nil {
|
||||
return errors.New("task is nil")
|
||||
}
|
||||
if err := validateTaskID(task.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find current location
|
||||
|
|
@ -173,7 +241,11 @@ func (q *Queue) UpdateTask(task *domain.Task) error {
|
|||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(currentFile, data, 0640)
|
||||
// SECURITY: Write with O_NOFOLLOW + fsync + cleanup on error
|
||||
if err := writeTaskFile(currentFile, data); err != nil {
|
||||
return fmt.Errorf("failed to write task file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebuildIndex rebuilds the queue index
|
||||
|
|
|
|||
|
|
@ -13,24 +13,25 @@ import (
|
|||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
)
|
||||
|
||||
type FilesystemQueue struct {
|
||||
root string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
root string
|
||||
}
|
||||
|
||||
type filesystemQueueIndex struct {
|
||||
Version int `json:"version"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Tasks []filesystemQueueIndexTask `json:"tasks"`
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
type filesystemQueueIndexTask struct {
|
||||
ID string `json:"id"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Priority int64 `json:"priority"`
|
||||
}
|
||||
|
||||
func NewFilesystemQueue(root string) (*FilesystemQueue, error) {
|
||||
|
|
@ -572,9 +573,6 @@ func writeFileAtomic(path string, data []byte, perm os.FileMode) error {
|
|||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, perm); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, path)
|
||||
// SECURITY: Use WriteFileSafe for atomic write with fsync
|
||||
return fileutil.WriteFileSafe(path, data, perm)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ const (
|
|||
|
||||
// TaskQueue manages ML experiment tasks via Redis
|
||||
type TaskQueue struct {
|
||||
client *redis.Client
|
||||
ctx context.Context
|
||||
client *redis.Client
|
||||
cancel context.CancelFunc
|
||||
metricsCh chan metricEvent
|
||||
metricsDone chan struct{}
|
||||
dedup *CommitDedup
|
||||
flushEvery time.Duration
|
||||
dedup *CommitDedup // Tracks recently queued commits
|
||||
}
|
||||
|
||||
type metricEvent struct {
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ import (
|
|||
|
||||
// DatasetInfo contains information about a dataset.
|
||||
type DatasetInfo struct {
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
Location string `json:"location"` // "nas" or "ml"
|
||||
LastAccess time.Time `json:"last_access"`
|
||||
Name string `json:"name"`
|
||||
Location string `json:"location"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
}
|
||||
|
||||
// DatasetStore manages dataset metadata and transfer tracking.
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ type DBConfig struct {
|
|||
Type string
|
||||
Connection string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Port int
|
||||
}
|
||||
|
||||
// DB wraps a database connection with type information.
|
||||
|
|
|
|||
|
|
@ -11,17 +11,18 @@ import (
|
|||
)
|
||||
|
||||
type Experiment struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Status string `json:"status"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
WorkspaceID string `json:"workspace_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ExperimentEnvironment struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
PythonVersion string `json:"python_version"`
|
||||
CUDAVersion string `json:"cuda_version,omitempty"`
|
||||
SystemOS string `json:"system_os"`
|
||||
|
|
@ -30,16 +31,15 @@ type ExperimentEnvironment struct {
|
|||
RequirementsHash string `json:"requirements_hash"`
|
||||
CondaEnvHash string `json:"conda_env_hash,omitempty"`
|
||||
Dependencies json.RawMessage `json:"dependencies,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ExperimentGitInfo struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CommitSHA string `json:"commit_sha"`
|
||||
Branch string `json:"branch"`
|
||||
RemoteURL string `json:"remote_url"`
|
||||
IsDirty bool `json:"is_dirty"`
|
||||
DiffPatch string `json:"diff_patch,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
IsDirty bool `json:"is_dirty"`
|
||||
}
|
||||
|
||||
type ExperimentSeeds struct {
|
||||
|
|
@ -51,17 +51,17 @@ type ExperimentSeeds struct {
|
|||
}
|
||||
|
||||
type Dataset struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type ExperimentWithMetadata struct {
|
||||
Experiment Experiment `json:"experiment"`
|
||||
Environment *ExperimentEnvironment `json:"environment,omitempty"`
|
||||
GitInfo *ExperimentGitInfo `json:"git_info,omitempty"`
|
||||
Seeds *ExperimentSeeds `json:"seeds,omitempty"`
|
||||
Experiment Experiment `json:"experiment"`
|
||||
}
|
||||
|
||||
func (db *DB) UpsertExperiment(ctx context.Context, exp *Experiment) error {
|
||||
|
|
|
|||
|
|
@ -10,30 +10,30 @@ import (
|
|||
|
||||
// Job represents a machine learning job in the system.
|
||||
type Job struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
ID string `json:"id"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Args string `json:"args"`
|
||||
JobName string `json:"job_name"`
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Priority int64 `json:"priority"`
|
||||
}
|
||||
|
||||
// Worker represents a worker node in the system.
|
||||
type Worker struct {
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Hostname string `json:"hostname"`
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
Status string `json:"status"`
|
||||
CurrentJobs int `json:"current_jobs"`
|
||||
MaxJobs int `json:"max_jobs"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// CreateJob inserts a new job into the database.
|
||||
|
|
|
|||
|
|
@ -8,23 +8,23 @@ import (
|
|||
|
||||
// Metric represents a recorded metric from WebSocket connections
|
||||
type Metric struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value float64 `json:"value"`
|
||||
User string `json:"user,omitempty"`
|
||||
RecordedAt time.Time `json:"recorded_at"`
|
||||
Name string `json:"name"`
|
||||
User string `json:"user,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Value float64 `json:"value"`
|
||||
}
|
||||
|
||||
// MetricSummary represents aggregated metric statistics
|
||||
type MetricSummary struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Name string `json:"name"`
|
||||
Count int64 `json:"count"`
|
||||
Avg float64 `json:"avg"`
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
Sum float64 `json:"sum"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
}
|
||||
|
||||
// RecordMetric records a metric to the database
|
||||
|
|
|
|||
Loading…
Reference in a new issue