feat: implement Go backend with comprehensive API and internal packages

- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
This commit is contained in:
Jeremie Fraeys 2025-12-04 16:53:53 -05:00
parent c5049a2fdf
commit 803677be57
62 changed files with 13354 additions and 0 deletions

32
cmd/api-server/README.md Normal file
View file

@ -0,0 +1,32 @@
# API Server
WebSocket API server for the ML CLI tool...
## Usage
```bash
./bin/api-server --config configs/config-dev.yaml --listen :9100
```
## Endpoints
- `GET /health` - Health check
- `WS /ws` - WebSocket endpoint for CLI communication
## Binary Protocol
See [CLI README](../../cli/README.md#websocket-protocol) for protocol details.
## Configuration
Uses the same configuration file as the worker. Experiment base path is read from `base_path` configuration key.
## Example
```bash
# Start API server
./bin/api-server --listen :9100
# In another terminal, test with CLI
./cli/zig-out/bin/ml status
```

363
cmd/api-server/main.go Normal file
View file

@ -0,0 +1,363 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"gopkg.in/yaml.v3"
)
// Config structure matching worker config
type Config struct {
BasePath string `yaml:"base_path"`
Auth auth.AuthConfig `yaml:"auth"`
Server ServerConfig `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
}
type RedisConfig struct {
Addr string `yaml:"addr"`
Password string `yaml:"password"`
DB int `yaml:"db"`
URL string `yaml:"url"`
}
type DatabaseConfig struct {
Type string `yaml:"type"`
Connection string `yaml:"connection"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
type SecurityConfig struct {
RateLimit RateLimitConfig `yaml:"rate_limit"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
}
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
}
type LockoutConfig struct {
Enabled bool `yaml:"enabled"`
MaxAttempts int `yaml:"max_attempts"`
LockoutDuration string `yaml:"lockout_duration"`
}
type ServerConfig struct {
Address string `yaml:"address"`
TLS TLSConfig `yaml:"tls"`
}
type TLSConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func main() {
// Parse flags
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
apiKey := flag.String("api-key", "", "API key for authentication")
flag.Parse()
// Load config
resolvedConfig, err := config.ResolveConfigPath(*configFile)
if err != nil {
log.Fatalf("Failed to resolve config: %v", err)
}
cfg, err := LoadConfig(resolvedConfig)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Ensure log directory exists
if cfg.Logging.File != "" {
logDir := filepath.Dir(cfg.Logging.File)
log.Printf("Creating log directory: %s", logDir)
if err := os.MkdirAll(logDir, 0755); err != nil {
log.Fatalf("Failed to create log directory: %v", err)
}
}
// Setup logging
logger := logging.NewLoggerFromConfig(cfg.Logging)
ctx := logging.EnsureTrace(context.Background())
logger = logger.Component(ctx, "api-server")
// Setup experiment manager
basePath := cfg.BasePath
if basePath == "" {
basePath = "/tmp/ml-experiments"
}
expManager := experiment.NewManager(basePath)
log.Printf("Initializing experiment manager with base_path: %s", basePath)
if err := expManager.Initialize(); err != nil {
logger.Fatal("failed to initialize experiment manager", "error", err)
}
logger.Info("experiment manager initialized", "base_path", basePath)
// Setup auth
var authCfg *auth.AuthConfig
if cfg.Auth.Enabled {
authCfg = &cfg.Auth
logger.Info("authentication enabled")
}
// Setup HTTP server with security middleware
mux := http.NewServeMux()
// Convert API keys from map to slice for security middleware
apiKeys := make([]string, 0, len(cfg.Auth.APIKeys))
for username := range cfg.Auth.APIKeys {
// For now, use username as the key (in production, this should be the actual API key)
apiKeys = append(apiKeys, string(username))
}
// Create security middleware
sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"))
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: cfg.Redis.Addr,
RedisPassword: cfg.Redis.Password,
RedisDB: cfg.Redis.DB,
}
if queueCfg.RedisAddr == "" {
queueCfg.RedisAddr = config.DefaultRedisAddr
}
// Support URL format for Redis
if cfg.Redis.URL != "" {
queueCfg.RedisAddr = cfg.Redis.URL
}
taskQueue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
logger.Error("failed to initialize task queue", "error", err)
// We continue without queue, but queue operations will fail
} else {
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
defer func() {
logger.Info("stopping task queue...")
if err := taskQueue.Close(); err != nil {
logger.Error("failed to stop task queue", "error", err)
} else {
logger.Info("task queue stopped")
}
}()
}
// Setup database if configured
var db *storage.DB
if cfg.Database.Type != "" {
dbConfig := storage.DBConfig{
Type: cfg.Database.Type,
Connection: cfg.Database.Connection,
Host: cfg.Database.Host,
Port: cfg.Database.Port,
Username: cfg.Database.Username,
Password: cfg.Database.Password,
Database: cfg.Database.Database,
}
db, err = storage.NewDB(dbConfig)
if err != nil {
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
} else {
// Load appropriate database schema
var schemaPath string
if cfg.Database.Type == "sqlite" {
schemaPath = "internal/storage/schema.sql"
} else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" {
schemaPath = "internal/storage/schema_postgres.sql"
} else {
logger.Error("unsupported database type", "type", cfg.Database.Type)
db.Close()
db = nil
}
if db != nil && schemaPath != "" {
schema, err := os.ReadFile(schemaPath)
if err != nil {
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
db.Close()
db = nil
} else {
if err := db.Initialize(string(schema)); err != nil {
logger.Error("failed to initialize database schema", "error", err)
db.Close()
db = nil
} else {
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
defer func() {
logger.Info("closing database connection...")
if err := db.Close(); err != nil {
logger.Error("failed to close database", "error", err)
} else {
logger.Info("database connection closed")
}
}()
}
}
}
}
}
// Setup WebSocket handler with authentication
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
// WebSocket endpoint - no middleware to avoid hijacking issues
mux.Handle("/ws", wsHandler)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "OK\n")
})
// Database status endpoint
mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if db != nil {
// Test database connection with a simple query
var result struct {
Status string `json:"status"`
Type string `json:"type"`
Path string `json:"path"`
Message string `json:"message"`
}
result.Status = "connected"
result.Type = "sqlite"
result.Path = cfg.Database.Connection
result.Message = "SQLite database is operational"
// Test a simple query to verify connectivity
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
result.Status = "error"
result.Message = fmt.Sprintf("Database query failed: %v", err)
}
jsonBytes, _ := json.Marshal(result)
w.Write(jsonBytes)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
}
})
// Apply security middleware to all routes except WebSocket
// Create separate handlers for WebSocket vs other routes
var finalHandler http.Handler = mux
// Wrap non-websocket routes with security middleware
finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ws" {
mux.ServeHTTP(w, r)
} else {
// Apply middleware chain for non-WebSocket routes
handler := sec.RateLimit(mux)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
// Apply audit logger and IP whitelist only to non-WebSocket routes
handler = middleware.AuditLogger(handler)
if len(cfg.Security.IPWhitelist) > 0 {
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
}
handler.ServeHTTP(w, r)
}
})
var handler http.Handler = finalHandler
server := &http.Server{
Addr: cfg.Server.Address,
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
if !cfg.Server.TLS.Enabled {
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
}
// Start server in goroutine
go func() {
// Setup TLS if configured
if cfg.Server.TLS.Enabled {
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed {
logger.Error("HTTPS server failed", "error", err)
}
} else {
logger.Info("starting HTTP server", "address", cfg.Server.Address)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server failed", "error", err)
}
}
os.Exit(1)
}()
// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
logger.Info("received shutdown signal", "signal", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
logger.Info("shutting down http server...")
if err := server.Shutdown(ctx); err != nil {
logger.Error("server shutdown error", "error", err)
} else {
logger.Info("http server shutdown complete")
}
logger.Info("api server stopped")
_ = expManager // Use expManager to avoid unused warning
_ = apiKey // Will be used for auth later
}

116
cmd/configlint/main.go Normal file
View file

@ -0,0 +1,116 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/xeipuuv/gojsonschema"
"gopkg.in/yaml.v3"
)
func main() {
var (
schemaPath string
failFast bool
)
flag.StringVar(&schemaPath, "schema", "configs/schema.yaml", "Path to JSON schema in YAML format")
flag.BoolVar(&failFast, "fail-fast", false, "Stop on first error")
flag.Parse()
if flag.NArg() == 0 {
log.Fatalf("usage: configlint [--schema path] [--fail-fast] <config files...>")
}
schemaLoader, err := loadSchema(schemaPath)
if err != nil {
log.Fatalf("failed to load schema: %v", err)
}
var hadError bool
for _, configPath := range flag.Args() {
if err := validateConfig(schemaLoader, configPath); err != nil {
hadError = true
fmt.Fprintf(os.Stderr, "configlint: %s: %v\n", configPath, err)
if failFast {
os.Exit(1)
}
}
}
if hadError {
os.Exit(1)
}
fmt.Println("All configuration files are valid.")
}
func loadSchema(schemaPath string) (gojsonschema.JSONLoader, error) {
data, err := os.ReadFile(schemaPath)
if err != nil {
return nil, err
}
var schemaYAML interface{}
if err := yaml.Unmarshal(data, &schemaYAML); err != nil {
return nil, err
}
schemaJSON, err := json.Marshal(schemaYAML)
if err != nil {
return nil, err
}
tmpFile, err := os.CreateTemp("", "fetchml-schema-*.json")
if err != nil {
return nil, err
}
defer tmpFile.Close()
if _, err := tmpFile.Write(schemaJSON); err != nil {
return nil, err
}
return gojsonschema.NewReferenceLoader("file://" + filepath.ToSlash(tmpFile.Name())), nil
}
func validateConfig(schemaLoader gojsonschema.JSONLoader, configPath string) error {
data, err := os.ReadFile(configPath)
if err != nil {
return err
}
var configYAML interface{}
if err := yaml.Unmarshal(data, &configYAML); err != nil {
return fmt.Errorf("failed to parse YAML: %w", err)
}
configJSON, err := json.Marshal(configYAML)
if err != nil {
return err
}
result, err := gojsonschema.Validate(schemaLoader, gojsonschema.NewBytesLoader(configJSON))
if err != nil {
return err
}
if result.Valid() {
fmt.Printf("%s: valid\n", configPath)
return nil
}
var builder strings.Builder
for _, issue := range result.Errors() {
builder.WriteString("- ")
builder.WriteString(issue.String())
builder.WriteByte('\n')
}
return fmt.Errorf("validation failed:\n%s", builder.String())
}

View file

@ -0,0 +1,132 @@
// DataConfig holds the configuration for the data manager
package main
import (
"fmt"
"os"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"gopkg.in/yaml.v3"
)
type DataConfig struct {
// ML Server (where training runs)
MLHost string `yaml:"ml_host"`
MLUser string `yaml:"ml_user"`
MLSSHKey string `yaml:"ml_ssh_key"`
MLPort int `yaml:"ml_port"`
MLDataDir string `yaml:"ml_data_dir"` // e.g., /data/active
// NAS (where datasets are stored)
NASHost string `yaml:"nas_host"`
NASUser string `yaml:"nas_user"`
NASSSHKey string `yaml:"nas_ssh_key"`
NASPort int `yaml:"nas_port"`
NASDataDir string `yaml:"nas_data_dir"` // e.g., /mnt/datasets
// Redis
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
// Authentication
Auth auth.AuthConfig `yaml:"auth"`
// Cleanup settings
MaxAgeHours int `yaml:"max_age_hours"` // Delete data older than X hours
MaxSizeGB int `yaml:"max_size_gb"` // Keep total size under X GB
CleanupInterval int `yaml:"cleanup_interval_min"` // Run cleanup every X minutes
// Podman integration
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUAccess bool `yaml:"gpu_access"`
}
func LoadDataConfig(path string) (*DataConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg DataConfig
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Defaults
if cfg.MLPort == 0 {
cfg.MLPort = config.DefaultSSHPort
}
if cfg.NASPort == 0 {
cfg.NASPort = config.DefaultSSHPort
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = config.DefaultRedisAddr
}
// Set default MLDataDir - use ./data/active for local/dev, /data/active for production
if cfg.MLDataDir == "" {
if cfg.MLHost == "" {
// Local mode - use local data directory
cfg.MLDataDir = config.DefaultLocalDataDir
} else {
// Production mode - use /data/active
cfg.MLDataDir = config.DefaultDataDir
}
}
if cfg.NASDataDir == "" {
cfg.NASDataDir = config.DefaultNASDataDir
}
// Expand paths
cfg.MLDataDir = config.ExpandPath(cfg.MLDataDir)
cfg.NASDataDir = config.ExpandPath(cfg.NASDataDir)
if cfg.MaxAgeHours == 0 {
cfg.MaxAgeHours = config.DefaultMaxAgeHours
}
if cfg.MaxSizeGB == 0 {
cfg.MaxSizeGB = config.DefaultMaxSizeGB
}
if cfg.CleanupInterval == 0 {
cfg.CleanupInterval = config.DefaultCleanupInterval
}
return &cfg, nil
}
// Validate implements utils.Validator interface
func (c *DataConfig) Validate() error {
if c.MLPort != 0 {
if err := config.ValidatePort(c.MLPort); err != nil {
return fmt.Errorf("invalid ML SSH port: %w", err)
}
}
if c.NASPort != 0 {
if err := config.ValidatePort(c.NASPort); err != nil {
return fmt.Errorf("invalid NAS SSH port: %w", err)
}
}
if c.RedisAddr != "" {
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
if c.MaxAgeHours < 1 {
return fmt.Errorf("max_age_hours must be at least 1, got %d", c.MaxAgeHours)
}
if c.MaxSizeGB < 1 {
return fmt.Errorf("max_size_gb must be at least 1, got %d", c.MaxSizeGB)
}
if c.CleanupInterval < 1 {
return fmt.Errorf("cleanup_interval must be at least 1, got %d", c.CleanupInterval)
}
return nil
}

View file

@ -0,0 +1,775 @@
// data_manager.go - Fetch data from NAS to ML server on-demand
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"log/slog"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errors"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/telemetry"
)
// SSHClient alias for convenience
type SSHClient = network.SSHClient
type DataManager struct {
config *DataConfig
mlServer *SSHClient
nasServer *SSHClient
taskQueue *queue.TaskQueue
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
}
type DataFetchRequest struct {
JobName string `json:"job_name"`
Datasets []string `json:"datasets"` // Dataset names to fetch
Priority int `json:"priority"`
RequestedAt time.Time `json:"requested_at"`
}
type DatasetInfo struct {
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
Location string `json:"location"` // "nas" or "ml"
LastAccess time.Time `json:"last_access"`
}
func NewDataManager(cfg *DataConfig, apiKey string) (*DataManager, error) {
mlServer, err := network.NewSSHClient(cfg.MLHost, cfg.MLUser, cfg.MLSSHKey, cfg.MLPort, "")
if err != nil {
return nil, fmt.Errorf("ML server connection failed: %w", err)
}
defer func() {
if err != nil {
if closeErr := mlServer.Close(); closeErr != nil {
log.Printf("Warning: failed to close ML server connection: %v", closeErr)
}
}
}()
nasServer, err := network.NewSSHClient(cfg.NASHost, cfg.NASUser, cfg.NASSSHKey, cfg.NASPort, "")
if err != nil {
return nil, fmt.Errorf("NAS connection failed: %w", err)
}
defer func() {
if err != nil {
if closeErr := nasServer.Close(); closeErr != nil {
log.Printf("Warning: failed to close NAS server connection: %v", closeErr)
}
}
}()
// Create MLDataDir if it doesn't exist (for production without NAS)
if cfg.MLDataDir != "" {
if _, err := mlServer.Exec(fmt.Sprintf("mkdir -p %s", cfg.MLDataDir)); err != nil {
logger := logging.NewLogger(slog.LevelInfo, false)
logger.Job(context.Background(), "data_manager", "").Error("Failed to create ML data directory", "dir", cfg.MLDataDir, "error", err)
}
}
// Setup Redis using internal queue
ctx, cancel := context.WithCancel(context.Background())
logger := logging.NewLogger(slog.LevelInfo, false)
var taskQueue *queue.TaskQueue
if cfg.RedisAddr != "" {
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
}
var err error
taskQueue, err = queue.NewTaskQueue(queueCfg)
if err != nil {
// FIXED: Check error return values for cleanup
if closeErr := mlServer.Close(); closeErr != nil {
logger.Warn("failed to close ML server during error cleanup", "error", closeErr)
}
if closeErr := nasServer.Close(); closeErr != nil {
logger.Warn("failed to close NAS server during error cleanup", "error", closeErr)
}
cancel() // Cancel context to prevent leak
return nil, fmt.Errorf("redis connection failed: %w", err)
}
} else {
taskQueue = nil // Local mode - no Redis
}
return &DataManager{
config: cfg,
mlServer: mlServer,
nasServer: nasServer,
taskQueue: taskQueue,
ctx: ctx,
cancel: cancel,
logger: logger,
}, nil
}
func (dm *DataManager) FetchDataset(jobName, datasetName string) error {
ctx, cancel := context.WithTimeout(dm.ctx, 30*time.Minute)
defer cancel()
return network.RetryForNetworkOperations(ctx, func() error {
return dm.fetchDatasetInternal(ctx, jobName, datasetName)
})
}
func (dm *DataManager) fetchDatasetInternal(ctx context.Context, jobName, datasetName string) error {
if err := container.ValidateJobName(datasetName); err != nil {
return &errors.DataFetchError{
Dataset: datasetName,
JobName: jobName,
Err: fmt.Errorf("invalid dataset name: %w", err),
}
}
logger := dm.logger.Job(ctx, jobName, "")
logger.Info("fetching dataset", "dataset", datasetName)
// Validate dataset size and run cleanup if needed
if err := dm.ValidateDatasetWithCleanup(datasetName); err != nil {
return &errors.DataFetchError{
Dataset: datasetName,
JobName: jobName,
Err: fmt.Errorf("dataset size validation failed: %w", err),
}
}
nasPath := filepath.Join(dm.config.NASDataDir, datasetName)
mlPath := filepath.Join(dm.config.MLDataDir, datasetName)
// Check if dataset exists on NAS
if !dm.nasServer.FileExists(nasPath) {
return &errors.DataFetchError{
Dataset: datasetName,
JobName: jobName,
Err: fmt.Errorf("dataset not found on NAS"),
}
}
// Check if already on ML server
if dm.mlServer.FileExists(mlPath) {
logger.Info("dataset already on ML server", "dataset", datasetName)
dm.updateLastAccess(datasetName)
return nil
}
// Get size for progress tracking
size, err := dm.nasServer.GetFileSize(nasPath)
if err != nil {
logger.Warn("could not get dataset size", "dataset", datasetName, "error", err)
size = 0
}
sizeGB := float64(size) / (1024 * 1024 * 1024)
logger.Info("transferring dataset",
"dataset", datasetName,
"size_gb", sizeGB,
"nas_path", nasPath,
"ml_path", mlPath)
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "transferring",
"job_name", jobName,
"size_bytes", size,
"started_at", time.Now().Unix()).Err(); err != nil {
logger.Warn("failed to record transfer start in Redis", "error", err)
}
}
// Use local copy for local mode, rsync for remote mode
var rsyncCmd string
if dm.config.NASHost == "" || dm.config.NASUser == "" {
// Local mode - use cp
rsyncCmd = fmt.Sprintf("mkdir -p %s && cp -r %s %s/", dm.config.MLDataDir, nasPath, mlPath)
} else {
// Remote mode - use rsync over SSH
rsyncCmd = fmt.Sprintf(
"mkdir -p %s && rsync -avz --progress %s@%s:%s/ %s/",
dm.config.MLDataDir,
dm.config.NASUser,
dm.config.NASHost,
nasPath,
mlPath,
)
}
ioBefore, ioErr := telemetry.ReadProcessIO()
start := time.Now()
out, err := telemetry.ExecWithMetrics(dm.logger, "dataset transfer", time.Since(start), func() (string, error) {
return dm.nasServer.ExecContext(ctx, rsyncCmd)
})
duration := time.Since(start)
if err != nil {
logger.Error("transfer failed",
"dataset", datasetName,
"duration", duration,
"error", err,
"output", out)
if ioErr == nil {
if after, readErr := telemetry.ReadProcessIO(); readErr == nil {
delta := telemetry.DiffIO(ioBefore, after)
logger.Debug("transfer io stats",
"dataset", datasetName,
"read_bytes", delta.ReadBytes,
"write_bytes", delta.WriteBytes)
}
}
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "failed",
"error", err.Error()).Err(); redisErr != nil {
logger.Warn("failed to record transfer failure in Redis", "error", redisErr)
}
}
return err
}
logger.Info("transfer complete",
"dataset", datasetName,
"duration", duration,
"size_gb", sizeGB)
if ioErr == nil {
if after, readErr := telemetry.ReadProcessIO(); readErr == nil {
delta := telemetry.DiffIO(ioBefore, after)
logger.Debug("transfer io stats",
"dataset", datasetName,
"read_bytes", delta.ReadBytes,
"write_bytes", delta.WriteBytes)
}
}
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
"status", "completed",
"completed_at", time.Now().Unix(),
"duration_seconds", duration.Seconds()).Err(); err != nil {
logger.Warn("failed to record transfer completion in Redis", "error", err)
}
}
// Track dataset metadata
dm.saveDatasetInfo(datasetName, size)
return nil
}
func (dm *DataManager) saveDatasetInfo(name string, size int64) {
if dm.taskQueue == nil {
return // Skip in local mode
}
info := DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
LastAccess: time.Now(),
}
data, _ := json.Marshal(info)
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
"dataset", name, "error", err)
}
}
}
func (dm *DataManager) updateLastAccess(name string) {
if dm.taskQueue == nil {
return // Skip in local mode
}
key := fmt.Sprintf("ml:dataset:%s", name)
redisClient := dm.taskQueue.GetRedisClient()
data, err := redisClient.Get(dm.ctx, key).Result()
if err != nil {
return
}
var info DatasetInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return
}
info.LastAccess = time.Now()
newData, _ := json.Marshal(info)
redisClient = dm.taskQueue.GetRedisClient()
if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis",
"dataset", name, "error", err)
}
}
// ListDatasetsOnML returns a list of all datasets currently stored on the ML server.
func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) {
out, err := dm.mlServer.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", dm.config.MLDataDir))
if err != nil {
return nil, err
}
var datasets []DatasetInfo
for name := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
if name == "" {
continue
}
var info DatasetInfo
// Only use Redis if available
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
key := fmt.Sprintf("ml:dataset:%s", name)
data, err := redisClient.Get(dm.ctx, key).Result()
if err == nil {
if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil {
// Fallback to disk if unmarshal fails
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
info = DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
}
}
} else {
// Fallback: get from disk
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
info = DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
}
}
} else {
// Local mode: get from disk
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
info = DatasetInfo{
Name: name,
SizeBytes: size,
Location: "ml",
}
}
datasets = append(datasets, info)
}
return datasets, nil
}
func (dm *DataManager) CleanupOldData() error {
logger := dm.logger.Job(dm.ctx, "data_manager", "")
logger.Info("running data cleanup")
datasets, err := dm.ListDatasetsOnML()
if err != nil {
return err
}
var totalSize int64
for _, ds := range datasets {
totalSize += ds.SizeBytes
}
totalSizeGB := float64(totalSize) / (1024 * 1024 * 1024)
logger.Info("current storage usage",
"total_size_gb", totalSizeGB,
"dataset_count", len(datasets))
// Delete datasets older than max age or if over size limit
maxAge := time.Duration(dm.config.MaxAgeHours) * time.Hour
maxSize := int64(dm.config.MaxSizeGB) * 1024 * 1024 * 1024
var deleted []string
for _, ds := range datasets {
shouldDelete := false
// Check age
if !ds.LastAccess.IsZero() && time.Since(ds.LastAccess) > maxAge {
logger.Info("dataset is old, marking for deletion",
"dataset", ds.Name,
"last_access", ds.LastAccess,
"age_hours", time.Since(ds.LastAccess).Hours())
shouldDelete = true
}
// Check if over size limit
if totalSize > maxSize {
logger.Info("over size limit, deleting oldest dataset",
"dataset", ds.Name,
"current_size_gb", totalSizeGB,
"max_size_gb", dm.config.MaxSizeGB)
shouldDelete = true
}
if shouldDelete {
path := filepath.Join(dm.config.MLDataDir, ds.Name)
logger.Info("deleting dataset", "dataset", ds.Name, "path", path)
if _, err := dm.mlServer.Exec(fmt.Sprintf("rm -rf %s", path)); err != nil {
logger.Error("failed to delete dataset",
"dataset", ds.Name,
"error", err)
continue
}
deleted = append(deleted, ds.Name)
totalSize -= ds.SizeBytes
// FIXED: Remove from Redis only if available, with error handling
if dm.taskQueue != nil {
redisClient := dm.taskQueue.GetRedisClient()
if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil {
logger.Warn("failed to delete dataset from Redis",
"dataset", ds.Name,
"error", err)
}
}
}
}
if len(deleted) > 0 {
logger.Info("cleanup complete",
"deleted_count", len(deleted),
"deleted_datasets", deleted)
} else {
logger.Info("cleanup complete", "deleted_count", 0)
}
return nil
}
// GetAvailableDiskSpace returns available disk space in bytes
func (dm *DataManager) GetAvailableDiskSpace() int64 {
logger := dm.logger.Job(dm.ctx, "data_manager", "")
// Check disk space on ML server
cmd := "df -k " + dm.config.MLDataDir + " | tail -1 | awk '{print $4}'"
output, err := dm.mlServer.Exec(cmd)
if err != nil {
logger.Error("failed to get disk space", "error", err)
return 0
}
// Parse KB to bytes
var freeKB int64
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &freeKB)
if err != nil {
logger.Error("failed to parse disk space", "error", err, "output", output)
return 0
}
return freeKB * 1024 // Convert KB to bytes
}
// GetDatasetInfo returns information about a dataset from NAS
func (dm *DataManager) GetDatasetInfo(datasetName string) (*DatasetInfo, error) {
// Check if dataset exists on NAS
nasPath := filepath.Join(dm.config.NASDataDir, datasetName)
cmd := fmt.Sprintf("test -d %s && echo 'exists'", nasPath)
output, err := dm.nasServer.Exec(cmd)
if err != nil || strings.TrimSpace(output) != "exists" {
return nil, fmt.Errorf("dataset %s not found on NAS", datasetName)
}
// Get dataset size
cmd = fmt.Sprintf("du -sb %s | cut -f1", nasPath)
output, err = dm.nasServer.Exec(cmd)
if err != nil {
return nil, fmt.Errorf("failed to get dataset size: %w", err)
}
var sizeBytes int64
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &sizeBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse dataset size: %w", err)
}
// Get last modification time as proxy for last access
cmd = fmt.Sprintf("stat -c %%Y %s", nasPath)
output, err = dm.nasServer.Exec(cmd)
if err != nil {
return nil, fmt.Errorf("failed to get dataset timestamp: %w", err)
}
var modTime int64
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &modTime)
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %w", err)
}
return &DatasetInfo{
Name: datasetName,
SizeBytes: sizeBytes,
Location: "nas",
LastAccess: time.Unix(modTime, 0),
}, nil
}
// ValidateDatasetWithCleanup checks if dataset fits and runs cleanup if needed
func (dm *DataManager) ValidateDatasetWithCleanup(datasetName string) error {
logger := dm.logger.Job(dm.ctx, "data_manager", "")
// Get dataset info
info, err := dm.GetDatasetInfo(datasetName)
if err != nil {
return fmt.Errorf("failed to get dataset info: %w", err)
}
// Check current available space
availableSpace := dm.GetAvailableDiskSpace()
logger.Info("dataset size validation",
"dataset", datasetName,
"dataset_size_gb", float64(info.SizeBytes)/(1024*1024*1024),
"available_gb", float64(availableSpace)/(1024*1024*1024))
// If enough space, proceed
if info.SizeBytes <= availableSpace {
logger.Info("sufficient space available", "dataset", datasetName)
return nil
}
// Try cleanup first
logger.Info("insufficient space, running cleanup",
"dataset", datasetName,
"required_gb", float64(info.SizeBytes)/(1024*1024*1024),
"available_gb", float64(availableSpace)/(1024*1024*1024))
if err := dm.CleanupOldData(); err != nil {
return fmt.Errorf("cleanup failed: %w", err)
}
// Check space again after cleanup
availableSpace = dm.GetAvailableDiskSpace()
logger.Info("space after cleanup",
"available_gb", float64(availableSpace)/(1024*1024*1024))
// If now enough space, proceed
if info.SizeBytes <= availableSpace {
logger.Info("cleanup freed enough space", "dataset", datasetName)
return nil
}
// Still not enough space
return fmt.Errorf("dataset %s (%.2fGB) too large for available space (%.2fGB) even after cleanup",
datasetName,
float64(info.SizeBytes)/(1024*1024*1024),
float64(availableSpace)/(1024*1024*1024))
}
func (dm *DataManager) StartCleanupLoop() {
logger := dm.logger.Job(dm.ctx, "data_manager", "")
ticker := time.NewTicker(time.Duration(dm.config.CleanupInterval) * time.Minute)
go func() {
defer ticker.Stop()
for {
select {
case <-dm.ctx.Done():
logger.Info("cleanup loop stopping")
return
case <-ticker.C:
if err := dm.CleanupOldData(); err != nil {
logger.Error("cleanup error", "error", err)
}
}
}
}()
}
// Close gracefully shuts down the DataManager, stopping the cleanup loop and
// closing all connections to ML server, NAS server, and Redis.
func (dm *DataManager) Close() {
dm.cancel() // Cancel context to stop cleanup loop
// Wait a moment for cleanup loop to finish
time.Sleep(100 * time.Millisecond)
if dm.mlServer != nil {
if err := dm.mlServer.Close(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing ML server connection", "error", err)
}
}
if dm.nasServer != nil {
if err := dm.nasServer.Close(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing NAS server connection", "error", err)
}
}
if dm.taskQueue != nil {
if err := dm.taskQueue.Close(); err != nil {
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing Redis connection", "error", err)
}
}
}
func main() {
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateAuthFlags(authFlags); err != nil {
log.Fatalf("Authentication flag error: %v", err)
}
// Get API key from various sources
apiKey := auth.GetAPIKeyFromSources(authFlags)
configFile := "configs/config-local.yaml"
if authFlags.ConfigFile != "" {
configFile = authFlags.ConfigFile
}
// Parse command line args
if len(os.Args) < 2 {
fmt.Println("Usage:")
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] fetch <job-name> <dataset> [dataset...]")
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] list")
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] cleanup")
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] validate <dataset>")
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] daemon")
fmt.Println()
auth.PrintAuthHelp()
os.Exit(1)
}
// Check for --config flag
if len(os.Args) >= 3 && os.Args[1] == "--config" {
configFile = os.Args[2]
// Shift args
os.Args = append([]string{os.Args[0]}, os.Args[3:]...)
}
cfg, err := LoadDataConfig(configFile)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
}
// Validate configuration
if err := cfg.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Test authentication if enabled
if cfg.Auth.Enabled && apiKey != "" {
user, err := cfg.Auth.ValidateAPIKey(apiKey)
if err != nil {
log.Fatalf("Authentication failed: %v", err)
}
log.Printf("Data manager authenticated as user: %s (admin: %v)", user.Name, user.Admin)
} else if cfg.Auth.Enabled {
log.Fatal("Authentication required but no API key provided")
}
dm, err := NewDataManager(cfg, apiKey)
if err != nil {
log.Fatalf("Failed to create data manager: %v", err)
}
defer dm.Close()
cmd := os.Args[1]
switch cmd {
case "fetch":
if len(os.Args) < 4 {
log.Fatal("Usage: data_manager fetch <job-name> <dataset> [dataset...]")
}
jobName := os.Args[2]
datasets := os.Args[3:]
for _, dataset := range datasets {
if err := dm.FetchDataset(jobName, dataset); err != nil {
dm.logger.Job(context.Background(), jobName, "").Error("failed to fetch dataset",
"dataset", dataset,
"error", err)
}
}
case "list":
datasets, err := dm.ListDatasetsOnML()
if err != nil {
log.Fatalf("Failed to list datasets: %v", err)
}
fmt.Println("Datasets on ML server:")
fmt.Println("======================")
var totalSize int64
for _, ds := range datasets {
sizeMB := float64(ds.SizeBytes) / (1024 * 1024)
lastAccess := "unknown"
if !ds.LastAccess.IsZero() {
lastAccess = ds.LastAccess.Format("2006-01-02 15:04:05")
}
fmt.Printf("%-30s %10.2f MB Last access: %s\n", ds.Name, sizeMB, lastAccess)
totalSize += ds.SizeBytes
}
fmt.Printf("\nTotal: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
case "validate":
if len(os.Args) < 3 {
log.Fatal("Usage: data_manager validate <dataset>")
}
dataset := os.Args[2]
fmt.Printf("Validating dataset: %s\n", dataset)
if err := dm.ValidateDatasetWithCleanup(dataset); err != nil {
log.Fatalf("Validation failed: %v", err)
}
fmt.Printf("✅ Dataset %s can be downloaded\n", dataset)
case "cleanup":
if err := dm.CleanupOldData(); err != nil {
log.Fatalf("Cleanup failed: %v", err)
}
case "daemon":
logger := dm.logger.Job(context.Background(), "data_manager", "")
logger.Info("starting data manager daemon")
dm.StartCleanupLoop()
logger.Info("cleanup configuration",
"interval_minutes", cfg.CleanupInterval,
"max_age_hours", cfg.MaxAgeHours,
"max_size_gb", cfg.MaxSizeGB)
// Handle graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
logger.Info("received shutdown signal", "signal", sig)
dm.Close()
logger.Info("data manager shut down gracefully")
default:
log.Fatalf("Unknown command: %s", cmd)
}
}

282
cmd/tui/README.md Normal file
View file

@ -0,0 +1,282 @@
# FetchML TUI - Terminal User Interface
An interactive terminal dashboard for managing ML experiments, monitoring system resources, and controlling job execution.
## Features
### 📊 Real-time Monitoring
- **Job Status** - Track pending, running, finished, and failed jobs
- **GPU Metrics** - Monitor GPU utilization, memory, and temperature
- **Container Status** - View running Podman/Docker containers
- **Task Queue** - See queued tasks with priorities and status
### 🎮 Interactive Controls
- **Queue Jobs** - Submit jobs with custom arguments and priorities
- **View Logs** - Real-time log viewing for running jobs
- **Cancel Tasks** - Stop running tasks
- **Delete Jobs** - Remove pending jobs
- **Mark Failed** - Manually mark stuck jobs as failed
### ⚙️ Settings Management
- **API Key Configuration** - Set and update API keys on the fly
- **In-memory Storage** - Settings persist for the session
### 🎨 Modern UI
- **Clean Design** - Dark-mode friendly with adaptive colors
- **Responsive Layout** - Adjusts to terminal size
- **Context-aware Help** - Shows relevant shortcuts for each view
- **Mouse Support** - Optional mouse navigation
## Quick Start
### Running the TUI
```bash
# Using make (recommended)
make tui-dev # Dev mode (remote server)
make tui # Local mode
# Direct execution with CLI config (TOML)
./bin/tui --config ~/.ml/config.toml
# With custom TOML config
./bin/tui --config path/to/config.toml
```
### First Time Setup
1. **Build the binary**
```bash
make build
```
2. **Get your API key**
```bash
./bin/user_manager --config configs/config_dev.yaml --cmd generate-key --username your_name
```
3. **Launch the TUI**
```bash
make tui-dev
```
## Keyboard Shortcuts
### Navigation
| Key | Action |
|-----|--------|
| `1` | Switch to Job List view |
| `g` | Switch to GPU Status view |
| `l` | View logs for selected job |
| `v` | Switch to Task Queue view |
| `o` | Switch to Container Status view |
| `s` | Open Settings |
| `h` or `?` | Toggle help screen |
### Job Management
| Key | Action |
|-----|--------|
| `t` | Queue selected job (default args) |
| `a` | Queue job with custom arguments |
| `c` | Cancel running task |
| `d` | Delete pending job |
| `f` | Mark running job as failed |
### System
| Key | Action |
|-----|--------|
| `r` | Refresh all data |
| `G` | Refresh GPU status only |
| `q` or `Ctrl+C` | Quit |
### Settings View
| Key | Action |
|-----|--------|
| `↑`/`↓` or `j`/`k` | Navigate options |
| `Enter` | Select/Save |
| `Esc` | Exit settings |
## Views
### Job List (Default)
- Shows all jobs across all statuses
- Filter with `/` key
- Navigate with arrow keys or `j`/`k`
- Select and press `l` to view logs
### GPU Status
- Real-time GPU metrics (nvidia-smi)
- macOS GPU info (system_profiler)
- Utilization, memory, temperature
### Container Status
- Running Podman/Docker containers
- Container health and status
- System info (Podman/Docker version)
### Task Queue
- All queued tasks with priorities
- Task status and creation time
- Running duration for active tasks
### Logs
- Last 200 lines of job output
- Auto-scroll to bottom
- Refreshes with job status
### Settings
- View current API key status
- Update API key
- Save configuration (in-memory)
## Terminal Compatibility
The TUI is built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) and works on all modern terminals:
### ✅ Fully Supported
- **WezTerm** (recommended)
- **Alacritty**
- **Kitty**
- **iTerm2** (macOS)
- **Terminal.app** (macOS)
- **Windows Terminal**
- **GNOME Terminal**
- **Konsole**
### ✅ Multiplexers
- **tmux**
- **screen**
### Features
- ✅ 256 colors
- ✅ True color (24-bit)
- ✅ Mouse support
- ✅ Alt screen buffer
- ✅ Adaptive colors (light/dark themes)
### Key Components
- **Model** - Pure data structures (State, Job, Task)
- **View** - Rendering functions (no business logic)
- **Controller** - Message handling and state updates
- **Services** - SSH/Redis communication
## Configuration
The TUI uses TOML configuration format for CLI settings:
```toml
# ~/.ml/config.toml
worker_host = "localhost"
worker_user = "your_user"
worker_base = "~/ml_jobs"
worker_port = 22
api_key = "your_api_key_here"
```
For CLI usage, run `ml init` to create a default configuration file.
See [Configuration Documentation](../docs/documentation.md#configuration) for details.
## Troubleshooting
### TUI doesn't start
```bash
# Check if binary exists
ls -la bin/tui
# Rebuild if needed
make build
# Check CLI config
cat ~/.ml/config.toml
```
### Authentication errors
```bash
# Verify CLI config exists
ls -la ~/.ml/config.toml
# Initialize CLI config if needed
ml init
# Test connection
./bin/tui --config ~/.ml/config.toml
```
### Display issues
```bash
# Check terminal type
echo $TERM
# Should be xterm-256color or similar
# If not, set it:
export TERM=xterm-256color
```
### Connection issues
```bash
# Test SSH connection
ssh your_user@your_server
# Test Redis connection
redis-cli ping
```
## Development
### Building
```bash
# Build TUI only
go build -o bin/tui ./cmd/tui
# Build all binaries
make build
```
### Testing
```bash
# Run with verbose logging
./bin/tui --config ~/.ml/config.toml 2>tui.log
# Check logs
tail -f tui.log
```
### Code Organization
- Keep files under 300 lines
- Separate concerns (MVC pattern)
- Use descriptive function names
- Add comments for complex logic
## Tips & Tricks
### Efficient Workflow
1. Keep TUI open in one terminal
2. Edit code in another terminal
3. Use `r` to refresh after changes
4. Use `h` to quickly reference shortcuts
### Custom Arguments
When queuing jobs with `a`:
```
--epochs 100 --lr 0.001 --priority 5
```
### Monitoring
- Use `G` for quick GPU refresh (faster than `r`)
- Check queue with `v` before queuing new jobs
- Use `l` to debug failed jobs
### Settings
- Update API key without restarting
- Changes are in-memory only
- Restart TUI to reset
## See Also
- [Main Documentation](../docs/documentation.md)
- [Worker Documentation](../cmd/worker/README.md)
- [Configuration Guide](../docs/documentation.md#configuration)
- [Bubble Tea Documentation](https://github.com/charmbracelet/bubbletea)

View file

@ -0,0 +1,492 @@
package config
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/auth"
utils "github.com/jfraeys/fetch_ml/internal/config"
"github.com/stretchr/testify/assert/yaml"
)
// CLIConfig represents the TOML config structure used by the CLI
type CLIConfig struct {
WorkerHost string `toml:"worker_host"`
WorkerUser string `toml:"worker_user"`
WorkerBase string `toml:"worker_base"`
WorkerPort int `toml:"worker_port"`
APIKey string `toml:"api_key"`
// User context (filled after authentication)
CurrentUser *UserContext `toml:"-"`
}
// UserContext represents the authenticated user information
type UserContext struct {
Name string `json:"name"`
Admin bool `json:"admin"`
Roles []string `json:"roles"`
Permissions map[string]bool `json:"permissions"`
}
// LoadCLIConfig loads the CLI's TOML configuration from the provided path.
// If path is empty, ~/.ml/config.toml is used. The resolved path is returned.
// Automatically migrates from YAML config if TOML doesn't exist.
// Environment variables with FETCH_ML_CLI_ prefix override config file values.
func LoadCLIConfig(configPath string) (*CLIConfig, string, error) {
if configPath == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, "", fmt.Errorf("failed to get home directory: %w", err)
}
configPath = filepath.Join(home, ".ml", "config.toml")
} else {
configPath = utils.ExpandPath(configPath)
if !filepath.IsAbs(configPath) {
if abs, err := filepath.Abs(configPath); err == nil {
configPath = abs
}
}
}
// Check if TOML config exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
// Try to migrate from YAML
yamlPath := strings.TrimSuffix(configPath, ".toml") + ".yaml"
if migratedPath, err := migrateFromYAML(yamlPath, configPath); err == nil {
log.Printf("Migrated configuration from %s to %s", yamlPath, migratedPath)
configPath = migratedPath
} else {
return nil, configPath, fmt.Errorf("CLI config not found at %s (run 'ml init' first)", configPath)
}
} else if err != nil {
return nil, configPath, fmt.Errorf("cannot access CLI config %s: %w", configPath, err)
}
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, configPath, fmt.Errorf("failed to read CLI config: %w", err)
}
config := &CLIConfig{}
if err := parseTOML(data, config); err != nil {
return nil, configPath, fmt.Errorf("failed to parse CLI config: %w", err)
}
if err := config.Validate(); err != nil {
return nil, configPath, err
}
// Apply environment variable overrides with FETCH_ML_CLI_ prefix
if host := os.Getenv("FETCH_ML_CLI_HOST"); host != "" {
config.WorkerHost = host
}
if user := os.Getenv("FETCH_ML_CLI_USER"); user != "" {
config.WorkerUser = user
}
if base := os.Getenv("FETCH_ML_CLI_BASE"); base != "" {
config.WorkerBase = base
}
if port := os.Getenv("FETCH_ML_CLI_PORT"); port != "" {
if p, err := parseInt(port); err == nil {
config.WorkerPort = p
}
}
if apiKey := os.Getenv("FETCH_ML_CLI_API_KEY"); apiKey != "" {
config.APIKey = apiKey
}
// Also support legacy ML_ prefix for backward compatibility
if host := os.Getenv("ML_HOST"); host != "" && config.WorkerHost == "" {
config.WorkerHost = host
}
if user := os.Getenv("ML_USER"); user != "" && config.WorkerUser == "" {
config.WorkerUser = user
}
if base := os.Getenv("ML_BASE"); base != "" && config.WorkerBase == "" {
config.WorkerBase = base
}
if port := os.Getenv("ML_PORT"); port != "" && config.WorkerPort == 0 {
if p, err := parseInt(port); err == nil {
config.WorkerPort = p
}
}
if apiKey := os.Getenv("ML_API_KEY"); apiKey != "" && config.APIKey == "" {
config.APIKey = apiKey
}
return config, configPath, nil
}
// parseTOML is a simple TOML parser for the CLI config format
func parseTOML(data []byte, config *CLIConfig) error {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove quotes if present
if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
value = value[1 : len(value)-1]
}
switch key {
case "worker_host":
config.WorkerHost = value
case "worker_user":
config.WorkerUser = value
case "worker_base":
config.WorkerBase = value
case "worker_port":
if p, err := parseInt(value); err == nil {
config.WorkerPort = p
}
case "api_key":
config.APIKey = value
}
}
return nil
}
// ToTUIConfig converts CLI config to TUI config structure
func (c *CLIConfig) ToTUIConfig() *Config {
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
tuiConfig := &Config{
Host: c.WorkerHost,
User: c.WorkerUser,
Port: c.WorkerPort,
BasePath: c.WorkerBase,
// Set defaults for TUI-specific fields using smart defaults
RedisAddr: smart.RedisAddr(),
RedisDB: 0,
PodmanImage: "ml-worker:latest",
ContainerWorkspace: utils.DefaultContainerWorkspace,
ContainerResults: utils.DefaultContainerResults,
GPUAccess: false,
}
// Set up auth config with CLI API key
tuiConfig.Auth = auth.AuthConfig{
Enabled: true,
APIKeys: map[auth.Username]auth.APIKeyEntry{
"cli_user": {
Hash: auth.APIKeyHash(hashAPIKey(c.APIKey)),
Admin: true,
Roles: []string{"user", "admin"},
Permissions: map[string]bool{
"read": true,
"write": true,
"delete": true,
},
},
},
}
// Set known hosts path
tuiConfig.KnownHosts = smart.KnownHostsPath()
return tuiConfig
}
// Validate validates the CLI config
func (c *CLIConfig) Validate() error {
var errors []string
if c.WorkerHost == "" {
errors = append(errors, "worker_host is required")
} else if len(strings.TrimSpace(c.WorkerHost)) == 0 {
errors = append(errors, "worker_host cannot be empty or whitespace")
}
if c.WorkerUser == "" {
errors = append(errors, "worker_user is required")
} else if len(strings.TrimSpace(c.WorkerUser)) == 0 {
errors = append(errors, "worker_user cannot be empty or whitespace")
}
if c.WorkerBase == "" {
errors = append(errors, "worker_base is required")
} else {
// Expand and validate path
c.WorkerBase = utils.ExpandPath(c.WorkerBase)
if !filepath.IsAbs(c.WorkerBase) {
errors = append(errors, "worker_base must be an absolute path")
}
}
if c.WorkerPort == 0 {
errors = append(errors, "worker_port is required")
} else if err := utils.ValidatePort(c.WorkerPort); err != nil {
errors = append(errors, fmt.Sprintf("invalid worker_port: %v", err))
}
if c.APIKey == "" {
errors = append(errors, "api_key is required")
} else if len(c.APIKey) < 16 {
errors = append(errors, "api_key must be at least 16 characters")
}
if len(errors) > 0 {
return fmt.Errorf("validation failed: %s", strings.Join(errors, "; "))
}
return nil
}
// AuthenticateWithServer validates the API key and sets user context
func (c *CLIConfig) AuthenticateWithServer() error {
if c.APIKey == "" {
return fmt.Errorf("no API key configured")
}
// Create temporary auth config for validation
authConfig := &auth.AuthConfig{
Enabled: true,
APIKeys: map[auth.Username]auth.APIKeyEntry{
"temp": {
Hash: auth.APIKeyHash(auth.HashAPIKey(c.APIKey)),
Admin: false,
},
},
}
// Validate API key and get user info
user, err := authConfig.ValidateAPIKey(auth.HashAPIKey(c.APIKey))
if err != nil {
return fmt.Errorf("API key validation failed: %w", err)
}
// Set user context
c.CurrentUser = &UserContext{
Name: user.Name,
Admin: user.Admin,
Roles: user.Roles,
Permissions: user.Permissions,
}
return nil
}
// CheckPermission checks if the current user has a specific permission
func (c *CLIConfig) CheckPermission(permission string) bool {
if c.CurrentUser == nil {
return false
}
// Admin users have all permissions
if c.CurrentUser.Admin {
return true
}
// Check explicit permission
if c.CurrentUser.Permissions[permission] {
return true
}
// Check wildcard permission
if c.CurrentUser.Permissions["*"] {
return true
}
return false
}
// CanViewJob checks if user can view a specific job
func (c *CLIConfig) CanViewJob(jobUserID string) bool {
if c.CurrentUser == nil {
return false
}
// Admin can view all jobs
if c.CurrentUser.Admin {
return true
}
// Users can view their own jobs
return jobUserID == c.CurrentUser.Name
}
// CanModifyJob checks if user can modify a specific job
func (c *CLIConfig) CanModifyJob(jobUserID string) bool {
if c.CurrentUser == nil {
return false
}
// Need jobs:update permission
if !c.CheckPermission("jobs:update") {
return false
}
// Admin can modify all jobs
if c.CurrentUser.Admin {
return true
}
// Users can only modify their own jobs
return jobUserID == c.CurrentUser.Name
}
// migrateFromYAML migrates configuration from YAML to TOML format
func migrateFromYAML(yamlPath, tomlPath string) (string, error) {
// Check if YAML file exists
if _, err := os.Stat(yamlPath); os.IsNotExist(err) {
return "", fmt.Errorf("YAML config not found at %s", yamlPath)
}
// Read YAML config
data, err := os.ReadFile(yamlPath)
if err != nil {
return "", fmt.Errorf("failed to read YAML config: %w", err)
}
// Parse YAML to extract relevant fields
var yamlConfig map[string]interface{}
if err := yaml.Unmarshal(data, &yamlConfig); err != nil {
return "", fmt.Errorf("failed to parse YAML config: %w", err)
}
// Create CLI config from YAML data
cliConfig := &CLIConfig{}
// Extract values with fallbacks
if host, ok := yamlConfig["host"].(string); ok {
cliConfig.WorkerHost = host
}
if user, ok := yamlConfig["user"].(string); ok {
cliConfig.WorkerUser = user
}
if base, ok := yamlConfig["base_path"].(string); ok {
cliConfig.WorkerBase = base
}
if port, ok := yamlConfig["port"].(int); ok {
cliConfig.WorkerPort = port
}
// Try to extract API key from auth section
if auth, ok := yamlConfig["auth"].(map[string]interface{}); ok {
if apiKeys, ok := auth["api_keys"].(map[string]interface{}); ok {
for _, keyEntry := range apiKeys {
if keyMap, ok := keyEntry.(map[string]interface{}); ok {
if hash, ok := keyMap["hash"].(string); ok {
cliConfig.APIKey = hash // Note: This is the hash, not the actual key
break
}
}
}
}
}
// Validate migrated config
if err := cliConfig.Validate(); err != nil {
return "", fmt.Errorf("migrated config validation failed: %w", err)
}
// Generate TOML content
tomlContent := fmt.Sprintf(`# Fetch ML CLI Configuration
# Migrated from YAML configuration
worker_host = "%s"
worker_user = "%s"
worker_base = "%s"
worker_port = %d
api_key = "%s"
`,
cliConfig.WorkerHost,
cliConfig.WorkerUser,
cliConfig.WorkerBase,
cliConfig.WorkerPort,
cliConfig.APIKey,
)
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(tomlPath), 0755); err != nil {
return "", fmt.Errorf("failed to create config directory: %w", err)
}
// Write TOML file
if err := os.WriteFile(tomlPath, []byte(tomlContent), 0600); err != nil {
return "", fmt.Errorf("failed to write TOML config: %w", err)
}
return tomlPath, nil
}
// ConfigExists checks if a CLI configuration file exists
func ConfigExists(configPath string) bool {
if configPath == "" {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath = filepath.Join(home, ".ml", "config.toml")
}
_, err := os.Stat(configPath)
return !os.IsNotExist(err)
}
// GenerateDefaultConfig creates a default TOML configuration file
func GenerateDefaultConfig(configPath string) error {
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Generate default configuration
defaultContent := `# Fetch ML CLI Configuration
# This file contains connection settings for the ML platform
# Worker connection settings
worker_host = "localhost" # Hostname or IP of the worker
worker_user = "your_username" # SSH username for the worker
worker_base = "~/ml_jobs" # Base directory for ML jobs on worker
worker_port = 22 # SSH port (default: 22)
# Authentication
api_key = "your_api_key_here" # Your API key (get from admin)
# Environment variable overrides:
# ML_HOST, ML_USER, ML_BASE, ML_PORT, ML_API_KEY
`
// Write configuration file
if err := os.WriteFile(configPath, []byte(defaultContent), 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
// Set proper permissions
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
}
return nil
}
func hashAPIKey(apiKey string) string {
if apiKey == "" {
return ""
}
return auth.HashAPIKey(apiKey)
}

View file

@ -0,0 +1,194 @@
package config
import (
"testing"
)
func TestCLIConfig_CheckPermission(t *testing.T) {
tests := []struct {
name string
config *CLIConfig
permission string
want bool
}{
{
name: "Admin has all permissions",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "admin",
Admin: true,
},
},
permission: "any:permission",
want: true,
},
{
name: "User with explicit permission",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user",
Admin: false,
Permissions: map[string]bool{"jobs:create": true},
},
},
permission: "jobs:create",
want: true,
},
{
name: "User without permission",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user",
Admin: false,
Permissions: map[string]bool{"jobs:read": true},
},
},
permission: "jobs:create",
want: false,
},
{
name: "No current user",
config: &CLIConfig{
CurrentUser: nil,
},
permission: "jobs:create",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.config.CheckPermission(tt.permission)
if got != tt.want {
t.Errorf("CheckPermission() = %v, want %v", got, tt.want)
}
})
}
}
func TestCLIConfig_CanViewJob(t *testing.T) {
tests := []struct {
name string
config *CLIConfig
jobUserID string
want bool
}{
{
name: "Admin can view any job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "admin",
Admin: true,
},
},
jobUserID: "other_user",
want: true,
},
{
name: "User can view own job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user1",
Admin: false,
},
},
jobUserID: "user1",
want: true,
},
{
name: "User cannot view other's job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user1",
Admin: false,
},
},
jobUserID: "user2",
want: false,
},
{
name: "No current user cannot view",
config: &CLIConfig{
CurrentUser: nil,
},
jobUserID: "user1",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.config.CanViewJob(tt.jobUserID)
if got != tt.want {
t.Errorf("CanViewJob() = %v, want %v", got, tt.want)
}
})
}
}
func TestCLIConfig_CanModifyJob(t *testing.T) {
tests := []struct {
name string
config *CLIConfig
jobUserID string
want bool
}{
{
name: "Admin can modify any job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "admin",
Admin: true,
Permissions: map[string]bool{"jobs:update": true},
},
},
jobUserID: "other_user",
want: true,
},
{
name: "User with permission can modify own job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user1",
Admin: false,
Permissions: map[string]bool{"jobs:update": true},
},
},
jobUserID: "user1",
want: true,
},
{
name: "User without permission cannot modify",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user1",
Admin: false,
Permissions: map[string]bool{"jobs:read": true},
},
},
jobUserID: "user1",
want: false,
},
{
name: "User cannot modify other's job",
config: &CLIConfig{
CurrentUser: &UserContext{
Name: "user1",
Admin: false,
Permissions: map[string]bool{"jobs:update": true},
},
},
jobUserID: "user2",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.config.CanModifyJob(tt.jobUserID)
if got != tt.want {
t.Errorf("CanModifyJob() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,145 @@
package config
import (
"fmt"
"os"
"path/filepath"
"github.com/BurntSushi/toml"
"github.com/jfraeys/fetch_ml/internal/auth"
utils "github.com/jfraeys/fetch_ml/internal/config"
)
// Config holds TUI configuration
type Config struct {
Host string `toml:"host"`
User string `toml:"user"`
SSHKey string `toml:"ssh_key"`
Port int `toml:"port"`
BasePath string `toml:"base_path"`
WrapperScript string `toml:"wrapper_script"`
TrainScript string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
RedisDB int `toml:"redis_db"`
KnownHosts string `toml:"known_hosts"`
// Authentication
Auth auth.AuthConfig `toml:"auth"`
// Podman settings
PodmanImage string `toml:"podman_image"`
ContainerWorkspace string `toml:"container_workspace"`
ContainerResults string `toml:"container_results"`
GPUAccess bool `toml:"gpu_access"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if _, err := toml.Decode(string(data), &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = utils.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
}
// wrapper_script is deprecated - using secure_runner.py directly via Podman
if cfg.TrainScript == "" {
cfg.TrainScript = utils.DefaultTrainScript
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
}
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" {
cfg.Host = host
}
if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" {
cfg.User = user
}
if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" {
cfg.SSHKey = sshKey
}
if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" {
if p, err := parseInt(port); err == nil {
cfg.Port = p
}
}
if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" {
cfg.BasePath = basePath
}
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
cfg.TrainScript = trainScript
}
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
cfg.RedisAddr = redisAddr
}
if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" {
cfg.RedisPassword = redisPassword
}
if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" {
if db, err := parseInt(redisDB); err == nil {
cfg.RedisDB = db
}
}
if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" {
cfg.KnownHosts = knownHosts
}
return &cfg, nil
}
// Validate implements utils.Validator interface
func (c *Config) Validate() error {
if c.Port != 0 {
if err := utils.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = utils.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
return nil
}
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") }
func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") }
func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") }
// parseInt parses a string to integer
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}

