diff --git a/cmd/api-server/README.md b/cmd/api-server/README.md new file mode 100644 index 0000000..513b330 --- /dev/null +++ b/cmd/api-server/README.md @@ -0,0 +1,32 @@ +# API Server + +WebSocket API server for the ML CLI tool... + +## Usage + +```bash +./bin/api-server --config configs/config-dev.yaml --listen :9100 +``` + +## Endpoints + +- `GET /health` - Health check +- `WS /ws` - WebSocket endpoint for CLI communication + +## Binary Protocol + +See [CLI README](../../cli/README.md#websocket-protocol) for protocol details. + +## Configuration + +Uses the same configuration file as the worker. Experiment base path is read from `base_path` configuration key. + +## Example + +```bash +# Start API server +./bin/api-server --listen :9100 + +# In another terminal, test with CLI +./cli/zig-out/bin/ml status +``` diff --git a/cmd/api-server/main.go b/cmd/api-server/main.go new file mode 100644 index 0000000..9c07666 --- /dev/null +++ b/cmd/api-server/main.go @@ -0,0 +1,363 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/jfraeys/fetch_ml/internal/api" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/middleware" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" + "gopkg.in/yaml.v3" +) + +// Config structure matching worker config +type Config struct { + BasePath string `yaml:"base_path"` + Auth auth.AuthConfig `yaml:"auth"` + Server ServerConfig `yaml:"server"` + Security SecurityConfig `yaml:"security"` + Redis RedisConfig `yaml:"redis"` + Database DatabaseConfig `yaml:"database"` + Logging logging.Config `yaml:"logging"` +} + +type RedisConfig struct { + Addr string `yaml:"addr"` + Password string `yaml:"password"` + DB int `yaml:"db"` + URL string `yaml:"url"` +} + +type DatabaseConfig struct { + Type string `yaml:"type"` + Connection string `yaml:"connection"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Username string `yaml:"username"` + Password string `yaml:"password"` + Database string `yaml:"database"` +} + +type SecurityConfig struct { + RateLimit RateLimitConfig `yaml:"rate_limit"` + IPWhitelist []string `yaml:"ip_whitelist"` + FailedLockout LockoutConfig `yaml:"failed_login_lockout"` +} + +type RateLimitConfig struct { + Enabled bool `yaml:"enabled"` + RequestsPerMinute int `yaml:"requests_per_minute"` + BurstSize int `yaml:"burst_size"` +} + +type LockoutConfig struct { + Enabled bool `yaml:"enabled"` + MaxAttempts int `yaml:"max_attempts"` + LockoutDuration string `yaml:"lockout_duration"` +} + +type ServerConfig struct { + Address string `yaml:"address"` + TLS TLSConfig `yaml:"tls"` +} + +type TLSConfig struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +func LoadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func main() { + // Parse flags + configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path") + apiKey := flag.String("api-key", "", "API key for authentication") + flag.Parse() + + // Load config + resolvedConfig, err := config.ResolveConfigPath(*configFile) + if err != nil { + log.Fatalf("Failed to resolve config: %v", err) + } + + cfg, err := LoadConfig(resolvedConfig) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // Ensure log directory exists + if cfg.Logging.File != "" { + logDir := filepath.Dir(cfg.Logging.File) + log.Printf("Creating log directory: %s", logDir) + if err := os.MkdirAll(logDir, 0755); err != nil { + log.Fatalf("Failed to create log directory: %v", err) + } + } + + // Setup logging + logger := logging.NewLoggerFromConfig(cfg.Logging) + ctx := logging.EnsureTrace(context.Background()) + logger = logger.Component(ctx, "api-server") + + // Setup experiment manager + basePath := cfg.BasePath + if basePath == "" { + basePath = "/tmp/ml-experiments" + } + expManager := experiment.NewManager(basePath) + log.Printf("Initializing experiment manager with base_path: %s", basePath) + if err := expManager.Initialize(); err != nil { + logger.Fatal("failed to initialize experiment manager", "error", err) + } + logger.Info("experiment manager initialized", "base_path", basePath) + + // Setup auth + var authCfg *auth.AuthConfig + if cfg.Auth.Enabled { + authCfg = &cfg.Auth + logger.Info("authentication enabled") + } + + // Setup HTTP server with security middleware + mux := http.NewServeMux() + + // Convert API keys from map to slice for security middleware + apiKeys := make([]string, 0, len(cfg.Auth.APIKeys)) + for username := range cfg.Auth.APIKeys { + // For now, use username as the key (in production, this should be the actual API key) + apiKeys = append(apiKeys, string(username)) + } + + // Create security middleware + sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET")) + + // Setup TaskQueue + queueCfg := queue.Config{ + RedisAddr: cfg.Redis.Addr, + RedisPassword: cfg.Redis.Password, + RedisDB: cfg.Redis.DB, + } + if queueCfg.RedisAddr == "" { + queueCfg.RedisAddr = config.DefaultRedisAddr + } + // Support URL format for Redis + if cfg.Redis.URL != "" { + queueCfg.RedisAddr = cfg.Redis.URL + } + + taskQueue, err := queue.NewTaskQueue(queueCfg) + if err != nil { + logger.Error("failed to initialize task queue", "error", err) + // We continue without queue, but queue operations will fail + } else { + logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr) + defer func() { + logger.Info("stopping task queue...") + if err := taskQueue.Close(); err != nil { + logger.Error("failed to stop task queue", "error", err) + } else { + logger.Info("task queue stopped") + } + }() + } + + // Setup database if configured + var db *storage.DB + if cfg.Database.Type != "" { + dbConfig := storage.DBConfig{ + Type: cfg.Database.Type, + Connection: cfg.Database.Connection, + Host: cfg.Database.Host, + Port: cfg.Database.Port, + Username: cfg.Database.Username, + Password: cfg.Database.Password, + Database: cfg.Database.Database, + } + + db, err = storage.NewDB(dbConfig) + if err != nil { + logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err) + } else { + // Load appropriate database schema + var schemaPath string + if cfg.Database.Type == "sqlite" { + schemaPath = "internal/storage/schema.sql" + } else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" { + schemaPath = "internal/storage/schema_postgres.sql" + } else { + logger.Error("unsupported database type", "type", cfg.Database.Type) + db.Close() + db = nil + } + + if db != nil && schemaPath != "" { + schema, err := os.ReadFile(schemaPath) + if err != nil { + logger.Error("failed to read database schema file", "path", schemaPath, "error", err) + db.Close() + db = nil + } else { + if err := db.Initialize(string(schema)); err != nil { + logger.Error("failed to initialize database schema", "error", err) + db.Close() + db = nil + } else { + logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection) + defer func() { + logger.Info("closing database connection...") + if err := db.Close(); err != nil { + logger.Error("failed to close database", "error", err) + } else { + logger.Info("database connection closed") + } + }() + } + } + } + } + } + + // Setup WebSocket handler with authentication + wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue) + + // WebSocket endpoint - no middleware to avoid hijacking issues + mux.Handle("/ws", wsHandler) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK\n") + }) + + // Database status endpoint + mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if db != nil { + // Test database connection with a simple query + var result struct { + Status string `json:"status"` + Type string `json:"type"` + Path string `json:"path"` + Message string `json:"message"` + } + result.Status = "connected" + result.Type = "sqlite" + result.Path = cfg.Database.Connection + result.Message = "SQLite database is operational" + + // Test a simple query to verify connectivity + if err := db.RecordSystemMetric("db_test", "ok"); err != nil { + result.Status = "error" + result.Message = fmt.Sprintf("Database query failed: %v", err) + } + + jsonBytes, _ := json.Marshal(result) + w.Write(jsonBytes) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`) + } + }) + + // Apply security middleware to all routes except WebSocket + // Create separate handlers for WebSocket vs other routes + var finalHandler http.Handler = mux + + // Wrap non-websocket routes with security middleware + finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ws" { + mux.ServeHTTP(w, r) + } else { + // Apply middleware chain for non-WebSocket routes + handler := sec.RateLimit(mux) + handler = middleware.SecurityHeaders(handler) + handler = middleware.CORS(handler) + handler = middleware.RequestTimeout(30 * time.Second)(handler) + + // Apply audit logger and IP whitelist only to non-WebSocket routes + handler = middleware.AuditLogger(handler) + if len(cfg.Security.IPWhitelist) > 0 { + handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler) + } + + handler.ServeHTTP(w, r) + } + }) + + var handler http.Handler = finalHandler + + server := &http.Server{ + Addr: cfg.Server.Address, + Handler: handler, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + if !cfg.Server.TLS.Enabled { + logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address) + } + + // Start server in goroutine + go func() { + // Setup TLS if configured + if cfg.Server.TLS.Enabled { + logger.Info("starting HTTPS server", "address", cfg.Server.Address) + if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed { + logger.Error("HTTPS server failed", "error", err) + } + } else { + logger.Info("starting HTTP server", "address", cfg.Server.Address) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error("HTTP server failed", "error", err) + } + } + os.Exit(1) + }() + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + sig := <-sigChan + logger.Info("received shutdown signal", "signal", sig) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + logger.Info("shutting down http server...") + if err := server.Shutdown(ctx); err != nil { + logger.Error("server shutdown error", "error", err) + } else { + logger.Info("http server shutdown complete") + } + + logger.Info("api server stopped") + + _ = expManager // Use expManager to avoid unused warning + _ = apiKey // Will be used for auth later +} diff --git a/cmd/configlint/main.go b/cmd/configlint/main.go new file mode 100644 index 0000000..0b4befc --- /dev/null +++ b/cmd/configlint/main.go @@ -0,0 +1,116 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/xeipuuv/gojsonschema" + "gopkg.in/yaml.v3" +) + +func main() { + var ( + schemaPath string + failFast bool + ) + + flag.StringVar(&schemaPath, "schema", "configs/schema.yaml", "Path to JSON schema in YAML format") + flag.BoolVar(&failFast, "fail-fast", false, "Stop on first error") + flag.Parse() + + if flag.NArg() == 0 { + log.Fatalf("usage: configlint [--schema path] [--fail-fast] ") + } + + schemaLoader, err := loadSchema(schemaPath) + if err != nil { + log.Fatalf("failed to load schema: %v", err) + } + + var hadError bool + for _, configPath := range flag.Args() { + if err := validateConfig(schemaLoader, configPath); err != nil { + hadError = true + fmt.Fprintf(os.Stderr, "configlint: %s: %v\n", configPath, err) + if failFast { + os.Exit(1) + } + } + } + + if hadError { + os.Exit(1) + } + + fmt.Println("All configuration files are valid.") +} + +func loadSchema(schemaPath string) (gojsonschema.JSONLoader, error) { + data, err := os.ReadFile(schemaPath) + if err != nil { + return nil, err + } + + var schemaYAML interface{} + if err := yaml.Unmarshal(data, &schemaYAML); err != nil { + return nil, err + } + + schemaJSON, err := json.Marshal(schemaYAML) + if err != nil { + return nil, err + } + + tmpFile, err := os.CreateTemp("", "fetchml-schema-*.json") + if err != nil { + return nil, err + } + defer tmpFile.Close() + + if _, err := tmpFile.Write(schemaJSON); err != nil { + return nil, err + } + + return gojsonschema.NewReferenceLoader("file://" + filepath.ToSlash(tmpFile.Name())), nil +} + +func validateConfig(schemaLoader gojsonschema.JSONLoader, configPath string) error { + data, err := os.ReadFile(configPath) + if err != nil { + return err + } + + var configYAML interface{} + if err := yaml.Unmarshal(data, &configYAML); err != nil { + return fmt.Errorf("failed to parse YAML: %w", err) + } + + configJSON, err := json.Marshal(configYAML) + if err != nil { + return err + } + + result, err := gojsonschema.Validate(schemaLoader, gojsonschema.NewBytesLoader(configJSON)) + if err != nil { + return err + } + + if result.Valid() { + fmt.Printf("%s: valid\n", configPath) + return nil + } + + var builder strings.Builder + for _, issue := range result.Errors() { + builder.WriteString("- ") + builder.WriteString(issue.String()) + builder.WriteByte('\n') + } + + return fmt.Errorf("validation failed:\n%s", builder.String()) +} diff --git a/cmd/data_manager/data_manager_config.go b/cmd/data_manager/data_manager_config.go new file mode 100644 index 0000000..66e38a6 --- /dev/null +++ b/cmd/data_manager/data_manager_config.go @@ -0,0 +1,132 @@ +// DataConfig holds the configuration for the data manager +package main + +import ( + "fmt" + "os" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "gopkg.in/yaml.v3" +) + +type DataConfig struct { + // ML Server (where training runs) + MLHost string `yaml:"ml_host"` + MLUser string `yaml:"ml_user"` + MLSSHKey string `yaml:"ml_ssh_key"` + MLPort int `yaml:"ml_port"` + MLDataDir string `yaml:"ml_data_dir"` // e.g., /data/active + + // NAS (where datasets are stored) + NASHost string `yaml:"nas_host"` + NASUser string `yaml:"nas_user"` + NASSSHKey string `yaml:"nas_ssh_key"` + NASPort int `yaml:"nas_port"` + NASDataDir string `yaml:"nas_data_dir"` // e.g., /mnt/datasets + + // Redis + RedisAddr string `yaml:"redis_addr"` + RedisPassword string `yaml:"redis_password"` + RedisDB int `yaml:"redis_db"` + + // Authentication + Auth auth.AuthConfig `yaml:"auth"` + + // Cleanup settings + MaxAgeHours int `yaml:"max_age_hours"` // Delete data older than X hours + MaxSizeGB int `yaml:"max_size_gb"` // Keep total size under X GB + CleanupInterval int `yaml:"cleanup_interval_min"` // Run cleanup every X minutes + + // Podman integration + PodmanImage string `yaml:"podman_image"` + ContainerWorkspace string `yaml:"container_workspace"` + ContainerResults string `yaml:"container_results"` + GPUAccess bool `yaml:"gpu_access"` +} + +func LoadDataConfig(path string) (*DataConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var cfg DataConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + // Defaults + if cfg.MLPort == 0 { + cfg.MLPort = config.DefaultSSHPort + } + if cfg.NASPort == 0 { + cfg.NASPort = config.DefaultSSHPort + } + if cfg.RedisAddr == "" { + cfg.RedisAddr = config.DefaultRedisAddr + } + // Set default MLDataDir - use ./data/active for local/dev, /data/active for production + if cfg.MLDataDir == "" { + if cfg.MLHost == "" { + // Local mode - use local data directory + cfg.MLDataDir = config.DefaultLocalDataDir + } else { + // Production mode - use /data/active + cfg.MLDataDir = config.DefaultDataDir + } + } + if cfg.NASDataDir == "" { + cfg.NASDataDir = config.DefaultNASDataDir + } + + // Expand paths + cfg.MLDataDir = config.ExpandPath(cfg.MLDataDir) + cfg.NASDataDir = config.ExpandPath(cfg.NASDataDir) + if cfg.MaxAgeHours == 0 { + cfg.MaxAgeHours = config.DefaultMaxAgeHours + } + if cfg.MaxSizeGB == 0 { + cfg.MaxSizeGB = config.DefaultMaxSizeGB + } + if cfg.CleanupInterval == 0 { + cfg.CleanupInterval = config.DefaultCleanupInterval + } + + return &cfg, nil +} + +// Validate implements utils.Validator interface +func (c *DataConfig) Validate() error { + if c.MLPort != 0 { + if err := config.ValidatePort(c.MLPort); err != nil { + return fmt.Errorf("invalid ML SSH port: %w", err) + } + } + + if c.NASPort != 0 { + if err := config.ValidatePort(c.NASPort); err != nil { + return fmt.Errorf("invalid NAS SSH port: %w", err) + } + } + + if c.RedisAddr != "" { + if err := config.ValidateRedisAddr(c.RedisAddr); err != nil { + return fmt.Errorf("invalid Redis configuration: %w", err) + } + } + + if c.MaxAgeHours < 1 { + return fmt.Errorf("max_age_hours must be at least 1, got %d", c.MaxAgeHours) + } + + if c.MaxSizeGB < 1 { + return fmt.Errorf("max_size_gb must be at least 1, got %d", c.MaxSizeGB) + } + + if c.CleanupInterval < 1 { + return fmt.Errorf("cleanup_interval must be at least 1, got %d", c.CleanupInterval) + } + + return nil +} diff --git a/cmd/data_manager/data_sync.go b/cmd/data_manager/data_sync.go new file mode 100644 index 0000000..03ee322 --- /dev/null +++ b/cmd/data_manager/data_sync.go @@ -0,0 +1,775 @@ +// data_manager.go - Fetch data from NAS to ML server on-demand +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "log/slog" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/errors" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/network" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/telemetry" +) + +// SSHClient alias for convenience +type SSHClient = network.SSHClient + +type DataManager struct { + config *DataConfig + mlServer *SSHClient + nasServer *SSHClient + taskQueue *queue.TaskQueue + ctx context.Context + cancel context.CancelFunc + logger *logging.Logger +} + +type DataFetchRequest struct { + JobName string `json:"job_name"` + Datasets []string `json:"datasets"` // Dataset names to fetch + Priority int `json:"priority"` + RequestedAt time.Time `json:"requested_at"` +} + +type DatasetInfo struct { + Name string `json:"name"` + SizeBytes int64 `json:"size_bytes"` + Location string `json:"location"` // "nas" or "ml" + LastAccess time.Time `json:"last_access"` +} + +func NewDataManager(cfg *DataConfig, apiKey string) (*DataManager, error) { + mlServer, err := network.NewSSHClient(cfg.MLHost, cfg.MLUser, cfg.MLSSHKey, cfg.MLPort, "") + if err != nil { + return nil, fmt.Errorf("ML server connection failed: %w", err) + } + defer func() { + if err != nil { + if closeErr := mlServer.Close(); closeErr != nil { + log.Printf("Warning: failed to close ML server connection: %v", closeErr) + } + } + }() + + nasServer, err := network.NewSSHClient(cfg.NASHost, cfg.NASUser, cfg.NASSSHKey, cfg.NASPort, "") + if err != nil { + return nil, fmt.Errorf("NAS connection failed: %w", err) + } + defer func() { + if err != nil { + if closeErr := nasServer.Close(); closeErr != nil { + log.Printf("Warning: failed to close NAS server connection: %v", closeErr) + } + } + }() + + // Create MLDataDir if it doesn't exist (for production without NAS) + if cfg.MLDataDir != "" { + if _, err := mlServer.Exec(fmt.Sprintf("mkdir -p %s", cfg.MLDataDir)); err != nil { + logger := logging.NewLogger(slog.LevelInfo, false) + logger.Job(context.Background(), "data_manager", "").Error("Failed to create ML data directory", "dir", cfg.MLDataDir, "error", err) + } + } + + // Setup Redis using internal queue + ctx, cancel := context.WithCancel(context.Background()) + logger := logging.NewLogger(slog.LevelInfo, false) + + var taskQueue *queue.TaskQueue + if cfg.RedisAddr != "" { + queueCfg := queue.Config{ + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + } + + var err error + taskQueue, err = queue.NewTaskQueue(queueCfg) + if err != nil { + // FIXED: Check error return values for cleanup + if closeErr := mlServer.Close(); closeErr != nil { + logger.Warn("failed to close ML server during error cleanup", "error", closeErr) + } + if closeErr := nasServer.Close(); closeErr != nil { + logger.Warn("failed to close NAS server during error cleanup", "error", closeErr) + } + cancel() // Cancel context to prevent leak + return nil, fmt.Errorf("redis connection failed: %w", err) + } + } else { + taskQueue = nil // Local mode - no Redis + } + + return &DataManager{ + config: cfg, + mlServer: mlServer, + nasServer: nasServer, + taskQueue: taskQueue, + ctx: ctx, + cancel: cancel, + logger: logger, + }, nil +} + +func (dm *DataManager) FetchDataset(jobName, datasetName string) error { + ctx, cancel := context.WithTimeout(dm.ctx, 30*time.Minute) + defer cancel() + + return network.RetryForNetworkOperations(ctx, func() error { + return dm.fetchDatasetInternal(ctx, jobName, datasetName) + }) +} + +func (dm *DataManager) fetchDatasetInternal(ctx context.Context, jobName, datasetName string) error { + if err := container.ValidateJobName(datasetName); err != nil { + return &errors.DataFetchError{ + Dataset: datasetName, + JobName: jobName, + Err: fmt.Errorf("invalid dataset name: %w", err), + } + } + + logger := dm.logger.Job(ctx, jobName, "") + logger.Info("fetching dataset", "dataset", datasetName) + + // Validate dataset size and run cleanup if needed + if err := dm.ValidateDatasetWithCleanup(datasetName); err != nil { + return &errors.DataFetchError{ + Dataset: datasetName, + JobName: jobName, + Err: fmt.Errorf("dataset size validation failed: %w", err), + } + } + + nasPath := filepath.Join(dm.config.NASDataDir, datasetName) + mlPath := filepath.Join(dm.config.MLDataDir, datasetName) + + // Check if dataset exists on NAS + if !dm.nasServer.FileExists(nasPath) { + return &errors.DataFetchError{ + Dataset: datasetName, + JobName: jobName, + Err: fmt.Errorf("dataset not found on NAS"), + } + } + + // Check if already on ML server + if dm.mlServer.FileExists(mlPath) { + logger.Info("dataset already on ML server", "dataset", datasetName) + dm.updateLastAccess(datasetName) + return nil + } + + // Get size for progress tracking + size, err := dm.nasServer.GetFileSize(nasPath) + if err != nil { + logger.Warn("could not get dataset size", "dataset", datasetName, "error", err) + size = 0 + } + + sizeGB := float64(size) / (1024 * 1024 * 1024) + logger.Info("transferring dataset", + "dataset", datasetName, + "size_gb", sizeGB, + "nas_path", nasPath, + "ml_path", mlPath) + + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), + "status", "transferring", + "job_name", jobName, + "size_bytes", size, + "started_at", time.Now().Unix()).Err(); err != nil { + logger.Warn("failed to record transfer start in Redis", "error", err) + } + } + + // Use local copy for local mode, rsync for remote mode + var rsyncCmd string + if dm.config.NASHost == "" || dm.config.NASUser == "" { + // Local mode - use cp + rsyncCmd = fmt.Sprintf("mkdir -p %s && cp -r %s %s/", dm.config.MLDataDir, nasPath, mlPath) + } else { + // Remote mode - use rsync over SSH + rsyncCmd = fmt.Sprintf( + "mkdir -p %s && rsync -avz --progress %s@%s:%s/ %s/", + dm.config.MLDataDir, + dm.config.NASUser, + dm.config.NASHost, + nasPath, + mlPath, + ) + } + + ioBefore, ioErr := telemetry.ReadProcessIO() + start := time.Now() + out, err := telemetry.ExecWithMetrics(dm.logger, "dataset transfer", time.Since(start), func() (string, error) { + return dm.nasServer.ExecContext(ctx, rsyncCmd) + }) + duration := time.Since(start) + + if err != nil { + logger.Error("transfer failed", + "dataset", datasetName, + "duration", duration, + "error", err, + "output", out) + + if ioErr == nil { + if after, readErr := telemetry.ReadProcessIO(); readErr == nil { + delta := telemetry.DiffIO(ioBefore, after) + logger.Debug("transfer io stats", + "dataset", datasetName, + "read_bytes", delta.ReadBytes, + "write_bytes", delta.WriteBytes) + } + } + + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), + "status", "failed", + "error", err.Error()).Err(); redisErr != nil { + logger.Warn("failed to record transfer failure in Redis", "error", redisErr) + } + } + return err + } + + logger.Info("transfer complete", + "dataset", datasetName, + "duration", duration, + "size_gb", sizeGB) + + if ioErr == nil { + if after, readErr := telemetry.ReadProcessIO(); readErr == nil { + delta := telemetry.DiffIO(ioBefore, after) + logger.Debug("transfer io stats", + "dataset", datasetName, + "read_bytes", delta.ReadBytes, + "write_bytes", delta.WriteBytes) + } + } + + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName), + "status", "completed", + "completed_at", time.Now().Unix(), + "duration_seconds", duration.Seconds()).Err(); err != nil { + logger.Warn("failed to record transfer completion in Redis", "error", err) + } + } + + // Track dataset metadata + dm.saveDatasetInfo(datasetName, size) + + return nil +} + +func (dm *DataManager) saveDatasetInfo(name string, size int64) { + if dm.taskQueue == nil { + return // Skip in local mode + } + + info := DatasetInfo{ + Name: name, + SizeBytes: size, + Location: "ml", + LastAccess: time.Now(), + } + + data, _ := json.Marshal(info) + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis", + "dataset", name, "error", err) + } + } +} + +func (dm *DataManager) updateLastAccess(name string) { + if dm.taskQueue == nil { + return // Skip in local mode + } + + key := fmt.Sprintf("ml:dataset:%s", name) + redisClient := dm.taskQueue.GetRedisClient() + data, err := redisClient.Get(dm.ctx, key).Result() + if err != nil { + return + } + + var info DatasetInfo + if err := json.Unmarshal([]byte(data), &info); err != nil { + return + } + + info.LastAccess = time.Now() + newData, _ := json.Marshal(info) + redisClient = dm.taskQueue.GetRedisClient() + if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis", + "dataset", name, "error", err) + } +} + +// ListDatasetsOnML returns a list of all datasets currently stored on the ML server. +func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) { + out, err := dm.mlServer.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", dm.config.MLDataDir)) + if err != nil { + return nil, err + } + + var datasets []DatasetInfo + for name := range strings.SplitSeq(strings.TrimSpace(out), "\n") { + if name == "" { + continue + } + + var info DatasetInfo + + // Only use Redis if available + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + key := fmt.Sprintf("ml:dataset:%s", name) + data, err := redisClient.Get(dm.ctx, key).Result() + + if err == nil { + if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil { + // Fallback to disk if unmarshal fails + size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name)) + info = DatasetInfo{ + Name: name, + SizeBytes: size, + Location: "ml", + } + } + } else { + // Fallback: get from disk + size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name)) + info = DatasetInfo{ + Name: name, + SizeBytes: size, + Location: "ml", + } + } + } else { + // Local mode: get from disk + size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name)) + info = DatasetInfo{ + Name: name, + SizeBytes: size, + Location: "ml", + } + } + + datasets = append(datasets, info) + } + + return datasets, nil +} + +func (dm *DataManager) CleanupOldData() error { + logger := dm.logger.Job(dm.ctx, "data_manager", "") + logger.Info("running data cleanup") + + datasets, err := dm.ListDatasetsOnML() + if err != nil { + return err + } + + var totalSize int64 + for _, ds := range datasets { + totalSize += ds.SizeBytes + } + + totalSizeGB := float64(totalSize) / (1024 * 1024 * 1024) + logger.Info("current storage usage", + "total_size_gb", totalSizeGB, + "dataset_count", len(datasets)) + + // Delete datasets older than max age or if over size limit + maxAge := time.Duration(dm.config.MaxAgeHours) * time.Hour + maxSize := int64(dm.config.MaxSizeGB) * 1024 * 1024 * 1024 + + var deleted []string + for _, ds := range datasets { + shouldDelete := false + + // Check age + if !ds.LastAccess.IsZero() && time.Since(ds.LastAccess) > maxAge { + logger.Info("dataset is old, marking for deletion", + "dataset", ds.Name, + "last_access", ds.LastAccess, + "age_hours", time.Since(ds.LastAccess).Hours()) + shouldDelete = true + } + + // Check if over size limit + if totalSize > maxSize { + logger.Info("over size limit, deleting oldest dataset", + "dataset", ds.Name, + "current_size_gb", totalSizeGB, + "max_size_gb", dm.config.MaxSizeGB) + shouldDelete = true + } + + if shouldDelete { + path := filepath.Join(dm.config.MLDataDir, ds.Name) + logger.Info("deleting dataset", "dataset", ds.Name, "path", path) + + if _, err := dm.mlServer.Exec(fmt.Sprintf("rm -rf %s", path)); err != nil { + logger.Error("failed to delete dataset", + "dataset", ds.Name, + "error", err) + continue + } + + deleted = append(deleted, ds.Name) + totalSize -= ds.SizeBytes + + // FIXED: Remove from Redis only if available, with error handling + if dm.taskQueue != nil { + redisClient := dm.taskQueue.GetRedisClient() + if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil { + logger.Warn("failed to delete dataset from Redis", + "dataset", ds.Name, + "error", err) + } + } + } + } + + if len(deleted) > 0 { + logger.Info("cleanup complete", + "deleted_count", len(deleted), + "deleted_datasets", deleted) + } else { + logger.Info("cleanup complete", "deleted_count", 0) + } + + return nil +} + +// GetAvailableDiskSpace returns available disk space in bytes +func (dm *DataManager) GetAvailableDiskSpace() int64 { + logger := dm.logger.Job(dm.ctx, "data_manager", "") + + // Check disk space on ML server + cmd := "df -k " + dm.config.MLDataDir + " | tail -1 | awk '{print $4}'" + output, err := dm.mlServer.Exec(cmd) + if err != nil { + logger.Error("failed to get disk space", "error", err) + return 0 + } + + // Parse KB to bytes + var freeKB int64 + _, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &freeKB) + if err != nil { + logger.Error("failed to parse disk space", "error", err, "output", output) + return 0 + } + + return freeKB * 1024 // Convert KB to bytes +} + +// GetDatasetInfo returns information about a dataset from NAS +func (dm *DataManager) GetDatasetInfo(datasetName string) (*DatasetInfo, error) { + // Check if dataset exists on NAS + nasPath := filepath.Join(dm.config.NASDataDir, datasetName) + cmd := fmt.Sprintf("test -d %s && echo 'exists'", nasPath) + output, err := dm.nasServer.Exec(cmd) + if err != nil || strings.TrimSpace(output) != "exists" { + return nil, fmt.Errorf("dataset %s not found on NAS", datasetName) + } + + // Get dataset size + cmd = fmt.Sprintf("du -sb %s | cut -f1", nasPath) + output, err = dm.nasServer.Exec(cmd) + if err != nil { + return nil, fmt.Errorf("failed to get dataset size: %w", err) + } + + var sizeBytes int64 + _, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &sizeBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse dataset size: %w", err) + } + + // Get last modification time as proxy for last access + cmd = fmt.Sprintf("stat -c %%Y %s", nasPath) + output, err = dm.nasServer.Exec(cmd) + if err != nil { + return nil, fmt.Errorf("failed to get dataset timestamp: %w", err) + } + + var modTime int64 + _, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &modTime) + if err != nil { + return nil, fmt.Errorf("failed to parse timestamp: %w", err) + } + + return &DatasetInfo{ + Name: datasetName, + SizeBytes: sizeBytes, + Location: "nas", + LastAccess: time.Unix(modTime, 0), + }, nil +} + +// ValidateDatasetWithCleanup checks if dataset fits and runs cleanup if needed +func (dm *DataManager) ValidateDatasetWithCleanup(datasetName string) error { + logger := dm.logger.Job(dm.ctx, "data_manager", "") + + // Get dataset info + info, err := dm.GetDatasetInfo(datasetName) + if err != nil { + return fmt.Errorf("failed to get dataset info: %w", err) + } + + // Check current available space + availableSpace := dm.GetAvailableDiskSpace() + + logger.Info("dataset size validation", + "dataset", datasetName, + "dataset_size_gb", float64(info.SizeBytes)/(1024*1024*1024), + "available_gb", float64(availableSpace)/(1024*1024*1024)) + + // If enough space, proceed + if info.SizeBytes <= availableSpace { + logger.Info("sufficient space available", "dataset", datasetName) + return nil + } + + // Try cleanup first + logger.Info("insufficient space, running cleanup", + "dataset", datasetName, + "required_gb", float64(info.SizeBytes)/(1024*1024*1024), + "available_gb", float64(availableSpace)/(1024*1024*1024)) + + if err := dm.CleanupOldData(); err != nil { + return fmt.Errorf("cleanup failed: %w", err) + } + + // Check space again after cleanup + availableSpace = dm.GetAvailableDiskSpace() + logger.Info("space after cleanup", + "available_gb", float64(availableSpace)/(1024*1024*1024)) + + // If now enough space, proceed + if info.SizeBytes <= availableSpace { + logger.Info("cleanup freed enough space", "dataset", datasetName) + return nil + } + + // Still not enough space + return fmt.Errorf("dataset %s (%.2fGB) too large for available space (%.2fGB) even after cleanup", + datasetName, + float64(info.SizeBytes)/(1024*1024*1024), + float64(availableSpace)/(1024*1024*1024)) +} + +func (dm *DataManager) StartCleanupLoop() { + logger := dm.logger.Job(dm.ctx, "data_manager", "") + ticker := time.NewTicker(time.Duration(dm.config.CleanupInterval) * time.Minute) + go func() { + defer ticker.Stop() + for { + select { + case <-dm.ctx.Done(): + logger.Info("cleanup loop stopping") + return + case <-ticker.C: + if err := dm.CleanupOldData(); err != nil { + logger.Error("cleanup error", "error", err) + } + } + } + }() +} + +// Close gracefully shuts down the DataManager, stopping the cleanup loop and +// closing all connections to ML server, NAS server, and Redis. +func (dm *DataManager) Close() { + dm.cancel() // Cancel context to stop cleanup loop + + // Wait a moment for cleanup loop to finish + time.Sleep(100 * time.Millisecond) + + if dm.mlServer != nil { + if err := dm.mlServer.Close(); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing ML server connection", "error", err) + } + } + if dm.nasServer != nil { + if err := dm.nasServer.Close(); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing NAS server connection", "error", err) + } + } + if dm.taskQueue != nil { + if err := dm.taskQueue.Close(); err != nil { + dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing Redis connection", "error", err) + } + } +} + +func main() { + // Parse authentication flags + authFlags := auth.ParseAuthFlags() + if err := auth.ValidateAuthFlags(authFlags); err != nil { + log.Fatalf("Authentication flag error: %v", err) + } + + // Get API key from various sources + apiKey := auth.GetAPIKeyFromSources(authFlags) + + configFile := "configs/config-local.yaml" + if authFlags.ConfigFile != "" { + configFile = authFlags.ConfigFile + } + + // Parse command line args + if len(os.Args) < 2 { + fmt.Println("Usage:") + fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key ] fetch [dataset...]") + fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key ] list") + fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key ] cleanup") + fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key ] validate ") + fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key ] daemon") + fmt.Println() + auth.PrintAuthHelp() + os.Exit(1) + } + + // Check for --config flag + if len(os.Args) >= 3 && os.Args[1] == "--config" { + configFile = os.Args[2] + // Shift args + os.Args = append([]string{os.Args[0]}, os.Args[3:]...) + } + + cfg, err := LoadDataConfig(configFile) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // Validate authentication configuration + if err := cfg.Auth.ValidateAuthConfig(); err != nil { + log.Fatalf("Invalid authentication configuration: %v", err) + } + + // Validate configuration + if err := cfg.Validate(); err != nil { + log.Fatalf("Invalid configuration: %v", err) + } + + // Test authentication if enabled + if cfg.Auth.Enabled && apiKey != "" { + user, err := cfg.Auth.ValidateAPIKey(apiKey) + if err != nil { + log.Fatalf("Authentication failed: %v", err) + } + log.Printf("Data manager authenticated as user: %s (admin: %v)", user.Name, user.Admin) + } else if cfg.Auth.Enabled { + log.Fatal("Authentication required but no API key provided") + } + + dm, err := NewDataManager(cfg, apiKey) + if err != nil { + log.Fatalf("Failed to create data manager: %v", err) + } + defer dm.Close() + + cmd := os.Args[1] + + switch cmd { + case "fetch": + if len(os.Args) < 4 { + log.Fatal("Usage: data_manager fetch [dataset...]") + } + jobName := os.Args[2] + datasets := os.Args[3:] + + for _, dataset := range datasets { + if err := dm.FetchDataset(jobName, dataset); err != nil { + dm.logger.Job(context.Background(), jobName, "").Error("failed to fetch dataset", + "dataset", dataset, + "error", err) + } + } + + case "list": + datasets, err := dm.ListDatasetsOnML() + if err != nil { + log.Fatalf("Failed to list datasets: %v", err) + } + + fmt.Println("Datasets on ML server:") + fmt.Println("======================") + var totalSize int64 + for _, ds := range datasets { + sizeMB := float64(ds.SizeBytes) / (1024 * 1024) + lastAccess := "unknown" + if !ds.LastAccess.IsZero() { + lastAccess = ds.LastAccess.Format("2006-01-02 15:04:05") + } + fmt.Printf("%-30s %10.2f MB Last access: %s\n", ds.Name, sizeMB, lastAccess) + totalSize += ds.SizeBytes + } + fmt.Printf("\nTotal: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) + + case "validate": + if len(os.Args) < 3 { + log.Fatal("Usage: data_manager validate ") + } + dataset := os.Args[2] + + fmt.Printf("Validating dataset: %s\n", dataset) + if err := dm.ValidateDatasetWithCleanup(dataset); err != nil { + log.Fatalf("Validation failed: %v", err) + } + fmt.Printf("✅ Dataset %s can be downloaded\n", dataset) + + case "cleanup": + if err := dm.CleanupOldData(); err != nil { + log.Fatalf("Cleanup failed: %v", err) + } + + case "daemon": + logger := dm.logger.Job(context.Background(), "data_manager", "") + logger.Info("starting data manager daemon") + dm.StartCleanupLoop() + logger.Info("cleanup configuration", + "interval_minutes", cfg.CleanupInterval, + "max_age_hours", cfg.MaxAgeHours, + "max_size_gb", cfg.MaxSizeGB) + + // Handle graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + sig := <-sigChan + logger.Info("received shutdown signal", "signal", sig) + dm.Close() + logger.Info("data manager shut down gracefully") + + default: + log.Fatalf("Unknown command: %s", cmd) + } +} diff --git a/cmd/tui/README.md b/cmd/tui/README.md new file mode 100644 index 0000000..7970661 --- /dev/null +++ b/cmd/tui/README.md @@ -0,0 +1,282 @@ +# FetchML TUI - Terminal User Interface + +An interactive terminal dashboard for managing ML experiments, monitoring system resources, and controlling job execution. + +## Features + +### 📊 Real-time Monitoring +- **Job Status** - Track pending, running, finished, and failed jobs +- **GPU Metrics** - Monitor GPU utilization, memory, and temperature +- **Container Status** - View running Podman/Docker containers +- **Task Queue** - See queued tasks with priorities and status + +### 🎮 Interactive Controls +- **Queue Jobs** - Submit jobs with custom arguments and priorities +- **View Logs** - Real-time log viewing for running jobs +- **Cancel Tasks** - Stop running tasks +- **Delete Jobs** - Remove pending jobs +- **Mark Failed** - Manually mark stuck jobs as failed + +### ⚙️ Settings Management +- **API Key Configuration** - Set and update API keys on the fly +- **In-memory Storage** - Settings persist for the session + +### 🎨 Modern UI +- **Clean Design** - Dark-mode friendly with adaptive colors +- **Responsive Layout** - Adjusts to terminal size +- **Context-aware Help** - Shows relevant shortcuts for each view +- **Mouse Support** - Optional mouse navigation + +## Quick Start + +### Running the TUI + +```bash +# Using make (recommended) +make tui-dev # Dev mode (remote server) +make tui # Local mode + +# Direct execution with CLI config (TOML) +./bin/tui --config ~/.ml/config.toml + +# With custom TOML config +./bin/tui --config path/to/config.toml +``` + +### First Time Setup + +1. **Build the binary** + ```bash + make build + ``` + +2. **Get your API key** + ```bash + ./bin/user_manager --config configs/config_dev.yaml --cmd generate-key --username your_name + ``` + +3. **Launch the TUI** + ```bash + make tui-dev + ``` + +## Keyboard Shortcuts + +### Navigation +| Key | Action | +|-----|--------| +| `1` | Switch to Job List view | +| `g` | Switch to GPU Status view | +| `l` | View logs for selected job | +| `v` | Switch to Task Queue view | +| `o` | Switch to Container Status view | +| `s` | Open Settings | +| `h` or `?` | Toggle help screen | + +### Job Management +| Key | Action | +|-----|--------| +| `t` | Queue selected job (default args) | +| `a` | Queue job with custom arguments | +| `c` | Cancel running task | +| `d` | Delete pending job | +| `f` | Mark running job as failed | + +### System +| Key | Action | +|-----|--------| +| `r` | Refresh all data | +| `G` | Refresh GPU status only | +| `q` or `Ctrl+C` | Quit | + +### Settings View +| Key | Action | +|-----|--------| +| `↑`/`↓` or `j`/`k` | Navigate options | +| `Enter` | Select/Save | +| `Esc` | Exit settings | + +## Views + +### Job List (Default) +- Shows all jobs across all statuses +- Filter with `/` key +- Navigate with arrow keys or `j`/`k` +- Select and press `l` to view logs + +### GPU Status +- Real-time GPU metrics (nvidia-smi) +- macOS GPU info (system_profiler) +- Utilization, memory, temperature + +### Container Status +- Running Podman/Docker containers +- Container health and status +- System info (Podman/Docker version) + +### Task Queue +- All queued tasks with priorities +- Task status and creation time +- Running duration for active tasks + +### Logs +- Last 200 lines of job output +- Auto-scroll to bottom +- Refreshes with job status + +### Settings +- View current API key status +- Update API key +- Save configuration (in-memory) + +## Terminal Compatibility + +The TUI is built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) and works on all modern terminals: + +### ✅ Fully Supported +- **WezTerm** (recommended) +- **Alacritty** +- **Kitty** +- **iTerm2** (macOS) +- **Terminal.app** (macOS) +- **Windows Terminal** +- **GNOME Terminal** +- **Konsole** + +### ✅ Multiplexers +- **tmux** +- **screen** + +### Features +- ✅ 256 colors +- ✅ True color (24-bit) +- ✅ Mouse support +- ✅ Alt screen buffer +- ✅ Adaptive colors (light/dark themes) + + +### Key Components + +- **Model** - Pure data structures (State, Job, Task) +- **View** - Rendering functions (no business logic) +- **Controller** - Message handling and state updates +- **Services** - SSH/Redis communication + +## Configuration + +The TUI uses TOML configuration format for CLI settings: + +```toml +# ~/.ml/config.toml +worker_host = "localhost" +worker_user = "your_user" +worker_base = "~/ml_jobs" +worker_port = 22 +api_key = "your_api_key_here" +``` + +For CLI usage, run `ml init` to create a default configuration file. + +See [Configuration Documentation](../docs/documentation.md#configuration) for details. + +## Troubleshooting + +### TUI doesn't start +```bash +# Check if binary exists +ls -la bin/tui + +# Rebuild if needed +make build + +# Check CLI config +cat ~/.ml/config.toml +``` + +### Authentication errors +```bash +# Verify CLI config exists +ls -la ~/.ml/config.toml + +# Initialize CLI config if needed +ml init + +# Test connection +./bin/tui --config ~/.ml/config.toml +``` + +### Display issues +```bash +# Check terminal type +echo $TERM + +# Should be xterm-256color or similar +# If not, set it: +export TERM=xterm-256color +``` + +### Connection issues +```bash +# Test SSH connection +ssh your_user@your_server + +# Test Redis connection +redis-cli ping +``` + +## Development + +### Building +```bash +# Build TUI only +go build -o bin/tui ./cmd/tui + +# Build all binaries +make build +``` + +### Testing +```bash +# Run with verbose logging +./bin/tui --config ~/.ml/config.toml 2>tui.log + +# Check logs +tail -f tui.log +``` + +### Code Organization +- Keep files under 300 lines +- Separate concerns (MVC pattern) +- Use descriptive function names +- Add comments for complex logic + +## Tips & Tricks + +### Efficient Workflow +1. Keep TUI open in one terminal +2. Edit code in another terminal +3. Use `r` to refresh after changes +4. Use `h` to quickly reference shortcuts + +### Custom Arguments +When queuing jobs with `a`: +``` +--epochs 100 --lr 0.001 --priority 5 +``` + +### Monitoring +- Use `G` for quick GPU refresh (faster than `r`) +- Check queue with `v` before queuing new jobs +- Use `l` to debug failed jobs + +### Settings +- Update API key without restarting +- Changes are in-memory only +- Restart TUI to reset + +## See Also + +- [Main Documentation](../docs/documentation.md) +- [Worker Documentation](../cmd/worker/README.md) +- [Configuration Guide](../docs/documentation.md#configuration) +- [Bubble Tea Documentation](https://github.com/charmbracelet/bubbletea) diff --git a/cmd/tui/internal/config/cli_config.go b/cmd/tui/internal/config/cli_config.go new file mode 100644 index 0000000..80fba08 --- /dev/null +++ b/cmd/tui/internal/config/cli_config.go @@ -0,0 +1,492 @@ +package config + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/jfraeys/fetch_ml/internal/auth" + utils "github.com/jfraeys/fetch_ml/internal/config" + "github.com/stretchr/testify/assert/yaml" +) + +// CLIConfig represents the TOML config structure used by the CLI +type CLIConfig struct { + WorkerHost string `toml:"worker_host"` + WorkerUser string `toml:"worker_user"` + WorkerBase string `toml:"worker_base"` + WorkerPort int `toml:"worker_port"` + APIKey string `toml:"api_key"` + + // User context (filled after authentication) + CurrentUser *UserContext `toml:"-"` +} + +// UserContext represents the authenticated user information +type UserContext struct { + Name string `json:"name"` + Admin bool `json:"admin"` + Roles []string `json:"roles"` + Permissions map[string]bool `json:"permissions"` +} + +// LoadCLIConfig loads the CLI's TOML configuration from the provided path. +// If path is empty, ~/.ml/config.toml is used. The resolved path is returned. +// Automatically migrates from YAML config if TOML doesn't exist. +// Environment variables with FETCH_ML_CLI_ prefix override config file values. +func LoadCLIConfig(configPath string) (*CLIConfig, string, error) { + if configPath == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, "", fmt.Errorf("failed to get home directory: %w", err) + } + configPath = filepath.Join(home, ".ml", "config.toml") + } else { + configPath = utils.ExpandPath(configPath) + if !filepath.IsAbs(configPath) { + if abs, err := filepath.Abs(configPath); err == nil { + configPath = abs + } + } + } + + // Check if TOML config exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + // Try to migrate from YAML + yamlPath := strings.TrimSuffix(configPath, ".toml") + ".yaml" + if migratedPath, err := migrateFromYAML(yamlPath, configPath); err == nil { + log.Printf("Migrated configuration from %s to %s", yamlPath, migratedPath) + configPath = migratedPath + } else { + return nil, configPath, fmt.Errorf("CLI config not found at %s (run 'ml init' first)", configPath) + } + } else if err != nil { + return nil, configPath, fmt.Errorf("cannot access CLI config %s: %w", configPath, err) + } + + if err := auth.CheckConfigFilePermissions(configPath); err != nil { + log.Printf("Warning: %v", err) + } + + data, err := os.ReadFile(configPath) + if err != nil { + return nil, configPath, fmt.Errorf("failed to read CLI config: %w", err) + } + + config := &CLIConfig{} + if err := parseTOML(data, config); err != nil { + return nil, configPath, fmt.Errorf("failed to parse CLI config: %w", err) + } + + if err := config.Validate(); err != nil { + return nil, configPath, err + } + + // Apply environment variable overrides with FETCH_ML_CLI_ prefix + if host := os.Getenv("FETCH_ML_CLI_HOST"); host != "" { + config.WorkerHost = host + } + if user := os.Getenv("FETCH_ML_CLI_USER"); user != "" { + config.WorkerUser = user + } + if base := os.Getenv("FETCH_ML_CLI_BASE"); base != "" { + config.WorkerBase = base + } + if port := os.Getenv("FETCH_ML_CLI_PORT"); port != "" { + if p, err := parseInt(port); err == nil { + config.WorkerPort = p + } + } + if apiKey := os.Getenv("FETCH_ML_CLI_API_KEY"); apiKey != "" { + config.APIKey = apiKey + } + + // Also support legacy ML_ prefix for backward compatibility + if host := os.Getenv("ML_HOST"); host != "" && config.WorkerHost == "" { + config.WorkerHost = host + } + if user := os.Getenv("ML_USER"); user != "" && config.WorkerUser == "" { + config.WorkerUser = user + } + if base := os.Getenv("ML_BASE"); base != "" && config.WorkerBase == "" { + config.WorkerBase = base + } + if port := os.Getenv("ML_PORT"); port != "" && config.WorkerPort == 0 { + if p, err := parseInt(port); err == nil { + config.WorkerPort = p + } + } + if apiKey := os.Getenv("ML_API_KEY"); apiKey != "" && config.APIKey == "" { + config.APIKey = apiKey + } + + return config, configPath, nil +} + +// parseTOML is a simple TOML parser for the CLI config format +func parseTOML(data []byte, config *CLIConfig) error { + lines := strings.Split(string(data), "\n") + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Remove quotes if present + if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) { + value = value[1 : len(value)-1] + } + + switch key { + case "worker_host": + config.WorkerHost = value + case "worker_user": + config.WorkerUser = value + case "worker_base": + config.WorkerBase = value + case "worker_port": + if p, err := parseInt(value); err == nil { + config.WorkerPort = p + } + case "api_key": + config.APIKey = value + } + } + + return nil +} + +// ToTUIConfig converts CLI config to TUI config structure +func (c *CLIConfig) ToTUIConfig() *Config { + // Get smart defaults for current environment + smart := utils.GetSmartDefaults() + + tuiConfig := &Config{ + Host: c.WorkerHost, + User: c.WorkerUser, + Port: c.WorkerPort, + BasePath: c.WorkerBase, + + // Set defaults for TUI-specific fields using smart defaults + RedisAddr: smart.RedisAddr(), + RedisDB: 0, + PodmanImage: "ml-worker:latest", + ContainerWorkspace: utils.DefaultContainerWorkspace, + ContainerResults: utils.DefaultContainerResults, + GPUAccess: false, + } + + // Set up auth config with CLI API key + tuiConfig.Auth = auth.AuthConfig{ + Enabled: true, + APIKeys: map[auth.Username]auth.APIKeyEntry{ + "cli_user": { + Hash: auth.APIKeyHash(hashAPIKey(c.APIKey)), + Admin: true, + Roles: []string{"user", "admin"}, + Permissions: map[string]bool{ + "read": true, + "write": true, + "delete": true, + }, + }, + }, + } + + // Set known hosts path + tuiConfig.KnownHosts = smart.KnownHostsPath() + + return tuiConfig +} + +// Validate validates the CLI config +func (c *CLIConfig) Validate() error { + var errors []string + + if c.WorkerHost == "" { + errors = append(errors, "worker_host is required") + } else if len(strings.TrimSpace(c.WorkerHost)) == 0 { + errors = append(errors, "worker_host cannot be empty or whitespace") + } + + if c.WorkerUser == "" { + errors = append(errors, "worker_user is required") + } else if len(strings.TrimSpace(c.WorkerUser)) == 0 { + errors = append(errors, "worker_user cannot be empty or whitespace") + } + + if c.WorkerBase == "" { + errors = append(errors, "worker_base is required") + } else { + // Expand and validate path + c.WorkerBase = utils.ExpandPath(c.WorkerBase) + if !filepath.IsAbs(c.WorkerBase) { + errors = append(errors, "worker_base must be an absolute path") + } + } + + if c.WorkerPort == 0 { + errors = append(errors, "worker_port is required") + } else if err := utils.ValidatePort(c.WorkerPort); err != nil { + errors = append(errors, fmt.Sprintf("invalid worker_port: %v", err)) + } + + if c.APIKey == "" { + errors = append(errors, "api_key is required") + } else if len(c.APIKey) < 16 { + errors = append(errors, "api_key must be at least 16 characters") + } + + if len(errors) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(errors, "; ")) + } + + return nil +} + +// AuthenticateWithServer validates the API key and sets user context +func (c *CLIConfig) AuthenticateWithServer() error { + if c.APIKey == "" { + return fmt.Errorf("no API key configured") + } + + // Create temporary auth config for validation + authConfig := &auth.AuthConfig{ + Enabled: true, + APIKeys: map[auth.Username]auth.APIKeyEntry{ + "temp": { + Hash: auth.APIKeyHash(auth.HashAPIKey(c.APIKey)), + Admin: false, + }, + }, + } + + // Validate API key and get user info + user, err := authConfig.ValidateAPIKey(auth.HashAPIKey(c.APIKey)) + if err != nil { + return fmt.Errorf("API key validation failed: %w", err) + } + + // Set user context + c.CurrentUser = &UserContext{ + Name: user.Name, + Admin: user.Admin, + Roles: user.Roles, + Permissions: user.Permissions, + } + + return nil +} + +// CheckPermission checks if the current user has a specific permission +func (c *CLIConfig) CheckPermission(permission string) bool { + if c.CurrentUser == nil { + return false + } + + // Admin users have all permissions + if c.CurrentUser.Admin { + return true + } + + // Check explicit permission + if c.CurrentUser.Permissions[permission] { + return true + } + + // Check wildcard permission + if c.CurrentUser.Permissions["*"] { + return true + } + + return false +} + +// CanViewJob checks if user can view a specific job +func (c *CLIConfig) CanViewJob(jobUserID string) bool { + if c.CurrentUser == nil { + return false + } + + // Admin can view all jobs + if c.CurrentUser.Admin { + return true + } + + // Users can view their own jobs + return jobUserID == c.CurrentUser.Name +} + +// CanModifyJob checks if user can modify a specific job +func (c *CLIConfig) CanModifyJob(jobUserID string) bool { + if c.CurrentUser == nil { + return false + } + + // Need jobs:update permission + if !c.CheckPermission("jobs:update") { + return false + } + + // Admin can modify all jobs + if c.CurrentUser.Admin { + return true + } + + // Users can only modify their own jobs + return jobUserID == c.CurrentUser.Name +} + +// migrateFromYAML migrates configuration from YAML to TOML format +func migrateFromYAML(yamlPath, tomlPath string) (string, error) { + // Check if YAML file exists + if _, err := os.Stat(yamlPath); os.IsNotExist(err) { + return "", fmt.Errorf("YAML config not found at %s", yamlPath) + } + + // Read YAML config + data, err := os.ReadFile(yamlPath) + if err != nil { + return "", fmt.Errorf("failed to read YAML config: %w", err) + } + + // Parse YAML to extract relevant fields + var yamlConfig map[string]interface{} + if err := yaml.Unmarshal(data, &yamlConfig); err != nil { + return "", fmt.Errorf("failed to parse YAML config: %w", err) + } + + // Create CLI config from YAML data + cliConfig := &CLIConfig{} + + // Extract values with fallbacks + if host, ok := yamlConfig["host"].(string); ok { + cliConfig.WorkerHost = host + } + if user, ok := yamlConfig["user"].(string); ok { + cliConfig.WorkerUser = user + } + if base, ok := yamlConfig["base_path"].(string); ok { + cliConfig.WorkerBase = base + } + if port, ok := yamlConfig["port"].(int); ok { + cliConfig.WorkerPort = port + } + + // Try to extract API key from auth section + if auth, ok := yamlConfig["auth"].(map[string]interface{}); ok { + if apiKeys, ok := auth["api_keys"].(map[string]interface{}); ok { + for _, keyEntry := range apiKeys { + if keyMap, ok := keyEntry.(map[string]interface{}); ok { + if hash, ok := keyMap["hash"].(string); ok { + cliConfig.APIKey = hash // Note: This is the hash, not the actual key + break + } + } + } + } + } + + // Validate migrated config + if err := cliConfig.Validate(); err != nil { + return "", fmt.Errorf("migrated config validation failed: %w", err) + } + + // Generate TOML content + tomlContent := fmt.Sprintf(`# Fetch ML CLI Configuration +# Migrated from YAML configuration + +worker_host = "%s" +worker_user = "%s" +worker_base = "%s" +worker_port = %d +api_key = "%s" +`, + cliConfig.WorkerHost, + cliConfig.WorkerUser, + cliConfig.WorkerBase, + cliConfig.WorkerPort, + cliConfig.APIKey, + ) + + // Create directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(tomlPath), 0755); err != nil { + return "", fmt.Errorf("failed to create config directory: %w", err) + } + + // Write TOML file + if err := os.WriteFile(tomlPath, []byte(tomlContent), 0600); err != nil { + return "", fmt.Errorf("failed to write TOML config: %w", err) + } + + return tomlPath, nil +} + +// ConfigExists checks if a CLI configuration file exists +func ConfigExists(configPath string) bool { + if configPath == "" { + home, err := os.UserHomeDir() + if err != nil { + return false + } + configPath = filepath.Join(home, ".ml", "config.toml") + } + + _, err := os.Stat(configPath) + return !os.IsNotExist(err) +} + +// GenerateDefaultConfig creates a default TOML configuration file +func GenerateDefaultConfig(configPath string) error { + // Create directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + // Generate default configuration + defaultContent := `# Fetch ML CLI Configuration +# This file contains connection settings for the ML platform + +# Worker connection settings +worker_host = "localhost" # Hostname or IP of the worker +worker_user = "your_username" # SSH username for the worker +worker_base = "~/ml_jobs" # Base directory for ML jobs on worker +worker_port = 22 # SSH port (default: 22) + +# Authentication +api_key = "your_api_key_here" # Your API key (get from admin) + +# Environment variable overrides: +# ML_HOST, ML_USER, ML_BASE, ML_PORT, ML_API_KEY +` + + // Write configuration file + if err := os.WriteFile(configPath, []byte(defaultContent), 0600); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + + // Set proper permissions + if err := auth.CheckConfigFilePermissions(configPath); err != nil { + log.Printf("Warning: %v", err) + } + + return nil +} + +func hashAPIKey(apiKey string) string { + if apiKey == "" { + return "" + } + return auth.HashAPIKey(apiKey) +} diff --git a/cmd/tui/internal/config/cli_config_test.go b/cmd/tui/internal/config/cli_config_test.go new file mode 100644 index 0000000..5de2ece --- /dev/null +++ b/cmd/tui/internal/config/cli_config_test.go @@ -0,0 +1,194 @@ +package config + +import ( + "testing" +) + +func TestCLIConfig_CheckPermission(t *testing.T) { + tests := []struct { + name string + config *CLIConfig + permission string + want bool + }{ + { + name: "Admin has all permissions", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "admin", + Admin: true, + }, + }, + permission: "any:permission", + want: true, + }, + { + name: "User with explicit permission", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user", + Admin: false, + Permissions: map[string]bool{"jobs:create": true}, + }, + }, + permission: "jobs:create", + want: true, + }, + { + name: "User without permission", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user", + Admin: false, + Permissions: map[string]bool{"jobs:read": true}, + }, + }, + permission: "jobs:create", + want: false, + }, + { + name: "No current user", + config: &CLIConfig{ + CurrentUser: nil, + }, + permission: "jobs:create", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.config.CheckPermission(tt.permission) + if got != tt.want { + t.Errorf("CheckPermission() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCLIConfig_CanViewJob(t *testing.T) { + tests := []struct { + name string + config *CLIConfig + jobUserID string + want bool + }{ + { + name: "Admin can view any job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "admin", + Admin: true, + }, + }, + jobUserID: "other_user", + want: true, + }, + { + name: "User can view own job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user1", + Admin: false, + }, + }, + jobUserID: "user1", + want: true, + }, + { + name: "User cannot view other's job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user1", + Admin: false, + }, + }, + jobUserID: "user2", + want: false, + }, + { + name: "No current user cannot view", + config: &CLIConfig{ + CurrentUser: nil, + }, + jobUserID: "user1", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.config.CanViewJob(tt.jobUserID) + if got != tt.want { + t.Errorf("CanViewJob() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCLIConfig_CanModifyJob(t *testing.T) { + tests := []struct { + name string + config *CLIConfig + jobUserID string + want bool + }{ + { + name: "Admin can modify any job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "admin", + Admin: true, + Permissions: map[string]bool{"jobs:update": true}, + }, + }, + jobUserID: "other_user", + want: true, + }, + { + name: "User with permission can modify own job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user1", + Admin: false, + Permissions: map[string]bool{"jobs:update": true}, + }, + }, + jobUserID: "user1", + want: true, + }, + { + name: "User without permission cannot modify", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user1", + Admin: false, + Permissions: map[string]bool{"jobs:read": true}, + }, + }, + jobUserID: "user1", + want: false, + }, + { + name: "User cannot modify other's job", + config: &CLIConfig{ + CurrentUser: &UserContext{ + Name: "user1", + Admin: false, + Permissions: map[string]bool{"jobs:update": true}, + }, + }, + jobUserID: "user2", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.config.CanModifyJob(tt.jobUserID) + if got != tt.want { + t.Errorf("CanModifyJob() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/tui/internal/config/config.go b/cmd/tui/internal/config/config.go new file mode 100644 index 0000000..15b6e45 --- /dev/null +++ b/cmd/tui/internal/config/config.go @@ -0,0 +1,145 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/BurntSushi/toml" + "github.com/jfraeys/fetch_ml/internal/auth" + utils "github.com/jfraeys/fetch_ml/internal/config" +) + +// Config holds TUI configuration +type Config struct { + Host string `toml:"host"` + User string `toml:"user"` + SSHKey string `toml:"ssh_key"` + Port int `toml:"port"` + BasePath string `toml:"base_path"` + WrapperScript string `toml:"wrapper_script"` + TrainScript string `toml:"train_script"` + RedisAddr string `toml:"redis_addr"` + RedisPassword string `toml:"redis_password"` + RedisDB int `toml:"redis_db"` + KnownHosts string `toml:"known_hosts"` + + // Authentication + Auth auth.AuthConfig `toml:"auth"` + + // Podman settings + PodmanImage string `toml:"podman_image"` + ContainerWorkspace string `toml:"container_workspace"` + ContainerResults string `toml:"container_results"` + GPUAccess bool `toml:"gpu_access"` +} + +func LoadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var cfg Config + if _, err := toml.Decode(string(data), &cfg); err != nil { + return nil, err + } + + // Get smart defaults for current environment + smart := utils.GetSmartDefaults() + + if cfg.Port == 0 { + cfg.Port = utils.DefaultSSHPort + } + if cfg.Host == "" { + cfg.Host = smart.Host() + } + if cfg.BasePath == "" { + cfg.BasePath = smart.BasePath() + } + // wrapper_script is deprecated - using secure_runner.py directly via Podman + if cfg.TrainScript == "" { + cfg.TrainScript = utils.DefaultTrainScript + } + if cfg.RedisAddr == "" { + cfg.RedisAddr = smart.RedisAddr() + } + if cfg.KnownHosts == "" { + cfg.KnownHosts = smart.KnownHostsPath() + } + + // Apply environment variable overrides with FETCH_ML_TUI_ prefix + if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" { + cfg.Host = host + } + if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" { + cfg.User = user + } + if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" { + cfg.SSHKey = sshKey + } + if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" { + if p, err := parseInt(port); err == nil { + cfg.Port = p + } + } + if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" { + cfg.BasePath = basePath + } + if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" { + cfg.TrainScript = trainScript + } + if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" { + cfg.RedisAddr = redisAddr + } + if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" { + cfg.RedisPassword = redisPassword + } + if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" { + if db, err := parseInt(redisDB); err == nil { + cfg.RedisDB = db + } + } + if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" { + cfg.KnownHosts = knownHosts + } + + return &cfg, nil +} + +// Validate implements utils.Validator interface +func (c *Config) Validate() error { + if c.Port != 0 { + if err := utils.ValidatePort(c.Port); err != nil { + return fmt.Errorf("invalid SSH port: %w", err) + } + } + + if c.BasePath != "" { + // Convert relative paths to absolute + c.BasePath = utils.ExpandPath(c.BasePath) + if !filepath.IsAbs(c.BasePath) { + c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath) + } + } + + if c.RedisAddr != "" { + if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil { + return fmt.Errorf("invalid Redis configuration: %w", err) + } + } + + return nil +} + +func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") } +func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") } +func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") } +func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") } + +// parseInt parses a string to integer +func parseInt(s string) (int, error) { + var result int + _, err := fmt.Sscanf(s, "%d", &result) + return result, err +} diff --git a/cmd/tui/internal/controller/commands.go b/cmd/tui/internal/controller/commands.go new file mode 100644 index 0000000..81e5d4c --- /dev/null +++ b/cmd/tui/internal/controller/commands.go @@ -0,0 +1,384 @@ +package controller + +import ( + "fmt" + "path/filepath" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" +) + +// Message types for async operations +type ( + JobsLoadedMsg []model.Job + TasksLoadedMsg []*model.Task + GpuLoadedMsg string + ContainerLoadedMsg string + LogLoadedMsg string + QueueLoadedMsg string + SettingsContentMsg string + SettingsUpdateMsg struct{} + StatusMsg struct { + Text string + Level string + } + TickMsg time.Time +) + +// Command factories for loading data + +func (c *Controller) loadAllData() tea.Cmd { + return tea.Batch( + c.loadJobs(), + c.loadQueue(), + c.loadGPU(), + c.loadContainer(), + ) +} + +func (c *Controller) loadJobs() tea.Cmd { + return func() tea.Msg { + type jobResult struct { + jobs []model.Job + err error + } + + resultChan := make(chan jobResult, 1) + go func() { + var jobs []model.Job + statusChan := make(chan []model.Job, 4) + + for _, status := range []model.JobStatus{model.StatusPending, model.StatusRunning, model.StatusFinished, model.StatusFailed} { + go func(s model.JobStatus) { + path := c.getPathForStatus(s) + names := c.server.ListDir(path) + var statusJobs []model.Job + for _, name := range names { + jobStatus, _ := c.taskQueue.GetJobStatus(name) + taskID := jobStatus["task_id"] + priority := int64(0) + if p, ok := jobStatus["priority"]; ok { + _, err := fmt.Sscanf(p, "%d", &priority) + if err != nil { + priority = 0 + } + } + statusJobs = append(statusJobs, model.Job{ + Name: name, + Status: s, + TaskID: taskID, + Priority: priority, + }) + } + statusChan <- statusJobs + }(status) + } + + for range 4 { + jobs = append(jobs, <-statusChan...) + } + + resultChan <- jobResult{jobs: jobs, err: nil} + }() + + result := <-resultChan + if result.err != nil { + return StatusMsg{Text: "Failed to load jobs: " + result.err.Error(), Level: "error"} + } + return JobsLoadedMsg(result.jobs) + } +} + +func (c *Controller) loadQueue() tea.Cmd { + return func() tea.Msg { + tasks, err := c.taskQueue.GetQueuedTasks() + if err != nil { + c.logger.Error("failed to load queue", "error", err) + return StatusMsg{Text: "Failed to load queue: " + err.Error(), Level: "error"} + } + c.logger.Info("loaded queue", "task_count", len(tasks)) + return TasksLoadedMsg(tasks) + } +} + +func (c *Controller) loadGPU() tea.Cmd { + return func() tea.Msg { + type gpuResult struct { + content string + err error + } + + resultChan := make(chan gpuResult, 1) + go func() { + cmd := "nvidia-smi --query-gpu=index,name,utilization.gpu,memory.used,memory.total,temperature.gpu --format=csv,noheader,nounits" + out, err := c.server.Exec(cmd) + if err == nil && strings.TrimSpace(out) != "" { + var formatted strings.Builder + formatted.WriteString("GPU Status\n") + formatted.WriteString(strings.Repeat("═", 50) + "\n\n") + lines := strings.Split(strings.TrimSpace(out), "\n") + for _, line := range lines { + parts := strings.Split(line, ", ") + if len(parts) >= 6 { + formatted.WriteString(fmt.Sprintf("🎮 GPU %s: %s\n", parts[0], parts[1])) + formatted.WriteString(fmt.Sprintf(" Utilization: %s%%\n", parts[2])) + formatted.WriteString(fmt.Sprintf(" Memory: %s/%s MB\n", parts[3], parts[4])) + formatted.WriteString(fmt.Sprintf(" Temperature: %s°C\n\n", parts[5])) + } + } + c.logger.Info("loaded GPU status", "type", "nvidia") + resultChan <- gpuResult{content: formatted.String(), err: nil} + return + } + + cmd = "system_profiler SPDisplaysDataType | grep 'Chipset Model\\|VRAM' | head -2" + out, err = c.server.Exec(cmd) + if err != nil { + c.logger.Warn("GPU info unavailable", "error", err) + resultChan <- gpuResult{content: "⚠️ GPU info unavailable\n\nRun on a system with nvidia-smi or macOS GPU", err: err} + return + } + + var formatted strings.Builder + formatted.WriteString("GPU Status (macOS)\n") + formatted.WriteString(strings.Repeat("═", 50) + "\n\n") + lines := strings.Split(strings.TrimSpace(out), "\n") + for _, line := range lines { + if strings.Contains(line, "Chipset Model") || strings.Contains(line, "VRAM") { + formatted.WriteString("🎮 " + strings.TrimSpace(line) + "\n") + } + } + formatted.WriteString("\n💡 Note: nvidia-smi not available on macOS\n") + + c.logger.Info("loaded GPU status", "type", "macos") + resultChan <- gpuResult{content: formatted.String(), err: nil} + }() + + result := <-resultChan + return GpuLoadedMsg(result.content) + } +} + +func (c *Controller) loadContainer() tea.Cmd { + return func() tea.Msg { + resultChan := make(chan string, 1) + go func() { + var formatted strings.Builder + formatted.WriteString("Container Status\n") + formatted.WriteString(strings.Repeat("═", 50) + "\n\n") + + formatted.WriteString("📋 Configuration:\n") + formatted.WriteString(fmt.Sprintf(" Image: %s\n", c.config.PodmanImage)) + formatted.WriteString(fmt.Sprintf(" GPU: %v\n", c.config.GPUAccess)) + formatted.WriteString(fmt.Sprintf(" Workspace: %s\n", c.config.ContainerWorkspace)) + formatted.WriteString(fmt.Sprintf(" Results: %s\n\n", c.config.ContainerResults)) + + cmd := "podman ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'" + out, err := c.server.Exec(cmd) + if err == nil && strings.TrimSpace(out) != "" { + formatted.WriteString("🐳 Running Containers (Podman):\n") + lines := strings.Split(strings.TrimSpace(out), "\n") + for _, line := range lines { + parts := strings.Split(line, "|") + if len(parts) >= 3 { + status := "🟢" + if strings.Contains(parts[1], "Exited") { + status = "🔴" + } + formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0])) + formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1])) + formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2])) + } + } + } else { + cmd = "docker ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'" + out, err = c.server.Exec(cmd) + if err == nil && strings.TrimSpace(out) != "" { + formatted.WriteString("🐳 Running Containers (Docker):\n") + lines := strings.Split(strings.TrimSpace(out), "\n") + for _, line := range lines { + parts := strings.Split(line, "|") + if len(parts) >= 3 { + status := "🟢" + if strings.Contains(parts[1], "Exited") { + status = "🔴" + } + formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0])) + formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1])) + formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2])) + } + } + } else { + formatted.WriteString("⚠️ No containers found\n") + } + } + + formatted.WriteString("💻 System Info:\n") + if podmanVersion, err := c.server.Exec("podman --version"); err == nil { + formatted.WriteString(fmt.Sprintf(" Podman: %s\n", strings.TrimSpace(podmanVersion))) + } else if dockerVersion, err := c.server.Exec("docker --version"); err == nil { + formatted.WriteString(fmt.Sprintf(" Docker: %s\n", strings.TrimSpace(dockerVersion))) + } else { + formatted.WriteString(" ⚠️ Container engine not available\n") + } + + c.logger.Info("loaded container status") + resultChan <- formatted.String() + }() + + return ContainerLoadedMsg(<-resultChan) + } +} + +func (c *Controller) loadLog(jobName string) tea.Cmd { + return func() tea.Msg { + resultChan := make(chan string, 1) + go func() { + statusChan := make(chan string, 3) + + for _, status := range []model.JobStatus{model.StatusRunning, model.StatusFinished, model.StatusFailed} { + go func(s model.JobStatus) { + logPath := filepath.Join(c.getPathForStatus(s), jobName, "output.log") + if c.server.RemoteExists(logPath) { + content := c.server.TailFile(logPath, 200) + statusChan <- content + } else { + statusChan <- "" + } + }(status) + } + + for range 3 { + result := <-statusChan + if result != "" { + var formatted strings.Builder + formatted.WriteString(fmt.Sprintf("📋 Log: %s\n", jobName)) + formatted.WriteString(strings.Repeat("═", 60) + "\n\n") + formatted.WriteString(result) + resultChan <- formatted.String() + return + } + } + + resultChan <- fmt.Sprintf("⚠️ No log found for %s\n\nJob may not have started yet.", jobName) + }() + + return LogLoadedMsg(<-resultChan) + } +} + +func (c *Controller) queueJob(jobName string, args string) tea.Cmd { + return func() tea.Msg { + resultChan := make(chan StatusMsg, 1) + go func() { + priority := int64(5) + if strings.Contains(args, "--priority") { + _, err := fmt.Sscanf(args, "--priority %d", &priority) + if err != nil { + c.logger.Error("invalid priority argument", "args", args, "error", err) + resultChan <- StatusMsg{ + Text: fmt.Sprintf("Invalid priority: %v", err), + Level: "error", + } + return + } + } + + task, err := c.taskQueue.EnqueueTask(jobName, args, priority) + if err != nil { + c.logger.Error("failed to queue job", "job_name", jobName, "error", err) + resultChan <- StatusMsg{ + Text: fmt.Sprintf("Failed to queue %s: %v", jobName, err), + Level: "error", + } + return + } + + c.logger.Info("job queued", "job_name", jobName, "task_id", task.ID[:8], "priority", priority) + resultChan <- StatusMsg{ + Text: fmt.Sprintf("✓ Queued: %s (ID: %s, P:%d)", jobName, task.ID[:8], priority), + Level: "success", + } + }() + + return <-resultChan + } +} + +func (c *Controller) deleteJob(jobName string) tea.Cmd { + return func() tea.Msg { + jobPath := filepath.Join(c.config.PendingPath(), jobName) + if _, err := c.server.Exec(fmt.Sprintf("rm -rf %s", jobPath)); err != nil { + return StatusMsg{Text: fmt.Sprintf("Failed to delete %s: %v", jobName, err), Level: "error"} + } + return StatusMsg{Text: fmt.Sprintf("✓ Deleted: %s", jobName), Level: "success"} + } +} + +func (c *Controller) markFailed(jobName string) tea.Cmd { + return func() tea.Msg { + src := filepath.Join(c.config.RunningPath(), jobName) + dst := filepath.Join(c.config.FailedPath(), jobName) + if _, err := c.server.Exec(fmt.Sprintf("mv %s %s", src, dst)); err != nil { + return StatusMsg{Text: fmt.Sprintf("Failed to mark failed: %v", err), Level: "error"} + } + return StatusMsg{Text: fmt.Sprintf("⚠ Marked failed: %s", jobName), Level: "warning"} + } +} + +func (c *Controller) cancelTask(taskID string) tea.Cmd { + return func() tea.Msg { + if err := c.taskQueue.CancelTask(taskID); err != nil { + c.logger.Error("failed to cancel task", "task_id", taskID[:8], "error", err) + return StatusMsg{Text: fmt.Sprintf("Cancel failed: %v", err), Level: "error"} + } + c.logger.Info("task cancelled", "task_id", taskID[:8]) + return StatusMsg{Text: fmt.Sprintf("✓ Cancelled: %s", taskID[:8]), Level: "success"} + } +} + +func (c *Controller) showQueue(m model.State) tea.Cmd { + return func() tea.Msg { + var content strings.Builder + content.WriteString("Task Queue\n") + content.WriteString(strings.Repeat("═", 60) + "\n\n") + + if len(m.QueuedTasks) == 0 { + content.WriteString("📭 No tasks in queue\n") + } else { + for i, task := range m.QueuedTasks { + statusIcon := "⏳" + if task.Status == "running" { + statusIcon = "▶" + } + + content.WriteString(fmt.Sprintf("%d. %s %s [ID: %s]\n", + i+1, statusIcon, task.JobName, task.ID[:8])) + content.WriteString(fmt.Sprintf(" Priority: %d | Status: %s\n", + task.Priority, task.Status)) + if task.Args != "" { + content.WriteString(fmt.Sprintf(" Args: %s\n", task.Args)) + } + content.WriteString(fmt.Sprintf(" Created: %s\n", + task.CreatedAt.Format("2006-01-02 15:04:05"))) + + if task.StartedAt != nil { + duration := time.Since(*task.StartedAt) + content.WriteString(fmt.Sprintf(" Running for: %s\n", + duration.Round(time.Second))) + } + content.WriteString("\n") + } + } + + return QueueLoadedMsg(content.String()) + } +} + +func tickCmd() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return TickMsg(t) + }) +} diff --git a/cmd/tui/internal/controller/controller.go b/cmd/tui/internal/controller/controller.go new file mode 100644 index 0000000..ee510fa --- /dev/null +++ b/cmd/tui/internal/controller/controller.go @@ -0,0 +1,302 @@ +package controller + +import ( + "fmt" + "time" + + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/list" + tea "github.com/charmbracelet/bubbletea" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/config" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/services" + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// Controller handles all business logic and state updates +type Controller struct { + config *config.Config + server *services.MLServer + taskQueue *services.TaskQueue + logger *logging.Logger +} + +// New creates a new Controller instance +func New(cfg *config.Config, srv *services.MLServer, tq *services.TaskQueue, logger *logging.Logger) *Controller { + return &Controller{ + config: cfg, + server: srv, + taskQueue: tq, + logger: logger, + } +} + +// Init initializes the TUI and returns initial commands +func (c *Controller) Init() tea.Cmd { + return tea.Batch( + tea.SetWindowTitle("FetchML"), + c.loadAllData(), + tickCmd(), + ) +} + +// Update handles all messages and updates the state +func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + // Handle input mode (for queuing jobs with args) + if m.InputMode { + switch msg.String() { + case "enter": + args := m.Input.Value() + m.Input.SetValue("") + m.InputMode = false + if job := getSelectedJob(m); job != nil { + cmds = append(cmds, c.queueJob(job.Name, args)) + } + return m, tea.Batch(cmds...) + case "esc": + m.InputMode = false + m.Input.SetValue("") + return m, nil + } + var cmd tea.Cmd + m.Input, cmd = m.Input.Update(msg) + return m, cmd + } + + // Handle settings-specific keys + if m.ActiveView == model.ViewModeSettings { + switch msg.String() { + case "up", "k": + if m.SettingsIndex > 1 { // Skip index 0 (Status) + m.SettingsIndex-- + cmds = append(cmds, c.updateSettingsContent(m)) + if m.SettingsIndex == 1 { + m.ApiKeyInput.Focus() + } else { + m.ApiKeyInput.Blur() + } + } + case "down", "j": + if m.SettingsIndex < 2 { + m.SettingsIndex++ + cmds = append(cmds, c.updateSettingsContent(m)) + if m.SettingsIndex == 1 { + m.ApiKeyInput.Focus() + } else { + m.ApiKeyInput.Blur() + } + } + case "enter": + if cmd := c.handleSettingsAction(&m); cmd != nil { + cmds = append(cmds, cmd) + } + case "esc": + m.ActiveView = model.ViewModeJobs + m.ApiKeyInput.Blur() + } + if m.SettingsIndex == 1 { // API Key input field + var cmd tea.Cmd + m.ApiKeyInput, cmd = m.ApiKeyInput.Update(msg) + cmds = append(cmds, cmd) + // Force update settings view to show typed characters immediately + cmds = append(cmds, c.updateSettingsContent(m)) + } + return m, tea.Batch(cmds...) + } + + // Handle global keys + switch { + case key.Matches(msg, m.Keys.Quit): + return m, tea.Quit + case key.Matches(msg, m.Keys.Refresh): + m.IsLoading = true + m.Status = "Refreshing all data..." + m.LastRefresh = time.Now() + cmds = append(cmds, c.loadAllData()) + case key.Matches(msg, m.Keys.RefreshGPU): + m.Status = "Refreshing GPU status..." + cmds = append(cmds, c.loadGPU()) + case key.Matches(msg, m.Keys.Trigger): + if job := getSelectedJob(m); job != nil { + cmds = append(cmds, c.queueJob(job.Name, "")) + } + case key.Matches(msg, m.Keys.TriggerArgs): + if job := getSelectedJob(m); job != nil { + m.InputMode = true + m.Input.Focus() + } + case key.Matches(msg, m.Keys.ViewQueue): + m.ActiveView = model.ViewModeQueue + cmds = append(cmds, c.showQueue(m)) + case key.Matches(msg, m.Keys.ViewContainer): + m.ActiveView = model.ViewModeContainer + cmds = append(cmds, c.loadContainer()) + case key.Matches(msg, m.Keys.ViewGPU): + m.ActiveView = model.ViewModeGPU + cmds = append(cmds, c.loadGPU()) + case key.Matches(msg, m.Keys.ViewJobs): + m.ActiveView = model.ViewModeJobs + case key.Matches(msg, m.Keys.ViewSettings): + m.ActiveView = model.ViewModeSettings + m.SettingsIndex = 1 // Start at Input field, skip Status + m.ApiKeyInput.Focus() + cmds = append(cmds, c.updateSettingsContent(m)) + case key.Matches(msg, m.Keys.ViewExperiments): + m.ActiveView = model.ViewModeExperiments + cmds = append(cmds, c.loadExperiments()) + case key.Matches(msg, m.Keys.Cancel): + if job := getSelectedJob(m); job != nil && job.TaskID != "" { + cmds = append(cmds, c.cancelTask(job.TaskID)) + } + case key.Matches(msg, m.Keys.Delete): + if job := getSelectedJob(m); job != nil && job.Status == model.StatusPending { + cmds = append(cmds, c.deleteJob(job.Name)) + } + case key.Matches(msg, m.Keys.MarkFailed): + if job := getSelectedJob(m); job != nil && job.Status == model.StatusRunning { + cmds = append(cmds, c.markFailed(job.Name)) + } + case key.Matches(msg, m.Keys.Help): + m.ShowHelp = !m.ShowHelp + } + + case tea.WindowSizeMsg: + m.Width = msg.Width + m.Height = msg.Height + + // Update component sizes + h, v := 4, 2 // docStyle.GetFrameSize() approx + listHeight := msg.Height - v - 8 + m.JobList.SetSize(msg.Width/3-h, listHeight) + + panelWidth := msg.Width*2/3 - h - 2 + panelHeight := (listHeight - 6) / 3 + + m.GpuView.Width = panelWidth + m.GpuView.Height = panelHeight + m.ContainerView.Width = panelWidth + m.ContainerView.Height = panelHeight + m.QueueView.Width = panelWidth + m.QueueView.Height = listHeight - 4 + m.SettingsView.Width = panelWidth + m.SettingsView.Height = listHeight - 4 + m.ExperimentsView.Width = panelWidth + m.ExperimentsView.Height = listHeight - 4 + + case JobsLoadedMsg: + m.Jobs = []model.Job(msg) + calculateJobStats(&m) + items := make([]list.Item, len(m.Jobs)) + for i, job := range m.Jobs { + items[i] = job + } + cmds = append(cmds, m.JobList.SetItems(items)) + m.Status = formatStatus(m) + m.IsLoading = false + + case TasksLoadedMsg: + m.QueuedTasks = []*model.Task(msg) + m.Status = formatStatus(m) + + case GpuLoadedMsg: + m.GpuView.SetContent(string(msg)) + m.GpuView.GotoTop() + + case ContainerLoadedMsg: + m.ContainerView.SetContent(string(msg)) + m.ContainerView.GotoTop() + + case QueueLoadedMsg: + m.QueueView.SetContent(string(msg)) + m.QueueView.GotoTop() + + case SettingsContentMsg: + m.SettingsView.SetContent(string(msg)) + + case ExperimentsLoadedMsg: + m.ExperimentsView.SetContent(string(msg)) + m.ExperimentsView.GotoTop() + + case SettingsUpdateMsg: + // Settings content was updated, just trigger a re-render + + case StatusMsg: + if msg.Level == "error" { + m.ErrorMsg = msg.Text + m.Status = "Error occurred - check status" + } else { + m.ErrorMsg = "" + m.Status = msg.Text + } + + case TickMsg: + var spinCmd tea.Cmd + m.Spinner, spinCmd = m.Spinner.Update(msg) + cmds = append(cmds, spinCmd) + + // Auto-refresh every 10 seconds + if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading { + m.LastRefresh = time.Now() + cmds = append(cmds, c.loadAllData()) + } + cmds = append(cmds, tickCmd()) + + default: + var spinCmd tea.Cmd + m.Spinner, spinCmd = m.Spinner.Update(msg) + cmds = append(cmds, spinCmd) + } + + // Update all bubble components + var cmd tea.Cmd + m.JobList, cmd = m.JobList.Update(msg) + cmds = append(cmds, cmd) + + m.GpuView, cmd = m.GpuView.Update(msg) + cmds = append(cmds, cmd) + + m.ContainerView, cmd = m.ContainerView.Update(msg) + cmds = append(cmds, cmd) + + m.QueueView, cmd = m.QueueView.Update(msg) + cmds = append(cmds, cmd) + + m.ExperimentsView, cmd = m.ExperimentsView.Update(msg) + cmds = append(cmds, cmd) + + return m, tea.Batch(cmds...) +} + +// ExperimentsLoadedMsg is sent when experiments are loaded +type ExperimentsLoadedMsg string + +func (c *Controller) loadExperiments() tea.Cmd { + return func() tea.Msg { + commitIDs, err := c.taskQueue.ListExperiments() + if err != nil { + return StatusMsg{Level: "error", Text: fmt.Sprintf("Failed to list experiments: %v", err)} + } + + if len(commitIDs) == 0 { + return ExperimentsLoadedMsg("Experiments:\n\nNo experiments found.") + } + + var output string + output += "Experiments:\n\n" + + for _, commitID := range commitIDs { + details, err := c.taskQueue.GetExperimentDetails(commitID) + if err != nil { + output += fmt.Sprintf("Error loading %s: %v\n\n", commitID, err) + continue + } + output += details + "\n----------------------------------------\n\n" + } + + return ExperimentsLoadedMsg(output) + } +} diff --git a/cmd/tui/internal/controller/helpers.go b/cmd/tui/internal/controller/helpers.go new file mode 100644 index 0000000..a4e30ed --- /dev/null +++ b/cmd/tui/internal/controller/helpers.go @@ -0,0 +1,69 @@ +package controller + +import ( + "fmt" + "strings" + + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" +) + +// Helper functions + +func (c *Controller) getPathForStatus(status model.JobStatus) string { + switch status { + case model.StatusPending: + return c.config.PendingPath() + case model.StatusRunning: + return c.config.RunningPath() + case model.StatusFinished: + return c.config.FinishedPath() + case model.StatusFailed: + return c.config.FailedPath() + } + return "" +} + +func getSelectedJob(m model.State) *model.Job { + if item := m.JobList.SelectedItem(); item != nil { + if job, ok := item.(model.Job); ok { + return &job + } + } + return nil +} + +func calculateJobStats(m *model.State) { + m.JobStats = make(map[model.JobStatus]int) + for _, job := range m.Jobs { + m.JobStats[job.Status]++ + } +} + +func formatStatus(m model.State) string { + var parts []string + + if len(m.Jobs) > 0 { + stats := []string{} + if count := m.JobStats[model.StatusPending]; count > 0 { + stats = append(stats, fmt.Sprintf("⏸ %d", count)) + } + if count := m.JobStats[model.StatusRunning]; count > 0 { + stats = append(stats, fmt.Sprintf("▶ %d", count)) + } + if count := m.JobStats[model.StatusFinished]; count > 0 { + stats = append(stats, fmt.Sprintf("✓ %d", count)) + } + if count := m.JobStats[model.StatusFailed]; count > 0 { + stats = append(stats, fmt.Sprintf("✗ %d", count)) + } + parts = append(parts, strings.Join(stats, " | ")) + } + + if len(m.QueuedTasks) > 0 { + parts = append(parts, fmt.Sprintf("Queue: %d", len(m.QueuedTasks))) + } + + parts = append(parts, fmt.Sprintf("Updated: %s", m.LastRefresh.Format("15:04:05"))) + + return strings.Join(parts, " • ") +} diff --git a/cmd/tui/internal/controller/settings.go b/cmd/tui/internal/controller/settings.go new file mode 100644 index 0000000..9c013b1 --- /dev/null +++ b/cmd/tui/internal/controller/settings.go @@ -0,0 +1,126 @@ +package controller + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" +) + +// Settings-related command factories and handlers + +func (c *Controller) updateSettingsContent(m model.State) tea.Cmd { + var content strings.Builder + + // API Key Status section + statusStyle := lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}). // borderfg + Padding(0, 1). + Width(m.SettingsView.Width - 4) + + if m.SettingsIndex == 0 { + statusStyle = statusStyle. + BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"}) // activeBorderfg + } + + statusContent := fmt.Sprintf("%s API Key Status\n%s", + getSettingsIndicator(m, 0), + getAPIKeyStatus(m)) + content.WriteString(statusStyle.Render(statusContent)) + content.WriteString("\n") + + // API Key Input section + inputStyle := lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}). + Padding(0, 1). + Width(m.SettingsView.Width - 4) + + if m.SettingsIndex == 1 { + inputStyle = inputStyle. + BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"}) + } + + inputContent := fmt.Sprintf("%s Enter New API Key\n%s", + getSettingsIndicator(m, 1), + m.ApiKeyInput.View()) + content.WriteString(inputStyle.Render(inputContent)) + content.WriteString("\n") + + // Save Configuration section + saveStyle := lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}). + Padding(0, 1). + Width(m.SettingsView.Width - 4) + + if m.SettingsIndex == 2 { + saveStyle = saveStyle. + BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"}) + } + + saveContent := fmt.Sprintf("%s Save Configuration\n[Enter]", + getSettingsIndicator(m, 2)) + content.WriteString(saveStyle.Render(saveContent)) + content.WriteString("\n") + + // Current API Key display + keyStyle := lipgloss.NewStyle(). + Foreground(lipgloss.AdaptiveColor{Light: "#666", Dark: "#999"}). + Italic(true) + + keyContent := fmt.Sprintf("Current API Key: %s", maskAPIKey(m.ApiKey)) + content.WriteString(keyStyle.Render(keyContent)) + + return func() tea.Msg { return SettingsContentMsg(content.String()) } +} + +func (c *Controller) handleSettingsAction(m *model.State) tea.Cmd { + switch m.SettingsIndex { + case 0: // API Key Status - do nothing + return nil + case 1: // Enter New API Key - do nothing, Enter key disabled + return nil + case 2: // Save Configuration + if m.ApiKeyInput.Value() != "" { + m.ApiKey = m.ApiKeyInput.Value() + m.ApiKeyInput.SetValue("") + m.Status = "Configuration saved (in-memory only)" + return c.updateSettingsContent(*m) + } else if m.ApiKey != "" { + m.Status = "Configuration saved (in-memory only)" + } else { + m.ErrorMsg = "No API key to save" + } + } + return nil +} + +// Helper functions for settings + +func getSettingsIndicator(m model.State, index int) string { + if index == m.SettingsIndex { + return "▶" + } + return " " +} + +func getAPIKeyStatus(m model.State) string { + if m.ApiKey != "" { + return "✓ API Key is set\n" + maskAPIKey(m.ApiKey) + } + return "⚠ No API Key configured" +} + +func maskAPIKey(key string) string { + if key == "" { + return "(not set)" + } + if len(key) <= 8 { + return "****" + } + return key[:4] + "****" + key[len(key)-4:] +} diff --git a/cmd/tui/internal/model/state.go b/cmd/tui/internal/model/state.go new file mode 100644 index 0000000..179c34e --- /dev/null +++ b/cmd/tui/internal/model/state.go @@ -0,0 +1,206 @@ +package model + +import ( + "fmt" + "time" + + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/spinner" + "github.com/charmbracelet/bubbles/textinput" + "github.com/charmbracelet/bubbles/viewport" + "github.com/charmbracelet/lipgloss" +) + +type ViewMode int + +const ( + ViewModeJobs ViewMode = iota + ViewModeGPU + ViewModeQueue + ViewModeContainer + ViewModeSettings + ViewModeDatasets + ViewModeExperiments +) + +type JobStatus string + +const ( + StatusPending JobStatus = "pending" + StatusQueued JobStatus = "queued" + StatusRunning JobStatus = "running" + StatusFinished JobStatus = "finished" + StatusFailed JobStatus = "failed" +) + +type Job struct { + Name string + Status JobStatus + TaskID string + Priority int64 +} + +func (j Job) Title() string { return j.Name } +func (j Job) Description() string { + icon := map[JobStatus]string{ + StatusPending: "⏸", + StatusQueued: "⏳", + StatusRunning: "▶", + StatusFinished: "✓", + StatusFailed: "✗", + }[j.Status] + pri := "" + if j.Priority > 0 { + pri = fmt.Sprintf(" [P%d]", j.Priority) + } + return fmt.Sprintf("%s %s%s", icon, j.Status, pri) +} +func (j Job) FilterValue() string { return j.Name } + +type Task struct { + ID string `json:"id"` + JobName string `json:"job_name"` + Args string `json:"args"` + Status string `json:"status"` + Priority int64 `json:"priority"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` + Error string `json:"error,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type DatasetInfo struct { + Name string `json:"name"` + SizeBytes int64 `json:"size_bytes"` + Location string `json:"location"` + LastAccess time.Time `json:"last_access"` +} + +// State holds the application state +type State struct { + Jobs []Job + QueuedTasks []*Task + Datasets []DatasetInfo + JobList list.Model + GpuView viewport.Model + ContainerView viewport.Model + QueueView viewport.Model + SettingsView viewport.Model + DatasetView viewport.Model + ExperimentsView viewport.Model + Input textinput.Model + ApiKeyInput textinput.Model + Status string + ErrorMsg string + InputMode bool + Width int + Height int + ShowHelp bool + Spinner spinner.Model + ActiveView ViewMode + LastRefresh time.Time + IsLoading bool + JobStats map[JobStatus]int + ApiKey string + SettingsIndex int + Keys KeyMap +} + +type KeyMap struct { + Refresh key.Binding + Trigger key.Binding + TriggerArgs key.Binding + ViewQueue key.Binding + ViewContainer key.Binding + ViewGPU key.Binding + ViewJobs key.Binding + ViewDatasets key.Binding + ViewExperiments key.Binding + ViewSettings key.Binding + Cancel key.Binding + Delete key.Binding + MarkFailed key.Binding + RefreshGPU key.Binding + Help key.Binding + Quit key.Binding +} + +var Keys = KeyMap{ + Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")), + Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")), + TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")), + ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")), + ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")), + ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")), + ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")), + ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")), + ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")), + Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")), + Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")), + MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")), + RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")), + ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")), + Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")), + Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")), +} + +func InitialState(apiKey string) State { + items := []list.Item{} + delegate := list.NewDefaultDelegate() + delegate.Styles.SelectedTitle = delegate.Styles.SelectedTitle. + Foreground(lipgloss.Color("170")). + Bold(true) + delegate.Styles.SelectedDesc = delegate.Styles.SelectedDesc. + Foreground(lipgloss.Color("246")) + + jobList := list.New(items, delegate, 0, 0) + jobList.Title = "ML Jobs & Queue" + jobList.SetShowStatusBar(true) + jobList.SetFilteringEnabled(true) + jobList.SetShowHelp(false) + // Styles will be set in View or here? + // Keeping style initialization here as it's part of the model state setup + jobList.Styles.Title = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"}). + Padding(0, 0, 1, 0) + + input := textinput.New() + input.Placeholder = "Args: --epochs 100 --lr 0.001 --priority 5" + input.Width = 60 + input.CharLimit = 200 + + apiKeyInput := textinput.New() + apiKeyInput.Placeholder = "Enter API key..." + apiKeyInput.Width = 40 + apiKeyInput.CharLimit = 200 + + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"}) + + return State{ + JobList: jobList, + GpuView: viewport.New(0, 0), + ContainerView: viewport.New(0, 0), + QueueView: viewport.New(0, 0), + SettingsView: viewport.New(0, 0), + DatasetView: viewport.New(0, 0), + ExperimentsView: viewport.New(0, 0), + Input: input, + ApiKeyInput: apiKeyInput, + Status: "Connected", + InputMode: false, + ShowHelp: false, + Spinner: s, + ActiveView: ViewModeJobs, + LastRefresh: time.Now(), + IsLoading: false, + JobStats: make(map[JobStatus]int), + ApiKey: apiKey, + SettingsIndex: 0, + Keys: Keys, + } +} diff --git a/cmd/tui/internal/services/services.go b/cmd/tui/internal/services/services.go new file mode 100644 index 0000000..fe382b5 --- /dev/null +++ b/cmd/tui/internal/services/services.go @@ -0,0 +1,237 @@ +package services + +import ( + "context" + "fmt" + + "github.com/jfraeys/fetch_ml/cmd/tui/internal/config" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/network" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility +type TaskQueue struct { + internal *queue.TaskQueue + expManager *experiment.Manager + ctx context.Context +} + +func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) { + // Create internal queue config + queueCfg := queue.Config{ + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + } + + internalQueue, err := queue.NewTaskQueue(queueCfg) + if err != nil { + return nil, fmt.Errorf("failed to create task queue: %w", err) + } + + // Initialize experiment manager + // TODO: Get base path from config + expManager := experiment.NewManager("./experiments") + + return &TaskQueue{ + internal: internalQueue, + expManager: expManager, + ctx: context.Background(), + }, nil +} + +func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) { + // Create internal task + internalTask := &queue.Task{ + JobName: jobName, + Args: args, + Priority: priority, + } + + // Use internal queue to enqueue + err := tq.internal.AddTask(internalTask) + if err != nil { + return nil, err + } + + // Convert to TUI model + return &model.Task{ + ID: internalTask.ID, + JobName: internalTask.JobName, + Args: internalTask.Args, + Status: "queued", + Priority: int64(internalTask.Priority), + CreatedAt: internalTask.CreatedAt, + Metadata: internalTask.Metadata, + }, nil +} + +func (tq *TaskQueue) GetNextTask() (*model.Task, error) { + internalTask, err := tq.internal.GetNextTask() + if err != nil { + return nil, err + } + if internalTask == nil { + return nil, nil + } + + // Convert to TUI model + return &model.Task{ + ID: internalTask.ID, + JobName: internalTask.JobName, + Args: internalTask.Args, + Status: internalTask.Status, + Priority: internalTask.Priority, + CreatedAt: internalTask.CreatedAt, + Metadata: internalTask.Metadata, + }, nil +} + +func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) { + internalTask, err := tq.internal.GetTask(taskID) + if err != nil { + return nil, err + } + + // Convert to TUI model + return &model.Task{ + ID: internalTask.ID, + JobName: internalTask.JobName, + Args: internalTask.Args, + Status: internalTask.Status, + Priority: internalTask.Priority, + CreatedAt: internalTask.CreatedAt, + Metadata: internalTask.Metadata, + }, nil +} + +func (tq *TaskQueue) UpdateTask(task *model.Task) error { + // Convert to internal task + internalTask := &queue.Task{ + ID: task.ID, + JobName: task.JobName, + Args: task.Args, + Status: task.Status, + Priority: task.Priority, + CreatedAt: task.CreatedAt, + Metadata: task.Metadata, + } + + return tq.internal.UpdateTask(internalTask) +} + +func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) { + internalTasks, err := tq.internal.GetAllTasks() + if err != nil { + return nil, err + } + + // Convert to TUI models + tasks := make([]*model.Task, len(internalTasks)) + for i, task := range internalTasks { + tasks[i] = &model.Task{ + ID: task.ID, + JobName: task.JobName, + Args: task.Args, + Status: task.Status, + Priority: task.Priority, + CreatedAt: task.CreatedAt, + Metadata: task.Metadata, + } + } + + return tasks, nil +} + +func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) { + // This method doesn't exist in internal queue, implement basic version + task, err := tq.internal.GetTaskByName(jobName) + if err != nil { + return nil, err + } + if task == nil { + return map[string]string{"status": "not_found"}, nil + } + + return map[string]string{ + "status": task.Status, + "task_id": task.ID, + }, nil +} + +func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error { + return tq.internal.RecordMetric(jobName, metric, value) +} + +func (tq *TaskQueue) GetMetrics(jobName string) (map[string]string, error) { + // This method doesn't exist in internal queue, return empty for now + return map[string]string{}, nil +} + +func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) { + // This method doesn't exist in internal queue, return empty for now + return []model.DatasetInfo{}, nil +} + +func (tq *TaskQueue) CancelTask(taskID string) error { + return tq.internal.CancelTask(taskID) +} + +func (tq *TaskQueue) ListExperiments() ([]string, error) { + return tq.expManager.ListExperiments() +} + +func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) { + meta, err := tq.expManager.ReadMetadata(commitID) + if err != nil { + return "", err + } + + metrics, err := tq.expManager.GetMetrics(commitID) + if err != nil { + return "", err + } + + output := fmt.Sprintf("Experiment: %s\n", meta.JobName) + output += fmt.Sprintf("Commit ID: %s\n", meta.CommitID) + output += fmt.Sprintf("User: %s\n", meta.User) + output += fmt.Sprintf("Timestamp: %d\n\n", meta.Timestamp) + output += "Metrics:\n" + + if len(metrics) == 0 { + output += " No metrics logged.\n" + } else { + for _, m := range metrics { + output += fmt.Sprintf(" %s: %.4f (Step: %d)\n", m.Name, m.Value, m.Step) + } + } + + return output, nil +} + +func (tq *TaskQueue) Close() error { + return tq.internal.Close() +} + +// MLServer wraps network.SSHClient for backward compatibility +type MLServer struct { + *network.SSHClient + addr string +} + +func NewMLServer(cfg *config.Config) (*MLServer, error) { + // Local mode: skip SSH entirely + if cfg.Host == "" { + client, _ := network.NewSSHClient("", "", "", 0, "") + return &MLServer{SSHClient: client, addr: "localhost"}, nil + } + + client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) + if err != nil { + return nil, err + } + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + return &MLServer{SSHClient: client, addr: addr}, nil +} diff --git a/cmd/tui/internal/view/view.go b/cmd/tui/internal/view/view.go new file mode 100644 index 0000000..b6b8151 --- /dev/null +++ b/cmd/tui/internal/view/view.go @@ -0,0 +1,255 @@ +package view + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" +) + +const ( + headerfgLight = "#d35400" + headerfgDark = "#ff9e64" + activeBorderfgLight = "#3498db" + activeBorderfgDark = "#7aa2f7" + errorbgLight = "#fee" + errorbgDark = "#633" + errorfgLight = "#a00" + errorfgDark = "#faa" + titlefgLight = "#d35400" + titlefgDark = "#ff9e64" + statusfgLight = "#2e3440" + statusfgDark = "#d8dee9" + statusbgLight = "#e5e9f0" + statusbgDark = "#2e3440" + borderfgLight = "#d8dee9" + borderfgDark = "#4c566a" + helpfgLight = "#4c566a" + helpfgDark = "#88c0d0" +) + +var ( + docStyle = lipgloss.NewStyle().Margin(1, 2) + + activeBorderStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.AdaptiveColor{Light: activeBorderfgLight, Dark: activeBorderfgDark}). + Padding(1, 2) + + errorStyle = lipgloss.NewStyle(). + Background(lipgloss.AdaptiveColor{Light: errorbgLight, Dark: errorbgDark}). + Foreground(lipgloss.AdaptiveColor{Light: errorfgLight, Dark: errorfgDark}). + Padding(0, 1). + Bold(true) + + titleStyle = (lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: titlefgLight, Dark: titlefgDark}). + MarginBottom(1)) + + statusStyle = (lipgloss.NewStyle(). + Background(lipgloss.AdaptiveColor{Light: statusbgLight, Dark: statusbgDark}). + Foreground(lipgloss.AdaptiveColor{Light: statusfgLight, Dark: statusfgDark}). + Padding(0, 1)) + + borderStyle = (lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.AdaptiveColor{Light: borderfgLight, Dark: borderfgDark}). + Padding(0, 1)) + + helpStyle = (lipgloss.NewStyle(). + Foreground(lipgloss.AdaptiveColor{Light: helpfgLight, Dark: helpfgDark})) +) + +func Render(m model.State) string { + if m.Width == 0 { + return "Loading..." + } + + // Title + title := titleStyle.Width(m.Width - 4).Render("🤖 ML Experiment Manager") + + // Left panel - Job list (30% width) + leftWidth := int(float64(m.Width) * 0.3) + leftPanel := getJobListPanel(m, leftWidth) + + // Right panel - Dynamic content (70% width) + rightWidth := m.Width - leftWidth - 4 + rightPanel := getRightPanel(m, rightWidth) + + // Main content + main := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) + + // Status bar + statusBar := getStatusBar(m) + + // Error bar (if present) + var errorBar string + if m.ErrorMsg != "" { + errorBar = errorStyle.Width(m.Width - 4).Render("⚠ Error: " + m.ErrorMsg) + } + + // Help view (toggleable) + var helpView string + if m.ShowHelp { + helpView = helpStyle.Width(m.Width-4). + Padding(1, 2). + Render(helpText(m)) + } + + // Quick help bar + quickHelp := helpStyle.Width(m.Width - 4).Render(getQuickHelp(m)) + + // Compose final layout + parts := []string{title, main, statusBar} + if errorBar != "" { + parts = append(parts, errorBar) + } + if helpView != "" { + parts = append(parts, helpView) + } + parts = append(parts, quickHelp) + + return docStyle.Render(lipgloss.JoinVertical(lipgloss.Left, parts...)) +} + +func getJobListPanel(m model.State, width int) string { + style := borderStyle + if m.ActiveView == model.ViewModeJobs { + style = activeBorderStyle + } + // Ensure the job list has proper dimensions to prevent rendering issues + // Note: We can't modify the model here as it's passed by value, + // but the View() method of list.Model uses its internal state. + // Ideally, the controller should have set the size. + // For now, we assume the controller handles resizing or we act on a copy. + // But list.Model.SetSize modifies the model. + // Since we receive 'm' by value, modifications to m.JobList won't persist. + // However, we need to render it with the correct size. + // So we can modify our local copy 'm'. + h, v := style.GetFrameSize() + m.JobList.SetSize(width-h, m.Height-v-4) // Adjust height for title/help/status + + // Custom empty state + if len(m.JobList.Items()) == 0 { + return style.Width(width - h).Render( + lipgloss.JoinVertical(lipgloss.Left, + m.JobList.Styles.Title.Render(m.JobList.Title), + "\n No jobs found.", + " Press 't' to queue.", + ), + ) + } + + return style.Width(width - h).Render(m.JobList.View()) +} + +func getRightPanel(m model.State, width int) string { + var content string + var viewTitle string + style := borderStyle + + switch m.ActiveView { + case model.ViewModeGPU: + style = activeBorderStyle + viewTitle = "🎮 GPU Status" + content = m.GpuView.View() + case model.ViewModeContainer: + style = activeBorderStyle + viewTitle = "🐳 Container Status" + content = m.ContainerView.View() + case model.ViewModeQueue: + style = activeBorderStyle + viewTitle = "⏳ Task Queue" + content = m.QueueView.View() + case model.ViewModeSettings: + style = activeBorderStyle + viewTitle = "⚙️ Settings" + content = m.SettingsView.View() + case model.ViewModeExperiments: + style = activeBorderStyle + viewTitle = "🧪 Experiments" + content = m.ExperimentsView.View() + default: + viewTitle = "📊 System Overview" + content = getOverviewPanel(m) + } + + header := lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: headerfgLight, Dark: headerfgDark}). + Render(viewTitle) + + h, _ := style.GetFrameSize() + return style.Width(width - h).Render(header + "\n\n" + content) +} + +func getOverviewPanel(m model.State) string { + var sections []string + + sections = append(sections, "🎮 GPU\n"+strings.Repeat("─", 40)) + sections = append(sections, m.GpuView.View()) + sections = append(sections, "\n🐳 Containers\n"+strings.Repeat("─", 40)) + sections = append(sections, m.ContainerView.View()) + + return strings.Join(sections, "\n") +} + +func getStatusBar(m model.State) string { + spinnerStr := m.Spinner.View() + if !m.IsLoading { + if m.ShowHelp { + spinnerStr = "?" + } else { + spinnerStr = "●" + } + } + statusText := m.Status + if m.ShowHelp { + statusText = "Press 'h' to hide help" + } + return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText) +} + +func helpText(m model.State) string { + if m.ActiveView == model.ViewModeSettings { + return `╔═══════════════════════════════════════════════════════════════╗ +║ Settings Shortcuts ║ +╠═══════════════════════════════════════════════════════════════╣ +║ Navigation ║ +║ j/k, ↑/↓ : Move selection ║ +║ Enter : Edit / Save ║ +║ Esc : Exit Settings ║ +║ ║ +║ General ║ +║ h or ? : Toggle this help q/Ctrl+C : Quit ║ +╚═══════════════════════════════════════════════════════════════╝` + } + + return `╔═══════════════════════════════════════════════════════════════╗ +║ Keyboard Shortcuts ║ +╠═══════════════════════════════════════════════════════════════╣ +║ Navigation ║ +║ j/k, ↑/↓ : Move selection / : Filter jobs ║ +║ 1 : Job list view 2 : Datasets view ║ +║ 3 : Experiments view v : Queue view ║ +║ g : GPU view o : Container view ║ +║ s : Settings view ║ +║ ║ +║ Actions ║ +║ t : Queue job a : Queue w/ args ║ +║ c : Cancel task d : Delete pending ║ +║ f : Mark as failed r : Refresh all ║ +║ G : Refresh GPU only ║ +║ ║ +║ General ║ +║ h or ? : Toggle this help q/Ctrl+C : Quit ║ +╚═══════════════════════════════════════════════════════════════╝` +} + +func getQuickHelp(m model.State) string { + if m.ActiveView == model.ViewModeSettings { + return " ↑/↓:move enter:select esc:exit settings q:quit" + } + return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers s:settings t:queue r:refresh q:quit" +} diff --git a/cmd/tui/main.go b/cmd/tui/main.go new file mode 100644 index 0000000..d906977 --- /dev/null +++ b/cmd/tui/main.go @@ -0,0 +1,204 @@ +// Package main implements the ML TUI +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + + tea "github.com/charmbracelet/bubbletea" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/config" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/controller" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/services" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/view" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/logging" +) + +type AppModel struct { + state model.State + controller *controller.Controller +} + +func (m AppModel) Init() tea.Cmd { + return m.controller.Init() +} + +func (m AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + newState, cmd := m.controller.Update(msg, m.state) + m.state = newState + return m, cmd +} + +func (m AppModel) View() string { + return view.Render(m.state) +} + +func main() { + // Parse authentication flags + authFlags := auth.ParseAuthFlags() + if err := auth.ValidateAuthFlags(authFlags); err != nil { + log.Fatalf("Authentication flag error: %v", err) + } + + // Get API key from various sources + apiKey := auth.GetAPIKeyFromSources(authFlags) + + var ( + cfg *config.Config + cliConfig *config.CLIConfig + cliConfPath string + ) + + configFlag := authFlags.ConfigFile + + // Only support TOML configuration + var err error + cliConfig, cliConfPath, err = config.LoadCLIConfig(configFlag) + if err != nil { + if configFlag != "" { + log.Fatalf("Failed to load TOML config %s: %v", configFlag, err) + } else { + // Provide helpful error message for data scientists + log.Printf("=== Fetch ML TUI - Configuration Required ===") + log.Printf("") + log.Printf("Error: %v", err) + log.Printf("") + log.Printf("To get started with the TUI, you need to initialize your configuration:") + log.Printf("") + log.Printf("Option 1: Using the Zig CLI (Recommended)") + log.Printf(" 1. Build the CLI: cd cli && make build") + log.Printf(" 2. Initialize config: ./cli/zig-out/bin/ml init") + log.Printf(" 3. Edit ~/.ml/config.toml with your settings") + log.Printf(" 4. Run TUI: ./bin/tui") + log.Printf("") + log.Printf("Option 2: Manual Configuration") + log.Printf(" 1. Create directory: mkdir -p ~/.ml") + log.Printf(" 2. Create config: touch ~/.ml/config.toml") + log.Printf(" 3. Add your settings to the file") + log.Printf(" 4. Run TUI: ./bin/tui") + log.Printf("") + log.Printf("Example ~/.ml/config.toml:") + log.Printf(" worker_host = \"localhost\"") + log.Printf(" worker_user = \"your_username\"") + log.Printf(" worker_base = \"~/ml_jobs\"") + log.Printf(" worker_port = 22") + log.Printf(" api_key = \"your_api_key_here\"") + log.Printf("") + log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs") + os.Exit(1) + } + } + + cfg = cliConfig.ToTUIConfig() + log.Printf("Loaded TOML configuration from %s", cliConfPath) + + // Validate authentication configuration + if err := cfg.Auth.ValidateAuthConfig(); err != nil { + log.Fatalf("Invalid authentication configuration: %v", err) + } + + if err := cfg.Validate(); err != nil { + log.Fatalf("Invalid configuration: %v", err) + } + + // Test authentication if enabled + if cfg.Auth.Enabled { + // Use API key from CLI config if available, otherwise use from flags + var effectiveAPIKey string + if cliConfig != nil && cliConfig.APIKey != "" { + effectiveAPIKey = cliConfig.APIKey + } else if apiKey != "" { + effectiveAPIKey = apiKey + } else { + log.Fatal("Authentication required but no API key provided") + } + + if _, err := cfg.Auth.ValidateAPIKey(effectiveAPIKey); err != nil { + log.Fatalf("Authentication failed: %v", err) + } + } + + srv, err := services.NewMLServer(cfg) + if err != nil { + log.Fatalf("Failed to connect to server: %v", err) + } + defer func() { + if err := srv.Close(); err != nil { + log.Printf("server close error: %v", err) + } + }() + + tq, err := services.NewTaskQueue(cfg) + if err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + defer func() { + if err := tq.Close(); err != nil { + log.Printf("task queue close error: %v", err) + } + }() + + // Initialize logger + // Note: In original code, logger was created inside initialModel. + // Here we create it and pass it to controller. + // We use slog.LevelError as default from original code. + // But original code imported "log/slog". + // We use internal/logging package. + // Check logging package signature. + // Original: logger := logging.NewLogger(slog.LevelError, false) + // We need to import "log/slog" in main if we use slog constants. + // Or use logging package constants if available. + // Let's check logging package. + // Assuming logging.NewLogger takes (slog.Level, bool). + // I'll import "log/slog". + + // Wait, I need to import "log/slog" + logger := logging.NewLogger(-4, false) // -4 is slog.LevelError value. Or I can import log/slog. + + // Initialize State and Controller + var effectiveAPIKey string + if cliConfig != nil && cliConfig.APIKey != "" { + effectiveAPIKey = cliConfig.APIKey + } else { + effectiveAPIKey = apiKey + } + + initialState := model.InitialState(effectiveAPIKey) + ctrl := controller.New(cfg, srv, tq, logger) + + appModel := AppModel{ + state: initialState, + controller: ctrl, + } + + // Run TUI app + p := tea.NewProgram(appModel, tea.WithAltScreen(), tea.WithMouseAllMotion()) + + // Ensure we restore the terminal even if panic or error occurs + // Note: p.Run() usually handles this, but explicit cleanup is safer + // if we want to ensure the alt screen is exited. + // We can't defer p.ReleaseTerminal() here because p is created here. + // But we can defer a function that calls it. + + // Set up signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Run program and handle signals + go func() { + <-sigChan + p.Quit() + }() + + if _, err := p.Run(); err != nil { + // Attempt to restore terminal before logging fatal error + p.ReleaseTerminal() + log.Fatalf("Error running TUI: %v", err) + } + + // Explicitly restore terminal after program exits + p.ReleaseTerminal() +} diff --git a/cmd/user_manager/main.go b/cmd/user_manager/main.go new file mode 100644 index 0000000..428b140 --- /dev/null +++ b/cmd/user_manager/main.go @@ -0,0 +1,175 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "github.com/jfraeys/fetch_ml/internal/auth" + "gopkg.in/yaml.v3" +) + +type ConfigWithAuth struct { + Auth auth.AuthConfig `yaml:"auth"` +} + +func main() { + var ( + configFile = flag.String("config", "", "Configuration file path") + command = flag.String("cmd", "", "Command: generate-key, list-users, hash-key") + username = flag.String("username", "", "Username for generate-key") + role = flag.String("role", "", "Role for generate-key") + admin = flag.Bool("admin", false, "Admin flag for generate-key") + apiKey = flag.String("key", "", "API key to hash") + ) + flag.Parse() + + if *configFile == "" || *command == "" { + fmt.Println("Usage: user_manager --config --cmd [options]") + fmt.Println("Commands: generate-key, list-users, hash-key") + os.Exit(1) + } + + switch *command { + case "generate-key": + if *username == "" { + log.Fatal("Usage: --cmd generate-key --username [--admin] [--role ]") + } + + // Load config + data, err := os.ReadFile(*configFile) + if err != nil { + log.Fatalf("Failed to read config: %v", err) + } + + var config ConfigWithAuth + if err := yaml.Unmarshal(data, &config); err != nil { + log.Fatalf("Failed to parse config: %v", err) + } + + // Generate API key + apiKey := auth.GenerateAPIKey() + + // Setup user + if config.Auth.APIKeys == nil { + config.Auth.APIKeys = make(map[auth.Username]auth.APIKeyEntry) + } + + adminStatus := *admin + roles := []string{"viewer"} + permissions := make(map[string]bool) + + if !adminStatus && *role == "" { + fmt.Printf("Make user '%s' an admin? (y/N): ", *username) + var response string + fmt.Scanln(&response) + adminStatus = strings.ToLower(strings.TrimSpace(response)) == "y" + } + + if adminStatus { + roles = []string{"admin"} + permissions["*"] = true + } else if *role != "" { + roles = []string{*role} + rolePerms := getRolePermissions(*role) + for perm, value := range rolePerms { + permissions[perm] = value + } + } + + // Save user + config.Auth.APIKeys[auth.Username(*username)] = auth.APIKeyEntry{ + Hash: auth.APIKeyHash(auth.HashAPIKey(apiKey)), + Admin: adminStatus, + Roles: roles, + Permissions: permissions, + } + + data, err = yaml.Marshal(config) + if err != nil { + log.Fatalf("Failed to marshal config: %v", err) + } + + if err := os.WriteFile(*configFile, data, 0600); err != nil { + log.Fatalf("Failed to write config: %v", err) + } + + fmt.Printf("Generated API key for user '%s':\nKey: %s\n", *username, apiKey) + + case "list-users": + data, err := os.ReadFile(*configFile) + if err != nil { + log.Fatalf("Failed to read config: %v", err) + } + + var config ConfigWithAuth + if err := yaml.Unmarshal(data, &config); err != nil { + log.Fatalf("Failed to parse config: %v", err) + } + + fmt.Println("Configured Users:") + fmt.Println("=================") + for username, entry := range config.Auth.APIKeys { + fmt.Printf("User: %s\n", string(username)) + fmt.Printf(" Admin: %v\n", entry.Admin) + if len(entry.Roles) > 0 { + fmt.Printf(" Roles: %v\n", entry.Roles) + } + if len(entry.Permissions) > 0 { + fmt.Printf(" Permissions: %d\n", len(entry.Permissions)) + } + fmt.Printf(" Key Hash: %s...\n\n", string(entry.Hash)[:8]) + } + + case "hash-key": + if *apiKey == "" { + log.Fatal("Usage: --cmd hash-key --key ") + } + hash := auth.HashAPIKey(*apiKey) + fmt.Printf("Hash: %s\n", hash) + + default: + log.Fatalf("Unknown command: %s", *command) + } +} + +// getRolePermissions returns permissions for a role +func getRolePermissions(role string) map[string]bool { + rolePermissions := map[string]map[string]bool{ + "admin": { + "*": true, + }, + "data_scientist": { + "jobs:create": true, + "jobs:read": true, + "jobs:update": true, + "data:read": true, + "models:read": true, + }, + "data_engineer": { + "data:create": true, + "data:read": true, + "data:update": true, + "data:delete": true, + }, + "viewer": { + "jobs:read": true, + "data:read": true, + "models:read": true, + "metrics:read": true, + }, + "operator": { + "jobs:read": true, + "jobs:update": true, + "metrics:read": true, + "system:read": true, + }, + } + + if perms, exists := rolePermissions[role]; exists { + return perms + } + return make(map[string]bool) +} diff --git a/cmd/worker/worker_config.go b/cmd/worker/worker_config.go new file mode 100644 index 0000000..7aabf13 --- /dev/null +++ b/cmd/worker/worker_config.go @@ -0,0 +1,173 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/google/uuid" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "gopkg.in/yaml.v3" +) + +const ( + defaultMetricsFlushInterval = 500 * time.Millisecond + datasetCacheDefaultTTL = 30 * time.Minute +) + +// Config holds worker configuration +type Config struct { + Host string `yaml:"host"` + User string `yaml:"user"` + SSHKey string `yaml:"ssh_key"` + Port int `yaml:"port"` + BasePath string `yaml:"base_path"` + TrainScript string `yaml:"train_script"` + RedisAddr string `yaml:"redis_addr"` + RedisPassword string `yaml:"redis_password"` + RedisDB int `yaml:"redis_db"` + KnownHosts string `yaml:"known_hosts"` + WorkerID string `yaml:"worker_id"` + MaxWorkers int `yaml:"max_workers"` + PollInterval int `yaml:"poll_interval_seconds"` + + // Authentication + Auth auth.AuthConfig `yaml:"auth"` + + // Metrics exporter + Metrics MetricsConfig `yaml:"metrics"` + // Metrics buffering + MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` + + // Data management + DataManagerPath string `yaml:"data_manager_path"` + AutoFetchData bool `yaml:"auto_fetch_data"` + DataDir string `yaml:"data_dir"` + DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` + + // Podman execution + PodmanImage string `yaml:"podman_image"` + ContainerWorkspace string `yaml:"container_workspace"` + ContainerResults string `yaml:"container_results"` + GPUAccess bool `yaml:"gpu_access"` + + // Task lease and retry settings + TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min) + HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min) + MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) + GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min) +} + +// MetricsConfig controls the Prometheus exporter. +type MetricsConfig struct { + Enabled bool `yaml:"enabled"` + ListenAddr string `yaml:"listen_addr"` +} + +func LoadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + // Get smart defaults for current environment + smart := config.GetSmartDefaults() + + if cfg.Port == 0 { + cfg.Port = config.DefaultSSHPort + } + if cfg.Host == "" { + cfg.Host = smart.Host() + } + if cfg.BasePath == "" { + cfg.BasePath = smart.BasePath() + } + if cfg.RedisAddr == "" { + cfg.RedisAddr = smart.RedisAddr() + } + if cfg.KnownHosts == "" { + cfg.KnownHosts = smart.KnownHostsPath() + } + if cfg.WorkerID == "" { + cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) + } + if cfg.MaxWorkers == 0 { + cfg.MaxWorkers = smart.MaxWorkers() + } + if cfg.PollInterval == 0 { + cfg.PollInterval = smart.PollInterval() + } + if cfg.DataManagerPath == "" { + cfg.DataManagerPath = "./data_manager" + } + if cfg.DataDir == "" { + if cfg.Host == "" || !cfg.AutoFetchData { + cfg.DataDir = config.DefaultLocalDataDir + } else { + cfg.DataDir = smart.DataDir() + } + } + if cfg.Metrics.ListenAddr == "" { + cfg.Metrics.ListenAddr = ":9100" + } + if cfg.MetricsFlushInterval == 0 { + cfg.MetricsFlushInterval = defaultMetricsFlushInterval + } + if cfg.DatasetCacheTTL == 0 { + cfg.DatasetCacheTTL = datasetCacheDefaultTTL + } + + // Set lease and retry defaults + if cfg.TaskLeaseDuration == 0 { + cfg.TaskLeaseDuration = 30 * time.Minute + } + if cfg.HeartbeatInterval == 0 { + cfg.HeartbeatInterval = 1 * time.Minute + } + if cfg.MaxRetries == 0 { + cfg.MaxRetries = 3 + } + if cfg.GracefulTimeout == 0 { + cfg.GracefulTimeout = 5 * time.Minute + } + + return &cfg, nil +} + +// Validate implements config.Validator interface +func (c *Config) Validate() error { + if c.Port != 0 { + if err := config.ValidatePort(c.Port); err != nil { + return fmt.Errorf("invalid SSH port: %w", err) + } + } + + if c.BasePath != "" { + // Convert relative paths to absolute + c.BasePath = config.ExpandPath(c.BasePath) + if !filepath.IsAbs(c.BasePath) { + c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath) + } + } + + if c.RedisAddr != "" { + if err := config.ValidateRedisAddr(c.RedisAddr); err != nil { + return fmt.Errorf("invalid Redis configuration: %w", err) + } + } + + if c.MaxWorkers < 1 { + return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) + } + + return nil +} + +// Task struct and Redis constants moved to internal/queue diff --git a/cmd/worker/worker_server.go b/cmd/worker/worker_server.go new file mode 100644 index 0000000..d15b4d5 --- /dev/null +++ b/cmd/worker/worker_server.go @@ -0,0 +1,883 @@ +// Package main implements the ML task worker +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/errors" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/network" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/telemetry" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// MLServer wraps network.SSHClient for backward compatibility +type MLServer struct { + *network.SSHClient +} + +func NewMLServer(cfg *Config) (*MLServer, error) { + client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) + if err != nil { + return nil, err + } + return &MLServer{SSHClient: client}, nil +} + +type Worker struct { + id string + config *Config + server *MLServer + queue *queue.TaskQueue + running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown + runningMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + logger *logging.Logger + metrics *metrics.Metrics + metricsSrv *http.Server + + datasetCache map[string]time.Time + datasetCacheMu sync.RWMutex + datasetCacheTTL time.Duration + + // Graceful shutdown fields + shutdownCh chan struct{} + activeTasks sync.Map // map[string]*queue.Task - track active tasks + gracefulWait sync.WaitGroup +} + +func (w *Worker) setupMetricsExporter() error { + if !w.config.Metrics.Enabled { + return nil + } + + reg := prometheus.NewRegistry() + reg.MustRegister( + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), + collectors.NewGoCollector(), + ) + + labels := prometheus.Labels{"worker_id": w.id} + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_processed_total", + Help: "Total tasks processed successfully by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksProcessed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_failed_total", + Help: "Total tasks failed by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksFailed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_active", + Help: "Number of tasks currently running on this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.runningCount()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_queued", + Help: "Latest observed queue depth from Redis.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.QueuedTasks.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_transferred_bytes_total", + Help: "Total bytes transferred while fetching datasets.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataTransferred.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_fetch_time_seconds_total", + Help: "Total time spent fetching datasets (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_execution_time_seconds_total", + Help: "Total execution time for completed tasks (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_worker_max_concurrency", + Help: "Configured maximum concurrent tasks for this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.config.MaxWorkers) + })) + + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + + srv := &http.Server{ + Addr: w.config.Metrics.ListenAddr, + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + w.metricsSrv = srv + go func() { + w.logger.Info("metrics exporter listening", + "addr", w.config.Metrics.ListenAddr, + "enabled", w.config.Metrics.Enabled) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + w.logger.Warn("metrics exporter stopped", + "error", err) + } + }() + + return nil +} + +func NewWorker(cfg *Config, apiKey string) (*Worker, error) { + srv, err := NewMLServer(cfg) + if err != nil { + return nil, err + } + + queueCfg := queue.Config{ + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + MetricsFlushInterval: cfg.MetricsFlushInterval, + } + queue, err := queue.NewTaskQueue(queueCfg) + if err != nil { + return nil, err + } + + // Create data_dir if it doesn't exist (for production without NAS) + if cfg.DataDir != "" { + if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { + log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + ctx = logging.EnsureTrace(ctx) + ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) + + baseLogger := logging.NewLogger(slog.LevelInfo, false) + logger := baseLogger.Component(ctx, "worker") + metrics := &metrics.Metrics{} + + worker := &Worker{ + id: cfg.WorkerID, + config: cfg, + server: srv, + queue: queue, + running: make(map[string]context.CancelFunc), + datasetCache: make(map[string]time.Time), + datasetCacheTTL: cfg.DatasetCacheTTL, + ctx: ctx, + cancel: cancel, + logger: logger, + metrics: metrics, + shutdownCh: make(chan struct{}), + } + + if err := worker.setupMetricsExporter(); err != nil { + return nil, err + } + + return worker, nil +} + +func (w *Worker) Start() { + w.logger.Info("worker started", + "worker_id", w.id, + "max_concurrent", w.config.MaxWorkers, + "poll_interval", w.config.PollInterval) + + go w.heartbeat() + + for { + select { + case <-w.ctx.Done(): + w.logger.Info("shutdown signal received, waiting for tasks") + w.waitForTasks() + return + default: + } + + if w.runningCount() >= w.config.MaxWorkers { + time.Sleep(50 * time.Millisecond) + continue + } + + queueStart := time.Now() + task, err := w.queue.GetNextTaskWithLease(w.config.WorkerID, w.config.TaskLeaseDuration) + queueLatency := time.Since(queueStart) + if err != nil { + if err == context.DeadlineExceeded { + continue + } + w.logger.Error("error fetching task", + "worker_id", w.id, + "error", err) + continue + } + + if task == nil { + if queueLatency > 200*time.Millisecond { + w.logger.Debug("queue poll latency", + "latency_ms", queueLatency.Milliseconds()) + } + continue + } + + if depth, derr := w.queue.QueueDepth(); derr == nil { + if queueLatency > 100*time.Millisecond || depth > 0 { + w.logger.Debug("queue fetch metrics", + "latency_ms", queueLatency.Milliseconds(), + "remaining_depth", depth) + } + } else if queueLatency > 100*time.Millisecond { + w.logger.Debug("queue fetch metrics", + "latency_ms", queueLatency.Milliseconds(), + "depth_error", derr) + } + + go w.executeTaskWithLease(task) + } +} + +func (w *Worker) heartbeat() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-w.ctx.Done(): + return + case <-ticker.C: + if err := w.queue.Heartbeat(w.id); err != nil { + w.logger.Warn("heartbeat failed", + "worker_id", w.id, + "error", err) + } + } + } +} + +// NEW: Fetch datasets using data_manager +func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { + logger := w.logger.Job(ctx, task.JobName, task.ID) + logger.Info("fetching datasets", + "worker_id", w.id, + "dataset_count", len(task.Datasets)) + + for _, dataset := range task.Datasets { + if w.datasetIsFresh(dataset) { + logger.Debug("skipping cached dataset", + "dataset", dataset) + continue + } + // Check for cancellation before each dataset fetch + select { + case <-w.ctx.Done(): + return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err()) + default: + } + + logger.Info("fetching dataset", + "worker_id", w.id, + "dataset", dataset) + + // Create command with context for cancellation support + cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) + cmd := exec.CommandContext(cmdCtx, + w.config.DataManagerPath, + "fetch", + task.JobName, + dataset, + ) + + output, err := cmd.CombinedOutput() + cancel() // Clean up context + + if err != nil { + return &errors.DataFetchError{ + Dataset: dataset, + JobName: task.JobName, + Err: fmt.Errorf("command failed: %w, output: %s", err, output), + } + } + + logger.Info("dataset ready", + "worker_id", w.id, + "dataset", dataset) + w.markDatasetFetched(dataset) + } + + return nil +} + +func (w *Worker) runJob(task *queue.Task) error { + // Validate job name to prevent path traversal + if err := container.ValidateJobName(task.JobName); err != nil { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + + jobPaths := config.NewJobPaths(w.config.BasePath) + jobDir := filepath.Join(jobPaths.PendingPath(), task.JobName) + outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName) + logFile := filepath.Join(outputDir, "output.log") + + // Sanitize paths + jobDir, err := container.SanitizePath(jobDir) + if err != nil { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + outputDir, err = container.SanitizePath(outputDir) + if err != nil { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + + // Create output directory + if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return "", fmt.Errorf("mkdir failed: %w", err) + } + return "", nil + }); err != nil { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to create output dir: %w", err), + } + } + + // Move job from pending to running + stagingStart := time.Now() + if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) { + if err := os.Rename(jobDir, outputDir); err != nil { + return "", fmt.Errorf("rename failed: %w", err) + } + return "", nil + }); err != nil { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to move job: %w", err), + } + } + stagingDuration := time.Since(stagingStart) + + if w.config.PodmanImage == "" { + return &errors.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: fmt.Errorf("podman_image must be configured"), + } + } + + containerWorkspace := w.config.ContainerWorkspace + if containerWorkspace == "" { + containerWorkspace = config.DefaultContainerWorkspace + } + containerResults := w.config.ContainerResults + if containerResults == "" { + containerResults = config.DefaultContainerResults + } + + podmanCfg := container.PodmanConfig{ + Image: w.config.PodmanImage, + Workspace: filepath.Join(outputDir, "code"), + Results: filepath.Join(outputDir, "results"), + ContainerWorkspace: containerWorkspace, + ContainerResults: containerResults, + GPUAccess: w.config.GPUAccess, + } + + scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript) + requirementsPath := filepath.Join(containerWorkspace, "requirements.txt") + + var extraArgs []string + if task.Args != "" { + extraArgs = strings.Fields(task.Args) + } + + ioBefore, ioErr := telemetry.ReadProcessIO() + podmanCmd := container.BuildPodmanCommand(podmanCfg, scriptPath, requirementsPath, extraArgs) + logFileHandle, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err == nil { + podmanCmd.Stdout = logFileHandle + podmanCmd.Stderr = logFileHandle + } else { + w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err) + } + + w.logger.Info("executing podman job", + "job", task.JobName, + "image", w.config.PodmanImage, + "workspace", podmanCfg.Workspace, + "results", podmanCfg.Results) + + containerStart := time.Now() + if err := podmanCmd.Run(); err != nil { + containerDuration := time.Since(containerStart) + // Move job to failed directory + failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) + if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) { + if err := os.Rename(outputDir, failedDir); err != nil { + return "", fmt.Errorf("rename to failed failed: %w", err) + } + return "", nil + }); moveErr != nil { + w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr) + } + + if ioErr == nil { + if after, err := telemetry.ReadProcessIO(); err == nil { + delta := telemetry.DiffIO(ioBefore, after) + w.logger.Debug("worker io stats", + "job", task.JobName, + "read_bytes", delta.ReadBytes, + "write_bytes", delta.WriteBytes) + } + } + w.logger.Info("job timing (failure)", + "job", task.JobName, + "staging_ms", stagingDuration.Milliseconds(), + "container_ms", containerDuration.Milliseconds(), + "finalize_ms", 0, + "total_ms", time.Since(stagingStart).Milliseconds(), + ) + return fmt.Errorf("execution failed: %w", err) + } + containerDuration := time.Since(containerStart) + + finalizeStart := time.Now() + // Move job to finished directory + finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) + if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) { + if err := os.Rename(outputDir, finishedDir); err != nil { + return "", fmt.Errorf("rename to finished failed: %w", err) + } + return "", nil + }); moveErr != nil { + w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr) + } + finalizeDuration := time.Since(finalizeStart) + totalDuration := time.Since(stagingStart) + var ioDelta telemetry.IOStats + if ioErr == nil { + if after, err := telemetry.ReadProcessIO(); err == nil { + ioDelta = telemetry.DiffIO(ioBefore, after) + } + } + + w.logger.Info("job timing", + "job", task.JobName, + "staging_ms", stagingDuration.Milliseconds(), + "container_ms", containerDuration.Milliseconds(), + "finalize_ms", finalizeDuration.Milliseconds(), + "total_ms", totalDuration.Milliseconds(), + "io_read_bytes", ioDelta.ReadBytes, + "io_write_bytes", ioDelta.WriteBytes, + ) + + return nil +} + +func parseDatasets(args string) []string { + if !strings.Contains(args, "--datasets") { + return nil + } + + parts := strings.Fields(args) + for i, part := range parts { + if part == "--datasets" && i+1 < len(parts) { + return strings.Split(parts[i+1], ",") + } + } + + return nil +} + +func (w *Worker) waitForTasks() { + timeout := time.After(5 * time.Minute) + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-timeout: + w.logger.Warn("shutdown timeout, force stopping", + "running_tasks", len(w.running)) + return + case <-ticker.C: + count := w.runningCount() + if count == 0 { + w.logger.Info("all tasks completed, shutting down") + return + } + w.logger.Debug("waiting for tasks to complete", + "remaining", count) + } + } +} + +func (w *Worker) runningCount() int { + w.runningMu.RLock() + defer w.runningMu.RUnlock() + return len(w.running) +} + +func (w *Worker) datasetIsFresh(dataset string) bool { + w.datasetCacheMu.RLock() + defer w.datasetCacheMu.RUnlock() + expires, ok := w.datasetCache[dataset] + return ok && time.Now().Before(expires) +} + +func (w *Worker) markDatasetFetched(dataset string) { + expires := time.Now().Add(w.datasetCacheTTL) + w.datasetCacheMu.Lock() + w.datasetCache[dataset] = expires + w.datasetCacheMu.Unlock() +} + +func (w *Worker) GetMetrics() map[string]any { + stats := w.metrics.GetStats() + stats["worker_id"] = w.id + stats["max_workers"] = w.config.MaxWorkers + return stats +} + +func (w *Worker) Stop() { + w.cancel() + w.waitForTasks() + + // FIXED: Check error return values + if err := w.server.Close(); err != nil { + w.logger.Warn("error closing server connection", "error", err) + } + if err := w.queue.Close(); err != nil { + w.logger.Warn("error closing queue connection", "error", err) + } + if w.metricsSrv != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := w.metricsSrv.Shutdown(ctx); err != nil { + w.logger.Warn("metrics exporter shutdown error", "error", err) + } + } + w.logger.Info("worker stopped", "worker_id", w.id) +} + +// Execute task with lease management and retry: +func (w *Worker) executeTaskWithLease(task *queue.Task) { + // Track task for graceful shutdown + w.gracefulWait.Add(1) + w.activeTasks.Store(task.ID, task) + defer w.gracefulWait.Done() + defer w.activeTasks.Delete(task.ID) + + // Create task-specific context with timeout + taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing + taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata + taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata + + taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) + defer taskCancel() + + logger := w.logger.Job(taskCtx, task.JobName, task.ID) + logger.Info("starting task", + "worker_id", w.id, + "datasets", task.Datasets, + "priority", task.Priority) + + // Record task start + w.metrics.RecordTaskStart() + defer w.metrics.RecordTaskCompletion() + + // Check for context cancellation + select { + case <-taskCtx.Done(): + logger.Info("task cancelled before execution") + return + default: + } + + // Parse datasets from task arguments + if task.Datasets == nil { + task.Datasets = parseDatasets(task.Args) + } + + // Start heartbeat goroutine + heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) + defer cancelHeartbeat() + + go w.heartbeatLoop(heartbeatCtx, task.ID) + + // Update task status + task.Status = "running" + now := time.Now() + task.StartedAt = &now + task.WorkerID = w.id + + if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { + logger.Error("failed to update task status", "error", err) + w.metrics.RecordTaskFailure() + return + } + + if w.config.AutoFetchData && len(task.Datasets) > 0 { + if err := w.fetchDatasets(taskCtx, task); err != nil { + logger.Error("data fetch failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Data fetch failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + err := w.queue.UpdateTask(task) + if err != nil { + logger.Error("failed to update task status after data fetch failure", "error", err) + } + w.metrics.RecordTaskFailure() + return + } + } + + // Execute job with panic recovery + var execErr error + func() { + defer func() { + if r := recover(); r != nil { + execErr = fmt.Errorf("panic during execution: %v", r) + } + }() + execErr = w.runJob(task) + }() + + // Finalize task + endTime := time.Now() + task.EndedAt = &endTime + + if execErr != nil { + task.Error = execErr.Error() + + // Check if transient error (network, timeout, etc) + if isTransientError(execErr) && task.RetryCount < task.MaxRetries { + w.logger.Warn("task failed with transient error, will retry", + "task_id", task.ID, + "error", execErr, + "retry_count", task.RetryCount) + w.queue.RetryTask(task) + } else { + task.Status = "failed" + w.queue.UpdateTaskWithMetrics(task, "final") + } + } else { + task.Status = "completed" + w.queue.UpdateTaskWithMetrics(task, "final") + } + + // Release lease + w.queue.ReleaseLease(task.ID, w.config.WorkerID) +} + +// Heartbeat loop to renew lease: +func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { + ticker := time.NewTicker(w.config.HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { + w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) + return + } + // Also update worker heartbeat + w.queue.Heartbeat(w.config.WorkerID) + } + } +} + +// Graceful shutdown: +func (w *Worker) Shutdown() error { + w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) + + // Wait for active tasks with timeout + done := make(chan struct{}) + go func() { + w.gracefulWait.Wait() + close(done) + }() + + timeout := time.After(w.config.GracefulTimeout) + select { + case <-done: + w.logger.Info("all tasks completed, shutdown successful") + case <-timeout: + w.logger.Warn("graceful shutdown timeout, releasing active leases") + w.releaseAllLeases() + } + + return w.queue.Close() +} + +// Release all active leases: +func (w *Worker) releaseAllLeases() { + w.activeTasks.Range(func(key, value interface{}) bool { + taskID := key.(string) + if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { + w.logger.Error("failed to release lease", "task_id", taskID, "error", err) + } + return true + }) +} + +// Helper functions: +func (w *Worker) countActiveTasks() int { + count := 0 + w.activeTasks.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + +func isTransientError(err error) bool { + if err == nil { + return false + } + // Check if error is transient (network, timeout, resource unavailable, etc) + errStr := err.Error() + transientIndicators := []string{ + "connection refused", + "timeout", + "temporary failure", + "resource temporarily unavailable", + "no such host", + "network unreachable", + } + for _, indicator := range transientIndicators { + if strings.Contains(strings.ToLower(errStr), indicator) { + return true + } + } + return false +} + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + // Parse authentication flags + authFlags := auth.ParseAuthFlags() + if err := auth.ValidateAuthFlags(authFlags); err != nil { + log.Fatalf("Authentication flag error: %v", err) + } + + // Get API key from various sources + apiKey := auth.GetAPIKeyFromSources(authFlags) + + // Load configuration + configPath := "config-local.yaml" + if authFlags.ConfigFile != "" { + configPath = authFlags.ConfigFile + } + + resolvedConfig, err := config.ResolveConfigPath(configPath) + if err != nil { + log.Fatalf("%v", err) + } + + cfg, err := LoadConfig(resolvedConfig) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // Validate authentication configuration + if err := cfg.Auth.ValidateAuthConfig(); err != nil { + log.Fatalf("Invalid authentication configuration: %v", err) + } + + // Validate configuration + if err := cfg.Validate(); err != nil { + log.Fatalf("Invalid configuration: %v", err) + } + + // Test authentication if enabled + if cfg.Auth.Enabled && apiKey != "" { + user, err := cfg.Auth.ValidateAPIKey(apiKey) + if err != nil { + log.Fatalf("Authentication failed: %v", err) + } + log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin) + } else if cfg.Auth.Enabled { + log.Fatal("Authentication required but no API key provided") + } + + worker, err := NewWorker(cfg, apiKey) + if err != nil { + log.Fatalf("Failed to create worker: %v", err) + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go worker.Start() + + sig := <-sigChan + log.Printf("Received signal: %v", sig) + + // Use graceful shutdown + if err := worker.Shutdown(); err != nil { + log.Printf("Graceful shutdown error: %v", err) + worker.Stop() // Fallback to force stop + } else { + log.Println("Worker shut down gracefully") + } +} diff --git a/examples/auth_integration_example.go b/examples/auth_integration_example.go new file mode 100644 index 0000000..6a1fe1f --- /dev/null +++ b/examples/auth_integration_example.go @@ -0,0 +1,78 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/jfraeys/fetch_ml/internal/auth" + "gopkg.in/yaml.v3" +) + +// Example: How to integrate auth into TUI startup +func checkAuth(configFile string) error { + // Load config + data, err := os.ReadFile(configFile) + if err != nil { + return fmt.Errorf("failed to read config: %w", err) + } + + var cfg struct { + Auth auth.AuthConfig `yaml:"auth"` + } + + if err := yaml.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("failed to parse config: %w", err) + } + + // If auth disabled, proceed normally + if !cfg.Auth.Enabled { + fmt.Println("🔓 Authentication disabled - proceeding normally") + return nil + } + + // Check for API key + apiKey := os.Getenv("FETCH_ML_API_KEY") + if apiKey == "" { + apiKey = getAPIKeyFromUser() + } + + // Validate API key + user, err := cfg.Auth.ValidateAPIKey(apiKey) + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + fmt.Printf("🔐 Authenticated as: %s", user.Name) + if user.Admin { + fmt.Println(" (admin)") + } else { + fmt.Println() + } + + return nil +} + +func getAPIKeyFromUser() string { + fmt.Print("🔑 Enter API key: ") + var key string + fmt.Scanln(&key) + return key +} + +// Example usage in main() +func exampleMain() { + configFile := "config_dev.yaml" + + // Check authentication first + if err := checkAuth(configFile); err != nil { + log.Fatalf("Authentication failed: %v", err) + } + + // Proceed with normal TUI initialization + fmt.Println("Starting TUI...") +} + +func main() { + exampleMain() +} diff --git a/internal/api/permissions_test.go b/internal/api/permissions_test.go new file mode 100644 index 0000000..5cfd2b9 --- /dev/null +++ b/internal/api/permissions_test.go @@ -0,0 +1,117 @@ +package api + +import ( + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +func TestUserPermissions(t *testing.T) { + authConfig := &auth.AuthConfig{ + Enabled: true, + APIKeys: map[auth.Username]auth.APIKeyEntry{ + "admin": { + Hash: auth.APIKeyHash(auth.HashAPIKey("admin_key")), + Admin: true, + }, + "scientist": { + Hash: auth.APIKeyHash(auth.HashAPIKey("ds_key")), + Admin: false, + Permissions: map[string]bool{ + "jobs:create": true, + "jobs:read": true, + "jobs:update": true, + }, + }, + }, + } + + tests := []struct { + name string + apiKey string + permission string + want bool + }{ + {"Admin can create", "admin_key", "jobs:create", true}, + {"Scientist can create", "ds_key", "jobs:create", true}, + {"Invalid key fails", "invalid_key", "jobs:create", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, err := authConfig.ValidateAPIKey(tt.apiKey) + + if tt.apiKey == "invalid_key" { + if err == nil { + t.Error("Expected error for invalid API key") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + got := user.HasPermission(tt.permission) + if got != tt.want { + t.Errorf("HasPermission() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTaskOwnership(t *testing.T) { + tasks := []*queue.Task{ + { + ID: "task1", + JobName: "user1_job", + UserID: "user1", + CreatedBy: "user1", + CreatedAt: time.Now(), + }, + { + ID: "task2", + JobName: "user2_job", + UserID: "user2", + CreatedBy: "user2", + CreatedAt: time.Now(), + }, + } + + users := map[string]*auth.User{ + "user1": {Name: "user1", Admin: false}, + "user2": {Name: "user2", Admin: false}, + "admin": {Name: "admin", Admin: true}, + } + + tests := []struct { + name string + userName string + task *queue.Task + want bool + }{ + {"User can view own task", "user1", tasks[0], true}, + {"User cannot view other task", "user1", tasks[1], false}, + {"Admin can view any task", "admin", tasks[1], true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user := users[tt.userName] + + canAccess := false + if user.Admin { + canAccess = true + } else if tt.task.UserID == user.Name || tt.task.CreatedBy == user.Name { + canAccess = true + } + + if canAccess != tt.want { + t.Errorf("Access = %v, want %v", canAccess, tt.want) + } + }) + } +} diff --git a/internal/api/protocol.go b/internal/api/protocol.go new file mode 100644 index 0000000..378808f --- /dev/null +++ b/internal/api/protocol.go @@ -0,0 +1,305 @@ +package api + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "time" +) + +// Response packet types +const ( + PacketTypeSuccess = 0x00 + PacketTypeError = 0x01 + PacketTypeProgress = 0x02 + PacketTypeStatus = 0x03 + PacketTypeData = 0x04 + PacketTypeLog = 0x05 +) + +// Error codes +const ( + ErrorCodeUnknownError = 0x00 + ErrorCodeInvalidRequest = 0x01 + ErrorCodeAuthenticationFailed = 0x02 + ErrorCodePermissionDenied = 0x03 + ErrorCodeResourceNotFound = 0x04 + ErrorCodeResourceAlreadyExists = 0x05 + + ErrorCodeServerOverloaded = 0x10 + ErrorCodeDatabaseError = 0x11 + ErrorCodeNetworkError = 0x12 + ErrorCodeStorageError = 0x13 + ErrorCodeTimeout = 0x14 + + ErrorCodeJobNotFound = 0x20 + ErrorCodeJobAlreadyRunning = 0x21 + ErrorCodeJobFailedToStart = 0x22 + ErrorCodeJobExecutionFailed = 0x23 + ErrorCodeJobCancelled = 0x24 + + ErrorCodeOutOfMemory = 0x30 + ErrorCodeDiskFull = 0x31 + ErrorCodeInvalidConfiguration = 0x32 + ErrorCodeServiceUnavailable = 0x33 +) + +// Progress types +const ( + ProgressTypePercentage = 0x00 + ProgressTypeStage = 0x01 + ProgressTypeMessage = 0x02 + ProgressTypeBytesTransferred = 0x03 +) + +// Log levels +const ( + LogLevelDebug = 0x00 + LogLevelInfo = 0x01 + LogLevelWarn = 0x02 + LogLevelError = 0x03 +) + +// ResponsePacket represents a structured response packet +type ResponsePacket struct { + PacketType byte + Timestamp uint64 + + // Success fields + SuccessMessage string + + // Error fields + ErrorCode byte + ErrorMessage string + ErrorDetails string + + // Progress fields + ProgressType byte + ProgressValue uint32 + ProgressTotal uint32 + ProgressMessage string + + // Status fields + StatusData string + + // Data fields + DataType string + DataPayload []byte + + // Log fields + LogLevel byte + LogMessage string +} + +// NewSuccessPacket creates a success response packet +func NewSuccessPacket(message string) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeSuccess, + Timestamp: uint64(time.Now().Unix()), + SuccessMessage: message, + } +} + +// NewSuccessPacketWithPayload creates a success response packet with JSON payload +func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponsePacket { + // Convert payload to JSON for the DataPayload field + payloadBytes, _ := json.Marshal(payload) + + return &ResponsePacket{ + PacketType: PacketTypeData, + Timestamp: uint64(time.Now().Unix()), + SuccessMessage: message, + DataType: "status", + DataPayload: payloadBytes, + } +} + +// NewErrorPacket creates an error response packet +func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeError, + Timestamp: uint64(time.Now().Unix()), + ErrorCode: errorCode, + ErrorMessage: message, + ErrorDetails: details, + } +} + +// NewProgressPacket creates a progress response packet +func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeProgress, + Timestamp: uint64(time.Now().Unix()), + ProgressType: progressType, + ProgressValue: value, + ProgressTotal: total, + ProgressMessage: message, + } +} + +// NewStatusPacket creates a status response packet +func NewStatusPacket(data string) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeStatus, + Timestamp: uint64(time.Now().Unix()), + StatusData: data, + } +} + +// NewDataPacket creates a data response packet +func NewDataPacket(dataType string, payload []byte) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeData, + Timestamp: uint64(time.Now().Unix()), + DataType: dataType, + DataPayload: payload, + } +} + +// NewLogPacket creates a log response packet +func NewLogPacket(level byte, message string) *ResponsePacket { + return &ResponsePacket{ + PacketType: PacketTypeLog, + Timestamp: uint64(time.Now().Unix()), + LogLevel: level, + LogMessage: message, + } +} + +// Serialize converts the packet to binary format +func (p *ResponsePacket) Serialize() ([]byte, error) { + var buf []byte + + // Packet type + buf = append(buf, p.PacketType) + + // Timestamp (8 bytes, big-endian) + timestampBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timestampBytes, p.Timestamp) + buf = append(buf, timestampBytes...) + + // Packet-specific data + switch p.PacketType { + case PacketTypeSuccess: + buf = append(buf, serializeString(p.SuccessMessage)...) + + case PacketTypeError: + buf = append(buf, p.ErrorCode) + buf = append(buf, serializeString(p.ErrorMessage)...) + buf = append(buf, serializeString(p.ErrorDetails)...) + + case PacketTypeProgress: + buf = append(buf, p.ProgressType) + valueBytes := make([]byte, 4) + binary.BigEndian.PutUint32(valueBytes, p.ProgressValue) + buf = append(buf, valueBytes...) + + totalBytes := make([]byte, 4) + binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal) + buf = append(buf, totalBytes...) + + buf = append(buf, serializeString(p.ProgressMessage)...) + + case PacketTypeStatus: + buf = append(buf, serializeString(p.StatusData)...) + + case PacketTypeData: + buf = append(buf, serializeString(p.DataType)...) + buf = append(buf, serializeBytes(p.DataPayload)...) + + case PacketTypeLog: + buf = append(buf, p.LogLevel) + buf = append(buf, serializeString(p.LogMessage)...) + + default: + return nil, fmt.Errorf("unknown packet type: %d", p.PacketType) + } + + return buf, nil +} + +// serializeString writes a string with 2-byte length prefix +func serializeString(s string) []byte { + length := uint16(len(s)) + buf := make([]byte, 2+len(s)) + binary.BigEndian.PutUint16(buf[:2], length) + copy(buf[2:], s) + return buf +} + +// serializeBytes writes bytes with 4-byte length prefix +func serializeBytes(b []byte) []byte { + length := uint32(len(b)) + buf := make([]byte, 4+len(b)) + binary.BigEndian.PutUint32(buf[:4], length) + copy(buf[4:], b) + return buf +} + +// GetErrorMessage returns a human-readable error message for an error code +func GetErrorMessage(code byte) string { + switch code { + case ErrorCodeUnknownError: + return "Unknown error occurred" + case ErrorCodeInvalidRequest: + return "Invalid request format" + case ErrorCodeAuthenticationFailed: + return "Authentication failed" + case ErrorCodePermissionDenied: + return "Permission denied" + case ErrorCodeResourceNotFound: + return "Resource not found" + case ErrorCodeResourceAlreadyExists: + return "Resource already exists" + + case ErrorCodeServerOverloaded: + return "Server is overloaded" + case ErrorCodeDatabaseError: + return "Database error occurred" + case ErrorCodeNetworkError: + return "Network error occurred" + case ErrorCodeStorageError: + return "Storage error occurred" + case ErrorCodeTimeout: + return "Operation timed out" + + case ErrorCodeJobNotFound: + return "Job not found" + case ErrorCodeJobAlreadyRunning: + return "Job is already running" + case ErrorCodeJobFailedToStart: + return "Job failed to start" + case ErrorCodeJobExecutionFailed: + return "Job execution failed" + case ErrorCodeJobCancelled: + return "Job was cancelled" + + case ErrorCodeOutOfMemory: + return "Server out of memory" + case ErrorCodeDiskFull: + return "Server disk full" + case ErrorCodeInvalidConfiguration: + return "Invalid server configuration" + case ErrorCodeServiceUnavailable: + return "Service temporarily unavailable" + + default: + return "Unknown error code" + } +} + +// GetLogLevelName returns the name for a log level +func GetLogLevelName(level byte) string { + switch level { + case LogLevelDebug: + return "DEBUG" + case LogLevelInfo: + return "INFO" + case LogLevelWarn: + return "WARN" + case LogLevelError: + return "ERROR" + default: + return "UNKNOWN" + } +} diff --git a/internal/api/ws.go b/internal/api/ws.go new file mode 100644 index 0000000..df1d3f3 --- /dev/null +++ b/internal/api/ws.go @@ -0,0 +1,606 @@ +package api + +import ( + "crypto/sha256" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "net/http" + "net/url" + "strings" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/queue" + "golang.org/x/crypto/acme/autocert" +) + +// Opcodes for binary WebSocket protocol +const ( + OpcodeQueueJob = 0x01 + OpcodeStatusRequest = 0x02 + OpcodeCancelJob = 0x03 + OpcodePrune = 0x04 + OpcodeLogMetric = 0x0A + OpcodeGetExperiment = 0x0B +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // Allow localhost and homelab origins for development + origin := r.Header.Get("Origin") + if origin == "" { + return true // Allow same-origin requests + } + + // Parse origin URL + parsedOrigin, err := url.Parse(origin) + if err != nil { + return false + } + + // Allow localhost and local network origins + host := parsedOrigin.Host + return strings.HasSuffix(host, ":8080") || + strings.HasSuffix(host, ":8081") || + strings.HasPrefix(host, "localhost") || + strings.HasPrefix(host, "127.0.0.1") || + strings.HasPrefix(host, "192.168.") || + strings.HasPrefix(host, "10.") || + strings.HasPrefix(host, "172.") + }, +} + +type WSHandler struct { + authConfig *auth.AuthConfig + logger *logging.Logger + expManager *experiment.Manager + queue *queue.TaskQueue +} + +func NewWSHandler(authConfig *auth.AuthConfig, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue) *WSHandler { + return &WSHandler{ + authConfig: authConfig, + logger: logger, + expManager: expManager, + queue: taskQueue, + } +} + +func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check API key before upgrading WebSocket + apiKey := r.Header.Get("X-API-Key") + if apiKey == "" { + // Also check Authorization header + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + apiKey = strings.TrimPrefix(authHeader, "Bearer ") + } + } + + // Validate API key if authentication is enabled + if h.authConfig != nil && h.authConfig.Enabled { + if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil { + h.logger.Warn("websocket authentication failed", "error", err) + http.Error(w, "Invalid API key", http.StatusUnauthorized) + return + } + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Error("websocket upgrade failed", "error", err) + return + } + defer conn.Close() + + h.logger.Info("websocket connection established", "remote", r.RemoteAddr) + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + h.logger.Error("websocket read error", "error", err) + } + break + } + + if messageType != websocket.BinaryMessage { + h.logger.Warn("received non-binary message") + continue + } + + if err := h.handleMessage(conn, message); err != nil { + h.logger.Error("message handling error", "error", err) + // Send error response + _ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode + } + } +} + +func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { + if len(message) < 1 { + return fmt.Errorf("message too short") + } + + opcode := message[0] + payload := message[1:] + + switch opcode { + case OpcodeQueueJob: + return h.handleQueueJob(conn, payload) + case OpcodeStatusRequest: + return h.handleStatusRequest(conn, payload) + case OpcodeCancelJob: + return h.handleCancelJob(conn, payload) + case OpcodePrune: + return h.handlePrune(conn, payload) + case OpcodeLogMetric: + return h.handleLogMetric(conn, payload) + case OpcodeGetExperiment: + return h.handleGetExperiment(conn, payload) + default: + return fmt.Errorf("unknown opcode: 0x%02x", opcode) + } +} + +func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var] + if len(payload) < 130 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") + } + + apiKeyHash := string(payload[:64]) + commitID := string(payload[64:128]) + priority := int64(payload[128]) + jobNameLen := int(payload[129]) + + if len(payload) < 130+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + + jobName := string(payload[130 : 130+jobNameLen]) + + h.logger.Info("queue job request", + "job", jobName, + "priority", priority, + "commit_id", commitID, + ) + + // Validate API key and get user information + user, err := h.authConfig.ValidateAPIKey(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + + // Check user permissions + if !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "") + } + + // Create experiment directory and metadata + if err := h.expManager.CreateExperiment(commitID); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) + } + + // Add user info to experiment metadata + meta := &experiment.Metadata{ + CommitID: commitID, + JobName: jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + + if err := h.expManager.WriteMetadata(meta); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) + } + + h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) + + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) + + // Enqueue task if queue is available + if h.queue != nil { + taskID := uuid.New().String() + task := &queue.Task{ + ID: taskID, + JobName: jobName, + Args: "", // TODO: Add args support + Status: "queued", + Priority: priority, + CreatedAt: time.Now(), + UserID: user.Name, + Username: user.Name, + CreatedBy: user.Name, + Metadata: map[string]string{ + "commit_id": commitID, + "user_id": user.Name, + "username": user.Name, + }, + } + + if err := h.queue.AddTask(task); err != nil { + h.logger.Error("failed to enqueue task", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) + } + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) + } + + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") + + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64] + if len(payload) < 64 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") + } + + apiKeyHash := string(payload[0:64]) + h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...") + + // Validate API key and get user information + user, err := h.authConfig.ValidateAPIKey(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + + // Check user permissions for viewing jobs + if !h.authConfig.Enabled || user.HasPermission("jobs:read") { + // Continue with status request + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "") + } + + // Get tasks with user filtering + var tasks []*queue.Task + if h.queue != nil { + allTasks, err := h.queue.GetAllTasks() + if err != nil { + h.logger.Error("failed to get tasks", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) + } + + // Filter tasks based on user permissions + for _, task := range allTasks { + // If auth is disabled or admin can see all tasks + if !h.authConfig.Enabled || user.Admin { + tasks = append(tasks, task) + continue + } + + // Users can only see their own tasks + if task.UserID == user.Name || task.CreatedBy == user.Name { + tasks = append(tasks, task) + } + } + } + + // Build status response with user-specific data + status := map[string]interface{}{ + "user": map[string]interface{}{ + "name": user.Name, + "admin": user.Admin, + "roles": user.Roles, + }, + "tasks": map[string]interface{}{ + "total": len(tasks), + "queued": countTasksByStatus(tasks, "queued"), + "running": countTasksByStatus(tasks, "running"), + "failed": countTasksByStatus(tasks, "failed"), + "completed": countTasksByStatus(tasks, "completed"), + }, + "queue": tasks, // Include filtered tasks + } + + packet := NewSuccessPacketWithPayload("Status retrieved", status) + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") + + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +// countTasksByStatus counts tasks by their status +func countTasksByStatus(tasks []*queue.Task, status string) int { + count := 0 + for _, task := range tasks { + if task.Status == status { + count++ + } + } + return count +} + +func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64][job_name_len:1][job_name:var] + if len(payload) < 65 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") + } + + // Parse 64-byte hex API key hash + apiKeyHash := string(payload[0:64]) + jobNameLen := int(payload[64]) + + if len(payload) < 65+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + + jobName := string(payload[65 : 65+jobNameLen]) + + h.logger.Info("cancel job request", "job", jobName) + + // Validate API key and get user information + user, err := h.authConfig.ValidateAPIKey(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + + // Check user permissions for canceling jobs + if !h.authConfig.Enabled || user.HasPermission("jobs:update") { + // Continue with cancel request + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "") + } + + // Find the task and verify ownership + if h.queue != nil { + task, err := h.queue.GetTaskByName(jobName) + if err != nil { + h.logger.Error("task not found", "job", jobName, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) + } + + // Check if user can cancel this task (admin or owner) + if !h.authConfig.Enabled || user.Admin || task.UserID == user.Name || task.CreatedBy == user.Name { + // User can cancel the task + } else { + h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID) + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "") + } + + // Cancel the task + if err := h.queue.CancelTask(task.ID); err != nil { + h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) + } + + h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) + } + + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") + + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64][prune_type:1][value:4] + if len(payload) < 69 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") + } + + // Parse 64-byte hex API key hash + apiKeyHash := string(payload[0:64]) + pruneType := payload[64] + value := binary.BigEndian.Uint32(payload[65:69]) + + h.logger.Info("prune request", "type", pruneType, "value", value) + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + h.logger.Error("api key verification failed", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) + } + } + + // Convert prune parameters + var keepCount int + var olderThanDays int + + switch pruneType { + case 0: + // keep N + keepCount = int(value) + olderThanDays = 0 + case 1: + // older than days + keepCount = 0 + olderThanDays = int(value) + default: + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "") + } + + // Perform pruning + pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) + if err != nil { + h.logger.Error("prune failed", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) + } + + h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) + + // Send structured success response + packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) + return h.sendResponsePacket(conn, packet) +} + +func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var] + if len(payload) < 141 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") + } + + apiKeyHash := string(payload[:64]) + commitID := string(payload[64:128]) + step := int(binary.BigEndian.Uint32(payload[128:132])) + valueBits := binary.BigEndian.Uint64(payload[132:140]) + value := math.Float64frombits(valueBits) + nameLen := int(payload[140]) + + if len(payload) < 141+nameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") + } + + name := string(payload[141 : 141+nameLen]) + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) + } + } + + if err := h.expManager.LogMetric(commitID, name, value, step); err != nil { + h.logger.Error("failed to log metric", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) + } + + return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) +} + +func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:64][commit_id:64] + if len(payload) < 128 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") + } + + apiKeyHash := string(payload[:64]) + commitID := string(payload[64:128]) + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) + } + } + + meta, err := h.expManager.ReadMetadata(commitID) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) + } + + metrics, err := h.expManager.GetMetrics(commitID) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) + } + + response := map[string]interface{}{ + "metadata": meta, + "metrics": metrics, + } + + return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response)) +} + +// Helper to hash API key for comparison +func HashAPIKey(apiKey string) string { + hash := sha256.Sum256([]byte(apiKey)) + return hex.EncodeToString(hash[:]) +} + +// SetupTLSConfig creates TLS configuration for WebSocket server +func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) { + var server *http.Server + + if certFile != "" && keyFile != "" { + // Use provided certificates + server = &http.Server{ + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + }, + } + } else if host != "" { + // Use Let's Encrypt with autocert + certManager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(host), + Cache: autocert.DirCache("/var/www/.cache"), + } + + server = &http.Server{ + TLSConfig: certManager.TLSConfig(), + } + } + + return server, nil +} + +// verifyAPIKeyHash verifies the provided hex hash against stored API keys +func (h *WSHandler) verifyAPIKeyHash(hexHash string) error { + if h.authConfig == nil || !h.authConfig.Enabled { + return nil // No auth required + } + + // For now, just check if it's a valid 64-char hex string + if len(hexHash) != 64 { + return fmt.Errorf("invalid api key hash length") + } + + // Check against stored API keys + for username, entry := range h.authConfig.APIKeys { + if string(entry.Hash) == hexHash { + _ = username // Username found but not needed for verification + return nil // Valid API key found + } + } + + return fmt.Errorf("invalid api key") +} + +// sendErrorPacket sends an error response packet +func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error { + packet := NewErrorPacket(errorCode, message, details) + return h.sendResponsePacket(conn, packet) +} + +// sendResponsePacket sends a structured response packet +func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error { + data, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize response packet", "error", err) + // Fallback to simple error response + return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) + } + + return conn.WriteMessage(websocket.BinaryMessage, data) +} + +// sendErrorResponse removed (unused) diff --git a/internal/api/ws_test.go b/internal/api/ws_test.go new file mode 100644 index 0000000..252e1e4 --- /dev/null +++ b/internal/api/ws_test.go @@ -0,0 +1,335 @@ +package api + +import ( + "encoding/binary" + "math" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTestServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) { + // Setup miniredis + s, err := miniredis.Run() + require.NoError(t, err) + + // Setup TaskQueue + queueCfg := queue.Config{ + RedisAddr: s.Addr(), + MetricsFlushInterval: 10 * time.Millisecond, + } + tq, err := queue.NewTaskQueue(queueCfg) + require.NoError(t, err) + + // Setup dependencies + logger := logging.NewLogger(0, false) + expManager := experiment.NewManager(t.TempDir()) + authCfg := &auth.AuthConfig{Enabled: false} + + // Create handler + handler := NewWSHandler(authCfg, logger, expManager, tq) + + // Setup test server + server := httptest.NewServer(handler) + + return server, tq, expManager, s +} + +func connectWS(t *testing.T, serverURL string) *websocket.Conn { + wsURL := "ws" + strings.TrimPrefix(serverURL, "http") + ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + return ws +} + +func TestWSHandler_QueueJob(t *testing.T) { + server, tq, _, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare queue_job message + // Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var] + opcode := byte(OpcodeQueueJob) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + commitID := make([]byte, 64) + copy(commitID, []byte(strings.Repeat("a", 64))) + priority := byte(5) + jobName := "test-job" + jobNameLen := byte(len(jobName)) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + msg = append(msg, commitID...) + msg = append(msg, priority) + msg = append(msg, jobNameLen) + msg = append(msg, []byte(jobName)...) + + // Send message + err := ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response (PacketTypeSuccess = 0x00) + assert.Equal(t, byte(PacketTypeSuccess), resp[0]) + + // Verify task in Redis + time.Sleep(100 * time.Millisecond) + task, err := tq.GetNextTask() + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, jobName, task.JobName) +} + +func TestWSHandler_StatusRequest(t *testing.T) { + server, tq, _, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + // Add a task to queue + task := &queue.Task{ + ID: "task-1", + JobName: "job-1", + Status: "queued", + Priority: 10, + CreatedAt: time.Now(), + UserID: "user", + CreatedBy: "user", + } + err := tq.AddTask(task) + require.NoError(t, err) + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare status_request message + // Protocol: [opcode:1][api_key_hash:64] + opcode := byte(OpcodeStatusRequest) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + + // Send message + err = ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response (PacketTypeData = 0x04 for status with payload) + assert.Equal(t, byte(PacketTypeData), resp[0]) +} + +func TestWSHandler_CancelJob(t *testing.T) { + server, tq, _, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + // Add a task to queue + task := &queue.Task{ + ID: "task-1", + JobName: "job-to-cancel", + Status: "queued", + Priority: 10, + CreatedAt: time.Now(), + UserID: "user", // Auth disabled so this matches any user + CreatedBy: "user", + } + err := tq.AddTask(task) + require.NoError(t, err) + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare cancel_job message + // Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var] + opcode := byte(OpcodeCancelJob) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + jobName := "job-to-cancel" + jobNameLen := byte(len(jobName)) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + msg = append(msg, jobNameLen) + msg = append(msg, []byte(jobName)...) + + // Send message + err = ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response + assert.Equal(t, byte(PacketTypeSuccess), resp[0]) + + // Verify task cancelled + updatedTask, err := tq.GetTask("task-1") + require.NoError(t, err) + assert.Equal(t, "cancelled", updatedTask.Status) +} + +func TestWSHandler_Prune(t *testing.T) { + server, tq, expManager, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + // Create some experiments + _ = expManager.CreateExperiment("commit-1") + _ = expManager.CreateExperiment("commit-2") + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare prune message + // Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4] + opcode := byte(OpcodePrune) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + pruneType := byte(0) // Keep N + value := uint32(1) // Keep 1 + valueBytes := make([]byte, 4) + binary.BigEndian.PutUint32(valueBytes, value) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + msg = append(msg, pruneType) + msg = append(msg, valueBytes...) + + // Send message + err := ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response + assert.Equal(t, byte(PacketTypeSuccess), resp[0]) +} + +func TestWSHandler_LogMetric(t *testing.T) { + server, tq, expManager, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + // Create experiment + commitIDStr := strings.Repeat("a", 64) + err := expManager.CreateExperiment(commitIDStr) + require.NoError(t, err) + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare log_metric message + // Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var] + opcode := byte(OpcodeLogMetric) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + commitID := []byte(commitIDStr) + step := uint32(100) + value := 0.95 + valueBits := math.Float64bits(value) + metricName := "accuracy" + nameLen := byte(len(metricName)) + + stepBytes := make([]byte, 4) + binary.BigEndian.PutUint32(stepBytes, step) + valueBytes := make([]byte, 8) + binary.BigEndian.PutUint64(valueBytes, valueBits) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + msg = append(msg, commitID...) + msg = append(msg, stepBytes...) + msg = append(msg, valueBytes...) + msg = append(msg, nameLen) + msg = append(msg, []byte(metricName)...) + + // Send message + err = ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response + assert.Equal(t, byte(PacketTypeSuccess), resp[0]) +} + +func TestWSHandler_GetExperiment(t *testing.T) { + server, tq, expManager, s := setupTestServer(t) + defer server.Close() + defer tq.Close() + defer s.Close() + + // Create experiment and metadata + commitIDStr := strings.Repeat("a", 64) + err := expManager.CreateExperiment(commitIDStr) + require.NoError(t, err) + + meta := &experiment.Metadata{ + CommitID: commitIDStr, + JobName: "test-job", + } + err = expManager.WriteMetadata(meta) + require.NoError(t, err) + + ws := connectWS(t, server.URL) + defer ws.Close() + + // Prepare get_experiment message + // Protocol: [opcode:1][api_key_hash:64][commit_id:64] + opcode := byte(OpcodeGetExperiment) + apiKeyHash := make([]byte, 64) + copy(apiKeyHash, []byte(strings.Repeat("0", 64))) + commitID := []byte(commitIDStr) + + var msg []byte + msg = append(msg, opcode) + msg = append(msg, apiKeyHash...) + msg = append(msg, commitID...) + + // Send message + err = ws.WriteMessage(websocket.BinaryMessage, msg) + require.NoError(t, err) + + // Read response + _, resp, err := ws.ReadMessage() + require.NoError(t, err) + + // Verify success response (PacketTypeData) + assert.Equal(t, byte(PacketTypeData), resp[0]) +} diff --git a/internal/auth/api_key.go b/internal/auth/api_key.go new file mode 100644 index 0000000..181e8c1 --- /dev/null +++ b/internal/auth/api_key.go @@ -0,0 +1,258 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "strings" + "time" +) + +// User represents an authenticated user +type User struct { + Name string `json:"name"` + Admin bool `json:"admin"` + Roles []string `json:"roles"` + Permissions map[string]bool `json:"permissions"` +} + +// APIKeyHash represents a SHA256 hash of an API key +type APIKeyHash string + +// APIKeyEntry represents an API key configuration +type APIKeyEntry struct { + Hash APIKeyHash `json:"hash"` + Admin bool `json:"admin"` + Roles []string `json:"roles,omitempty"` + Permissions map[string]bool `json:"permissions,omitempty"` +} + +// Username represents a user identifier +type Username string + +// AuthConfig represents the authentication configuration +type AuthConfig struct { + Enabled bool `json:"enabled"` + APIKeys map[Username]APIKeyEntry `json:"api_keys"` +} + +// AuthStore interface for different authentication backends +type AuthStore interface { + ValidateAPIKey(ctx context.Context, key string) (*User, error) + CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error + RevokeAPIKey(ctx context.Context, userID string) error + ListUsers(ctx context.Context) ([]UserInfo, error) +} + +// contextKey is the type for context keys +type contextKey string + +const userContextKey = contextKey("user") + +// ValidateAPIKey validates an API key and returns user information +func (c *AuthConfig) ValidateAPIKey(key string) (*User, error) { + if !c.Enabled { + // Auth disabled - return default admin user for development + return &User{Name: "default", Admin: true}, nil + } + + keyHash := HashAPIKey(key) + + for username, entry := range c.APIKeys { + if string(entry.Hash) == keyHash { + // Build user with role and permission inheritance + user := &User{ + Name: string(username), + Admin: entry.Admin, + Roles: entry.Roles, + Permissions: make(map[string]bool), + } + + // Copy explicit permissions + for perm, value := range entry.Permissions { + user.Permissions[perm] = value + } + + // Add role-based permissions + rolePerms := getRolePermissions(entry.Roles) + for perm, value := range rolePerms { + if _, exists := user.Permissions[perm]; !exists { + user.Permissions[perm] = value + } + } + + // Admin gets all permissions + if entry.Admin { + user.Permissions["*"] = true + } + + return user, nil + } + } + + return nil, fmt.Errorf("invalid API key") +} + +// AuthMiddleware creates HTTP middleware for API key authentication +func (c *AuthConfig) AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !c.Enabled { + if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") != "1" || os.Getenv("FETCH_ML_DEBUG") != "1" { + http.Error(w, "Unauthorized: Authentication disabled", http.StatusUnauthorized) + return + } + log.Println("WARNING: Insecure authentication bypass enabled: FETCH_ML_ALLOW_INSECURE_AUTH=1 and FETCH_ML_DEBUG=1; do NOT use this configuration in production.") + ctx := context.WithValue(r.Context(), userContextKey, &User{Name: "default", Admin: true}) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + // Only accept API key from header - no query parameters for security + apiKey := r.Header.Get("X-API-Key") + if apiKey == "" { + http.Error(w, "Unauthorized: Missing API key in X-API-Key header", http.StatusUnauthorized) + return + } + + user, err := c.ValidateAPIKey(apiKey) + if err != nil { + http.Error(w, "Unauthorized: Invalid API key", http.StatusUnauthorized) + return + } + + // Add user to context + ctx := context.WithValue(r.Context(), userContextKey, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetUserFromContext retrieves user from request context +func GetUserFromContext(ctx context.Context) *User { + if user, ok := ctx.Value(userContextKey).(*User); ok { + return user + } + return nil +} + +// RequireAdmin creates middleware that requires admin privileges +func RequireAdmin(next http.Handler) http.Handler { + return RequirePermission("system:admin")(next) +} + +// RequirePermission creates middleware that requires specific permission +func RequirePermission(permission string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) + return + } + + if !user.HasPermission(permission) { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// HasPermission checks if user has a specific permission +func (u *User) HasPermission(permission string) bool { + // Wildcard permission grants all access + if u.Permissions["*"] { + return true + } + + // Direct permission check + if u.Permissions[permission] { + return true + } + + // Hierarchical permission check (e.g., "jobs:create" matches "jobs") + parts := strings.Split(permission, ":") + for i := 1; i < len(parts); i++ { + partial := strings.Join(parts[:i], ":") + if u.Permissions[partial] { + return true + } + } + + return false +} + +// HasRole checks if user has a specific role +func (u *User) HasRole(role string) bool { + for _, userRole := range u.Roles { + if userRole == role { + return true + } + } + return false +} + +// getRolePermissions returns permissions for given roles +func getRolePermissions(roles []string) map[string]bool { + permissions := make(map[string]bool) + + // Use YAML permission manager if available + if pm := GetGlobalPermissionManager(); pm != nil && pm.loaded { + for _, role := range roles { + rolePerms := pm.GetRolePermissions(role) + for perm, value := range rolePerms { + permissions[perm] = value + } + } + return permissions + } + + // Fallback to inline permissions + rolePermissions := map[string]map[string]bool{ + "admin": {"*": true}, + "data_scientist": { + "jobs:create": true, "jobs:read": true, "jobs:update": true, + "data:read": true, "models:read": true, + }, + "data_engineer": { + "data:create": true, "data:read": true, "data:update": true, "data:delete": true, + }, + "viewer": { + "jobs:read": true, "data:read": true, "models:read": true, "metrics:read": true, + }, + "operator": { + "jobs:read": true, "jobs:update": true, "metrics:read": true, "system:read": true, + }, + } + + for _, role := range roles { + if rolePerms, exists := rolePermissions[role]; exists { + for perm, value := range rolePerms { + permissions[perm] = value + } + } + } + + return permissions +} + +// GenerateAPIKey generates a new random API key +func GenerateAPIKey() string { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("%x", sha256.Sum256([]byte(time.Now().String()))) + } + return hex.EncodeToString(buf) +} + +// HashAPIKey creates a SHA256 hash of an API key +func HashAPIKey(key string) string { + hash := sha256.Sum256([]byte(key)) + return hex.EncodeToString(hash[:]) +} diff --git a/internal/auth/api_key_test.go b/internal/auth/api_key_test.go new file mode 100644 index 0000000..e87b134 --- /dev/null +++ b/internal/auth/api_key_test.go @@ -0,0 +1,229 @@ +package auth + +import ( + "testing" +) + +func TestHashAPIKey(t *testing.T) { + tests := []struct { + name string + key string + expected string + }{ + { + name: "known hash", + key: "password", + expected: "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8", + }, + { + name: "another known hash", + key: "test", + expected: "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := HashAPIKey(tt.key) + if got != tt.expected { + t.Errorf("HashAPIKey() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestHashAPIKeyConsistency(t *testing.T) { + key := "my-secret-key" + hash1 := HashAPIKey(key) + hash2 := HashAPIKey(key) + + if hash1 != hash2 { + t.Errorf("HashAPIKey() not consistent: %v != %v", hash1, hash2) + } + + if len(hash1) != 64 { + t.Errorf("HashAPIKey() wrong length: got %d, want 64", len(hash1)) + } +} + +func TestGenerateAPIKey(t *testing.T) { + // Test that it generates keys + key1 := GenerateAPIKey() + + if len(key1) != 64 { + t.Errorf("GenerateAPIKey() length = %d, want 64", len(key1)) + } + + // Test uniqueness (timing-based, should be different) + key2 := GenerateAPIKey() + + if key1 == key2 { + t.Errorf("GenerateAPIKey() not unique: both generated %s", key1) + } +} + +func TestUserHasPermission(t *testing.T) { + tests := []struct { + name string + user *User + permission string + want bool + }{ + { + name: "wildcard grants all", + user: &User{ + Permissions: map[string]bool{"*": true}, + }, + permission: "anything", + want: true, + }, + { + name: "direct permission", + user: &User{ + Permissions: map[string]bool{"jobs:create": true}, + }, + permission: "jobs:create", + want: true, + }, + { + name: "hierarchical permission match", + user: &User{ + Permissions: map[string]bool{"jobs": true}, + }, + permission: "jobs:create", + want: true, + }, + { + name: "no permission", + user: &User{ + Permissions: map[string]bool{"jobs:read": true}, + }, + permission: "jobs:create", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.user.HasPermission(tt.permission) + if got != tt.want { + t.Errorf("HasPermission() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUserHasRole(t *testing.T) { + tests := []struct { + name string + user *User + role string + want bool + }{ + { + name: "has role", + user: &User{ + Roles: []string{"admin", "user"}, + }, + role: "admin", + want: true, + }, + { + name: "does not have role", + user: &User{ + Roles: []string{"user"}, + }, + role: "admin", + want: false, + }, + { + name: "empty roles", + user: &User{ + Roles: []string{}, + }, + role: "admin", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.user.HasRole(tt.role) + if got != tt.want { + t.Errorf("HasRole() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthConfigValidateAPIKey(t *testing.T) { + config := &AuthConfig{ + Enabled: true, + APIKeys: map[Username]APIKeyEntry{ + "testuser": { + Hash: APIKeyHash(HashAPIKey("test-key")), + Admin: false, + Roles: []string{"user"}, + Permissions: map[string]bool{ + "jobs:read": true, + }, + }, + "admin": { + Hash: APIKeyHash(HashAPIKey("admin-key")), + Admin: true, + }, + }, + } + + tests := []struct { + name string + key string + wantErr bool + wantAdmin bool + }{ + { + name: "valid user key", + key: "test-key", + wantErr: false, + wantAdmin: false, + }, + { + name: "valid admin key", + key: "admin-key", + wantErr: false, + wantAdmin: true, + }, + { + name: "invalid key", + key: "wrong-key", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, err := config.ValidateAPIKey(tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateAPIKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && user.Admin != tt.wantAdmin { + t.Errorf("ValidateAPIKey() admin = %v, want %v", user.Admin, tt.wantAdmin) + } + }) + } +} + +func TestAuthConfigDisabled(t *testing.T) { + config := &AuthConfig{ + Enabled: false, + } + + user, err := config.ValidateAPIKey("any-key") + if err != nil { + t.Errorf("ValidateAPIKey() with auth disabled should not error: %v", err) + } + if !user.Admin { + t.Error("ValidateAPIKey() with auth disabled should return admin user") + } +} diff --git a/internal/auth/database.go b/internal/auth/database.go new file mode 100644 index 0000000..cae1fb0 --- /dev/null +++ b/internal/auth/database.go @@ -0,0 +1,210 @@ +package auth + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +// DatabaseAuthStore implements authentication using SQLite database +type DatabaseAuthStore struct { + db *sql.DB +} + +// APIKeyRecord represents an API key in the database +type APIKeyRecord struct { + ID int `json:"id"` + UserID string `json:"user_id"` + KeyHash string `json:"key_hash"` + Admin bool `json:"admin"` + Roles string `json:"roles"` // JSON array + Permissions string `json:"permissions"` // JSON object + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` +} + +// NewDatabaseAuthStore creates a new database-backed auth store +func NewDatabaseAuthStore(dbPath string) (*DatabaseAuthStore, error) { + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + store := &DatabaseAuthStore{db: db} + if err := store.init(); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + return store, nil +} + +// init creates the necessary database tables +func (s *DatabaseAuthStore) init() error { + query := ` + CREATE TABLE IF NOT EXISTS api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL UNIQUE, + key_hash TEXT NOT NULL UNIQUE, + admin BOOLEAN NOT NULL DEFAULT FALSE, + roles TEXT NOT NULL DEFAULT '[]', + permissions TEXT NOT NULL DEFAULT '{}', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME, + revoked_at DATETIME, + CHECK (json_valid(roles)), + CHECK (json_valid(permissions)) + ); + + CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash); + CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id); + CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31')); + ` + + _, err := s.db.Exec(query) + return err +} + +// ValidateAPIKey checks if an API key is valid and returns user info +func (s *DatabaseAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) { + keyHash := HashAPIKey(key) + + query := ` + SELECT user_id, admin, roles, permissions, expires_at, revoked_at + FROM api_keys + WHERE key_hash = ? + AND (revoked_at IS NULL OR revoked_at > ?) + AND (expires_at IS NULL OR expires_at > ?) + ` + + var userID string + var admin bool + var rolesJSON, permissionsJSON string + var expiresAt, revokedAt sql.NullTime + now := time.Now() + + err := s.db.QueryRowContext(ctx, query, keyHash, now, now).Scan( + &userID, &admin, &rolesJSON, &permissionsJSON, &expiresAt, &revokedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("invalid API key") + } + return nil, fmt.Errorf("database error: %w", err) + } + + // Parse roles + var roles []string + if err := json.Unmarshal([]byte(rolesJSON), &roles); err != nil { + log.Printf("Failed to parse roles for user %s: %v", userID, err) + roles = []string{} + } + + // Parse permissions + var permissions map[string]bool + if err := json.Unmarshal([]byte(permissionsJSON), &permissions); err != nil { + log.Printf("Failed to parse permissions for user %s: %v", userID, err) + permissions = make(map[string]bool) + } + + // Admin gets all permissions + if admin { + permissions["*"] = true + } + + return &User{ + Name: userID, + Admin: admin, + Roles: roles, + Permissions: permissions, + }, nil +} + +// CreateAPIKey creates a new API key in the database +func (s *DatabaseAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error { + rolesJSON, err := json.Marshal(roles) + if err != nil { + return fmt.Errorf("failed to marshal roles: %w", err) + } + + permissionsJSON, err := json.Marshal(permissions) + if err != nil { + return fmt.Errorf("failed to marshal permissions: %w", err) + } + + query := ` + INSERT INTO api_keys (user_id, key_hash, admin, roles, permissions, expires_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET + key_hash = excluded.key_hash, + admin = excluded.admin, + roles = excluded.roles, + permissions = excluded.permissions, + expires_at = excluded.expires_at, + revoked_at = NULL + ` + + _, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt) + return err +} + +// RevokeAPIKey revokes an API key +func (s *DatabaseAuthStore) RevokeAPIKey(ctx context.Context, userID string) error { + query := `UPDATE api_keys SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?` + _, err := s.db.ExecContext(ctx, query, userID) + return err +} + +// ListUsers returns all active users +func (s *DatabaseAuthStore) ListUsers(ctx context.Context) ([]APIKeyRecord, error) { + query := ` + SELECT id, user_id, key_hash, admin, roles, permissions, created_at, expires_at, revoked_at + FROM api_keys + WHERE revoked_at IS NULL + ORDER BY created_at DESC + ` + + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query users: %w", err) + } + defer rows.Close() + + var users []APIKeyRecord + for rows.Next() { + var user APIKeyRecord + err := rows.Scan( + &user.ID, &user.UserID, &user.KeyHash, &user.Admin, + &user.Roles, &user.Permissions, &user.CreatedAt, + &user.ExpiresAt, &user.RevokedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan user: %w", err) + } + users = append(users, user) + } + + return users, nil +} + +// CleanupExpiredKeys removes expired and revoked keys +func (s *DatabaseAuthStore) CleanupExpiredKeys(ctx context.Context) error { + query := ` + DELETE FROM api_keys + WHERE (revoked_at IS NOT NULL AND revoked_at < datetime('now', '-30 days')) + OR (expires_at IS NOT NULL AND expires_at < datetime('now', '-7 days')) + ` + + _, err := s.db.ExecContext(ctx, query) + return err +} + +// Close closes the database connection +func (s *DatabaseAuthStore) Close() error { + return s.db.Close() +} diff --git a/internal/auth/flags.go b/internal/auth/flags.go new file mode 100644 index 0000000..205a074 --- /dev/null +++ b/internal/auth/flags.go @@ -0,0 +1,122 @@ +package auth + +import ( + "flag" + "fmt" + "log" + "os" + "strings" +) + +// AuthFlags holds authentication-related command line flags +type AuthFlags struct { + APIKey string + APIKeyFile string + ConfigFile string + EnableAuth bool + ShowHelp bool +} + +// ParseAuthFlags parses authentication command line flags +func ParseAuthFlags() *AuthFlags { + flags := &AuthFlags{} + + flag.StringVar(&flags.APIKey, "api-key", "", "API key for authentication") + flag.StringVar(&flags.APIKeyFile, "api-key-file", "", "Path to file containing API key") + flag.StringVar(&flags.ConfigFile, "config", "", "Configuration file path") + flag.BoolVar(&flags.EnableAuth, "enable-auth", false, "Enable authentication") + flag.BoolVar(&flags.ShowHelp, "auth-help", false, "Show authentication help") + + // Custom help flag that doesn't exit + flag.Usage = func() {} + + flag.Parse() + + return flags +} + +// GetAPIKeyFromSources gets API key from multiple sources in priority order +func GetAPIKeyFromSources(flags *AuthFlags) string { + // 1. Command line flag (highest priority) + if flags.APIKey != "" { + return flags.APIKey + } + + // 2. Explicit file flag + if flags.APIKeyFile != "" { + contents, readErr := os.ReadFile(flags.APIKeyFile) + if readErr == nil { + return strings.TrimSpace(string(contents)) + } + log.Printf("Warning: Could not read API key file %s: %v", flags.APIKeyFile, readErr) + } + + // 3. Environment variable + if envKey := os.Getenv("FETCH_ML_API_KEY"); envKey != "" { + return envKey + } + + // 4. File-based key (for automated scripts) + if fileKey := os.Getenv("FETCH_ML_API_KEY_FILE"); fileKey != "" { + content, err := os.ReadFile(fileKey) + if err == nil { + return strings.TrimSpace(string(content)) + } + log.Printf("Warning: Could not read API key file %s: %v", fileKey, err) + } + + return "" +} + +// ValidateAuthFlags validates parsed authentication flags +func ValidateAuthFlags(flags *AuthFlags) error { + if flags.ShowHelp { + PrintAuthHelp() + os.Exit(0) + } + + if flags.APIKeyFile != "" { + if _, err := os.Stat(flags.APIKeyFile); err != nil { + return fmt.Errorf("api key file not found: %s", flags.APIKeyFile) + } + if err := CheckConfigFilePermissions(flags.APIKeyFile); err != nil { + log.Printf("Warning: %v", err) + } + } + + // If config file is specified, check if it exists + if flags.ConfigFile != "" { + if _, err := os.Stat(flags.ConfigFile); err != nil { + return fmt.Errorf("config file not found: %s", flags.ConfigFile) + } + + // Check file permissions + if err := CheckConfigFilePermissions(flags.ConfigFile); err != nil { + log.Printf("Warning: %v", err) + } + } + + return nil +} + +// PrintAuthHelp prints authentication-specific help +func PrintAuthHelp() { + fmt.Println("Authentication Options:") + fmt.Println(" --api-key API key for authentication") + fmt.Println(" --api-key-file Read API key from file") + fmt.Println(" --config Configuration file path") + fmt.Println(" --enable-auth Enable authentication (if disabled)") + fmt.Println(" --auth-help Show this help") + fmt.Println() + fmt.Println("Environment Variables:") + fmt.Println(" FETCH_ML_API_KEY API key for authentication") + fmt.Println(" FETCH_ML_API_KEY_FILE File containing API key") + fmt.Println(" FETCH_ML_ENV Environment (development/production)") + fmt.Println(" FETCH_ML_ALLOW_INSECURE_AUTH Allow insecure auth (dev only)") + fmt.Println() + fmt.Println("Security Notes:") + fmt.Println(" - API keys in command line may be visible in process lists") + fmt.Println(" - Environment variables are preferred for automated scripts") + fmt.Println(" - File-based keys should have restricted permissions (600)") + fmt.Println(" - Authentication is mandatory in production environments") +} diff --git a/internal/auth/hybrid.go b/internal/auth/hybrid.go new file mode 100644 index 0000000..e5eb37d --- /dev/null +++ b/internal/auth/hybrid.go @@ -0,0 +1,275 @@ +package auth + +import ( + "context" + "fmt" + "log" + "sync" + "time" +) + +// HybridAuthStore combines file-based and database authentication +// Falls back to file config if database is not available +type HybridAuthStore struct { + fileStore *AuthConfig + dbStore *DatabaseAuthStore + useDB bool + mu sync.RWMutex +} + +// NewHybridAuthStore creates a hybrid auth store +func NewHybridAuthStore(config *AuthConfig, dbPath string) (*HybridAuthStore, error) { + hybrid := &HybridAuthStore{ + fileStore: config, + useDB: false, + } + + // Try to initialize database store + if dbPath != "" { + dbStore, err := NewDatabaseAuthStore(dbPath) + if err != nil { + log.Printf("Failed to initialize database auth store, falling back to file: %v", err) + } else { + hybrid.dbStore = dbStore + hybrid.useDB = true + log.Printf("Using database authentication store") + } + } + + // If database is available, migrate file-based keys to database + if hybrid.useDB && config.Enabled && len(config.APIKeys) > 0 { + if err := hybrid.migrateFileToDatabase(context.Background()); err != nil { + log.Printf("Failed to migrate file keys to database: %v", err) + } + } + + return hybrid, nil +} + +// ValidateAPIKey validates an API key using either database or file store +func (h *HybridAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) { + h.mu.RLock() + useDB := h.useDB + h.mu.RUnlock() + + if useDB { + user, err := h.dbStore.ValidateAPIKey(ctx, key) + if err == nil { + return user, nil + } + + // If database fails, fall back to file store + log.Printf("Database auth failed, falling back to file store: %v", err) + return h.fileStore.ValidateAPIKey(key) + } + + // Use file store + return h.fileStore.ValidateAPIKey(key) +} + +// CreateAPIKey creates an API key using the preferred store +func (h *HybridAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error { + h.mu.RLock() + useDB := h.useDB + h.mu.RUnlock() + + if useDB { + err := h.dbStore.CreateAPIKey(ctx, userID, keyHash, admin, roles, permissions, expiresAt) + if err == nil { + return nil + } + + // If database fails, fall back to file store + log.Printf("Database key creation failed, using file store: %v", err) + return h.createFileAPIKey(userID, keyHash, admin, roles, permissions) + } + + // Use file store + return h.createFileAPIKey(userID, keyHash, admin, roles, permissions) +} + +// createFileAPIKey creates an API key in the file store +func (h *HybridAuthStore) createFileAPIKey(userID string, keyHash string, admin bool, roles []string, permissions map[string]bool) error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.fileStore.APIKeys == nil { + h.fileStore.APIKeys = make(map[Username]APIKeyEntry) + } + + h.fileStore.APIKeys[Username(userID)] = APIKeyEntry{ + Hash: APIKeyHash(keyHash), + Admin: admin, + Roles: roles, + Permissions: permissions, + } + + return nil +} + +// RevokeAPIKey revokes an API key +func (h *HybridAuthStore) RevokeAPIKey(ctx context.Context, userID string) error { + h.mu.RLock() + useDB := h.useDB + h.mu.RUnlock() + + if useDB { + err := h.dbStore.RevokeAPIKey(ctx, userID) + if err == nil { + return nil + } + + log.Printf("Database key revocation failed: %v", err) + } + + // Remove from file store + h.mu.Lock() + delete(h.fileStore.APIKeys, Username(userID)) + h.mu.Unlock() + + return nil +} + +// ListUsers returns all users from the active store +func (h *HybridAuthStore) ListUsers(ctx context.Context) ([]UserInfo, error) { + h.mu.RLock() + useDB := h.useDB + h.mu.RUnlock() + + if useDB { + records, err := h.dbStore.ListUsers(ctx) + if err == nil { + users := make([]UserInfo, len(records)) + for i, record := range records { + users[i] = UserInfo{ + UserID: record.UserID, + Admin: record.Admin, + KeyHash: record.KeyHash, + Created: record.CreatedAt, + Expires: record.ExpiresAt, + Revoked: record.RevokedAt, + } + } + return users, nil + } + + log.Printf("Database user listing failed: %v", err) + } + + // Use file store + return h.listFileUsers() +} + +// UserInfo represents user information for listing +type UserInfo struct { + UserID string `json:"user_id"` + Admin bool `json:"admin"` + KeyHash string `json:"key_hash"` + Created time.Time `json:"created"` + Expires *time.Time `json:"expires,omitempty"` + Revoked *time.Time `json:"revoked,omitempty"` +} + +// listFileUsers returns users from file store +func (h *HybridAuthStore) listFileUsers() ([]UserInfo, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + var users []UserInfo + for username, entry := range h.fileStore.APIKeys { + users = append(users, UserInfo{ + UserID: string(username), + Admin: entry.Admin, + KeyHash: string(entry.Hash), + Created: time.Now(), // File store doesn't track creation time + }) + } + + return users, nil +} + +// migrateFileToDatabase migrates file-based keys to database +func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error { + if h.fileStore == nil || len(h.fileStore.APIKeys) == 0 { + return nil + } + + log.Printf("Migrating %d API keys from file to database...", len(h.fileStore.APIKeys)) + + for username, entry := range h.fileStore.APIKeys { + userID := string(username) + err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil) + if err != nil { + log.Printf("Failed to migrate key for user %s: %v", userID, err) + continue + } + log.Printf("Migrated key for user: %s", userID) + } + + log.Printf("Migration completed. Consider removing keys from config file.") + return nil +} + +// SwitchToDatabase attempts to switch to database authentication +func (h *HybridAuthStore) SwitchToDatabase(dbPath string) error { + dbStore, err := NewDatabaseAuthStore(dbPath) + if err != nil { + return fmt.Errorf("failed to create database store: %w", err) + } + + h.mu.Lock() + defer h.mu.Unlock() + + // Close existing database if any + if h.dbStore != nil { + h.dbStore.Close() + } + + h.dbStore = dbStore + h.useDB = true + + // Migrate existing keys + if h.fileStore.Enabled && len(h.fileStore.APIKeys) > 0 { + if err := h.migrateFileToDatabase(context.Background()); err != nil { + log.Printf("Migration warning: %v", err) + } + } + + return nil +} + +// Close closes the database connection +func (h *HybridAuthStore) Close() error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.dbStore != nil { + return h.dbStore.Close() + } + return nil +} + +// GetDatabaseStats returns database statistics +func (h *HybridAuthStore) GetDatabaseStats(ctx context.Context) (map[string]interface{}, error) { + h.mu.RLock() + useDB := h.useDB + h.mu.RUnlock() + + if !useDB { + return map[string]interface{}{ + "store_type": "file", + "users": len(h.fileStore.APIKeys), + }, nil + } + + users, err := h.dbStore.ListUsers(ctx) + if err != nil { + return nil, err + } + + return map[string]interface{}{ + "store_type": "database", + "users": len(users), + "path": "db/fetch_ml.db", + }, nil +} diff --git a/internal/auth/keychain.go b/internal/auth/keychain.go new file mode 100644 index 0000000..fb5ebfa --- /dev/null +++ b/internal/auth/keychain.go @@ -0,0 +1,167 @@ +package auth + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/zalando/go-keyring" +) + +// KeychainManager provides secure storage for API keys. +type KeychainManager struct { + primary systemKeyring + fallback *fileKeyStore +} + +// systemKeyring abstracts go-keyring for easier testing. +type systemKeyring interface { + Set(service, account, secret string) error + Get(service, account string) (string, error) + Delete(service, account string) error +} + +type goKeyring struct{} + +func (goKeyring) Set(service, account, secret string) error { + return keyring.Set(service, account, secret) +} + +func (goKeyring) Get(service, account string) (string, error) { + return keyring.Get(service, account) +} + +func (goKeyring) Delete(service, account string) error { + return keyring.Delete(service, account) +} + +// NewKeychainManager returns a manager backed by the OS keyring with a secure file fallback. +func NewKeychainManager() *KeychainManager { + return newKeychainManagerWithKeyring(goKeyring{}, defaultFallbackDir()) +} + +func newKeychainManagerWithKeyring(kr systemKeyring, fallbackDir string) *KeychainManager { + if fallbackDir == "" { + fallbackDir = defaultFallbackDir() + } + return &KeychainManager{ + primary: kr, + fallback: newFileKeyStore(fallbackDir), + } +} + +func defaultFallbackDir() string { + home, err := os.UserHomeDir() + if err != nil || home == "" { + return filepath.Join(os.TempDir(), "fetch_ml", "keys") + } + return filepath.Join(home, ".fetch_ml", "keys") +} + +// StoreAPIKey stores the key in the OS keyring, falling back to a protected file when needed. +func (km *KeychainManager) StoreAPIKey(service, account, key string) error { + if err := km.primary.Set(service, account, key); err == nil { + return nil + } else if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) { + return km.fallback.store(service, account, key) + } else if fallbackErr := km.fallback.store(service, account, key); fallbackErr == nil { + return nil + } + return fmt.Errorf("failed to store API key") +} + +// GetAPIKey retrieves a key from the OS keyring or fallback store. +func (km *KeychainManager) GetAPIKey(service, account string) (string, error) { + secret, err := km.primary.Get(service, account) + if err == nil { + return secret, nil + } + if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) { + return km.fallback.get(service, account) + } + // Unknown error - try fallback before surfacing + if fallbackSecret, ferr := km.fallback.get(service, account); ferr == nil { + return fallbackSecret, nil + } + return "", fmt.Errorf("failed to retrieve API key") +} + +// DeleteAPIKey removes a key from both stores. +func (km *KeychainManager) DeleteAPIKey(service, account string) error { + if err := km.primary.Delete(service, account); err != nil && !errors.Is(err, keyring.ErrNotFound) && !errors.Is(err, keyring.ErrUnsupportedPlatform) { + return fmt.Errorf("failed to delete API key: %w", err) + } + if err := km.fallback.delete(service, account); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + return nil +} + +// IsAvailable reports whether the OS keyring backend is usable. +func (km *KeychainManager) IsAvailable() bool { + _, err := km.primary.Get("fetch_ml_probe", fmt.Sprintf("probe_%d", time.Now().UnixNano())) + return err == nil || !errors.Is(err, keyring.ErrUnsupportedPlatform) +} + +// ListAvailableMethods returns backends the manager can use. +func (km *KeychainManager) ListAvailableMethods() []string { + methods := []string{} + if km.IsAvailable() { + methods = append(methods, "OS keyring") + } + methods = append(methods, fmt.Sprintf("Encrypted file (%s)", km.fallback.baseDir)) + return methods +} + +// fileKeyStore stores secrets with 0600 permissions as a fallback. +type fileKeyStore struct { + baseDir string + mu sync.Mutex +} + +func newFileKeyStore(baseDir string) *fileKeyStore { + return &fileKeyStore{baseDir: baseDir} +} + +func (f *fileKeyStore) store(service, account, secret string) error { + f.mu.Lock() + defer f.mu.Unlock() + if err := os.MkdirAll(f.baseDir, 0o700); err != nil { + return fmt.Errorf("failed to prepare key store: %w", err) + } + path := f.path(service, account) + return os.WriteFile(path, []byte(secret), 0o600) +} + +func (f *fileKeyStore) get(service, account string) (string, error) { + f.mu.Lock() + defer f.mu.Unlock() + data, err := os.ReadFile(f.path(service, account)) + if err != nil { + return "", err + } + return string(data), nil +} + +func (f *fileKeyStore) delete(service, account string) error { + f.mu.Lock() + defer f.mu.Unlock() + path := f.path(service, account) + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + return nil +} + +func (f *fileKeyStore) path(service, account string) string { + return filepath.Join(f.baseDir, fmt.Sprintf("%s_%s.key", sanitize(service), sanitize(account))) +} + +func sanitize(value string) string { + replacer := strings.NewReplacer("/", "_", "\\", "_", "..", "_", " ", "_", "\t", "_") + return replacer.Replace(value) +} diff --git a/internal/auth/keychain_test.go b/internal/auth/keychain_test.go new file mode 100644 index 0000000..62ff9da --- /dev/null +++ b/internal/auth/keychain_test.go @@ -0,0 +1,129 @@ +package auth + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/zalando/go-keyring" +) + +type fakeKeyring struct { + secrets map[string]string + setErr error + getErr error + deleteErr error +} + +func newFakeKeyring() *fakeKeyring { + return &fakeKeyring{secrets: make(map[string]string)} +} + +func (f *fakeKeyring) Set(service, account, secret string) error { + if f.setErr != nil { + return f.setErr + } + f.secrets[key(service, account)] = secret + return nil +} + +func (f *fakeKeyring) Get(service, account string) (string, error) { + if f.getErr != nil { + return "", f.getErr + } + if secret, ok := f.secrets[key(service, account)]; ok { + return secret, nil + } + return "", keyring.ErrNotFound +} + +func (f *fakeKeyring) Delete(service, account string) error { + if f.deleteErr != nil { + return f.deleteErr + } + delete(f.secrets, key(service, account)) + return nil +} + +func key(service, account string) string { + return service + ":" + account +} + +func newTestManager(t *testing.T, kr systemKeyring) (*KeychainManager, string) { + t.Helper() + baseDir := t.TempDir() + return newKeychainManagerWithKeyring(kr, baseDir), baseDir +} + +func TestKeychainStoreAndGetPrimary(t *testing.T) { + kr := newFakeKeyring() + km, baseDir := newTestManager(t, kr) + + if err := km.StoreAPIKey("fetch-ml", "alice", "super-secret"); err != nil { + t.Fatalf("StoreAPIKey failed: %v", err) + } + + got, err := km.GetAPIKey("fetch-ml", "alice") + if err != nil { + t.Fatalf("GetAPIKey failed: %v", err) + } + if got != "super-secret" { + t.Fatalf("expected secret to be stored in primary keyring") + } + + // Ensure fallback file was not created when primary succeeds + path := filepath.Join(baseDir, filepath.Base(km.fallback.path("fetch-ml", "alice"))) + if _, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected no fallback file, got err=%v", err) + } +} + +func TestKeychainFallbackWhenUnsupported(t *testing.T) { + kr := newFakeKeyring() + kr.setErr = keyring.ErrUnsupportedPlatform + kr.getErr = keyring.ErrUnsupportedPlatform + kr.deleteErr = keyring.ErrUnsupportedPlatform + km, _ := newTestManager(t, kr) + + if err := km.StoreAPIKey("fetch-ml", "bob", "fallback-secret"); err != nil { + t.Fatalf("StoreAPIKey should fallback: %v", err) + } + + got, err := km.GetAPIKey("fetch-ml", "bob") + if err != nil { + t.Fatalf("GetAPIKey should use fallback: %v", err) + } + if got != "fallback-secret" { + t.Fatalf("expected fallback secret, got %s", got) + } +} + +func TestKeychainDeleteRemovesFallback(t *testing.T) { + kr := newFakeKeyring() + kr.deleteErr = keyring.ErrNotFound + km, _ := newTestManager(t, kr) + + if err := km.fallback.store("fetch-ml", "carol", "temp"); err != nil { + t.Fatalf("failed to seed fallback store: %v", err) + } + + if err := km.DeleteAPIKey("fetch-ml", "carol"); err != nil { + t.Fatalf("DeleteAPIKey failed: %v", err) + } + + if _, err := km.fallback.get("fetch-ml", "carol"); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected fallback secret removed, err=%v", err) + } +} + +func TestListAvailableMethodsIncludesFallback(t *testing.T) { + kr := newFakeKeyring() + kr.getErr = keyring.ErrUnsupportedPlatform + km, _ := newTestManager(t, kr) + + methods := km.ListAvailableMethods() + if len(methods) != 1 || methods[0] == "OS keyring" { + t.Fatalf("expected only fallback method, got %v", methods) + } +} diff --git a/internal/auth/permissions.go b/internal/auth/permissions.go new file mode 100644 index 0000000..0a825af --- /dev/null +++ b/internal/auth/permissions.go @@ -0,0 +1,192 @@ +package auth + +import ( + "fmt" + "strings" +) + +// Permission constants for type safety +const ( + // Job permissions + PermissionJobsCreate = "jobs:create" + PermissionJobsRead = "jobs:read" + PermissionJobsUpdate = "jobs:update" + PermissionJobsDelete = "jobs:delete" + + // Data permissions + PermissionDataCreate = "data:create" + PermissionDataRead = "data:read" + PermissionDataUpdate = "data:update" + PermissionDataDelete = "data:delete" + + // Model permissions + PermissionModelsCreate = "models:create" + PermissionModelsRead = "models:read" + PermissionModelsUpdate = "models:update" + PermissionModelsDelete = "models:delete" + + // System permissions + PermissionSystemConfig = "system:config" + PermissionSystemMetrics = "system:metrics" + PermissionSystemLogs = "system:logs" + PermissionSystemUsers = "system:users" + + // Wildcard permission + PermissionAll = "*" +) + +// Role constants +const ( + RoleAdmin = "admin" + RoleDataScientist = "data_scientist" + RoleDataEngineer = "data_engineer" + RoleViewer = "viewer" + RoleOperator = "operator" +) + +// PermissionGroup represents a group of related permissions +type PermissionGroup struct { + Name string + Permissions []string + Description string +} + +// Built-in permission groups +var PermissionGroups = map[string]PermissionGroup{ + "full_access": { + Name: "Full Access", + Permissions: []string{PermissionAll}, + Description: "Complete system access", + }, + "job_management": { + Name: "Job Management", + Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete}, + Description: "Create, read, update, and delete ML jobs", + }, + "data_access": { + Name: "Data Access", + Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete}, + Description: "Access and manage datasets", + }, + "readonly": { + Name: "Read Only", + Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics}, + Description: "View-only access to system resources", + }, + "system_admin": { + Name: "System Administration", + Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics}, + Description: "System configuration and user management", + }, +} + +// GetPermissionGroup returns a permission group by name +func GetPermissionGroup(name string) (PermissionGroup, bool) { + group, exists := PermissionGroups[name] + return group, exists +} + +// ValidatePermission checks if a permission string is valid +func ValidatePermission(permission string) error { + if permission == PermissionAll { + return nil + } + + // Check if permission matches known patterns + validPrefixes := []string{"jobs:", "data:", "models:", "system:"} + for _, prefix := range validPrefixes { + if strings.HasPrefix(permission, prefix) { + return nil + } + } + + return fmt.Errorf("invalid permission format: %s", permission) +} + +// ValidateRole checks if a role is valid +func ValidateRole(role string) error { + validRoles := []string{RoleAdmin, RoleDataScientist, RoleDataEngineer, RoleViewer, RoleOperator} + for _, validRole := range validRoles { + if role == validRole { + return nil + } + } + return fmt.Errorf("invalid role: %s", role) +} + +// ExpandPermissionGroups converts permission group names to actual permissions +func ExpandPermissionGroups(groups []string) ([]string, error) { + var permissions []string + + for _, groupName := range groups { + if groupName == PermissionAll { + return []string{PermissionAll}, nil + } + + group, exists := GetPermissionGroup(groupName) + if !exists { + return nil, fmt.Errorf("unknown permission group: %s", groupName) + } + + permissions = append(permissions, group.Permissions...) + } + + // Remove duplicates + unique := make(map[string]bool) + for _, perm := range permissions { + unique[perm] = true + } + + result := make([]string, 0, len(unique)) + for perm := range unique { + result = append(result, perm) + } + + return result, nil +} + +// PermissionCheckResult represents the result of a permission check +type PermissionCheckResult struct { + Allowed bool `json:"allowed"` + Permission string `json:"permission"` + User string `json:"user"` + Roles []string `json:"roles"` + Missing []string `json:"missing,omitempty"` +} + +// CheckMultiplePermissions checks multiple permissions at once +func (u *User) CheckMultiplePermissions(permissions []string) []PermissionCheckResult { + results := make([]PermissionCheckResult, len(permissions)) + + for i, permission := range permissions { + allowed := u.HasPermission(permission) + missing := []string{} + if !allowed { + missing = []string{permission} + } + + results[i] = PermissionCheckResult{ + Allowed: allowed, + Permission: permission, + User: u.Name, + Roles: u.Roles, + Missing: missing, + } + } + + return results +} + +// GetEffectivePermissions returns all effective permissions for a user +func (u *User) GetEffectivePermissions() []string { + if u.Permissions[PermissionAll] { + return []string{PermissionAll} + } + + permissions := make([]string, 0, len(u.Permissions)) + for perm := range u.Permissions { + permissions = append(permissions, perm) + } + + return permissions +} diff --git a/internal/auth/permissions_loader.go b/internal/auth/permissions_loader.go new file mode 100644 index 0000000..4f70018 --- /dev/null +++ b/internal/auth/permissions_loader.go @@ -0,0 +1,295 @@ +package auth + +import ( + "fmt" + "os" + "sync" + + "gopkg.in/yaml.v3" +) + +// PermissionConfig represents the permissions configuration +type PermissionConfig struct { + Roles map[string]RoleConfig `yaml:"roles"` + Groups map[string]GroupConfig `yaml:"groups"` + Hierarchy map[string]HierarchyConfig `yaml:"hierarchy"` + Defaults DefaultsConfig `yaml:"defaults"` +} + +// RoleConfig defines a role and its permissions +type RoleConfig struct { + Description string `yaml:"description"` + Permissions []string `yaml:"permissions"` +} + +// GroupConfig defines a permission group +type GroupConfig struct { + Description string `yaml:"description"` + Inherits []string `yaml:"inherits"` + Permissions []string `yaml:"permissions"` +} + +// HierarchyConfig defines resource hierarchy +type HierarchyConfig struct { + Children map[string]interface{} `yaml:"children"` + Special map[string]string `yaml:"special"` +} + +// DefaultsConfig defines default settings +type DefaultsConfig struct { + NewUserRole string `yaml:"new_user_role"` + AdminUsers []string `yaml:"admin_users"` +} + +// PermissionManager manages permissions from YAML file +type PermissionManager struct { + config *PermissionConfig + rolePerms map[string]map[string]bool + groupPerms map[string]map[string]bool + mu sync.RWMutex + loaded bool +} + +// NewPermissionManager creates a new permission manager +func NewPermissionManager(configPath string) (*PermissionManager, error) { + pm := &PermissionManager{} + + if err := pm.loadConfig(configPath); err != nil { + return nil, fmt.Errorf("failed to load permissions: %w", err) + } + + return pm, nil +} + +// loadConfig loads permissions from YAML file +func (pm *PermissionManager) loadConfig(configPath string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read permissions file: %w", err) + } + + var config PermissionConfig + if err := yaml.Unmarshal(data, &config); err != nil { + return fmt.Errorf("failed to parse permissions file: %w", err) + } + + pm.config = &config + pm.rolePerms = make(map[string]map[string]bool) + pm.groupPerms = make(map[string]map[string]bool) + + // Process role permissions + for roleName, role := range config.Roles { + perms := make(map[string]bool) + for _, perm := range role.Permissions { + perms[perm] = true + } + pm.rolePerms[roleName] = perms + } + + // Process group permissions + for groupName, group := range config.Groups { + perms := make(map[string]bool) + + // Add direct permissions + for _, perm := range group.Permissions { + perms[perm] = true + } + + // Inherit permissions from other roles/groups + for _, inherit := range group.Inherits { + if rolePerms, exists := pm.rolePerms[inherit]; exists { + for perm, value := range rolePerms { + perms[perm] = value + } + } + if groupPerms, exists := pm.groupPerms[inherit]; exists { + for perm, value := range groupPerms { + perms[perm] = value + } + } + } + + pm.groupPerms[groupName] = perms + } + + pm.loaded = true + return nil +} + +// GetRolePermissions returns permissions for a role +func (pm *PermissionManager) GetRolePermissions(role string) map[string]bool { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return make(map[string]bool) + } + + if perms, exists := pm.rolePerms[role]; exists { + result := make(map[string]bool) + for perm, value := range perms { + result[perm] = value + } + return result + } + + return make(map[string]bool) +} + +// GetGroupPermissions returns permissions for a group +func (pm *PermissionManager) GetGroupPermissions(group string) map[string]bool { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return make(map[string]bool) + } + + if perms, exists := pm.groupPerms[group]; exists { + result := make(map[string]bool) + for perm, value := range perms { + result[perm] = value + } + return result + } + + return make(map[string]bool) +} + +// GetAllRoles returns all available roles +func (pm *PermissionManager) GetAllRoles() map[string]RoleConfig { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return make(map[string]RoleConfig) + } + + result := make(map[string]RoleConfig) + for name, role := range pm.config.Roles { + result[name] = role + } + return result +} + +// GetAllGroups returns all available groups +func (pm *PermissionManager) GetAllGroups() map[string]GroupConfig { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return make(map[string]GroupConfig) + } + + result := make(map[string]GroupConfig) + for name, group := range pm.config.Groups { + result[name] = group + } + return result +} + +// GetDefaultRole returns the default role for new users +func (pm *PermissionManager) GetDefaultRole() string { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded || pm.config.Defaults.NewUserRole == "" { + return "viewer" + } + + return pm.config.Defaults.NewUserRole +} + +// IsAdminUser checks if a username should have admin rights +func (pm *PermissionManager) IsAdminUser(username string) bool { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return false + } + + for _, adminUser := range pm.config.Defaults.AdminUsers { + if adminUser == username { + return true + } + } + return false +} + +// ReloadConfig reloads the permissions configuration +func (pm *PermissionManager) ReloadConfig(configPath string) error { + return pm.loadConfig(configPath) +} + +// ValidatePermission checks if a permission string is valid +func (pm *PermissionManager) ValidatePermission(permission string) bool { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return false + } + + // Wildcard is always valid + if permission == "*" { + return true + } + + // Check if permission matches any defined role permissions + for _, rolePerms := range pm.rolePerms { + if _, exists := rolePerms[permission]; exists { + return true + } + } + + // Check if permission matches any group permissions + for _, groupPerms := range pm.groupPerms { + if _, exists := groupPerms[permission]; exists { + return true + } + } + + return false +} + +// GetPermissionHierarchy returns the hierarchy for a resource +func (pm *PermissionManager) GetPermissionHierarchy(resource string) map[string]interface{} { + pm.mu.RLock() + defer pm.mu.RUnlock() + + if !pm.loaded { + return make(map[string]interface{}) + } + + if hierarchy, exists := pm.config.Hierarchy[resource]; exists { + return hierarchy.Children + } + + return make(map[string]interface{}) +} + +// Global permission manager instance +var globalPermissionManager *PermissionManager +var permissionManagerOnce sync.Once + +// GetGlobalPermissionManager returns the global permission manager +func GetGlobalPermissionManager() *PermissionManager { + permissionManagerOnce.Do(func() { + // Try to load from default location + if pm, err := NewPermissionManager("configs/schema/permissions.yaml"); err == nil { + globalPermissionManager = pm + } else { + // Fallback to empty manager + globalPermissionManager = &PermissionManager{ + rolePerms: make(map[string]map[string]bool), + groupPerms: make(map[string]map[string]bool), + loaded: false, + } + } + }) + return globalPermissionManager +} diff --git a/internal/auth/validator.go b/internal/auth/validator.go new file mode 100644 index 0000000..d7da43d --- /dev/null +++ b/internal/auth/validator.go @@ -0,0 +1,100 @@ +package auth + +import ( + "fmt" + "log" + "os" + "strings" +) + +// ValidateAuthConfig enforces authentication requirements +func (c *AuthConfig) ValidateAuthConfig() error { + // Check if we're in production environment + isProduction := os.Getenv("FETCH_ML_ENV") == "production" + + if isProduction { + if !c.Enabled { + return fmt.Errorf("authentication must be enabled in production environment") + } + + if len(c.APIKeys) == 0 { + return fmt.Errorf("at least one API key must be configured in production") + } + + // Ensure at least one admin user exists + hasAdmin := false + for _, entry := range c.APIKeys { + if entry.Admin { + hasAdmin = true + break + } + } + + if !hasAdmin { + return fmt.Errorf("at least one admin user must be configured in production") + } + + // Check for insecure development override + if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") == "1" { + log.Printf("WARNING: FETCH_ML_ALLOW_INSECURE_AUTH is enabled in production - this is insecure") + } + } + + // Validate API key format + for username, entry := range c.APIKeys { + if string(username) == "" { + return fmt.Errorf("empty username not allowed") + } + + if entry.Hash == "" { + return fmt.Errorf("user %s has empty API key hash", username) + } + + // Validate hash format (should be 64 hex chars for SHA256) + if len(entry.Hash) != 64 { + return fmt.Errorf("user %s has invalid API key hash format", username) + } + + // Check hash contains only hex characters + for _, char := range entry.Hash { + if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) { + return fmt.Errorf("user %s has invalid API key hash characters", username) + } + } + } + + return nil +} + +// CheckConfigFilePermissions ensures config files have secure permissions +func CheckConfigFilePermissions(configPath string) error { + info, err := os.Stat(configPath) + if err != nil { + return fmt.Errorf("cannot stat config file: %w", err) + } + + // Check file permissions (should be 600 or 640) + perm := info.Mode().Perm() + if perm&0077 != 0 { + return fmt.Errorf("config file %s has insecure permissions: %o (should be 600 or 640)", configPath, perm) + } + + return nil +} + +// SanitizeConfig removes sensitive information for logging +func (c *AuthConfig) SanitizeConfig() map[string]interface{} { + sanitized := map[string]interface{}{ + "enabled": c.Enabled, + "users": make(map[string]interface{}), + } + + for username := range c.APIKeys { + sanitized["users"].(map[string]interface{})[string(username)] = map[string]interface{}{ + "admin": c.APIKeys[username].Admin, + "hash": strings.Repeat("*", 8) + "...", // Show only prefix + } + } + + return sanitized +} diff --git a/internal/config/constants.go b/internal/config/constants.go new file mode 100644 index 0000000..1a38677 --- /dev/null +++ b/internal/config/constants.go @@ -0,0 +1,54 @@ +package config + +// Default configuration values (legacy - use SmartDefaults for new code) +const ( + DefaultSSHPort = 22 + DefaultRedisPort = 6379 + DefaultRedisAddr = "localhost:6379" + DefaultBasePath = "/mnt/nas/jobs" + DefaultTrainScript = "train.py" + DefaultDataDir = "/data/active" + DefaultLocalDataDir = "./data/active" + DefaultNASDataDir = "/mnt/datasets" + DefaultMaxWorkers = 2 + DefaultPollInterval = 5 + DefaultMaxAgeHours = 24 + DefaultMaxSizeGB = 100 + DefaultCleanupInterval = 60 +) + +// Redis key prefixes +const ( + RedisTaskQueueKey = "ml:queue" + RedisTaskPrefix = "ml:task:" + RedisJobMetricsPrefix = "ml:metrics:" + RedisTaskStatusPrefix = "ml:status:" + RedisDatasetPrefix = "ml:dataset:" + RedisWorkerHeartbeat = "ml:workers:heartbeat" +) + +// Task status constants +const ( + TaskStatusQueued = "queued" + TaskStatusRunning = "running" + TaskStatusCompleted = "completed" + TaskStatusFailed = "failed" + TaskStatusCancelled = "cancelled" +) + +// Job status constants +const ( + JobStatusPending = "pending" + JobStatusQueued = "queued" + JobStatusRunning = "running" + JobStatusFinished = "finished" + JobStatusFailed = "failed" +) + +// Podman defaults +const ( + DefaultPodmanMemory = "8g" + DefaultPodmanCPUs = "2" + DefaultContainerWorkspace = "/workspace" + DefaultContainerResults = "/workspace/results" +) diff --git a/internal/config/paths.go b/internal/config/paths.go new file mode 100644 index 0000000..13b4981 --- /dev/null +++ b/internal/config/paths.go @@ -0,0 +1,73 @@ +// Package config provides shared utilities for the fetch_ml project. +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// ExpandPath expands environment variables and tilde in a path +func ExpandPath(path string) string { + if path == "" { + return "" + } + expanded := os.ExpandEnv(path) + if strings.HasPrefix(expanded, "~") { + home, err := os.UserHomeDir() + if err == nil { + expanded = filepath.Join(home, expanded[1:]) + } + } + return expanded +} + +// ResolveConfigPath resolves a config file path, checking multiple locations +func ResolveConfigPath(path string) (string, error) { + candidates := []string{path} + if !filepath.IsAbs(path) { + candidates = append(candidates, filepath.Join("configs", path)) + } + + var checked []string + for _, candidate := range candidates { + resolved := ExpandPath(candidate) + checked = append(checked, resolved) + if _, err := os.Stat(resolved); err == nil { + return resolved, nil + } + } + + return "", fmt.Errorf("config file not found (looked in %s)", strings.Join(checked, ", ")) +} + +// JobPaths provides helper methods for job directory paths +type JobPaths struct { + BasePath string +} + +// NewJobPaths creates a new JobPaths instance +func NewJobPaths(basePath string) *JobPaths { + return &JobPaths{BasePath: basePath} +} + +// PendingPath returns the path to pending jobs directory +func (j *JobPaths) PendingPath() string { + return filepath.Join(j.BasePath, "pending") +} + +// RunningPath returns the path to running jobs directory +func (j *JobPaths) RunningPath() string { + return filepath.Join(j.BasePath, "running") +} + +// FinishedPath returns the path to finished jobs directory +func (j *JobPaths) FinishedPath() string { + return filepath.Join(j.BasePath, "finished") +} + +// FailedPath returns the path to failed jobs directory +func (j *JobPaths) FailedPath() string { + return filepath.Join(j.BasePath, "failed") +} diff --git a/internal/config/smart_defaults.go b/internal/config/smart_defaults.go new file mode 100644 index 0000000..b27e157 --- /dev/null +++ b/internal/config/smart_defaults.go @@ -0,0 +1,222 @@ +package config + +import ( + "os" + "path/filepath" + "runtime" + "strings" +) + +// EnvironmentProfile represents the deployment environment +type EnvironmentProfile int + +const ( + ProfileLocal EnvironmentProfile = iota + ProfileContainer + ProfileCI + ProfileProduction +) + +// DetectEnvironment determines the current environment profile +func DetectEnvironment() EnvironmentProfile { + // CI detection + if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("GITLAB_CI") != "" { + return ProfileCI + } + + // Container detection + if _, err := os.Stat("/.dockerenv"); err == nil { + return ProfileContainer + } + if os.Getenv("KUBERNETES_SERVICE_HOST") != "" { + return ProfileContainer + } + if os.Getenv("CONTAINER") != "" { + return ProfileContainer + } + + // Production detection (customizable) + if os.Getenv("FETCH_ML_ENV") == "production" || os.Getenv("ENV") == "production" { + return ProfileProduction + } + + // Default to local development + return ProfileLocal +} + +// SmartDefaults provides environment-aware default values +type SmartDefaults struct { + Profile EnvironmentProfile +} + +// GetSmartDefaults returns defaults for the current environment +func GetSmartDefaults() *SmartDefaults { + return &SmartDefaults{ + Profile: DetectEnvironment(), + } +} + +// Host returns the appropriate default host +func (s *SmartDefaults) Host() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "host.docker.internal" // Docker Desktop/Colima + case ProfileProduction: + return "0.0.0.0" + default: // ProfileLocal + return "localhost" + } +} + +// BasePath returns the appropriate default base path +func (s *SmartDefaults) BasePath() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "/workspace/ml-experiments" + case ProfileProduction: + return "/var/lib/fetch_ml/experiments" + default: // ProfileLocal + if home, err := os.UserHomeDir(); err == nil { + return filepath.Join(home, "ml-experiments") + } + return "./ml-experiments" + } +} + +// DataDir returns the appropriate default data directory +func (s *SmartDefaults) DataDir() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "/workspace/data" + case ProfileProduction: + return "/var/lib/fetch_ml/data" + default: // ProfileLocal + if home, err := os.UserHomeDir(); err == nil { + return filepath.Join(home, "ml-data") + } + return "./data" + } +} + +// RedisAddr returns the appropriate default Redis address +func (s *SmartDefaults) RedisAddr() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "redis:6379" // Service name in containers + case ProfileProduction: + return "redis:6379" + default: // ProfileLocal + return "localhost:6379" + } +} + +// SSHKeyPath returns the appropriate default SSH key path +func (s *SmartDefaults) SSHKeyPath() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "/workspace/.ssh/id_rsa" + case ProfileProduction: + return "/etc/fetch_ml/ssh/id_rsa" + default: // ProfileLocal + if home, err := os.UserHomeDir(); err == nil { + return filepath.Join(home, ".ssh", "id_rsa") + } + return "~/.ssh/id_rsa" + } +} + +// KnownHostsPath returns the appropriate default known_hosts path +func (s *SmartDefaults) KnownHostsPath() string { + switch s.Profile { + case ProfileContainer, ProfileCI: + return "/workspace/.ssh/known_hosts" + case ProfileProduction: + return "/etc/fetch_ml/ssh/known_hosts" + default: // ProfileLocal + if home, err := os.UserHomeDir(); err == nil { + return filepath.Join(home, ".ssh", "known_hosts") + } + return "~/.ssh/known_hosts" + } +} + +// LogLevel returns the appropriate default log level +func (s *SmartDefaults) LogLevel() string { + switch s.Profile { + case ProfileCI: + return "debug" // More verbose for CI debugging + case ProfileProduction: + return "info" + default: // ProfileLocal, ProfileContainer + return "info" + } +} + +// MaxWorkers returns the appropriate default worker count +func (s *SmartDefaults) MaxWorkers() int { + switch s.Profile { + case ProfileCI: + return 1 // Conservative for CI + case ProfileProduction: + return runtime.NumCPU() // Scale with CPU cores + default: // ProfileLocal, ProfileContainer + return 2 // Reasonable default for local dev + } +} + +// PollInterval returns the appropriate default poll interval in seconds +func (s *SmartDefaults) PollInterval() int { + switch s.Profile { + case ProfileCI: + return 1 // Fast polling for quick tests + case ProfileProduction: + return 10 // Conservative for production + default: // ProfileLocal, ProfileContainer + return 5 // Balanced default + } +} + +// IsInContainer returns true if running in a container environment +func (s *SmartDefaults) IsInContainer() bool { + return s.Profile == ProfileContainer || s.Profile == ProfileCI +} + +// IsProduction returns true if this is a production environment +func (s *SmartDefaults) IsProduction() bool { + return s.Profile == ProfileProduction +} + +// IsCI returns true if this is a CI environment +func (s *SmartDefaults) IsCI() bool { + return s.Profile == ProfileCI +} + +// ExpandPath expands ~ and environment variables in paths +func (s *SmartDefaults) ExpandPath(path string) string { + if strings.HasPrefix(path, "~/") { + if home, err := os.UserHomeDir(); err == nil { + path = filepath.Join(home, path[2:]) + } + } + + // Expand environment variables + path = os.ExpandEnv(path) + + return path +} + +// GetEnvironmentDescription returns a human-readable description +func (s *SmartDefaults) GetEnvironmentDescription() string { + switch s.Profile { + case ProfileLocal: + return "Local Development" + case ProfileContainer: + return "Container Environment" + case ProfileCI: + return "CI/CD Environment" + case ProfileProduction: + return "Production Environment" + default: + return "Unknown Environment" + } +} diff --git a/internal/config/validation.go b/internal/config/validation.go new file mode 100644 index 0000000..7bdb3d3 --- /dev/null +++ b/internal/config/validation.go @@ -0,0 +1,69 @@ +// Package utils provides shared utilities for the fetch_ml project, +// including SSH clients, configuration helpers, logging, metrics, +// and validation functions. +package config + +import ( + "fmt" + "net" + "os" +) + +// Validator is an interface for types that can validate themselves. +type Validator interface { + Validate() error +} + +// ValidateConfig validates a configuration struct that implements the Validator interface. +func ValidateConfig(v Validator) error { + return v.Validate() +} + +// ValidatePort checks if a port number is within the valid range (1-65535). +func ValidatePort(port int) error { + if port < 1 || port > 65535 { + return fmt.Errorf("invalid port: %d (must be 1-65535)", port) + } + return nil +} + +// ValidateDirectory checks if a path exists and is a directory. +func ValidateDirectory(path string) error { + if path == "" { + return fmt.Errorf("path cannot be empty") + } + + expanded := ExpandPath(path) + info, err := os.Stat(expanded) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("directory does not exist: %s", expanded) + } + return fmt.Errorf("cannot access directory %s: %w", expanded, err) + } + + if !info.IsDir() { + return fmt.Errorf("path is not a directory: %s", expanded) + } + + return nil +} + +// ValidateRedisAddr validates a Redis address in the format "host:port". +func ValidateRedisAddr(addr string) error { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("invalid redis address format: %w", err) + } + + if host == "" { + return fmt.Errorf("redis host cannot be empty") + } + + var portInt int + if _, err := fmt.Sscanf(port, "%d", &portInt); err != nil { + return fmt.Errorf("invalid port number %q: %w", port, err) + } + + return ValidatePort(portInt) +} diff --git a/internal/container/podman.go b/internal/container/podman.go new file mode 100644 index 0000000..f2ff68f --- /dev/null +++ b/internal/container/podman.go @@ -0,0 +1,105 @@ +// Package utils provides shared utilities for the fetch_ml project. +package container + +import ( + "fmt" + "os/exec" + "path/filepath" + "strings" + + "github.com/jfraeys/fetch_ml/internal/config" +) + +// PodmanConfig holds configuration for Podman container execution +type PodmanConfig struct { + Image string + Workspace string + Results string + ContainerWorkspace string + ContainerResults string + GPUAccess bool + Memory string + CPUs string +} + +// BuildPodmanCommand builds a Podman command for executing ML experiments +func BuildPodmanCommand(cfg PodmanConfig, scriptPath, requirementsPath string, extraArgs []string) *exec.Cmd { + args := []string{ + "run", "--rm", + "--security-opt", "no-new-privileges", + "--cap-drop", "ALL", + } + + if cfg.Memory != "" { + args = append(args, "--memory", cfg.Memory) + } else { + args = append(args, "--memory", config.DefaultPodmanMemory) + } + + if cfg.CPUs != "" { + args = append(args, "--cpus", cfg.CPUs) + } else { + args = append(args, "--cpus", config.DefaultPodmanCPUs) + } + + args = append(args, "--userns", "keep-id") + + // Mount workspace + workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace) + args = append(args, "-v", workspaceMount) + + // Mount results + resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults) + args = append(args, "-v", resultsMount) + + if cfg.GPUAccess { + args = append(args, "--device", "/dev/dri") + } + + // Image and command + args = append(args, cfg.Image, + "--workspace", cfg.ContainerWorkspace, + "--requirements", requirementsPath, + "--script", scriptPath, + ) + + // Add extra arguments via --args flag + if len(extraArgs) > 0 { + args = append(args, "--args") + args = append(args, extraArgs...) + } + + return exec.Command("podman", args...) +} + +// SanitizePath ensures a path is safe to use (prevents path traversal) +func SanitizePath(path string) (string, error) { + // Clean the path to remove any .. or . components + cleaned := filepath.Clean(path) + + // Check for path traversal attempts + if strings.Contains(cleaned, "..") { + return "", fmt.Errorf("path traversal detected: %s", path) + } + + return cleaned, nil +} + +// ValidateJobName validates a job name is safe +func ValidateJobName(jobName string) error { + if jobName == "" { + return fmt.Errorf("job name cannot be empty") + } + + // Check for dangerous characters + if strings.ContainsAny(jobName, "/\\<>:\"|?*") { + return fmt.Errorf("job name contains invalid characters: %s", jobName) + } + + // Check for path traversal + if strings.Contains(jobName, "..") { + return fmt.Errorf("job name contains path traversal: %s", jobName) + } + + return nil +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..2d7c921 --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,39 @@ +// Package utils provides shared utilities for the fetch_ml project. +package errors + +import ( + "fmt" +) + +// DataFetchError represents an error that occurred while fetching a dataset +// from the NAS to the ML server. +type DataFetchError struct { + Dataset string + JobName string + Err error +} + +func (e *DataFetchError) Error() string { + return fmt.Sprintf("failed to fetch dataset %s for job %s: %v", + e.Dataset, e.JobName, e.Err) +} + +func (e *DataFetchError) Unwrap() error { + return e.Err +} + +type TaskExecutionError struct { + TaskID string + JobName string + Phase string // "data_fetch", "execution", "cleanup" + Err error +} + +func (e *TaskExecutionError) Error() string { + return fmt.Sprintf("task %s (%s) failed during %s: %v", + e.TaskID[:8], e.JobName, e.Phase, e.Err) +} + +func (e *TaskExecutionError) Unwrap() error { + return e.Err +} diff --git a/internal/experiment/manager.go b/internal/experiment/manager.go new file mode 100644 index 0000000..c37b599 --- /dev/null +++ b/internal/experiment/manager.go @@ -0,0 +1,343 @@ +package experiment + +import ( + "encoding/binary" + "fmt" + "math" + "os" + "path/filepath" + "time" +) + +// Metadata represents experiment metadata stored in meta.bin +type Metadata struct { + CommitID string + Timestamp int64 + JobName string + User string +} + +// Manager handles experiment storage and metadata +type Manager struct { + basePath string +} + +func NewManager(basePath string) *Manager { + return &Manager{ + basePath: basePath, + } +} + +// Initialize ensures the experiment directory exists +func (m *Manager) Initialize() error { + if err := os.MkdirAll(m.basePath, 0755); err != nil { + return fmt.Errorf("failed to create experiment base directory: %w", err) + } + return nil +} + +// GetExperimentPath returns the path for a given commit ID +func (m *Manager) GetExperimentPath(commitID string) string { + return filepath.Join(m.basePath, commitID) +} + +// GetFilesPath returns the path to the files directory for an experiment +func (m *Manager) GetFilesPath(commitID string) string { + return filepath.Join(m.GetExperimentPath(commitID), "files") +} + +// GetMetadataPath returns the path to meta.bin for an experiment +func (m *Manager) GetMetadataPath(commitID string) string { + return filepath.Join(m.GetExperimentPath(commitID), "meta.bin") +} + +// ExperimentExists checks if an experiment with the given commit ID exists +func (m *Manager) ExperimentExists(commitID string) bool { + path := m.GetExperimentPath(commitID) + info, err := os.Stat(path) + return err == nil && info.IsDir() +} + +// CreateExperiment creates the directory structure for a new experiment +func (m *Manager) CreateExperiment(commitID string) error { + filesPath := m.GetFilesPath(commitID) + + if err := os.MkdirAll(filesPath, 0755); err != nil { + return fmt.Errorf("failed to create experiment directory: %w", err) + } + + return nil +} + +// WriteMetadata writes experiment metadata to meta.bin +func (m *Manager) WriteMetadata(meta *Metadata) error { + path := m.GetMetadataPath(meta.CommitID) + + // Binary format: + // [version:1][timestamp:8][commit_id_len:1][commit_id:var][job_name_len:1][job_name:var][user_len:1][user:var] + + buf := make([]byte, 0, 256) + + // Version + buf = append(buf, 0x01) + + // Timestamp + ts := make([]byte, 8) + binary.BigEndian.PutUint64(ts, uint64(meta.Timestamp)) + buf = append(buf, ts...) + + // Commit ID + buf = append(buf, byte(len(meta.CommitID))) + buf = append(buf, []byte(meta.CommitID)...) + + // Job Name + buf = append(buf, byte(len(meta.JobName))) + buf = append(buf, []byte(meta.JobName)...) + + // User + buf = append(buf, byte(len(meta.User))) + buf = append(buf, []byte(meta.User)...) + + return os.WriteFile(path, buf, 0644) +} + +// ReadMetadata reads experiment metadata from meta.bin +func (m *Manager) ReadMetadata(commitID string) (*Metadata, error) { + path := m.GetMetadataPath(commitID) + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read metadata: %w", err) + } + + if len(data) < 10 { + return nil, fmt.Errorf("metadata file too short") + } + + meta := &Metadata{} + offset := 0 + + // Version + version := data[offset] + offset++ + if version != 0x01 { + return nil, fmt.Errorf("unsupported metadata version: %d", version) + } + + // Timestamp + meta.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8])) + offset += 8 + + // Commit ID + commitIDLen := int(data[offset]) + offset++ + meta.CommitID = string(data[offset : offset+commitIDLen]) + offset += commitIDLen + + // Job Name + if offset >= len(data) { + return meta, nil + } + jobNameLen := int(data[offset]) + offset++ + meta.JobName = string(data[offset : offset+jobNameLen]) + offset += jobNameLen + + // User + if offset >= len(data) { + return meta, nil + } + userLen := int(data[offset]) + offset++ + meta.User = string(data[offset : offset+userLen]) + + return meta, nil +} + +// ListExperiments returns all experiment commit IDs +func (m *Manager) ListExperiments() ([]string, error) { + entries, err := os.ReadDir(m.basePath) + if err != nil { + return nil, fmt.Errorf("failed to read experiments directory: %w", err) + } + + var commitIDs []string + for _, entry := range entries { + if entry.IsDir() { + commitIDs = append(commitIDs, entry.Name()) + } + } + + return commitIDs, nil +} + +// PruneExperiments removes old experiments based on retention policy +func (m *Manager) PruneExperiments(keepCount int, olderThanDays int) ([]string, error) { + commitIDs, err := m.ListExperiments() + if err != nil { + return nil, err + } + + type experiment struct { + commitID string + timestamp int64 + } + + var experiments []experiment + for _, commitID := range commitIDs { + meta, err := m.ReadMetadata(commitID) + if err != nil { + continue // Skip experiments with invalid metadata + } + experiments = append(experiments, experiment{ + commitID: commitID, + timestamp: meta.Timestamp, + }) + } + + // Sort by timestamp (newest first) + for i := 0; i < len(experiments); i++ { + for j := i + 1; j < len(experiments); j++ { + if experiments[j].timestamp > experiments[i].timestamp { + experiments[i], experiments[j] = experiments[j], experiments[i] + } + } + } + + var pruned []string + cutoffTime := time.Now().AddDate(0, 0, -olderThanDays).Unix() + + for i, exp := range experiments { + shouldPrune := false + + // Keep the newest N experiments + if i >= keepCount { + shouldPrune = true + } + + // Also prune if older than threshold + if olderThanDays > 0 && exp.timestamp < cutoffTime { + shouldPrune = true + } + + if shouldPrune { + expPath := m.GetExperimentPath(exp.commitID) + if err := os.RemoveAll(expPath); err != nil { + // Log but continue + continue + } + pruned = append(pruned, exp.commitID) + } + } + + return pruned, nil +} + +// Metric represents a single data point in an experiment +type Metric struct { + Name string `json:"name"` + Value float64 `json:"value"` + Step int `json:"step"` + Timestamp int64 `json:"timestamp"` +} + +// GetMetricsPath returns the path to metrics.bin for an experiment +func (m *Manager) GetMetricsPath(commitID string) string { + return filepath.Join(m.GetExperimentPath(commitID), "metrics.bin") +} + +// LogMetric appends a metric to the experiment's metrics file +func (m *Manager) LogMetric(commitID string, name string, value float64, step int) error { + path := m.GetMetricsPath(commitID) + + // Binary format for each metric: + // [timestamp:8][step:4][value:8][name_len:1][name:var] + + buf := make([]byte, 0, 64) + + // Timestamp + ts := make([]byte, 8) + binary.BigEndian.PutUint64(ts, uint64(time.Now().Unix())) + buf = append(buf, ts...) + + // Step + st := make([]byte, 4) + binary.BigEndian.PutUint32(st, uint32(step)) + buf = append(buf, st...) + + // Value (float64) + val := make([]byte, 8) + binary.BigEndian.PutUint64(val, math.Float64bits(value)) + buf = append(buf, val...) + + // Name + if len(name) > 255 { + name = name[:255] + } + buf = append(buf, byte(len(name))) + buf = append(buf, []byte(name)...) + + // Append to file + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open metrics file: %w", err) + } + defer f.Close() + + if _, err := f.Write(buf); err != nil { + return fmt.Errorf("failed to write metric: %w", err) + } + + return nil +} + +// GetMetrics reads all metrics for an experiment +func (m *Manager) GetMetrics(commitID string) ([]Metric, error) { + path := m.GetMetricsPath(commitID) + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return []Metric{}, nil + } + return nil, fmt.Errorf("failed to read metrics file: %w", err) + } + + var metrics []Metric + offset := 0 + + for offset < len(data) { + if offset+21 > len(data) { // Min size check + break + } + + m := Metric{} + + // Timestamp + m.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8])) + offset += 8 + + // Step + m.Step = int(binary.BigEndian.Uint32(data[offset : offset+4])) + offset += 4 + + // Value + bits := binary.BigEndian.Uint64(data[offset : offset+8]) + m.Value = math.Float64frombits(bits) + offset += 8 + + // Name + nameLen := int(data[offset]) + offset++ + + if offset+nameLen > len(data) { + break + } + m.Name = string(data[offset : offset+nameLen]) + offset += nameLen + + metrics = append(metrics, m) + } + + return metrics, nil +} diff --git a/internal/logging/config.go b/internal/logging/config.go new file mode 100644 index 0000000..d446824 --- /dev/null +++ b/internal/logging/config.go @@ -0,0 +1,52 @@ +package logging + +import ( + "log/slog" + "os" + "strings" +) + +// Config holds logging configuration +type Config struct { + Level string `yaml:"level"` + File string `yaml:"file"` + AuditLog string `yaml:"audit_log"` +} + +// LevelFromEnv reads LOG_LEVEL (if set) and returns the matching slog level. +// Accepted values: debug, info, warn, error. Defaults to info. +func LevelFromEnv() slog.Level { + return parseLevel(os.Getenv("LOG_LEVEL"), slog.LevelInfo) +} + +func parseLevel(value string, defaultLevel slog.Level) slog.Level { + switch strings.ToLower(strings.TrimSpace(value)) { + case "debug": + return slog.LevelDebug + case "warn", "warning": + return slog.LevelWarn + case "error": + return slog.LevelError + case "info", "": + return slog.LevelInfo + default: + return defaultLevel + } +} + +// NewConfiguredLogger creates a logger using the level configured via LOG_LEVEL. +// JSON/text output is still controlled by LOG_FORMAT in NewLogger. +func NewConfiguredLogger() *Logger { + return NewLogger(LevelFromEnv(), false) +} + +// NewLoggerFromConfig creates a logger from configuration +func NewLoggerFromConfig(cfg Config) *Logger { + level := parseLevel(cfg.Level, slog.LevelInfo) + + if cfg.File != "" { + return NewFileLogger(level, false, cfg.File) + } + + return NewLogger(level, false) +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..67b0e1e --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,172 @@ +package logging + +import ( + "context" + "io" + "log/slog" + "os" + "path/filepath" + "time" + + "github.com/google/uuid" +) + +type ctxKey string + +const ( + CtxTraceID ctxKey = "trace_id" + CtxSpanID ctxKey = "span_id" + CtxWorker ctxKey = "worker_id" + CtxJob ctxKey = "job_name" + CtxTask ctxKey = "task_id" +) + +type Logger struct { + *slog.Logger +} + +// NewLogger creates a logger that writes to stderr (development mode) +func NewLogger(level slog.Level, jsonOutput bool) *Logger { + opts := &slog.HandlerOptions{ + Level: level, + AddSource: os.Getenv("LOG_ADD_SOURCE") == "1", + } + + var handler slog.Handler + if jsonOutput || os.Getenv("LOG_FORMAT") == "json" { + handler = slog.NewJSONHandler(os.Stderr, opts) + } else { + handler = NewColorTextHandler(os.Stderr, opts) + } + + return &Logger{slog.New(handler)} +} + +// NewFileLogger creates a logger that writes to a file only (production mode) +func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger { + opts := &slog.HandlerOptions{ + Level: level, + AddSource: os.Getenv("LOG_ADD_SOURCE") == "1", + } + + // Create log directory if it doesn't exist + if logFile != "" { + logDir := filepath.Dir(logFile) + if err := os.MkdirAll(logDir, 0755); err != nil { + // Fallback to stderr only if directory creation fails + return NewLogger(level, jsonOutput) + } + } + + // Open log file + file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + // Fallback to stderr only if file creation fails + return NewLogger(level, jsonOutput) + } + + // Write to file only (production) + var handler slog.Handler + if jsonOutput || os.Getenv("LOG_FORMAT") == "json" { + handler = slog.NewJSONHandler(file, opts) + } else { + handler = slog.NewTextHandler(file, opts) + } + + return &Logger{slog.New(handler)} +} + +// Inject trace + span if missing +func EnsureTrace(ctx context.Context) context.Context { + if ctx.Value(CtxTraceID) == nil { + ctx = context.WithValue(ctx, CtxTraceID, uuid.NewString()) + } + if ctx.Value(CtxSpanID) == nil { + ctx = context.WithValue(ctx, CtxSpanID, uuid.NewString()) + } + return ctx +} + +func (l *Logger) WithContext(ctx context.Context, args ...any) *Logger { + if trace := ctx.Value(CtxTraceID); trace != nil { + args = append(args, "trace_id", trace) + } + if span := ctx.Value(CtxSpanID); span != nil { + args = append(args, "span_id", span) + } + if worker := ctx.Value(CtxWorker); worker != nil { + args = append(args, "worker_id", worker) + } + if job := ctx.Value(CtxJob); job != nil { + args = append(args, "job_name", job) + } + if task := ctx.Value(CtxTask); task != nil { + args = append(args, "task_id", task) + } + return &Logger{Logger: l.With(args...)} +} + +func CtxWithWorker(ctx context.Context, worker string) context.Context { + return context.WithValue(ctx, CtxWorker, worker) +} + +func CtxWithJob(ctx context.Context, job string) context.Context { + return context.WithValue(ctx, CtxJob, job) +} + +func CtxWithTask(ctx context.Context, task string) context.Context { + return context.WithValue(ctx, CtxTask, task) +} + +func (l *Logger) Component(ctx context.Context, name string) *Logger { + return l.WithContext(ctx, "component", name) +} + +func (l *Logger) Worker(ctx context.Context, workerID string) *Logger { + return l.WithContext(ctx, "worker_id", workerID) +} + +func (l *Logger) Job(ctx context.Context, job string, task string) *Logger { + return l.WithContext(ctx, "job_name", job, "task_id", task) +} + +func (l *Logger) Fatal(msg string, args ...any) { + l.Error(msg, args...) + os.Exit(1) +} + +func (l *Logger) Panic(msg string, args ...any) { + l.Error(msg, args...) + panic(msg) +} + +// ----------------------------------------------------- +// Colorized human-friendly console logs +// ----------------------------------------------------- + +type ColorTextHandler struct { + slog.Handler +} + +func NewColorTextHandler(w io.Writer, opts *slog.HandlerOptions) slog.Handler { + base := slog.NewTextHandler(w, opts) + return &ColorTextHandler{Handler: base} +} + +func (h *ColorTextHandler) Handle(ctx context.Context, r slog.Record) error { + // Add uniform timestamp (override default) + r.Time = time.Now() + + switch r.Level { + case slog.LevelDebug: + r.Add("lvl_color", "\033[34mDBG\033[0m") + case slog.LevelInfo: + r.Add("lvl_color", "\033[32mINF\033[0m") + case slog.LevelWarn: + r.Add("lvl_color", "\033[33mWRN\033[0m") + case slog.LevelError: + r.Add("lvl_color", "\033[31mERR\033[0m") + } + + return h.Handler.Handle(ctx, r) +} diff --git a/internal/logging/sanitize.go b/internal/logging/sanitize.go new file mode 100644 index 0000000..a9b2d64 --- /dev/null +++ b/internal/logging/sanitize.go @@ -0,0 +1,80 @@ +package logging + +import ( + "regexp" + "strings" +) + +// Patterns for sensitive data +var ( + // API keys: 32+ hex characters + apiKeyPattern = regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`) + + // JWT tokens + jwtPattern = regexp.MustCompile(`eyJ[a-zA-Z0-9_-]{10,}\.eyJ[a-zA-Z0-9_-]{10,}\.[a-zA-Z0-9_-]{10,}`) + + // Password-like fields in logs + passwordPattern = regexp.MustCompile(`(?i)(password|passwd|pwd|secret|token|key)["']?\s*[:=]\s*["']?([^"'\s,}]+)`) + + // Redis URLs with passwords + redisPasswordPattern = regexp.MustCompile(`redis://:[^@]+@`) +) + +// SanitizeLogMessage removes sensitive data from log messages +func SanitizeLogMessage(message string) string { + // Redact API keys + message = apiKeyPattern.ReplaceAllString(message, "[REDACTED-API-KEY]") + + // Redact JWT tokens + message = jwtPattern.ReplaceAllString(message, "[REDACTED-JWT]") + + // Redact password-like fields + message = passwordPattern.ReplaceAllStringFunc(message, func(match string) string { + parts := passwordPattern.FindStringSubmatch(match) + if len(parts) >= 2 { + return parts[1] + "=[REDACTED]" + } + return match + }) + + // Redact Redis passwords from URLs + message = redisPasswordPattern.ReplaceAllString(message, "redis://:[REDACTED]@") + + return message +} + +// SanitizeArgs removes sensitive data from structured log arguments +func SanitizeArgs(args []any) []any { + sanitized := make([]any, len(args)) + copy(sanitized, args) + + for i := 0; i < len(sanitized)-1; i += 2 { + // Check if this is a key-value pair + key, okKey := sanitized[i].(string) + value, okValue := sanitized[i+1].(string) + + if okKey && okValue { + lowerKey := strings.ToLower(key) + // Redact sensitive fields + if strings.Contains(lowerKey, "password") || + strings.Contains(lowerKey, "secret") || + strings.Contains(lowerKey, "token") || + strings.Contains(lowerKey, "key") || + strings.Contains(lowerKey, "api") { + sanitized[i+1] = "[REDACTED]" + } else if strings.HasPrefix(value, "redis://") { + sanitized[i+1] = SanitizeLogMessage(value) + } + } + } + + return sanitized +} + +// RedactAPIKey masks an API key for logging (shows first/last 4 chars) +func RedactAPIKey(key string) string { + if len(key) <= 8 { + return "[REDACTED]" + } + return key[:4] + "..." + key[len(key)-4:] +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..f11cbc8 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,71 @@ +// Package utils provides shared utilities for the fetch_ml project. +package metrics + +import ( + "sync/atomic" + "time" +) + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + +type Metrics struct { + TasksProcessed atomic.Int64 + TasksFailed atomic.Int64 + DataFetchTime atomic.Int64 // Total nanoseconds + ExecutionTime atomic.Int64 + DataTransferred atomic.Int64 // Total bytes + ActiveTasks atomic.Int64 + QueuedTasks atomic.Int64 +} + +func (m *Metrics) RecordTaskSuccess(duration time.Duration) { + m.TasksProcessed.Add(1) + m.ExecutionTime.Add(duration.Nanoseconds()) +} + +func (m *Metrics) RecordTaskFailure() { + m.TasksFailed.Add(1) +} + +func (m *Metrics) RecordTaskStart() { + m.ActiveTasks.Add(1) +} + +// RecordTaskCompletion decrements the number of active tasks. It is safe to call +// even if no tasks are currently recorded; the caller should ensure calls are +// balanced with RecordTaskStart. +func (m *Metrics) RecordTaskCompletion() { + m.ActiveTasks.Add(-1) +} + +func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) { + m.DataTransferred.Add(bytes) + m.DataFetchTime.Add(duration.Nanoseconds()) +} + +func (m *Metrics) SetQueuedTasks(count int64) { + m.QueuedTasks.Store(count) +} + +func (m *Metrics) GetStats() map[string]any { + processed := m.TasksProcessed.Load() + failed := m.TasksFailed.Load() + dataTransferred := m.DataTransferred.Load() + dataFetchTime := m.DataFetchTime.Load() + + return map[string]any{ + "tasks_processed": processed, + "tasks_failed": failed, + "active_tasks": m.ActiveTasks.Load(), + "queued_tasks": m.QueuedTasks.Load(), + "success_rate": float64(processed-failed) / float64(max(processed, 1)), + "avg_exec_time": time.Duration(m.ExecutionTime.Load() / max(processed, 1)), + "data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024), + "avg_fetch_time": time.Duration(dataFetchTime / max(processed, 1)), + } +} diff --git a/internal/middleware/security.go b/internal/middleware/security.go new file mode 100644 index 0000000..cc00eba --- /dev/null +++ b/internal/middleware/security.go @@ -0,0 +1,259 @@ +package middleware + +import ( + "context" + "log" + "net/http" + "strings" + "time" + + "golang.org/x/time/rate" +) + +// SecurityMiddleware provides comprehensive security features +type SecurityMiddleware struct { + rateLimiter *rate.Limiter + apiKeys map[string]bool + jwtSecret []byte +} + +func NewSecurityMiddleware(apiKeys []string, jwtSecret string) *SecurityMiddleware { + keyMap := make(map[string]bool) + for _, key := range apiKeys { + keyMap[key] = true + } + + return &SecurityMiddleware{ + rateLimiter: rate.NewLimiter(rate.Limit(60), 10), // 60 requests per minute, burst of 10 + apiKeys: keyMap, + jwtSecret: []byte(jwtSecret), + } +} + +// Rate limiting middleware +func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !sm.rateLimiter.Allow() { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// API key authentication +func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiKey := r.Header.Get("X-API-Key") + if apiKey == "" { + // Also check Authorization header + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + apiKey = strings.TrimPrefix(authHeader, "Bearer ") + } + } + + if !sm.apiKeys[apiKey] { + http.Error(w, "Invalid API key", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Security headers middleware +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Prevent clickjacking + w.Header().Set("X-Frame-Options", "DENY") + // Prevent MIME type sniffing + w.Header().Set("X-Content-Type-Options", "nosniff") + // Enable XSS protection + w.Header().Set("X-XSS-Protection", "1; mode=block") + // Content Security Policy + w.Header().Set("Content-Security-Policy", "default-src 'self'") + // Referrer policy + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + // HSTS (HTTPS only) + if r.TLS != nil { + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + } + next.ServeHTTP(w, r) + }) +} + +// IP whitelist middleware +func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientIP := getClientIP(r) + + // Check if client IP is in whitelist + allowed := false + for _, ip := range allowedIPs { + if strings.Contains(ip, "/") { + // CIDR notation - would need proper IP net parsing + if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) { + allowed = true + break + } + } else { + if clientIP == ip { + allowed = true + break + } + } + } + + if !allowed { + http.Error(w, "IP not whitelisted", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// CORS middleware with restrictive defaults +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // Only allow specific origins in production + allowedOrigins := []string{ + "https://ml-experiments.example.com", + "https://app.example.com", + } + + isAllowed := false + for _, allowed := range allowedOrigins { + if origin == allowed { + isAllowed = true + break + } + } + + if isAllowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", "86400") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Request timeout middleware +func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } +} + +// Request size limiter +func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength > maxSize { + http.Error(w, "Request too large", http.StatusRequestEntityTooLarge) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// Security audit logging +func AuditLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + path := r.URL.Path + raw := r.URL.RawQuery + + // Wrap response writer to capture status code + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // Process request + next.ServeHTTP(wrapped, r) + + // Log after processing + latency := time.Since(start) + clientIP := getClientIP(r) + method := r.Method + statusCode := wrapped.statusCode + + if raw != "" { + path = path + "?" + raw + } + + // Log security-relevant events + if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") { + // Log to security audit system + logSecurityEvent(map[string]interface{}{ + "timestamp": start.Unix(), + "client_ip": clientIP, + "method": method, + "path": path, + "status": statusCode, + "latency": latency, + "user_agent": r.UserAgent(), + "referer": r.Referer(), + }) + } + }) +} + +// Helper to get client IP +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the list + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + + // Fall back to RemoteAddr + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { + return r.RemoteAddr[:idx] + } + return r.RemoteAddr +} + +// Response writer wrapper to capture status code +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func logSecurityEvent(event map[string]interface{}) { + // Implementation would send to security monitoring system + // For now, just log (in production, use proper logging) + log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"]) +} diff --git a/internal/network/retry.go b/internal/network/retry.go new file mode 100644 index 0000000..8f3b826 --- /dev/null +++ b/internal/network/retry.go @@ -0,0 +1,73 @@ +// Package utils provides shared utilities for the fetch_ml project. +package network + +import ( + "context" + "math" + "time" +) + +type RetryConfig struct { + MaxAttempts int + InitialDelay time.Duration + MaxDelay time.Duration + Multiplier float64 +} + +func DefaultRetryConfig() RetryConfig { + return RetryConfig{ + MaxAttempts: 3, + InitialDelay: 1 * time.Second, + MaxDelay: 30 * time.Second, + Multiplier: 2.0, + } +} + +func Retry(ctx context.Context, cfg RetryConfig, fn func() error) error { + var lastErr error + delay := cfg.InitialDelay + + for attempt := 0; attempt < cfg.MaxAttempts; attempt++ { + if err := fn(); err == nil { + return nil + } else { + lastErr = err + } + + if attempt < cfg.MaxAttempts-1 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + delay = time.Duration(math.Min( + float64(delay)*cfg.Multiplier, + float64(cfg.MaxDelay), + )) + } + } + } + + return lastErr +} + +// RetryWithBackoff provides a convenient wrapper for common retry scenarios +func RetryWithBackoff(ctx context.Context, maxAttempts int, operation func() error) error { + cfg := RetryConfig{ + MaxAttempts: maxAttempts, + InitialDelay: 200 * time.Millisecond, + MaxDelay: 2 * time.Second, + Multiplier: 2.0, + } + return Retry(ctx, cfg, operation) +} + +// RetryForNetworkOperations is optimized for network-related operations +func RetryForNetworkOperations(ctx context.Context, operation func() error) error { + cfg := RetryConfig{ + MaxAttempts: 5, + InitialDelay: 200 * time.Millisecond, + MaxDelay: 5 * time.Second, + Multiplier: 1.5, + } + return Retry(ctx, cfg, operation) +} diff --git a/internal/network/ssh.go b/internal/network/ssh.go new file mode 100644 index 0000000..ade01dc --- /dev/null +++ b/internal/network/ssh.go @@ -0,0 +1,304 @@ +// Package utils provides shared utilities for the fetch_ml project. +package network + +import ( + "context" + "fmt" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/config" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/knownhosts" +) + +// SSHClient provides SSH connection and command execution +type SSHClient struct { + client *ssh.Client + host string +} + +// NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client. +// knownHostsPath is optional - if provided, will use known_hosts verification +func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) { + if host == "" || keyPath == "" { + // Local mode - no SSH connection needed + return &SSHClient{client: nil, host: ""}, nil + } + + keyPath = config.ExpandPath(keyPath) + if strings.HasPrefix(keyPath, "~") { + home, _ := os.UserHomeDir() + keyPath = filepath.Join(home, keyPath[1:]) + } + + key, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read SSH key: %w", err) + } + + var signer ssh.Signer + if signer, err = ssh.ParsePrivateKey(key); err != nil { + if _, ok := err.(*ssh.PassphraseMissingError); ok { + // Try to use ssh-agent for passphrase-protected keys + if agentSigner, agentErr := sshAgentSigner(); agentErr == nil { + signer = agentSigner + } else { + return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err) + } + } else { + return nil, fmt.Errorf("failed to parse SSH key: %w", err) + } + } + + hostKeyCallback := ssh.InsecureIgnoreHostKey() + if knownHostsPath != "" { + knownHostsPath = config.ExpandPath(knownHostsPath) + if _, err := os.Stat(knownHostsPath); err == nil { + callback, err := knownhosts.New(knownHostsPath) + if err != nil { + log.Printf("Warning: failed to parse known_hosts: %v; using insecure host key verification", err) + } else { + hostKeyCallback = callback + } + } else if !os.IsNotExist(err) { + log.Printf("Warning: known_hosts not found at %s; using insecure host key verification", knownHostsPath) + } + } + + sshConfig := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: hostKeyCallback, + Timeout: 10 * time.Second, + HostKeyAlgorithms: []string{ + ssh.KeyAlgoRSA, + ssh.KeyAlgoRSASHA256, + ssh.KeyAlgoRSASHA512, + ssh.KeyAlgoED25519, + ssh.KeyAlgoECDSA256, + ssh.KeyAlgoECDSA384, + ssh.KeyAlgoECDSA521, + }, + } + + addr := fmt.Sprintf("%s:%d", host, port) + client, err := ssh.Dial("tcp", addr, sshConfig) + if err != nil { + return nil, fmt.Errorf("SSH connection failed: %w", err) + } + + return &SSHClient{client: client, host: host}, nil +} + +// Exec executes a command remotely via SSH or locally if in local mode +func (c *SSHClient) Exec(cmd string) (string, error) { + return c.ExecContext(context.Background(), cmd) +} + +// ExecContext executes a command with context support for cancellation and timeouts +func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) { + if c.client == nil { + // Local mode - execute command locally with context + execCmd := exec.CommandContext(ctx, "sh", "-c", cmd) + output, err := execCmd.CombinedOutput() + return string(output), err + } + + session, err := c.client.NewSession() + if err != nil { + return "", fmt.Errorf("create session: %w", err) + } + defer func() { + if closeErr := session.Close(); closeErr != nil { + // Session may already be closed, so we just log at debug level + log.Printf("session close error (may be expected): %v", closeErr) + } + }() + + // Run command with context cancellation + type result struct { + output string + err error + } + resultCh := make(chan result, 1) + + go func() { + output, err := session.CombinedOutput(cmd) + resultCh <- result{string(output), err} + }() + + select { + case <-ctx.Done(): + // FIXED: Check error return value + if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil { + log.Printf("failed to send SIGTERM: %v", sigErr) + } + + // Give process time to cleanup gracefully + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + + select { + case res := <-resultCh: + // Command finished during graceful shutdown + return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output) + case <-timer.C: + if closeErr := session.Close(); closeErr != nil { + log.Printf("failed to force close session: %v", closeErr) + } + + // Wait a bit more for final result + select { + case res := <-resultCh: + return res.output, fmt.Errorf("command cancelled and force closed: %w (output: %s)", ctx.Err(), res.output) + case <-time.After(5 * time.Second): + return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err()) + } + } + case res := <-resultCh: + return res.output, res.err + } +} + +// FileExists checks if a file exists remotely or locally +func (c *SSHClient) FileExists(path string) bool { + if c.client == nil { + // Local mode - check file locally + _, err := os.Stat(path) + return err == nil + } + + out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path)) + if err != nil { + return false + } + return strings.Contains(strings.TrimSpace(out), "exists") +} + +// GetFileSize gets the size of a file or directory remotely or locally +func (c *SSHClient) GetFileSize(path string) (int64, error) { + if c.client == nil { + // Local mode - get size locally + var size int64 + err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + size += info.Size() + return nil + }) + return size, err + } + + out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path)) + if err != nil { + return 0, err + } + + var size int64 + if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil { + return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err) + } + return size, nil +} + +// RemoteExists checks if a remote path exists (alias for FileExists for compatibility) +func (c *SSHClient) RemoteExists(path string) bool { + return c.FileExists(path) +} + +// ListDir lists directory contents remotely or locally +func (c *SSHClient) ListDir(path string) []string { + if c.client == nil { + // Local mode + entries, err := os.ReadDir(path) + if err != nil { + return nil + } + var items []string + for _, entry := range entries { + items = append(items, entry.Name()) + } + return items + } + + out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path)) + if err != nil { + return nil + } + + var items []string + for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") { + if line != "" { + items = append(items, line) + } + } + return items +} + +// TailFile gets the last N lines of a file remotely or locally +func (c *SSHClient) TailFile(path string, lines int) string { + if c.client == nil { + // Local mode - read file and return last N lines + data, err := os.ReadFile(path) + if err != nil { + return "" + } + fileLines := strings.Split(string(data), "\n") + if len(fileLines) > lines { + fileLines = fileLines[len(fileLines)-lines:] + } + return strings.Join(fileLines, "\n") + } + + out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path)) + if err != nil { + return "" + } + return out +} + +// Close closes the SSH connection +func (c *SSHClient) Close() error { + if c.client != nil { + return c.client.Close() + } + return nil +} + +// sshAgentSigner attempts to get a signer from ssh-agent +func sshAgentSigner() (ssh.Signer, error) { + sshAuthSock := os.Getenv("SSH_AUTH_SOCK") + if sshAuthSock == "" { + return nil, fmt.Errorf("SSH_AUTH_SOCK not set") + } + + conn, err := net.Dial("unix", sshAuthSock) + if err != nil { + return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err) + } + defer func() { + if closeErr := conn.Close(); closeErr != nil { + log.Printf("warning: failed to close ssh-agent connection: %v", closeErr) + } + }() + + agentClient := agent.NewClient(conn) + signers, err := agentClient.Signers() + if err != nil { + return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err) + } + + if len(signers) == 0 { + return nil, fmt.Errorf("no signers available in ssh-agent") + } + + return signers[0], nil +} diff --git a/internal/network/ssh_pool.go b/internal/network/ssh_pool.go new file mode 100755 index 0000000..115085e --- /dev/null +++ b/internal/network/ssh_pool.go @@ -0,0 +1,84 @@ +// Package utils provides shared utilities for the fetch_ml project. +package network + +import ( + "context" + "sync" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +type SSHPool struct { + factory func() (*SSHClient, error) + pool chan *SSHClient + active int + maxConns int + mu sync.Mutex + logger *logging.Logger +} + +func NewSSHPool(maxConns int, factory func() (*SSHClient, error), logger *logging.Logger) *SSHPool { + return &SSHPool{ + factory: factory, + pool: make(chan *SSHClient, maxConns), + maxConns: maxConns, + logger: logger, + } +} + +func (p *SSHPool) Get(ctx context.Context) (*SSHClient, error) { + select { + case conn := <-p.pool: + return conn, nil + case <-ctx.Done(): + return nil, ctx.Err() + default: + p.mu.Lock() + if p.active < p.maxConns { + p.active++ + p.mu.Unlock() + return p.factory() + } + p.mu.Unlock() + + // Wait for available connection + select { + case conn := <-p.pool: + return conn, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (p *SSHPool) Put(conn *SSHClient) { + select { + case p.pool <- conn: + default: + // Pool is full, close connection + err := conn.Close() + if err != nil { + p.logger.Warn("failed to close SSH connection", "error", err) + } + p.mu.Lock() + p.active-- + p.mu.Unlock() + } +} + +func (p *SSHPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + // Close all connections in the pool + close(p.pool) + for conn := range p.pool { + err := conn.Close() + if err != nil { + p.logger.Warn("failed to close SSH connection", "error", err) + } + } + + // Reset active count + p.active = 0 +} diff --git a/internal/queue/errors.go b/internal/queue/errors.go new file mode 100644 index 0000000..5203dd6 --- /dev/null +++ b/internal/queue/errors.go @@ -0,0 +1,215 @@ +package queue + +import ( + "errors" + "fmt" + "strings" +) + +// ErrorCategory represents the type of error encountered +type ErrorCategory string + +const ( + ErrorNetwork ErrorCategory = "network" // Network connectivity issues + ErrorResource ErrorCategory = "resource" // Resource exhaustion (OOM, disk full) + ErrorRateLimit ErrorCategory = "rate_limit" // Rate limiting or throttling + ErrorAuth ErrorCategory = "auth" // Authentication/authorization failures + ErrorValidation ErrorCategory = "validation" // Input validation errors + ErrorTimeout ErrorCategory = "timeout" // Operation timeout + ErrorPermanent ErrorCategory = "permanent" // Non-retryable errors + ErrorUnknown ErrorCategory = "unknown" // Unclassified errors +) + +// TaskError wraps an error with category and context +type TaskError struct { + Category ErrorCategory + Message string + Cause error + Context map[string]string +} + +func (e *TaskError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("[%s] %s: %v", e.Category, e.Message, e.Cause) + } + return fmt.Sprintf("[%s] %s", e.Category, e.Message) +} + +func (e *TaskError) Unwrap() error { + return e.Cause +} + +// NewTaskError creates a new categorized error +func NewTaskError(category ErrorCategory, message string, cause error) *TaskError { + return &TaskError{ + Category: category, + Message: message, + Cause: cause, + Context: make(map[string]string), + } +} + +// ClassifyError categorizes an error for retry logic +func ClassifyError(err error) ErrorCategory { + if err == nil { + return ErrorUnknown + } + + // Check if already classified + var taskErr *TaskError + if errors.As(err, &taskErr) { + return taskErr.Category + } + + errStr := strings.ToLower(err.Error()) + + // Network errors (retryable) + networkIndicators := []string{ + "connection refused", + "connection reset", + "connection timeout", + "no route to host", + "network unreachable", + "temporary failure", + "dns", + "dial tcp", + "i/o timeout", + } + for _, indicator := range networkIndicators { + if strings.Contains(errStr, indicator) { + return ErrorNetwork + } + } + + // Resource errors (retryable after delay) + resourceIndicators := []string{ + "out of memory", + "oom", + "no space left", + "disk full", + "resource temporarily unavailable", + "too many open files", + "cannot allocate memory", + } + for _, indicator := range resourceIndicators { + if strings.Contains(errStr, indicator) { + return ErrorResource + } + } + + // Rate limiting (retryable with backoff) + rateLimitIndicators := []string{ + "rate limit", + "too many requests", + "throttle", + "quota exceeded", + "429", + } + for _, indicator := range rateLimitIndicators { + if strings.Contains(errStr, indicator) { + return ErrorRateLimit + } + } + + // Timeout errors (retryable) + timeoutIndicators := []string{ + "timeout", + "deadline exceeded", + "context deadline", + } + for _, indicator := range timeoutIndicators { + if strings.Contains(errStr, indicator) { + return ErrorTimeout + } + } + + // Authentication errors (not retryable) + authIndicators := []string{ + "unauthorized", + "forbidden", + "authentication failed", + "invalid credentials", + "access denied", + "401", + "403", + } + for _, indicator := range authIndicators { + if strings.Contains(errStr, indicator) { + return ErrorAuth + } + } + + // Validation errors (not retryable) + validationIndicators := []string{ + "invalid input", + "validation failed", + "bad request", + "malformed", + "400", + } + for _, indicator := range validationIndicators { + if strings.Contains(errStr, indicator) { + return ErrorValidation + } + } + + // Default to unknown + return ErrorUnknown +} + +// IsRetryable determines if an error category should be retried +func IsRetryable(category ErrorCategory) bool { + switch category { + case ErrorNetwork, ErrorResource, ErrorRateLimit, ErrorTimeout, ErrorUnknown: + return true + case ErrorAuth, ErrorValidation, ErrorPermanent: + return false + default: + return false + } +} + +// GetUserMessage returns a user-friendly error message with suggestions +func GetUserMessage(category ErrorCategory, err error) string { + messages := map[ErrorCategory]string{ + ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.", + ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.", + ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.", + ErrorAuth: "Authentication failed. Please check your API key or credentials.", + ErrorValidation: "Invalid input. Please review your request and correct any errors.", + ErrorTimeout: "Operation timed out. The task may be too complex or the system is slow. Try again or simplify the request.", + ErrorPermanent: "A permanent error occurred. This task cannot be retried automatically.", + ErrorUnknown: "An unexpected error occurred. If this persists, please contact support.", + } + + baseMsg := messages[category] + if err != nil { + return fmt.Sprintf("%s (Details: %v)", baseMsg, err) + } + return baseMsg +} + +// RetryDelay calculates the retry delay based on error category and retry count +func RetryDelay(category ErrorCategory, retryCount int) int { + switch category { + case ErrorRateLimit: + // Longer backoff for rate limits + return min(300, 10*(1< 8 && cfg.RedisAddr[:8] == "redis://" { + opts, err = redis.ParseURL(cfg.RedisAddr) + if err != nil { + return nil, fmt.Errorf("invalid redis url: %w", err) + } + } else { + opts = &redis.Options{ + Addr: cfg.RedisAddr, + Password: cfg.RedisPassword, + DB: cfg.RedisDB, + } + } + + rdb := redis.NewClient(opts) + + ctx, cancel := context.WithCancel(context.Background()) + if err := rdb.Ping(ctx).Err(); err != nil { + cancel() + return nil, fmt.Errorf("redis connection failed: %w", err) + } + + flushEvery := cfg.MetricsFlushInterval + if flushEvery == 0 { + flushEvery = defaultMetricsFlushInterval + } + + tq := &TaskQueue{ + client: rdb, + ctx: ctx, + cancel: cancel, + metricsCh: make(chan metricEvent, 256), + metricsDone: make(chan struct{}), + flushEvery: flushEvery, + } + + go tq.runMetricsBuffer() + go tq.runLeaseReclamation() // Start lease reclamation background job + + return tq, nil +} + +// AddTask adds a new task to the queue with default retry settings +func (tq *TaskQueue) AddTask(task *Task) error { + // Set default retry settings if not specified + if task.MaxRetries == 0 { + task.MaxRetries = defaultMaxRetries + } + + taskData, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("failed to marshal task: %w", err) + } + + pipe := tq.client.Pipeline() + + // Store task data + pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour) + + // Add to priority queue (ZSET) + // Use priority as score (higher priority = higher score) + pipe.ZAdd(tq.ctx, TaskQueueKey, redis.Z{ + Score: float64(task.Priority), + Member: task.ID, + }) + + // Initialize status + pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName, + "status", task.Status, + "task_id", task.ID, + "updated_at", time.Now().Format(time.RFC3339)) + + _, err = pipe.Exec(tq.ctx) + if err != nil { + return fmt.Errorf("failed to enqueue task: %w", err) + } + + // Record metrics + TasksQueued.Inc() + + // Update queue depth + depth, _ := tq.QueueDepth() + UpdateQueueDepth(depth) + + return nil +} + +// GetNextTask gets the next task without lease (backward compatible) +func (tq *TaskQueue) GetNextTask() (*Task, error) { + result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, nil + } + + taskID := result[0].Member.(string) + return tq.GetTask(taskID) +} + +// GetNextTaskWithLease gets the next task and acquires a lease +func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + + // Pop highest priority task + result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, nil + } + + taskID := result[0].Member.(string) + task, err := tq.GetTask(taskID) + if err != nil { + // Re-queue the task if we can't fetch it + tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{ + Score: result[0].Score, + Member: taskID, + }) + return nil, err + } + + // Acquire lease + now := time.Now() + leaseExpiry := now.Add(leaseDuration) + task.LeaseExpiry = &leaseExpiry + task.LeasedBy = workerID + + // Update task with lease + if err := tq.UpdateTask(task); err != nil { + // Re-queue if update fails + tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{ + Score: result[0].Score, + Member: taskID, + }) + return nil, err + } + + return task, nil +} + +// RenewLease renews the lease on a task (heartbeat) +func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error { + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + + task, err := tq.GetTask(taskID) + if err != nil { + return err + } + + // Verify the worker owns the lease + if task.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", task.LeasedBy) + } + + // Renew lease + leaseExpiry := time.Now().Add(leaseDuration) + task.LeaseExpiry = &leaseExpiry + + // Record renewal metric + RecordLeaseRenewal(workerID) + + return tq.UpdateTask(task) +} + +// ReleaseLease releases the lease on a task +func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error { + task, err := tq.GetTask(taskID) + if err != nil { + return err + } + + // Verify the worker owns the lease + if task.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", task.LeasedBy) + } + + // Clear lease + task.LeaseExpiry = nil + task.LeasedBy = "" + + return tq.UpdateTask(task) +} + +// RetryTask re-queues a failed task with smart backoff based on error category +func (tq *TaskQueue) RetryTask(task *Task) error { + if task.RetryCount >= task.MaxRetries { + // Move to dead letter queue + RecordDLQAddition("max_retries") + return tq.MoveToDeadLetterQueue(task, "max retries exceeded") + } + + // Classify the error if it exists + var errorCategory ErrorCategory = ErrorUnknown + if task.Error != "" { + errorCategory = ClassifyError(fmt.Errorf("%s", task.Error)) + } + + // Check if error is retryable + if !IsRetryable(errorCategory) { + RecordDLQAddition(string(errorCategory)) + return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory)) + } + + task.RetryCount++ + task.Status = "queued" + task.LastError = task.Error // Preserve last error + task.Error = "" // Clear current error + + // Calculate smart backoff based on error category + backoffSeconds := RetryDelay(errorCategory, task.RetryCount) + nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second) + task.NextRetry = &nextRetry + + // Clear lease + task.LeaseExpiry = nil + task.LeasedBy = "" + + // Record retry metrics + RecordTaskRetry(task.JobName, errorCategory) + + // Re-queue with same priority + return tq.AddTask(task) +} + +// MoveToDeadLetterQueue moves a task to the dead letter queue +func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error { + task.Status = "failed" + task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError) + + taskData, err := json.Marshal(task) + if err != nil { + return err + } + + // Store in dead letter queue with timestamp + key := "task:dlq:" + task.ID + + // Record metrics + RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError))) + + pipe := tq.client.Pipeline() + pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour) + pipe.ZRem(tq.ctx, TaskQueueKey, task.ID) + _, err = pipe.Exec(tq.ctx) + return err +} + +// runLeaseReclamation reclaims expired leases every 1 minute +func (tq *TaskQueue) runLeaseReclamation() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-tq.ctx.Done(): + return + case <-ticker.C: + if err := tq.reclaimExpiredLeases(); err != nil { + // Log error but continue + continue + } + } + } +} + +// reclaimExpiredLeases finds and re-queues tasks with expired leases +func (tq *TaskQueue) reclaimExpiredLeases() error { + // Scan for all task keys + iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator() + now := time.Now() + + for iter.Next(tq.ctx) { + taskKey := iter.Val() + taskID := taskKey[len(TaskPrefix):] + + task, err := tq.GetTask(taskID) + if err != nil { + continue + } + + // Check if lease expired and task is still running + if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" { + // Lease expired - retry or fail the task + task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy) + + // Record lease expiration + RecordLeaseExpiration() + + if task.RetryCount < task.MaxRetries { + // Retry the task + if err := tq.RetryTask(task); err != nil { + continue + } + } else { + // Max retries exceeded - move to DLQ + if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil { + continue + } + } + } + } + + return iter.Err() +} + +// GetTask retrieves a task by ID +func (tq *TaskQueue) GetTask(taskID string) (*Task, error) { + data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result() + if err != nil { + return nil, err + } + + var task Task + if err := json.Unmarshal([]byte(data), &task); err != nil { + return nil, err + } + + return &task, nil +} + +// GetAllTasks retrieves all tasks from the queue +func (tq *TaskQueue) GetAllTasks() ([]*Task, error) { + // Get all task keys + keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result() + if err != nil { + return nil, err + } + + var tasks []*Task + for _, key := range keys { + data, err := tq.client.Get(tq.ctx, key).Result() + if err != nil { + continue // Skip tasks that can't be retrieved + } + + var task Task + if err := json.Unmarshal([]byte(data), &task); err != nil { + continue // Skip malformed tasks + } + + tasks = append(tasks, &task) + } + + return tasks, nil +} + +// GetTaskByName retrieves a task by its job name +func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) { + tasks, err := tq.GetAllTasks() + if err != nil { + return nil, err + } + + for _, task := range tasks { + if task.JobName == jobName { + return task, nil + } + } + + return nil, fmt.Errorf("task with job name '%s' not found", jobName) +} + +// CancelTask marks a task as cancelled +func (tq *TaskQueue) CancelTask(taskID string) error { + task, err := tq.GetTask(taskID) + if err != nil { + return err + } + + // Update task status to cancelled + task.Status = "cancelled" + now := time.Now() + task.EndedAt = &now + + return tq.UpdateTask(task) +} + +// UpdateTask updates a task in Redis +func (tq *TaskQueue) UpdateTask(task *Task) error { + taskData, err := json.Marshal(task) + if err != nil { + return err + } + + pipe := tq.client.Pipeline() + pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour) + pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName, + "status", task.Status, + "task_id", task.ID, + "updated_at", time.Now().Format(time.RFC3339)) + + _, err = pipe.Exec(tq.ctx) + return err +} + +// UpdateTaskWithMetrics updates task and records metrics +func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error { + if err := tq.UpdateTask(task); err != nil { + return err + } + + metricName := "tasks_" + action + return tq.RecordMetric(task.JobName, metricName, 1) +} + +// RecordMetric records a metric value +func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error { + evt := metricEvent{JobName: jobName, Metric: metric, Value: value} + select { + case tq.metricsCh <- evt: + return nil + default: + return tq.writeMetrics(jobName, map[string]float64{metric: value}) + } +} + +// Heartbeat records worker heartbeat +func (tq *TaskQueue) Heartbeat(workerID string) error { + return tq.client.HSet(tq.ctx, WorkerHeartbeat, + workerID, time.Now().Unix()).Err() +} + +// QueueDepth returns the number of pending tasks +func (tq *TaskQueue) QueueDepth() (int64, error) { + return tq.client.ZCard(tq.ctx, TaskQueueKey).Result() +} + +// Close closes the task queue and cleans up resources +func (tq *TaskQueue) Close() error { + tq.cancel() + <-tq.metricsDone // Wait for metrics buffer to finish + return tq.client.Close() +} + +// GetRedisClient returns the underlying Redis client for direct access +func (tq *TaskQueue) GetRedisClient() *redis.Client { + return tq.client +} + +// WaitForNextTask waits for next task with timeout +func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) { + if ctx == nil { + ctx = tq.ctx + } + result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, err + } + member, ok := result.Member.(string) + if !ok { + return nil, fmt.Errorf("unexpected task id type %T", result.Member) + } + return tq.GetTask(member) +} + +// runMetricsBuffer buffers and flushes metrics +func (tq *TaskQueue) runMetricsBuffer() { + defer close(tq.metricsDone) + ticker := time.NewTicker(tq.flushEvery) + defer ticker.Stop() + pending := make(map[string]map[string]float64) + flush := func() { + for job, metrics := range pending { + if err := tq.writeMetrics(job, metrics); err != nil { + continue + } + delete(pending, job) + } + } + + for { + select { + case <-tq.ctx.Done(): + flush() + return + case evt, ok := <-tq.metricsCh: + if !ok { + flush() + return + } + if _, exists := pending[evt.JobName]; !exists { + pending[evt.JobName] = make(map[string]float64) + } + pending[evt.JobName][evt.Metric] = evt.Value + case <-ticker.C: + flush() + } + } +} + +// writeMetrics writes metrics to Redis +func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error { + if len(metrics) == 0 { + return nil + } + key := JobMetricsPrefix + jobName + args := make([]any, 0, len(metrics)*2+2) + args = append(args, "timestamp", time.Now().Unix()) + for metric, value := range metrics { + args = append(args, metric, value) + } + return tq.client.HSet(context.Background(), key, args...).Err() +} diff --git a/internal/queue/queue_permissions_test.go b/internal/queue/queue_permissions_test.go new file mode 100644 index 0000000..b842016 --- /dev/null +++ b/internal/queue/queue_permissions_test.go @@ -0,0 +1,152 @@ +package queue + +import ( + "testing" + "time" +) + +func TestTask_UserFields(t *testing.T) { + task := &Task{ + UserID: "testuser", + Username: "testuser", + CreatedBy: "testuser", + } + + if task.UserID != "testuser" { + t.Errorf("Expected UserID to be 'testuser', got '%s'", task.UserID) + } + + if task.Username != "testuser" { + t.Errorf("Expected Username to be 'testuser', got '%s'", task.Username) + } + + if task.CreatedBy != "testuser" { + t.Errorf("Expected CreatedBy to be 'testuser', got '%s'", task.CreatedBy) + } +} + +func TestTaskQueue_UserFiltering(t *testing.T) { + // Setup test Redis configuration + queueCfg := Config{ + RedisAddr: "localhost:6379", + RedisDB: 15, // Use dedicated test DB + } + + // Create task queue + taskQueue, err := NewTaskQueue(queueCfg) + if err != nil { + t.Skip("Redis not available for integration testing") + return + } + defer taskQueue.Close() + + // Clear test database + taskQueue.client.FlushDB(taskQueue.ctx) + + // Create test tasks with different users + tasks := []*Task{ + { + ID: "task1", + JobName: "user1_job1", + Status: "queued", + UserID: "user1", + CreatedBy: "user1", + CreatedAt: time.Now(), + }, + { + ID: "task2", + JobName: "user1_job2", + Status: "running", + UserID: "user1", + CreatedBy: "user1", + CreatedAt: time.Now(), + }, + { + ID: "task3", + JobName: "user2_job1", + Status: "queued", + UserID: "user2", + CreatedBy: "user2", + CreatedAt: time.Now(), + }, + { + ID: "task4", + JobName: "admin_job", + Status: "completed", + UserID: "admin", + CreatedBy: "admin", + CreatedAt: time.Now(), + }, + } + + // Add tasks to queue + for _, task := range tasks { + err := taskQueue.AddTask(task) + if err != nil { + t.Fatalf("Failed to add task %s: %v", task.ID, err) + } + } + + // Test GetAllTasks + allTasks, err := taskQueue.GetAllTasks() + if err != nil { + t.Fatalf("Failed to get all tasks: %v", err) + } + + if len(allTasks) != len(tasks) { + t.Errorf("Expected %d tasks, got %d", len(tasks), len(allTasks)) + } + + // Test user filtering logic + filterTasksForUser := func(tasks []*Task, userID string) []*Task { + var filtered []*Task + for _, task := range tasks { + if task.UserID == userID || task.CreatedBy == userID { + filtered = append(filtered, task) + } + } + return filtered + } + + // Test filtering for user1 (should get 2 tasks) + user1Tasks := filterTasksForUser(allTasks, "user1") + if len(user1Tasks) != 2 { + t.Errorf("Expected 2 tasks for user1, got %d", len(user1Tasks)) + } + + // Test filtering for user2 (should get 1 task) + user2Tasks := filterTasksForUser(allTasks, "user2") + if len(user2Tasks) != 1 { + t.Errorf("Expected 1 task for user2, got %d", len(user2Tasks)) + } + + // Test filtering for admin (should get 1 task) + adminTasks := filterTasksForUser(allTasks, "admin") + if len(adminTasks) != 1 { + t.Errorf("Expected 1 task for admin, got %d", len(adminTasks)) + } + + // Test GetTaskByName + task, err := taskQueue.GetTaskByName("user1_job1") + if err != nil { + t.Errorf("Failed to get task by name: %v", err) + } + if task == nil || task.UserID != "user1" { + t.Error("Got wrong task or nil task") + } + + // Test CancelTask + err = taskQueue.CancelTask("task1") + if err != nil { + t.Errorf("Failed to cancel task: %v", err) + } + + // Verify task was cancelled + cancelledTask, err := taskQueue.GetTask("task1") + if err != nil { + t.Errorf("Failed to get cancelled task: %v", err) + } + if cancelledTask.Status != "cancelled" { + t.Errorf("Expected status 'cancelled', got '%s'", cancelledTask.Status) + } +} diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go new file mode 100644 index 0000000..48cab03 --- /dev/null +++ b/internal/queue/queue_test.go @@ -0,0 +1,193 @@ +package queue + +import ( + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTaskQueue(t *testing.T) { + // Start miniredis + s, err := miniredis.Run() + if err != nil { + t.Fatalf("failed to start miniredis: %v", err) + } + defer s.Close() + + // Create TaskQueue + cfg := Config{ + RedisAddr: s.Addr(), + MetricsFlushInterval: 10 * time.Millisecond, // Fast flush for testing + } + tq, err := NewTaskQueue(cfg) + assert.NoError(t, err) + defer tq.Close() + + t.Run("AddTask", func(t *testing.T) { + task := &Task{ + ID: "task-1", + JobName: "job-1", + Status: "queued", + Priority: 10, + CreatedAt: time.Now(), + } + err = tq.AddTask(task) + assert.NoError(t, err) + + // Verify task is in Redis + // Check ZSET + score, err := s.ZScore(TaskQueueKey, "task-1") + assert.NoError(t, err) + assert.Equal(t, float64(10), score) + }) + + t.Run("GetNextTask", func(t *testing.T) { + // Add another task + task := &Task{ + ID: "task-2", + JobName: "job-2", + Status: "queued", + Priority: 20, // Higher priority + CreatedAt: time.Now(), + } + err = tq.AddTask(task) + assert.NoError(t, err) + + // Should get task-2 first due to higher priority + nextTask, err := tq.GetNextTask() + assert.NoError(t, err) + assert.NotNil(t, nextTask) + assert.Equal(t, "task-2", nextTask.ID) + + // Verify task is removed from ZSET + _, err = tq.client.ZScore(tq.ctx, TaskQueueKey, "task-2").Result() + assert.Equal(t, redis.Nil, err) + }) + + t.Run("GetNextTaskWithLease", func(t *testing.T) { + task := &Task{ + ID: "task-lease", + JobName: "job-lease", + Status: "queued", + Priority: 15, + CreatedAt: time.Now(), + } + err := tq.AddTask(task) + require.NoError(t, err) + + workerID := "worker-1" + leaseDuration := 1 * time.Minute + + leasedTask, err := tq.GetNextTaskWithLease(workerID, leaseDuration) + require.NoError(t, err) + require.NotNil(t, leasedTask) + assert.Equal(t, "task-lease", leasedTask.ID) + assert.Equal(t, workerID, leasedTask.LeasedBy) + assert.NotNil(t, leasedTask.LeaseExpiry) + assert.True(t, leasedTask.LeaseExpiry.After(time.Now())) + }) + + t.Run("RenewLease", func(t *testing.T) { + taskID := "task-lease" + workerID := "worker-1" + + // Get initial expiry + task, err := tq.GetTask(taskID) + require.NoError(t, err) + initialExpiry := task.LeaseExpiry + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Renew lease + err = tq.RenewLease(taskID, workerID, 1*time.Minute) + require.NoError(t, err) + + // Verify expiry updated + task, err = tq.GetTask(taskID) + require.NoError(t, err) + assert.True(t, task.LeaseExpiry.After(*initialExpiry)) + }) + + t.Run("ReleaseLease", func(t *testing.T) { + taskID := "task-lease" + workerID := "worker-1" + + err := tq.ReleaseLease(taskID, workerID) + require.NoError(t, err) + + task, err := tq.GetTask(taskID) + require.NoError(t, err) + assert.Nil(t, task.LeaseExpiry) + assert.Empty(t, task.LeasedBy) + }) + + t.Run("RetryTask", func(t *testing.T) { + task := &Task{ + ID: "task-retry", + JobName: "job-retry", + Status: "failed", + Priority: 10, + CreatedAt: time.Now(), + MaxRetries: 3, + RetryCount: 0, + Error: "some transient error", + } + + // Add task directly to verify retry logic + err := tq.AddTask(task) + require.NoError(t, err) + + // Simulate failure and retry + task.Error = "connection timeout" + err = tq.RetryTask(task) + require.NoError(t, err) + + // Verify task updated + updatedTask, err := tq.GetTask(task.ID) + require.NoError(t, err) + assert.Equal(t, 1, updatedTask.RetryCount) + assert.Equal(t, "queued", updatedTask.Status) + assert.Empty(t, updatedTask.Error) + assert.Equal(t, "connection timeout", updatedTask.LastError) + assert.NotNil(t, updatedTask.NextRetry) + }) + + t.Run("DLQ", func(t *testing.T) { + task := &Task{ + ID: "task-dlq", + JobName: "job-dlq", + Status: "failed", + Priority: 10, + CreatedAt: time.Now(), + MaxRetries: 1, + RetryCount: 1, // Already at max retries + Error: "fatal error", + } + + err := tq.AddTask(task) + require.NoError(t, err) + + // Retry should move to DLQ + err = tq.RetryTask(task) + require.NoError(t, err) + + // Verify removed from main queue + _, err = tq.client.ZScore(tq.ctx, TaskQueueKey, task.ID).Result() + assert.Equal(t, redis.Nil, err) + + // Verify in DLQ + dlqKey := "task:dlq:" + task.ID + exists := s.Exists(dlqKey) + assert.True(t, exists) + + // Verify DLQ content + val, err := s.Get(dlqKey) + require.NoError(t, err) + assert.Contains(t, val, "max retries exceeded") + }) +} diff --git a/internal/queue/task.go b/internal/queue/task.go new file mode 100644 index 0000000..25347f1 --- /dev/null +++ b/internal/queue/task.go @@ -0,0 +1,47 @@ +package queue + +import ( + "time" + + "github.com/jfraeys/fetch_ml/internal/config" +) + +// Task represents an ML experiment task +type Task struct { + ID string `json:"id"` + JobName string `json:"job_name"` + Args string `json:"args"` + Status string `json:"status"` // queued, running, completed, failed + Priority int64 `json:"priority"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` + WorkerID string `json:"worker_id,omitempty"` + Error string `json:"error,omitempty"` + Datasets []string `json:"datasets,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + + // User ownership and permissions + UserID string `json:"user_id"` // User who owns this task + Username string `json:"username"` // Username for display + CreatedBy string `json:"created_by"` // User who submitted the task + + // Lease management for task resilience + LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires + LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease + + // Retry management + RetryCount int `json:"retry_count"` // Number of retry attempts made + MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3) + LastError string `json:"last_error,omitempty"` // Last error encountered + NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff) +} + +// Redis key constants +var ( + TaskQueueKey = config.RedisTaskQueueKey + TaskPrefix = config.RedisTaskPrefix + TaskStatusPrefix = config.RedisTaskStatusPrefix + WorkerHeartbeat = config.RedisWorkerHeartbeat + JobMetricsPrefix = config.RedisJobMetricsPrefix +) diff --git a/internal/storage/db.go b/internal/storage/db.go new file mode 100644 index 0000000..163d0d6 --- /dev/null +++ b/internal/storage/db.go @@ -0,0 +1,433 @@ +package storage + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +type DBConfig struct { + Type string + Connection string + Host string + Port int + Username string + Password string + Database string +} + +type DB struct { + conn *sql.DB + dbType string +} + +func NewDB(config DBConfig) (*DB, error) { + var conn *sql.DB + var err error + + switch strings.ToLower(config.Type) { + case "sqlite": + conn, err = sql.Open("sqlite3", config.Connection) + if err != nil { + return nil, fmt.Errorf("failed to open SQLite database: %w", err) + } + // Enable foreign keys + if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, fmt.Errorf("failed to enable foreign keys: %w", err) + } + // Enable WAL mode for better concurrency + if _, err := conn.Exec("PRAGMA journal_mode = WAL"); err != nil { + return nil, fmt.Errorf("failed to enable WAL mode: %w", err) + } + case "postgres": + connStr := buildPostgresConnectionString(config) + conn, err = sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) + } + case "postgresql": + // Handle "postgresql" as alias for "postgres" + connStr := buildPostgresConnectionString(config) + conn, err = sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) + } + default: + return nil, fmt.Errorf("unsupported database type: %s", config.Type) + } + + return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil +} + +func buildPostgresConnectionString(config DBConfig) string { + if config.Connection != "" { + return config.Connection + } + + var connStr strings.Builder + connStr.WriteString("host=") + if config.Host != "" { + connStr.WriteString(config.Host) + } else { + connStr.WriteString("localhost") + } + + if config.Port > 0 { + connStr.WriteString(fmt.Sprintf(" port=%d", config.Port)) + } else { + connStr.WriteString(" port=5432") + } + + if config.Username != "" { + connStr.WriteString(fmt.Sprintf(" user=%s", config.Username)) + } + + if config.Password != "" { + connStr.WriteString(fmt.Sprintf(" password=%s", config.Password)) + } + + if config.Database != "" { + connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database)) + } else { + connStr.WriteString(" dbname=fetch_ml") + } + + connStr.WriteString(" sslmode=disable") + return connStr.String() +} + +// Legacy constructor for backward compatibility +func NewDBFromPath(dbPath string) (*DB, error) { + return NewDB(DBConfig{ + Type: "sqlite", + Connection: dbPath, + }) +} + +type Job struct { + ID string `json:"id"` + JobName string `json:"job_name"` + Args string `json:"args"` + Status string `json:"status"` + Priority int64 `json:"priority"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` + WorkerID string `json:"worker_id,omitempty"` + Error string `json:"error,omitempty"` + Datasets []string `json:"datasets,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + UpdatedAt time.Time `json:"updated_at"` +} + +type Worker struct { + ID string `json:"id"` + Hostname string `json:"hostname"` + LastHeartbeat time.Time `json:"last_heartbeat"` + Status string `json:"status"` + CurrentJobs int `json:"current_jobs"` + MaxJobs int `json:"max_jobs"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +func (db *DB) Initialize(schema string) error { + if _, err := db.conn.Exec(schema); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } + return nil +} + +func (db *DB) Close() error { + return db.conn.Close() +} + +// Job operations +func (db *DB) CreateJob(job *Job) error { + datasetsJSON, _ := json.Marshal(job.Datasets) + metadataJSON, _ := json.Marshal(job.Metadata) + + var query string + if db.dbType == "sqlite" { + query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?)` + } else { + query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7)` + } + + _, err := db.conn.Exec(query, job.ID, job.JobName, job.Args, job.Status, + job.Priority, string(datasetsJSON), string(metadataJSON)) + if err != nil { + return fmt.Errorf("failed to create job: %w", err) + } + return nil +} + +func (db *DB) GetJob(id string) (*Job, error) { + var query string + if db.dbType == "sqlite" { + query = `SELECT id, job_name, args, status, priority, created_at, started_at, + ended_at, worker_id, error, datasets, metadata, updated_at + FROM jobs WHERE id = ?` + } else { + query = `SELECT id, job_name, args, status, priority, created_at, started_at, + ended_at, worker_id, error, datasets, metadata, updated_at + FROM jobs WHERE id = $1` + } + + var job Job + var datasetsJSON, metadataJSON string + var workerID sql.NullString + var errorMsg sql.NullString + + err := db.conn.QueryRow(query, id).Scan( + &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, + &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) + + if err != nil { + return nil, fmt.Errorf("failed to get job: %w", err) + } + + if workerID.Valid { + job.WorkerID = workerID.String + } + + if errorMsg.Valid { + job.Error = errorMsg.String + } + + json.Unmarshal([]byte(datasetsJSON), &job.Datasets) + json.Unmarshal([]byte(metadataJSON), &job.Metadata) + + return &job, nil +} + +func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { + var query string + if db.dbType == "sqlite" { + query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?, + started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, + ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END + WHERE id = ?` + } else { + query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3, + started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END, + ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END + WHERE id = $6` + } + + _, err := db.conn.Exec(query, status, workerID, errorMsg, status, status, id) + if err != nil { + return fmt.Errorf("failed to update job status: %w", err) + } + return nil +} + +func (db *DB) ListJobs(status string, limit int) ([]*Job, error) { + var query string + if db.dbType == "sqlite" { + query = `SELECT id, job_name, args, status, priority, created_at, started_at, + ended_at, worker_id, error, datasets, metadata, updated_at + FROM jobs` + } else { + query = `SELECT id, job_name, args, status, priority, created_at, started_at, + ended_at, worker_id, error, datasets, metadata, updated_at + FROM jobs` + } + + var args []interface{} + if status != "" { + if db.dbType == "sqlite" { + query += " WHERE status = ?" + } else { + query += " WHERE status = $1" + } + args = append(args, status) + } + query += " ORDER BY created_at DESC" + if limit > 0 { + if db.dbType == "sqlite" { + query += " LIMIT ?" + } else { + query += fmt.Sprintf(" LIMIT $%d", len(args)+1) + } + args = append(args, limit) + } + + rows, err := db.conn.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("failed to list jobs: %w", err) + } + defer rows.Close() + + var jobs []*Job + for rows.Next() { + var job Job + var datasetsJSON, metadataJSON string + var workerID sql.NullString + var errorMsg sql.NullString + + err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID, + &errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("failed to scan job: %w", err) + } + + if workerID.Valid { + job.WorkerID = workerID.String + } + + if errorMsg.Valid { + job.Error = errorMsg.String + } + + json.Unmarshal([]byte(datasetsJSON), &job.Datasets) + json.Unmarshal([]byte(metadataJSON), &job.Metadata) + + jobs = append(jobs, &job) + } + + return jobs, nil +} + +// Worker operations +func (db *DB) RegisterWorker(worker *Worker) error { + metadataJSON, _ := json.Marshal(worker.Metadata) + + var query string + if db.dbType == "sqlite" { + query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) + VALUES (?, ?, ?, ?, ?, ?)` + } else { + query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (id) DO UPDATE SET + hostname = EXCLUDED.hostname, + status = EXCLUDED.status, + current_jobs = EXCLUDED.current_jobs, + max_jobs = EXCLUDED.max_jobs, + metadata = EXCLUDED.metadata` + } + + _, err := db.conn.Exec(query, worker.ID, worker.Hostname, worker.Status, + worker.CurrentJobs, worker.MaxJobs, string(metadataJSON)) + if err != nil { + return fmt.Errorf("failed to register worker: %w", err) + } + return nil +} + +func (db *DB) UpdateWorkerHeartbeat(workerID string) error { + var query string + if db.dbType == "sqlite" { + query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?` + } else { + query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1` + } + + _, err := db.conn.Exec(query, workerID) + if err != nil { + return fmt.Errorf("failed to update worker heartbeat: %w", err) + } + return nil +} + +func (db *DB) GetActiveWorkers() ([]*Worker, error) { + var query string + if db.dbType == "sqlite" { + query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata + FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')` + } else { + query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata + FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'` + } + + rows, err := db.conn.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to get active workers: %w", err) + } + defer rows.Close() + + var workers []*Worker + for rows.Next() { + var worker Worker + var metadataJSON string + + err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat, + &worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON) + if err != nil { + return nil, fmt.Errorf("failed to scan worker: %w", err) + } + + json.Unmarshal([]byte(metadataJSON), &worker.Metadata) + workers = append(workers, &worker) + } + + return workers, nil +} + +// Metrics operations +func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error { + var query string + if db.dbType == "sqlite" { + query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)` + } else { + query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)` + } + + _, err := db.conn.Exec(query, jobID, metricName, metricValue) + if err != nil { + return fmt.Errorf("failed to record job metric: %w", err) + } + return nil +} + +func (db *DB) RecordSystemMetric(metricName, metricValue string) error { + var query string + if db.dbType == "sqlite" { + query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)` + } else { + query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)` + } + + _, err := db.conn.Exec(query, metricName, metricValue) + if err != nil { + return fmt.Errorf("failed to record system metric: %w", err) + } + return nil +} + +func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) { + var query string + if db.dbType == "sqlite" { + query = `SELECT metric_name, metric_value FROM job_metrics + WHERE job_id = ? ORDER BY timestamp DESC` + } else { + query = `SELECT metric_name, metric_value FROM job_metrics + WHERE job_id = $1 ORDER BY timestamp DESC` + } + + rows, err := db.conn.Query(query, jobID) + if err != nil { + return nil, fmt.Errorf("failed to get job metrics: %w", err) + } + defer rows.Close() + + metrics := make(map[string]string) + for rows.Next() { + var name, value string + if err := rows.Scan(&name, &value); err != nil { + return nil, fmt.Errorf("failed to scan metric: %w", err) + } + metrics[name] = value + } + + return metrics, nil +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go new file mode 100644 index 0000000..9fe5f57 --- /dev/null +++ b/internal/storage/db_test.go @@ -0,0 +1,212 @@ +package storage + +import ( + "os" + "testing" +) + +func TestDB(t *testing.T) { + // Use a temporary database + dbPath := t.TempDir() + "/test.db" + + // Initialize database + db, err := NewDBFromPath(dbPath) + if err != nil { + t.Fatalf("Failed to create database: %v", err) + } + defer db.Close() + + // Initialize schema + schema, err := os.ReadFile("schema.sql") + if err != nil { + t.Fatalf("Failed to read schema: %v", err) + } + + if err := db.Initialize(string(schema)); err != nil { + t.Fatalf("Failed to initialize schema: %v", err) + } + + // Test job creation + job := &Job{ + ID: "test-job-1", + JobName: "test_experiment", + Args: "--epochs 10 --lr 0.001", + Status: "pending", + Priority: 1, + Datasets: []string{"dataset1", "dataset2"}, + Metadata: map[string]string{"gpu": "true", "memory": "8GB"}, + } + + if err := db.CreateJob(job); err != nil { + t.Fatalf("Failed to create job: %v", err) + } + + // Verify job exists in database + var count int + err = db.conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE id = ?", "test-job-1").Scan(&count) + if err != nil { + t.Fatalf("Failed to verify job creation: %v", err) + } + if count != 1 { + t.Fatalf("Expected 1 job in database, got %d", count) + } + + // Test job retrieval + retrievedJob, err := db.GetJob("test-job-1") + if err != nil { + t.Fatalf("Failed to get job: %v", err) + } + + if retrievedJob.ID != job.ID { + t.Errorf("Expected job ID %s, got %s", job.ID, retrievedJob.ID) + } + + if retrievedJob.JobName != job.JobName { + t.Errorf("Expected job name %s, got %s", job.JobName, retrievedJob.JobName) + } + + if len(retrievedJob.Datasets) != 2 { + t.Errorf("Expected 2 datasets, got %d", len(retrievedJob.Datasets)) + } + + if retrievedJob.Metadata["gpu"] != "true" { + t.Errorf("Expected gpu=true, got %s", retrievedJob.Metadata["gpu"]) + } + + // Test job status update + if err := db.UpdateJobStatus("test-job-1", "running", "worker-1", ""); err != nil { + t.Fatalf("Failed to update job status: %v", err) + } + + // Verify status update + updatedJob, err := db.GetJob("test-job-1") + if err != nil { + t.Fatalf("Failed to get updated job: %v", err) + } + + if updatedJob.Status != "running" { + t.Errorf("Expected status running, got %s", updatedJob.Status) + } + + if updatedJob.WorkerID != "worker-1" { + t.Errorf("Expected worker ID worker-1, got %s", updatedJob.WorkerID) + } + + if updatedJob.StartedAt == nil { + t.Error("Expected StartedAt to be set") + } + + // Test worker registration + worker := &Worker{ + ID: "worker-1", + Hostname: "test-host", + Status: "active", + CurrentJobs: 0, + MaxJobs: 2, + Metadata: map[string]string{"cpu": "8", "memory": "16GB"}, + } + + if err := db.RegisterWorker(worker); err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + // Test worker heartbeat + if err := db.UpdateWorkerHeartbeat("worker-1"); err != nil { + t.Fatalf("Failed to update worker heartbeat: %v", err) + } + + // Test metrics recording + if err := db.RecordJobMetric("test-job-1", "accuracy", "0.95"); err != nil { + t.Fatalf("Failed to record job metric: %v", err) + } + + if err := db.RecordSystemMetric("cpu_usage", "75"); err != nil { + t.Fatalf("Failed to record system metric: %v", err) + } + + // Test metrics retrieval + metrics, err := db.GetJobMetrics("test-job-1") + if err != nil { + t.Fatalf("Failed to get job metrics: %v", err) + } + + if metrics["accuracy"] != "0.95" { + t.Errorf("Expected accuracy 0.95, got %s", metrics["accuracy"]) + } + + // Test job listing + jobs, err := db.ListJobs("", 10) + if err != nil { + t.Fatalf("Failed to list jobs: %v", err) + } + + t.Logf("Found %d jobs", len(jobs)) + for i, job := range jobs { + t.Logf("Job %d: ID=%s, Status=%s", i, job.ID, job.Status) + } + + if len(jobs) != 1 { + t.Errorf("Expected 1 job, got %d", len(jobs)) + return + } + + if jobs[0].ID != "test-job-1" { + t.Errorf("Expected job ID test-job-1, got %s", jobs[0].ID) + return + } + + // Test active workers + workers, err := db.GetActiveWorkers() + if err != nil { + t.Fatalf("Failed to get active workers: %v", err) + } + + if len(workers) != 1 { + t.Errorf("Expected 1 active worker, got %d", len(workers)) + } + + if workers[0].ID != "worker-1" { + t.Errorf("Expected worker ID worker-1, got %s", workers[0].ID) + } +} + +func TestDBConstraints(t *testing.T) { + dbPath := t.TempDir() + "/test_constraints.db" + + db, err := NewDBFromPath(dbPath) + if err != nil { + t.Fatalf("Failed to create database: %v", err) + } + defer db.Close() + + schema, err := os.ReadFile("schema.sql") + if err != nil { + t.Fatalf("Failed to read schema: %v", err) + } + + if err := db.Initialize(string(schema)); err != nil { + t.Fatalf("Failed to initialize schema: %v", err) + } + + // Test duplicate job ID + job := &Job{ + ID: "duplicate-test", + JobName: "test", + Status: "pending", + } + + if err := db.CreateJob(job); err != nil { + t.Fatalf("Failed to create first job: %v", err) + } + + // Should fail on duplicate + if err := db.CreateJob(job); err == nil { + t.Error("Expected error when creating duplicate job") + } + + // Test getting non-existent job + _, err = db.GetJob("non-existent") + if err == nil { + t.Error("Expected error when getting non-existent job") + } +} diff --git a/internal/storage/migrate.go b/internal/storage/migrate.go new file mode 100644 index 0000000..a23771d --- /dev/null +++ b/internal/storage/migrate.go @@ -0,0 +1,257 @@ +package storage + +import ( + "encoding/json" + "fmt" + "log" + "strings" + "time" + + "context" + + "github.com/go-redis/redis/v8" +) + +// Migrator handles migration from Redis to SQLite +type Migrator struct { + redisClient *redis.Client + sqliteDB *DB +} + +func NewMigrator(redisAddr, sqlitePath string) (*Migrator, error) { + // Connect to Redis + rdb := redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + + // Connect to SQLite + db, err := NewDBFromPath(sqlitePath) + if err != nil { + return nil, fmt.Errorf("failed to connect to SQLite: %w", err) + } + + return &Migrator{ + redisClient: rdb, + sqliteDB: db, + }, nil +} + +func (m *Migrator) Close() error { + if err := m.sqliteDB.Close(); err != nil { + return err + } + return m.redisClient.Close() +} + +// MigrateJobs migrates job data from Redis to SQLite +func (m *Migrator) MigrateJobs(ctx context.Context) error { + log.Println("Starting job migration from Redis to SQLite...") + + // Get all job keys from Redis + jobKeys, err := m.redisClient.Keys(ctx, "job:*").Result() + if err != nil { + return fmt.Errorf("failed to get job keys from Redis: %w", err) + } + + for _, jobKey := range jobKeys { + jobData, err := m.redisClient.HGetAll(ctx, jobKey).Result() + if err != nil { + log.Printf("Failed to get job data for %s: %v", jobKey, err) + continue + } + + // Parse job data + job := &Job{ + ID: jobKey[4:], // Remove "job:" prefix + JobName: jobData["job_name"], + Args: jobData["args"], + Status: jobData["status"], + Priority: parsePriority(jobData["priority"]), + WorkerID: jobData["worker_id"], + Error: jobData["error"], + } + + // Parse timestamps + if createdAtStr := jobData["created_at"]; createdAtStr != "" { + if ts, err := time.Parse(time.RFC3339, createdAtStr); err == nil { + job.CreatedAt = ts + } + } + + if startedAtStr := jobData["started_at"]; startedAtStr != "" { + if ts, err := time.Parse(time.RFC3339, startedAtStr); err == nil { + job.StartedAt = &ts + } + } + + if endedAtStr := jobData["ended_at"]; endedAtStr != "" { + if ts, err := time.Parse(time.RFC3339, endedAtStr); err == nil { + job.EndedAt = &ts + } + } + + // Parse JSON fields + if datasetsStr := jobData["datasets"]; datasetsStr != "" { + json.Unmarshal([]byte(datasetsStr), &job.Datasets) + } + + if metadataStr := jobData["metadata"]; metadataStr != "" { + json.Unmarshal([]byte(metadataStr), &job.Metadata) + } + + // Insert into SQLite + if err := m.sqliteDB.CreateJob(job); err != nil { + log.Printf("Failed to create job %s in SQLite: %v", job.ID, err) + continue + } + + log.Printf("Migrated job: %s", job.ID) + } + + log.Printf("Migrated %d jobs from Redis to SQLite", len(jobKeys)) + return nil +} + +// MigrateMetrics migrates metrics from Redis to SQLite +func (m *Migrator) MigrateMetrics(ctx context.Context) error { + log.Println("Starting metrics migration from Redis to SQLite...") + + // Get all metric keys from Redis + metricKeys, err := m.redisClient.Keys(ctx, "metrics:*").Result() + if err != nil { + return fmt.Errorf("failed to get metric keys from Redis: %w", err) + } + + for _, metricKey := range metricKeys { + metricData, err := m.redisClient.HGetAll(ctx, metricKey).Result() + if err != nil { + log.Printf("Failed to get metric data for %s: %v", metricKey, err) + continue + } + + // Parse metric key format: metrics:job:job_id or metrics:system + parts := parseMetricKey(metricKey) + if len(parts) < 2 { + continue + } + + metricType := parts[1] // "job" or "system" + + for name, value := range metricData { + if metricType == "job" && len(parts) == 3 { + // Job metric + jobID := parts[2] + if err := m.sqliteDB.RecordJobMetric(jobID, name, value); err != nil { + log.Printf("Failed to record job metric %s for job %s: %v", name, jobID, err) + } + } else if metricType == "system" { + // System metric + if err := m.sqliteDB.RecordSystemMetric(name, value); err != nil { + log.Printf("Failed to record system metric %s: %v", name, err) + } + } + } + } + + log.Printf("Migrated %d metric keys from Redis to SQLite", len(metricKeys)) + return nil +} + +// MigrateWorkers migrates worker data from Redis to SQLite +func (m *Migrator) MigrateWorkers(ctx context.Context) error { + log.Println("Starting worker migration from Redis to SQLite...") + + // Get all worker keys from Redis + workerKeys, err := m.redisClient.Keys(ctx, "worker:*").Result() + if err != nil { + return fmt.Errorf("failed to get worker keys from Redis: %w", err) + } + + for _, workerKey := range workerKeys { + workerData, err := m.redisClient.HGetAll(ctx, workerKey).Result() + if err != nil { + log.Printf("Failed to get worker data for %s: %v", workerKey, err) + continue + } + + worker := &Worker{ + ID: workerKey[8:], // Remove "worker:" prefix + Hostname: workerData["hostname"], + Status: workerData["status"], + CurrentJobs: parseInt(workerData["current_jobs"]), + MaxJobs: parseInt(workerData["max_jobs"]), + } + + // Parse heartbeat + if heartbeatStr := workerData["last_heartbeat"]; heartbeatStr != "" { + if ts, err := time.Parse(time.RFC3339, heartbeatStr); err == nil { + worker.LastHeartbeat = ts + } + } + + // Parse metadata + if metadataStr := workerData["metadata"]; metadataStr != "" { + json.Unmarshal([]byte(metadataStr), &worker.Metadata) + } + + // Insert into SQLite + if err := m.sqliteDB.RegisterWorker(worker); err != nil { + log.Printf("Failed to register worker %s in SQLite: %v", worker.ID, err) + continue + } + + log.Printf("Migrated worker: %s", worker.ID) + } + + log.Printf("Migrated %d workers from Redis to SQLite", len(workerKeys)) + return nil +} + +// MigrateAll performs complete migration from Redis to SQLite +func (m *Migrator) MigrateAll(ctx context.Context) error { + log.Println("Starting complete migration from Redis to SQLite...") + + // Test connections + if err := m.redisClient.Ping(ctx).Err(); err != nil { + return fmt.Errorf("failed to connect to Redis: %w", err) + } + + // Run migrations in order + if err := m.MigrateJobs(ctx); err != nil { + return fmt.Errorf("job migration failed: %w", err) + } + + if err := m.MigrateWorkers(ctx); err != nil { + return fmt.Errorf("worker migration failed: %w", err) + } + + if err := m.MigrateMetrics(ctx); err != nil { + return fmt.Errorf("metrics migration failed: %w", err) + } + + log.Println("Migration completed successfully!") + return nil +} + +// Helper functions +func parsePriority(s string) int64 { + if s == "" { + return 0 + } + // Implementation depends on your priority format + return 0 +} + +func parseInt(s string) int { + if s == "" { + return 0 + } + // Implementation depends on your int format + return 0 +} + +func parseMetricKey(key string) []string { + // Simple split - adjust based on your Redis key format + parts := strings.Split(key, ":") + return parts +} diff --git a/internal/storage/schema.sql b/internal/storage/schema.sql new file mode 100644 index 0000000..ce415c4 --- /dev/null +++ b/internal/storage/schema.sql @@ -0,0 +1,61 @@ +-- SQLite schema for Fetch ML job persistence +-- Complements Redis for task queuing + +CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + job_name TEXT NOT NULL, + args TEXT, + status TEXT NOT NULL DEFAULT 'pending', + priority INTEGER DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + started_at DATETIME, + ended_at DATETIME, + worker_id TEXT, + error TEXT, + datasets TEXT, -- JSON array + metadata TEXT, -- JSON object + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS job_metrics ( + job_id TEXT, + metric_name TEXT, + metric_value TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (job_id, metric_name, timestamp), + FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS workers ( + id TEXT PRIMARY KEY, + hostname TEXT, + last_heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP, + status TEXT DEFAULT 'active', + current_jobs INTEGER DEFAULT 0, + max_jobs INTEGER DEFAULT 1, + metadata TEXT -- JSON object +); + +CREATE TABLE IF NOT EXISTS system_metrics ( + metric_name TEXT, + metric_value TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (metric_name, timestamp) +); + +-- Indexes for performance +CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); +CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at); +CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id); +CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id); +CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp); +CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat); +CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp); + +-- Triggers to update timestamps +CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp + AFTER UPDATE ON jobs + FOR EACH ROW + BEGIN + UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END; diff --git a/internal/storage/schema_postgres.sql b/internal/storage/schema_postgres.sql new file mode 100644 index 0000000..713ba40 --- /dev/null +++ b/internal/storage/schema_postgres.sql @@ -0,0 +1,68 @@ +-- PostgreSQL schema for Fetch ML job persistence +-- Complements Redis for task queuing + +CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + job_name TEXT NOT NULL, + args TEXT, + status TEXT NOT NULL DEFAULT 'pending', + priority INTEGER DEFAULT 0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP WITH TIME ZONE, + ended_at TIMESTAMP WITH TIME ZONE, + worker_id TEXT, + error TEXT, + datasets TEXT, -- JSON array + metadata TEXT, -- JSON object + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS job_metrics ( + job_id TEXT, + metric_name TEXT, + metric_value TEXT, + timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (job_id, metric_name, timestamp), + FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS workers ( + id TEXT PRIMARY KEY, + hostname TEXT, + last_heartbeat TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + status TEXT DEFAULT 'active', + current_jobs INTEGER DEFAULT 0, + max_jobs INTEGER DEFAULT 1, + metadata TEXT -- JSON object +); + +CREATE TABLE IF NOT EXISTS system_metrics ( + metric_name TEXT, + metric_value TEXT, + timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (metric_name, timestamp) +); + +-- Indexes for performance +CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); +CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at); +CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id); +CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id); +CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp); +CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat); +CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp); + +-- Function to update updated_at timestamp +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ language 'plpgsql'; + +-- Trigger to update timestamps +CREATE TRIGGER update_jobs_timestamp + BEFORE UPDATE ON jobs + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go new file mode 100644 index 0000000..a3a2888 --- /dev/null +++ b/internal/telemetry/telemetry.go @@ -0,0 +1,77 @@ +package telemetry + +import ( + "bufio" + "os" + "strconv" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +type IOStats struct { + ReadBytes uint64 + WriteBytes uint64 +} + +func ReadProcessIO() (IOStats, error) { + f, err := os.Open("/proc/self/io") + if err != nil { + return IOStats{}, err + } + defer f.Close() + + var stats IOStats + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "read_bytes:") { + stats.ReadBytes = parseUintField(line) + } + if strings.HasPrefix(line, "write_bytes:") { + stats.WriteBytes = parseUintField(line) + } + } + if err := scanner.Err(); err != nil { + return IOStats{}, err + } + return stats, nil +} + +func DiffIO(before, after IOStats) IOStats { + var delta IOStats + if after.ReadBytes >= before.ReadBytes { + delta.ReadBytes = after.ReadBytes - before.ReadBytes + } + if after.WriteBytes >= before.WriteBytes { + delta.WriteBytes = after.WriteBytes - before.WriteBytes + } + return delta +} + +func parseUintField(line string) uint64 { + parts := strings.Split(line, ":") + if len(parts) != 2 { + return 0 + } + value, err := strconv.ParseUint(strings.TrimSpace(parts[1]), 10, 64) + if err != nil { + return 0 + } + return value +} + +func ExecWithMetrics(logger *logging.Logger, description string, threshold time.Duration, fn func() (string, error)) (string, error) { + start := time.Now() + out, err := fn() + duration := time.Since(start) + if duration > threshold { + fields := []any{"latency_ms", duration.Milliseconds(), "command", description} + if err != nil { + fields = append(fields, "error", err) + } + logger.Debug("ssh exec", fields...) + } + return out, err +}