View file

@ -0,0 +1,384 @@
package controller
import (
"fmt"
"path/filepath"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
// Message types for async operations
type (
JobsLoadedMsg []model.Job
TasksLoadedMsg []*model.Task
GpuLoadedMsg string
ContainerLoadedMsg string
LogLoadedMsg string
QueueLoadedMsg string
SettingsContentMsg string
SettingsUpdateMsg struct{}
StatusMsg struct {
Text string
Level string
}
TickMsg time.Time
)
// Command factories for loading data
func (c *Controller) loadAllData() tea.Cmd {
return tea.Batch(
c.loadJobs(),
c.loadQueue(),
c.loadGPU(),
c.loadContainer(),
)
}
func (c *Controller) loadJobs() tea.Cmd {
return func() tea.Msg {
type jobResult struct {
jobs []model.Job
err error
}
resultChan := make(chan jobResult, 1)
go func() {
var jobs []model.Job
statusChan := make(chan []model.Job, 4)
for _, status := range []model.JobStatus{model.StatusPending, model.StatusRunning, model.StatusFinished, model.StatusFailed} {
go func(s model.JobStatus) {
path := c.getPathForStatus(s)
names := c.server.ListDir(path)
var statusJobs []model.Job
for _, name := range names {
jobStatus, _ := c.taskQueue.GetJobStatus(name)
taskID := jobStatus["task_id"]
priority := int64(0)
if p, ok := jobStatus["priority"]; ok {
_, err := fmt.Sscanf(p, "%d", &priority)
if err != nil {
priority = 0
}
}
statusJobs = append(statusJobs, model.Job{
Name: name,
Status: s,
TaskID: taskID,
Priority: priority,
})
}
statusChan <- statusJobs
}(status)
}
for range 4 {
jobs = append(jobs, <-statusChan...)
}
resultChan <- jobResult{jobs: jobs, err: nil}
}()
result := <-resultChan
if result.err != nil {
return StatusMsg{Text: "Failed to load jobs: " + result.err.Error(), Level: "error"}
}
return JobsLoadedMsg(result.jobs)
}
}
func (c *Controller) loadQueue() tea.Cmd {
return func() tea.Msg {
tasks, err := c.taskQueue.GetQueuedTasks()
if err != nil {
c.logger.Error("failed to load queue", "error", err)
return StatusMsg{Text: "Failed to load queue: " + err.Error(), Level: "error"}
}
c.logger.Info("loaded queue", "task_count", len(tasks))
return TasksLoadedMsg(tasks)
}
}
func (c *Controller) loadGPU() tea.Cmd {
return func() tea.Msg {
type gpuResult struct {
content string
err error
}
resultChan := make(chan gpuResult, 1)
go func() {
cmd := "nvidia-smi --query-gpu=index,name,utilization.gpu,memory.used,memory.total,temperature.gpu --format=csv,noheader,nounits"
out, err := c.server.Exec(cmd)
if err == nil && strings.TrimSpace(out) != "" {
var formatted strings.Builder
formatted.WriteString("GPU Status\n")
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
lines := strings.Split(strings.TrimSpace(out), "\n")
for _, line := range lines {
parts := strings.Split(line, ", ")
if len(parts) >= 6 {
formatted.WriteString(fmt.Sprintf("🎮 GPU %s: %s\n", parts[0], parts[1]))
formatted.WriteString(fmt.Sprintf(" Utilization: %s%%\n", parts[2]))
formatted.WriteString(fmt.Sprintf(" Memory: %s/%s MB\n", parts[3], parts[4]))
formatted.WriteString(fmt.Sprintf(" Temperature: %s°C\n\n", parts[5]))
}
}
c.logger.Info("loaded GPU status", "type", "nvidia")
resultChan <- gpuResult{content: formatted.String(), err: nil}
return
}
cmd = "system_profiler SPDisplaysDataType | grep 'Chipset Model\\|VRAM' | head -2"
out, err = c.server.Exec(cmd)
if err != nil {
c.logger.Warn("GPU info unavailable", "error", err)
resultChan <- gpuResult{content: "⚠️ GPU info unavailable\n\nRun on a system with nvidia-smi or macOS GPU", err: err}
return
}
var formatted strings.Builder
formatted.WriteString("GPU Status (macOS)\n")
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
lines := strings.Split(strings.TrimSpace(out), "\n")
for _, line := range lines {
if strings.Contains(line, "Chipset Model") || strings.Contains(line, "VRAM") {
formatted.WriteString("🎮 " + strings.TrimSpace(line) + "\n")
}
}
formatted.WriteString("\n💡 Note: nvidia-smi not available on macOS\n")
c.logger.Info("loaded GPU status", "type", "macos")
resultChan <- gpuResult{content: formatted.String(), err: nil}
}()
result := <-resultChan
return GpuLoadedMsg(result.content)
}
}
func (c *Controller) loadContainer() tea.Cmd {
return func() tea.Msg {
resultChan := make(chan string, 1)
go func() {
var formatted strings.Builder
formatted.WriteString("Container Status\n")
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
formatted.WriteString("📋 Configuration:\n")
formatted.WriteString(fmt.Sprintf(" Image: %s\n", c.config.PodmanImage))
formatted.WriteString(fmt.Sprintf(" GPU: %v\n", c.config.GPUAccess))
formatted.WriteString(fmt.Sprintf(" Workspace: %s\n", c.config.ContainerWorkspace))
formatted.WriteString(fmt.Sprintf(" Results: %s\n\n", c.config.ContainerResults))
cmd := "podman ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'"
out, err := c.server.Exec(cmd)
if err == nil && strings.TrimSpace(out) != "" {
formatted.WriteString("🐳 Running Containers (Podman):\n")
lines := strings.Split(strings.TrimSpace(out), "\n")
for _, line := range lines {
parts := strings.Split(line, "|")
if len(parts) >= 3 {
status := "🟢"
if strings.Contains(parts[1], "Exited") {
status = "🔴"
}
formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0]))
formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1]))
formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2]))
}
}
} else {
cmd = "docker ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'"
out, err = c.server.Exec(cmd)
if err == nil && strings.TrimSpace(out) != "" {
formatted.WriteString("🐳 Running Containers (Docker):\n")
lines := strings.Split(strings.TrimSpace(out), "\n")
for _, line := range lines {
parts := strings.Split(line, "|")
if len(parts) >= 3 {
status := "🟢"
if strings.Contains(parts[1], "Exited") {
status = "🔴"
}
formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0]))
formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1]))
formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2]))
}
}
} else {
formatted.WriteString("⚠️ No containers found\n")
}
}
formatted.WriteString("💻 System Info:\n")
if podmanVersion, err := c.server.Exec("podman --version"); err == nil {
formatted.WriteString(fmt.Sprintf(" Podman: %s\n", strings.TrimSpace(podmanVersion)))
} else if dockerVersion, err := c.server.Exec("docker --version"); err == nil {
formatted.WriteString(fmt.Sprintf(" Docker: %s\n", strings.TrimSpace(dockerVersion)))
} else {
formatted.WriteString(" ⚠️ Container engine not available\n")
}
c.logger.Info("loaded container status")
resultChan <- formatted.String()
}()
return ContainerLoadedMsg(<-resultChan)
}
}
func (c *Controller) loadLog(jobName string) tea.Cmd {
return func() tea.Msg {
resultChan := make(chan string, 1)
go func() {
statusChan := make(chan string, 3)
for _, status := range []model.JobStatus{model.StatusRunning, model.StatusFinished, model.StatusFailed} {
go func(s model.JobStatus) {
logPath := filepath.Join(c.getPathForStatus(s), jobName, "output.log")
if c.server.RemoteExists(logPath) {
content := c.server.TailFile(logPath, 200)
statusChan <- content
} else {
statusChan <- ""
}
}(status)
}
for range 3 {
result := <-statusChan
if result != "" {
var formatted strings.Builder
formatted.WriteString(fmt.Sprintf("📋 Log: %s\n", jobName))
formatted.WriteString(strings.Repeat("═", 60) + "\n\n")
formatted.WriteString(result)
resultChan <- formatted.String()
return
}
}
resultChan <- fmt.Sprintf("⚠️ No log found for %s\n\nJob may not have started yet.", jobName)
}()
return LogLoadedMsg(<-resultChan)
}
}
func (c *Controller) queueJob(jobName string, args string) tea.Cmd {
return func() tea.Msg {
resultChan := make(chan StatusMsg, 1)
go func() {
priority := int64(5)
if strings.Contains(args, "--priority") {
_, err := fmt.Sscanf(args, "--priority %d", &priority)
if err != nil {
c.logger.Error("invalid priority argument", "args", args, "error", err)
resultChan <- StatusMsg{
Text: fmt.Sprintf("Invalid priority: %v", err),
Level: "error",
}
return
}
}
task, err := c.taskQueue.EnqueueTask(jobName, args, priority)
if err != nil {
c.logger.Error("failed to queue job", "job_name", jobName, "error", err)
resultChan <- StatusMsg{
Text: fmt.Sprintf("Failed to queue %s: %v", jobName, err),
Level: "error",
}
return
}
c.logger.Info("job queued", "job_name", jobName, "task_id", task.ID[:8], "priority", priority)
resultChan <- StatusMsg{
Text: fmt.Sprintf("✓ Queued: %s (ID: %s, P:%d)", jobName, task.ID[:8], priority),
Level: "success",
}
}()
return <-resultChan
}
}
func (c *Controller) deleteJob(jobName string) tea.Cmd {
return func() tea.Msg {
jobPath := filepath.Join(c.config.PendingPath(), jobName)
if _, err := c.server.Exec(fmt.Sprintf("rm -rf %s", jobPath)); err != nil {
return StatusMsg{Text: fmt.Sprintf("Failed to delete %s: %v", jobName, err), Level: "error"}
}
return StatusMsg{Text: fmt.Sprintf("✓ Deleted: %s", jobName), Level: "success"}
}
}
func (c *Controller) markFailed(jobName string) tea.Cmd {
return func() tea.Msg {
src := filepath.Join(c.config.RunningPath(), jobName)
dst := filepath.Join(c.config.FailedPath(), jobName)
if _, err := c.server.Exec(fmt.Sprintf("mv %s %s", src, dst)); err != nil {
return StatusMsg{Text: fmt.Sprintf("Failed to mark failed: %v", err), Level: "error"}
}
return StatusMsg{Text: fmt.Sprintf("⚠ Marked failed: %s", jobName), Level: "warning"}
}
}
func (c *Controller) cancelTask(taskID string) tea.Cmd {
return func() tea.Msg {
if err := c.taskQueue.CancelTask(taskID); err != nil {
c.logger.Error("failed to cancel task", "task_id", taskID[:8], "error", err)
return StatusMsg{Text: fmt.Sprintf("Cancel failed: %v", err), Level: "error"}
}
c.logger.Info("task cancelled", "task_id", taskID[:8])
return StatusMsg{Text: fmt.Sprintf("✓ Cancelled: %s", taskID[:8]), Level: "success"}
}
}
func (c *Controller) showQueue(m model.State) tea.Cmd {
return func() tea.Msg {
var content strings.Builder
content.WriteString("Task Queue\n")
content.WriteString(strings.Repeat("═", 60) + "\n\n")
if len(m.QueuedTasks) == 0 {
content.WriteString("📭 No tasks in queue\n")
} else {
for i, task := range m.QueuedTasks {
statusIcon := "⏳"
if task.Status == "running" {
statusIcon = "▶"
}
content.WriteString(fmt.Sprintf("%d. %s %s [ID: %s]\n",
i+1, statusIcon, task.JobName, task.ID[:8]))
content.WriteString(fmt.Sprintf(" Priority: %d | Status: %s\n",
task.Priority, task.Status))
if task.Args != "" {
content.WriteString(fmt.Sprintf(" Args: %s\n", task.Args))
}
content.WriteString(fmt.Sprintf(" Created: %s\n",
task.CreatedAt.Format("2006-01-02 15:04:05")))
if task.StartedAt != nil {
duration := time.Since(*task.StartedAt)
content.WriteString(fmt.Sprintf(" Running for: %s\n",
duration.Round(time.Second)))
}
content.WriteString("\n")
}
}
return QueueLoadedMsg(content.String())
}
}
func tickCmd() tea.Cmd {
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
return TickMsg(t)
})
}

View file

@ -0,0 +1,302 @@
package controller
import (
"fmt"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/services"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// Controller handles all business logic and state updates
type Controller struct {
config *config.Config
server *services.MLServer
taskQueue *services.TaskQueue
logger *logging.Logger
}
// New creates a new Controller instance
func New(cfg *config.Config, srv *services.MLServer, tq *services.TaskQueue, logger *logging.Logger) *Controller {
return &Controller{
config: cfg,
server: srv,
taskQueue: tq,
logger: logger,
}
}
// Init initializes the TUI and returns initial commands
func (c *Controller) Init() tea.Cmd {
return tea.Batch(
tea.SetWindowTitle("FetchML"),
c.loadAllData(),
tickCmd(),
)
}
// Update handles all messages and updates the state
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
// Handle input mode (for queuing jobs with args)
if m.InputMode {
switch msg.String() {
case "enter":
args := m.Input.Value()
m.Input.SetValue("")
m.InputMode = false
if job := getSelectedJob(m); job != nil {
cmds = append(cmds, c.queueJob(job.Name, args))
}
return m, tea.Batch(cmds...)
case "esc":
m.InputMode = false
m.Input.SetValue("")
return m, nil
}
var cmd tea.Cmd
m.Input, cmd = m.Input.Update(msg)
return m, cmd
}
// Handle settings-specific keys
if m.ActiveView == model.ViewModeSettings {
switch msg.String() {
case "up", "k":
if m.SettingsIndex > 1 { // Skip index 0 (Status)
m.SettingsIndex--
cmds = append(cmds, c.updateSettingsContent(m))
if m.SettingsIndex == 1 {
m.ApiKeyInput.Focus()
} else {
m.ApiKeyInput.Blur()
}
}
case "down", "j":
if m.SettingsIndex < 2 {
m.SettingsIndex++
cmds = append(cmds, c.updateSettingsContent(m))
if m.SettingsIndex == 1 {
m.ApiKeyInput.Focus()
} else {
m.ApiKeyInput.Blur()
}
}
case "enter":
if cmd := c.handleSettingsAction(&m); cmd != nil {
cmds = append(cmds, cmd)
}
case "esc":
m.ActiveView = model.ViewModeJobs
m.ApiKeyInput.Blur()
}
if m.SettingsIndex == 1 { // API Key input field
var cmd tea.Cmd
m.ApiKeyInput, cmd = m.ApiKeyInput.Update(msg)
cmds = append(cmds, cmd)
// Force update settings view to show typed characters immediately
cmds = append(cmds, c.updateSettingsContent(m))
}
return m, tea.Batch(cmds...)
}
// Handle global keys
switch {
case key.Matches(msg, m.Keys.Quit):
return m, tea.Quit
case key.Matches(msg, m.Keys.Refresh):
m.IsLoading = true
m.Status = "Refreshing all data..."
m.LastRefresh = time.Now()
cmds = append(cmds, c.loadAllData())
case key.Matches(msg, m.Keys.RefreshGPU):
m.Status = "Refreshing GPU status..."
cmds = append(cmds, c.loadGPU())
case key.Matches(msg, m.Keys.Trigger):
if job := getSelectedJob(m); job != nil {
cmds = append(cmds, c.queueJob(job.Name, ""))
}
case key.Matches(msg, m.Keys.TriggerArgs):
if job := getSelectedJob(m); job != nil {
m.InputMode = true
m.Input.Focus()
}
case key.Matches(msg, m.Keys.ViewQueue):
m.ActiveView = model.ViewModeQueue
cmds = append(cmds, c.showQueue(m))
case key.Matches(msg, m.Keys.ViewContainer):
m.ActiveView = model.ViewModeContainer
cmds = append(cmds, c.loadContainer())
case key.Matches(msg, m.Keys.ViewGPU):
m.ActiveView = model.ViewModeGPU
cmds = append(cmds, c.loadGPU())
case key.Matches(msg, m.Keys.ViewJobs):
m.ActiveView = model.ViewModeJobs
case key.Matches(msg, m.Keys.ViewSettings):
m.ActiveView = model.ViewModeSettings
m.SettingsIndex = 1 // Start at Input field, skip Status
m.ApiKeyInput.Focus()
cmds = append(cmds, c.updateSettingsContent(m))
case key.Matches(msg, m.Keys.ViewExperiments):
m.ActiveView = model.ViewModeExperiments
cmds = append(cmds, c.loadExperiments())
case key.Matches(msg, m.Keys.Cancel):
if job := getSelectedJob(m); job != nil && job.TaskID != "" {
cmds = append(cmds, c.cancelTask(job.TaskID))
}
case key.Matches(msg, m.Keys.Delete):
if job := getSelectedJob(m); job != nil && job.Status == model.StatusPending {
cmds = append(cmds, c.deleteJob(job.Name))
}
case key.Matches(msg, m.Keys.MarkFailed):
if job := getSelectedJob(m); job != nil && job.Status == model.StatusRunning {
cmds = append(cmds, c.markFailed(job.Name))
}
case key.Matches(msg, m.Keys.Help):
m.ShowHelp = !m.ShowHelp
}
case tea.WindowSizeMsg:
m.Width = msg.Width
m.Height = msg.Height
// Update component sizes
h, v := 4, 2 // docStyle.GetFrameSize() approx
listHeight := msg.Height - v - 8
m.JobList.SetSize(msg.Width/3-h, listHeight)
panelWidth := msg.Width*2/3 - h - 2
panelHeight := (listHeight - 6) / 3
m.GpuView.Width = panelWidth
m.GpuView.Height = panelHeight
m.ContainerView.Width = panelWidth
m.ContainerView.Height = panelHeight
m.QueueView.Width = panelWidth
m.QueueView.Height = listHeight - 4
m.SettingsView.Width = panelWidth
m.SettingsView.Height = listHeight - 4
m.ExperimentsView.Width = panelWidth
m.ExperimentsView.Height = listHeight - 4
case JobsLoadedMsg:
m.Jobs = []model.Job(msg)
calculateJobStats(&m)
items := make([]list.Item, len(m.Jobs))
for i, job := range m.Jobs {
items[i] = job
}
cmds = append(cmds, m.JobList.SetItems(items))
m.Status = formatStatus(m)
m.IsLoading = false
case TasksLoadedMsg:
m.QueuedTasks = []*model.Task(msg)
m.Status = formatStatus(m)
case GpuLoadedMsg:
m.GpuView.SetContent(string(msg))
m.GpuView.GotoTop()
case ContainerLoadedMsg:
m.ContainerView.SetContent(string(msg))
m.ContainerView.GotoTop()
case QueueLoadedMsg:
m.QueueView.SetContent(string(msg))
m.QueueView.GotoTop()
case SettingsContentMsg:
m.SettingsView.SetContent(string(msg))
case ExperimentsLoadedMsg:
m.ExperimentsView.SetContent(string(msg))
m.ExperimentsView.GotoTop()
case SettingsUpdateMsg:
// Settings content was updated, just trigger a re-render
case StatusMsg:
if msg.Level == "error" {
m.ErrorMsg = msg.Text
m.Status = "Error occurred - check status"
} else {
m.ErrorMsg = ""
m.Status = msg.Text
}
case TickMsg:
var spinCmd tea.Cmd
m.Spinner, spinCmd = m.Spinner.Update(msg)
cmds = append(cmds, spinCmd)
// Auto-refresh every 10 seconds
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
m.LastRefresh = time.Now()
cmds = append(cmds, c.loadAllData())
}
cmds = append(cmds, tickCmd())
default:
var spinCmd tea.Cmd
m.Spinner, spinCmd = m.Spinner.Update(msg)
cmds = append(cmds, spinCmd)
}
// Update all bubble components
var cmd tea.Cmd
m.JobList, cmd = m.JobList.Update(msg)
cmds = append(cmds, cmd)
m.GpuView, cmd = m.GpuView.Update(msg)
cmds = append(cmds, cmd)
m.ContainerView, cmd = m.ContainerView.Update(msg)
cmds = append(cmds, cmd)
m.QueueView, cmd = m.QueueView.Update(msg)
cmds = append(cmds, cmd)
m.ExperimentsView, cmd = m.ExperimentsView.Update(msg)
cmds = append(cmds, cmd)
return m, tea.Batch(cmds...)
}
// ExperimentsLoadedMsg is sent when experiments are loaded
type ExperimentsLoadedMsg string
func (c *Controller) loadExperiments() tea.Cmd {
return func() tea.Msg {
commitIDs, err := c.taskQueue.ListExperiments()
if err != nil {
return StatusMsg{Level: "error", Text: fmt.Sprintf("Failed to list experiments: %v", err)}
}
if len(commitIDs) == 0 {
return ExperimentsLoadedMsg("Experiments:\n\nNo experiments found.")
}
var output string
output += "Experiments:\n\n"
for _, commitID := range commitIDs {
details, err := c.taskQueue.GetExperimentDetails(commitID)
if err != nil {
output += fmt.Sprintf("Error loading %s: %v\n\n", commitID, err)
continue
}
output += details + "\n----------------------------------------\n\n"
}
return ExperimentsLoadedMsg(output)
}
}

View file

@ -0,0 +1,69 @@
package controller
import (
"fmt"
"strings"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
// Helper functions
func (c *Controller) getPathForStatus(status model.JobStatus) string {
switch status {
case model.StatusPending:
return c.config.PendingPath()
case model.StatusRunning:
return c.config.RunningPath()
case model.StatusFinished:
return c.config.FinishedPath()
case model.StatusFailed:
return c.config.FailedPath()
}
return ""
}
func getSelectedJob(m model.State) *model.Job {
if item := m.JobList.SelectedItem(); item != nil {
if job, ok := item.(model.Job); ok {
return &job
}
}
return nil
}
func calculateJobStats(m *model.State) {
m.JobStats = make(map[model.JobStatus]int)
for _, job := range m.Jobs {
m.JobStats[job.Status]++
}
}
func formatStatus(m model.State) string {
var parts []string
if len(m.Jobs) > 0 {
stats := []string{}
if count := m.JobStats[model.StatusPending]; count > 0 {
stats = append(stats, fmt.Sprintf("⏸ %d", count))
}
if count := m.JobStats[model.StatusRunning]; count > 0 {
stats = append(stats, fmt.Sprintf("▶ %d", count))
}
if count := m.JobStats[model.StatusFinished]; count > 0 {
stats = append(stats, fmt.Sprintf("✓ %d", count))
}
if count := m.JobStats[model.StatusFailed]; count > 0 {
stats = append(stats, fmt.Sprintf("✗ %d", count))
}
parts = append(parts, strings.Join(stats, " | "))
}
if len(m.QueuedTasks) > 0 {
parts = append(parts, fmt.Sprintf("Queue: %d", len(m.QueuedTasks)))
}
parts = append(parts, fmt.Sprintf("Updated: %s", m.LastRefresh.Format("15:04:05")))
return strings.Join(parts, " • ")
}

View file

@ -0,0 +1,126 @@
package controller
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
// Settings-related command factories and handlers
func (c *Controller) updateSettingsContent(m model.State) tea.Cmd {
var content strings.Builder
// API Key Status section
statusStyle := lipgloss.NewStyle().
Border(lipgloss.NormalBorder()).
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}). // borderfg
Padding(0, 1).
Width(m.SettingsView.Width - 4)
if m.SettingsIndex == 0 {
statusStyle = statusStyle.
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"}) // activeBorderfg
}
statusContent := fmt.Sprintf("%s API Key Status\n%s",
getSettingsIndicator(m, 0),
getAPIKeyStatus(m))
content.WriteString(statusStyle.Render(statusContent))
content.WriteString("\n")
// API Key Input section
inputStyle := lipgloss.NewStyle().
Border(lipgloss.NormalBorder()).
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}).
Padding(0, 1).
Width(m.SettingsView.Width - 4)
if m.SettingsIndex == 1 {
inputStyle = inputStyle.
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"})
}
inputContent := fmt.Sprintf("%s Enter New API Key\n%s",
getSettingsIndicator(m, 1),
m.ApiKeyInput.View())
content.WriteString(inputStyle.Render(inputContent))
content.WriteString("\n")
// Save Configuration section
saveStyle := lipgloss.NewStyle().
Border(lipgloss.NormalBorder()).
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}).
Padding(0, 1).
Width(m.SettingsView.Width - 4)
if m.SettingsIndex == 2 {
saveStyle = saveStyle.
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"})
}
saveContent := fmt.Sprintf("%s Save Configuration\n[Enter]",
getSettingsIndicator(m, 2))
content.WriteString(saveStyle.Render(saveContent))
content.WriteString("\n")
// Current API Key display
keyStyle := lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#666", Dark: "#999"}).
Italic(true)
keyContent := fmt.Sprintf("Current API Key: %s", maskAPIKey(m.ApiKey))
content.WriteString(keyStyle.Render(keyContent))
return func() tea.Msg { return SettingsContentMsg(content.String()) }
}
func (c *Controller) handleSettingsAction(m *model.State) tea.Cmd {
switch m.SettingsIndex {
case 0: // API Key Status - do nothing
return nil
case 1: // Enter New API Key - do nothing, Enter key disabled
return nil
case 2: // Save Configuration
if m.ApiKeyInput.Value() != "" {
m.ApiKey = m.ApiKeyInput.Value()
m.ApiKeyInput.SetValue("")
m.Status = "Configuration saved (in-memory only)"
return c.updateSettingsContent(*m)
} else if m.ApiKey != "" {
m.Status = "Configuration saved (in-memory only)"
} else {
m.ErrorMsg = "No API key to save"
}
}
return nil
}
// Helper functions for settings
func getSettingsIndicator(m model.State, index int) string {
if index == m.SettingsIndex {
return "▶"
}
return " "
}
func getAPIKeyStatus(m model.State) string {
if m.ApiKey != "" {
return "✓ API Key is set\n" + maskAPIKey(m.ApiKey)
}
return "⚠ No API Key configured"
}
func maskAPIKey(key string) string {
if key == "" {
return "(not set)"
}
if len(key) <= 8 {
return "****"
}
return key[:4] + "****" + key[len(key)-4:]
}

View file

@ -0,0 +1,206 @@
package model
import (
"fmt"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
"github.com/charmbracelet/lipgloss"
)
type ViewMode int
const (
ViewModeJobs ViewMode = iota
ViewModeGPU
ViewModeQueue
ViewModeContainer
ViewModeSettings
ViewModeDatasets
ViewModeExperiments
)
type JobStatus string
const (
StatusPending JobStatus = "pending"
StatusQueued JobStatus = "queued"
StatusRunning JobStatus = "running"
StatusFinished JobStatus = "finished"
StatusFailed JobStatus = "failed"
)
type Job struct {
Name string
Status JobStatus
TaskID string
Priority int64
}
func (j Job) Title() string { return j.Name }
func (j Job) Description() string {
icon := map[JobStatus]string{
StatusPending: "⏸",
StatusQueued: "⏳",
StatusRunning: "▶",
StatusFinished: "✓",
StatusFailed: "✗",
}[j.Status]
pri := ""
if j.Priority > 0 {
pri = fmt.Sprintf(" [P%d]", j.Priority)
}
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
}
func (j Job) FilterValue() string { return j.Name }
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
Error string `json:"error,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
type DatasetInfo struct {
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
Location string `json:"location"`
LastAccess time.Time `json:"last_access"`
}
// State holds the application state
type State struct {
Jobs []Job
QueuedTasks []*Task
Datasets []DatasetInfo
JobList list.Model
GpuView viewport.Model
ContainerView viewport.Model
QueueView viewport.Model
SettingsView viewport.Model
DatasetView viewport.Model
ExperimentsView viewport.Model
Input textinput.Model
ApiKeyInput textinput.Model
Status string
ErrorMsg string
InputMode bool
Width int
Height int
ShowHelp bool
Spinner spinner.Model
ActiveView ViewMode
LastRefresh time.Time
IsLoading bool
JobStats map[JobStatus]int
ApiKey string
SettingsIndex int
Keys KeyMap
}
type KeyMap struct {
Refresh key.Binding
Trigger key.Binding
TriggerArgs key.Binding
ViewQueue key.Binding
ViewContainer key.Binding
ViewGPU key.Binding
ViewJobs key.Binding
ViewDatasets key.Binding
ViewExperiments key.Binding
ViewSettings key.Binding
Cancel key.Binding
Delete key.Binding
MarkFailed key.Binding
RefreshGPU key.Binding
Help key.Binding
Quit key.Binding
}
var Keys = KeyMap{
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
}
func InitialState(apiKey string) State {
items := []list.Item{}
delegate := list.NewDefaultDelegate()
delegate.Styles.SelectedTitle = delegate.Styles.SelectedTitle.
Foreground(lipgloss.Color("170")).
Bold(true)
delegate.Styles.SelectedDesc = delegate.Styles.SelectedDesc.
Foreground(lipgloss.Color("246"))
jobList := list.New(items, delegate, 0, 0)
jobList.Title = "ML Jobs & Queue"
jobList.SetShowStatusBar(true)
jobList.SetFilteringEnabled(true)
jobList.SetShowHelp(false)
// Styles will be set in View or here?
// Keeping style initialization here as it's part of the model state setup
jobList.Styles.Title = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"}).
Padding(0, 0, 1, 0)
input := textinput.New()
input.Placeholder = "Args: --epochs 100 --lr 0.001 --priority 5"
input.Width = 60
input.CharLimit = 200
apiKeyInput := textinput.New()
apiKeyInput.Placeholder = "Enter API key..."
apiKeyInput.Width = 40
apiKeyInput.CharLimit = 200
s := spinner.New()
s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"})
return State{
JobList: jobList,
GpuView: viewport.New(0, 0),
ContainerView: viewport.New(0, 0),
QueueView: viewport.New(0, 0),
SettingsView: viewport.New(0, 0),
DatasetView: viewport.New(0, 0),
ExperimentsView: viewport.New(0, 0),
Input: input,
ApiKeyInput: apiKeyInput,
Status: "Connected",
InputMode: false,
ShowHelp: false,
Spinner: s,
ActiveView: ViewModeJobs,
LastRefresh: time.Now(),
IsLoading: false,
JobStats: make(map[JobStatus]int),
ApiKey: apiKey,
SettingsIndex: 0,
Keys: Keys,
}
}

View file

@ -0,0 +1,237 @@
package services
import (
"context"
"fmt"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
type TaskQueue struct {
internal *queue.TaskQueue
expManager *experiment.Manager
ctx context.Context
}
func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
// Create internal queue config
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
}
internalQueue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
return nil, fmt.Errorf("failed to create task queue: %w", err)
}
// Initialize experiment manager
// TODO: Get base path from config
expManager := experiment.NewManager("./experiments")
return &TaskQueue{
internal: internalQueue,
expManager: expManager,
ctx: context.Background(),
}, nil
}
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) {
// Create internal task
internalTask := &queue.Task{
JobName: jobName,
Args: args,
Priority: priority,
}
// Use internal queue to enqueue
err := tq.internal.AddTask(internalTask)
if err != nil {
return nil, err
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: "queued",
Priority: int64(internalTask.Priority),
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
}, nil
}
func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
internalTask, err := tq.internal.GetNextTask()
if err != nil {
return nil, err
}
if internalTask == nil {
return nil, nil
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: internalTask.Status,
Priority: internalTask.Priority,
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
}, nil
}
func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) {
internalTask, err := tq.internal.GetTask(taskID)
if err != nil {
return nil, err
}
// Convert to TUI model
return &model.Task{
ID: internalTask.ID,
JobName: internalTask.JobName,
Args: internalTask.Args,
Status: internalTask.Status,
Priority: internalTask.Priority,
CreatedAt: internalTask.CreatedAt,
Metadata: internalTask.Metadata,
}, nil
}
func (tq *TaskQueue) UpdateTask(task *model.Task) error {
// Convert to internal task
internalTask := &queue.Task{
ID: task.ID,
JobName: task.JobName,
Args: task.Args,
Status: task.Status,
Priority: task.Priority,
CreatedAt: task.CreatedAt,
Metadata: task.Metadata,
}
return tq.internal.UpdateTask(internalTask)
}
func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) {
internalTasks, err := tq.internal.GetAllTasks()
if err != nil {
return nil, err
}
// Convert to TUI models
tasks := make([]*model.Task, len(internalTasks))
for i, task := range internalTasks {
tasks[i] = &model.Task{
ID: task.ID,
JobName: task.JobName,
Args: task.Args,
Status: task.Status,
Priority: task.Priority,
CreatedAt: task.CreatedAt,
Metadata: task.Metadata,
}
}
return tasks, nil
}
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
// This method doesn't exist in internal queue, implement basic version
task, err := tq.internal.GetTaskByName(jobName)
if err != nil {
return nil, err
}
if task == nil {
return map[string]string{"status": "not_found"}, nil
}
return map[string]string{
"status": task.Status,
"task_id": task.ID,
}, nil
}
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
return tq.internal.RecordMetric(jobName, metric, value)
}
func (tq *TaskQueue) GetMetrics(jobName string) (map[string]string, error) {
// This method doesn't exist in internal queue, return empty for now
return map[string]string{}, nil
}
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
// This method doesn't exist in internal queue, return empty for now
return []model.DatasetInfo{}, nil
}
func (tq *TaskQueue) CancelTask(taskID string) error {
return tq.internal.CancelTask(taskID)
}
func (tq *TaskQueue) ListExperiments() ([]string, error) {
return tq.expManager.ListExperiments()
}
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
meta, err := tq.expManager.ReadMetadata(commitID)
if err != nil {
return "", err
}
metrics, err := tq.expManager.GetMetrics(commitID)
if err != nil {
return "", err
}
output := fmt.Sprintf("Experiment: %s\n", meta.JobName)
output += fmt.Sprintf("Commit ID: %s\n", meta.CommitID)
output += fmt.Sprintf("User: %s\n", meta.User)
output += fmt.Sprintf("Timestamp: %d\n\n", meta.Timestamp)
output += "Metrics:\n"
if len(metrics) == 0 {
output += " No metrics logged.\n"
} else {
for _, m := range metrics {
output += fmt.Sprintf(" %s: %.4f (Step: %d)\n", m.Name, m.Value, m.Step)
}
}
return output, nil
}
func (tq *TaskQueue) Close() error {
return tq.internal.Close()
}
// MLServer wraps network.SSHClient for backward compatibility
type MLServer struct {
*network.SSHClient
addr string
}
func NewMLServer(cfg *config.Config) (*MLServer, error) {
// Local mode: skip SSH entirely
if cfg.Host == "" {
client, _ := network.NewSSHClient("", "", "", 0, "")
return &MLServer{SSHClient: client, addr: "localhost"}, nil
}
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
return &MLServer{SSHClient: client, addr: addr}, nil
}

View file

@ -0,0 +1,255 @@
package view
import (
"strings"
"github.com/charmbracelet/lipgloss"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
const (
headerfgLight = "#d35400"
headerfgDark = "#ff9e64"
activeBorderfgLight = "#3498db"
activeBorderfgDark = "#7aa2f7"
errorbgLight = "#fee"
errorbgDark = "#633"
errorfgLight = "#a00"
errorfgDark = "#faa"
titlefgLight = "#d35400"
titlefgDark = "#ff9e64"
statusfgLight = "#2e3440"
statusfgDark = "#d8dee9"
statusbgLight = "#e5e9f0"
statusbgDark = "#2e3440"
borderfgLight = "#d8dee9"
borderfgDark = "#4c566a"
helpfgLight = "#4c566a"
helpfgDark = "#88c0d0"
)
var (
docStyle = lipgloss.NewStyle().Margin(1, 2)
activeBorderStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.AdaptiveColor{Light: activeBorderfgLight, Dark: activeBorderfgDark}).
Padding(1, 2)
errorStyle = lipgloss.NewStyle().
Background(lipgloss.AdaptiveColor{Light: errorbgLight, Dark: errorbgDark}).
Foreground(lipgloss.AdaptiveColor{Light: errorfgLight, Dark: errorfgDark}).
Padding(0, 1).
Bold(true)
titleStyle = (lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: titlefgLight, Dark: titlefgDark}).
MarginBottom(1))
statusStyle = (lipgloss.NewStyle().
Background(lipgloss.AdaptiveColor{Light: statusbgLight, Dark: statusbgDark}).
Foreground(lipgloss.AdaptiveColor{Light: statusfgLight, Dark: statusfgDark}).
Padding(0, 1))
borderStyle = (lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.AdaptiveColor{Light: borderfgLight, Dark: borderfgDark}).
Padding(0, 1))
helpStyle = (lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: helpfgLight, Dark: helpfgDark}))
)
func Render(m model.State) string {
if m.Width == 0 {
return "Loading..."
}
// Title
title := titleStyle.Width(m.Width - 4).Render("🤖 ML Experiment Manager")
// Left panel - Job list (30% width)
leftWidth := int(float64(m.Width) * 0.3)
leftPanel := getJobListPanel(m, leftWidth)
// Right panel - Dynamic content (70% width)
rightWidth := m.Width - leftWidth - 4
rightPanel := getRightPanel(m, rightWidth)
// Main content
main := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel)
// Status bar
statusBar := getStatusBar(m)
// Error bar (if present)
var errorBar string
if m.ErrorMsg != "" {
errorBar = errorStyle.Width(m.Width - 4).Render("⚠ Error: " + m.ErrorMsg)
}
// Help view (toggleable)
var helpView string
if m.ShowHelp {
helpView = helpStyle.Width(m.Width-4).
Padding(1, 2).
Render(helpText(m))
}
// Quick help bar
quickHelp := helpStyle.Width(m.Width - 4).Render(getQuickHelp(m))
// Compose final layout
parts := []string{title, main, statusBar}
if errorBar != "" {
parts = append(parts, errorBar)
}
if helpView != "" {
parts = append(parts, helpView)
}
parts = append(parts, quickHelp)
return docStyle.Render(lipgloss.JoinVertical(lipgloss.Left, parts...))
}
func getJobListPanel(m model.State, width int) string {
style := borderStyle
if m.ActiveView == model.ViewModeJobs {
style = activeBorderStyle
}
// Ensure the job list has proper dimensions to prevent rendering issues
// Note: We can't modify the model here as it's passed by value,
// but the View() method of list.Model uses its internal state.
// Ideally, the controller should have set the size.
// For now, we assume the controller handles resizing or we act on a copy.
// But list.Model.SetSize modifies the model.
// Since we receive 'm' by value, modifications to m.JobList won't persist.
// However, we need to render it with the correct size.
// So we can modify our local copy 'm'.
h, v := style.GetFrameSize()
m.JobList.SetSize(width-h, m.Height-v-4) // Adjust height for title/help/status
// Custom empty state
if len(m.JobList.Items()) == 0 {
return style.Width(width - h).Render(
lipgloss.JoinVertical(lipgloss.Left,
m.JobList.Styles.Title.Render(m.JobList.Title),
"\n No jobs found.",
" Press 't' to queue.",
),
)
}
return style.Width(width - h).Render(m.JobList.View())
}
func getRightPanel(m model.State, width int) string {
var content string
var viewTitle string
style := borderStyle
switch m.ActiveView {
case model.ViewModeGPU:
style = activeBorderStyle
viewTitle = "🎮 GPU Status"
content = m.GpuView.View()
case model.ViewModeContainer:
style = activeBorderStyle
viewTitle = "🐳 Container Status"
content = m.ContainerView.View()
case model.ViewModeQueue:
style = activeBorderStyle
viewTitle = "⏳ Task Queue"
content = m.QueueView.View()
case model.ViewModeSettings:
style = activeBorderStyle
viewTitle = "⚙️ Settings"
content = m.SettingsView.View()
case model.ViewModeExperiments:
style = activeBorderStyle
viewTitle = "🧪 Experiments"
content = m.ExperimentsView.View()
default:
viewTitle = "📊 System Overview"
content = getOverviewPanel(m)
}
header := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: headerfgLight, Dark: headerfgDark}).
Render(viewTitle)
h, _ := style.GetFrameSize()
return style.Width(width - h).Render(header + "\n\n" + content)
}
func getOverviewPanel(m model.State) string {
var sections []string
sections = append(sections, "🎮 GPU\n"+strings.Repeat("─", 40))
sections = append(sections, m.GpuView.View())
sections = append(sections, "\n🐳 Containers\n"+strings.Repeat("─", 40))
sections = append(sections, m.ContainerView.View())
return strings.Join(sections, "\n")
}
func getStatusBar(m model.State) string {
spinnerStr := m.Spinner.View()
if !m.IsLoading {
if m.ShowHelp {
spinnerStr = "?"
} else {
spinnerStr = "●"
}
}
statusText := m.Status
if m.ShowHelp {
statusText = "Press 'h' to hide help"
}
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText)
}
func helpText(m model.State) string {
if m.ActiveView == model.ViewModeSettings {
return `
Settings Shortcuts
Navigation
j/k, / : Move selection
Enter : Edit / Save
Esc : Exit Settings
General
h or ? : Toggle this help q/Ctrl+C : Quit
`
}
return `
Keyboard Shortcuts
Navigation
j/k, / : Move selection / : Filter jobs
1 : Job list view 2 : Datasets view
3 : Experiments view v : Queue view
g : GPU view o : Container view
s : Settings view
Actions
t : Queue job a : Queue w/ args
c : Cancel task d : Delete pending
f : Mark as failed r : Refresh all
G : Refresh GPU only
General
h or ? : Toggle this help q/Ctrl+C : Quit
`
}
func getQuickHelp(m model.State) string {
if m.ActiveView == model.ViewModeSettings {
return " ↑/↓:move enter:select esc:exit settings q:quit"
}
return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers s:settings t:queue r:refresh q:quit"
}

204
cmd/tui/main.go Normal file
View file

@ -0,0 +1,204 @@
// Package main implements the ML TUI
package main
import (
"log"
"os"
"os/signal"
"syscall"
tea "github.com/charmbracelet/bubbletea"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/controller"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/services"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/view"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/logging"
)
type AppModel struct {
state model.State
controller *controller.Controller
}
func (m AppModel) Init() tea.Cmd {
return m.controller.Init()
}
func (m AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
newState, cmd := m.controller.Update(msg, m.state)
m.state = newState
return m, cmd
}
func (m AppModel) View() string {
return view.Render(m.state)
}
func main() {
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateAuthFlags(authFlags); err != nil {
log.Fatalf("Authentication flag error: %v", err)
}
// Get API key from various sources
apiKey := auth.GetAPIKeyFromSources(authFlags)
var (
cfg *config.Config
cliConfig *config.CLIConfig
cliConfPath string
)
configFlag := authFlags.ConfigFile
// Only support TOML configuration
var err error
cliConfig, cliConfPath, err = config.LoadCLIConfig(configFlag)
if err != nil {
if configFlag != "" {
log.Fatalf("Failed to load TOML config %s: %v", configFlag, err)
} else {
// Provide helpful error message for data scientists
log.Printf("=== Fetch ML TUI - Configuration Required ===")
log.Printf("")
log.Printf("Error: %v", err)
log.Printf("")
log.Printf("To get started with the TUI, you need to initialize your configuration:")
log.Printf("")
log.Printf("Option 1: Using the Zig CLI (Recommended)")
log.Printf(" 1. Build the CLI: cd cli && make build")
log.Printf(" 2. Initialize config: ./cli/zig-out/bin/ml init")
log.Printf(" 3. Edit ~/.ml/config.toml with your settings")
log.Printf(" 4. Run TUI: ./bin/tui")
log.Printf("")
log.Printf("Option 2: Manual Configuration")
log.Printf(" 1. Create directory: mkdir -p ~/.ml")
log.Printf(" 2. Create config: touch ~/.ml/config.toml")
log.Printf(" 3. Add your settings to the file")
log.Printf(" 4. Run TUI: ./bin/tui")
log.Printf("")
log.Printf("Example ~/.ml/config.toml:")
log.Printf(" worker_host = \"localhost\"")
log.Printf(" worker_user = \"your_username\"")
log.Printf(" worker_base = \"~/ml_jobs\"")
log.Printf(" worker_port = 22")
log.Printf(" api_key = \"your_api_key_here\"")
log.Printf("")
log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs")
os.Exit(1)
}
}
cfg = cliConfig.ToTUIConfig()
log.Printf("Loaded TOML configuration from %s", cliConfPath)
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
}
if err := cfg.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Test authentication if enabled
if cfg.Auth.Enabled {
// Use API key from CLI config if available, otherwise use from flags
var effectiveAPIKey string
if cliConfig != nil && cliConfig.APIKey != "" {
effectiveAPIKey = cliConfig.APIKey
} else if apiKey != "" {
effectiveAPIKey = apiKey
} else {
log.Fatal("Authentication required but no API key provided")
}
if _, err := cfg.Auth.ValidateAPIKey(effectiveAPIKey); err != nil {
log.Fatalf("Authentication failed: %v", err)
}
}
srv, err := services.NewMLServer(cfg)
if err != nil {
log.Fatalf("Failed to connect to server: %v", err)
}
defer func() {
if err := srv.Close(); err != nil {
log.Printf("server close error: %v", err)
}
}()
tq, err := services.NewTaskQueue(cfg)
if err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
defer func() {
if err := tq.Close(); err != nil {
log.Printf("task queue close error: %v", err)
}
}()
// Initialize logger
// Note: In original code, logger was created inside initialModel.
// Here we create it and pass it to controller.
// We use slog.LevelError as default from original code.
// But original code imported "log/slog".
// We use internal/logging package.
// Check logging package signature.
// Original: logger := logging.NewLogger(slog.LevelError, false)
// We need to import "log/slog" in main if we use slog constants.
// Or use logging package constants if available.
// Let's check logging package.
// Assuming logging.NewLogger takes (slog.Level, bool).
// I'll import "log/slog".
// Wait, I need to import "log/slog"
logger := logging.NewLogger(-4, false) // -4 is slog.LevelError value. Or I can import log/slog.
// Initialize State and Controller
var effectiveAPIKey string
if cliConfig != nil && cliConfig.APIKey != "" {
effectiveAPIKey = cliConfig.APIKey
} else {
effectiveAPIKey = apiKey
}
initialState := model.InitialState(effectiveAPIKey)
ctrl := controller.New(cfg, srv, tq, logger)
appModel := AppModel{
state: initialState,
controller: ctrl,
}
// Run TUI app
p := tea.NewProgram(appModel, tea.WithAltScreen(), tea.WithMouseAllMotion())
// Ensure we restore the terminal even if panic or error occurs
// Note: p.Run() usually handles this, but explicit cleanup is safer
// if we want to ensure the alt screen is exited.
// We can't defer p.ReleaseTerminal() here because p is created here.
// But we can defer a function that calls it.
// Set up signal handling for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Run program and handle signals
go func() {
<-sigChan
p.Quit()
}()
if _, err := p.Run(); err != nil {
// Attempt to restore terminal before logging fatal error
p.ReleaseTerminal()
log.Fatalf("Error running TUI: %v", err)
}
// Explicitly restore terminal after program exits
p.ReleaseTerminal()
}

175
cmd/user_manager/main.go Normal file
View file

@ -0,0 +1,175 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"strings"
"github.com/jfraeys/fetch_ml/internal/auth"
"gopkg.in/yaml.v3"
)
type ConfigWithAuth struct {
Auth auth.AuthConfig `yaml:"auth"`
}
func main() {
var (
configFile = flag.String("config", "", "Configuration file path")
command = flag.String("cmd", "", "Command: generate-key, list-users, hash-key")
username = flag.String("username", "", "Username for generate-key")
role = flag.String("role", "", "Role for generate-key")
admin = flag.Bool("admin", false, "Admin flag for generate-key")
apiKey = flag.String("key", "", "API key to hash")
)
flag.Parse()
if *configFile == "" || *command == "" {
fmt.Println("Usage: user_manager --config <config.yaml> --cmd <command> [options]")
fmt.Println("Commands: generate-key, list-users, hash-key")
os.Exit(1)
}
switch *command {
case "generate-key":
if *username == "" {
log.Fatal("Usage: --cmd generate-key --username <name> [--admin] [--role <role>]")
}
// Load config
data, err := os.ReadFile(*configFile)
if err != nil {
log.Fatalf("Failed to read config: %v", err)
}
var config ConfigWithAuth
if err := yaml.Unmarshal(data, &config); err != nil {
log.Fatalf("Failed to parse config: %v", err)
}
// Generate API key
apiKey := auth.GenerateAPIKey()
// Setup user
if config.Auth.APIKeys == nil {
config.Auth.APIKeys = make(map[auth.Username]auth.APIKeyEntry)
}
adminStatus := *admin
roles := []string{"viewer"}
permissions := make(map[string]bool)
if !adminStatus && *role == "" {
fmt.Printf("Make user '%s' an admin? (y/N): ", *username)
var response string
fmt.Scanln(&response)
adminStatus = strings.ToLower(strings.TrimSpace(response)) == "y"
}
if adminStatus {
roles = []string{"admin"}
permissions["*"] = true
} else if *role != "" {
roles = []string{*role}
rolePerms := getRolePermissions(*role)
for perm, value := range rolePerms {
permissions[perm] = value
}
}
// Save user
config.Auth.APIKeys[auth.Username(*username)] = auth.APIKeyEntry{
Hash: auth.APIKeyHash(auth.HashAPIKey(apiKey)),
Admin: adminStatus,
Roles: roles,
Permissions: permissions,
}
data, err = yaml.Marshal(config)
if err != nil {
log.Fatalf("Failed to marshal config: %v", err)
}
if err := os.WriteFile(*configFile, data, 0600); err != nil {
log.Fatalf("Failed to write config: %v", err)
}
fmt.Printf("Generated API key for user '%s':\nKey: %s\n", *username, apiKey)
case "list-users":
data, err := os.ReadFile(*configFile)
if err != nil {
log.Fatalf("Failed to read config: %v", err)
}
var config ConfigWithAuth
if err := yaml.Unmarshal(data, &config); err != nil {
log.Fatalf("Failed to parse config: %v", err)
}
fmt.Println("Configured Users:")
fmt.Println("=================")
for username, entry := range config.Auth.APIKeys {
fmt.Printf("User: %s\n", string(username))
fmt.Printf(" Admin: %v\n", entry.Admin)
if len(entry.Roles) > 0 {
fmt.Printf(" Roles: %v\n", entry.Roles)
}
if len(entry.Permissions) > 0 {
fmt.Printf(" Permissions: %d\n", len(entry.Permissions))
}
fmt.Printf(" Key Hash: %s...\n\n", string(entry.Hash)[:8])
}
case "hash-key":
if *apiKey == "" {
log.Fatal("Usage: --cmd hash-key --key <api-key>")
}
hash := auth.HashAPIKey(*apiKey)
fmt.Printf("Hash: %s\n", hash)
default:
log.Fatalf("Unknown command: %s", *command)
}
}
// getRolePermissions returns permissions for a role
func getRolePermissions(role string) map[string]bool {
rolePermissions := map[string]map[string]bool{
"admin": {
"*": true,
},
"data_scientist": {
"jobs:create": true,
"jobs:read": true,
"jobs:update": true,
"data:read": true,
"models:read": true,
},
"data_engineer": {
"data:create": true,
"data:read": true,
"data:update": true,
"data:delete": true,
},
"viewer": {
"jobs:read": true,
"data:read": true,
"models:read": true,
"metrics:read": true,
},
"operator": {
"jobs:read": true,
"jobs:update": true,
"metrics:read": true,
"system:read": true,
},
}
if perms, exists := rolePermissions[role]; exists {
return perms
}
return make(map[string]bool)
}

173
cmd/worker/worker_config.go Normal file
View file

@ -0,0 +1,173 @@
package main
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"gopkg.in/yaml.v3"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
datasetCacheDefaultTTL = 30 * time.Minute
)
// Config holds worker configuration
type Config struct {
Host string `yaml:"host"`
User string `yaml:"user"`
SSHKey string `yaml:"ssh_key"`
Port int `yaml:"port"`
BasePath string `yaml:"base_path"`
TrainScript string `yaml:"train_script"`
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
KnownHosts string `yaml:"known_hosts"`
WorkerID string `yaml:"worker_id"`
MaxWorkers int `yaml:"max_workers"`
PollInterval int `yaml:"poll_interval_seconds"`
// Authentication
Auth auth.AuthConfig `yaml:"auth"`
// Metrics exporter
Metrics MetricsConfig `yaml:"metrics"`
// Metrics buffering
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
// Data management
DataManagerPath string `yaml:"data_manager_path"`
AutoFetchData bool `yaml:"auto_fetch_data"`
DataDir string `yaml:"data_dir"`
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
// Podman execution
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUAccess bool `yaml:"gpu_access"`
// Task lease and retry settings
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min)
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min)
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min)
}
// MetricsConfig controls the Prometheus exporter.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_addr"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := config.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
}
if cfg.MaxWorkers == 0 {
cfg.MaxWorkers = smart.MaxWorkers()
}
if cfg.PollInterval == 0 {
cfg.PollInterval = smart.PollInterval()
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
}
if cfg.DataDir == "" {
if cfg.Host == "" || !cfg.AutoFetchData {
cfg.DataDir = config.DefaultLocalDataDir
} else {
cfg.DataDir = smart.DataDir()
}
}
if cfg.Metrics.ListenAddr == "" {
cfg.Metrics.ListenAddr = ":9100"
}
if cfg.MetricsFlushInterval == 0 {
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
// Set lease and retry defaults
if cfg.TaskLeaseDuration == 0 {
cfg.TaskLeaseDuration = 30 * time.Minute
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 1 * time.Minute
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.GracefulTimeout == 0 {
cfg.GracefulTimeout = 5 * time.Minute
}
return &cfg, nil
}
// Validate implements config.Validator interface
func (c *Config) Validate() error {
if c.Port != 0 {
if err := config.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = config.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
if c.MaxWorkers < 1 {
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
}
return nil
}
// Task struct and Redis constants moved to internal/queue

883
cmd/worker/worker_server.go Normal file
View file

@ -0,0 +1,883 @@
// Package main implements the ML task worker
package main
import (
"context"
"fmt"
"log"
"log/slog"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errors"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// MLServer wraps network.SSHClient for backward compatibility
type MLServer struct {
*network.SSHClient
}
func NewMLServer(cfg *Config) (*MLServer, error) {
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
return &MLServer{SSHClient: client}, nil
}
type Worker struct {
id string
config *Config
server *MLServer
queue *queue.TaskQueue
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
runningMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
metrics *metrics.Metrics
metricsSrv *http.Server
datasetCache map[string]time.Time
datasetCacheMu sync.RWMutex
datasetCacheTTL time.Duration
// Graceful shutdown fields
shutdownCh chan struct{}
activeTasks sync.Map // map[string]*queue.Task - track active tasks
gracefulWait sync.WaitGroup
}
func (w *Worker) setupMetricsExporter() error {
if !w.config.Metrics.Enabled {
return nil
}
reg := prometheus.NewRegistry()
reg.MustRegister(
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
collectors.NewGoCollector(),
)
labels := prometheus.Labels{"worker_id": w.id}
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_processed_total",
Help: "Total tasks processed successfully by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksProcessed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_failed_total",
Help: "Total tasks failed by this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.TasksFailed.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_active",
Help: "Number of tasks currently running on this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.runningCount())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_tasks_queued",
Help: "Latest observed queue depth from Redis.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.QueuedTasks.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_transferred_bytes_total",
Help: "Total bytes transferred while fetching datasets.",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataTransferred.Load())
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_data_fetch_time_seconds_total",
Help: "Total time spent fetching datasets (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_execution_time_seconds_total",
Help: "Total execution time for completed tasks (seconds).",
ConstLabels: labels,
}, func() float64 {
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_worker_max_concurrency",
Help: "Configured maximum concurrent tasks for this worker.",
ConstLabels: labels,
}, func() float64 {
return float64(w.config.MaxWorkers)
}))
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
srv := &http.Server{
Addr: w.config.Metrics.ListenAddr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
w.metricsSrv = srv
go func() {
w.logger.Info("metrics exporter listening",
"addr", w.config.Metrics.ListenAddr,
"enabled", w.config.Metrics.Enabled)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
w.logger.Warn("metrics exporter stopped",
"error", err)
}
}()
return nil
}
func NewWorker(cfg *Config, apiKey string) (*Worker, error) {
srv, err := NewMLServer(cfg)
if err != nil {
return nil, err
}
queueCfg := queue.Config{
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
RedisDB: cfg.RedisDB,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
return nil, err
}
// Create data_dir if it doesn't exist (for production without NAS)
if cfg.DataDir != "" {
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
}
}
ctx, cancel := context.WithCancel(context.Background())
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
baseLogger := logging.NewLogger(slog.LevelInfo, false)
logger := baseLogger.Component(ctx, "worker")
metrics := &metrics.Metrics{}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
server: srv,
queue: queue,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
ctx: ctx,
cancel: cancel,
logger: logger,
metrics: metrics,
shutdownCh: make(chan struct{}),
}
if err := worker.setupMetricsExporter(); err != nil {
return nil, err
}
return worker, nil
}
func (w *Worker) Start() {
w.logger.Info("worker started",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers,
"poll_interval", w.config.PollInterval)
go w.heartbeat()
for {
select {
case <-w.ctx.Done():
w.logger.Info("shutdown signal received, waiting for tasks")
w.waitForTasks()
return
default:
}
if w.runningCount() >= w.config.MaxWorkers {
time.Sleep(50 * time.Millisecond)
continue
}
queueStart := time.Now()
task, err := w.queue.GetNextTaskWithLease(w.config.WorkerID, w.config.TaskLeaseDuration)
queueLatency := time.Since(queueStart)
if err != nil {
if err == context.DeadlineExceeded {
continue
}
w.logger.Error("error fetching task",
"worker_id", w.id,
"error", err)
continue
}
if task == nil {
if queueLatency > 200*time.Millisecond {
w.logger.Debug("queue poll latency",
"latency_ms", queueLatency.Milliseconds())
}
continue
}
if depth, derr := w.queue.QueueDepth(); derr == nil {
if queueLatency > 100*time.Millisecond || depth > 0 {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"remaining_depth", depth)
}
} else if queueLatency > 100*time.Millisecond {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"depth_error", derr)
}
go w.executeTaskWithLease(task)
}
}
func (w *Worker) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if err := w.queue.Heartbeat(w.id); err != nil {
w.logger.Warn("heartbeat failed",
"worker_id", w.id,
"error", err)
}
}
}
}
// NEW: Fetch datasets using data_manager
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
logger := w.logger.Job(ctx, task.JobName, task.ID)
logger.Info("fetching datasets",
"worker_id", w.id,
"dataset_count", len(task.Datasets))
for _, dataset := range task.Datasets {
if w.datasetIsFresh(dataset) {
logger.Debug("skipping cached dataset",
"dataset", dataset)
continue
}
// Check for cancellation before each dataset fetch
select {
case <-w.ctx.Done():
return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err())
default:
}
logger.Info("fetching dataset",
"worker_id", w.id,
"dataset", dataset)
// Create command with context for cancellation support
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
cmd := exec.CommandContext(cmdCtx,
w.config.DataManagerPath,
"fetch",
task.JobName,
dataset,
)
output, err := cmd.CombinedOutput()
cancel() // Clean up context
if err != nil {
return &errors.DataFetchError{
Dataset: dataset,
JobName: task.JobName,
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
}
}
logger.Info("dataset ready",
"worker_id", w.id,
"dataset", dataset)
w.markDatasetFetched(dataset)
}
return nil
}
func (w *Worker) runJob(task *queue.Task) error {
// Validate job name to prevent path traversal
if err := container.ValidateJobName(task.JobName); err != nil {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
jobPaths := config.NewJobPaths(w.config.BasePath)
jobDir := filepath.Join(jobPaths.PendingPath(), task.JobName)
outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName)
logFile := filepath.Join(outputDir, "output.log")
// Sanitize paths
jobDir, err := container.SanitizePath(jobDir)
if err != nil {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
outputDir, err = container.SanitizePath(outputDir)
if err != nil {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
// Create output directory
if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return "", fmt.Errorf("mkdir failed: %w", err)
}
return "", nil
}); err != nil {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create output dir: %w", err),
}
}
// Move job from pending to running
stagingStart := time.Now()
if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) {
if err := os.Rename(jobDir, outputDir); err != nil {
return "", fmt.Errorf("rename failed: %w", err)
}
return "", nil
}); err != nil {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to move job: %w", err),
}
}
stagingDuration := time.Since(stagingStart)
if w.config.PodmanImage == "" {
return &errors.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: fmt.Errorf("podman_image must be configured"),
}
}
containerWorkspace := w.config.ContainerWorkspace
if containerWorkspace == "" {
containerWorkspace = config.DefaultContainerWorkspace
}
containerResults := w.config.ContainerResults
if containerResults == "" {
containerResults = config.DefaultContainerResults
}
podmanCfg := container.PodmanConfig{
Image: w.config.PodmanImage,
Workspace: filepath.Join(outputDir, "code"),
Results: filepath.Join(outputDir, "results"),
ContainerWorkspace: containerWorkspace,
ContainerResults: containerResults,
GPUAccess: w.config.GPUAccess,
}
scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript)
requirementsPath := filepath.Join(containerWorkspace, "requirements.txt")
var extraArgs []string
if task.Args != "" {
extraArgs = strings.Fields(task.Args)
}
ioBefore, ioErr := telemetry.ReadProcessIO()
podmanCmd := container.BuildPodmanCommand(podmanCfg, scriptPath, requirementsPath, extraArgs)
logFileHandle, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err == nil {
podmanCmd.Stdout = logFileHandle
podmanCmd.Stderr = logFileHandle
} else {
w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err)
}
w.logger.Info("executing podman job",
"job", task.JobName,
"image", w.config.PodmanImage,
"workspace", podmanCfg.Workspace,
"results", podmanCfg.Results)
containerStart := time.Now()
if err := podmanCmd.Run(); err != nil {
containerDuration := time.Since(containerStart)
// Move job to failed directory
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) {
if err := os.Rename(outputDir, failedDir); err != nil {
return "", fmt.Errorf("rename to failed failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr)
}
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
delta := telemetry.DiffIO(ioBefore, after)
w.logger.Debug("worker io stats",
"job", task.JobName,
"read_bytes", delta.ReadBytes,
"write_bytes", delta.WriteBytes)
}
}
w.logger.Info("job timing (failure)",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", 0,
"total_ms", time.Since(stagingStart).Milliseconds(),
)
return fmt.Errorf("execution failed: %w", err)
}
containerDuration := time.Since(containerStart)
finalizeStart := time.Now()
// Move job to finished directory
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) {
if err := os.Rename(outputDir, finishedDir); err != nil {
return "", fmt.Errorf("rename to finished failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr)
}
finalizeDuration := time.Since(finalizeStart)
totalDuration := time.Since(stagingStart)
var ioDelta telemetry.IOStats
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
ioDelta = telemetry.DiffIO(ioBefore, after)
}
}
w.logger.Info("job timing",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", finalizeDuration.Milliseconds(),
"total_ms", totalDuration.Milliseconds(),
"io_read_bytes", ioDelta.ReadBytes,
"io_write_bytes", ioDelta.WriteBytes,
)
return nil
}
func parseDatasets(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}
func (w *Worker) waitForTasks() {
timeout := time.After(5 * time.Minute)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-timeout:
w.logger.Warn("shutdown timeout, force stopping",
"running_tasks", len(w.running))
return
case <-ticker.C:
count := w.runningCount()
if count == 0 {
w.logger.Info("all tasks completed, shutting down")
return
}
w.logger.Debug("waiting for tasks to complete",
"remaining", count)
}
}
}
func (w *Worker) runningCount() int {
w.runningMu.RLock()
defer w.runningMu.RUnlock()
return len(w.running)
}
func (w *Worker) datasetIsFresh(dataset string) bool {
w.datasetCacheMu.RLock()
defer w.datasetCacheMu.RUnlock()
expires, ok := w.datasetCache[dataset]
return ok && time.Now().Before(expires)
}
func (w *Worker) markDatasetFetched(dataset string) {
expires := time.Now().Add(w.datasetCacheTTL)
w.datasetCacheMu.Lock()
w.datasetCache[dataset] = expires
w.datasetCacheMu.Unlock()
}
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
return stats
}
func (w *Worker) Stop() {
w.cancel()
w.waitForTasks()
// FIXED: Check error return values
if err := w.server.Close(); err != nil {
w.logger.Warn("error closing server connection", "error", err)
}
if err := w.queue.Close(); err != nil {
w.logger.Warn("error closing queue connection", "error", err)
}
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics exporter shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Execute task with lease management and retry:
func (w *Worker) executeTaskWithLease(task *queue.Task) {
// Track task for graceful shutdown
w.gracefulWait.Add(1)
w.activeTasks.Store(task.ID, task)
defer w.gracefulWait.Done()
defer w.activeTasks.Delete(task.ID)
// Create task-specific context with timeout
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
defer taskCancel()
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
logger.Info("starting task",
"worker_id", w.id,
"datasets", task.Datasets,
"priority", task.Priority)
// Record task start
w.metrics.RecordTaskStart()
defer w.metrics.RecordTaskCompletion()
// Check for context cancellation
select {
case <-taskCtx.Done():
logger.Info("task cancelled before execution")
return
default:
}
// Parse datasets from task arguments
if task.Datasets == nil {
task.Datasets = parseDatasets(task.Args)
}
// Start heartbeat goroutine
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer cancelHeartbeat()
go w.heartbeatLoop(heartbeatCtx, task.ID)
// Update task status
task.Status = "running"
now := time.Now()
task.StartedAt = &now
task.WorkerID = w.id
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
logger.Error("failed to update task status", "error", err)
w.metrics.RecordTaskFailure()
return
}
if w.config.AutoFetchData && len(task.Datasets) > 0 {
if err := w.fetchDatasets(taskCtx, task); err != nil {
logger.Error("data fetch failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
err := w.queue.UpdateTask(task)
if err != nil {
logger.Error("failed to update task status after data fetch failure", "error", err)
}
w.metrics.RecordTaskFailure()
return
}
}
// Execute job with panic recovery
var execErr error
func() {
defer func() {
if r := recover(); r != nil {
execErr = fmt.Errorf("panic during execution: %v", r)
}
}()
execErr = w.runJob(task)
}()
// Finalize task
endTime := time.Now()
task.EndedAt = &endTime
if execErr != nil {
task.Error = execErr.Error()
// Check if transient error (network, timeout, etc)
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
w.logger.Warn("task failed with transient error, will retry",
"task_id", task.ID,
"error", execErr,
"retry_count", task.RetryCount)
w.queue.RetryTask(task)
} else {
task.Status = "failed"
w.queue.UpdateTaskWithMetrics(task, "final")
}
} else {
task.Status = "completed"
w.queue.UpdateTaskWithMetrics(task, "final")
}
// Release lease
w.queue.ReleaseLease(task.ID, w.config.WorkerID)
}
// Heartbeat loop to renew lease:
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
ticker := time.NewTicker(w.config.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
return
}
// Also update worker heartbeat
w.queue.Heartbeat(w.config.WorkerID)
}
}
}
// Graceful shutdown:
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
// Wait for active tasks with timeout
done := make(chan struct{})
go func() {
w.gracefulWait.Wait()
close(done)
}()
timeout := time.After(w.config.GracefulTimeout)
select {
case <-done:
w.logger.Info("all tasks completed, shutdown successful")
case <-timeout:
w.logger.Warn("graceful shutdown timeout, releasing active leases")
w.releaseAllLeases()
}
return w.queue.Close()
}
// Release all active leases:
func (w *Worker) releaseAllLeases() {
w.activeTasks.Range(func(key, value interface{}) bool {
taskID := key.(string)
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
}
return true
})
}
// Helper functions:
func (w *Worker) countActiveTasks() int {
count := 0
w.activeTasks.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
func isTransientError(err error) bool {
if err == nil {
return false
}
// Check if error is transient (network, timeout, resource unavailable, etc)
errStr := err.Error()
transientIndicators := []string{
"connection refused",
"timeout",
"temporary failure",
"resource temporarily unavailable",
"no such host",
"network unreachable",
}
for _, indicator := range transientIndicators {
if strings.Contains(strings.ToLower(errStr), indicator) {
return true
}
}
return false
}
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateAuthFlags(authFlags); err != nil {
log.Fatalf("Authentication flag error: %v", err)
}
// Get API key from various sources
apiKey := auth.GetAPIKeyFromSources(authFlags)
// Load configuration
configPath := "config-local.yaml"
if authFlags.ConfigFile != "" {
configPath = authFlags.ConfigFile
}
resolvedConfig, err := config.ResolveConfigPath(configPath)
if err != nil {
log.Fatalf("%v", err)
}
cfg, err := LoadConfig(resolvedConfig)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
}
// Validate configuration
if err := cfg.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Test authentication if enabled
if cfg.Auth.Enabled && apiKey != "" {
user, err := cfg.Auth.ValidateAPIKey(apiKey)
if err != nil {
log.Fatalf("Authentication failed: %v", err)
}
log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin)
} else if cfg.Auth.Enabled {
log.Fatal("Authentication required but no API key provided")
}
worker, err := NewWorker(cfg, apiKey)
if err != nil {
log.Fatalf("Failed to create worker: %v", err)
}
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go worker.Start()
sig := <-sigChan
log.Printf("Received signal: %v", sig)
// Use graceful shutdown
if err := worker.Shutdown(); err != nil {
log.Printf("Graceful shutdown error: %v", err)
worker.Stop() // Fallback to force stop
} else {
log.Println("Worker shut down gracefully")
}
}

View file

@ -0,0 +1,78 @@
package main
import (
"fmt"
"log"
"os"
"github.com/jfraeys/fetch_ml/internal/auth"
"gopkg.in/yaml.v3"
)
// Example: How to integrate auth into TUI startup
func checkAuth(configFile string) error {
// Load config
data, err := os.ReadFile(configFile)
if err != nil {
return fmt.Errorf("failed to read config: %w", err)
}
var cfg struct {
Auth auth.AuthConfig `yaml:"auth"`
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse config: %w", err)
}
// If auth disabled, proceed normally
if !cfg.Auth.Enabled {
fmt.Println("🔓 Authentication disabled - proceeding normally")
return nil
}
// Check for API key
apiKey := os.Getenv("FETCH_ML_API_KEY")
if apiKey == "" {
apiKey = getAPIKeyFromUser()
}
// Validate API key
user, err := cfg.Auth.ValidateAPIKey(apiKey)
if err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
fmt.Printf("🔐 Authenticated as: %s", user.Name)
if user.Admin {
fmt.Println(" (admin)")
} else {
fmt.Println()
}
return nil
}
func getAPIKeyFromUser() string {
fmt.Print("🔑 Enter API key: ")
var key string
fmt.Scanln(&key)
return key
}
// Example usage in main()
func exampleMain() {
configFile := "config_dev.yaml"
// Check authentication first
if err := checkAuth(configFile); err != nil {
log.Fatalf("Authentication failed: %v", err)
}
// Proceed with normal TUI initialization
fmt.Println("Starting TUI...")
}
func main() {
exampleMain()
}

View file

@ -0,0 +1,117 @@
package api
import (
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/queue"
)
func TestUserPermissions(t *testing.T) {
authConfig := &auth.AuthConfig{
Enabled: true,
APIKeys: map[auth.Username]auth.APIKeyEntry{
"admin": {
Hash: auth.APIKeyHash(auth.HashAPIKey("admin_key")),
Admin: true,
},
"scientist": {
Hash: auth.APIKeyHash(auth.HashAPIKey("ds_key")),
Admin: false,
Permissions: map[string]bool{
"jobs:create": true,
"jobs:read": true,
"jobs:update": true,
},
},
},
}
tests := []struct {
name string
apiKey string
permission string
want bool
}{
{"Admin can create", "admin_key", "jobs:create", true},
{"Scientist can create", "ds_key", "jobs:create", true},
{"Invalid key fails", "invalid_key", "jobs:create", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := authConfig.ValidateAPIKey(tt.apiKey)
if tt.apiKey == "invalid_key" {
if err == nil {
t.Error("Expected error for invalid API key")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
got := user.HasPermission(tt.permission)
if got != tt.want {
t.Errorf("HasPermission() = %v, want %v", got, tt.want)
}
})
}
}
func TestTaskOwnership(t *testing.T) {
tasks := []*queue.Task{
{
ID: "task1",
JobName: "user1_job",
UserID: "user1",
CreatedBy: "user1",
CreatedAt: time.Now(),
},
{
ID: "task2",
JobName: "user2_job",
UserID: "user2",
CreatedBy: "user2",
CreatedAt: time.Now(),
},
}
users := map[string]*auth.User{
"user1": {Name: "user1", Admin: false},
"user2": {Name: "user2", Admin: false},
"admin": {Name: "admin", Admin: true},
}
tests := []struct {
name string
userName string
task *queue.Task
want bool
}{
{"User can view own task", "user1", tasks[0], true},
{"User cannot view other task", "user1", tasks[1], false},
{"Admin can view any task", "admin", tasks[1], true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user := users[tt.userName]
canAccess := false
if user.Admin {
canAccess = true
} else if tt.task.UserID == user.Name || tt.task.CreatedBy == user.Name {
canAccess = true
}
if canAccess != tt.want {
t.Errorf("Access = %v, want %v", canAccess, tt.want)
}
})
}
}

305
internal/api/protocol.go Normal file
View file

@ -0,0 +1,305 @@
package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"time"
)
// Response packet types
const (
PacketTypeSuccess = 0x00
PacketTypeError = 0x01
PacketTypeProgress = 0x02
PacketTypeStatus = 0x03
PacketTypeData = 0x04
PacketTypeLog = 0x05
)
// Error codes
const (
ErrorCodeUnknownError = 0x00
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeResourceAlreadyExists = 0x05
ErrorCodeServerOverloaded = 0x10
ErrorCodeDatabaseError = 0x11
ErrorCodeNetworkError = 0x12
ErrorCodeStorageError = 0x13
ErrorCodeTimeout = 0x14
ErrorCodeJobNotFound = 0x20
ErrorCodeJobAlreadyRunning = 0x21
ErrorCodeJobFailedToStart = 0x22
ErrorCodeJobExecutionFailed = 0x23
ErrorCodeJobCancelled = 0x24
ErrorCodeOutOfMemory = 0x30
ErrorCodeDiskFull = 0x31
ErrorCodeInvalidConfiguration = 0x32
ErrorCodeServiceUnavailable = 0x33
)
// Progress types
const (
ProgressTypePercentage = 0x00
ProgressTypeStage = 0x01
ProgressTypeMessage = 0x02
ProgressTypeBytesTransferred = 0x03
)
// Log levels
const (
LogLevelDebug = 0x00
LogLevelInfo = 0x01
LogLevelWarn = 0x02
LogLevelError = 0x03
)
// ResponsePacket represents a structured response packet
type ResponsePacket struct {
PacketType byte
Timestamp uint64
// Success fields
SuccessMessage string
// Error fields
ErrorCode byte
ErrorMessage string
ErrorDetails string
// Progress fields
ProgressType byte
ProgressValue uint32
ProgressTotal uint32
ProgressMessage string
// Status fields
StatusData string
// Data fields
DataType string
DataPayload []byte
// Log fields
LogLevel byte
LogMessage string
}
// NewSuccessPacket creates a success response packet
func NewSuccessPacket(message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeSuccess,
Timestamp: uint64(time.Now().Unix()),
SuccessMessage: message,
}
}
// NewSuccessPacketWithPayload creates a success response packet with JSON payload
func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponsePacket {
// Convert payload to JSON for the DataPayload field
payloadBytes, _ := json.Marshal(payload)
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()),
SuccessMessage: message,
DataType: "status",
DataPayload: payloadBytes,
}
}
// NewErrorPacket creates an error response packet
func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeError,
Timestamp: uint64(time.Now().Unix()),
ErrorCode: errorCode,
ErrorMessage: message,
ErrorDetails: details,
}
}
// NewProgressPacket creates a progress response packet
func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeProgress,
Timestamp: uint64(time.Now().Unix()),
ProgressType: progressType,
ProgressValue: value,
ProgressTotal: total,
ProgressMessage: message,
}
}
// NewStatusPacket creates a status response packet
func NewStatusPacket(data string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeStatus,
Timestamp: uint64(time.Now().Unix()),
StatusData: data,
}
}
// NewDataPacket creates a data response packet
func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()),
DataType: dataType,
DataPayload: payload,
}
}
// NewLogPacket creates a log response packet
func NewLogPacket(level byte, message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeLog,
Timestamp: uint64(time.Now().Unix()),
LogLevel: level,
LogMessage: message,
}
}
// Serialize converts the packet to binary format
func (p *ResponsePacket) Serialize() ([]byte, error) {
var buf []byte
// Packet type
buf = append(buf, p.PacketType)
// Timestamp (8 bytes, big-endian)
timestampBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timestampBytes, p.Timestamp)
buf = append(buf, timestampBytes...)
// Packet-specific data
switch p.PacketType {
case PacketTypeSuccess:
buf = append(buf, serializeString(p.SuccessMessage)...)
case PacketTypeError:
buf = append(buf, p.ErrorCode)
buf = append(buf, serializeString(p.ErrorMessage)...)
buf = append(buf, serializeString(p.ErrorDetails)...)
case PacketTypeProgress:
buf = append(buf, p.ProgressType)
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, p.ProgressValue)
buf = append(buf, valueBytes...)
totalBytes := make([]byte, 4)
binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal)
buf = append(buf, totalBytes...)
buf = append(buf, serializeString(p.ProgressMessage)...)
case PacketTypeStatus:
buf = append(buf, serializeString(p.StatusData)...)
case PacketTypeData:
buf = append(buf, serializeString(p.DataType)...)
buf = append(buf, serializeBytes(p.DataPayload)...)
case PacketTypeLog:
buf = append(buf, p.LogLevel)
buf = append(buf, serializeString(p.LogMessage)...)
default:
return nil, fmt.Errorf("unknown packet type: %d", p.PacketType)
}
return buf, nil
}
// serializeString writes a string with 2-byte length prefix
func serializeString(s string) []byte {
length := uint16(len(s))
buf := make([]byte, 2+len(s))
binary.BigEndian.PutUint16(buf[:2], length)
copy(buf[2:], s)
return buf
}
// serializeBytes writes bytes with 4-byte length prefix
func serializeBytes(b []byte) []byte {
length := uint32(len(b))
buf := make([]byte, 4+len(b))
binary.BigEndian.PutUint32(buf[:4], length)
copy(buf[4:], b)
return buf
}
// GetErrorMessage returns a human-readable error message for an error code
func GetErrorMessage(code byte) string {
switch code {
case ErrorCodeUnknownError:
return "Unknown error occurred"
case ErrorCodeInvalidRequest:
return "Invalid request format"
case ErrorCodeAuthenticationFailed:
return "Authentication failed"
case ErrorCodePermissionDenied:
return "Permission denied"
case ErrorCodeResourceNotFound:
return "Resource not found"
case ErrorCodeResourceAlreadyExists:
return "Resource already exists"
case ErrorCodeServerOverloaded:
return "Server is overloaded"
case ErrorCodeDatabaseError:
return "Database error occurred"
case ErrorCodeNetworkError:
return "Network error occurred"
case ErrorCodeStorageError:
return "Storage error occurred"
case ErrorCodeTimeout:
return "Operation timed out"
case ErrorCodeJobNotFound:
return "Job not found"
case ErrorCodeJobAlreadyRunning:
return "Job is already running"
case ErrorCodeJobFailedToStart:
return "Job failed to start"
case ErrorCodeJobExecutionFailed:
return "Job execution failed"
case ErrorCodeJobCancelled:
return "Job was cancelled"
case ErrorCodeOutOfMemory:
return "Server out of memory"
case ErrorCodeDiskFull:
return "Server disk full"
case ErrorCodeInvalidConfiguration:
return "Invalid server configuration"
case ErrorCodeServiceUnavailable:
return "Service temporarily unavailable"
default:
return "Unknown error code"
}
}
// GetLogLevelName returns the name for a log level
func GetLogLevelName(level byte) string {
switch level {
case LogLevelDebug:
return "DEBUG"
case LogLevelInfo:
return "INFO"
case LogLevelWarn:
return "WARN"
case LogLevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}

606
internal/api/ws.go Normal file
View file

@ -0,0 +1,606 @@
package api
import (
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"golang.org/x/crypto/acme/autocert"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// Allow localhost and homelab origins for development
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Parse origin URL
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
// Allow localhost and local network origins
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasSuffix(host, ":8081") ||
strings.HasPrefix(host, "localhost") ||
strings.HasPrefix(host, "127.0.0.1") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
}
type WSHandler struct {
authConfig *auth.AuthConfig
logger *logging.Logger
expManager *experiment.Manager
queue *queue.TaskQueue
}
func NewWSHandler(authConfig *auth.AuthConfig, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger,
expManager: expManager,
queue: taskQueue,
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check API key before upgrading WebSocket
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
// Also check Authorization header
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
}
}
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send error response
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
if len(payload) < 130 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
priority := int64(payload[128])
jobNameLen := int(payload[129])
if len(payload) < 130+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[130 : 130+jobNameLen])
h.logger.Info("queue job request",
"job", jobName,
"priority", priority,
"commit_id", commitID,
)
// Validate API key and get user information
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
// Check user permissions
if !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
}
// Create experiment directory and metadata
if err := h.expManager.CreateExperiment(commitID); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
// Add user info to experiment metadata
meta := &experiment.Metadata{
CommitID: commitID,
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if err := h.expManager.WriteMetadata(meta); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
}
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
// Enqueue task if queue is available
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "", // TODO: Add args support
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitID,
"user_id": user.Name,
"username": user.Name,
},
}
if err := h.queue.AddTask(task); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64]
if len(payload) < 64 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
}
apiKeyHash := string(payload[0:64])
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
// Validate API key and get user information
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
// Check user permissions for viewing jobs
if !h.authConfig.Enabled || user.HasPermission("jobs:read") {
// Continue with status request
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
}
// Get tasks with user filtering
var tasks []*queue.Task
if h.queue != nil {
allTasks, err := h.queue.GetAllTasks()
if err != nil {
h.logger.Error("failed to get tasks", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
}
// Filter tasks based on user permissions
for _, task := range allTasks {
// If auth is disabled or admin can see all tasks
if !h.authConfig.Enabled || user.Admin {
tasks = append(tasks, task)
continue
}
// Users can only see their own tasks
if task.UserID == user.Name || task.CreatedBy == user.Name {
tasks = append(tasks, task)
}
}
}
// Build status response with user-specific data
status := map[string]interface{}{
"user": map[string]interface{}{
"name": user.Name,
"admin": user.Admin,
"roles": user.Roles,
},
"tasks": map[string]interface{}{
"total": len(tasks),
"queued": countTasksByStatus(tasks, "queued"),
"running": countTasksByStatus(tasks, "running"),
"failed": countTasksByStatus(tasks, "failed"),
"completed": countTasksByStatus(tasks, "completed"),
},
"queue": tasks, // Include filtered tasks
}
packet := NewSuccessPacketWithPayload("Status retrieved", status)
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
// countTasksByStatus counts tasks by their status
func countTasksByStatus(tasks []*queue.Task, status string) int {
count := 0
for _, task := range tasks {
if task.Status == status {
count++
}
}
return count
}
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
if len(payload) < 65 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
jobNameLen := int(payload[64])
if len(payload) < 65+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[65 : 65+jobNameLen])
h.logger.Info("cancel job request", "job", jobName)
// Validate API key and get user information
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
// Check user permissions for canceling jobs
if !h.authConfig.Enabled || user.HasPermission("jobs:update") {
// Continue with cancel request
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
}
// Find the task and verify ownership
if h.queue != nil {
task, err := h.queue.GetTaskByName(jobName)
if err != nil {
h.logger.Error("task not found", "job", jobName, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
}
// Check if user can cancel this task (admin or owner)
if !h.authConfig.Enabled || user.Admin || task.UserID == user.Name || task.CreatedBy == user.Name {
// User can cancel the task
} else {
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
}
// Cancel the task
if err := h.queue.CancelTask(task.ID); err != nil {
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
}
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
}
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][prune_type:1][value:4]
if len(payload) < 69 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
pruneType := payload[64]
value := binary.BigEndian.Uint32(payload[65:69])
h.logger.Info("prune request", "type", pruneType, "value", value)
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
h.logger.Error("api key verification failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
// Convert prune parameters
var keepCount int
var olderThanDays int
switch pruneType {
case 0:
// keep N
keepCount = int(value)
olderThanDays = 0
case 1:
// older than days
keepCount = 0
olderThanDays = int(value)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
}
// Perform pruning
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
if err != nil {
h.logger.Error("prune failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
}
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
// Send structured success response
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
return h.sendResponsePacket(conn, packet)
}
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
if len(payload) < 141 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
step := int(binary.BigEndian.Uint32(payload[128:132]))
valueBits := binary.BigEndian.Uint64(payload[132:140])
value := math.Float64frombits(valueBits)
nameLen := int(payload[140])
if len(payload) < 141+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
}
name := string(payload[141 : 141+nameLen])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
h.logger.Error("failed to log metric", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
}
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64]
if len(payload) < 128 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
meta, err := h.expManager.ReadMetadata(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
}
metrics, err := h.expManager.GetMetrics(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
}
response := map[string]interface{}{
"metadata": meta,
"metrics": metrics,
}
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
}
// Helper to hash API key for comparison
func HashAPIKey(apiKey string) string {
hash := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(hash[:])
}
// SetupTLSConfig creates TLS configuration for WebSocket server
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
var server *http.Server
if certFile != "" && keyFile != "" {
// Use provided certificates
server = &http.Server{
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
} else if host != "" {
// Use Let's Encrypt with autocert
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(host),
Cache: autocert.DirCache("/var/www/.cache"),
}
server = &http.Server{
TLSConfig: certManager.TLSConfig(),
}
}
return server, nil
}
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
if h.authConfig == nil || !h.authConfig.Enabled {
return nil // No auth required
}
// For now, just check if it's a valid 64-char hex string
if len(hexHash) != 64 {
return fmt.Errorf("invalid api key hash length")
}
// Check against stored API keys
for username, entry := range h.authConfig.APIKeys {
if string(entry.Hash) == hexHash {
_ = username // Username found but not needed for verification
return nil // Valid API key found
}
}
return fmt.Errorf("invalid api key")
}
// sendErrorPacket sends an error response packet
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
packet := NewErrorPacket(errorCode, message, details)
return h.sendResponsePacket(conn, packet)
}
// sendResponsePacket sends a structured response packet
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
data, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize response packet", "error", err)
// Fallback to simple error response
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
}
return conn.WriteMessage(websocket.BinaryMessage, data)
}
// sendErrorResponse removed (unused)

335
internal/api/ws_test.go Normal file
View file

@ -0,0 +1,335 @@
package api
import (
"encoding/binary"
"math"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) {
// Setup miniredis
s, err := miniredis.Run()
require.NoError(t, err)
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
}
tq, err := queue.NewTaskQueue(queueCfg)
require.NoError(t, err)
// Setup dependencies
logger := logging.NewLogger(0, false)
expManager := experiment.NewManager(t.TempDir())
authCfg := &auth.AuthConfig{Enabled: false}
// Create handler
handler := NewWSHandler(authCfg, logger, expManager, tq)
// Setup test server
server := httptest.NewServer(handler)
return server, tq, expManager, s
}
func connectWS(t *testing.T, serverURL string) *websocket.Conn {
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
return ws
}
func TestWSHandler_QueueJob(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare queue_job message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
opcode := byte(OpcodeQueueJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := make([]byte, 64)
copy(commitID, []byte(strings.Repeat("a", 64)))
priority := byte(5)
jobName := "test-job"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, priority)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeSuccess = 0x00)
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
// Verify task in Redis
time.Sleep(100 * time.Millisecond)
task, err := tq.GetNextTask()
require.NoError(t, err)
require.NotNil(t, task)
assert.Equal(t, jobName, task.JobName)
}
func TestWSHandler_StatusRequest(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-1",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare status_request message
// Protocol: [opcode:1][api_key_hash:64]
opcode := byte(OpcodeStatusRequest)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData = 0x04 for status with payload)
assert.Equal(t, byte(PacketTypeData), resp[0])
}
func TestWSHandler_CancelJob(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-to-cancel",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user", // Auth disabled so this matches any user
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare cancel_job message
// Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var]
opcode := byte(OpcodeCancelJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
jobName := "job-to-cancel"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
// Verify task cancelled
updatedTask, err := tq.GetTask("task-1")
require.NoError(t, err)
assert.Equal(t, "cancelled", updatedTask.Status)
}
func TestWSHandler_Prune(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create some experiments
_ = expManager.CreateExperiment("commit-1")
_ = expManager.CreateExperiment("commit-2")
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare prune message
// Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4]
opcode := byte(OpcodePrune)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
pruneType := byte(0) // Keep N
value := uint32(1) // Keep 1
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, value)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, pruneType)
msg = append(msg, valueBytes...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
}
func TestWSHandler_LogMetric(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create experiment
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare log_metric message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
opcode := byte(OpcodeLogMetric)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
step := uint32(100)
value := 0.95
valueBits := math.Float64bits(value)
metricName := "accuracy"
nameLen := byte(len(metricName))
stepBytes := make([]byte, 4)
binary.BigEndian.PutUint32(stepBytes, step)
valueBytes := make([]byte, 8)
binary.BigEndian.PutUint64(valueBytes, valueBits)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, stepBytes...)
msg = append(msg, valueBytes...)
msg = append(msg, nameLen)
msg = append(msg, []byte(metricName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
}
func TestWSHandler_GetExperiment(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create experiment and metadata
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: "test-job",
}
err = expManager.WriteMetadata(meta)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare get_experiment message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64]
opcode := byte(OpcodeGetExperiment)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData)
assert.Equal(t, byte(PacketTypeData), resp[0])
}

258
internal/auth/api_key.go Normal file
View file

@ -0,0 +1,258 @@
package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"log"
"net/http"
"os"
"strings"
"time"
)
// User represents an authenticated user
type User struct {
Name string `json:"name"`
Admin bool `json:"admin"`
Roles []string `json:"roles"`
Permissions map[string]bool `json:"permissions"`
}
// APIKeyHash represents a SHA256 hash of an API key
type APIKeyHash string
// APIKeyEntry represents an API key configuration
type APIKeyEntry struct {
Hash APIKeyHash `json:"hash"`
Admin bool `json:"admin"`
Roles []string `json:"roles,omitempty"`
Permissions map[string]bool `json:"permissions,omitempty"`
}
// Username represents a user identifier
type Username string
// AuthConfig represents the authentication configuration
type AuthConfig struct {
Enabled bool `json:"enabled"`
APIKeys map[Username]APIKeyEntry `json:"api_keys"`
}
// AuthStore interface for different authentication backends
type AuthStore interface {
ValidateAPIKey(ctx context.Context, key string) (*User, error)
CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error
RevokeAPIKey(ctx context.Context, userID string) error
ListUsers(ctx context.Context) ([]UserInfo, error)
}
// contextKey is the type for context keys
type contextKey string
const userContextKey = contextKey("user")
// ValidateAPIKey validates an API key and returns user information
func (c *AuthConfig) ValidateAPIKey(key string) (*User, error) {
if !c.Enabled {
// Auth disabled - return default admin user for development
return &User{Name: "default", Admin: true}, nil
}
keyHash := HashAPIKey(key)
for username, entry := range c.APIKeys {
if string(entry.Hash) == keyHash {
// Build user with role and permission inheritance
user := &User{
Name: string(username),
Admin: entry.Admin,
Roles: entry.Roles,
Permissions: make(map[string]bool),
}
// Copy explicit permissions
for perm, value := range entry.Permissions {
user.Permissions[perm] = value
}
// Add role-based permissions
rolePerms := getRolePermissions(entry.Roles)
for perm, value := range rolePerms {
if _, exists := user.Permissions[perm]; !exists {
user.Permissions[perm] = value
}
}
// Admin gets all permissions
if entry.Admin {
user.Permissions["*"] = true
}
return user, nil
}
}
return nil, fmt.Errorf("invalid API key")
}
// AuthMiddleware creates HTTP middleware for API key authentication
func (c *AuthConfig) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !c.Enabled {
if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") != "1" || os.Getenv("FETCH_ML_DEBUG") != "1" {
http.Error(w, "Unauthorized: Authentication disabled", http.StatusUnauthorized)
return
}
log.Println("WARNING: Insecure authentication bypass enabled: FETCH_ML_ALLOW_INSECURE_AUTH=1 and FETCH_ML_DEBUG=1; do NOT use this configuration in production.")
ctx := context.WithValue(r.Context(), userContextKey, &User{Name: "default", Admin: true})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Only accept API key from header - no query parameters for security
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
http.Error(w, "Unauthorized: Missing API key in X-API-Key header", http.StatusUnauthorized)
return
}
user, err := c.ValidateAPIKey(apiKey)
if err != nil {
http.Error(w, "Unauthorized: Invalid API key", http.StatusUnauthorized)
return
}
// Add user to context
ctx := context.WithValue(r.Context(), userContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetUserFromContext retrieves user from request context
func GetUserFromContext(ctx context.Context) *User {
if user, ok := ctx.Value(userContextKey).(*User); ok {
return user
}
return nil
}
// RequireAdmin creates middleware that requires admin privileges
func RequireAdmin(next http.Handler) http.Handler {
return RequirePermission("system:admin")(next)
}
// RequirePermission creates middleware that requires specific permission
func RequirePermission(permission string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
if !user.HasPermission(permission) {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// HasPermission checks if user has a specific permission
func (u *User) HasPermission(permission string) bool {
// Wildcard permission grants all access
if u.Permissions["*"] {
return true
}
// Direct permission check
if u.Permissions[permission] {
return true
}
// Hierarchical permission check (e.g., "jobs:create" matches "jobs")
parts := strings.Split(permission, ":")
for i := 1; i < len(parts); i++ {
partial := strings.Join(parts[:i], ":")
if u.Permissions[partial] {
return true
}
}
return false
}
// HasRole checks if user has a specific role
func (u *User) HasRole(role string) bool {
for _, userRole := range u.Roles {
if userRole == role {
return true
}
}
return false
}
// getRolePermissions returns permissions for given roles
func getRolePermissions(roles []string) map[string]bool {
permissions := make(map[string]bool)
// Use YAML permission manager if available
if pm := GetGlobalPermissionManager(); pm != nil && pm.loaded {
for _, role := range roles {
rolePerms := pm.GetRolePermissions(role)
for perm, value := range rolePerms {
permissions[perm] = value
}
}
return permissions
}
// Fallback to inline permissions
rolePermissions := map[string]map[string]bool{
"admin": {"*": true},
"data_scientist": {
"jobs:create": true, "jobs:read": true, "jobs:update": true,
"data:read": true, "models:read": true,
},
"data_engineer": {
"data:create": true, "data:read": true, "data:update": true, "data:delete": true,
},
"viewer": {
"jobs:read": true, "data:read": true, "models:read": true, "metrics:read": true,
},
"operator": {
"jobs:read": true, "jobs:update": true, "metrics:read": true, "system:read": true,
},
}
for _, role := range roles {
if rolePerms, exists := rolePermissions[role]; exists {
for perm, value := range rolePerms {
permissions[perm] = value
}
}
}
return permissions
}
// GenerateAPIKey generates a new random API key
func GenerateAPIKey() string {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("%x", sha256.Sum256([]byte(time.Now().String())))
}
return hex.EncodeToString(buf)
}
// HashAPIKey creates a SHA256 hash of an API key
func HashAPIKey(key string) string {
hash := sha256.Sum256([]byte(key))
return hex.EncodeToString(hash[:])
}

View file

@ -0,0 +1,229 @@
package auth
import (
"testing"
)
func TestHashAPIKey(t *testing.T) {
tests := []struct {
name string
key string
expected string
}{
{
name: "known hash",
key: "password",
expected: "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8",
},
{
name: "another known hash",
key: "test",
expected: "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := HashAPIKey(tt.key)
if got != tt.expected {
t.Errorf("HashAPIKey() = %v, want %v", got, tt.expected)
}
})
}
}
func TestHashAPIKeyConsistency(t *testing.T) {
key := "my-secret-key"
hash1 := HashAPIKey(key)
hash2 := HashAPIKey(key)
if hash1 != hash2 {
t.Errorf("HashAPIKey() not consistent: %v != %v", hash1, hash2)
}
if len(hash1) != 64 {
t.Errorf("HashAPIKey() wrong length: got %d, want 64", len(hash1))
}
}
func TestGenerateAPIKey(t *testing.T) {
// Test that it generates keys
key1 := GenerateAPIKey()
if len(key1) != 64 {
t.Errorf("GenerateAPIKey() length = %d, want 64", len(key1))
}
// Test uniqueness (timing-based, should be different)
key2 := GenerateAPIKey()
if key1 == key2 {
t.Errorf("GenerateAPIKey() not unique: both generated %s", key1)
}
}
func TestUserHasPermission(t *testing.T) {
tests := []struct {
name string
user *User
permission string
want bool
}{
{
name: "wildcard grants all",
user: &User{
Permissions: map[string]bool{"*": true},
},
permission: "anything",
want: true,
},
{
name: "direct permission",
user: &User{
Permissions: map[string]bool{"jobs:create": true},
},
permission: "jobs:create",
want: true,
},
{
name: "hierarchical permission match",
user: &User{
Permissions: map[string]bool{"jobs": true},
},
permission: "jobs:create",
want: true,
},
{
name: "no permission",
user: &User{
Permissions: map[string]bool{"jobs:read": true},
},
permission: "jobs:create",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.user.HasPermission(tt.permission)
if got != tt.want {
t.Errorf("HasPermission() = %v, want %v", got, tt.want)
}
})
}
}
func TestUserHasRole(t *testing.T) {
tests := []struct {
name string
user *User
role string
want bool
}{
{
name: "has role",
user: &User{
Roles: []string{"admin", "user"},
},
role: "admin",
want: true,
},
{
name: "does not have role",
user: &User{
Roles: []string{"user"},
},
role: "admin",
want: false,
},
{
name: "empty roles",
user: &User{
Roles: []string{},
},
role: "admin",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.user.HasRole(tt.role)
if got != tt.want {
t.Errorf("HasRole() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthConfigValidateAPIKey(t *testing.T) {
config := &AuthConfig{
Enabled: true,
APIKeys: map[Username]APIKeyEntry{
"testuser": {
Hash: APIKeyHash(HashAPIKey("test-key")),
Admin: false,
Roles: []string{"user"},
Permissions: map[string]bool{
"jobs:read": true,
},
},
"admin": {
Hash: APIKeyHash(HashAPIKey("admin-key")),
Admin: true,
},
},
}
tests := []struct {
name string
key string
wantErr bool
wantAdmin bool
}{
{
name: "valid user key",
key: "test-key",
wantErr: false,
wantAdmin: false,
},
{
name: "valid admin key",
key: "admin-key",
wantErr: false,
wantAdmin: true,
},
{
name: "invalid key",
key: "wrong-key",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := config.ValidateAPIKey(tt.key)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateAPIKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && user.Admin != tt.wantAdmin {
t.Errorf("ValidateAPIKey() admin = %v, want %v", user.Admin, tt.wantAdmin)
}
})
}
}
func TestAuthConfigDisabled(t *testing.T) {
config := &AuthConfig{
Enabled: false,
}
user, err := config.ValidateAPIKey("any-key")
if err != nil {
t.Errorf("ValidateAPIKey() with auth disabled should not error: %v", err)
}
if !user.Admin {
t.Error("ValidateAPIKey() with auth disabled should return admin user")
}
}

210
internal/auth/database.go Normal file
View file

@ -0,0 +1,210 @@
package auth
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"time"
_ "github.com/mattn/go-sqlite3"
)
// DatabaseAuthStore implements authentication using SQLite database
type DatabaseAuthStore struct {
db *sql.DB
}
// APIKeyRecord represents an API key in the database
type APIKeyRecord struct {
ID int `json:"id"`
UserID string `json:"user_id"`
KeyHash string `json:"key_hash"`
Admin bool `json:"admin"`
Roles string `json:"roles"` // JSON array
Permissions string `json:"permissions"` // JSON object
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
}
// NewDatabaseAuthStore creates a new database-backed auth store
func NewDatabaseAuthStore(dbPath string) (*DatabaseAuthStore, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
store := &DatabaseAuthStore{db: db}
if err := store.init(); err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
}
return store, nil
}
// init creates the necessary database tables
func (s *DatabaseAuthStore) init() error {
query := `
CREATE TABLE IF NOT EXISTS api_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL UNIQUE,
key_hash TEXT NOT NULL UNIQUE,
admin BOOLEAN NOT NULL DEFAULT FALSE,
roles TEXT NOT NULL DEFAULT '[]',
permissions TEXT NOT NULL DEFAULT '{}',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME,
revoked_at DATETIME,
CHECK (json_valid(roles)),
CHECK (json_valid(permissions))
);
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31'));
`
_, err := s.db.Exec(query)
return err
}
// ValidateAPIKey checks if an API key is valid and returns user info
func (s *DatabaseAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) {
keyHash := HashAPIKey(key)
query := `
SELECT user_id, admin, roles, permissions, expires_at, revoked_at
FROM api_keys
WHERE key_hash = ?
AND (revoked_at IS NULL OR revoked_at > ?)
AND (expires_at IS NULL OR expires_at > ?)
`
var userID string
var admin bool
var rolesJSON, permissionsJSON string
var expiresAt, revokedAt sql.NullTime
now := time.Now()
err := s.db.QueryRowContext(ctx, query, keyHash, now, now).Scan(
&userID, &admin, &rolesJSON, &permissionsJSON, &expiresAt, &revokedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid API key")
}
return nil, fmt.Errorf("database error: %w", err)
}
// Parse roles
var roles []string
if err := json.Unmarshal([]byte(rolesJSON), &roles); err != nil {
log.Printf("Failed to parse roles for user %s: %v", userID, err)
roles = []string{}
}
// Parse permissions
var permissions map[string]bool
if err := json.Unmarshal([]byte(permissionsJSON), &permissions); err != nil {
log.Printf("Failed to parse permissions for user %s: %v", userID, err)
permissions = make(map[string]bool)
}
// Admin gets all permissions
if admin {
permissions["*"] = true
}
return &User{
Name: userID,
Admin: admin,
Roles: roles,
Permissions: permissions,
}, nil
}
// CreateAPIKey creates a new API key in the database
func (s *DatabaseAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error {
rolesJSON, err := json.Marshal(roles)
if err != nil {
return fmt.Errorf("failed to marshal roles: %w", err)
}
permissionsJSON, err := json.Marshal(permissions)
if err != nil {
return fmt.Errorf("failed to marshal permissions: %w", err)
}
query := `
INSERT INTO api_keys (user_id, key_hash, admin, roles, permissions, expires_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
key_hash = excluded.key_hash,
admin = excluded.admin,
roles = excluded.roles,
permissions = excluded.permissions,
expires_at = excluded.expires_at,
revoked_at = NULL
`
_, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt)
return err
}
// RevokeAPIKey revokes an API key
func (s *DatabaseAuthStore) RevokeAPIKey(ctx context.Context, userID string) error {
query := `UPDATE api_keys SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`
_, err := s.db.ExecContext(ctx, query, userID)
return err
}
// ListUsers returns all active users
func (s *DatabaseAuthStore) ListUsers(ctx context.Context) ([]APIKeyRecord, error) {
query := `
SELECT id, user_id, key_hash, admin, roles, permissions, created_at, expires_at, revoked_at
FROM api_keys
WHERE revoked_at IS NULL
ORDER BY created_at DESC
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query users: %w", err)
}
defer rows.Close()
var users []APIKeyRecord
for rows.Next() {
var user APIKeyRecord
err := rows.Scan(
&user.ID, &user.UserID, &user.KeyHash, &user.Admin,
&user.Roles, &user.Permissions, &user.CreatedAt,
&user.ExpiresAt, &user.RevokedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan user: %w", err)
}
users = append(users, user)
}
return users, nil
}
// CleanupExpiredKeys removes expired and revoked keys
func (s *DatabaseAuthStore) CleanupExpiredKeys(ctx context.Context) error {
query := `
DELETE FROM api_keys
WHERE (revoked_at IS NOT NULL AND revoked_at < datetime('now', '-30 days'))
OR (expires_at IS NOT NULL AND expires_at < datetime('now', '-7 days'))
`
_, err := s.db.ExecContext(ctx, query)
return err
}
// Close closes the database connection
func (s *DatabaseAuthStore) Close() error {
return s.db.Close()
}

122
internal/auth/flags.go Normal file
View file

@ -0,0 +1,122 @@
package auth
import (
"flag"
"fmt"
"log"
"os"
"strings"
)
// AuthFlags holds authentication-related command line flags
type AuthFlags struct {
APIKey string
APIKeyFile string
ConfigFile string
EnableAuth bool
ShowHelp bool
}
// ParseAuthFlags parses authentication command line flags
func ParseAuthFlags() *AuthFlags {
flags := &AuthFlags{}
flag.StringVar(&flags.APIKey, "api-key", "", "API key for authentication")
flag.StringVar(&flags.APIKeyFile, "api-key-file", "", "Path to file containing API key")
flag.StringVar(&flags.ConfigFile, "config", "", "Configuration file path")
flag.BoolVar(&flags.EnableAuth, "enable-auth", false, "Enable authentication")
flag.BoolVar(&flags.ShowHelp, "auth-help", false, "Show authentication help")
// Custom help flag that doesn't exit
flag.Usage = func() {}
flag.Parse()
return flags
}
// GetAPIKeyFromSources gets API key from multiple sources in priority order
func GetAPIKeyFromSources(flags *AuthFlags) string {
// 1. Command line flag (highest priority)
if flags.APIKey != "" {
return flags.APIKey
}
// 2. Explicit file flag
if flags.APIKeyFile != "" {
contents, readErr := os.ReadFile(flags.APIKeyFile)
if readErr == nil {
return strings.TrimSpace(string(contents))
}
log.Printf("Warning: Could not read API key file %s: %v", flags.APIKeyFile, readErr)
}
// 3. Environment variable
if envKey := os.Getenv("FETCH_ML_API_KEY"); envKey != "" {
return envKey
}
// 4. File-based key (for automated scripts)
if fileKey := os.Getenv("FETCH_ML_API_KEY_FILE"); fileKey != "" {
content, err := os.ReadFile(fileKey)
if err == nil {
return strings.TrimSpace(string(content))
}
log.Printf("Warning: Could not read API key file %s: %v", fileKey, err)
}
return ""
}
// ValidateAuthFlags validates parsed authentication flags
func ValidateAuthFlags(flags *AuthFlags) error {
if flags.ShowHelp {
PrintAuthHelp()
os.Exit(0)
}
if flags.APIKeyFile != "" {
if _, err := os.Stat(flags.APIKeyFile); err != nil {
return fmt.Errorf("api key file not found: %s", flags.APIKeyFile)
}
if err := CheckConfigFilePermissions(flags.APIKeyFile); err != nil {
log.Printf("Warning: %v", err)
}
}
// If config file is specified, check if it exists
if flags.ConfigFile != "" {
if _, err := os.Stat(flags.ConfigFile); err != nil {
return fmt.Errorf("config file not found: %s", flags.ConfigFile)
}
// Check file permissions
if err := CheckConfigFilePermissions(flags.ConfigFile); err != nil {
log.Printf("Warning: %v", err)
}
}
return nil
}
// PrintAuthHelp prints authentication-specific help
func PrintAuthHelp() {
fmt.Println("Authentication Options:")
fmt.Println(" --api-key <key> API key for authentication")
fmt.Println(" --api-key-file <path> Read API key from file")
fmt.Println(" --config <file> Configuration file path")
fmt.Println(" --enable-auth Enable authentication (if disabled)")
fmt.Println(" --auth-help Show this help")
fmt.Println()
fmt.Println("Environment Variables:")
fmt.Println(" FETCH_ML_API_KEY API key for authentication")
fmt.Println(" FETCH_ML_API_KEY_FILE File containing API key")
fmt.Println(" FETCH_ML_ENV Environment (development/production)")
fmt.Println(" FETCH_ML_ALLOW_INSECURE_AUTH Allow insecure auth (dev only)")
fmt.Println()
fmt.Println("Security Notes:")
fmt.Println(" - API keys in command line may be visible in process lists")
fmt.Println(" - Environment variables are preferred for automated scripts")
fmt.Println(" - File-based keys should have restricted permissions (600)")
fmt.Println(" - Authentication is mandatory in production environments")
}

275
internal/auth/hybrid.go Normal file
View file

@ -0,0 +1,275 @@
package auth
import (
"context"
"fmt"
"log"
"sync"
"time"
)
// HybridAuthStore combines file-based and database authentication
// Falls back to file config if database is not available
type HybridAuthStore struct {
fileStore *AuthConfig
dbStore *DatabaseAuthStore
useDB bool
mu sync.RWMutex
}
// NewHybridAuthStore creates a hybrid auth store
func NewHybridAuthStore(config *AuthConfig, dbPath string) (*HybridAuthStore, error) {
hybrid := &HybridAuthStore{
fileStore: config,
useDB: false,
}
// Try to initialize database store
if dbPath != "" {
dbStore, err := NewDatabaseAuthStore(dbPath)
if err != nil {
log.Printf("Failed to initialize database auth store, falling back to file: %v", err)
} else {
hybrid.dbStore = dbStore
hybrid.useDB = true
log.Printf("Using database authentication store")
}
}
// If database is available, migrate file-based keys to database
if hybrid.useDB && config.Enabled && len(config.APIKeys) > 0 {
if err := hybrid.migrateFileToDatabase(context.Background()); err != nil {
log.Printf("Failed to migrate file keys to database: %v", err)
}
}
return hybrid, nil
}
// ValidateAPIKey validates an API key using either database or file store
func (h *HybridAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) {
h.mu.RLock()
useDB := h.useDB
h.mu.RUnlock()
if useDB {
user, err := h.dbStore.ValidateAPIKey(ctx, key)
if err == nil {
return user, nil
}
// If database fails, fall back to file store
log.Printf("Database auth failed, falling back to file store: %v", err)
return h.fileStore.ValidateAPIKey(key)
}
// Use file store
return h.fileStore.ValidateAPIKey(key)
}
// CreateAPIKey creates an API key using the preferred store
func (h *HybridAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error {
h.mu.RLock()
useDB := h.useDB
h.mu.RUnlock()
if useDB {
err := h.dbStore.CreateAPIKey(ctx, userID, keyHash, admin, roles, permissions, expiresAt)
if err == nil {
return nil
}
// If database fails, fall back to file store
log.Printf("Database key creation failed, using file store: %v", err)
return h.createFileAPIKey(userID, keyHash, admin, roles, permissions)
}
// Use file store
return h.createFileAPIKey(userID, keyHash, admin, roles, permissions)
}
// createFileAPIKey creates an API key in the file store
func (h *HybridAuthStore) createFileAPIKey(userID string, keyHash string, admin bool, roles []string, permissions map[string]bool) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.fileStore.APIKeys == nil {
h.fileStore.APIKeys = make(map[Username]APIKeyEntry)
}
h.fileStore.APIKeys[Username(userID)] = APIKeyEntry{
Hash: APIKeyHash(keyHash),
Admin: admin,
Roles: roles,
Permissions: permissions,
}
return nil
}
// RevokeAPIKey revokes an API key
func (h *HybridAuthStore) RevokeAPIKey(ctx context.Context, userID string) error {
h.mu.RLock()
useDB := h.useDB
h.mu.RUnlock()
if useDB {
err := h.dbStore.RevokeAPIKey(ctx, userID)
if err == nil {
return nil
}
log.Printf("Database key revocation failed: %v", err)
}
// Remove from file store
h.mu.Lock()
delete(h.fileStore.APIKeys, Username(userID))
h.mu.Unlock()
return nil
}
// ListUsers returns all users from the active store
func (h *HybridAuthStore) ListUsers(ctx context.Context) ([]UserInfo, error) {
h.mu.RLock()
useDB := h.useDB
h.mu.RUnlock()
if useDB {
records, err := h.dbStore.ListUsers(ctx)
if err == nil {
users := make([]UserInfo, len(records))
for i, record := range records {
users[i] = UserInfo{
UserID: record.UserID,
Admin: record.Admin,
KeyHash: record.KeyHash,
Created: record.CreatedAt,
Expires: record.ExpiresAt,
Revoked: record.RevokedAt,
}
}
return users, nil
}
log.Printf("Database user listing failed: %v", err)
}
// Use file store
return h.listFileUsers()
}
// UserInfo represents user information for listing
type UserInfo struct {
UserID string `json:"user_id"`
Admin bool `json:"admin"`
KeyHash string `json:"key_hash"`
Created time.Time `json:"created"`
Expires *time.Time `json:"expires,omitempty"`
Revoked *time.Time `json:"revoked,omitempty"`
}
// listFileUsers returns users from file store
func (h *HybridAuthStore) listFileUsers() ([]UserInfo, error) {
h.mu.RLock()
defer h.mu.RUnlock()
var users []UserInfo
for username, entry := range h.fileStore.APIKeys {
users = append(users, UserInfo{
UserID: string(username),
Admin: entry.Admin,
KeyHash: string(entry.Hash),
Created: time.Now(), // File store doesn't track creation time
})
}
return users, nil
}
// migrateFileToDatabase migrates file-based keys to database
func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error {
if h.fileStore == nil || len(h.fileStore.APIKeys) == 0 {
return nil
}
log.Printf("Migrating %d API keys from file to database...", len(h.fileStore.APIKeys))
for username, entry := range h.fileStore.APIKeys {
userID := string(username)
err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil)
if err != nil {
log.Printf("Failed to migrate key for user %s: %v", userID, err)
continue
}
log.Printf("Migrated key for user: %s", userID)
}
log.Printf("Migration completed. Consider removing keys from config file.")
return nil
}
// SwitchToDatabase attempts to switch to database authentication
func (h *HybridAuthStore) SwitchToDatabase(dbPath string) error {
dbStore, err := NewDatabaseAuthStore(dbPath)
if err != nil {
return fmt.Errorf("failed to create database store: %w", err)
}
h.mu.Lock()
defer h.mu.Unlock()
// Close existing database if any
if h.dbStore != nil {
h.dbStore.Close()
}
h.dbStore = dbStore
h.useDB = true
// Migrate existing keys
if h.fileStore.Enabled && len(h.fileStore.APIKeys) > 0 {
if err := h.migrateFileToDatabase(context.Background()); err != nil {
log.Printf("Migration warning: %v", err)
}
}
return nil
}
// Close closes the database connection
func (h *HybridAuthStore) Close() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.dbStore != nil {
return h.dbStore.Close()
}
return nil
}
// GetDatabaseStats returns database statistics
func (h *HybridAuthStore) GetDatabaseStats(ctx context.Context) (map[string]interface{}, error) {
h.mu.RLock()
useDB := h.useDB
h.mu.RUnlock()
if !useDB {
return map[string]interface{}{
"store_type": "file",
"users": len(h.fileStore.APIKeys),
}, nil
}
users, err := h.dbStore.ListUsers(ctx)
if err != nil {
return nil, err
}
return map[string]interface{}{
"store_type": "database",
"users": len(users),
"path": "db/fetch_ml.db",
}, nil
}

167
internal/auth/keychain.go Normal file
View file

@ -0,0 +1,167 @@
package auth
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/zalando/go-keyring"
)
// KeychainManager provides secure storage for API keys.
type KeychainManager struct {
primary systemKeyring
fallback *fileKeyStore
}
// systemKeyring abstracts go-keyring for easier testing.
type systemKeyring interface {
Set(service, account, secret string) error
Get(service, account string) (string, error)
Delete(service, account string) error
}
type goKeyring struct{}
func (goKeyring) Set(service, account, secret string) error {
return keyring.Set(service, account, secret)
}
func (goKeyring) Get(service, account string) (string, error) {
return keyring.Get(service, account)
}
func (goKeyring) Delete(service, account string) error {
return keyring.Delete(service, account)
}
// NewKeychainManager returns a manager backed by the OS keyring with a secure file fallback.
func NewKeychainManager() *KeychainManager {
return newKeychainManagerWithKeyring(goKeyring{}, defaultFallbackDir())
}
func newKeychainManagerWithKeyring(kr systemKeyring, fallbackDir string) *KeychainManager {
if fallbackDir == "" {
fallbackDir = defaultFallbackDir()
}
return &KeychainManager{
primary: kr,
fallback: newFileKeyStore(fallbackDir),
}
}
func defaultFallbackDir() string {
home, err := os.UserHomeDir()
if err != nil || home == "" {
return filepath.Join(os.TempDir(), "fetch_ml", "keys")
}
return filepath.Join(home, ".fetch_ml", "keys")
}
// StoreAPIKey stores the key in the OS keyring, falling back to a protected file when needed.
func (km *KeychainManager) StoreAPIKey(service, account, key string) error {
if err := km.primary.Set(service, account, key); err == nil {
return nil
} else if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) {
return km.fallback.store(service, account, key)
} else if fallbackErr := km.fallback.store(service, account, key); fallbackErr == nil {
return nil
}
return fmt.Errorf("failed to store API key")
}
// GetAPIKey retrieves a key from the OS keyring or fallback store.
func (km *KeychainManager) GetAPIKey(service, account string) (string, error) {
secret, err := km.primary.Get(service, account)
if err == nil {
return secret, nil
}
if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) {
return km.fallback.get(service, account)
}
// Unknown error - try fallback before surfacing
if fallbackSecret, ferr := km.fallback.get(service, account); ferr == nil {
return fallbackSecret, nil
}
return "", fmt.Errorf("failed to retrieve API key")
}
// DeleteAPIKey removes a key from both stores.
func (km *KeychainManager) DeleteAPIKey(service, account string) error {
if err := km.primary.Delete(service, account); err != nil && !errors.Is(err, keyring.ErrNotFound) && !errors.Is(err, keyring.ErrUnsupportedPlatform) {
return fmt.Errorf("failed to delete API key: %w", err)
}
if err := km.fallback.delete(service, account); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return nil
}
// IsAvailable reports whether the OS keyring backend is usable.
func (km *KeychainManager) IsAvailable() bool {
_, err := km.primary.Get("fetch_ml_probe", fmt.Sprintf("probe_%d", time.Now().UnixNano()))
return err == nil || !errors.Is(err, keyring.ErrUnsupportedPlatform)
}
// ListAvailableMethods returns backends the manager can use.
func (km *KeychainManager) ListAvailableMethods() []string {
methods := []string{}
if km.IsAvailable() {
methods = append(methods, "OS keyring")
}
methods = append(methods, fmt.Sprintf("Encrypted file (%s)", km.fallback.baseDir))
return methods
}
// fileKeyStore stores secrets with 0600 permissions as a fallback.
type fileKeyStore struct {
baseDir string
mu sync.Mutex
}
func newFileKeyStore(baseDir string) *fileKeyStore {
return &fileKeyStore{baseDir: baseDir}
}
func (f *fileKeyStore) store(service, account, secret string) error {
f.mu.Lock()
defer f.mu.Unlock()
if err := os.MkdirAll(f.baseDir, 0o700); err != nil {
return fmt.Errorf("failed to prepare key store: %w", err)
}
path := f.path(service, account)
return os.WriteFile(path, []byte(secret), 0o600)
}
func (f *fileKeyStore) get(service, account string) (string, error) {
f.mu.Lock()
defer f.mu.Unlock()
data, err := os.ReadFile(f.path(service, account))
if err != nil {
return "", err
}
return string(data), nil
}
func (f *fileKeyStore) delete(service, account string) error {
f.mu.Lock()
defer f.mu.Unlock()
path := f.path(service, account)
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return nil
}
func (f *fileKeyStore) path(service, account string) string {
return filepath.Join(f.baseDir, fmt.Sprintf("%s_%s.key", sanitize(service), sanitize(account)))
}
func sanitize(value string) string {
replacer := strings.NewReplacer("/", "_", "\\", "_", "..", "_", " ", "_", "\t", "_")
return replacer.Replace(value)
}

View file

@ -0,0 +1,129 @@
package auth
import (
"errors"
"os"
"path/filepath"
"testing"
"github.com/zalando/go-keyring"
)
type fakeKeyring struct {
secrets map[string]string
setErr error
getErr error
deleteErr error
}
func newFakeKeyring() *fakeKeyring {
return &fakeKeyring{secrets: make(map[string]string)}
}
func (f *fakeKeyring) Set(service, account, secret string) error {
if f.setErr != nil {
return f.setErr
}
f.secrets[key(service, account)] = secret
return nil
}
func (f *fakeKeyring) Get(service, account string) (string, error) {
if f.getErr != nil {
return "", f.getErr
}
if secret, ok := f.secrets[key(service, account)]; ok {
return secret, nil
}
return "", keyring.ErrNotFound
}
func (f *fakeKeyring) Delete(service, account string) error {
if f.deleteErr != nil {
return f.deleteErr
}
delete(f.secrets, key(service, account))
return nil
}
func key(service, account string) string {
return service + ":" + account
}
func newTestManager(t *testing.T, kr systemKeyring) (*KeychainManager, string) {
t.Helper()
baseDir := t.TempDir()
return newKeychainManagerWithKeyring(kr, baseDir), baseDir
}
func TestKeychainStoreAndGetPrimary(t *testing.T) {
kr := newFakeKeyring()
km, baseDir := newTestManager(t, kr)
if err := km.StoreAPIKey("fetch-ml", "alice", "super-secret"); err != nil {
t.Fatalf("StoreAPIKey failed: %v", err)
}
got, err := km.GetAPIKey("fetch-ml", "alice")
if err != nil {
t.Fatalf("GetAPIKey failed: %v", err)
}
if got != "super-secret" {
t.Fatalf("expected secret to be stored in primary keyring")
}
// Ensure fallback file was not created when primary succeeds
path := filepath.Join(baseDir, filepath.Base(km.fallback.path("fetch-ml", "alice")))
if _, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected no fallback file, got err=%v", err)
}
}
func TestKeychainFallbackWhenUnsupported(t *testing.T) {
kr := newFakeKeyring()
kr.setErr = keyring.ErrUnsupportedPlatform
kr.getErr = keyring.ErrUnsupportedPlatform
kr.deleteErr = keyring.ErrUnsupportedPlatform
km, _ := newTestManager(t, kr)
if err := km.StoreAPIKey("fetch-ml", "bob", "fallback-secret"); err != nil {
t.Fatalf("StoreAPIKey should fallback: %v", err)
}
got, err := km.GetAPIKey("fetch-ml", "bob")
if err != nil {
t.Fatalf("GetAPIKey should use fallback: %v", err)
}
if got != "fallback-secret" {
t.Fatalf("expected fallback secret, got %s", got)
}
}
func TestKeychainDeleteRemovesFallback(t *testing.T) {
kr := newFakeKeyring()
kr.deleteErr = keyring.ErrNotFound
km, _ := newTestManager(t, kr)
if err := km.fallback.store("fetch-ml", "carol", "temp"); err != nil {
t.Fatalf("failed to seed fallback store: %v", err)
}
if err := km.DeleteAPIKey("fetch-ml", "carol"); err != nil {
t.Fatalf("DeleteAPIKey failed: %v", err)
}
if _, err := km.fallback.get("fetch-ml", "carol"); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected fallback secret removed, err=%v", err)
}
}
func TestListAvailableMethodsIncludesFallback(t *testing.T) {
kr := newFakeKeyring()
kr.getErr = keyring.ErrUnsupportedPlatform
km, _ := newTestManager(t, kr)
methods := km.ListAvailableMethods()
if len(methods) != 1 || methods[0] == "OS keyring" {
t.Fatalf("expected only fallback method, got %v", methods)
}
}

View file

@ -0,0 +1,192 @@
package auth
import (
"fmt"
"strings"
)
// Permission constants for type safety
const (
// Job permissions
PermissionJobsCreate = "jobs:create"
PermissionJobsRead = "jobs:read"
PermissionJobsUpdate = "jobs:update"
PermissionJobsDelete = "jobs:delete"
// Data permissions
PermissionDataCreate = "data:create"
PermissionDataRead = "data:read"
PermissionDataUpdate = "data:update"
PermissionDataDelete = "data:delete"
// Model permissions
PermissionModelsCreate = "models:create"
PermissionModelsRead = "models:read"
PermissionModelsUpdate = "models:update"
PermissionModelsDelete = "models:delete"
// System permissions
PermissionSystemConfig = "system:config"
PermissionSystemMetrics = "system:metrics"
PermissionSystemLogs = "system:logs"
PermissionSystemUsers = "system:users"
// Wildcard permission
PermissionAll = "*"
)
// Role constants
const (
RoleAdmin = "admin"
RoleDataScientist = "data_scientist"
RoleDataEngineer = "data_engineer"
RoleViewer = "viewer"
RoleOperator = "operator"
)
// PermissionGroup represents a group of related permissions
type PermissionGroup struct {
Name string
Permissions []string
Description string
}
// Built-in permission groups
var PermissionGroups = map[string]PermissionGroup{
"full_access": {
Name: "Full Access",
Permissions: []string{PermissionAll},
Description: "Complete system access",
},
"job_management": {
Name: "Job Management",
Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete},
Description: "Create, read, update, and delete ML jobs",
},
"data_access": {
Name: "Data Access",
Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete},
Description: "Access and manage datasets",
},
"readonly": {
Name: "Read Only",
Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics},
Description: "View-only access to system resources",
},
"system_admin": {
Name: "System Administration",
Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics},
Description: "System configuration and user management",
},
}
// GetPermissionGroup returns a permission group by name
func GetPermissionGroup(name string) (PermissionGroup, bool) {
group, exists := PermissionGroups[name]
return group, exists
}
// ValidatePermission checks if a permission string is valid
func ValidatePermission(permission string) error {
if permission == PermissionAll {
return nil
}
// Check if permission matches known patterns
validPrefixes := []string{"jobs:", "data:", "models:", "system:"}
for _, prefix := range validPrefixes {
if strings.HasPrefix(permission, prefix) {
return nil
}
}
return fmt.Errorf("invalid permission format: %s", permission)
}
// ValidateRole checks if a role is valid
func ValidateRole(role string) error {
validRoles := []string{RoleAdmin, RoleDataScientist, RoleDataEngineer, RoleViewer, RoleOperator}
for _, validRole := range validRoles {
if role == validRole {
return nil
}
}
return fmt.Errorf("invalid role: %s", role)
}
// ExpandPermissionGroups converts permission group names to actual permissions
func ExpandPermissionGroups(groups []string) ([]string, error) {
var permissions []string
for _, groupName := range groups {
if groupName == PermissionAll {
return []string{PermissionAll}, nil
}
group, exists := GetPermissionGroup(groupName)
if !exists {
return nil, fmt.Errorf("unknown permission group: %s", groupName)
}
permissions = append(permissions, group.Permissions...)
}
// Remove duplicates
unique := make(map[string]bool)
for _, perm := range permissions {
unique[perm] = true
}
result := make([]string, 0, len(unique))
for perm := range unique {
result = append(result, perm)
}
return result, nil
}
// PermissionCheckResult represents the result of a permission check
type PermissionCheckResult struct {
Allowed bool `json:"allowed"`
Permission string `json:"permission"`
User string `json:"user"`
Roles []string `json:"roles"`
Missing []string `json:"missing,omitempty"`
}
// CheckMultiplePermissions checks multiple permissions at once
func (u *User) CheckMultiplePermissions(permissions []string) []PermissionCheckResult {
results := make([]PermissionCheckResult, len(permissions))
for i, permission := range permissions {
allowed := u.HasPermission(permission)
missing := []string{}
if !allowed {
missing = []string{permission}
}
results[i] = PermissionCheckResult{
Allowed: allowed,
Permission: permission,
User: u.Name,
Roles: u.Roles,
Missing: missing,
}
}
return results
}
// GetEffectivePermissions returns all effective permissions for a user
func (u *User) GetEffectivePermissions() []string {
if u.Permissions[PermissionAll] {
return []string{PermissionAll}
}
permissions := make([]string, 0, len(u.Permissions))
for perm := range u.Permissions {
permissions = append(permissions, perm)
}
return permissions
}

View file

@ -0,0 +1,295 @@
package auth
import (
"fmt"
"os"
"sync"
"gopkg.in/yaml.v3"
)
// PermissionConfig represents the permissions configuration
type PermissionConfig struct {
Roles map[string]RoleConfig `yaml:"roles"`
Groups map[string]GroupConfig `yaml:"groups"`
Hierarchy map[string]HierarchyConfig `yaml:"hierarchy"`
Defaults DefaultsConfig `yaml:"defaults"`
}
// RoleConfig defines a role and its permissions
type RoleConfig struct {
Description string `yaml:"description"`
Permissions []string `yaml:"permissions"`
}
// GroupConfig defines a permission group
type GroupConfig struct {
Description string `yaml:"description"`
Inherits []string `yaml:"inherits"`
Permissions []string `yaml:"permissions"`
}
// HierarchyConfig defines resource hierarchy
type HierarchyConfig struct {
Children map[string]interface{} `yaml:"children"`
Special map[string]string `yaml:"special"`
}
// DefaultsConfig defines default settings
type DefaultsConfig struct {
NewUserRole string `yaml:"new_user_role"`
AdminUsers []string `yaml:"admin_users"`
}
// PermissionManager manages permissions from YAML file
type PermissionManager struct {
config *PermissionConfig
rolePerms map[string]map[string]bool
groupPerms map[string]map[string]bool
mu sync.RWMutex
loaded bool
}
// NewPermissionManager creates a new permission manager
func NewPermissionManager(configPath string) (*PermissionManager, error) {
pm := &PermissionManager{}
if err := pm.loadConfig(configPath); err != nil {
return nil, fmt.Errorf("failed to load permissions: %w", err)
}
return pm, nil
}
// loadConfig loads permissions from YAML file
func (pm *PermissionManager) loadConfig(configPath string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read permissions file: %w", err)
}
var config PermissionConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse permissions file: %w", err)
}
pm.config = &config
pm.rolePerms = make(map[string]map[string]bool)
pm.groupPerms = make(map[string]map[string]bool)
// Process role permissions
for roleName, role := range config.Roles {
perms := make(map[string]bool)
for _, perm := range role.Permissions {
perms[perm] = true
}
pm.rolePerms[roleName] = perms
}
// Process group permissions
for groupName, group := range config.Groups {
perms := make(map[string]bool)
// Add direct permissions
for _, perm := range group.Permissions {
perms[perm] = true
}
// Inherit permissions from other roles/groups
for _, inherit := range group.Inherits {
if rolePerms, exists := pm.rolePerms[inherit]; exists {
for perm, value := range rolePerms {
perms[perm] = value
}
}
if groupPerms, exists := pm.groupPerms[inherit]; exists {
for perm, value := range groupPerms {
perms[perm] = value
}
}
}
pm.groupPerms[groupName] = perms
}
pm.loaded = true
return nil
}
// GetRolePermissions returns permissions for a role
func (pm *PermissionManager) GetRolePermissions(role string) map[string]bool {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return make(map[string]bool)
}
if perms, exists := pm.rolePerms[role]; exists {
result := make(map[string]bool)
for perm, value := range perms {
result[perm] = value
}
return result
}
return make(map[string]bool)
}
// GetGroupPermissions returns permissions for a group
func (pm *PermissionManager) GetGroupPermissions(group string) map[string]bool {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return make(map[string]bool)
}
if perms, exists := pm.groupPerms[group]; exists {
result := make(map[string]bool)
for perm, value := range perms {
result[perm] = value
}
return result
}
return make(map[string]bool)
}
// GetAllRoles returns all available roles
func (pm *PermissionManager) GetAllRoles() map[string]RoleConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return make(map[string]RoleConfig)
}
result := make(map[string]RoleConfig)
for name, role := range pm.config.Roles {
result[name] = role
}
return result
}
// GetAllGroups returns all available groups
func (pm *PermissionManager) GetAllGroups() map[string]GroupConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return make(map[string]GroupConfig)
}
result := make(map[string]GroupConfig)
for name, group := range pm.config.Groups {
result[name] = group
}
return result
}
// GetDefaultRole returns the default role for new users
func (pm *PermissionManager) GetDefaultRole() string {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded || pm.config.Defaults.NewUserRole == "" {
return "viewer"
}
return pm.config.Defaults.NewUserRole
}
// IsAdminUser checks if a username should have admin rights
func (pm *PermissionManager) IsAdminUser(username string) bool {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return false
}
for _, adminUser := range pm.config.Defaults.AdminUsers {
if adminUser == username {
return true
}
}
return false
}
// ReloadConfig reloads the permissions configuration
func (pm *PermissionManager) ReloadConfig(configPath string) error {
return pm.loadConfig(configPath)
}
// ValidatePermission checks if a permission string is valid
func (pm *PermissionManager) ValidatePermission(permission string) bool {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return false
}
// Wildcard is always valid
if permission == "*" {
return true
}
// Check if permission matches any defined role permissions
for _, rolePerms := range pm.rolePerms {
if _, exists := rolePerms[permission]; exists {
return true
}
}
// Check if permission matches any group permissions
for _, groupPerms := range pm.groupPerms {
if _, exists := groupPerms[permission]; exists {
return true
}
}
return false
}
// GetPermissionHierarchy returns the hierarchy for a resource
func (pm *PermissionManager) GetPermissionHierarchy(resource string) map[string]interface{} {
pm.mu.RLock()
defer pm.mu.RUnlock()
if !pm.loaded {
return make(map[string]interface{})
}
if hierarchy, exists := pm.config.Hierarchy[resource]; exists {
return hierarchy.Children
}
return make(map[string]interface{})
}
// Global permission manager instance
var globalPermissionManager *PermissionManager
var permissionManagerOnce sync.Once
// GetGlobalPermissionManager returns the global permission manager
func GetGlobalPermissionManager() *PermissionManager {
permissionManagerOnce.Do(func() {
// Try to load from default location
if pm, err := NewPermissionManager("configs/schema/permissions.yaml"); err == nil {
globalPermissionManager = pm
} else {
// Fallback to empty manager
globalPermissionManager = &PermissionManager{
rolePerms: make(map[string]map[string]bool),
groupPerms: make(map[string]map[string]bool),
loaded: false,
}
}
})
return globalPermissionManager
}

100
internal/auth/validator.go Normal file
View file

@ -0,0 +1,100 @@
package auth
import (
"fmt"
"log"
"os"
"strings"
)
// ValidateAuthConfig enforces authentication requirements
func (c *AuthConfig) ValidateAuthConfig() error {
// Check if we're in production environment
isProduction := os.Getenv("FETCH_ML_ENV") == "production"
if isProduction {
if !c.Enabled {
return fmt.Errorf("authentication must be enabled in production environment")
}
if len(c.APIKeys) == 0 {
return fmt.Errorf("at least one API key must be configured in production")
}
// Ensure at least one admin user exists
hasAdmin := false
for _, entry := range c.APIKeys {
if entry.Admin {
hasAdmin = true
break
}
}
if !hasAdmin {
return fmt.Errorf("at least one admin user must be configured in production")
}
// Check for insecure development override
if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") == "1" {
log.Printf("WARNING: FETCH_ML_ALLOW_INSECURE_AUTH is enabled in production - this is insecure")
}
}
// Validate API key format
for username, entry := range c.APIKeys {
if string(username) == "" {
return fmt.Errorf("empty username not allowed")
}
if entry.Hash == "" {
return fmt.Errorf("user %s has empty API key hash", username)
}
// Validate hash format (should be 64 hex chars for SHA256)
if len(entry.Hash) != 64 {
return fmt.Errorf("user %s has invalid API key hash format", username)
}
// Check hash contains only hex characters
for _, char := range entry.Hash {
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
return fmt.Errorf("user %s has invalid API key hash characters", username)
}
}
}
return nil
}
// CheckConfigFilePermissions ensures config files have secure permissions
func CheckConfigFilePermissions(configPath string) error {
info, err := os.Stat(configPath)
if err != nil {
return fmt.Errorf("cannot stat config file: %w", err)
}
// Check file permissions (should be 600 or 640)
perm := info.Mode().Perm()
if perm&0077 != 0 {
return fmt.Errorf("config file %s has insecure permissions: %o (should be 600 or 640)", configPath, perm)
}
return nil
}
// SanitizeConfig removes sensitive information for logging
func (c *AuthConfig) SanitizeConfig() map[string]interface{} {
sanitized := map[string]interface{}{
"enabled": c.Enabled,
"users": make(map[string]interface{}),
}
for username := range c.APIKeys {
sanitized["users"].(map[string]interface{})[string(username)] = map[string]interface{}{
"admin": c.APIKeys[username].Admin,
"hash": strings.Repeat("*", 8) + "...", // Show only prefix
}
}
return sanitized
}

View file

@ -0,0 +1,54 @@
package config
// Default configuration values (legacy - use SmartDefaults for new code)
const (
DefaultSSHPort = 22
DefaultRedisPort = 6379
DefaultRedisAddr = "localhost:6379"
DefaultBasePath = "/mnt/nas/jobs"
DefaultTrainScript = "train.py"
DefaultDataDir = "/data/active"
DefaultLocalDataDir = "./data/active"
DefaultNASDataDir = "/mnt/datasets"
DefaultMaxWorkers = 2
DefaultPollInterval = 5
DefaultMaxAgeHours = 24
DefaultMaxSizeGB = 100
DefaultCleanupInterval = 60
)
// Redis key prefixes
const (
RedisTaskQueueKey = "ml:queue"
RedisTaskPrefix = "ml:task:"
RedisJobMetricsPrefix = "ml:metrics:"
RedisTaskStatusPrefix = "ml:status:"
RedisDatasetPrefix = "ml:dataset:"
RedisWorkerHeartbeat = "ml:workers:heartbeat"
)
// Task status constants
const (
TaskStatusQueued = "queued"
TaskStatusRunning = "running"
TaskStatusCompleted = "completed"
TaskStatusFailed = "failed"
TaskStatusCancelled = "cancelled"
)
// Job status constants
const (
JobStatusPending = "pending"
JobStatusQueued = "queued"
JobStatusRunning = "running"
JobStatusFinished = "finished"
JobStatusFailed = "failed"
)
// Podman defaults
const (
DefaultPodmanMemory = "8g"
DefaultPodmanCPUs = "2"
DefaultContainerWorkspace = "/workspace"
DefaultContainerResults = "/workspace/results"
)

73
internal/config/paths.go Normal file
View file

@ -0,0 +1,73 @@
// Package config provides shared utilities for the fetch_ml project.
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// ExpandPath expands environment variables and tilde in a path
func ExpandPath(path string) string {
if path == "" {
return ""
}
expanded := os.ExpandEnv(path)
if strings.HasPrefix(expanded, "~") {
home, err := os.UserHomeDir()
if err == nil {
expanded = filepath.Join(home, expanded[1:])
}
}
return expanded
}
// ResolveConfigPath resolves a config file path, checking multiple locations
func ResolveConfigPath(path string) (string, error) {
candidates := []string{path}
if !filepath.IsAbs(path) {
candidates = append(candidates, filepath.Join("configs", path))
}
var checked []string
for _, candidate := range candidates {
resolved := ExpandPath(candidate)
checked = append(checked, resolved)
if _, err := os.Stat(resolved); err == nil {
return resolved, nil
}
}
return "", fmt.Errorf("config file not found (looked in %s)", strings.Join(checked, ", "))
}
// JobPaths provides helper methods for job directory paths
type JobPaths struct {
BasePath string
}
// NewJobPaths creates a new JobPaths instance
func NewJobPaths(basePath string) *JobPaths {
return &JobPaths{BasePath: basePath}
}
// PendingPath returns the path to pending jobs directory
func (j *JobPaths) PendingPath() string {
return filepath.Join(j.BasePath, "pending")
}
// RunningPath returns the path to running jobs directory
func (j *JobPaths) RunningPath() string {
return filepath.Join(j.BasePath, "running")
}
// FinishedPath returns the path to finished jobs directory
func (j *JobPaths) FinishedPath() string {
return filepath.Join(j.BasePath, "finished")
}
// FailedPath returns the path to failed jobs directory
func (j *JobPaths) FailedPath() string {
return filepath.Join(j.BasePath, "failed")
}

View file

@ -0,0 +1,222 @@
package config
import (
"os"
"path/filepath"
"runtime"
"strings"
)
// EnvironmentProfile represents the deployment environment
type EnvironmentProfile int
const (
ProfileLocal EnvironmentProfile = iota
ProfileContainer
ProfileCI
ProfileProduction
)
// DetectEnvironment determines the current environment profile
func DetectEnvironment() EnvironmentProfile {
// CI detection
if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("GITLAB_CI") != "" {
return ProfileCI
}
// Container detection
if _, err := os.Stat("/.dockerenv"); err == nil {
return ProfileContainer
}
if os.Getenv("KUBERNETES_SERVICE_HOST") != "" {
return ProfileContainer
}
if os.Getenv("CONTAINER") != "" {
return ProfileContainer
}
// Production detection (customizable)
if os.Getenv("FETCH_ML_ENV") == "production" || os.Getenv("ENV") == "production" {
return ProfileProduction
}
// Default to local development
return ProfileLocal
}
// SmartDefaults provides environment-aware default values
type SmartDefaults struct {
Profile EnvironmentProfile
}
// GetSmartDefaults returns defaults for the current environment
func GetSmartDefaults() *SmartDefaults {
return &SmartDefaults{
Profile: DetectEnvironment(),
}
}
// Host returns the appropriate default host
func (s *SmartDefaults) Host() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "host.docker.internal" // Docker Desktop/Colima
case ProfileProduction:
return "0.0.0.0"
default: // ProfileLocal
return "localhost"
}
}
// BasePath returns the appropriate default base path
func (s *SmartDefaults) BasePath() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/ml-experiments"
case ProfileProduction:
return "/var/lib/fetch_ml/experiments"
default: // ProfileLocal
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-experiments")
}
return "./ml-experiments"
}
}
// DataDir returns the appropriate default data directory
func (s *SmartDefaults) DataDir() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/data"
case ProfileProduction:
return "/var/lib/fetch_ml/data"
default: // ProfileLocal
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-data")
}
return "./data"
}
}
// RedisAddr returns the appropriate default Redis address
func (s *SmartDefaults) RedisAddr() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "redis:6379" // Service name in containers
case ProfileProduction:
return "redis:6379"
default: // ProfileLocal
return "localhost:6379"
}
}
// SSHKeyPath returns the appropriate default SSH key path
func (s *SmartDefaults) SSHKeyPath() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/id_rsa"
case ProfileProduction:
return "/etc/fetch_ml/ssh/id_rsa"
default: // ProfileLocal
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "id_rsa")
}
return "~/.ssh/id_rsa"
}
}
// KnownHostsPath returns the appropriate default known_hosts path
func (s *SmartDefaults) KnownHostsPath() string {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/known_hosts"
case ProfileProduction:
return "/etc/fetch_ml/ssh/known_hosts"
default: // ProfileLocal
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "known_hosts")
}
return "~/.ssh/known_hosts"
}
}
// LogLevel returns the appropriate default log level
func (s *SmartDefaults) LogLevel() string {
switch s.Profile {
case ProfileCI:
return "debug" // More verbose for CI debugging
case ProfileProduction:
return "info"
default: // ProfileLocal, ProfileContainer
return "info"
}
}
// MaxWorkers returns the appropriate default worker count
func (s *SmartDefaults) MaxWorkers() int {
switch s.Profile {
case ProfileCI:
return 1 // Conservative for CI
case ProfileProduction:
return runtime.NumCPU() // Scale with CPU cores
default: // ProfileLocal, ProfileContainer
return 2 // Reasonable default for local dev
}
}
// PollInterval returns the appropriate default poll interval in seconds
func (s *SmartDefaults) PollInterval() int {
switch s.Profile {
case ProfileCI:
return 1 // Fast polling for quick tests
case ProfileProduction:
return 10 // Conservative for production
default: // ProfileLocal, ProfileContainer
return 5 // Balanced default
}
}
// IsInContainer returns true if running in a container environment
func (s *SmartDefaults) IsInContainer() bool {
return s.Profile == ProfileContainer || s.Profile == ProfileCI
}
// IsProduction returns true if this is a production environment
func (s *SmartDefaults) IsProduction() bool {
return s.Profile == ProfileProduction
}
// IsCI returns true if this is a CI environment
func (s *SmartDefaults) IsCI() bool {
return s.Profile == ProfileCI
}
// ExpandPath expands ~ and environment variables in paths
func (s *SmartDefaults) ExpandPath(path string) string {
if strings.HasPrefix(path, "~/") {
if home, err := os.UserHomeDir(); err == nil {
path = filepath.Join(home, path[2:])
}
}
// Expand environment variables
path = os.ExpandEnv(path)
return path
}
// GetEnvironmentDescription returns a human-readable description
func (s *SmartDefaults) GetEnvironmentDescription() string {
switch s.Profile {
case ProfileLocal:
return "Local Development"
case ProfileContainer:
return "Container Environment"
case ProfileCI:
return "CI/CD Environment"
case ProfileProduction:
return "Production Environment"
default:
return "Unknown Environment"
}
}

View file

@ -0,0 +1,69 @@
// Package utils provides shared utilities for the fetch_ml project,
// including SSH clients, configuration helpers, logging, metrics,
// and validation functions.
package config
import (
"fmt"
"net"
"os"
)
// Validator is an interface for types that can validate themselves.
type Validator interface {
Validate() error
}
// ValidateConfig validates a configuration struct that implements the Validator interface.
func ValidateConfig(v Validator) error {
return v.Validate()
}
// ValidatePort checks if a port number is within the valid range (1-65535).
func ValidatePort(port int) error {
if port < 1 || port > 65535 {
return fmt.Errorf("invalid port: %d (must be 1-65535)", port)
}
return nil
}
// ValidateDirectory checks if a path exists and is a directory.
func ValidateDirectory(path string) error {
if path == "" {
return fmt.Errorf("path cannot be empty")
}
expanded := ExpandPath(path)
info, err := os.Stat(expanded)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("directory does not exist: %s", expanded)
}
return fmt.Errorf("cannot access directory %s: %w", expanded, err)
}
if !info.IsDir() {
return fmt.Errorf("path is not a directory: %s", expanded)
}
return nil
}
// ValidateRedisAddr validates a Redis address in the format "host:port".
func ValidateRedisAddr(addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid redis address format: %w", err)
}
if host == "" {
return fmt.Errorf("redis host cannot be empty")
}
var portInt int
if _, err := fmt.Sscanf(port, "%d", &portInt); err != nil {
return fmt.Errorf("invalid port number %q: %w", port, err)
}
return ValidatePort(portInt)
}

View file

@ -0,0 +1,105 @@
// Package utils provides shared utilities for the fetch_ml project.
package container
import (
"fmt"
"os/exec"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/config"
)
// PodmanConfig holds configuration for Podman container execution
type PodmanConfig struct {
Image string
Workspace string
Results string
ContainerWorkspace string
ContainerResults string
GPUAccess bool
Memory string
CPUs string
}
// BuildPodmanCommand builds a Podman command for executing ML experiments
func BuildPodmanCommand(cfg PodmanConfig, scriptPath, requirementsPath string, extraArgs []string) *exec.Cmd {
args := []string{
"run", "--rm",
"--security-opt", "no-new-privileges",
"--cap-drop", "ALL",
}
if cfg.Memory != "" {
args = append(args, "--memory", cfg.Memory)
} else {
args = append(args, "--memory", config.DefaultPodmanMemory)
}
if cfg.CPUs != "" {
args = append(args, "--cpus", cfg.CPUs)
} else {
args = append(args, "--cpus", config.DefaultPodmanCPUs)
}
args = append(args, "--userns", "keep-id")
// Mount workspace
workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace)
args = append(args, "-v", workspaceMount)
// Mount results
resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults)
args = append(args, "-v", resultsMount)
if cfg.GPUAccess {
args = append(args, "--device", "/dev/dri")
}
// Image and command
args = append(args, cfg.Image,
"--workspace", cfg.ContainerWorkspace,
"--requirements", requirementsPath,
"--script", scriptPath,
)
// Add extra arguments via --args flag
if len(extraArgs) > 0 {
args = append(args, "--args")
args = append(args, extraArgs...)
}
return exec.Command("podman", args...)
}
// SanitizePath ensures a path is safe to use (prevents path traversal)
func SanitizePath(path string) (string, error) {
// Clean the path to remove any .. or . components
cleaned := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleaned, "..") {
return "", fmt.Errorf("path traversal detected: %s", path)
}
return cleaned, nil
}
// ValidateJobName validates a job name is safe
func ValidateJobName(jobName string) error {
if jobName == "" {
return fmt.Errorf("job name cannot be empty")
}
// Check for dangerous characters
if strings.ContainsAny(jobName, "/\\<>:\"|?*") {
return fmt.Errorf("job name contains invalid characters: %s", jobName)
}
// Check for path traversal
if strings.Contains(jobName, "..") {
return fmt.Errorf("job name contains path traversal: %s", jobName)
}
return nil
}

39
internal/errors/errors.go Normal file
View file

@ -0,0 +1,39 @@
// Package utils provides shared utilities for the fetch_ml project.
package errors
import (
"fmt"
)
// DataFetchError represents an error that occurred while fetching a dataset
// from the NAS to the ML server.
type DataFetchError struct {
Dataset string
JobName string
Err error
}
func (e *DataFetchError) Error() string {
return fmt.Sprintf("failed to fetch dataset %s for job %s: %v",
e.Dataset, e.JobName, e.Err)
}
func (e *DataFetchError) Unwrap() error {
return e.Err
}
type TaskExecutionError struct {
TaskID string
JobName string
Phase string // "data_fetch", "execution", "cleanup"
Err error
}
func (e *TaskExecutionError) Error() string {
return fmt.Sprintf("task %s (%s) failed during %s: %v",
e.TaskID[:8], e.JobName, e.Phase, e.Err)
}
func (e *TaskExecutionError) Unwrap() error {
return e.Err
}

View file

@ -0,0 +1,343 @@
package experiment
import (
"encoding/binary"
"fmt"
"math"
"os"
"path/filepath"
"time"
)
// Metadata represents experiment metadata stored in meta.bin
type Metadata struct {
CommitID string
Timestamp int64
JobName string
User string
}
// Manager handles experiment storage and metadata
type Manager struct {
basePath string
}
func NewManager(basePath string) *Manager {
return &Manager{
basePath: basePath,
}
}
// Initialize ensures the experiment directory exists
func (m *Manager) Initialize() error {
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return fmt.Errorf("failed to create experiment base directory: %w", err)
}
return nil
}
// GetExperimentPath returns the path for a given commit ID
func (m *Manager) GetExperimentPath(commitID string) string {
return filepath.Join(m.basePath, commitID)
}
// GetFilesPath returns the path to the files directory for an experiment
func (m *Manager) GetFilesPath(commitID string) string {
return filepath.Join(m.GetExperimentPath(commitID), "files")
}
// GetMetadataPath returns the path to meta.bin for an experiment
func (m *Manager) GetMetadataPath(commitID string) string {
return filepath.Join(m.GetExperimentPath(commitID), "meta.bin")
}
// ExperimentExists checks if an experiment with the given commit ID exists
func (m *Manager) ExperimentExists(commitID string) bool {
path := m.GetExperimentPath(commitID)
info, err := os.Stat(path)
return err == nil && info.IsDir()
}
// CreateExperiment creates the directory structure for a new experiment
func (m *Manager) CreateExperiment(commitID string) error {
filesPath := m.GetFilesPath(commitID)
if err := os.MkdirAll(filesPath, 0755); err != nil {
return fmt.Errorf("failed to create experiment directory: %w", err)
}
return nil
}
// WriteMetadata writes experiment metadata to meta.bin
func (m *Manager) WriteMetadata(meta *Metadata) error {
path := m.GetMetadataPath(meta.CommitID)
// Binary format:
// [version:1][timestamp:8][commit_id_len:1][commit_id:var][job_name_len:1][job_name:var][user_len:1][user:var]
buf := make([]byte, 0, 256)
// Version
buf = append(buf, 0x01)
// Timestamp
ts := make([]byte, 8)
binary.BigEndian.PutUint64(ts, uint64(meta.Timestamp))
buf = append(buf, ts...)
// Commit ID
buf = append(buf, byte(len(meta.CommitID)))
buf = append(buf, []byte(meta.CommitID)...)
// Job Name
buf = append(buf, byte(len(meta.JobName)))
buf = append(buf, []byte(meta.JobName)...)
// User
buf = append(buf, byte(len(meta.User)))
buf = append(buf, []byte(meta.User)...)
return os.WriteFile(path, buf, 0644)
}
// ReadMetadata reads experiment metadata from meta.bin
func (m *Manager) ReadMetadata(commitID string) (*Metadata, error) {
path := m.GetMetadataPath(commitID)
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read metadata: %w", err)
}
if len(data) < 10 {
return nil, fmt.Errorf("metadata file too short")
}
meta := &Metadata{}
offset := 0
// Version
version := data[offset]
offset++
if version != 0x01 {
return nil, fmt.Errorf("unsupported metadata version: %d", version)
}
// Timestamp
meta.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8]))
offset += 8
// Commit ID
commitIDLen := int(data[offset])
offset++
meta.CommitID = string(data[offset : offset+commitIDLen])
offset += commitIDLen
// Job Name
if offset >= len(data) {
return meta, nil
}
jobNameLen := int(data[offset])
offset++
meta.JobName = string(data[offset : offset+jobNameLen])
offset += jobNameLen
// User
if offset >= len(data) {
return meta, nil
}
userLen := int(data[offset])
offset++
meta.User = string(data[offset : offset+userLen])
return meta, nil
}
// ListExperiments returns all experiment commit IDs
func (m *Manager) ListExperiments() ([]string, error) {
entries, err := os.ReadDir(m.basePath)
if err != nil {
return nil, fmt.Errorf("failed to read experiments directory: %w", err)
}
var commitIDs []string
for _, entry := range entries {
if entry.IsDir() {
commitIDs = append(commitIDs, entry.Name())
}
}
return commitIDs, nil
}
// PruneExperiments removes old experiments based on retention policy
func (m *Manager) PruneExperiments(keepCount int, olderThanDays int) ([]string, error) {
commitIDs, err := m.ListExperiments()
if err != nil {
return nil, err
}
type experiment struct {
commitID string
timestamp int64
}
var experiments []experiment
for _, commitID := range commitIDs {
meta, err := m.ReadMetadata(commitID)
if err != nil {
continue // Skip experiments with invalid metadata
}
experiments = append(experiments, experiment{
commitID: commitID,
timestamp: meta.Timestamp,
})
}
// Sort by timestamp (newest first)
for i := 0; i < len(experiments); i++ {
for j := i + 1; j < len(experiments); j++ {
if experiments[j].timestamp > experiments[i].timestamp {
experiments[i], experiments[j] = experiments[j], experiments[i]
}
}
}
var pruned []string
cutoffTime := time.Now().AddDate(0, 0, -olderThanDays).Unix()
for i, exp := range experiments {
shouldPrune := false
// Keep the newest N experiments
if i >= keepCount {
shouldPrune = true
}
// Also prune if older than threshold
if olderThanDays > 0 && exp.timestamp < cutoffTime {
shouldPrune = true
}
if shouldPrune {
expPath := m.GetExperimentPath(exp.commitID)
if err := os.RemoveAll(expPath); err != nil {
// Log but continue
continue
}
pruned = append(pruned, exp.commitID)
}
}
return pruned, nil
}
// Metric represents a single data point in an experiment
type Metric struct {
Name string `json:"name"`
Value float64 `json:"value"`
Step int `json:"step"`
Timestamp int64 `json:"timestamp"`
}
// GetMetricsPath returns the path to metrics.bin for an experiment
func (m *Manager) GetMetricsPath(commitID string) string {
return filepath.Join(m.GetExperimentPath(commitID), "metrics.bin")
}
// LogMetric appends a metric to the experiment's metrics file
func (m *Manager) LogMetric(commitID string, name string, value float64, step int) error {
path := m.GetMetricsPath(commitID)
// Binary format for each metric:
// [timestamp:8][step:4][value:8][name_len:1][name:var]
buf := make([]byte, 0, 64)
// Timestamp
ts := make([]byte, 8)
binary.BigEndian.PutUint64(ts, uint64(time.Now().Unix()))
buf = append(buf, ts...)
// Step
st := make([]byte, 4)
binary.BigEndian.PutUint32(st, uint32(step))
buf = append(buf, st...)
// Value (float64)
val := make([]byte, 8)
binary.BigEndian.PutUint64(val, math.Float64bits(value))
buf = append(buf, val...)
// Name
if len(name) > 255 {
name = name[:255]
}
buf = append(buf, byte(len(name)))
buf = append(buf, []byte(name)...)
// Append to file
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open metrics file: %w", err)
}
defer f.Close()
if _, err := f.Write(buf); err != nil {
return fmt.Errorf("failed to write metric: %w", err)
}
return nil
}
// GetMetrics reads all metrics for an experiment
func (m *Manager) GetMetrics(commitID string) ([]Metric, error) {
path := m.GetMetricsPath(commitID)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return []Metric{}, nil
}
return nil, fmt.Errorf("failed to read metrics file: %w", err)
}
var metrics []Metric
offset := 0
for offset < len(data) {
if offset+21 > len(data) { // Min size check
break
}
m := Metric{}
// Timestamp
m.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8]))
offset += 8
// Step
m.Step = int(binary.BigEndian.Uint32(data[offset : offset+4]))
offset += 4
// Value
bits := binary.BigEndian.Uint64(data[offset : offset+8])
m.Value = math.Float64frombits(bits)
offset += 8
// Name
nameLen := int(data[offset])
offset++
if offset+nameLen > len(data) {
break
}
m.Name = string(data[offset : offset+nameLen])
offset += nameLen
metrics = append(metrics, m)
}
return metrics, nil
}

View file

@ -0,0 +1,52 @@
package logging
import (
"log/slog"
"os"
"strings"
)
// Config holds logging configuration
type Config struct {
Level string `yaml:"level"`
File string `yaml:"file"`
AuditLog string `yaml:"audit_log"`
}
// LevelFromEnv reads LOG_LEVEL (if set) and returns the matching slog level.
// Accepted values: debug, info, warn, error. Defaults to info.
func LevelFromEnv() slog.Level {
return parseLevel(os.Getenv("LOG_LEVEL"), slog.LevelInfo)
}
func parseLevel(value string, defaultLevel slog.Level) slog.Level {
switch strings.ToLower(strings.TrimSpace(value)) {
case "debug":
return slog.LevelDebug
case "warn", "warning":
return slog.LevelWarn
case "error":
return slog.LevelError
case "info", "":
return slog.LevelInfo
default:
return defaultLevel
}
}
// NewConfiguredLogger creates a logger using the level configured via LOG_LEVEL.
// JSON/text output is still controlled by LOG_FORMAT in NewLogger.
func NewConfiguredLogger() *Logger {
return NewLogger(LevelFromEnv(), false)
}
// NewLoggerFromConfig creates a logger from configuration
func NewLoggerFromConfig(cfg Config) *Logger {
level := parseLevel(cfg.Level, slog.LevelInfo)
if cfg.File != "" {
return NewFileLogger(level, false, cfg.File)
}
return NewLogger(level, false)
}

172
internal/logging/logging.go Normal file
View file

@ -0,0 +1,172 @@
package logging
import (
"context"
"io"
"log/slog"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
)
type ctxKey string
const (
CtxTraceID ctxKey = "trace_id"
CtxSpanID ctxKey = "span_id"
CtxWorker ctxKey = "worker_id"
CtxJob ctxKey = "job_name"
CtxTask ctxKey = "task_id"
)
type Logger struct {
*slog.Logger
}
// NewLogger creates a logger that writes to stderr (development mode)
func NewLogger(level slog.Level, jsonOutput bool) *Logger {
opts := &slog.HandlerOptions{
Level: level,
AddSource: os.Getenv("LOG_ADD_SOURCE") == "1",
}
var handler slog.Handler
if jsonOutput || os.Getenv("LOG_FORMAT") == "json" {
handler = slog.NewJSONHandler(os.Stderr, opts)
} else {
handler = NewColorTextHandler(os.Stderr, opts)
}
return &Logger{slog.New(handler)}
}
// NewFileLogger creates a logger that writes to a file only (production mode)
func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger {
opts := &slog.HandlerOptions{
Level: level,
AddSource: os.Getenv("LOG_ADD_SOURCE") == "1",
}
// Create log directory if it doesn't exist
if logFile != "" {
logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0755); err != nil {
// Fallback to stderr only if directory creation fails
return NewLogger(level, jsonOutput)
}
}
// Open log file
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
// Fallback to stderr only if file creation fails
return NewLogger(level, jsonOutput)
}
// Write to file only (production)
var handler slog.Handler
if jsonOutput || os.Getenv("LOG_FORMAT") == "json" {
handler = slog.NewJSONHandler(file, opts)
} else {
handler = slog.NewTextHandler(file, opts)
}
return &Logger{slog.New(handler)}
}
// Inject trace + span if missing
func EnsureTrace(ctx context.Context) context.Context {
if ctx.Value(CtxTraceID) == nil {
ctx = context.WithValue(ctx, CtxTraceID, uuid.NewString())
}
if ctx.Value(CtxSpanID) == nil {
ctx = context.WithValue(ctx, CtxSpanID, uuid.NewString())
}
return ctx
}
func (l *Logger) WithContext(ctx context.Context, args ...any) *Logger {
if trace := ctx.Value(CtxTraceID); trace != nil {
args = append(args, "trace_id", trace)
}
if span := ctx.Value(CtxSpanID); span != nil {
args = append(args, "span_id", span)
}
if worker := ctx.Value(CtxWorker); worker != nil {
args = append(args, "worker_id", worker)
}
if job := ctx.Value(CtxJob); job != nil {
args = append(args, "job_name", job)
}
if task := ctx.Value(CtxTask); task != nil {
args = append(args, "task_id", task)
}
return &Logger{Logger: l.With(args...)}
}
func CtxWithWorker(ctx context.Context, worker string) context.Context {
return context.WithValue(ctx, CtxWorker, worker)
}
func CtxWithJob(ctx context.Context, job string) context.Context {
return context.WithValue(ctx, CtxJob, job)
}
func CtxWithTask(ctx context.Context, task string) context.Context {
return context.WithValue(ctx, CtxTask, task)
}
func (l *Logger) Component(ctx context.Context, name string) *Logger {
return l.WithContext(ctx, "component", name)
}
func (l *Logger) Worker(ctx context.Context, workerID string) *Logger {
return l.WithContext(ctx, "worker_id", workerID)
}
func (l *Logger) Job(ctx context.Context, job string, task string) *Logger {
return l.WithContext(ctx, "job_name", job, "task_id", task)
}
func (l *Logger) Fatal(msg string, args ...any) {
l.Error(msg, args...)
os.Exit(1)
}
func (l *Logger) Panic(msg string, args ...any) {
l.Error(msg, args...)
panic(msg)
}
// -----------------------------------------------------
// Colorized human-friendly console logs
// -----------------------------------------------------
type ColorTextHandler struct {
slog.Handler
}
func NewColorTextHandler(w io.Writer, opts *slog.HandlerOptions) slog.Handler {
base := slog.NewTextHandler(w, opts)
return &ColorTextHandler{Handler: base}
}
func (h *ColorTextHandler) Handle(ctx context.Context, r slog.Record) error {
// Add uniform timestamp (override default)
r.Time = time.Now()
switch r.Level {
case slog.LevelDebug:
r.Add("lvl_color", "\033[34mDBG\033[0m")
case slog.LevelInfo:
r.Add("lvl_color", "\033[32mINF\033[0m")
case slog.LevelWarn:
r.Add("lvl_color", "\033[33mWRN\033[0m")
case slog.LevelError:
r.Add("lvl_color", "\033[31mERR\033[0m")
}
return h.Handler.Handle(ctx, r)
}

View file

@ -0,0 +1,80 @@
package logging
import (
"regexp"
"strings"
)
// Patterns for sensitive data
var (
// API keys: 32+ hex characters
apiKeyPattern = regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`)
// JWT tokens
jwtPattern = regexp.MustCompile(`eyJ[a-zA-Z0-9_-]{10,}\.eyJ[a-zA-Z0-9_-]{10,}\.[a-zA-Z0-9_-]{10,}`)
// Password-like fields in logs
passwordPattern = regexp.MustCompile(`(?i)(password|passwd|pwd|secret|token|key)["']?\s*[:=]\s*["']?([^"'\s,}]+)`)
// Redis URLs with passwords
redisPasswordPattern = regexp.MustCompile(`redis://:[^@]+@`)
)
// SanitizeLogMessage removes sensitive data from log messages
func SanitizeLogMessage(message string) string {
// Redact API keys
message = apiKeyPattern.ReplaceAllString(message, "[REDACTED-API-KEY]")
// Redact JWT tokens
message = jwtPattern.ReplaceAllString(message, "[REDACTED-JWT]")
// Redact password-like fields
message = passwordPattern.ReplaceAllStringFunc(message, func(match string) string {
parts := passwordPattern.FindStringSubmatch(match)
if len(parts) >= 2 {
return parts[1] + "=[REDACTED]"
}
return match
})
// Redact Redis passwords from URLs
message = redisPasswordPattern.ReplaceAllString(message, "redis://:[REDACTED]@")
return message
}
// SanitizeArgs removes sensitive data from structured log arguments
func SanitizeArgs(args []any) []any {
sanitized := make([]any, len(args))
copy(sanitized, args)
for i := 0; i < len(sanitized)-1; i += 2 {
// Check if this is a key-value pair
key, okKey := sanitized[i].(string)
value, okValue := sanitized[i+1].(string)
if okKey && okValue {
lowerKey := strings.ToLower(key)
// Redact sensitive fields
if strings.Contains(lowerKey, "password") ||
strings.Contains(lowerKey, "secret") ||
strings.Contains(lowerKey, "token") ||
strings.Contains(lowerKey, "key") ||
strings.Contains(lowerKey, "api") {
sanitized[i+1] = "[REDACTED]"
} else if strings.HasPrefix(value, "redis://") {
sanitized[i+1] = SanitizeLogMessage(value)
}
}
}
return sanitized
}
// RedactAPIKey masks an API key for logging (shows first/last 4 chars)
func RedactAPIKey(key string) string {
if len(key) <= 8 {
return "[REDACTED]"
}
return key[:4] + "..." + key[len(key)-4:]
}

View file

@ -0,0 +1,71 @@
// Package utils provides shared utilities for the fetch_ml project.
package metrics
import (
"sync/atomic"
"time"
)
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
type Metrics struct {
TasksProcessed atomic.Int64
TasksFailed atomic.Int64
DataFetchTime atomic.Int64 // Total nanoseconds
ExecutionTime atomic.Int64
DataTransferred atomic.Int64 // Total bytes
ActiveTasks atomic.Int64
QueuedTasks atomic.Int64
}
func (m *Metrics) RecordTaskSuccess(duration time.Duration) {
m.TasksProcessed.Add(1)
m.ExecutionTime.Add(duration.Nanoseconds())
}
func (m *Metrics) RecordTaskFailure() {
m.TasksFailed.Add(1)
}
func (m *Metrics) RecordTaskStart() {
m.ActiveTasks.Add(1)
}
// RecordTaskCompletion decrements the number of active tasks. It is safe to call
// even if no tasks are currently recorded; the caller should ensure calls are
// balanced with RecordTaskStart.
func (m *Metrics) RecordTaskCompletion() {
m.ActiveTasks.Add(-1)
}
func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) {
m.DataTransferred.Add(bytes)
m.DataFetchTime.Add(duration.Nanoseconds())
}
func (m *Metrics) SetQueuedTasks(count int64) {
m.QueuedTasks.Store(count)
}
func (m *Metrics) GetStats() map[string]any {
processed := m.TasksProcessed.Load()
failed := m.TasksFailed.Load()
dataTransferred := m.DataTransferred.Load()
dataFetchTime := m.DataFetchTime.Load()
return map[string]any{
"tasks_processed": processed,
"tasks_failed": failed,
"active_tasks": m.ActiveTasks.Load(),
"queued_tasks": m.QueuedTasks.Load(),
"success_rate": float64(processed-failed) / float64(max(processed, 1)),
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / max(processed, 1)),
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
"avg_fetch_time": time.Duration(dataFetchTime / max(processed, 1)),
}
}

View file

@ -0,0 +1,259 @@
package middleware
import (
"context"
"log"
"net/http"
"strings"
"time"
"golang.org/x/time/rate"
)
// SecurityMiddleware provides comprehensive security features
type SecurityMiddleware struct {
rateLimiter *rate.Limiter
apiKeys map[string]bool
jwtSecret []byte
}
func NewSecurityMiddleware(apiKeys []string, jwtSecret string) *SecurityMiddleware {
keyMap := make(map[string]bool)
for _, key := range apiKeys {
keyMap[key] = true
}
return &SecurityMiddleware{
rateLimiter: rate.NewLimiter(rate.Limit(60), 10), // 60 requests per minute, burst of 10
apiKeys: keyMap,
jwtSecret: []byte(jwtSecret),
}
}
// Rate limiting middleware
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !sm.rateLimiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// API key authentication
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
// Also check Authorization header
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if !sm.apiKeys[apiKey] {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// Security headers middleware
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS protection
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Content Security Policy
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// HSTS (HTTPS only)
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
next.ServeHTTP(w, r)
})
}
// IP whitelist middleware
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
// Check if client IP is in whitelist
allowed := false
for _, ip := range allowedIPs {
if strings.Contains(ip, "/") {
// CIDR notation - would need proper IP net parsing
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
allowed = true
break
}
} else {
if clientIP == ip {
allowed = true
break
}
}
}
if !allowed {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// CORS middleware with restrictive defaults
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Only allow specific origins in production
allowedOrigins := []string{
"https://ml-experiments.example.com",
"https://app.example.com",
}
isAllowed := false
for _, allowed := range allowedOrigins {
if origin == allowed {
isAllowed = true
break
}
}
if isAllowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// Request timeout middleware
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
// Request size limiter
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > maxSize {
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
return
}
next.ServeHTTP(w, r)
})
}
}
// Security audit logging
func AuditLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := r.URL.Path
raw := r.URL.RawQuery
// Wrap response writer to capture status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Process request
next.ServeHTTP(wrapped, r)
// Log after processing
latency := time.Since(start)
clientIP := getClientIP(r)
method := r.Method
statusCode := wrapped.statusCode
if raw != "" {
path = path + "?" + raw
}
// Log security-relevant events
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
// Log to security audit system
logSecurityEvent(map[string]interface{}{
"timestamp": start.Unix(),
"client_ip": clientIP,
"method": method,
"path": path,
"status": statusCode,
"latency": latency,
"user_agent": r.UserAgent(),
"referer": r.Referer(),
})
}
})
}
// Helper to get client IP
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
// Response writer wrapper to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func logSecurityEvent(event map[string]interface{}) {
// Implementation would send to security monitoring system
// For now, just log (in production, use proper logging)
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
}

73
internal/network/retry.go Normal file
View file

@ -0,0 +1,73 @@
// Package utils provides shared utilities for the fetch_ml project.
package network
import (
"context"
"math"
"time"
)
type RetryConfig struct {
MaxAttempts int
InitialDelay time.Duration
MaxDelay time.Duration
Multiplier float64
}
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
}
}
func Retry(ctx context.Context, cfg RetryConfig, fn func() error) error {
var lastErr error
delay := cfg.InitialDelay
for attempt := 0; attempt < cfg.MaxAttempts; attempt++ {
if err := fn(); err == nil {
return nil
} else {
lastErr = err
}
if attempt < cfg.MaxAttempts-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
delay = time.Duration(math.Min(
float64(delay)*cfg.Multiplier,
float64(cfg.MaxDelay),
))
}
}
}
return lastErr
}
// RetryWithBackoff provides a convenient wrapper for common retry scenarios
func RetryWithBackoff(ctx context.Context, maxAttempts int, operation func() error) error {
cfg := RetryConfig{
MaxAttempts: maxAttempts,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 2 * time.Second,
Multiplier: 2.0,
}
return Retry(ctx, cfg, operation)
}
// RetryForNetworkOperations is optimized for network-related operations
func RetryForNetworkOperations(ctx context.Context, operation func() error) error {
cfg := RetryConfig{
MaxAttempts: 5,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 5 * time.Second,
Multiplier: 1.5,
}
return Retry(ctx, cfg, operation)
}

304
internal/network/ssh.go Normal file
View file

@ -0,0 +1,304 @@
// Package utils provides shared utilities for the fetch_ml project.
package network
import (
"context"
"fmt"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/config"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
// SSHClient provides SSH connection and command execution
type SSHClient struct {
client *ssh.Client
host string
}
// NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client.
// knownHostsPath is optional - if provided, will use known_hosts verification
func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) {
if host == "" || keyPath == "" {
// Local mode - no SSH connection needed
return &SSHClient{client: nil, host: ""}, nil
}
keyPath = config.ExpandPath(keyPath)
if strings.HasPrefix(keyPath, "~") {
home, _ := os.UserHomeDir()
keyPath = filepath.Join(home, keyPath[1:])
}
key, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read SSH key: %w", err)
}
var signer ssh.Signer
if signer, err = ssh.ParsePrivateKey(key); err != nil {
if _, ok := err.(*ssh.PassphraseMissingError); ok {
// Try to use ssh-agent for passphrase-protected keys
if agentSigner, agentErr := sshAgentSigner(); agentErr == nil {
signer = agentSigner
} else {
return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err)
}
} else {
return nil, fmt.Errorf("failed to parse SSH key: %w", err)
}
}
hostKeyCallback := ssh.InsecureIgnoreHostKey()
if knownHostsPath != "" {
knownHostsPath = config.ExpandPath(knownHostsPath)
if _, err := os.Stat(knownHostsPath); err == nil {
callback, err := knownhosts.New(knownHostsPath)
if err != nil {
log.Printf("Warning: failed to parse known_hosts: %v; using insecure host key verification", err)
} else {
hostKeyCallback = callback
}
} else if !os.IsNotExist(err) {
log.Printf("Warning: known_hosts not found at %s; using insecure host key verification", knownHostsPath)
}
}
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: hostKeyCallback,
Timeout: 10 * time.Second,
HostKeyAlgorithms: []string{
ssh.KeyAlgoRSA,
ssh.KeyAlgoRSASHA256,
ssh.KeyAlgoRSASHA512,
ssh.KeyAlgoED25519,
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
},
}
addr := fmt.Sprintf("%s:%d", host, port)
client, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
return nil, fmt.Errorf("SSH connection failed: %w", err)
}
return &SSHClient{client: client, host: host}, nil
}
// Exec executes a command remotely via SSH or locally if in local mode
func (c *SSHClient) Exec(cmd string) (string, error) {
return c.ExecContext(context.Background(), cmd)
}
// ExecContext executes a command with context support for cancellation and timeouts
func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) {
if c.client == nil {
// Local mode - execute command locally with context
execCmd := exec.CommandContext(ctx, "sh", "-c", cmd)
output, err := execCmd.CombinedOutput()
return string(output), err
}
session, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("create session: %w", err)
}
defer func() {
if closeErr := session.Close(); closeErr != nil {
// Session may already be closed, so we just log at debug level
log.Printf("session close error (may be expected): %v", closeErr)
}
}()
// Run command with context cancellation
type result struct {
output string
err error
}
resultCh := make(chan result, 1)
go func() {
output, err := session.CombinedOutput(cmd)
resultCh <- result{string(output), err}
}()
select {
case <-ctx.Done():
// FIXED: Check error return value
if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil {
log.Printf("failed to send SIGTERM: %v", sigErr)
}
// Give process time to cleanup gracefully
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case res := <-resultCh:
// Command finished during graceful shutdown
return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output)
case <-timer.C:
if closeErr := session.Close(); closeErr != nil {
log.Printf("failed to force close session: %v", closeErr)
}
// Wait a bit more for final result
select {
case res := <-resultCh:
return res.output, fmt.Errorf("command cancelled and force closed: %w (output: %s)", ctx.Err(), res.output)
case <-time.After(5 * time.Second):
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
}
}
case res := <-resultCh:
return res.output, res.err
}
}
// FileExists checks if a file exists remotely or locally
func (c *SSHClient) FileExists(path string) bool {
if c.client == nil {
// Local mode - check file locally
_, err := os.Stat(path)
return err == nil
}
out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path))
if err != nil {
return false
}
return strings.Contains(strings.TrimSpace(out), "exists")
}
// GetFileSize gets the size of a file or directory remotely or locally
func (c *SSHClient) GetFileSize(path string) (int64, error) {
if c.client == nil {
// Local mode - get size locally
var size int64
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
size += info.Size()
return nil
})
return size, err
}
out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path))
if err != nil {
return 0, err
}
var size int64
if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil {
return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err)
}
return size, nil
}
// RemoteExists checks if a remote path exists (alias for FileExists for compatibility)
func (c *SSHClient) RemoteExists(path string) bool {
return c.FileExists(path)
}
// ListDir lists directory contents remotely or locally
func (c *SSHClient) ListDir(path string) []string {
if c.client == nil {
// Local mode
entries, err := os.ReadDir(path)
if err != nil {
return nil
}
var items []string
for _, entry := range entries {
items = append(items, entry.Name())
}
return items
}
out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path))
if err != nil {
return nil
}
var items []string
for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
if line != "" {
items = append(items, line)
}
}
return items
}
// TailFile gets the last N lines of a file remotely or locally
func (c *SSHClient) TailFile(path string, lines int) string {
if c.client == nil {
// Local mode - read file and return last N lines
data, err := os.ReadFile(path)
if err != nil {
return ""
}
fileLines := strings.Split(string(data), "\n")
if len(fileLines) > lines {
fileLines = fileLines[len(fileLines)-lines:]
}
return strings.Join(fileLines, "\n")
}
out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path))
if err != nil {
return ""
}
return out
}
// Close closes the SSH connection
func (c *SSHClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// sshAgentSigner attempts to get a signer from ssh-agent
func sshAgentSigner() (ssh.Signer, error) {
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
if sshAuthSock == "" {
return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
}
conn, err := net.Dial("unix", sshAuthSock)
if err != nil {
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
log.Printf("warning: failed to close ssh-agent connection: %v", closeErr)
}
}()
agentClient := agent.NewClient(conn)
signers, err := agentClient.Signers()
if err != nil {
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
}
if len(signers) == 0 {
return nil, fmt.Errorf("no signers available in ssh-agent")
}
return signers[0], nil
}

84
internal/network/ssh_pool.go Executable file
View file

@ -0,0 +1,84 @@
// Package utils provides shared utilities for the fetch_ml project.
package network
import (
"context"
"sync"
"github.com/jfraeys/fetch_ml/internal/logging"
)
type SSHPool struct {
factory func() (*SSHClient, error)
pool chan *SSHClient
active int
maxConns int
mu sync.Mutex
logger *logging.Logger
}
func NewSSHPool(maxConns int, factory func() (*SSHClient, error), logger *logging.Logger) *SSHPool {
return &SSHPool{
factory: factory,
pool: make(chan *SSHClient, maxConns),
maxConns: maxConns,
logger: logger,
}
}
func (p *SSHPool) Get(ctx context.Context) (*SSHClient, error) {
select {
case conn := <-p.pool:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
p.mu.Lock()
if p.active < p.maxConns {
p.active++
p.mu.Unlock()
return p.factory()
}
p.mu.Unlock()
// Wait for available connection
select {
case conn := <-p.pool:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (p *SSHPool) Put(conn *SSHClient) {
select {
case p.pool <- conn:
default:
// Pool is full, close connection
err := conn.Close()
if err != nil {
p.logger.Warn("failed to close SSH connection", "error", err)
}
p.mu.Lock()
p.active--
p.mu.Unlock()
}
}
func (p *SSHPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
// Close all connections in the pool
close(p.pool)
for conn := range p.pool {
err := conn.Close()
if err != nil {
p.logger.Warn("failed to close SSH connection", "error", err)
}
}
// Reset active count
p.active = 0
}

215
internal/queue/errors.go Normal file
View file

@ -0,0 +1,215 @@
package queue
import (
"errors"
"fmt"
"strings"
)
// ErrorCategory represents the type of error encountered
type ErrorCategory string
const (
ErrorNetwork ErrorCategory = "network" // Network connectivity issues
ErrorResource ErrorCategory = "resource" // Resource exhaustion (OOM, disk full)
ErrorRateLimit ErrorCategory = "rate_limit" // Rate limiting or throttling
ErrorAuth ErrorCategory = "auth" // Authentication/authorization failures
ErrorValidation ErrorCategory = "validation" // Input validation errors
ErrorTimeout ErrorCategory = "timeout" // Operation timeout
ErrorPermanent ErrorCategory = "permanent" // Non-retryable errors
ErrorUnknown ErrorCategory = "unknown" // Unclassified errors
)
// TaskError wraps an error with category and context
type TaskError struct {
Category ErrorCategory
Message string
Cause error
Context map[string]string
}
func (e *TaskError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("[%s] %s: %v", e.Category, e.Message, e.Cause)
}
return fmt.Sprintf("[%s] %s", e.Category, e.Message)
}
func (e *TaskError) Unwrap() error {
return e.Cause
}
// NewTaskError creates a new categorized error
func NewTaskError(category ErrorCategory, message string, cause error) *TaskError {
return &TaskError{
Category: category,
Message: message,
Cause: cause,
Context: make(map[string]string),
}
}
// ClassifyError categorizes an error for retry logic
func ClassifyError(err error) ErrorCategory {
if err == nil {
return ErrorUnknown
}
// Check if already classified
var taskErr *TaskError
if errors.As(err, &taskErr) {
return taskErr.Category
}
errStr := strings.ToLower(err.Error())
// Network errors (retryable)
networkIndicators := []string{
"connection refused",
"connection reset",
"connection timeout",
"no route to host",
"network unreachable",
"temporary failure",
"dns",
"dial tcp",
"i/o timeout",
}
for _, indicator := range networkIndicators {
if strings.Contains(errStr, indicator) {
return ErrorNetwork
}
}
// Resource errors (retryable after delay)
resourceIndicators := []string{
"out of memory",
"oom",
"no space left",
"disk full",
"resource temporarily unavailable",
"too many open files",
"cannot allocate memory",
}
for _, indicator := range resourceIndicators {
if strings.Contains(errStr, indicator) {
return ErrorResource
}
}
// Rate limiting (retryable with backoff)
rateLimitIndicators := []string{
"rate limit",
"too many requests",
"throttle",
"quota exceeded",
"429",
}
for _, indicator := range rateLimitIndicators {
if strings.Contains(errStr, indicator) {
return ErrorRateLimit
}
}
// Timeout errors (retryable)
timeoutIndicators := []string{
"timeout",
"deadline exceeded",
"context deadline",
}
for _, indicator := range timeoutIndicators {
if strings.Contains(errStr, indicator) {
return ErrorTimeout
}
}
// Authentication errors (not retryable)
authIndicators := []string{
"unauthorized",
"forbidden",
"authentication failed",
"invalid credentials",
"access denied",
"401",
"403",
}
for _, indicator := range authIndicators {
if strings.Contains(errStr, indicator) {
return ErrorAuth
}
}
// Validation errors (not retryable)
validationIndicators := []string{
"invalid input",
"validation failed",
"bad request",
"malformed",
"400",
}
for _, indicator := range validationIndicators {
if strings.Contains(errStr, indicator) {
return ErrorValidation
}
}
// Default to unknown
return ErrorUnknown
}
// IsRetryable determines if an error category should be retried
func IsRetryable(category ErrorCategory) bool {
switch category {
case ErrorNetwork, ErrorResource, ErrorRateLimit, ErrorTimeout, ErrorUnknown:
return true
case ErrorAuth, ErrorValidation, ErrorPermanent:
return false
default:
return false
}
}
// GetUserMessage returns a user-friendly error message with suggestions
func GetUserMessage(category ErrorCategory, err error) string {
messages := map[ErrorCategory]string{
ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.",
ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.",
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
ErrorValidation: "Invalid input. Please review your request and correct any errors.",
ErrorTimeout: "Operation timed out. The task may be too complex or the system is slow. Try again or simplify the request.",
ErrorPermanent: "A permanent error occurred. This task cannot be retried automatically.",
ErrorUnknown: "An unexpected error occurred. If this persists, please contact support.",
}
baseMsg := messages[category]
if err != nil {
return fmt.Sprintf("%s (Details: %v)", baseMsg, err)
}
return baseMsg
}
// RetryDelay calculates the retry delay based on error category and retry count
func RetryDelay(category ErrorCategory, retryCount int) int {
switch category {
case ErrorRateLimit:
// Longer backoff for rate limits
return min(300, 10*(1<<retryCount)) // 10s, 20s, 40s, 80s, up to 300s
case ErrorResource:
// Medium backoff for resource issues
return min(120, 5*(1<<retryCount)) // 5s, 10s, 20s, 40s, up to 120s
case ErrorNetwork, ErrorTimeout:
// Standard exponential backoff
return 1 << retryCount // 1s, 2s, 4s, 8s, etc
default:
// Default exponential backoff
return 1 << retryCount
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

118
internal/queue/metrics.go Normal file
View file

@ -0,0 +1,118 @@
package queue
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// Queue metrics
QueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
Name: "fetch_ml_queue_depth",
Help: "Number of tasks in the queue",
})
TasksQueued = promauto.NewCounter(prometheus.CounterOpts{
Name: "fetch_ml_tasks_queued_total",
Help: "Total number of tasks queued",
})
// Task execution metrics
TaskDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "fetch_ml_task_duration_seconds",
Help: "Task execution duration in seconds",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, // 1s to 1h
}, []string{"job_name", "status"})
TasksCompleted = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_tasks_completed_total",
Help: "Total number of completed tasks",
}, []string{"job_name", "status"})
// Error metrics
TaskFailures = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_task_failures_total",
Help: "Total number of failed tasks by error category",
}, []string{"job_name", "error_category"})
TaskRetries = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_task_retries_total",
Help: "Total number of task retries",
}, []string{"job_name", "error_category"})
// Lease metrics
LeaseExpirations = promauto.NewCounter(prometheus.CounterOpts{
Name: "fetch_ml_lease_expirations_total",
Help: "Total number of expired leases reclaimed",
})
LeaseRenewals = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_lease_renewals_total",
Help: "Total number of successful lease renewals",
}, []string{"worker_id"})
// Dead letter queue metrics
DLQSize = promauto.NewGauge(prometheus.GaugeOpts{
Name: "fetch_ml_dlq_size",
Help: "Number of tasks in dead letter queue",
})
DLQAdditions = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_dlq_additions_total",
Help: "Total number of tasks moved to DLQ",
}, []string{"reason"})
// Worker metrics
ActiveTasks = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "fetch_ml_active_tasks",
Help: "Number of currently executing tasks",
}, []string{"worker_id"})
WorkerHeartbeats = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "fetch_ml_worker_heartbeats_total",
Help: "Total number of worker heartbeats",
}, []string{"worker_id"})
)
// RecordTaskStart records when a task starts
func RecordTaskStart(jobName, workerID string) {
ActiveTasks.WithLabelValues(workerID).Inc()
}
// RecordTaskEnd records when a task completes
func RecordTaskEnd(jobName, workerID, status string, durationSeconds float64) {
ActiveTasks.WithLabelValues(workerID).Dec()
TaskDuration.WithLabelValues(jobName, status).Observe(durationSeconds)
TasksCompleted.WithLabelValues(jobName, status).Inc()
}
// RecordTaskFailure records a task failure with error category
func RecordTaskFailure(jobName string, errorCategory ErrorCategory) {
TaskFailures.WithLabelValues(jobName, string(errorCategory)).Inc()
}
// RecordTaskRetry records a task retry
func RecordTaskRetry(jobName string, errorCategory ErrorCategory) {
TaskRetries.WithLabelValues(jobName, string(errorCategory)).Inc()
}
// RecordLeaseExpiration records a lease expiration
func RecordLeaseExpiration() {
LeaseExpirations.Inc()
}
// RecordLeaseRenewal records a successful lease renewal
func RecordLeaseRenewal(workerID string) {
LeaseRenewals.WithLabelValues(workerID).Inc()
}
// RecordDLQAddition records a task being moved to DLQ
func RecordDLQAddition(reason string) {
DLQAdditions.WithLabelValues(reason).Inc()
DLQSize.Inc()
}
// UpdateQueueDepth updates the current queue depth gauge
func UpdateQueueDepth(depth int64) {
QueueDepth.Set(float64(depth))
}

558
internal/queue/queue.go Normal file
View file

@ -0,0 +1,558 @@
package queue
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
defaultLeaseDuration = 30 * time.Minute
defaultMaxRetries = 3
)
// TaskQueue manages ML experiment tasks via Redis
type TaskQueue struct {
client *redis.Client
ctx context.Context
cancel context.CancelFunc
metricsCh chan metricEvent
metricsDone chan struct{}
flushEvery time.Duration
}
type metricEvent struct {
JobName string
Metric string
Value float64
}
// Config holds configuration for TaskQueue
type Config struct {
RedisAddr string
RedisPassword string
RedisDB int
MetricsFlushInterval time.Duration
}
func NewTaskQueue(cfg Config) (*TaskQueue, error) {
var opts *redis.Options
var err error
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
opts, err = redis.ParseURL(cfg.RedisAddr)
if err != nil {
return nil, fmt.Errorf("invalid redis url: %w", err)
}
} else {
opts = &redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
}
}
rdb := redis.NewClient(opts)
ctx, cancel := context.WithCancel(context.Background())
if err := rdb.Ping(ctx).Err(); err != nil {
cancel()
return nil, fmt.Errorf("redis connection failed: %w", err)
}
flushEvery := cfg.MetricsFlushInterval
if flushEvery == 0 {
flushEvery = defaultMetricsFlushInterval
}
tq := &TaskQueue{
client: rdb,
ctx: ctx,
cancel: cancel,
metricsCh: make(chan metricEvent, 256),
metricsDone: make(chan struct{}),
flushEvery: flushEvery,
}
go tq.runMetricsBuffer()
go tq.runLeaseReclamation() // Start lease reclamation background job
return tq, nil
}
// AddTask adds a new task to the queue with default retry settings
func (tq *TaskQueue) AddTask(task *Task) error {
// Set default retry settings if not specified
if task.MaxRetries == 0 {
task.MaxRetries = defaultMaxRetries
}
taskData, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
pipe := tq.client.Pipeline()
// Store task data
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
// Add to priority queue (ZSET)
// Use priority as score (higher priority = higher score)
pipe.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: float64(task.Priority),
Member: task.ID,
})
// Initialize status
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
_, err = pipe.Exec(tq.ctx)
if err != nil {
return fmt.Errorf("failed to enqueue task: %w", err)
}
// Record metrics
TasksQueued.Inc()
// Update queue depth
depth, _ := tq.QueueDepth()
UpdateQueueDepth(depth)
return nil
}
// GetNextTask gets the next task without lease (backward compatible)
func (tq *TaskQueue) GetNextTask() (*Task, error) {
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
return tq.GetTask(taskID)
}
// GetNextTaskWithLease gets the next task and acquires a lease
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
// Pop highest priority task
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
taskID := result[0].Member.(string)
task, err := tq.GetTask(taskID)
if err != nil {
// Re-queue the task if we can't fetch it
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
// Acquire lease
now := time.Now()
leaseExpiry := now.Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
task.LeasedBy = workerID
// Update task with lease
if err := tq.UpdateTask(task); err != nil {
// Re-queue if update fails
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
Score: result[0].Score,
Member: taskID,
})
return nil, err
}
return task, nil
}
// RenewLease renews the lease on a task (heartbeat)
func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
if leaseDuration == 0 {
leaseDuration = defaultLeaseDuration
}
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Renew lease
leaseExpiry := time.Now().Add(leaseDuration)
task.LeaseExpiry = &leaseExpiry
// Record renewal metric
RecordLeaseRenewal(workerID)
return tq.UpdateTask(task)
}
// ReleaseLease releases the lease on a task
func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Verify the worker owns the lease
if task.LeasedBy != workerID {
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
}
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
return tq.UpdateTask(task)
}
// RetryTask re-queues a failed task with smart backoff based on error category
func (tq *TaskQueue) RetryTask(task *Task) error {
if task.RetryCount >= task.MaxRetries {
// Move to dead letter queue
RecordDLQAddition("max_retries")
return tq.MoveToDeadLetterQueue(task, "max retries exceeded")
}
// Classify the error if it exists
var errorCategory ErrorCategory = ErrorUnknown
if task.Error != "" {
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
}
// Check if error is retryable
if !IsRetryable(errorCategory) {
RecordDLQAddition(string(errorCategory))
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
}
task.RetryCount++
task.Status = "queued"
task.LastError = task.Error // Preserve last error
task.Error = "" // Clear current error
// Calculate smart backoff based on error category
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
task.NextRetry = &nextRetry
// Clear lease
task.LeaseExpiry = nil
task.LeasedBy = ""
// Record retry metrics
RecordTaskRetry(task.JobName, errorCategory)
// Re-queue with same priority
return tq.AddTask(task)
}
// MoveToDeadLetterQueue moves a task to the dead letter queue
func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
task.Status = "failed"
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
taskData, err := json.Marshal(task)
if err != nil {
return err
}
// Store in dead letter queue with timestamp
key := "task:dlq:" + task.ID
// Record metrics
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
pipe.ZRem(tq.ctx, TaskQueueKey, task.ID)
_, err = pipe.Exec(tq.ctx)
return err
}
// runLeaseReclamation reclaims expired leases every 1 minute
func (tq *TaskQueue) runLeaseReclamation() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-tq.ctx.Done():
return
case <-ticker.C:
if err := tq.reclaimExpiredLeases(); err != nil {
// Log error but continue
continue
}
}
}
}
// reclaimExpiredLeases finds and re-queues tasks with expired leases
func (tq *TaskQueue) reclaimExpiredLeases() error {
// Scan for all task keys
iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator()
now := time.Now()
for iter.Next(tq.ctx) {
taskKey := iter.Val()
taskID := taskKey[len(TaskPrefix):]
task, err := tq.GetTask(taskID)
if err != nil {
continue
}
// Check if lease expired and task is still running
if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" {
// Lease expired - retry or fail the task
task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy)
// Record lease expiration
RecordLeaseExpiration()
if task.RetryCount < task.MaxRetries {
// Retry the task
if err := tq.RetryTask(task); err != nil {
continue
}
} else {
// Max retries exceeded - move to DLQ
if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil {
continue
}
}
}
}
return iter.Err()
}
// GetTask retrieves a task by ID
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result()
if err != nil {
return nil, err
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
return nil, err
}
return &task, nil
}
// GetAllTasks retrieves all tasks from the queue
func (tq *TaskQueue) GetAllTasks() ([]*Task, error) {
// Get all task keys
keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result()
if err != nil {
return nil, err
}
var tasks []*Task
for _, key := range keys {
data, err := tq.client.Get(tq.ctx, key).Result()
if err != nil {
continue // Skip tasks that can't be retrieved
}
var task Task
if err := json.Unmarshal([]byte(data), &task); err != nil {
continue // Skip malformed tasks
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// GetTaskByName retrieves a task by its job name
func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) {
tasks, err := tq.GetAllTasks()
if err != nil {
return nil, err
}
for _, task := range tasks {
if task.JobName == jobName {
return task, nil
}
}
return nil, fmt.Errorf("task with job name '%s' not found", jobName)
}
// CancelTask marks a task as cancelled
func (tq *TaskQueue) CancelTask(taskID string) error {
task, err := tq.GetTask(taskID)
if err != nil {
return err
}
// Update task status to cancelled
task.Status = "cancelled"
now := time.Now()
task.EndedAt = &now
return tq.UpdateTask(task)
}
// UpdateTask updates a task in Redis
func (tq *TaskQueue) UpdateTask(task *Task) error {
taskData, err := json.Marshal(task)
if err != nil {
return err
}
pipe := tq.client.Pipeline()
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
"status", task.Status,
"task_id", task.ID,
"updated_at", time.Now().Format(time.RFC3339))
_, err = pipe.Exec(tq.ctx)
return err
}
// UpdateTaskWithMetrics updates task and records metrics
func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error {
if err := tq.UpdateTask(task); err != nil {
return err
}
metricName := "tasks_" + action
return tq.RecordMetric(task.JobName, metricName, 1)
}
// RecordMetric records a metric value
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
evt := metricEvent{JobName: jobName, Metric: metric, Value: value}
select {
case tq.metricsCh <- evt:
return nil
default:
return tq.writeMetrics(jobName, map[string]float64{metric: value})
}
}
// Heartbeat records worker heartbeat
func (tq *TaskQueue) Heartbeat(workerID string) error {
return tq.client.HSet(tq.ctx, WorkerHeartbeat,
workerID, time.Now().Unix()).Err()
}
// QueueDepth returns the number of pending tasks
func (tq *TaskQueue) QueueDepth() (int64, error) {
return tq.client.ZCard(tq.ctx, TaskQueueKey).Result()
}
// Close closes the task queue and cleans up resources
func (tq *TaskQueue) Close() error {
tq.cancel()
<-tq.metricsDone // Wait for metrics buffer to finish
return tq.client.Close()
}
// GetRedisClient returns the underlying Redis client for direct access
func (tq *TaskQueue) GetRedisClient() *redis.Client {
return tq.client
}
// WaitForNextTask waits for next task with timeout
func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) {
if ctx == nil {
ctx = tq.ctx
}
result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
member, ok := result.Member.(string)
if !ok {
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
}
return tq.GetTask(member)
}
// runMetricsBuffer buffers and flushes metrics
func (tq *TaskQueue) runMetricsBuffer() {
defer close(tq.metricsDone)
ticker := time.NewTicker(tq.flushEvery)
defer ticker.Stop()
pending := make(map[string]map[string]float64)
flush := func() {
for job, metrics := range pending {
if err := tq.writeMetrics(job, metrics); err != nil {
continue
}
delete(pending, job)
}
}
for {
select {
case <-tq.ctx.Done():
flush()
return
case evt, ok := <-tq.metricsCh:
if !ok {
flush()
return
}
if _, exists := pending[evt.JobName]; !exists {
pending[evt.JobName] = make(map[string]float64)
}
pending[evt.JobName][evt.Metric] = evt.Value
case <-ticker.C:
flush()
}
}
}
// writeMetrics writes metrics to Redis
func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error {
if len(metrics) == 0 {
return nil
}
key := JobMetricsPrefix + jobName
args := make([]any, 0, len(metrics)*2+2)
args = append(args, "timestamp", time.Now().Unix())
for metric, value := range metrics {
args = append(args, metric, value)
}
return tq.client.HSet(context.Background(), key, args...).Err()
}

View file

@ -0,0 +1,152 @@
package queue
import (
"testing"
"time"
)
func TestTask_UserFields(t *testing.T) {
task := &Task{
UserID: "testuser",
Username: "testuser",
CreatedBy: "testuser",
}
if task.UserID != "testuser" {
t.Errorf("Expected UserID to be 'testuser', got '%s'", task.UserID)
}
if task.Username != "testuser" {
t.Errorf("Expected Username to be 'testuser', got '%s'", task.Username)
}
if task.CreatedBy != "testuser" {
t.Errorf("Expected CreatedBy to be 'testuser', got '%s'", task.CreatedBy)
}
}
func TestTaskQueue_UserFiltering(t *testing.T) {
// Setup test Redis configuration
queueCfg := Config{
RedisAddr: "localhost:6379",
RedisDB: 15, // Use dedicated test DB
}
// Create task queue
taskQueue, err := NewTaskQueue(queueCfg)
if err != nil {
t.Skip("Redis not available for integration testing")
return
}
defer taskQueue.Close()
// Clear test database
taskQueue.client.FlushDB(taskQueue.ctx)
// Create test tasks with different users
tasks := []*Task{
{
ID: "task1",
JobName: "user1_job1",
Status: "queued",
UserID: "user1",
CreatedBy: "user1",
CreatedAt: time.Now(),
},
{
ID: "task2",
JobName: "user1_job2",
Status: "running",
UserID: "user1",
CreatedBy: "user1",
CreatedAt: time.Now(),
},
{
ID: "task3",
JobName: "user2_job1",
Status: "queued",
UserID: "user2",
CreatedBy: "user2",
CreatedAt: time.Now(),
},
{
ID: "task4",
JobName: "admin_job",
Status: "completed",
UserID: "admin",
CreatedBy: "admin",
CreatedAt: time.Now(),
},
}
// Add tasks to queue
for _, task := range tasks {
err := taskQueue.AddTask(task)
if err != nil {
t.Fatalf("Failed to add task %s: %v", task.ID, err)
}
}
// Test GetAllTasks
allTasks, err := taskQueue.GetAllTasks()
if err != nil {
t.Fatalf("Failed to get all tasks: %v", err)
}
if len(allTasks) != len(tasks) {
t.Errorf("Expected %d tasks, got %d", len(tasks), len(allTasks))
}
// Test user filtering logic
filterTasksForUser := func(tasks []*Task, userID string) []*Task {
var filtered []*Task
for _, task := range tasks {
if task.UserID == userID || task.CreatedBy == userID {
filtered = append(filtered, task)
}
}
return filtered
}
// Test filtering for user1 (should get 2 tasks)
user1Tasks := filterTasksForUser(allTasks, "user1")
if len(user1Tasks) != 2 {
t.Errorf("Expected 2 tasks for user1, got %d", len(user1Tasks))
}
// Test filtering for user2 (should get 1 task)
user2Tasks := filterTasksForUser(allTasks, "user2")
if len(user2Tasks) != 1 {
t.Errorf("Expected 1 task for user2, got %d", len(user2Tasks))
}
// Test filtering for admin (should get 1 task)
adminTasks := filterTasksForUser(allTasks, "admin")
if len(adminTasks) != 1 {
t.Errorf("Expected 1 task for admin, got %d", len(adminTasks))
}
// Test GetTaskByName
task, err := taskQueue.GetTaskByName("user1_job1")
if err != nil {
t.Errorf("Failed to get task by name: %v", err)
}
if task == nil || task.UserID != "user1" {
t.Error("Got wrong task or nil task")
}
// Test CancelTask
err = taskQueue.CancelTask("task1")
if err != nil {
t.Errorf("Failed to cancel task: %v", err)
}
// Verify task was cancelled
cancelledTask, err := taskQueue.GetTask("task1")
if err != nil {
t.Errorf("Failed to get cancelled task: %v", err)
}
if cancelledTask.Status != "cancelled" {
t.Errorf("Expected status 'cancelled', got '%s'", cancelledTask.Status)
}
}

View file

@ -0,0 +1,193 @@
package queue
import (
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTaskQueue(t *testing.T) {
// Start miniredis
s, err := miniredis.Run()
if err != nil {
t.Fatalf("failed to start miniredis: %v", err)
}
defer s.Close()
// Create TaskQueue
cfg := Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond, // Fast flush for testing
}
tq, err := NewTaskQueue(cfg)
assert.NoError(t, err)
defer tq.Close()
t.Run("AddTask", func(t *testing.T) {
task := &Task{
ID: "task-1",
JobName: "job-1",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
}
err = tq.AddTask(task)
assert.NoError(t, err)
// Verify task is in Redis
// Check ZSET
score, err := s.ZScore(TaskQueueKey, "task-1")
assert.NoError(t, err)
assert.Equal(t, float64(10), score)
})
t.Run("GetNextTask", func(t *testing.T) {
// Add another task
task := &Task{
ID: "task-2",
JobName: "job-2",
Status: "queued",
Priority: 20, // Higher priority
CreatedAt: time.Now(),
}
err = tq.AddTask(task)
assert.NoError(t, err)
// Should get task-2 first due to higher priority
nextTask, err := tq.GetNextTask()
assert.NoError(t, err)
assert.NotNil(t, nextTask)
assert.Equal(t, "task-2", nextTask.ID)
// Verify task is removed from ZSET
_, err = tq.client.ZScore(tq.ctx, TaskQueueKey, "task-2").Result()
assert.Equal(t, redis.Nil, err)
})
t.Run("GetNextTaskWithLease", func(t *testing.T) {
task := &Task{
ID: "task-lease",
JobName: "job-lease",
Status: "queued",
Priority: 15,
CreatedAt: time.Now(),
}
err := tq.AddTask(task)
require.NoError(t, err)
workerID := "worker-1"
leaseDuration := 1 * time.Minute
leasedTask, err := tq.GetNextTaskWithLease(workerID, leaseDuration)
require.NoError(t, err)
require.NotNil(t, leasedTask)
assert.Equal(t, "task-lease", leasedTask.ID)
assert.Equal(t, workerID, leasedTask.LeasedBy)
assert.NotNil(t, leasedTask.LeaseExpiry)
assert.True(t, leasedTask.LeaseExpiry.After(time.Now()))
})
t.Run("RenewLease", func(t *testing.T) {
taskID := "task-lease"
workerID := "worker-1"
// Get initial expiry
task, err := tq.GetTask(taskID)
require.NoError(t, err)
initialExpiry := task.LeaseExpiry
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Renew lease
err = tq.RenewLease(taskID, workerID, 1*time.Minute)
require.NoError(t, err)
// Verify expiry updated
task, err = tq.GetTask(taskID)
require.NoError(t, err)
assert.True(t, task.LeaseExpiry.After(*initialExpiry))
})
t.Run("ReleaseLease", func(t *testing.T) {
taskID := "task-lease"
workerID := "worker-1"
err := tq.ReleaseLease(taskID, workerID)
require.NoError(t, err)
task, err := tq.GetTask(taskID)
require.NoError(t, err)
assert.Nil(t, task.LeaseExpiry)
assert.Empty(t, task.LeasedBy)
})
t.Run("RetryTask", func(t *testing.T) {
task := &Task{
ID: "task-retry",
JobName: "job-retry",
Status: "failed",
Priority: 10,
CreatedAt: time.Now(),
MaxRetries: 3,
RetryCount: 0,
Error: "some transient error",
}
// Add task directly to verify retry logic
err := tq.AddTask(task)
require.NoError(t, err)
// Simulate failure and retry
task.Error = "connection timeout"
err = tq.RetryTask(task)
require.NoError(t, err)
// Verify task updated
updatedTask, err := tq.GetTask(task.ID)
require.NoError(t, err)
assert.Equal(t, 1, updatedTask.RetryCount)
assert.Equal(t, "queued", updatedTask.Status)
assert.Empty(t, updatedTask.Error)
assert.Equal(t, "connection timeout", updatedTask.LastError)
assert.NotNil(t, updatedTask.NextRetry)
})
t.Run("DLQ", func(t *testing.T) {
task := &Task{
ID: "task-dlq",
JobName: "job-dlq",
Status: "failed",
Priority: 10,
CreatedAt: time.Now(),
MaxRetries: 1,
RetryCount: 1, // Already at max retries
Error: "fatal error",
}
err := tq.AddTask(task)
require.NoError(t, err)
// Retry should move to DLQ
err = tq.RetryTask(task)
require.NoError(t, err)
// Verify removed from main queue
_, err = tq.client.ZScore(tq.ctx, TaskQueueKey, task.ID).Result()
assert.Equal(t, redis.Nil, err)
// Verify in DLQ
dlqKey := "task:dlq:" + task.ID
exists := s.Exists(dlqKey)
assert.True(t, exists)
// Verify DLQ content
val, err := s.Get(dlqKey)
require.NoError(t, err)
assert.Contains(t, val, "max retries exceeded")
})
}

47
internal/queue/task.go Normal file
View file

@ -0,0 +1,47 @@
package queue
import (
"time"
"github.com/jfraeys/fetch_ml/internal/config"
)
// Task represents an ML experiment task
type Task struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"` // queued, running, completed, failed
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// User ownership and permissions
UserID string `json:"user_id"` // User who owns this task
Username string `json:"username"` // Username for display
CreatedBy string `json:"created_by"` // User who submitted the task
// Lease management for task resilience
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
// Retry management
RetryCount int `json:"retry_count"` // Number of retry attempts made
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
LastError string `json:"last_error,omitempty"` // Last error encountered
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
}
// Redis key constants
var (
TaskQueueKey = config.RedisTaskQueueKey
TaskPrefix = config.RedisTaskPrefix
TaskStatusPrefix = config.RedisTaskStatusPrefix
WorkerHeartbeat = config.RedisWorkerHeartbeat
JobMetricsPrefix = config.RedisJobMetricsPrefix
)

433
internal/storage/db.go Normal file
View file

@ -0,0 +1,433 @@
package storage
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type DBConfig struct {
Type string
Connection string
Host string
Port int
Username string
Password string
Database string
}
type DB struct {
conn *sql.DB
dbType string
}
func NewDB(config DBConfig) (*DB, error) {
var conn *sql.DB
var err error
switch strings.ToLower(config.Type) {
case "sqlite":
conn, err = sql.Open("sqlite3", config.Connection)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Enable foreign keys
if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Enable WAL mode for better concurrency
if _, err := conn.Exec("PRAGMA journal_mode = WAL"); err != nil {
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
}
case "postgres":
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
case "postgresql":
// Handle "postgresql" as alias for "postgres"
connStr := buildPostgresConnectionString(config)
conn, err = sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
default:
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
}
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
}
func buildPostgresConnectionString(config DBConfig) string {
if config.Connection != "" {
return config.Connection
}
var connStr strings.Builder
connStr.WriteString("host=")
if config.Host != "" {
connStr.WriteString(config.Host)
} else {
connStr.WriteString("localhost")
}
if config.Port > 0 {
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
} else {
connStr.WriteString(" port=5432")
}
if config.Username != "" {
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
}
if config.Password != "" {
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
}
if config.Database != "" {
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
} else {
connStr.WriteString(" dbname=fetch_ml")
}
connStr.WriteString(" sslmode=disable")
return connStr.String()
}
// Legacy constructor for backward compatibility
func NewDBFromPath(dbPath string) (*DB, error) {
return NewDB(DBConfig{
Type: "sqlite",
Connection: dbPath,
})
}
type Job struct {
ID string `json:"id"`
JobName string `json:"job_name"`
Args string `json:"args"`
Status string `json:"status"`
Priority int64 `json:"priority"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
WorkerID string `json:"worker_id,omitempty"`
Error string `json:"error,omitempty"`
Datasets []string `json:"datasets,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type Worker struct {
ID string `json:"id"`
Hostname string `json:"hostname"`
LastHeartbeat time.Time `json:"last_heartbeat"`
Status string `json:"status"`
CurrentJobs int `json:"current_jobs"`
MaxJobs int `json:"max_jobs"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func (db *DB) Initialize(schema string) error {
if _, err := db.conn.Exec(schema); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
func (db *DB) Close() error {
return db.conn.Close()
}
// Job operations
func (db *DB) CreateJob(job *Job) error {
datasetsJSON, _ := json.Marshal(job.Datasets)
metadataJSON, _ := json.Marshal(job.Metadata)
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)`
}
_, err := db.conn.Exec(query, job.ID, job.JobName, job.Args, job.Status,
job.Priority, string(datasetsJSON), string(metadataJSON))
if err != nil {
return fmt.Errorf("failed to create job: %w", err)
}
return nil
}
func (db *DB) GetJob(id string) (*Job, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs WHERE id = ?`
} else {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs WHERE id = $1`
}
var job Job
var datasetsJSON, metadataJSON string
var workerID sql.NullString
var errorMsg sql.NullString
err := db.conn.QueryRow(query, id).Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to get job: %w", err)
}
if workerID.Valid {
job.WorkerID = workerID.String
}
if errorMsg.Valid {
job.Error = errorMsg.String
}
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
return &job, nil
}
func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error {
var query string
if db.dbType == "sqlite" {
query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?,
started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = ?`
} else {
query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3,
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
WHERE id = $6`
}
_, err := db.conn.Exec(query, status, workerID, errorMsg, status, status, id)
if err != nil {
return fmt.Errorf("failed to update job status: %w", err)
}
return nil
}
func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs`
} else {
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
ended_at, worker_id, error, datasets, metadata, updated_at
FROM jobs`
}
var args []interface{}
if status != "" {
if db.dbType == "sqlite" {
query += " WHERE status = ?"
} else {
query += " WHERE status = $1"
}
args = append(args, status)
}
query += " ORDER BY created_at DESC"
if limit > 0 {
if db.dbType == "sqlite" {
query += " LIMIT ?"
} else {
query += fmt.Sprintf(" LIMIT $%d", len(args)+1)
}
args = append(args, limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list jobs: %w", err)
}
defer rows.Close()
var jobs []*Job
for rows.Next() {
var job Job
var datasetsJSON, metadataJSON string
var workerID sql.NullString
var errorMsg sql.NullString
err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan job: %w", err)
}
if workerID.Valid {
job.WorkerID = workerID.String
}
if errorMsg.Valid {
job.Error = errorMsg.String
}
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
jobs = append(jobs, &job)
}
return jobs, nil
}
// Worker operations
func (db *DB) RegisterWorker(worker *Worker) error {
metadataJSON, _ := json.Marshal(worker.Metadata)
var query string
if db.dbType == "sqlite" {
query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
VALUES (?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
hostname = EXCLUDED.hostname,
status = EXCLUDED.status,
current_jobs = EXCLUDED.current_jobs,
max_jobs = EXCLUDED.max_jobs,
metadata = EXCLUDED.metadata`
}
_, err := db.conn.Exec(query, worker.ID, worker.Hostname, worker.Status,
worker.CurrentJobs, worker.MaxJobs, string(metadataJSON))
if err != nil {
return fmt.Errorf("failed to register worker: %w", err)
}
return nil
}
func (db *DB) UpdateWorkerHeartbeat(workerID string) error {
var query string
if db.dbType == "sqlite" {
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?`
} else {
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1`
}
_, err := db.conn.Exec(query, workerID)
if err != nil {
return fmt.Errorf("failed to update worker heartbeat: %w", err)
}
return nil
}
func (db *DB) GetActiveWorkers() ([]*Worker, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')`
} else {
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
}
rows, err := db.conn.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to get active workers: %w", err)
}
defer rows.Close()
var workers []*Worker
for rows.Next() {
var worker Worker
var metadataJSON string
err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat,
&worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON)
if err != nil {
return nil, fmt.Errorf("failed to scan worker: %w", err)
}
json.Unmarshal([]byte(metadataJSON), &worker.Metadata)
workers = append(workers, &worker)
}
return workers, nil
}
// Metrics operations
func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error {
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)`
} else {
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)`
}
_, err := db.conn.Exec(query, jobID, metricName, metricValue)
if err != nil {
return fmt.Errorf("failed to record job metric: %w", err)
}
return nil
}
func (db *DB) RecordSystemMetric(metricName, metricValue string) error {
var query string
if db.dbType == "sqlite" {
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)`
} else {
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)`
}
_, err := db.conn.Exec(query, metricName, metricValue)
if err != nil {
return fmt.Errorf("failed to record system metric: %w", err)
}
return nil
}
func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) {
var query string
if db.dbType == "sqlite" {
query = `SELECT metric_name, metric_value FROM job_metrics
WHERE job_id = ? ORDER BY timestamp DESC`
} else {
query = `SELECT metric_name, metric_value FROM job_metrics
WHERE job_id = $1 ORDER BY timestamp DESC`
}
rows, err := db.conn.Query(query, jobID)
if err != nil {
return nil, fmt.Errorf("failed to get job metrics: %w", err)
}
defer rows.Close()
metrics := make(map[string]string)
for rows.Next() {
var name, value string
if err := rows.Scan(&name, &value); err != nil {
return nil, fmt.Errorf("failed to scan metric: %w", err)
}
metrics[name] = value
}
return metrics, nil
}

212
internal/storage/db_test.go Normal file
View file

@ -0,0 +1,212 @@
package storage
import (
"os"
"testing"
)
func TestDB(t *testing.T) {
// Use a temporary database
dbPath := t.TempDir() + "/test.db"
// Initialize database
db, err := NewDBFromPath(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Initialize schema
schema, err := os.ReadFile("schema.sql")
if err != nil {
t.Fatalf("Failed to read schema: %v", err)
}
if err := db.Initialize(string(schema)); err != nil {
t.Fatalf("Failed to initialize schema: %v", err)
}
// Test job creation
job := &Job{
ID: "test-job-1",
JobName: "test_experiment",
Args: "--epochs 10 --lr 0.001",
Status: "pending",
Priority: 1,
Datasets: []string{"dataset1", "dataset2"},
Metadata: map[string]string{"gpu": "true", "memory": "8GB"},
}
if err := db.CreateJob(job); err != nil {
t.Fatalf("Failed to create job: %v", err)
}
// Verify job exists in database
var count int
err = db.conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE id = ?", "test-job-1").Scan(&count)
if err != nil {
t.Fatalf("Failed to verify job creation: %v", err)
}
if count != 1 {
t.Fatalf("Expected 1 job in database, got %d", count)
}
// Test job retrieval
retrievedJob, err := db.GetJob("test-job-1")
if err != nil {
t.Fatalf("Failed to get job: %v", err)
}
if retrievedJob.ID != job.ID {
t.Errorf("Expected job ID %s, got %s", job.ID, retrievedJob.ID)
}
if retrievedJob.JobName != job.JobName {
t.Errorf("Expected job name %s, got %s", job.JobName, retrievedJob.JobName)
}
if len(retrievedJob.Datasets) != 2 {
t.Errorf("Expected 2 datasets, got %d", len(retrievedJob.Datasets))
}
if retrievedJob.Metadata["gpu"] != "true" {
t.Errorf("Expected gpu=true, got %s", retrievedJob.Metadata["gpu"])
}
// Test job status update
if err := db.UpdateJobStatus("test-job-1", "running", "worker-1", ""); err != nil {
t.Fatalf("Failed to update job status: %v", err)
}
// Verify status update
updatedJob, err := db.GetJob("test-job-1")
if err != nil {
t.Fatalf("Failed to get updated job: %v", err)
}
if updatedJob.Status != "running" {
t.Errorf("Expected status running, got %s", updatedJob.Status)
}
if updatedJob.WorkerID != "worker-1" {
t.Errorf("Expected worker ID worker-1, got %s", updatedJob.WorkerID)
}
if updatedJob.StartedAt == nil {
t.Error("Expected StartedAt to be set")
}
// Test worker registration
worker := &Worker{
ID: "worker-1",
Hostname: "test-host",
Status: "active",
CurrentJobs: 0,
MaxJobs: 2,
Metadata: map[string]string{"cpu": "8", "memory": "16GB"},
}
if err := db.RegisterWorker(worker); err != nil {
t.Fatalf("Failed to register worker: %v", err)
}
// Test worker heartbeat
if err := db.UpdateWorkerHeartbeat("worker-1"); err != nil {
t.Fatalf("Failed to update worker heartbeat: %v", err)
}
// Test metrics recording
if err := db.RecordJobMetric("test-job-1", "accuracy", "0.95"); err != nil {
t.Fatalf("Failed to record job metric: %v", err)
}
if err := db.RecordSystemMetric("cpu_usage", "75"); err != nil {
t.Fatalf("Failed to record system metric: %v", err)
}
// Test metrics retrieval
metrics, err := db.GetJobMetrics("test-job-1")
if err != nil {
t.Fatalf("Failed to get job metrics: %v", err)
}
if metrics["accuracy"] != "0.95" {
t.Errorf("Expected accuracy 0.95, got %s", metrics["accuracy"])
}
// Test job listing
jobs, err := db.ListJobs("", 10)
if err != nil {
t.Fatalf("Failed to list jobs: %v", err)
}
t.Logf("Found %d jobs", len(jobs))
for i, job := range jobs {
t.Logf("Job %d: ID=%s, Status=%s", i, job.ID, job.Status)
}
if len(jobs) != 1 {
t.Errorf("Expected 1 job, got %d", len(jobs))
return
}
if jobs[0].ID != "test-job-1" {
t.Errorf("Expected job ID test-job-1, got %s", jobs[0].ID)
return
}
// Test active workers
workers, err := db.GetActiveWorkers()
if err != nil {
t.Fatalf("Failed to get active workers: %v", err)
}
if len(workers) != 1 {
t.Errorf("Expected 1 active worker, got %d", len(workers))
}
if workers[0].ID != "worker-1" {
t.Errorf("Expected worker ID worker-1, got %s", workers[0].ID)
}
}
func TestDBConstraints(t *testing.T) {
dbPath := t.TempDir() + "/test_constraints.db"
db, err := NewDBFromPath(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
schema, err := os.ReadFile("schema.sql")
if err != nil {
t.Fatalf("Failed to read schema: %v", err)
}
if err := db.Initialize(string(schema)); err != nil {
t.Fatalf("Failed to initialize schema: %v", err)
}
// Test duplicate job ID
job := &Job{
ID: "duplicate-test",
JobName: "test",
Status: "pending",
}
if err := db.CreateJob(job); err != nil {
t.Fatalf("Failed to create first job: %v", err)
}
// Should fail on duplicate
if err := db.CreateJob(job); err == nil {
t.Error("Expected error when creating duplicate job")
}
// Test getting non-existent job
_, err = db.GetJob("non-existent")
if err == nil {
t.Error("Expected error when getting non-existent job")
}
}

257
internal/storage/migrate.go Normal file
View file

@ -0,0 +1,257 @@
package storage
import (
"encoding/json"
"fmt"
"log"
"strings"
"time"
"context"
"github.com/go-redis/redis/v8"
)
// Migrator handles migration from Redis to SQLite
type Migrator struct {
redisClient *redis.Client
sqliteDB *DB
}
func NewMigrator(redisAddr, sqlitePath string) (*Migrator, error) {
// Connect to Redis
rdb := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
// Connect to SQLite
db, err := NewDBFromPath(sqlitePath)
if err != nil {
return nil, fmt.Errorf("failed to connect to SQLite: %w", err)
}
return &Migrator{
redisClient: rdb,
sqliteDB: db,
}, nil
}
func (m *Migrator) Close() error {
if err := m.sqliteDB.Close(); err != nil {
return err
}
return m.redisClient.Close()
}
// MigrateJobs migrates job data from Redis to SQLite
func (m *Migrator) MigrateJobs(ctx context.Context) error {
log.Println("Starting job migration from Redis to SQLite...")
// Get all job keys from Redis
jobKeys, err := m.redisClient.Keys(ctx, "job:*").Result()
if err != nil {
return fmt.Errorf("failed to get job keys from Redis: %w", err)
}
for _, jobKey := range jobKeys {
jobData, err := m.redisClient.HGetAll(ctx, jobKey).Result()
if err != nil {
log.Printf("Failed to get job data for %s: %v", jobKey, err)
continue
}
// Parse job data
job := &Job{
ID: jobKey[4:], // Remove "job:" prefix
JobName: jobData["job_name"],
Args: jobData["args"],
Status: jobData["status"],
Priority: parsePriority(jobData["priority"]),
WorkerID: jobData["worker_id"],
Error: jobData["error"],
}
// Parse timestamps
if createdAtStr := jobData["created_at"]; createdAtStr != "" {
if ts, err := time.Parse(time.RFC3339, createdAtStr); err == nil {
job.CreatedAt = ts
}
}
if startedAtStr := jobData["started_at"]; startedAtStr != "" {
if ts, err := time.Parse(time.RFC3339, startedAtStr); err == nil {
job.StartedAt = &ts
}
}
if endedAtStr := jobData["ended_at"]; endedAtStr != "" {
if ts, err := time.Parse(time.RFC3339, endedAtStr); err == nil {
job.EndedAt = &ts
}
}
// Parse JSON fields
if datasetsStr := jobData["datasets"]; datasetsStr != "" {
json.Unmarshal([]byte(datasetsStr), &job.Datasets)
}
if metadataStr := jobData["metadata"]; metadataStr != "" {
json.Unmarshal([]byte(metadataStr), &job.Metadata)
}
// Insert into SQLite
if err := m.sqliteDB.CreateJob(job); err != nil {
log.Printf("Failed to create job %s in SQLite: %v", job.ID, err)
continue
}
log.Printf("Migrated job: %s", job.ID)
}
log.Printf("Migrated %d jobs from Redis to SQLite", len(jobKeys))
return nil
}
// MigrateMetrics migrates metrics from Redis to SQLite
func (m *Migrator) MigrateMetrics(ctx context.Context) error {
log.Println("Starting metrics migration from Redis to SQLite...")
// Get all metric keys from Redis
metricKeys, err := m.redisClient.Keys(ctx, "metrics:*").Result()
if err != nil {
return fmt.Errorf("failed to get metric keys from Redis: %w", err)
}
for _, metricKey := range metricKeys {
metricData, err := m.redisClient.HGetAll(ctx, metricKey).Result()
if err != nil {
log.Printf("Failed to get metric data for %s: %v", metricKey, err)
continue
}
// Parse metric key format: metrics:job:job_id or metrics:system
parts := parseMetricKey(metricKey)
if len(parts) < 2 {
continue
}
metricType := parts[1] // "job" or "system"
for name, value := range metricData {
if metricType == "job" && len(parts) == 3 {
// Job metric
jobID := parts[2]
if err := m.sqliteDB.RecordJobMetric(jobID, name, value); err != nil {
log.Printf("Failed to record job metric %s for job %s: %v", name, jobID, err)
}
} else if metricType == "system" {
// System metric
if err := m.sqliteDB.RecordSystemMetric(name, value); err != nil {
log.Printf("Failed to record system metric %s: %v", name, err)
}
}
}
}
log.Printf("Migrated %d metric keys from Redis to SQLite", len(metricKeys))
return nil
}
// MigrateWorkers migrates worker data from Redis to SQLite
func (m *Migrator) MigrateWorkers(ctx context.Context) error {
log.Println("Starting worker migration from Redis to SQLite...")
// Get all worker keys from Redis
workerKeys, err := m.redisClient.Keys(ctx, "worker:*").Result()
if err != nil {
return fmt.Errorf("failed to get worker keys from Redis: %w", err)
}
for _, workerKey := range workerKeys {
workerData, err := m.redisClient.HGetAll(ctx, workerKey).Result()
if err != nil {
log.Printf("Failed to get worker data for %s: %v", workerKey, err)
continue
}
worker := &Worker{
ID: workerKey[8:], // Remove "worker:" prefix
Hostname: workerData["hostname"],
Status: workerData["status"],
CurrentJobs: parseInt(workerData["current_jobs"]),
MaxJobs: parseInt(workerData["max_jobs"]),
}
// Parse heartbeat
if heartbeatStr := workerData["last_heartbeat"]; heartbeatStr != "" {
if ts, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
worker.LastHeartbeat = ts
}
}
// Parse metadata
if metadataStr := workerData["metadata"]; metadataStr != "" {
json.Unmarshal([]byte(metadataStr), &worker.Metadata)
}
// Insert into SQLite
if err := m.sqliteDB.RegisterWorker(worker); err != nil {
log.Printf("Failed to register worker %s in SQLite: %v", worker.ID, err)
continue
}
log.Printf("Migrated worker: %s", worker.ID)
}
log.Printf("Migrated %d workers from Redis to SQLite", len(workerKeys))
return nil
}
// MigrateAll performs complete migration from Redis to SQLite
func (m *Migrator) MigrateAll(ctx context.Context) error {
log.Println("Starting complete migration from Redis to SQLite...")
// Test connections
if err := m.redisClient.Ping(ctx).Err(); err != nil {
return fmt.Errorf("failed to connect to Redis: %w", err)
}
// Run migrations in order
if err := m.MigrateJobs(ctx); err != nil {
return fmt.Errorf("job migration failed: %w", err)
}
if err := m.MigrateWorkers(ctx); err != nil {
return fmt.Errorf("worker migration failed: %w", err)
}
if err := m.MigrateMetrics(ctx); err != nil {
return fmt.Errorf("metrics migration failed: %w", err)
}
log.Println("Migration completed successfully!")
return nil
}
// Helper functions
func parsePriority(s string) int64 {
if s == "" {
return 0
}
// Implementation depends on your priority format
return 0
}
func parseInt(s string) int {
if s == "" {
return 0
}
// Implementation depends on your int format
return 0
}
func parseMetricKey(key string) []string {
// Simple split - adjust based on your Redis key format
parts := strings.Split(key, ":")
return parts
}

View file

@ -0,0 +1,61 @@
-- SQLite schema for Fetch ML job persistence
-- Complements Redis for task queuing
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
job_name TEXT NOT NULL,
args TEXT,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
ended_at DATETIME,
worker_id TEXT,
error TEXT,
datasets TEXT, -- JSON array
metadata TEXT, -- JSON object
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS job_metrics (
job_id TEXT,
metric_name TEXT,
metric_value TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (job_id, metric_name, timestamp),
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS workers (
id TEXT PRIMARY KEY,
hostname TEXT,
last_heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP,
status TEXT DEFAULT 'active',
current_jobs INTEGER DEFAULT 0,
max_jobs INTEGER DEFAULT 1,
metadata TEXT -- JSON object
);
CREATE TABLE IF NOT EXISTS system_metrics (
metric_name TEXT,
metric_value TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (metric_name, timestamp)
);
-- Indexes for performance
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id);
CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
-- Triggers to update timestamps
CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
AFTER UPDATE ON jobs
FOR EACH ROW
BEGIN
UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;

View file

@ -0,0 +1,68 @@
-- PostgreSQL schema for Fetch ML job persistence
-- Complements Redis for task queuing
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
job_name TEXT NOT NULL,
args TEXT,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP WITH TIME ZONE,
ended_at TIMESTAMP WITH TIME ZONE,
worker_id TEXT,
error TEXT,
datasets TEXT, -- JSON array
metadata TEXT, -- JSON object
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS job_metrics (
job_id TEXT,
metric_name TEXT,
metric_value TEXT,
timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (job_id, metric_name, timestamp),
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS workers (
id TEXT PRIMARY KEY,
hostname TEXT,
last_heartbeat TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
status TEXT DEFAULT 'active',
current_jobs INTEGER DEFAULT 0,
max_jobs INTEGER DEFAULT 1,
metadata TEXT -- JSON object
);
CREATE TABLE IF NOT EXISTS system_metrics (
metric_name TEXT,
metric_value TEXT,
timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (metric_name, timestamp)
);
-- Indexes for performance
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id);
CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
-- Function to update updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- Trigger to update timestamps
CREATE TRIGGER update_jobs_timestamp
BEFORE UPDATE ON jobs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();

View file

@ -0,0 +1,77 @@
package telemetry
import (
"bufio"
"os"
"strconv"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
type IOStats struct {
ReadBytes uint64
WriteBytes uint64
}
func ReadProcessIO() (IOStats, error) {
f, err := os.Open("/proc/self/io")
if err != nil {
return IOStats{}, err
}
defer f.Close()
var stats IOStats
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "read_bytes:") {
stats.ReadBytes = parseUintField(line)
}
if strings.HasPrefix(line, "write_bytes:") {
stats.WriteBytes = parseUintField(line)
}
}
if err := scanner.Err(); err != nil {
return IOStats{}, err
}
return stats, nil
}
func DiffIO(before, after IOStats) IOStats {
var delta IOStats
if after.ReadBytes >= before.ReadBytes {
delta.ReadBytes = after.ReadBytes - before.ReadBytes
}
if after.WriteBytes >= before.WriteBytes {
delta.WriteBytes = after.WriteBytes - before.WriteBytes
}
return delta
}
func parseUintField(line string) uint64 {
parts := strings.Split(line, ":")
if len(parts) != 2 {
return 0
}
value, err := strconv.ParseUint(strings.TrimSpace(parts[1]), 10, 64)
if err != nil {
return 0
}
return value
}
func ExecWithMetrics(logger *logging.Logger, description string, threshold time.Duration, fn func() (string, error)) (string, error) {
start := time.Now()
out, err := fn()
duration := time.Since(start)
if duration > threshold {
fields := []any{"latency_ms", duration.Milliseconds(), "command", description}
if err != nil {
fields = append(fields, "error", err)
}
logger.Debug("ssh exec", fields...)
}
return out, err
}