feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints - Implement authentication system with API keys and permissions - Add task queue system with Redis backend and error handling - Include storage layer with database migrations and schemas - Add comprehensive logging, metrics, and telemetry - Implement security middleware and network utilities - Add experiment management and container orchestration - Include configuration management with smart defaults
This commit is contained in:
parent
c5049a2fdf
commit
803677be57
62 changed files with 13354 additions and 0 deletions
32
cmd/api-server/README.md
Normal file
32
cmd/api-server/README.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# API Server
|
||||
|
||||
WebSocket API server for the ML CLI tool...
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
./bin/api-server --config configs/config-dev.yaml --listen :9100
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `GET /health` - Health check
|
||||
- `WS /ws` - WebSocket endpoint for CLI communication
|
||||
|
||||
## Binary Protocol
|
||||
|
||||
See [CLI README](../../cli/README.md#websocket-protocol) for protocol details.
|
||||
|
||||
## Configuration
|
||||
|
||||
Uses the same configuration file as the worker. Experiment base path is read from `base_path` configuration key.
|
||||
|
||||
## Example
|
||||
|
||||
```bash
|
||||
# Start API server
|
||||
./bin/api-server --listen :9100
|
||||
|
||||
# In another terminal, test with CLI
|
||||
./cli/zig-out/bin/ml status
|
||||
```
|
||||
363
cmd/api-server/main.go
Normal file
363
cmd/api-server/main.go
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/middleware"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config structure matching worker config
|
||||
type Config struct {
|
||||
BasePath string `yaml:"base_path"`
|
||||
Auth auth.AuthConfig `yaml:"auth"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Logging logging.Config `yaml:"logging"`
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string `yaml:"addr"`
|
||||
Password string `yaml:"password"`
|
||||
DB int `yaml:"db"`
|
||||
URL string `yaml:"url"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Connection string `yaml:"connection"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
RequestsPerMinute int `yaml:"requests_per_minute"`
|
||||
BurstSize int `yaml:"burst_size"`
|
||||
}
|
||||
|
||||
type LockoutConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MaxAttempts int `yaml:"max_attempts"`
|
||||
LockoutDuration string `yaml:"lockout_duration"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
TLS TLSConfig `yaml:"tls"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse flags
|
||||
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
|
||||
apiKey := flag.String("api-key", "", "API key for authentication")
|
||||
flag.Parse()
|
||||
|
||||
// Load config
|
||||
resolvedConfig, err := config.ResolveConfigPath(*configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(resolvedConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Ensure log directory exists
|
||||
if cfg.Logging.File != "" {
|
||||
logDir := filepath.Dir(cfg.Logging.File)
|
||||
log.Printf("Creating log directory: %s", logDir)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
log.Fatalf("Failed to create log directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup logging
|
||||
logger := logging.NewLoggerFromConfig(cfg.Logging)
|
||||
ctx := logging.EnsureTrace(context.Background())
|
||||
logger = logger.Component(ctx, "api-server")
|
||||
|
||||
// Setup experiment manager
|
||||
basePath := cfg.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/tmp/ml-experiments"
|
||||
}
|
||||
expManager := experiment.NewManager(basePath)
|
||||
log.Printf("Initializing experiment manager with base_path: %s", basePath)
|
||||
if err := expManager.Initialize(); err != nil {
|
||||
logger.Fatal("failed to initialize experiment manager", "error", err)
|
||||
}
|
||||
logger.Info("experiment manager initialized", "base_path", basePath)
|
||||
|
||||
// Setup auth
|
||||
var authCfg *auth.AuthConfig
|
||||
if cfg.Auth.Enabled {
|
||||
authCfg = &cfg.Auth
|
||||
logger.Info("authentication enabled")
|
||||
}
|
||||
|
||||
// Setup HTTP server with security middleware
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Convert API keys from map to slice for security middleware
|
||||
apiKeys := make([]string, 0, len(cfg.Auth.APIKeys))
|
||||
for username := range cfg.Auth.APIKeys {
|
||||
// For now, use username as the key (in production, this should be the actual API key)
|
||||
apiKeys = append(apiKeys, string(username))
|
||||
}
|
||||
|
||||
// Create security middleware
|
||||
sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Setup TaskQueue
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: cfg.Redis.Addr,
|
||||
RedisPassword: cfg.Redis.Password,
|
||||
RedisDB: cfg.Redis.DB,
|
||||
}
|
||||
if queueCfg.RedisAddr == "" {
|
||||
queueCfg.RedisAddr = config.DefaultRedisAddr
|
||||
}
|
||||
// Support URL format for Redis
|
||||
if cfg.Redis.URL != "" {
|
||||
queueCfg.RedisAddr = cfg.Redis.URL
|
||||
}
|
||||
|
||||
taskQueue, err := queue.NewTaskQueue(queueCfg)
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize task queue", "error", err)
|
||||
// We continue without queue, but queue operations will fail
|
||||
} else {
|
||||
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
|
||||
defer func() {
|
||||
logger.Info("stopping task queue...")
|
||||
if err := taskQueue.Close(); err != nil {
|
||||
logger.Error("failed to stop task queue", "error", err)
|
||||
} else {
|
||||
logger.Info("task queue stopped")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Setup database if configured
|
||||
var db *storage.DB
|
||||
if cfg.Database.Type != "" {
|
||||
dbConfig := storage.DBConfig{
|
||||
Type: cfg.Database.Type,
|
||||
Connection: cfg.Database.Connection,
|
||||
Host: cfg.Database.Host,
|
||||
Port: cfg.Database.Port,
|
||||
Username: cfg.Database.Username,
|
||||
Password: cfg.Database.Password,
|
||||
Database: cfg.Database.Database,
|
||||
}
|
||||
|
||||
db, err = storage.NewDB(dbConfig)
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
|
||||
} else {
|
||||
// Load appropriate database schema
|
||||
var schemaPath string
|
||||
if cfg.Database.Type == "sqlite" {
|
||||
schemaPath = "internal/storage/schema.sql"
|
||||
} else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" {
|
||||
schemaPath = "internal/storage/schema_postgres.sql"
|
||||
} else {
|
||||
logger.Error("unsupported database type", "type", cfg.Database.Type)
|
||||
db.Close()
|
||||
db = nil
|
||||
}
|
||||
|
||||
if db != nil && schemaPath != "" {
|
||||
schema, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
|
||||
db.Close()
|
||||
db = nil
|
||||
} else {
|
||||
if err := db.Initialize(string(schema)); err != nil {
|
||||
logger.Error("failed to initialize database schema", "error", err)
|
||||
db.Close()
|
||||
db = nil
|
||||
} else {
|
||||
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
|
||||
defer func() {
|
||||
logger.Info("closing database connection...")
|
||||
if err := db.Close(); err != nil {
|
||||
logger.Error("failed to close database", "error", err)
|
||||
} else {
|
||||
logger.Info("database connection closed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup WebSocket handler with authentication
|
||||
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
|
||||
|
||||
// WebSocket endpoint - no middleware to avoid hijacking issues
|
||||
mux.Handle("/ws", wsHandler)
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "OK\n")
|
||||
})
|
||||
|
||||
// Database status endpoint
|
||||
mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if db != nil {
|
||||
// Test database connection with a simple query
|
||||
var result struct {
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
result.Status = "connected"
|
||||
result.Type = "sqlite"
|
||||
result.Path = cfg.Database.Connection
|
||||
result.Message = "SQLite database is operational"
|
||||
|
||||
// Test a simple query to verify connectivity
|
||||
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
|
||||
result.Status = "error"
|
||||
result.Message = fmt.Sprintf("Database query failed: %v", err)
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(result)
|
||||
w.Write(jsonBytes)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
|
||||
}
|
||||
})
|
||||
|
||||
// Apply security middleware to all routes except WebSocket
|
||||
// Create separate handlers for WebSocket vs other routes
|
||||
var finalHandler http.Handler = mux
|
||||
|
||||
// Wrap non-websocket routes with security middleware
|
||||
finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ws" {
|
||||
mux.ServeHTTP(w, r)
|
||||
} else {
|
||||
// Apply middleware chain for non-WebSocket routes
|
||||
handler := sec.RateLimit(mux)
|
||||
handler = middleware.SecurityHeaders(handler)
|
||||
handler = middleware.CORS(handler)
|
||||
handler = middleware.RequestTimeout(30 * time.Second)(handler)
|
||||
|
||||
// Apply audit logger and IP whitelist only to non-WebSocket routes
|
||||
handler = middleware.AuditLogger(handler)
|
||||
if len(cfg.Security.IPWhitelist) > 0 {
|
||||
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
var handler http.Handler = finalHandler
|
||||
|
||||
server := &http.Server{
|
||||
Addr: cfg.Server.Address,
|
||||
Handler: handler,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
if !cfg.Server.TLS.Enabled {
|
||||
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
|
||||
}
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
// Setup TLS if configured
|
||||
if cfg.Server.TLS.Enabled {
|
||||
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
|
||||
if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTPS server failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("starting HTTP server", "address", cfg.Server.Address)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP server failed", "error", err)
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
// Setup graceful shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
sig := <-sigChan
|
||||
logger.Info("received shutdown signal", "signal", sig)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
logger.Info("shutting down http server...")
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
logger.Error("server shutdown error", "error", err)
|
||||
} else {
|
||||
logger.Info("http server shutdown complete")
|
||||
}
|
||||
|
||||
logger.Info("api server stopped")
|
||||
|
||||
_ = expManager // Use expManager to avoid unused warning
|
||||
_ = apiKey // Will be used for auth later
|
||||
}
|
||||
116
cmd/configlint/main.go
Normal file
116
cmd/configlint/main.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/xeipuuv/gojsonschema"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
schemaPath string
|
||||
failFast bool
|
||||
)
|
||||
|
||||
flag.StringVar(&schemaPath, "schema", "configs/schema.yaml", "Path to JSON schema in YAML format")
|
||||
flag.BoolVar(&failFast, "fail-fast", false, "Stop on first error")
|
||||
flag.Parse()
|
||||
|
||||
if flag.NArg() == 0 {
|
||||
log.Fatalf("usage: configlint [--schema path] [--fail-fast] <config files...>")
|
||||
}
|
||||
|
||||
schemaLoader, err := loadSchema(schemaPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load schema: %v", err)
|
||||
}
|
||||
|
||||
var hadError bool
|
||||
for _, configPath := range flag.Args() {
|
||||
if err := validateConfig(schemaLoader, configPath); err != nil {
|
||||
hadError = true
|
||||
fmt.Fprintf(os.Stderr, "configlint: %s: %v\n", configPath, err)
|
||||
if failFast {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hadError {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("All configuration files are valid.")
|
||||
}
|
||||
|
||||
func loadSchema(schemaPath string) (gojsonschema.JSONLoader, error) {
|
||||
data, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var schemaYAML interface{}
|
||||
if err := yaml.Unmarshal(data, &schemaYAML); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schemaJSON, err := json.Marshal(schemaYAML)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "fetchml-schema-*.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
if _, err := tmpFile.Write(schemaJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gojsonschema.NewReferenceLoader("file://" + filepath.ToSlash(tmpFile.Name())), nil
|
||||
}
|
||||
|
||||
func validateConfig(schemaLoader gojsonschema.JSONLoader, configPath string) error {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var configYAML interface{}
|
||||
if err := yaml.Unmarshal(data, &configYAML); err != nil {
|
||||
return fmt.Errorf("failed to parse YAML: %w", err)
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(configYAML)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := gojsonschema.Validate(schemaLoader, gojsonschema.NewBytesLoader(configJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.Valid() {
|
||||
fmt.Printf("%s: valid\n", configPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, issue := range result.Errors() {
|
||||
builder.WriteString("- ")
|
||||
builder.WriteString(issue.String())
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
|
||||
return fmt.Errorf("validation failed:\n%s", builder.String())
|
||||
}
|
||||
132
cmd/data_manager/data_manager_config.go
Normal file
132
cmd/data_manager/data_manager_config.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
// DataConfig holds the configuration for the data manager
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type DataConfig struct {
|
||||
// ML Server (where training runs)
|
||||
MLHost string `yaml:"ml_host"`
|
||||
MLUser string `yaml:"ml_user"`
|
||||
MLSSHKey string `yaml:"ml_ssh_key"`
|
||||
MLPort int `yaml:"ml_port"`
|
||||
MLDataDir string `yaml:"ml_data_dir"` // e.g., /data/active
|
||||
|
||||
// NAS (where datasets are stored)
|
||||
NASHost string `yaml:"nas_host"`
|
||||
NASUser string `yaml:"nas_user"`
|
||||
NASSSHKey string `yaml:"nas_ssh_key"`
|
||||
NASPort int `yaml:"nas_port"`
|
||||
NASDataDir string `yaml:"nas_data_dir"` // e.g., /mnt/datasets
|
||||
|
||||
// Redis
|
||||
RedisAddr string `yaml:"redis_addr"`
|
||||
RedisPassword string `yaml:"redis_password"`
|
||||
RedisDB int `yaml:"redis_db"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.AuthConfig `yaml:"auth"`
|
||||
|
||||
// Cleanup settings
|
||||
MaxAgeHours int `yaml:"max_age_hours"` // Delete data older than X hours
|
||||
MaxSizeGB int `yaml:"max_size_gb"` // Keep total size under X GB
|
||||
CleanupInterval int `yaml:"cleanup_interval_min"` // Run cleanup every X minutes
|
||||
|
||||
// Podman integration
|
||||
PodmanImage string `yaml:"podman_image"`
|
||||
ContainerWorkspace string `yaml:"container_workspace"`
|
||||
ContainerResults string `yaml:"container_results"`
|
||||
GPUAccess bool `yaml:"gpu_access"`
|
||||
}
|
||||
|
||||
func LoadDataConfig(path string) (*DataConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg DataConfig
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Defaults
|
||||
if cfg.MLPort == 0 {
|
||||
cfg.MLPort = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.NASPort == 0 {
|
||||
cfg.NASPort = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = config.DefaultRedisAddr
|
||||
}
|
||||
// Set default MLDataDir - use ./data/active for local/dev, /data/active for production
|
||||
if cfg.MLDataDir == "" {
|
||||
if cfg.MLHost == "" {
|
||||
// Local mode - use local data directory
|
||||
cfg.MLDataDir = config.DefaultLocalDataDir
|
||||
} else {
|
||||
// Production mode - use /data/active
|
||||
cfg.MLDataDir = config.DefaultDataDir
|
||||
}
|
||||
}
|
||||
if cfg.NASDataDir == "" {
|
||||
cfg.NASDataDir = config.DefaultNASDataDir
|
||||
}
|
||||
|
||||
// Expand paths
|
||||
cfg.MLDataDir = config.ExpandPath(cfg.MLDataDir)
|
||||
cfg.NASDataDir = config.ExpandPath(cfg.NASDataDir)
|
||||
if cfg.MaxAgeHours == 0 {
|
||||
cfg.MaxAgeHours = config.DefaultMaxAgeHours
|
||||
}
|
||||
if cfg.MaxSizeGB == 0 {
|
||||
cfg.MaxSizeGB = config.DefaultMaxSizeGB
|
||||
}
|
||||
if cfg.CleanupInterval == 0 {
|
||||
cfg.CleanupInterval = config.DefaultCleanupInterval
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements utils.Validator interface
|
||||
func (c *DataConfig) Validate() error {
|
||||
if c.MLPort != 0 {
|
||||
if err := config.ValidatePort(c.MLPort); err != nil {
|
||||
return fmt.Errorf("invalid ML SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.NASPort != 0 {
|
||||
if err := config.ValidatePort(c.NASPort); err != nil {
|
||||
return fmt.Errorf("invalid NAS SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.MaxAgeHours < 1 {
|
||||
return fmt.Errorf("max_age_hours must be at least 1, got %d", c.MaxAgeHours)
|
||||
}
|
||||
|
||||
if c.MaxSizeGB < 1 {
|
||||
return fmt.Errorf("max_size_gb must be at least 1, got %d", c.MaxSizeGB)
|
||||
}
|
||||
|
||||
if c.CleanupInterval < 1 {
|
||||
return fmt.Errorf("cleanup_interval must be at least 1, got %d", c.CleanupInterval)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
775
cmd/data_manager/data_sync.go
Normal file
775
cmd/data_manager/data_sync.go
Normal file
|
|
@ -0,0 +1,775 @@
|
|||
// data_manager.go - Fetch data from NAS to ML server on-demand
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/errors"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
)
|
||||
|
||||
// SSHClient alias for convenience
|
||||
type SSHClient = network.SSHClient
|
||||
|
||||
type DataManager struct {
|
||||
config *DataConfig
|
||||
mlServer *SSHClient
|
||||
nasServer *SSHClient
|
||||
taskQueue *queue.TaskQueue
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
type DataFetchRequest struct {
|
||||
JobName string `json:"job_name"`
|
||||
Datasets []string `json:"datasets"` // Dataset names to fetch
|
||||
Priority int `json:"priority"`
|
||||
RequestedAt time.Time `json:"requested_at"`
|
||||
}
|
||||
|
||||
type DatasetInfo struct {
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
Location string `json:"location"` // "nas" or "ml"
|
||||
LastAccess time.Time `json:"last_access"`
|
||||
}
|
||||
|
||||
func NewDataManager(cfg *DataConfig, apiKey string) (*DataManager, error) {
|
||||
mlServer, err := network.NewSSHClient(cfg.MLHost, cfg.MLUser, cfg.MLSSHKey, cfg.MLPort, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ML server connection failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := mlServer.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close ML server connection: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
nasServer, err := network.NewSSHClient(cfg.NASHost, cfg.NASUser, cfg.NASSSHKey, cfg.NASPort, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NAS connection failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := nasServer.Close(); closeErr != nil {
|
||||
log.Printf("Warning: failed to close NAS server connection: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create MLDataDir if it doesn't exist (for production without NAS)
|
||||
if cfg.MLDataDir != "" {
|
||||
if _, err := mlServer.Exec(fmt.Sprintf("mkdir -p %s", cfg.MLDataDir)); err != nil {
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
logger.Job(context.Background(), "data_manager", "").Error("Failed to create ML data directory", "dir", cfg.MLDataDir, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup Redis using internal queue
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := logging.NewLogger(slog.LevelInfo, false)
|
||||
|
||||
var taskQueue *queue.TaskQueue
|
||||
if cfg.RedisAddr != "" {
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
}
|
||||
|
||||
var err error
|
||||
taskQueue, err = queue.NewTaskQueue(queueCfg)
|
||||
if err != nil {
|
||||
// FIXED: Check error return values for cleanup
|
||||
if closeErr := mlServer.Close(); closeErr != nil {
|
||||
logger.Warn("failed to close ML server during error cleanup", "error", closeErr)
|
||||
}
|
||||
if closeErr := nasServer.Close(); closeErr != nil {
|
||||
logger.Warn("failed to close NAS server during error cleanup", "error", closeErr)
|
||||
}
|
||||
cancel() // Cancel context to prevent leak
|
||||
return nil, fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
taskQueue = nil // Local mode - no Redis
|
||||
}
|
||||
|
||||
return &DataManager{
|
||||
config: cfg,
|
||||
mlServer: mlServer,
|
||||
nasServer: nasServer,
|
||||
taskQueue: taskQueue,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (dm *DataManager) FetchDataset(jobName, datasetName string) error {
|
||||
ctx, cancel := context.WithTimeout(dm.ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
return network.RetryForNetworkOperations(ctx, func() error {
|
||||
return dm.fetchDatasetInternal(ctx, jobName, datasetName)
|
||||
})
|
||||
}
|
||||
|
||||
func (dm *DataManager) fetchDatasetInternal(ctx context.Context, jobName, datasetName string) error {
|
||||
if err := container.ValidateJobName(datasetName); err != nil {
|
||||
return &errors.DataFetchError{
|
||||
Dataset: datasetName,
|
||||
JobName: jobName,
|
||||
Err: fmt.Errorf("invalid dataset name: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
logger := dm.logger.Job(ctx, jobName, "")
|
||||
logger.Info("fetching dataset", "dataset", datasetName)
|
||||
|
||||
// Validate dataset size and run cleanup if needed
|
||||
if err := dm.ValidateDatasetWithCleanup(datasetName); err != nil {
|
||||
return &errors.DataFetchError{
|
||||
Dataset: datasetName,
|
||||
JobName: jobName,
|
||||
Err: fmt.Errorf("dataset size validation failed: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
nasPath := filepath.Join(dm.config.NASDataDir, datasetName)
|
||||
mlPath := filepath.Join(dm.config.MLDataDir, datasetName)
|
||||
|
||||
// Check if dataset exists on NAS
|
||||
if !dm.nasServer.FileExists(nasPath) {
|
||||
return &errors.DataFetchError{
|
||||
Dataset: datasetName,
|
||||
JobName: jobName,
|
||||
Err: fmt.Errorf("dataset not found on NAS"),
|
||||
}
|
||||
}
|
||||
|
||||
// Check if already on ML server
|
||||
if dm.mlServer.FileExists(mlPath) {
|
||||
logger.Info("dataset already on ML server", "dataset", datasetName)
|
||||
dm.updateLastAccess(datasetName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get size for progress tracking
|
||||
size, err := dm.nasServer.GetFileSize(nasPath)
|
||||
if err != nil {
|
||||
logger.Warn("could not get dataset size", "dataset", datasetName, "error", err)
|
||||
size = 0
|
||||
}
|
||||
|
||||
sizeGB := float64(size) / (1024 * 1024 * 1024)
|
||||
logger.Info("transferring dataset",
|
||||
"dataset", datasetName,
|
||||
"size_gb", sizeGB,
|
||||
"nas_path", nasPath,
|
||||
"ml_path", mlPath)
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "transferring",
|
||||
"job_name", jobName,
|
||||
"size_bytes", size,
|
||||
"started_at", time.Now().Unix()).Err(); err != nil {
|
||||
logger.Warn("failed to record transfer start in Redis", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Use local copy for local mode, rsync for remote mode
|
||||
var rsyncCmd string
|
||||
if dm.config.NASHost == "" || dm.config.NASUser == "" {
|
||||
// Local mode - use cp
|
||||
rsyncCmd = fmt.Sprintf("mkdir -p %s && cp -r %s %s/", dm.config.MLDataDir, nasPath, mlPath)
|
||||
} else {
|
||||
// Remote mode - use rsync over SSH
|
||||
rsyncCmd = fmt.Sprintf(
|
||||
"mkdir -p %s && rsync -avz --progress %s@%s:%s/ %s/",
|
||||
dm.config.MLDataDir,
|
||||
dm.config.NASUser,
|
||||
dm.config.NASHost,
|
||||
nasPath,
|
||||
mlPath,
|
||||
)
|
||||
}
|
||||
|
||||
ioBefore, ioErr := telemetry.ReadProcessIO()
|
||||
start := time.Now()
|
||||
out, err := telemetry.ExecWithMetrics(dm.logger, "dataset transfer", time.Since(start), func() (string, error) {
|
||||
return dm.nasServer.ExecContext(ctx, rsyncCmd)
|
||||
})
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("transfer failed",
|
||||
"dataset", datasetName,
|
||||
"duration", duration,
|
||||
"error", err,
|
||||
"output", out)
|
||||
|
||||
if ioErr == nil {
|
||||
if after, readErr := telemetry.ReadProcessIO(); readErr == nil {
|
||||
delta := telemetry.DiffIO(ioBefore, after)
|
||||
logger.Debug("transfer io stats",
|
||||
"dataset", datasetName,
|
||||
"read_bytes", delta.ReadBytes,
|
||||
"write_bytes", delta.WriteBytes)
|
||||
}
|
||||
}
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if redisErr := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "failed",
|
||||
"error", err.Error()).Err(); redisErr != nil {
|
||||
logger.Warn("failed to record transfer failure in Redis", "error", redisErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("transfer complete",
|
||||
"dataset", datasetName,
|
||||
"duration", duration,
|
||||
"size_gb", sizeGB)
|
||||
|
||||
if ioErr == nil {
|
||||
if after, readErr := telemetry.ReadProcessIO(); readErr == nil {
|
||||
delta := telemetry.DiffIO(ioBefore, after)
|
||||
logger.Debug("transfer io stats",
|
||||
"dataset", datasetName,
|
||||
"read_bytes", delta.ReadBytes,
|
||||
"write_bytes", delta.WriteBytes)
|
||||
}
|
||||
}
|
||||
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.HSet(dm.ctx, fmt.Sprintf("ml:data:transfer:%s", datasetName),
|
||||
"status", "completed",
|
||||
"completed_at", time.Now().Unix(),
|
||||
"duration_seconds", duration.Seconds()).Err(); err != nil {
|
||||
logger.Warn("failed to record transfer completion in Redis", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track dataset metadata
|
||||
dm.saveDatasetInfo(datasetName, size)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dm *DataManager) saveDatasetInfo(name string, size int64) {
|
||||
if dm.taskQueue == nil {
|
||||
return // Skip in local mode
|
||||
}
|
||||
|
||||
info := DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
LastAccess: time.Now(),
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(info)
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Set(dm.ctx, fmt.Sprintf("ml:dataset:%s", name), data, 0).Err(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to save dataset info to Redis",
|
||||
"dataset", name, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dm *DataManager) updateLastAccess(name string) {
|
||||
if dm.taskQueue == nil {
|
||||
return // Skip in local mode
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("ml:dataset:%s", name)
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
data, err := redisClient.Get(dm.ctx, key).Result()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var info DatasetInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
info.LastAccess = time.Now()
|
||||
newData, _ := json.Marshal(info)
|
||||
redisClient = dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Set(dm.ctx, key, newData, 0).Err(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("failed to update last access in Redis",
|
||||
"dataset", name, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListDatasetsOnML returns a list of all datasets currently stored on the ML server.
|
||||
func (dm *DataManager) ListDatasetsOnML() ([]DatasetInfo, error) {
|
||||
out, err := dm.mlServer.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", dm.config.MLDataDir))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var datasets []DatasetInfo
|
||||
for name := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var info DatasetInfo
|
||||
|
||||
// Only use Redis if available
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
key := fmt.Sprintf("ml:dataset:%s", name)
|
||||
data, err := redisClient.Get(dm.ctx, key).Result()
|
||||
|
||||
if err == nil {
|
||||
if unmarshalErr := json.Unmarshal([]byte(data), &info); unmarshalErr != nil {
|
||||
// Fallback to disk if unmarshal fails
|
||||
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
|
||||
info = DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: get from disk
|
||||
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
|
||||
info = DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Local mode: get from disk
|
||||
size, _ := dm.mlServer.GetFileSize(filepath.Join(dm.config.MLDataDir, name))
|
||||
info = DatasetInfo{
|
||||
Name: name,
|
||||
SizeBytes: size,
|
||||
Location: "ml",
|
||||
}
|
||||
}
|
||||
|
||||
datasets = append(datasets, info)
|
||||
}
|
||||
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
func (dm *DataManager) CleanupOldData() error {
|
||||
logger := dm.logger.Job(dm.ctx, "data_manager", "")
|
||||
logger.Info("running data cleanup")
|
||||
|
||||
datasets, err := dm.ListDatasetsOnML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, ds := range datasets {
|
||||
totalSize += ds.SizeBytes
|
||||
}
|
||||
|
||||
totalSizeGB := float64(totalSize) / (1024 * 1024 * 1024)
|
||||
logger.Info("current storage usage",
|
||||
"total_size_gb", totalSizeGB,
|
||||
"dataset_count", len(datasets))
|
||||
|
||||
// Delete datasets older than max age or if over size limit
|
||||
maxAge := time.Duration(dm.config.MaxAgeHours) * time.Hour
|
||||
maxSize := int64(dm.config.MaxSizeGB) * 1024 * 1024 * 1024
|
||||
|
||||
var deleted []string
|
||||
for _, ds := range datasets {
|
||||
shouldDelete := false
|
||||
|
||||
// Check age
|
||||
if !ds.LastAccess.IsZero() && time.Since(ds.LastAccess) > maxAge {
|
||||
logger.Info("dataset is old, marking for deletion",
|
||||
"dataset", ds.Name,
|
||||
"last_access", ds.LastAccess,
|
||||
"age_hours", time.Since(ds.LastAccess).Hours())
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// Check if over size limit
|
||||
if totalSize > maxSize {
|
||||
logger.Info("over size limit, deleting oldest dataset",
|
||||
"dataset", ds.Name,
|
||||
"current_size_gb", totalSizeGB,
|
||||
"max_size_gb", dm.config.MaxSizeGB)
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
if shouldDelete {
|
||||
path := filepath.Join(dm.config.MLDataDir, ds.Name)
|
||||
logger.Info("deleting dataset", "dataset", ds.Name, "path", path)
|
||||
|
||||
if _, err := dm.mlServer.Exec(fmt.Sprintf("rm -rf %s", path)); err != nil {
|
||||
logger.Error("failed to delete dataset",
|
||||
"dataset", ds.Name,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
deleted = append(deleted, ds.Name)
|
||||
totalSize -= ds.SizeBytes
|
||||
|
||||
// FIXED: Remove from Redis only if available, with error handling
|
||||
if dm.taskQueue != nil {
|
||||
redisClient := dm.taskQueue.GetRedisClient()
|
||||
if err := redisClient.Del(dm.ctx, fmt.Sprintf("ml:dataset:%s", ds.Name)).Err(); err != nil {
|
||||
logger.Warn("failed to delete dataset from Redis",
|
||||
"dataset", ds.Name,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(deleted) > 0 {
|
||||
logger.Info("cleanup complete",
|
||||
"deleted_count", len(deleted),
|
||||
"deleted_datasets", deleted)
|
||||
} else {
|
||||
logger.Info("cleanup complete", "deleted_count", 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableDiskSpace returns available disk space in bytes
|
||||
func (dm *DataManager) GetAvailableDiskSpace() int64 {
|
||||
logger := dm.logger.Job(dm.ctx, "data_manager", "")
|
||||
|
||||
// Check disk space on ML server
|
||||
cmd := "df -k " + dm.config.MLDataDir + " | tail -1 | awk '{print $4}'"
|
||||
output, err := dm.mlServer.Exec(cmd)
|
||||
if err != nil {
|
||||
logger.Error("failed to get disk space", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Parse KB to bytes
|
||||
var freeKB int64
|
||||
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &freeKB)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse disk space", "error", err, "output", output)
|
||||
return 0
|
||||
}
|
||||
|
||||
return freeKB * 1024 // Convert KB to bytes
|
||||
}
|
||||
|
||||
// GetDatasetInfo returns information about a dataset from NAS
|
||||
func (dm *DataManager) GetDatasetInfo(datasetName string) (*DatasetInfo, error) {
|
||||
// Check if dataset exists on NAS
|
||||
nasPath := filepath.Join(dm.config.NASDataDir, datasetName)
|
||||
cmd := fmt.Sprintf("test -d %s && echo 'exists'", nasPath)
|
||||
output, err := dm.nasServer.Exec(cmd)
|
||||
if err != nil || strings.TrimSpace(output) != "exists" {
|
||||
return nil, fmt.Errorf("dataset %s not found on NAS", datasetName)
|
||||
}
|
||||
|
||||
// Get dataset size
|
||||
cmd = fmt.Sprintf("du -sb %s | cut -f1", nasPath)
|
||||
output, err = dm.nasServer.Exec(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset size: %w", err)
|
||||
}
|
||||
|
||||
var sizeBytes int64
|
||||
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &sizeBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse dataset size: %w", err)
|
||||
}
|
||||
|
||||
// Get last modification time as proxy for last access
|
||||
cmd = fmt.Sprintf("stat -c %%Y %s", nasPath)
|
||||
output, err = dm.nasServer.Exec(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset timestamp: %w", err)
|
||||
}
|
||||
|
||||
var modTime int64
|
||||
_, err = fmt.Sscanf(strings.TrimSpace(output), "%d", &modTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse timestamp: %w", err)
|
||||
}
|
||||
|
||||
return &DatasetInfo{
|
||||
Name: datasetName,
|
||||
SizeBytes: sizeBytes,
|
||||
Location: "nas",
|
||||
LastAccess: time.Unix(modTime, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateDatasetWithCleanup checks if dataset fits and runs cleanup if needed
|
||||
func (dm *DataManager) ValidateDatasetWithCleanup(datasetName string) error {
|
||||
logger := dm.logger.Job(dm.ctx, "data_manager", "")
|
||||
|
||||
// Get dataset info
|
||||
info, err := dm.GetDatasetInfo(datasetName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get dataset info: %w", err)
|
||||
}
|
||||
|
||||
// Check current available space
|
||||
availableSpace := dm.GetAvailableDiskSpace()
|
||||
|
||||
logger.Info("dataset size validation",
|
||||
"dataset", datasetName,
|
||||
"dataset_size_gb", float64(info.SizeBytes)/(1024*1024*1024),
|
||||
"available_gb", float64(availableSpace)/(1024*1024*1024))
|
||||
|
||||
// If enough space, proceed
|
||||
if info.SizeBytes <= availableSpace {
|
||||
logger.Info("sufficient space available", "dataset", datasetName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try cleanup first
|
||||
logger.Info("insufficient space, running cleanup",
|
||||
"dataset", datasetName,
|
||||
"required_gb", float64(info.SizeBytes)/(1024*1024*1024),
|
||||
"available_gb", float64(availableSpace)/(1024*1024*1024))
|
||||
|
||||
if err := dm.CleanupOldData(); err != nil {
|
||||
return fmt.Errorf("cleanup failed: %w", err)
|
||||
}
|
||||
|
||||
// Check space again after cleanup
|
||||
availableSpace = dm.GetAvailableDiskSpace()
|
||||
logger.Info("space after cleanup",
|
||||
"available_gb", float64(availableSpace)/(1024*1024*1024))
|
||||
|
||||
// If now enough space, proceed
|
||||
if info.SizeBytes <= availableSpace {
|
||||
logger.Info("cleanup freed enough space", "dataset", datasetName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Still not enough space
|
||||
return fmt.Errorf("dataset %s (%.2fGB) too large for available space (%.2fGB) even after cleanup",
|
||||
datasetName,
|
||||
float64(info.SizeBytes)/(1024*1024*1024),
|
||||
float64(availableSpace)/(1024*1024*1024))
|
||||
}
|
||||
|
||||
func (dm *DataManager) StartCleanupLoop() {
|
||||
logger := dm.logger.Job(dm.ctx, "data_manager", "")
|
||||
ticker := time.NewTicker(time.Duration(dm.config.CleanupInterval) * time.Minute)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-dm.ctx.Done():
|
||||
logger.Info("cleanup loop stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := dm.CleanupOldData(); err != nil {
|
||||
logger.Error("cleanup error", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the DataManager, stopping the cleanup loop and
|
||||
// closing all connections to ML server, NAS server, and Redis.
|
||||
func (dm *DataManager) Close() {
|
||||
dm.cancel() // Cancel context to stop cleanup loop
|
||||
|
||||
// Wait a moment for cleanup loop to finish
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if dm.mlServer != nil {
|
||||
if err := dm.mlServer.Close(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing ML server connection", "error", err)
|
||||
}
|
||||
}
|
||||
if dm.nasServer != nil {
|
||||
if err := dm.nasServer.Close(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing NAS server connection", "error", err)
|
||||
}
|
||||
}
|
||||
if dm.taskQueue != nil {
|
||||
if err := dm.taskQueue.Close(); err != nil {
|
||||
dm.logger.Job(dm.ctx, "data_manager", "").Warn("error closing Redis connection", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse authentication flags
|
||||
authFlags := auth.ParseAuthFlags()
|
||||
if err := auth.ValidateAuthFlags(authFlags); err != nil {
|
||||
log.Fatalf("Authentication flag error: %v", err)
|
||||
}
|
||||
|
||||
// Get API key from various sources
|
||||
apiKey := auth.GetAPIKeyFromSources(authFlags)
|
||||
|
||||
configFile := "configs/config-local.yaml"
|
||||
if authFlags.ConfigFile != "" {
|
||||
configFile = authFlags.ConfigFile
|
||||
}
|
||||
|
||||
// Parse command line args
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] fetch <job-name> <dataset> [dataset...]")
|
||||
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] list")
|
||||
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] cleanup")
|
||||
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] validate <dataset>")
|
||||
fmt.Println(" data_manager [--config configs/config-local.yaml] [--api-key <key>] daemon")
|
||||
fmt.Println()
|
||||
auth.PrintAuthHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Check for --config flag
|
||||
if len(os.Args) >= 3 && os.Args[1] == "--config" {
|
||||
configFile = os.Args[2]
|
||||
// Shift args
|
||||
os.Args = append([]string{os.Args[0]}, os.Args[3:]...)
|
||||
}
|
||||
|
||||
cfg, err := LoadDataConfig(configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Validate authentication configuration
|
||||
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
||||
log.Fatalf("Invalid authentication configuration: %v", err)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("Invalid configuration: %v", err)
|
||||
}
|
||||
|
||||
// Test authentication if enabled
|
||||
if cfg.Auth.Enabled && apiKey != "" {
|
||||
user, err := cfg.Auth.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Authentication failed: %v", err)
|
||||
}
|
||||
log.Printf("Data manager authenticated as user: %s (admin: %v)", user.Name, user.Admin)
|
||||
} else if cfg.Auth.Enabled {
|
||||
log.Fatal("Authentication required but no API key provided")
|
||||
}
|
||||
|
||||
dm, err := NewDataManager(cfg, apiKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create data manager: %v", err)
|
||||
}
|
||||
defer dm.Close()
|
||||
|
||||
cmd := os.Args[1]
|
||||
|
||||
switch cmd {
|
||||
case "fetch":
|
||||
if len(os.Args) < 4 {
|
||||
log.Fatal("Usage: data_manager fetch <job-name> <dataset> [dataset...]")
|
||||
}
|
||||
jobName := os.Args[2]
|
||||
datasets := os.Args[3:]
|
||||
|
||||
for _, dataset := range datasets {
|
||||
if err := dm.FetchDataset(jobName, dataset); err != nil {
|
||||
dm.logger.Job(context.Background(), jobName, "").Error("failed to fetch dataset",
|
||||
"dataset", dataset,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
|
||||
case "list":
|
||||
datasets, err := dm.ListDatasetsOnML()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to list datasets: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Datasets on ML server:")
|
||||
fmt.Println("======================")
|
||||
var totalSize int64
|
||||
for _, ds := range datasets {
|
||||
sizeMB := float64(ds.SizeBytes) / (1024 * 1024)
|
||||
lastAccess := "unknown"
|
||||
if !ds.LastAccess.IsZero() {
|
||||
lastAccess = ds.LastAccess.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
fmt.Printf("%-30s %10.2f MB Last access: %s\n", ds.Name, sizeMB, lastAccess)
|
||||
totalSize += ds.SizeBytes
|
||||
}
|
||||
fmt.Printf("\nTotal: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
|
||||
|
||||
case "validate":
|
||||
if len(os.Args) < 3 {
|
||||
log.Fatal("Usage: data_manager validate <dataset>")
|
||||
}
|
||||
dataset := os.Args[2]
|
||||
|
||||
fmt.Printf("Validating dataset: %s\n", dataset)
|
||||
if err := dm.ValidateDatasetWithCleanup(dataset); err != nil {
|
||||
log.Fatalf("Validation failed: %v", err)
|
||||
}
|
||||
fmt.Printf("✅ Dataset %s can be downloaded\n", dataset)
|
||||
|
||||
case "cleanup":
|
||||
if err := dm.CleanupOldData(); err != nil {
|
||||
log.Fatalf("Cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
case "daemon":
|
||||
logger := dm.logger.Job(context.Background(), "data_manager", "")
|
||||
logger.Info("starting data manager daemon")
|
||||
dm.StartCleanupLoop()
|
||||
logger.Info("cleanup configuration",
|
||||
"interval_minutes", cfg.CleanupInterval,
|
||||
"max_age_hours", cfg.MaxAgeHours,
|
||||
"max_size_gb", cfg.MaxSizeGB)
|
||||
|
||||
// Handle graceful shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
sig := <-sigChan
|
||||
logger.Info("received shutdown signal", "signal", sig)
|
||||
dm.Close()
|
||||
logger.Info("data manager shut down gracefully")
|
||||
|
||||
default:
|
||||
log.Fatalf("Unknown command: %s", cmd)
|
||||
}
|
||||
}
|
||||
282
cmd/tui/README.md
Normal file
282
cmd/tui/README.md
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
# FetchML TUI - Terminal User Interface
|
||||
|
||||
An interactive terminal dashboard for managing ML experiments, monitoring system resources, and controlling job execution.
|
||||
|
||||
## Features
|
||||
|
||||
### 📊 Real-time Monitoring
|
||||
- **Job Status** - Track pending, running, finished, and failed jobs
|
||||
- **GPU Metrics** - Monitor GPU utilization, memory, and temperature
|
||||
- **Container Status** - View running Podman/Docker containers
|
||||
- **Task Queue** - See queued tasks with priorities and status
|
||||
|
||||
### 🎮 Interactive Controls
|
||||
- **Queue Jobs** - Submit jobs with custom arguments and priorities
|
||||
- **View Logs** - Real-time log viewing for running jobs
|
||||
- **Cancel Tasks** - Stop running tasks
|
||||
- **Delete Jobs** - Remove pending jobs
|
||||
- **Mark Failed** - Manually mark stuck jobs as failed
|
||||
|
||||
### ⚙️ Settings Management
|
||||
- **API Key Configuration** - Set and update API keys on the fly
|
||||
- **In-memory Storage** - Settings persist for the session
|
||||
|
||||
### 🎨 Modern UI
|
||||
- **Clean Design** - Dark-mode friendly with adaptive colors
|
||||
- **Responsive Layout** - Adjusts to terminal size
|
||||
- **Context-aware Help** - Shows relevant shortcuts for each view
|
||||
- **Mouse Support** - Optional mouse navigation
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Running the TUI
|
||||
|
||||
```bash
|
||||
# Using make (recommended)
|
||||
make tui-dev # Dev mode (remote server)
|
||||
make tui # Local mode
|
||||
|
||||
# Direct execution with CLI config (TOML)
|
||||
./bin/tui --config ~/.ml/config.toml
|
||||
|
||||
# With custom TOML config
|
||||
./bin/tui --config path/to/config.toml
|
||||
```
|
||||
|
||||
### First Time Setup
|
||||
|
||||
1. **Build the binary**
|
||||
```bash
|
||||
make build
|
||||
```
|
||||
|
||||
2. **Get your API key**
|
||||
```bash
|
||||
./bin/user_manager --config configs/config_dev.yaml --cmd generate-key --username your_name
|
||||
```
|
||||
|
||||
3. **Launch the TUI**
|
||||
```bash
|
||||
make tui-dev
|
||||
```
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
### Navigation
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `1` | Switch to Job List view |
|
||||
| `g` | Switch to GPU Status view |
|
||||
| `l` | View logs for selected job |
|
||||
| `v` | Switch to Task Queue view |
|
||||
| `o` | Switch to Container Status view |
|
||||
| `s` | Open Settings |
|
||||
| `h` or `?` | Toggle help screen |
|
||||
|
||||
### Job Management
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `t` | Queue selected job (default args) |
|
||||
| `a` | Queue job with custom arguments |
|
||||
| `c` | Cancel running task |
|
||||
| `d` | Delete pending job |
|
||||
| `f` | Mark running job as failed |
|
||||
|
||||
### System
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `r` | Refresh all data |
|
||||
| `G` | Refresh GPU status only |
|
||||
| `q` or `Ctrl+C` | Quit |
|
||||
|
||||
### Settings View
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `↑`/`↓` or `j`/`k` | Navigate options |
|
||||
| `Enter` | Select/Save |
|
||||
| `Esc` | Exit settings |
|
||||
|
||||
## Views
|
||||
|
||||
### Job List (Default)
|
||||
- Shows all jobs across all statuses
|
||||
- Filter with `/` key
|
||||
- Navigate with arrow keys or `j`/`k`
|
||||
- Select and press `l` to view logs
|
||||
|
||||
### GPU Status
|
||||
- Real-time GPU metrics (nvidia-smi)
|
||||
- macOS GPU info (system_profiler)
|
||||
- Utilization, memory, temperature
|
||||
|
||||
### Container Status
|
||||
- Running Podman/Docker containers
|
||||
- Container health and status
|
||||
- System info (Podman/Docker version)
|
||||
|
||||
### Task Queue
|
||||
- All queued tasks with priorities
|
||||
- Task status and creation time
|
||||
- Running duration for active tasks
|
||||
|
||||
### Logs
|
||||
- Last 200 lines of job output
|
||||
- Auto-scroll to bottom
|
||||
- Refreshes with job status
|
||||
|
||||
### Settings
|
||||
- View current API key status
|
||||
- Update API key
|
||||
- Save configuration (in-memory)
|
||||
|
||||
## Terminal Compatibility
|
||||
|
||||
The TUI is built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) and works on all modern terminals:
|
||||
|
||||
### ✅ Fully Supported
|
||||
- **WezTerm** (recommended)
|
||||
- **Alacritty**
|
||||
- **Kitty**
|
||||
- **iTerm2** (macOS)
|
||||
- **Terminal.app** (macOS)
|
||||
- **Windows Terminal**
|
||||
- **GNOME Terminal**
|
||||
- **Konsole**
|
||||
|
||||
### ✅ Multiplexers
|
||||
- **tmux**
|
||||
- **screen**
|
||||
|
||||
### Features
|
||||
- ✅ 256 colors
|
||||
- ✅ True color (24-bit)
|
||||
- ✅ Mouse support
|
||||
- ✅ Alt screen buffer
|
||||
- ✅ Adaptive colors (light/dark themes)
|
||||
|
||||
|
||||
### Key Components
|
||||
|
||||
- **Model** - Pure data structures (State, Job, Task)
|
||||
- **View** - Rendering functions (no business logic)
|
||||
- **Controller** - Message handling and state updates
|
||||
- **Services** - SSH/Redis communication
|
||||
|
||||
## Configuration
|
||||
|
||||
The TUI uses TOML configuration format for CLI settings:
|
||||
|
||||
```toml
|
||||
# ~/.ml/config.toml
|
||||
worker_host = "localhost"
|
||||
worker_user = "your_user"
|
||||
worker_base = "~/ml_jobs"
|
||||
worker_port = 22
|
||||
api_key = "your_api_key_here"
|
||||
```
|
||||
|
||||
For CLI usage, run `ml init` to create a default configuration file.
|
||||
|
||||
See [Configuration Documentation](../docs/documentation.md#configuration) for details.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### TUI doesn't start
|
||||
```bash
|
||||
# Check if binary exists
|
||||
ls -la bin/tui
|
||||
|
||||
# Rebuild if needed
|
||||
make build
|
||||
|
||||
# Check CLI config
|
||||
cat ~/.ml/config.toml
|
||||
```
|
||||
|
||||
### Authentication errors
|
||||
```bash
|
||||
# Verify CLI config exists
|
||||
ls -la ~/.ml/config.toml
|
||||
|
||||
# Initialize CLI config if needed
|
||||
ml init
|
||||
|
||||
# Test connection
|
||||
./bin/tui --config ~/.ml/config.toml
|
||||
```
|
||||
|
||||
### Display issues
|
||||
```bash
|
||||
# Check terminal type
|
||||
echo $TERM
|
||||
|
||||
# Should be xterm-256color or similar
|
||||
# If not, set it:
|
||||
export TERM=xterm-256color
|
||||
```
|
||||
|
||||
### Connection issues
|
||||
```bash
|
||||
# Test SSH connection
|
||||
ssh your_user@your_server
|
||||
|
||||
# Test Redis connection
|
||||
redis-cli ping
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Building
|
||||
```bash
|
||||
# Build TUI only
|
||||
go build -o bin/tui ./cmd/tui
|
||||
|
||||
# Build all binaries
|
||||
make build
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run with verbose logging
|
||||
./bin/tui --config ~/.ml/config.toml 2>tui.log
|
||||
|
||||
# Check logs
|
||||
tail -f tui.log
|
||||
```
|
||||
|
||||
### Code Organization
|
||||
- Keep files under 300 lines
|
||||
- Separate concerns (MVC pattern)
|
||||
- Use descriptive function names
|
||||
- Add comments for complex logic
|
||||
|
||||
## Tips & Tricks
|
||||
|
||||
### Efficient Workflow
|
||||
1. Keep TUI open in one terminal
|
||||
2. Edit code in another terminal
|
||||
3. Use `r` to refresh after changes
|
||||
4. Use `h` to quickly reference shortcuts
|
||||
|
||||
### Custom Arguments
|
||||
When queuing jobs with `a`:
|
||||
```
|
||||
--epochs 100 --lr 0.001 --priority 5
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
- Use `G` for quick GPU refresh (faster than `r`)
|
||||
- Check queue with `v` before queuing new jobs
|
||||
- Use `l` to debug failed jobs
|
||||
|
||||
### Settings
|
||||
- Update API key without restarting
|
||||
- Changes are in-memory only
|
||||
- Restart TUI to reset
|
||||
|
||||
## See Also
|
||||
|
||||
- [Main Documentation](../docs/documentation.md)
|
||||
- [Worker Documentation](../cmd/worker/README.md)
|
||||
- [Configuration Guide](../docs/documentation.md#configuration)
|
||||
- [Bubble Tea Documentation](https://github.com/charmbracelet/bubbletea)
|
||||
492
cmd/tui/internal/config/cli_config.go
Normal file
492
cmd/tui/internal/config/cli_config.go
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
utils "github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/stretchr/testify/assert/yaml"
|
||||
)
|
||||
|
||||
// CLIConfig represents the TOML config structure used by the CLI
|
||||
type CLIConfig struct {
|
||||
WorkerHost string `toml:"worker_host"`
|
||||
WorkerUser string `toml:"worker_user"`
|
||||
WorkerBase string `toml:"worker_base"`
|
||||
WorkerPort int `toml:"worker_port"`
|
||||
APIKey string `toml:"api_key"`
|
||||
|
||||
// User context (filled after authentication)
|
||||
CurrentUser *UserContext `toml:"-"`
|
||||
}
|
||||
|
||||
// UserContext represents the authenticated user information
|
||||
type UserContext struct {
|
||||
Name string `json:"name"`
|
||||
Admin bool `json:"admin"`
|
||||
Roles []string `json:"roles"`
|
||||
Permissions map[string]bool `json:"permissions"`
|
||||
}
|
||||
|
||||
// LoadCLIConfig loads the CLI's TOML configuration from the provided path.
|
||||
// If path is empty, ~/.ml/config.toml is used. The resolved path is returned.
|
||||
// Automatically migrates from YAML config if TOML doesn't exist.
|
||||
// Environment variables with FETCH_ML_CLI_ prefix override config file values.
|
||||
func LoadCLIConfig(configPath string) (*CLIConfig, string, error) {
|
||||
if configPath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
configPath = filepath.Join(home, ".ml", "config.toml")
|
||||
} else {
|
||||
configPath = utils.ExpandPath(configPath)
|
||||
if !filepath.IsAbs(configPath) {
|
||||
if abs, err := filepath.Abs(configPath); err == nil {
|
||||
configPath = abs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if TOML config exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
// Try to migrate from YAML
|
||||
yamlPath := strings.TrimSuffix(configPath, ".toml") + ".yaml"
|
||||
if migratedPath, err := migrateFromYAML(yamlPath, configPath); err == nil {
|
||||
log.Printf("Migrated configuration from %s to %s", yamlPath, migratedPath)
|
||||
configPath = migratedPath
|
||||
} else {
|
||||
return nil, configPath, fmt.Errorf("CLI config not found at %s (run 'ml init' first)", configPath)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, configPath, fmt.Errorf("cannot access CLI config %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, configPath, fmt.Errorf("failed to read CLI config: %w", err)
|
||||
}
|
||||
|
||||
config := &CLIConfig{}
|
||||
if err := parseTOML(data, config); err != nil {
|
||||
return nil, configPath, fmt.Errorf("failed to parse CLI config: %w", err)
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, configPath, err
|
||||
}
|
||||
|
||||
// Apply environment variable overrides with FETCH_ML_CLI_ prefix
|
||||
if host := os.Getenv("FETCH_ML_CLI_HOST"); host != "" {
|
||||
config.WorkerHost = host
|
||||
}
|
||||
if user := os.Getenv("FETCH_ML_CLI_USER"); user != "" {
|
||||
config.WorkerUser = user
|
||||
}
|
||||
if base := os.Getenv("FETCH_ML_CLI_BASE"); base != "" {
|
||||
config.WorkerBase = base
|
||||
}
|
||||
if port := os.Getenv("FETCH_ML_CLI_PORT"); port != "" {
|
||||
if p, err := parseInt(port); err == nil {
|
||||
config.WorkerPort = p
|
||||
}
|
||||
}
|
||||
if apiKey := os.Getenv("FETCH_ML_CLI_API_KEY"); apiKey != "" {
|
||||
config.APIKey = apiKey
|
||||
}
|
||||
|
||||
// Also support legacy ML_ prefix for backward compatibility
|
||||
if host := os.Getenv("ML_HOST"); host != "" && config.WorkerHost == "" {
|
||||
config.WorkerHost = host
|
||||
}
|
||||
if user := os.Getenv("ML_USER"); user != "" && config.WorkerUser == "" {
|
||||
config.WorkerUser = user
|
||||
}
|
||||
if base := os.Getenv("ML_BASE"); base != "" && config.WorkerBase == "" {
|
||||
config.WorkerBase = base
|
||||
}
|
||||
if port := os.Getenv("ML_PORT"); port != "" && config.WorkerPort == 0 {
|
||||
if p, err := parseInt(port); err == nil {
|
||||
config.WorkerPort = p
|
||||
}
|
||||
}
|
||||
if apiKey := os.Getenv("ML_API_KEY"); apiKey != "" && config.APIKey == "" {
|
||||
config.APIKey = apiKey
|
||||
}
|
||||
|
||||
return config, configPath, nil
|
||||
}
|
||||
|
||||
// parseTOML is a simple TOML parser for the CLI config format
|
||||
func parseTOML(data []byte, config *CLIConfig) error {
|
||||
lines := strings.Split(string(data), "\n")
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// Remove quotes if present
|
||||
if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "worker_host":
|
||||
config.WorkerHost = value
|
||||
case "worker_user":
|
||||
config.WorkerUser = value
|
||||
case "worker_base":
|
||||
config.WorkerBase = value
|
||||
case "worker_port":
|
||||
if p, err := parseInt(value); err == nil {
|
||||
config.WorkerPort = p
|
||||
}
|
||||
case "api_key":
|
||||
config.APIKey = value
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToTUIConfig converts CLI config to TUI config structure
|
||||
func (c *CLIConfig) ToTUIConfig() *Config {
|
||||
// Get smart defaults for current environment
|
||||
smart := utils.GetSmartDefaults()
|
||||
|
||||
tuiConfig := &Config{
|
||||
Host: c.WorkerHost,
|
||||
User: c.WorkerUser,
|
||||
Port: c.WorkerPort,
|
||||
BasePath: c.WorkerBase,
|
||||
|
||||
// Set defaults for TUI-specific fields using smart defaults
|
||||
RedisAddr: smart.RedisAddr(),
|
||||
RedisDB: 0,
|
||||
PodmanImage: "ml-worker:latest",
|
||||
ContainerWorkspace: utils.DefaultContainerWorkspace,
|
||||
ContainerResults: utils.DefaultContainerResults,
|
||||
GPUAccess: false,
|
||||
}
|
||||
|
||||
// Set up auth config with CLI API key
|
||||
tuiConfig.Auth = auth.AuthConfig{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"cli_user": {
|
||||
Hash: auth.APIKeyHash(hashAPIKey(c.APIKey)),
|
||||
Admin: true,
|
||||
Roles: []string{"user", "admin"},
|
||||
Permissions: map[string]bool{
|
||||
"read": true,
|
||||
"write": true,
|
||||
"delete": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set known hosts path
|
||||
tuiConfig.KnownHosts = smart.KnownHostsPath()
|
||||
|
||||
return tuiConfig
|
||||
}
|
||||
|
||||
// Validate validates the CLI config
|
||||
func (c *CLIConfig) Validate() error {
|
||||
var errors []string
|
||||
|
||||
if c.WorkerHost == "" {
|
||||
errors = append(errors, "worker_host is required")
|
||||
} else if len(strings.TrimSpace(c.WorkerHost)) == 0 {
|
||||
errors = append(errors, "worker_host cannot be empty or whitespace")
|
||||
}
|
||||
|
||||
if c.WorkerUser == "" {
|
||||
errors = append(errors, "worker_user is required")
|
||||
} else if len(strings.TrimSpace(c.WorkerUser)) == 0 {
|
||||
errors = append(errors, "worker_user cannot be empty or whitespace")
|
||||
}
|
||||
|
||||
if c.WorkerBase == "" {
|
||||
errors = append(errors, "worker_base is required")
|
||||
} else {
|
||||
// Expand and validate path
|
||||
c.WorkerBase = utils.ExpandPath(c.WorkerBase)
|
||||
if !filepath.IsAbs(c.WorkerBase) {
|
||||
errors = append(errors, "worker_base must be an absolute path")
|
||||
}
|
||||
}
|
||||
|
||||
if c.WorkerPort == 0 {
|
||||
errors = append(errors, "worker_port is required")
|
||||
} else if err := utils.ValidatePort(c.WorkerPort); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("invalid worker_port: %v", err))
|
||||
}
|
||||
|
||||
if c.APIKey == "" {
|
||||
errors = append(errors, "api_key is required")
|
||||
} else if len(c.APIKey) < 16 {
|
||||
errors = append(errors, "api_key must be at least 16 characters")
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("validation failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthenticateWithServer validates the API key and sets user context
|
||||
func (c *CLIConfig) AuthenticateWithServer() error {
|
||||
if c.APIKey == "" {
|
||||
return fmt.Errorf("no API key configured")
|
||||
}
|
||||
|
||||
// Create temporary auth config for validation
|
||||
authConfig := &auth.AuthConfig{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"temp": {
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey(c.APIKey)),
|
||||
Admin: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Validate API key and get user info
|
||||
user, err := authConfig.ValidateAPIKey(auth.HashAPIKey(c.APIKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("API key validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Set user context
|
||||
c.CurrentUser = &UserContext{
|
||||
Name: user.Name,
|
||||
Admin: user.Admin,
|
||||
Roles: user.Roles,
|
||||
Permissions: user.Permissions,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPermission checks if the current user has a specific permission
|
||||
func (c *CLIConfig) CheckPermission(permission string) bool {
|
||||
if c.CurrentUser == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Admin users have all permissions
|
||||
if c.CurrentUser.Admin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check explicit permission
|
||||
if c.CurrentUser.Permissions[permission] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check wildcard permission
|
||||
if c.CurrentUser.Permissions["*"] {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CanViewJob checks if user can view a specific job
|
||||
func (c *CLIConfig) CanViewJob(jobUserID string) bool {
|
||||
if c.CurrentUser == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Admin can view all jobs
|
||||
if c.CurrentUser.Admin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Users can view their own jobs
|
||||
return jobUserID == c.CurrentUser.Name
|
||||
}
|
||||
|
||||
// CanModifyJob checks if user can modify a specific job
|
||||
func (c *CLIConfig) CanModifyJob(jobUserID string) bool {
|
||||
if c.CurrentUser == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Need jobs:update permission
|
||||
if !c.CheckPermission("jobs:update") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Admin can modify all jobs
|
||||
if c.CurrentUser.Admin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Users can only modify their own jobs
|
||||
return jobUserID == c.CurrentUser.Name
|
||||
}
|
||||
|
||||
// migrateFromYAML migrates configuration from YAML to TOML format
|
||||
func migrateFromYAML(yamlPath, tomlPath string) (string, error) {
|
||||
// Check if YAML file exists
|
||||
if _, err := os.Stat(yamlPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("YAML config not found at %s", yamlPath)
|
||||
}
|
||||
|
||||
// Read YAML config
|
||||
data, err := os.ReadFile(yamlPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read YAML config: %w", err)
|
||||
}
|
||||
|
||||
// Parse YAML to extract relevant fields
|
||||
var yamlConfig map[string]interface{}
|
||||
if err := yaml.Unmarshal(data, &yamlConfig); err != nil {
|
||||
return "", fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
|
||||
// Create CLI config from YAML data
|
||||
cliConfig := &CLIConfig{}
|
||||
|
||||
// Extract values with fallbacks
|
||||
if host, ok := yamlConfig["host"].(string); ok {
|
||||
cliConfig.WorkerHost = host
|
||||
}
|
||||
if user, ok := yamlConfig["user"].(string); ok {
|
||||
cliConfig.WorkerUser = user
|
||||
}
|
||||
if base, ok := yamlConfig["base_path"].(string); ok {
|
||||
cliConfig.WorkerBase = base
|
||||
}
|
||||
if port, ok := yamlConfig["port"].(int); ok {
|
||||
cliConfig.WorkerPort = port
|
||||
}
|
||||
|
||||
// Try to extract API key from auth section
|
||||
if auth, ok := yamlConfig["auth"].(map[string]interface{}); ok {
|
||||
if apiKeys, ok := auth["api_keys"].(map[string]interface{}); ok {
|
||||
for _, keyEntry := range apiKeys {
|
||||
if keyMap, ok := keyEntry.(map[string]interface{}); ok {
|
||||
if hash, ok := keyMap["hash"].(string); ok {
|
||||
cliConfig.APIKey = hash // Note: This is the hash, not the actual key
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate migrated config
|
||||
if err := cliConfig.Validate(); err != nil {
|
||||
return "", fmt.Errorf("migrated config validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Generate TOML content
|
||||
tomlContent := fmt.Sprintf(`# Fetch ML CLI Configuration
|
||||
# Migrated from YAML configuration
|
||||
|
||||
worker_host = "%s"
|
||||
worker_user = "%s"
|
||||
worker_base = "%s"
|
||||
worker_port = %d
|
||||
api_key = "%s"
|
||||
`,
|
||||
cliConfig.WorkerHost,
|
||||
cliConfig.WorkerUser,
|
||||
cliConfig.WorkerBase,
|
||||
cliConfig.WorkerPort,
|
||||
cliConfig.APIKey,
|
||||
)
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(tomlPath), 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Write TOML file
|
||||
if err := os.WriteFile(tomlPath, []byte(tomlContent), 0600); err != nil {
|
||||
return "", fmt.Errorf("failed to write TOML config: %w", err)
|
||||
}
|
||||
|
||||
return tomlPath, nil
|
||||
}
|
||||
|
||||
// ConfigExists checks if a CLI configuration file exists
|
||||
func ConfigExists(configPath string) bool {
|
||||
if configPath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
configPath = filepath.Join(home, ".ml", "config.toml")
|
||||
}
|
||||
|
||||
_, err := os.Stat(configPath)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// GenerateDefaultConfig creates a default TOML configuration file
|
||||
func GenerateDefaultConfig(configPath string) error {
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate default configuration
|
||||
defaultContent := `# Fetch ML CLI Configuration
|
||||
# This file contains connection settings for the ML platform
|
||||
|
||||
# Worker connection settings
|
||||
worker_host = "localhost" # Hostname or IP of the worker
|
||||
worker_user = "your_username" # SSH username for the worker
|
||||
worker_base = "~/ml_jobs" # Base directory for ML jobs on worker
|
||||
worker_port = 22 # SSH port (default: 22)
|
||||
|
||||
# Authentication
|
||||
api_key = "your_api_key_here" # Your API key (get from admin)
|
||||
|
||||
# Environment variable overrides:
|
||||
# ML_HOST, ML_USER, ML_BASE, ML_PORT, ML_API_KEY
|
||||
`
|
||||
|
||||
// Write configuration file
|
||||
if err := os.WriteFile(configPath, []byte(defaultContent), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
// Set proper permissions
|
||||
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashAPIKey(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
return auth.HashAPIKey(apiKey)
|
||||
}
|
||||
194
cmd/tui/internal/config/cli_config_test.go
Normal file
194
cmd/tui/internal/config/cli_config_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCLIConfig_CheckPermission(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *CLIConfig
|
||||
permission string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Admin has all permissions",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "admin",
|
||||
Admin: true,
|
||||
},
|
||||
},
|
||||
permission: "any:permission",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User with explicit permission",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:create": true},
|
||||
},
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User without permission",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
},
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "No current user",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: nil,
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.config.CheckPermission(tt.permission)
|
||||
if got != tt.want {
|
||||
t.Errorf("CheckPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLIConfig_CanViewJob(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *CLIConfig
|
||||
jobUserID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Admin can view any job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "admin",
|
||||
Admin: true,
|
||||
},
|
||||
},
|
||||
jobUserID: "other_user",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User can view own job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user1",
|
||||
Admin: false,
|
||||
},
|
||||
},
|
||||
jobUserID: "user1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User cannot view other's job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user1",
|
||||
Admin: false,
|
||||
},
|
||||
},
|
||||
jobUserID: "user2",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "No current user cannot view",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: nil,
|
||||
},
|
||||
jobUserID: "user1",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.config.CanViewJob(tt.jobUserID)
|
||||
if got != tt.want {
|
||||
t.Errorf("CanViewJob() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLIConfig_CanModifyJob(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *CLIConfig
|
||||
jobUserID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Admin can modify any job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "admin",
|
||||
Admin: true,
|
||||
Permissions: map[string]bool{"jobs:update": true},
|
||||
},
|
||||
},
|
||||
jobUserID: "other_user",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User with permission can modify own job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user1",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:update": true},
|
||||
},
|
||||
},
|
||||
jobUserID: "user1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "User without permission cannot modify",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user1",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
},
|
||||
},
|
||||
jobUserID: "user1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "User cannot modify other's job",
|
||||
config: &CLIConfig{
|
||||
CurrentUser: &UserContext{
|
||||
Name: "user1",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:update": true},
|
||||
},
|
||||
},
|
||||
jobUserID: "user2",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.config.CanModifyJob(tt.jobUserID)
|
||||
if got != tt.want {
|
||||
t.Errorf("CanModifyJob() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
145
cmd/tui/internal/config/config.go
Normal file
145
cmd/tui/internal/config/config.go
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
utils "github.com/jfraeys/fetch_ml/internal/config"
|
||||
)
|
||||
|
||||
// Config holds TUI configuration
|
||||
type Config struct {
|
||||
Host string `toml:"host"`
|
||||
User string `toml:"user"`
|
||||
SSHKey string `toml:"ssh_key"`
|
||||
Port int `toml:"port"`
|
||||
BasePath string `toml:"base_path"`
|
||||
WrapperScript string `toml:"wrapper_script"`
|
||||
TrainScript string `toml:"train_script"`
|
||||
RedisAddr string `toml:"redis_addr"`
|
||||
RedisPassword string `toml:"redis_password"`
|
||||
RedisDB int `toml:"redis_db"`
|
||||
KnownHosts string `toml:"known_hosts"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.AuthConfig `toml:"auth"`
|
||||
|
||||
// Podman settings
|
||||
PodmanImage string `toml:"podman_image"`
|
||||
ContainerWorkspace string `toml:"container_workspace"`
|
||||
ContainerResults string `toml:"container_results"`
|
||||
GPUAccess bool `toml:"gpu_access"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if _, err := toml.Decode(string(data), &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get smart defaults for current environment
|
||||
smart := utils.GetSmartDefaults()
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = utils.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
}
|
||||
// wrapper_script is deprecated - using secure_runner.py directly via Podman
|
||||
if cfg.TrainScript == "" {
|
||||
cfg.TrainScript = utils.DefaultTrainScript
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
}
|
||||
|
||||
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
|
||||
if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" {
|
||||
cfg.Host = host
|
||||
}
|
||||
if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" {
|
||||
cfg.User = user
|
||||
}
|
||||
if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" {
|
||||
cfg.SSHKey = sshKey
|
||||
}
|
||||
if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" {
|
||||
if p, err := parseInt(port); err == nil {
|
||||
cfg.Port = p
|
||||
}
|
||||
}
|
||||
if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" {
|
||||
cfg.BasePath = basePath
|
||||
}
|
||||
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
|
||||
cfg.TrainScript = trainScript
|
||||
}
|
||||
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
|
||||
cfg.RedisAddr = redisAddr
|
||||
}
|
||||
if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" {
|
||||
cfg.RedisPassword = redisPassword
|
||||
}
|
||||
if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" {
|
||||
if db, err := parseInt(redisDB); err == nil {
|
||||
cfg.RedisDB = db
|
||||
}
|
||||
}
|
||||
if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" {
|
||||
cfg.KnownHosts = knownHosts
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements utils.Validator interface
|
||||
func (c *Config) Validate() error {
|
||||
if c.Port != 0 {
|
||||
if err := utils.ValidatePort(c.Port); err != nil {
|
||||
return fmt.Errorf("invalid SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = utils.ExpandPath(c.BasePath)
|
||||
if !filepath.IsAbs(c.BasePath) {
|
||||
c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
|
||||
func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") }
|
||||
func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") }
|
||||
func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") }
|
||||
|
||||
// parseInt parses a string to integer
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
_, err := fmt.Sscanf(s, "%d", &result)
|
||||
return result, err
|
||||
}
|
||||
384
cmd/tui/internal/controller/commands.go
Normal file
384
cmd/tui/internal/controller/commands.go
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
)
|
||||
|
||||
// Message types for async operations
|
||||
type (
|
||||
JobsLoadedMsg []model.Job
|
||||
TasksLoadedMsg []*model.Task
|
||||
GpuLoadedMsg string
|
||||
ContainerLoadedMsg string
|
||||
LogLoadedMsg string
|
||||
QueueLoadedMsg string
|
||||
SettingsContentMsg string
|
||||
SettingsUpdateMsg struct{}
|
||||
StatusMsg struct {
|
||||
Text string
|
||||
Level string
|
||||
}
|
||||
TickMsg time.Time
|
||||
)
|
||||
|
||||
// Command factories for loading data
|
||||
|
||||
func (c *Controller) loadAllData() tea.Cmd {
|
||||
return tea.Batch(
|
||||
c.loadJobs(),
|
||||
c.loadQueue(),
|
||||
c.loadGPU(),
|
||||
c.loadContainer(),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Controller) loadJobs() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
type jobResult struct {
|
||||
jobs []model.Job
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan jobResult, 1)
|
||||
go func() {
|
||||
var jobs []model.Job
|
||||
statusChan := make(chan []model.Job, 4)
|
||||
|
||||
for _, status := range []model.JobStatus{model.StatusPending, model.StatusRunning, model.StatusFinished, model.StatusFailed} {
|
||||
go func(s model.JobStatus) {
|
||||
path := c.getPathForStatus(s)
|
||||
names := c.server.ListDir(path)
|
||||
var statusJobs []model.Job
|
||||
for _, name := range names {
|
||||
jobStatus, _ := c.taskQueue.GetJobStatus(name)
|
||||
taskID := jobStatus["task_id"]
|
||||
priority := int64(0)
|
||||
if p, ok := jobStatus["priority"]; ok {
|
||||
_, err := fmt.Sscanf(p, "%d", &priority)
|
||||
if err != nil {
|
||||
priority = 0
|
||||
}
|
||||
}
|
||||
statusJobs = append(statusJobs, model.Job{
|
||||
Name: name,
|
||||
Status: s,
|
||||
TaskID: taskID,
|
||||
Priority: priority,
|
||||
})
|
||||
}
|
||||
statusChan <- statusJobs
|
||||
}(status)
|
||||
}
|
||||
|
||||
for range 4 {
|
||||
jobs = append(jobs, <-statusChan...)
|
||||
}
|
||||
|
||||
resultChan <- jobResult{jobs: jobs, err: nil}
|
||||
}()
|
||||
|
||||
result := <-resultChan
|
||||
if result.err != nil {
|
||||
return StatusMsg{Text: "Failed to load jobs: " + result.err.Error(), Level: "error"}
|
||||
}
|
||||
return JobsLoadedMsg(result.jobs)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadQueue() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
tasks, err := c.taskQueue.GetQueuedTasks()
|
||||
if err != nil {
|
||||
c.logger.Error("failed to load queue", "error", err)
|
||||
return StatusMsg{Text: "Failed to load queue: " + err.Error(), Level: "error"}
|
||||
}
|
||||
c.logger.Info("loaded queue", "task_count", len(tasks))
|
||||
return TasksLoadedMsg(tasks)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadGPU() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
type gpuResult struct {
|
||||
content string
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan gpuResult, 1)
|
||||
go func() {
|
||||
cmd := "nvidia-smi --query-gpu=index,name,utilization.gpu,memory.used,memory.total,temperature.gpu --format=csv,noheader,nounits"
|
||||
out, err := c.server.Exec(cmd)
|
||||
if err == nil && strings.TrimSpace(out) != "" {
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("GPU Status\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
parts := strings.Split(line, ", ")
|
||||
if len(parts) >= 6 {
|
||||
formatted.WriteString(fmt.Sprintf("🎮 GPU %s: %s\n", parts[0], parts[1]))
|
||||
formatted.WriteString(fmt.Sprintf(" Utilization: %s%%\n", parts[2]))
|
||||
formatted.WriteString(fmt.Sprintf(" Memory: %s/%s MB\n", parts[3], parts[4]))
|
||||
formatted.WriteString(fmt.Sprintf(" Temperature: %s°C\n\n", parts[5]))
|
||||
}
|
||||
}
|
||||
c.logger.Info("loaded GPU status", "type", "nvidia")
|
||||
resultChan <- gpuResult{content: formatted.String(), err: nil}
|
||||
return
|
||||
}
|
||||
|
||||
cmd = "system_profiler SPDisplaysDataType | grep 'Chipset Model\\|VRAM' | head -2"
|
||||
out, err = c.server.Exec(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn("GPU info unavailable", "error", err)
|
||||
resultChan <- gpuResult{content: "⚠️ GPU info unavailable\n\nRun on a system with nvidia-smi or macOS GPU", err: err}
|
||||
return
|
||||
}
|
||||
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("GPU Status (macOS)\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "Chipset Model") || strings.Contains(line, "VRAM") {
|
||||
formatted.WriteString("🎮 " + strings.TrimSpace(line) + "\n")
|
||||
}
|
||||
}
|
||||
formatted.WriteString("\n💡 Note: nvidia-smi not available on macOS\n")
|
||||
|
||||
c.logger.Info("loaded GPU status", "type", "macos")
|
||||
resultChan <- gpuResult{content: formatted.String(), err: nil}
|
||||
}()
|
||||
|
||||
result := <-resultChan
|
||||
return GpuLoadedMsg(result.content)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadContainer() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
resultChan := make(chan string, 1)
|
||||
go func() {
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("Container Status\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
|
||||
formatted.WriteString("📋 Configuration:\n")
|
||||
formatted.WriteString(fmt.Sprintf(" Image: %s\n", c.config.PodmanImage))
|
||||
formatted.WriteString(fmt.Sprintf(" GPU: %v\n", c.config.GPUAccess))
|
||||
formatted.WriteString(fmt.Sprintf(" Workspace: %s\n", c.config.ContainerWorkspace))
|
||||
formatted.WriteString(fmt.Sprintf(" Results: %s\n\n", c.config.ContainerResults))
|
||||
|
||||
cmd := "podman ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'"
|
||||
out, err := c.server.Exec(cmd)
|
||||
if err == nil && strings.TrimSpace(out) != "" {
|
||||
formatted.WriteString("🐳 Running Containers (Podman):\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
parts := strings.Split(line, "|")
|
||||
if len(parts) >= 3 {
|
||||
status := "🟢"
|
||||
if strings.Contains(parts[1], "Exited") {
|
||||
status = "🔴"
|
||||
}
|
||||
formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0]))
|
||||
formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1]))
|
||||
formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2]))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cmd = "docker ps -a --format '{{.Names}}|{{.Status}}|{{.Image}}'"
|
||||
out, err = c.server.Exec(cmd)
|
||||
if err == nil && strings.TrimSpace(out) != "" {
|
||||
formatted.WriteString("🐳 Running Containers (Docker):\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
parts := strings.Split(line, "|")
|
||||
if len(parts) >= 3 {
|
||||
status := "🟢"
|
||||
if strings.Contains(parts[1], "Exited") {
|
||||
status = "🔴"
|
||||
}
|
||||
formatted.WriteString(fmt.Sprintf(" %s %s\n", status, parts[0]))
|
||||
formatted.WriteString(fmt.Sprintf(" Status: %s\n", parts[1]))
|
||||
formatted.WriteString(fmt.Sprintf(" Image: %s\n\n", parts[2]))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
formatted.WriteString("⚠️ No containers found\n")
|
||||
}
|
||||
}
|
||||
|
||||
formatted.WriteString("💻 System Info:\n")
|
||||
if podmanVersion, err := c.server.Exec("podman --version"); err == nil {
|
||||
formatted.WriteString(fmt.Sprintf(" Podman: %s\n", strings.TrimSpace(podmanVersion)))
|
||||
} else if dockerVersion, err := c.server.Exec("docker --version"); err == nil {
|
||||
formatted.WriteString(fmt.Sprintf(" Docker: %s\n", strings.TrimSpace(dockerVersion)))
|
||||
} else {
|
||||
formatted.WriteString(" ⚠️ Container engine not available\n")
|
||||
}
|
||||
|
||||
c.logger.Info("loaded container status")
|
||||
resultChan <- formatted.String()
|
||||
}()
|
||||
|
||||
return ContainerLoadedMsg(<-resultChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadLog(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
resultChan := make(chan string, 1)
|
||||
go func() {
|
||||
statusChan := make(chan string, 3)
|
||||
|
||||
for _, status := range []model.JobStatus{model.StatusRunning, model.StatusFinished, model.StatusFailed} {
|
||||
go func(s model.JobStatus) {
|
||||
logPath := filepath.Join(c.getPathForStatus(s), jobName, "output.log")
|
||||
if c.server.RemoteExists(logPath) {
|
||||
content := c.server.TailFile(logPath, 200)
|
||||
statusChan <- content
|
||||
} else {
|
||||
statusChan <- ""
|
||||
}
|
||||
}(status)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
result := <-statusChan
|
||||
if result != "" {
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString(fmt.Sprintf("📋 Log: %s\n", jobName))
|
||||
formatted.WriteString(strings.Repeat("═", 60) + "\n\n")
|
||||
formatted.WriteString(result)
|
||||
resultChan <- formatted.String()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resultChan <- fmt.Sprintf("⚠️ No log found for %s\n\nJob may not have started yet.", jobName)
|
||||
}()
|
||||
|
||||
return LogLoadedMsg(<-resultChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) queueJob(jobName string, args string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
resultChan := make(chan StatusMsg, 1)
|
||||
go func() {
|
||||
priority := int64(5)
|
||||
if strings.Contains(args, "--priority") {
|
||||
_, err := fmt.Sscanf(args, "--priority %d", &priority)
|
||||
if err != nil {
|
||||
c.logger.Error("invalid priority argument", "args", args, "error", err)
|
||||
resultChan <- StatusMsg{
|
||||
Text: fmt.Sprintf("Invalid priority: %v", err),
|
||||
Level: "error",
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
task, err := c.taskQueue.EnqueueTask(jobName, args, priority)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to queue job", "job_name", jobName, "error", err)
|
||||
resultChan <- StatusMsg{
|
||||
Text: fmt.Sprintf("Failed to queue %s: %v", jobName, err),
|
||||
Level: "error",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Info("job queued", "job_name", jobName, "task_id", task.ID[:8], "priority", priority)
|
||||
resultChan <- StatusMsg{
|
||||
Text: fmt.Sprintf("✓ Queued: %s (ID: %s, P:%d)", jobName, task.ID[:8], priority),
|
||||
Level: "success",
|
||||
}
|
||||
}()
|
||||
|
||||
return <-resultChan
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) deleteJob(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
jobPath := filepath.Join(c.config.PendingPath(), jobName)
|
||||
if _, err := c.server.Exec(fmt.Sprintf("rm -rf %s", jobPath)); err != nil {
|
||||
return StatusMsg{Text: fmt.Sprintf("Failed to delete %s: %v", jobName, err), Level: "error"}
|
||||
}
|
||||
return StatusMsg{Text: fmt.Sprintf("✓ Deleted: %s", jobName), Level: "success"}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) markFailed(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
src := filepath.Join(c.config.RunningPath(), jobName)
|
||||
dst := filepath.Join(c.config.FailedPath(), jobName)
|
||||
if _, err := c.server.Exec(fmt.Sprintf("mv %s %s", src, dst)); err != nil {
|
||||
return StatusMsg{Text: fmt.Sprintf("Failed to mark failed: %v", err), Level: "error"}
|
||||
}
|
||||
return StatusMsg{Text: fmt.Sprintf("⚠ Marked failed: %s", jobName), Level: "warning"}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) cancelTask(taskID string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if err := c.taskQueue.CancelTask(taskID); err != nil {
|
||||
c.logger.Error("failed to cancel task", "task_id", taskID[:8], "error", err)
|
||||
return StatusMsg{Text: fmt.Sprintf("Cancel failed: %v", err), Level: "error"}
|
||||
}
|
||||
c.logger.Info("task cancelled", "task_id", taskID[:8])
|
||||
return StatusMsg{Text: fmt.Sprintf("✓ Cancelled: %s", taskID[:8]), Level: "success"}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) showQueue(m model.State) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
var content strings.Builder
|
||||
content.WriteString("Task Queue\n")
|
||||
content.WriteString(strings.Repeat("═", 60) + "\n\n")
|
||||
|
||||
if len(m.QueuedTasks) == 0 {
|
||||
content.WriteString("📭 No tasks in queue\n")
|
||||
} else {
|
||||
for i, task := range m.QueuedTasks {
|
||||
statusIcon := "⏳"
|
||||
if task.Status == "running" {
|
||||
statusIcon = "▶"
|
||||
}
|
||||
|
||||
content.WriteString(fmt.Sprintf("%d. %s %s [ID: %s]\n",
|
||||
i+1, statusIcon, task.JobName, task.ID[:8]))
|
||||
content.WriteString(fmt.Sprintf(" Priority: %d | Status: %s\n",
|
||||
task.Priority, task.Status))
|
||||
if task.Args != "" {
|
||||
content.WriteString(fmt.Sprintf(" Args: %s\n", task.Args))
|
||||
}
|
||||
content.WriteString(fmt.Sprintf(" Created: %s\n",
|
||||
task.CreatedAt.Format("2006-01-02 15:04:05")))
|
||||
|
||||
if task.StartedAt != nil {
|
||||
duration := time.Since(*task.StartedAt)
|
||||
content.WriteString(fmt.Sprintf(" Running for: %s\n",
|
||||
duration.Round(time.Second)))
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return QueueLoadedMsg(content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func tickCmd() tea.Cmd {
|
||||
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
|
||||
return TickMsg(t)
|
||||
})
|
||||
}
|
||||
302
cmd/tui/internal/controller/controller.go
Normal file
302
cmd/tui/internal/controller/controller.go
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/services"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// Controller handles all business logic and state updates
|
||||
type Controller struct {
|
||||
config *config.Config
|
||||
server *services.MLServer
|
||||
taskQueue *services.TaskQueue
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// New creates a new Controller instance
|
||||
func New(cfg *config.Config, srv *services.MLServer, tq *services.TaskQueue, logger *logging.Logger) *Controller {
|
||||
return &Controller{
|
||||
config: cfg,
|
||||
server: srv,
|
||||
taskQueue: tq,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the TUI and returns initial commands
|
||||
func (c *Controller) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
tea.SetWindowTitle("FetchML"),
|
||||
c.loadAllData(),
|
||||
tickCmd(),
|
||||
)
|
||||
}
|
||||
|
||||
// Update handles all messages and updates the state
|
||||
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
// Handle input mode (for queuing jobs with args)
|
||||
if m.InputMode {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
args := m.Input.Value()
|
||||
m.Input.SetValue("")
|
||||
m.InputMode = false
|
||||
if job := getSelectedJob(m); job != nil {
|
||||
cmds = append(cmds, c.queueJob(job.Name, args))
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
case "esc":
|
||||
m.InputMode = false
|
||||
m.Input.SetValue("")
|
||||
return m, nil
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.Input, cmd = m.Input.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
// Handle settings-specific keys
|
||||
if m.ActiveView == model.ViewModeSettings {
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
if m.SettingsIndex > 1 { // Skip index 0 (Status)
|
||||
m.SettingsIndex--
|
||||
cmds = append(cmds, c.updateSettingsContent(m))
|
||||
if m.SettingsIndex == 1 {
|
||||
m.ApiKeyInput.Focus()
|
||||
} else {
|
||||
m.ApiKeyInput.Blur()
|
||||
}
|
||||
}
|
||||
case "down", "j":
|
||||
if m.SettingsIndex < 2 {
|
||||
m.SettingsIndex++
|
||||
cmds = append(cmds, c.updateSettingsContent(m))
|
||||
if m.SettingsIndex == 1 {
|
||||
m.ApiKeyInput.Focus()
|
||||
} else {
|
||||
m.ApiKeyInput.Blur()
|
||||
}
|
||||
}
|
||||
case "enter":
|
||||
if cmd := c.handleSettingsAction(&m); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
case "esc":
|
||||
m.ActiveView = model.ViewModeJobs
|
||||
m.ApiKeyInput.Blur()
|
||||
}
|
||||
if m.SettingsIndex == 1 { // API Key input field
|
||||
var cmd tea.Cmd
|
||||
m.ApiKeyInput, cmd = m.ApiKeyInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
// Force update settings view to show typed characters immediately
|
||||
cmds = append(cmds, c.updateSettingsContent(m))
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// Handle global keys
|
||||
switch {
|
||||
case key.Matches(msg, m.Keys.Quit):
|
||||
return m, tea.Quit
|
||||
case key.Matches(msg, m.Keys.Refresh):
|
||||
m.IsLoading = true
|
||||
m.Status = "Refreshing all data..."
|
||||
m.LastRefresh = time.Now()
|
||||
cmds = append(cmds, c.loadAllData())
|
||||
case key.Matches(msg, m.Keys.RefreshGPU):
|
||||
m.Status = "Refreshing GPU status..."
|
||||
cmds = append(cmds, c.loadGPU())
|
||||
case key.Matches(msg, m.Keys.Trigger):
|
||||
if job := getSelectedJob(m); job != nil {
|
||||
cmds = append(cmds, c.queueJob(job.Name, ""))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.TriggerArgs):
|
||||
if job := getSelectedJob(m); job != nil {
|
||||
m.InputMode = true
|
||||
m.Input.Focus()
|
||||
}
|
||||
case key.Matches(msg, m.Keys.ViewQueue):
|
||||
m.ActiveView = model.ViewModeQueue
|
||||
cmds = append(cmds, c.showQueue(m))
|
||||
case key.Matches(msg, m.Keys.ViewContainer):
|
||||
m.ActiveView = model.ViewModeContainer
|
||||
cmds = append(cmds, c.loadContainer())
|
||||
case key.Matches(msg, m.Keys.ViewGPU):
|
||||
m.ActiveView = model.ViewModeGPU
|
||||
cmds = append(cmds, c.loadGPU())
|
||||
case key.Matches(msg, m.Keys.ViewJobs):
|
||||
m.ActiveView = model.ViewModeJobs
|
||||
case key.Matches(msg, m.Keys.ViewSettings):
|
||||
m.ActiveView = model.ViewModeSettings
|
||||
m.SettingsIndex = 1 // Start at Input field, skip Status
|
||||
m.ApiKeyInput.Focus()
|
||||
cmds = append(cmds, c.updateSettingsContent(m))
|
||||
case key.Matches(msg, m.Keys.ViewExperiments):
|
||||
m.ActiveView = model.ViewModeExperiments
|
||||
cmds = append(cmds, c.loadExperiments())
|
||||
case key.Matches(msg, m.Keys.Cancel):
|
||||
if job := getSelectedJob(m); job != nil && job.TaskID != "" {
|
||||
cmds = append(cmds, c.cancelTask(job.TaskID))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.Delete):
|
||||
if job := getSelectedJob(m); job != nil && job.Status == model.StatusPending {
|
||||
cmds = append(cmds, c.deleteJob(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.MarkFailed):
|
||||
if job := getSelectedJob(m); job != nil && job.Status == model.StatusRunning {
|
||||
cmds = append(cmds, c.markFailed(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.Help):
|
||||
m.ShowHelp = !m.ShowHelp
|
||||
}
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
m.Width = msg.Width
|
||||
m.Height = msg.Height
|
||||
|
||||
// Update component sizes
|
||||
h, v := 4, 2 // docStyle.GetFrameSize() approx
|
||||
listHeight := msg.Height - v - 8
|
||||
m.JobList.SetSize(msg.Width/3-h, listHeight)
|
||||
|
||||
panelWidth := msg.Width*2/3 - h - 2
|
||||
panelHeight := (listHeight - 6) / 3
|
||||
|
||||
m.GpuView.Width = panelWidth
|
||||
m.GpuView.Height = panelHeight
|
||||
m.ContainerView.Width = panelWidth
|
||||
m.ContainerView.Height = panelHeight
|
||||
m.QueueView.Width = panelWidth
|
||||
m.QueueView.Height = listHeight - 4
|
||||
m.SettingsView.Width = panelWidth
|
||||
m.SettingsView.Height = listHeight - 4
|
||||
m.ExperimentsView.Width = panelWidth
|
||||
m.ExperimentsView.Height = listHeight - 4
|
||||
|
||||
case JobsLoadedMsg:
|
||||
m.Jobs = []model.Job(msg)
|
||||
calculateJobStats(&m)
|
||||
items := make([]list.Item, len(m.Jobs))
|
||||
for i, job := range m.Jobs {
|
||||
items[i] = job
|
||||
}
|
||||
cmds = append(cmds, m.JobList.SetItems(items))
|
||||
m.Status = formatStatus(m)
|
||||
m.IsLoading = false
|
||||
|
||||
case TasksLoadedMsg:
|
||||
m.QueuedTasks = []*model.Task(msg)
|
||||
m.Status = formatStatus(m)
|
||||
|
||||
case GpuLoadedMsg:
|
||||
m.GpuView.SetContent(string(msg))
|
||||
m.GpuView.GotoTop()
|
||||
|
||||
case ContainerLoadedMsg:
|
||||
m.ContainerView.SetContent(string(msg))
|
||||
m.ContainerView.GotoTop()
|
||||
|
||||
case QueueLoadedMsg:
|
||||
m.QueueView.SetContent(string(msg))
|
||||
m.QueueView.GotoTop()
|
||||
|
||||
case SettingsContentMsg:
|
||||
m.SettingsView.SetContent(string(msg))
|
||||
|
||||
case ExperimentsLoadedMsg:
|
||||
m.ExperimentsView.SetContent(string(msg))
|
||||
m.ExperimentsView.GotoTop()
|
||||
|
||||
case SettingsUpdateMsg:
|
||||
// Settings content was updated, just trigger a re-render
|
||||
|
||||
case StatusMsg:
|
||||
if msg.Level == "error" {
|
||||
m.ErrorMsg = msg.Text
|
||||
m.Status = "Error occurred - check status"
|
||||
} else {
|
||||
m.ErrorMsg = ""
|
||||
m.Status = msg.Text
|
||||
}
|
||||
|
||||
case TickMsg:
|
||||
var spinCmd tea.Cmd
|
||||
m.Spinner, spinCmd = m.Spinner.Update(msg)
|
||||
cmds = append(cmds, spinCmd)
|
||||
|
||||
// Auto-refresh every 10 seconds
|
||||
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
|
||||
m.LastRefresh = time.Now()
|
||||
cmds = append(cmds, c.loadAllData())
|
||||
}
|
||||
cmds = append(cmds, tickCmd())
|
||||
|
||||
default:
|
||||
var spinCmd tea.Cmd
|
||||
m.Spinner, spinCmd = m.Spinner.Update(msg)
|
||||
cmds = append(cmds, spinCmd)
|
||||
}
|
||||
|
||||
// Update all bubble components
|
||||
var cmd tea.Cmd
|
||||
m.JobList, cmd = m.JobList.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
m.GpuView, cmd = m.GpuView.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
m.ContainerView, cmd = m.ContainerView.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
m.QueueView, cmd = m.QueueView.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
m.ExperimentsView, cmd = m.ExperimentsView.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// ExperimentsLoadedMsg is sent when experiments are loaded
|
||||
type ExperimentsLoadedMsg string
|
||||
|
||||
func (c *Controller) loadExperiments() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
commitIDs, err := c.taskQueue.ListExperiments()
|
||||
if err != nil {
|
||||
return StatusMsg{Level: "error", Text: fmt.Sprintf("Failed to list experiments: %v", err)}
|
||||
}
|
||||
|
||||
if len(commitIDs) == 0 {
|
||||
return ExperimentsLoadedMsg("Experiments:\n\nNo experiments found.")
|
||||
}
|
||||
|
||||
var output string
|
||||
output += "Experiments:\n\n"
|
||||
|
||||
for _, commitID := range commitIDs {
|
||||
details, err := c.taskQueue.GetExperimentDetails(commitID)
|
||||
if err != nil {
|
||||
output += fmt.Sprintf("Error loading %s: %v\n\n", commitID, err)
|
||||
continue
|
||||
}
|
||||
output += details + "\n----------------------------------------\n\n"
|
||||
}
|
||||
|
||||
return ExperimentsLoadedMsg(output)
|
||||
}
|
||||
}
|
||||
69
cmd/tui/internal/controller/helpers.go
Normal file
69
cmd/tui/internal/controller/helpers.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
)
|
||||
|
||||
// Helper functions
|
||||
|
||||
func (c *Controller) getPathForStatus(status model.JobStatus) string {
|
||||
switch status {
|
||||
case model.StatusPending:
|
||||
return c.config.PendingPath()
|
||||
case model.StatusRunning:
|
||||
return c.config.RunningPath()
|
||||
case model.StatusFinished:
|
||||
return c.config.FinishedPath()
|
||||
case model.StatusFailed:
|
||||
return c.config.FailedPath()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getSelectedJob(m model.State) *model.Job {
|
||||
if item := m.JobList.SelectedItem(); item != nil {
|
||||
if job, ok := item.(model.Job); ok {
|
||||
return &job
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateJobStats(m *model.State) {
|
||||
m.JobStats = make(map[model.JobStatus]int)
|
||||
for _, job := range m.Jobs {
|
||||
m.JobStats[job.Status]++
|
||||
}
|
||||
}
|
||||
|
||||
func formatStatus(m model.State) string {
|
||||
var parts []string
|
||||
|
||||
if len(m.Jobs) > 0 {
|
||||
stats := []string{}
|
||||
if count := m.JobStats[model.StatusPending]; count > 0 {
|
||||
stats = append(stats, fmt.Sprintf("⏸ %d", count))
|
||||
}
|
||||
if count := m.JobStats[model.StatusRunning]; count > 0 {
|
||||
stats = append(stats, fmt.Sprintf("▶ %d", count))
|
||||
}
|
||||
if count := m.JobStats[model.StatusFinished]; count > 0 {
|
||||
stats = append(stats, fmt.Sprintf("✓ %d", count))
|
||||
}
|
||||
if count := m.JobStats[model.StatusFailed]; count > 0 {
|
||||
stats = append(stats, fmt.Sprintf("✗ %d", count))
|
||||
}
|
||||
parts = append(parts, strings.Join(stats, " | "))
|
||||
}
|
||||
|
||||
if len(m.QueuedTasks) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("Queue: %d", len(m.QueuedTasks)))
|
||||
}
|
||||
|
||||
parts = append(parts, fmt.Sprintf("Updated: %s", m.LastRefresh.Format("15:04:05")))
|
||||
|
||||
return strings.Join(parts, " • ")
|
||||
}
|
||||
126
cmd/tui/internal/controller/settings.go
Normal file
126
cmd/tui/internal/controller/settings.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
)
|
||||
|
||||
// Settings-related command factories and handlers
|
||||
|
||||
func (c *Controller) updateSettingsContent(m model.State) tea.Cmd {
|
||||
var content strings.Builder
|
||||
|
||||
// API Key Status section
|
||||
statusStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.NormalBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}). // borderfg
|
||||
Padding(0, 1).
|
||||
Width(m.SettingsView.Width - 4)
|
||||
|
||||
if m.SettingsIndex == 0 {
|
||||
statusStyle = statusStyle.
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"}) // activeBorderfg
|
||||
}
|
||||
|
||||
statusContent := fmt.Sprintf("%s API Key Status\n%s",
|
||||
getSettingsIndicator(m, 0),
|
||||
getAPIKeyStatus(m))
|
||||
content.WriteString(statusStyle.Render(statusContent))
|
||||
content.WriteString("\n")
|
||||
|
||||
// API Key Input section
|
||||
inputStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.NormalBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}).
|
||||
Padding(0, 1).
|
||||
Width(m.SettingsView.Width - 4)
|
||||
|
||||
if m.SettingsIndex == 1 {
|
||||
inputStyle = inputStyle.
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"})
|
||||
}
|
||||
|
||||
inputContent := fmt.Sprintf("%s Enter New API Key\n%s",
|
||||
getSettingsIndicator(m, 1),
|
||||
m.ApiKeyInput.View())
|
||||
content.WriteString(inputStyle.Render(inputContent))
|
||||
content.WriteString("\n")
|
||||
|
||||
// Save Configuration section
|
||||
saveStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.NormalBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#d8dee9", Dark: "#4c566a"}).
|
||||
Padding(0, 1).
|
||||
Width(m.SettingsView.Width - 4)
|
||||
|
||||
if m.SettingsIndex == 2 {
|
||||
saveStyle = saveStyle.
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "#3498db", Dark: "#7aa2f7"})
|
||||
}
|
||||
|
||||
saveContent := fmt.Sprintf("%s Save Configuration\n[Enter]",
|
||||
getSettingsIndicator(m, 2))
|
||||
content.WriteString(saveStyle.Render(saveContent))
|
||||
content.WriteString("\n")
|
||||
|
||||
// Current API Key display
|
||||
keyStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#666", Dark: "#999"}).
|
||||
Italic(true)
|
||||
|
||||
keyContent := fmt.Sprintf("Current API Key: %s", maskAPIKey(m.ApiKey))
|
||||
content.WriteString(keyStyle.Render(keyContent))
|
||||
|
||||
return func() tea.Msg { return SettingsContentMsg(content.String()) }
|
||||
}
|
||||
|
||||
func (c *Controller) handleSettingsAction(m *model.State) tea.Cmd {
|
||||
switch m.SettingsIndex {
|
||||
case 0: // API Key Status - do nothing
|
||||
return nil
|
||||
case 1: // Enter New API Key - do nothing, Enter key disabled
|
||||
return nil
|
||||
case 2: // Save Configuration
|
||||
if m.ApiKeyInput.Value() != "" {
|
||||
m.ApiKey = m.ApiKeyInput.Value()
|
||||
m.ApiKeyInput.SetValue("")
|
||||
m.Status = "Configuration saved (in-memory only)"
|
||||
return c.updateSettingsContent(*m)
|
||||
} else if m.ApiKey != "" {
|
||||
m.Status = "Configuration saved (in-memory only)"
|
||||
} else {
|
||||
m.ErrorMsg = "No API key to save"
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions for settings
|
||||
|
||||
func getSettingsIndicator(m model.State, index int) string {
|
||||
if index == m.SettingsIndex {
|
||||
return "▶"
|
||||
}
|
||||
return " "
|
||||
}
|
||||
|
||||
func getAPIKeyStatus(m model.State) string {
|
||||
if m.ApiKey != "" {
|
||||
return "✓ API Key is set\n" + maskAPIKey(m.ApiKey)
|
||||
}
|
||||
return "⚠ No API Key configured"
|
||||
}
|
||||
|
||||
func maskAPIKey(key string) string {
|
||||
if key == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
if len(key) <= 8 {
|
||||
return "****"
|
||||
}
|
||||
return key[:4] + "****" + key[len(key)-4:]
|
||||
}
|
||||
206
cmd/tui/internal/model/state.go
Normal file
206
cmd/tui/internal/model/state.go
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
type ViewMode int
|
||||
|
||||
const (
|
||||
ViewModeJobs ViewMode = iota
|
||||
ViewModeGPU
|
||||
ViewModeQueue
|
||||
ViewModeContainer
|
||||
ViewModeSettings
|
||||
ViewModeDatasets
|
||||
ViewModeExperiments
|
||||
)
|
||||
|
||||
type JobStatus string
|
||||
|
||||
const (
|
||||
StatusPending JobStatus = "pending"
|
||||
StatusQueued JobStatus = "queued"
|
||||
StatusRunning JobStatus = "running"
|
||||
StatusFinished JobStatus = "finished"
|
||||
StatusFailed JobStatus = "failed"
|
||||
)
|
||||
|
||||
type Job struct {
|
||||
Name string
|
||||
Status JobStatus
|
||||
TaskID string
|
||||
Priority int64
|
||||
}
|
||||
|
||||
func (j Job) Title() string { return j.Name }
|
||||
func (j Job) Description() string {
|
||||
icon := map[JobStatus]string{
|
||||
StatusPending: "⏸",
|
||||
StatusQueued: "⏳",
|
||||
StatusRunning: "▶",
|
||||
StatusFinished: "✓",
|
||||
StatusFailed: "✗",
|
||||
}[j.Status]
|
||||
pri := ""
|
||||
if j.Priority > 0 {
|
||||
pri = fmt.Sprintf(" [P%d]", j.Priority)
|
||||
}
|
||||
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
|
||||
}
|
||||
func (j Job) FilterValue() string { return j.Name }
|
||||
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type DatasetInfo struct {
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
Location string `json:"location"`
|
||||
LastAccess time.Time `json:"last_access"`
|
||||
}
|
||||
|
||||
// State holds the application state
|
||||
type State struct {
|
||||
Jobs []Job
|
||||
QueuedTasks []*Task
|
||||
Datasets []DatasetInfo
|
||||
JobList list.Model
|
||||
GpuView viewport.Model
|
||||
ContainerView viewport.Model
|
||||
QueueView viewport.Model
|
||||
SettingsView viewport.Model
|
||||
DatasetView viewport.Model
|
||||
ExperimentsView viewport.Model
|
||||
Input textinput.Model
|
||||
ApiKeyInput textinput.Model
|
||||
Status string
|
||||
ErrorMsg string
|
||||
InputMode bool
|
||||
Width int
|
||||
Height int
|
||||
ShowHelp bool
|
||||
Spinner spinner.Model
|
||||
ActiveView ViewMode
|
||||
LastRefresh time.Time
|
||||
IsLoading bool
|
||||
JobStats map[JobStatus]int
|
||||
ApiKey string
|
||||
SettingsIndex int
|
||||
Keys KeyMap
|
||||
}
|
||||
|
||||
type KeyMap struct {
|
||||
Refresh key.Binding
|
||||
Trigger key.Binding
|
||||
TriggerArgs key.Binding
|
||||
ViewQueue key.Binding
|
||||
ViewContainer key.Binding
|
||||
ViewGPU key.Binding
|
||||
ViewJobs key.Binding
|
||||
ViewDatasets key.Binding
|
||||
ViewExperiments key.Binding
|
||||
ViewSettings key.Binding
|
||||
Cancel key.Binding
|
||||
Delete key.Binding
|
||||
MarkFailed key.Binding
|
||||
RefreshGPU key.Binding
|
||||
Help key.Binding
|
||||
Quit key.Binding
|
||||
}
|
||||
|
||||
var Keys = KeyMap{
|
||||
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
|
||||
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
|
||||
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
|
||||
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
|
||||
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
|
||||
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
|
||||
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
|
||||
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
|
||||
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
|
||||
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
|
||||
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
|
||||
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
|
||||
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
|
||||
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
|
||||
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
|
||||
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
|
||||
}
|
||||
|
||||
func InitialState(apiKey string) State {
|
||||
items := []list.Item{}
|
||||
delegate := list.NewDefaultDelegate()
|
||||
delegate.Styles.SelectedTitle = delegate.Styles.SelectedTitle.
|
||||
Foreground(lipgloss.Color("170")).
|
||||
Bold(true)
|
||||
delegate.Styles.SelectedDesc = delegate.Styles.SelectedDesc.
|
||||
Foreground(lipgloss.Color("246"))
|
||||
|
||||
jobList := list.New(items, delegate, 0, 0)
|
||||
jobList.Title = "ML Jobs & Queue"
|
||||
jobList.SetShowStatusBar(true)
|
||||
jobList.SetFilteringEnabled(true)
|
||||
jobList.SetShowHelp(false)
|
||||
// Styles will be set in View or here?
|
||||
// Keeping style initialization here as it's part of the model state setup
|
||||
jobList.Styles.Title = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"}).
|
||||
Padding(0, 0, 1, 0)
|
||||
|
||||
input := textinput.New()
|
||||
input.Placeholder = "Args: --epochs 100 --lr 0.001 --priority 5"
|
||||
input.Width = 60
|
||||
input.CharLimit = 200
|
||||
|
||||
apiKeyInput := textinput.New()
|
||||
apiKeyInput.Placeholder = "Enter API key..."
|
||||
apiKeyInput.Width = 40
|
||||
apiKeyInput.CharLimit = 200
|
||||
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Dot
|
||||
s.Style = lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2980b9", Dark: "#7aa2f7"})
|
||||
|
||||
return State{
|
||||
JobList: jobList,
|
||||
GpuView: viewport.New(0, 0),
|
||||
ContainerView: viewport.New(0, 0),
|
||||
QueueView: viewport.New(0, 0),
|
||||
SettingsView: viewport.New(0, 0),
|
||||
DatasetView: viewport.New(0, 0),
|
||||
ExperimentsView: viewport.New(0, 0),
|
||||
Input: input,
|
||||
ApiKeyInput: apiKeyInput,
|
||||
Status: "Connected",
|
||||
InputMode: false,
|
||||
ShowHelp: false,
|
||||
Spinner: s,
|
||||
ActiveView: ViewModeJobs,
|
||||
LastRefresh: time.Now(),
|
||||
IsLoading: false,
|
||||
JobStats: make(map[JobStatus]int),
|
||||
ApiKey: apiKey,
|
||||
SettingsIndex: 0,
|
||||
Keys: Keys,
|
||||
}
|
||||
}
|
||||
237
cmd/tui/internal/services/services.go
Normal file
237
cmd/tui/internal/services/services.go
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// TaskQueue wraps the internal queue.TaskQueue for TUI compatibility
|
||||
type TaskQueue struct {
|
||||
internal *queue.TaskQueue
|
||||
expManager *experiment.Manager
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
||||
// Create internal queue config
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
}
|
||||
|
||||
internalQueue, err := queue.NewTaskQueue(queueCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create task queue: %w", err)
|
||||
}
|
||||
|
||||
// Initialize experiment manager
|
||||
// TODO: Get base path from config
|
||||
expManager := experiment.NewManager("./experiments")
|
||||
|
||||
return &TaskQueue{
|
||||
internal: internalQueue,
|
||||
expManager: expManager,
|
||||
ctx: context.Background(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) EnqueueTask(jobName, args string, priority int64) (*model.Task, error) {
|
||||
// Create internal task
|
||||
internalTask := &queue.Task{
|
||||
JobName: jobName,
|
||||
Args: args,
|
||||
Priority: priority,
|
||||
}
|
||||
|
||||
// Use internal queue to enqueue
|
||||
err := tq.internal.AddTask(internalTask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: "queued",
|
||||
Priority: int64(internalTask.Priority),
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetNextTask() (*model.Task, error) {
|
||||
internalTask, err := tq.internal.GetNextTask()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if internalTask == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: internalTask.Status,
|
||||
Priority: internalTask.Priority,
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetTask(taskID string) (*model.Task, error) {
|
||||
internalTask, err := tq.internal.GetTask(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI model
|
||||
return &model.Task{
|
||||
ID: internalTask.ID,
|
||||
JobName: internalTask.JobName,
|
||||
Args: internalTask.Args,
|
||||
Status: internalTask.Status,
|
||||
Priority: internalTask.Priority,
|
||||
CreatedAt: internalTask.CreatedAt,
|
||||
Metadata: internalTask.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) UpdateTask(task *model.Task) error {
|
||||
// Convert to internal task
|
||||
internalTask := &queue.Task{
|
||||
ID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Args: task.Args,
|
||||
Status: task.Status,
|
||||
Priority: task.Priority,
|
||||
CreatedAt: task.CreatedAt,
|
||||
Metadata: task.Metadata,
|
||||
}
|
||||
|
||||
return tq.internal.UpdateTask(internalTask)
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetQueuedTasks() ([]*model.Task, error) {
|
||||
internalTasks, err := tq.internal.GetAllTasks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert to TUI models
|
||||
tasks := make([]*model.Task, len(internalTasks))
|
||||
for i, task := range internalTasks {
|
||||
tasks[i] = &model.Task{
|
||||
ID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Args: task.Args,
|
||||
Status: task.Status,
|
||||
Priority: task.Priority,
|
||||
CreatedAt: task.CreatedAt,
|
||||
Metadata: task.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetJobStatus(jobName string) (map[string]string, error) {
|
||||
// This method doesn't exist in internal queue, implement basic version
|
||||
task, err := tq.internal.GetTaskByName(jobName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if task == nil {
|
||||
return map[string]string{"status": "not_found"}, nil
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"status": task.Status,
|
||||
"task_id": task.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
|
||||
return tq.internal.RecordMetric(jobName, metric, value)
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetMetrics(jobName string) (map[string]string, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return []model.DatasetInfo{}, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) CancelTask(taskID string) error {
|
||||
return tq.internal.CancelTask(taskID)
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) ListExperiments() ([]string, error) {
|
||||
return tq.expManager.ListExperiments()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) GetExperimentDetails(commitID string) (string, error) {
|
||||
meta, err := tq.expManager.ReadMetadata(commitID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
metrics, err := tq.expManager.GetMetrics(commitID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
output := fmt.Sprintf("Experiment: %s\n", meta.JobName)
|
||||
output += fmt.Sprintf("Commit ID: %s\n", meta.CommitID)
|
||||
output += fmt.Sprintf("User: %s\n", meta.User)
|
||||
output += fmt.Sprintf("Timestamp: %d\n\n", meta.Timestamp)
|
||||
output += "Metrics:\n"
|
||||
|
||||
if len(metrics) == 0 {
|
||||
output += " No metrics logged.\n"
|
||||
} else {
|
||||
for _, m := range metrics {
|
||||
output += fmt.Sprintf(" %s: %.4f (Step: %d)\n", m.Name, m.Value, m.Step)
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue) Close() error {
|
||||
return tq.internal.Close()
|
||||
}
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewMLServer(cfg *config.Config) (*MLServer, error) {
|
||||
// Local mode: skip SSH entirely
|
||||
if cfg.Host == "" {
|
||||
client, _ := network.NewSSHClient("", "", "", 0, "")
|
||||
return &MLServer{SSHClient: client, addr: "localhost"}, nil
|
||||
}
|
||||
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
return &MLServer{SSHClient: client, addr: addr}, nil
|
||||
}
|
||||
255
cmd/tui/internal/view/view.go
Normal file
255
cmd/tui/internal/view/view.go
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
package view
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
headerfgLight = "#d35400"
|
||||
headerfgDark = "#ff9e64"
|
||||
activeBorderfgLight = "#3498db"
|
||||
activeBorderfgDark = "#7aa2f7"
|
||||
errorbgLight = "#fee"
|
||||
errorbgDark = "#633"
|
||||
errorfgLight = "#a00"
|
||||
errorfgDark = "#faa"
|
||||
titlefgLight = "#d35400"
|
||||
titlefgDark = "#ff9e64"
|
||||
statusfgLight = "#2e3440"
|
||||
statusfgDark = "#d8dee9"
|
||||
statusbgLight = "#e5e9f0"
|
||||
statusbgDark = "#2e3440"
|
||||
borderfgLight = "#d8dee9"
|
||||
borderfgDark = "#4c566a"
|
||||
helpfgLight = "#4c566a"
|
||||
helpfgDark = "#88c0d0"
|
||||
)
|
||||
|
||||
var (
|
||||
docStyle = lipgloss.NewStyle().Margin(1, 2)
|
||||
|
||||
activeBorderStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: activeBorderfgLight, Dark: activeBorderfgDark}).
|
||||
Padding(1, 2)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Background(lipgloss.AdaptiveColor{Light: errorbgLight, Dark: errorbgDark}).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: errorfgLight, Dark: errorfgDark}).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
titleStyle = (lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: titlefgLight, Dark: titlefgDark}).
|
||||
MarginBottom(1))
|
||||
|
||||
statusStyle = (lipgloss.NewStyle().
|
||||
Background(lipgloss.AdaptiveColor{Light: statusbgLight, Dark: statusbgDark}).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: statusfgLight, Dark: statusfgDark}).
|
||||
Padding(0, 1))
|
||||
|
||||
borderStyle = (lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: borderfgLight, Dark: borderfgDark}).
|
||||
Padding(0, 1))
|
||||
|
||||
helpStyle = (lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: helpfgLight, Dark: helpfgDark}))
|
||||
)
|
||||
|
||||
func Render(m model.State) string {
|
||||
if m.Width == 0 {
|
||||
return "Loading..."
|
||||
}
|
||||
|
||||
// Title
|
||||
title := titleStyle.Width(m.Width - 4).Render("🤖 ML Experiment Manager")
|
||||
|
||||
// Left panel - Job list (30% width)
|
||||
leftWidth := int(float64(m.Width) * 0.3)
|
||||
leftPanel := getJobListPanel(m, leftWidth)
|
||||
|
||||
// Right panel - Dynamic content (70% width)
|
||||
rightWidth := m.Width - leftWidth - 4
|
||||
rightPanel := getRightPanel(m, rightWidth)
|
||||
|
||||
// Main content
|
||||
main := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel)
|
||||
|
||||
// Status bar
|
||||
statusBar := getStatusBar(m)
|
||||
|
||||
// Error bar (if present)
|
||||
var errorBar string
|
||||
if m.ErrorMsg != "" {
|
||||
errorBar = errorStyle.Width(m.Width - 4).Render("⚠ Error: " + m.ErrorMsg)
|
||||
}
|
||||
|
||||
// Help view (toggleable)
|
||||
var helpView string
|
||||
if m.ShowHelp {
|
||||
helpView = helpStyle.Width(m.Width-4).
|
||||
Padding(1, 2).
|
||||
Render(helpText(m))
|
||||
}
|
||||
|
||||
// Quick help bar
|
||||
quickHelp := helpStyle.Width(m.Width - 4).Render(getQuickHelp(m))
|
||||
|
||||
// Compose final layout
|
||||
parts := []string{title, main, statusBar}
|
||||
if errorBar != "" {
|
||||
parts = append(parts, errorBar)
|
||||
}
|
||||
if helpView != "" {
|
||||
parts = append(parts, helpView)
|
||||
}
|
||||
parts = append(parts, quickHelp)
|
||||
|
||||
return docStyle.Render(lipgloss.JoinVertical(lipgloss.Left, parts...))
|
||||
}
|
||||
|
||||
func getJobListPanel(m model.State, width int) string {
|
||||
style := borderStyle
|
||||
if m.ActiveView == model.ViewModeJobs {
|
||||
style = activeBorderStyle
|
||||
}
|
||||
// Ensure the job list has proper dimensions to prevent rendering issues
|
||||
// Note: We can't modify the model here as it's passed by value,
|
||||
// but the View() method of list.Model uses its internal state.
|
||||
// Ideally, the controller should have set the size.
|
||||
// For now, we assume the controller handles resizing or we act on a copy.
|
||||
// But list.Model.SetSize modifies the model.
|
||||
// Since we receive 'm' by value, modifications to m.JobList won't persist.
|
||||
// However, we need to render it with the correct size.
|
||||
// So we can modify our local copy 'm'.
|
||||
h, v := style.GetFrameSize()
|
||||
m.JobList.SetSize(width-h, m.Height-v-4) // Adjust height for title/help/status
|
||||
|
||||
// Custom empty state
|
||||
if len(m.JobList.Items()) == 0 {
|
||||
return style.Width(width - h).Render(
|
||||
lipgloss.JoinVertical(lipgloss.Left,
|
||||
m.JobList.Styles.Title.Render(m.JobList.Title),
|
||||
"\n No jobs found.",
|
||||
" Press 't' to queue.",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
return style.Width(width - h).Render(m.JobList.View())
|
||||
}
|
||||
|
||||
func getRightPanel(m model.State, width int) string {
|
||||
var content string
|
||||
var viewTitle string
|
||||
style := borderStyle
|
||||
|
||||
switch m.ActiveView {
|
||||
case model.ViewModeGPU:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "🎮 GPU Status"
|
||||
content = m.GpuView.View()
|
||||
case model.ViewModeContainer:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "🐳 Container Status"
|
||||
content = m.ContainerView.View()
|
||||
case model.ViewModeQueue:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "⏳ Task Queue"
|
||||
content = m.QueueView.View()
|
||||
case model.ViewModeSettings:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "⚙️ Settings"
|
||||
content = m.SettingsView.View()
|
||||
case model.ViewModeExperiments:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "🧪 Experiments"
|
||||
content = m.ExperimentsView.View()
|
||||
default:
|
||||
viewTitle = "📊 System Overview"
|
||||
content = getOverviewPanel(m)
|
||||
}
|
||||
|
||||
header := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: headerfgLight, Dark: headerfgDark}).
|
||||
Render(viewTitle)
|
||||
|
||||
h, _ := style.GetFrameSize()
|
||||
return style.Width(width - h).Render(header + "\n\n" + content)
|
||||
}
|
||||
|
||||
func getOverviewPanel(m model.State) string {
|
||||
var sections []string
|
||||
|
||||
sections = append(sections, "🎮 GPU\n"+strings.Repeat("─", 40))
|
||||
sections = append(sections, m.GpuView.View())
|
||||
sections = append(sections, "\n🐳 Containers\n"+strings.Repeat("─", 40))
|
||||
sections = append(sections, m.ContainerView.View())
|
||||
|
||||
return strings.Join(sections, "\n")
|
||||
}
|
||||
|
||||
func getStatusBar(m model.State) string {
|
||||
spinnerStr := m.Spinner.View()
|
||||
if !m.IsLoading {
|
||||
if m.ShowHelp {
|
||||
spinnerStr = "?"
|
||||
} else {
|
||||
spinnerStr = "●"
|
||||
}
|
||||
}
|
||||
statusText := m.Status
|
||||
if m.ShowHelp {
|
||||
statusText = "Press 'h' to hide help"
|
||||
}
|
||||
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText)
|
||||
}
|
||||
|
||||
func helpText(m model.State) string {
|
||||
if m.ActiveView == model.ViewModeSettings {
|
||||
return `╔═══════════════════════════════════════════════════════════════╗
|
||||
║ Settings Shortcuts ║
|
||||
╠═══════════════════════════════════════════════════════════════╣
|
||||
║ Navigation ║
|
||||
║ j/k, ↑/↓ : Move selection ║
|
||||
║ Enter : Edit / Save ║
|
||||
║ Esc : Exit Settings ║
|
||||
║ ║
|
||||
║ General ║
|
||||
║ h or ? : Toggle this help q/Ctrl+C : Quit ║
|
||||
╚═══════════════════════════════════════════════════════════════╝`
|
||||
}
|
||||
|
||||
return `╔═══════════════════════════════════════════════════════════════╗
|
||||
║ Keyboard Shortcuts ║
|
||||
╠═══════════════════════════════════════════════════════════════╣
|
||||
║ Navigation ║
|
||||
║ j/k, ↑/↓ : Move selection / : Filter jobs ║
|
||||
║ 1 : Job list view 2 : Datasets view ║
|
||||
║ 3 : Experiments view v : Queue view ║
|
||||
║ g : GPU view o : Container view ║
|
||||
║ s : Settings view ║
|
||||
║ ║
|
||||
║ Actions ║
|
||||
║ t : Queue job a : Queue w/ args ║
|
||||
║ c : Cancel task d : Delete pending ║
|
||||
║ f : Mark as failed r : Refresh all ║
|
||||
║ G : Refresh GPU only ║
|
||||
║ ║
|
||||
║ General ║
|
||||
║ h or ? : Toggle this help q/Ctrl+C : Quit ║
|
||||
╚═══════════════════════════════════════════════════════════════╝`
|
||||
}
|
||||
|
||||
func getQuickHelp(m model.State) string {
|
||||
if m.ActiveView == model.ViewModeSettings {
|
||||
return " ↑/↓:move enter:select esc:exit settings q:quit"
|
||||
}
|
||||
return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers s:settings t:queue r:refresh q:quit"
|
||||
}
|
||||
204
cmd/tui/main.go
Normal file
204
cmd/tui/main.go
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
// Package main implements the ML TUI
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/controller"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/services"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/view"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
type AppModel struct {
|
||||
state model.State
|
||||
controller *controller.Controller
|
||||
}
|
||||
|
||||
func (m AppModel) Init() tea.Cmd {
|
||||
return m.controller.Init()
|
||||
}
|
||||
|
||||
func (m AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
newState, cmd := m.controller.Update(msg, m.state)
|
||||
m.state = newState
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m AppModel) View() string {
|
||||
return view.Render(m.state)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse authentication flags
|
||||
authFlags := auth.ParseAuthFlags()
|
||||
if err := auth.ValidateAuthFlags(authFlags); err != nil {
|
||||
log.Fatalf("Authentication flag error: %v", err)
|
||||
}
|
||||
|
||||
// Get API key from various sources
|
||||
apiKey := auth.GetAPIKeyFromSources(authFlags)
|
||||
|
||||
var (
|
||||
cfg *config.Config
|
||||
cliConfig *config.CLIConfig
|
||||
cliConfPath string
|
||||
)
|
||||
|
||||
configFlag := authFlags.ConfigFile
|
||||
|
||||
// Only support TOML configuration
|
||||
var err error
|
||||
cliConfig, cliConfPath, err = config.LoadCLIConfig(configFlag)
|
||||
if err != nil {
|
||||
if configFlag != "" {
|
||||
log.Fatalf("Failed to load TOML config %s: %v", configFlag, err)
|
||||
} else {
|
||||
// Provide helpful error message for data scientists
|
||||
log.Printf("=== Fetch ML TUI - Configuration Required ===")
|
||||
log.Printf("")
|
||||
log.Printf("Error: %v", err)
|
||||
log.Printf("")
|
||||
log.Printf("To get started with the TUI, you need to initialize your configuration:")
|
||||
log.Printf("")
|
||||
log.Printf("Option 1: Using the Zig CLI (Recommended)")
|
||||
log.Printf(" 1. Build the CLI: cd cli && make build")
|
||||
log.Printf(" 2. Initialize config: ./cli/zig-out/bin/ml init")
|
||||
log.Printf(" 3. Edit ~/.ml/config.toml with your settings")
|
||||
log.Printf(" 4. Run TUI: ./bin/tui")
|
||||
log.Printf("")
|
||||
log.Printf("Option 2: Manual Configuration")
|
||||
log.Printf(" 1. Create directory: mkdir -p ~/.ml")
|
||||
log.Printf(" 2. Create config: touch ~/.ml/config.toml")
|
||||
log.Printf(" 3. Add your settings to the file")
|
||||
log.Printf(" 4. Run TUI: ./bin/tui")
|
||||
log.Printf("")
|
||||
log.Printf("Example ~/.ml/config.toml:")
|
||||
log.Printf(" worker_host = \"localhost\"")
|
||||
log.Printf(" worker_user = \"your_username\"")
|
||||
log.Printf(" worker_base = \"~/ml_jobs\"")
|
||||
log.Printf(" worker_port = 22")
|
||||
log.Printf(" api_key = \"your_api_key_here\"")
|
||||
log.Printf("")
|
||||
log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
cfg = cliConfig.ToTUIConfig()
|
||||
log.Printf("Loaded TOML configuration from %s", cliConfPath)
|
||||
|
||||
// Validate authentication configuration
|
||||
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
||||
log.Fatalf("Invalid authentication configuration: %v", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("Invalid configuration: %v", err)
|
||||
}
|
||||
|
||||
// Test authentication if enabled
|
||||
if cfg.Auth.Enabled {
|
||||
// Use API key from CLI config if available, otherwise use from flags
|
||||
var effectiveAPIKey string
|
||||
if cliConfig != nil && cliConfig.APIKey != "" {
|
||||
effectiveAPIKey = cliConfig.APIKey
|
||||
} else if apiKey != "" {
|
||||
effectiveAPIKey = apiKey
|
||||
} else {
|
||||
log.Fatal("Authentication required but no API key provided")
|
||||
}
|
||||
|
||||
if _, err := cfg.Auth.ValidateAPIKey(effectiveAPIKey); err != nil {
|
||||
log.Fatalf("Authentication failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
srv, err := services.NewMLServer(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to server: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := srv.Close(); err != nil {
|
||||
log.Printf("server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tq, err := services.NewTaskQueue(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to Redis: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := tq.Close(); err != nil {
|
||||
log.Printf("task queue close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize logger
|
||||
// Note: In original code, logger was created inside initialModel.
|
||||
// Here we create it and pass it to controller.
|
||||
// We use slog.LevelError as default from original code.
|
||||
// But original code imported "log/slog".
|
||||
// We use internal/logging package.
|
||||
// Check logging package signature.
|
||||
// Original: logger := logging.NewLogger(slog.LevelError, false)
|
||||
// We need to import "log/slog" in main if we use slog constants.
|
||||
// Or use logging package constants if available.
|
||||
// Let's check logging package.
|
||||
// Assuming logging.NewLogger takes (slog.Level, bool).
|
||||
// I'll import "log/slog".
|
||||
|
||||
// Wait, I need to import "log/slog"
|
||||
logger := logging.NewLogger(-4, false) // -4 is slog.LevelError value. Or I can import log/slog.
|
||||
|
||||
// Initialize State and Controller
|
||||
var effectiveAPIKey string
|
||||
if cliConfig != nil && cliConfig.APIKey != "" {
|
||||
effectiveAPIKey = cliConfig.APIKey
|
||||
} else {
|
||||
effectiveAPIKey = apiKey
|
||||
}
|
||||
|
||||
initialState := model.InitialState(effectiveAPIKey)
|
||||
ctrl := controller.New(cfg, srv, tq, logger)
|
||||
|
||||
appModel := AppModel{
|
||||
state: initialState,
|
||||
controller: ctrl,
|
||||
}
|
||||
|
||||
// Run TUI app
|
||||
p := tea.NewProgram(appModel, tea.WithAltScreen(), tea.WithMouseAllMotion())
|
||||
|
||||
// Ensure we restore the terminal even if panic or error occurs
|
||||
// Note: p.Run() usually handles this, but explicit cleanup is safer
|
||||
// if we want to ensure the alt screen is exited.
|
||||
// We can't defer p.ReleaseTerminal() here because p is created here.
|
||||
// But we can defer a function that calls it.
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Run program and handle signals
|
||||
go func() {
|
||||
<-sigChan
|
||||
p.Quit()
|
||||
}()
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
// Attempt to restore terminal before logging fatal error
|
||||
p.ReleaseTerminal()
|
||||
log.Fatalf("Error running TUI: %v", err)
|
||||
}
|
||||
|
||||
// Explicitly restore terminal after program exits
|
||||
p.ReleaseTerminal()
|
||||
}
|
||||
175
cmd/user_manager/main.go
Normal file
175
cmd/user_manager/main.go
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type ConfigWithAuth struct {
|
||||
Auth auth.AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
configFile = flag.String("config", "", "Configuration file path")
|
||||
command = flag.String("cmd", "", "Command: generate-key, list-users, hash-key")
|
||||
username = flag.String("username", "", "Username for generate-key")
|
||||
role = flag.String("role", "", "Role for generate-key")
|
||||
admin = flag.Bool("admin", false, "Admin flag for generate-key")
|
||||
apiKey = flag.String("key", "", "API key to hash")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *configFile == "" || *command == "" {
|
||||
fmt.Println("Usage: user_manager --config <config.yaml> --cmd <command> [options]")
|
||||
fmt.Println("Commands: generate-key, list-users, hash-key")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch *command {
|
||||
case "generate-key":
|
||||
if *username == "" {
|
||||
log.Fatal("Usage: --cmd generate-key --username <name> [--admin] [--role <role>]")
|
||||
}
|
||||
|
||||
// Load config
|
||||
data, err := os.ReadFile(*configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var config ConfigWithAuth
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
log.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Generate API key
|
||||
apiKey := auth.GenerateAPIKey()
|
||||
|
||||
// Setup user
|
||||
if config.Auth.APIKeys == nil {
|
||||
config.Auth.APIKeys = make(map[auth.Username]auth.APIKeyEntry)
|
||||
}
|
||||
|
||||
adminStatus := *admin
|
||||
roles := []string{"viewer"}
|
||||
permissions := make(map[string]bool)
|
||||
|
||||
if !adminStatus && *role == "" {
|
||||
fmt.Printf("Make user '%s' an admin? (y/N): ", *username)
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
adminStatus = strings.ToLower(strings.TrimSpace(response)) == "y"
|
||||
}
|
||||
|
||||
if adminStatus {
|
||||
roles = []string{"admin"}
|
||||
permissions["*"] = true
|
||||
} else if *role != "" {
|
||||
roles = []string{*role}
|
||||
rolePerms := getRolePermissions(*role)
|
||||
for perm, value := range rolePerms {
|
||||
permissions[perm] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Save user
|
||||
config.Auth.APIKeys[auth.Username(*username)] = auth.APIKeyEntry{
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey(apiKey)),
|
||||
Admin: adminStatus,
|
||||
Roles: roles,
|
||||
Permissions: permissions,
|
||||
}
|
||||
|
||||
data, err = yaml.Marshal(config)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to marshal config: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(*configFile, data, 0600); err != nil {
|
||||
log.Fatalf("Failed to write config: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Generated API key for user '%s':\nKey: %s\n", *username, apiKey)
|
||||
|
||||
case "list-users":
|
||||
data, err := os.ReadFile(*configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var config ConfigWithAuth
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
log.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Configured Users:")
|
||||
fmt.Println("=================")
|
||||
for username, entry := range config.Auth.APIKeys {
|
||||
fmt.Printf("User: %s\n", string(username))
|
||||
fmt.Printf(" Admin: %v\n", entry.Admin)
|
||||
if len(entry.Roles) > 0 {
|
||||
fmt.Printf(" Roles: %v\n", entry.Roles)
|
||||
}
|
||||
if len(entry.Permissions) > 0 {
|
||||
fmt.Printf(" Permissions: %d\n", len(entry.Permissions))
|
||||
}
|
||||
fmt.Printf(" Key Hash: %s...\n\n", string(entry.Hash)[:8])
|
||||
}
|
||||
|
||||
case "hash-key":
|
||||
if *apiKey == "" {
|
||||
log.Fatal("Usage: --cmd hash-key --key <api-key>")
|
||||
}
|
||||
hash := auth.HashAPIKey(*apiKey)
|
||||
fmt.Printf("Hash: %s\n", hash)
|
||||
|
||||
default:
|
||||
log.Fatalf("Unknown command: %s", *command)
|
||||
}
|
||||
}
|
||||
|
||||
// getRolePermissions returns permissions for a role
|
||||
func getRolePermissions(role string) map[string]bool {
|
||||
rolePermissions := map[string]map[string]bool{
|
||||
"admin": {
|
||||
"*": true,
|
||||
},
|
||||
"data_scientist": {
|
||||
"jobs:create": true,
|
||||
"jobs:read": true,
|
||||
"jobs:update": true,
|
||||
"data:read": true,
|
||||
"models:read": true,
|
||||
},
|
||||
"data_engineer": {
|
||||
"data:create": true,
|
||||
"data:read": true,
|
||||
"data:update": true,
|
||||
"data:delete": true,
|
||||
},
|
||||
"viewer": {
|
||||
"jobs:read": true,
|
||||
"data:read": true,
|
||||
"models:read": true,
|
||||
"metrics:read": true,
|
||||
},
|
||||
"operator": {
|
||||
"jobs:read": true,
|
||||
"jobs:update": true,
|
||||
"metrics:read": true,
|
||||
"system:read": true,
|
||||
},
|
||||
}
|
||||
|
||||
if perms, exists := rolePermissions[role]; exists {
|
||||
return perms
|
||||
}
|
||||
return make(map[string]bool)
|
||||
}
|
||||
173
cmd/worker/worker_config.go
Normal file
173
cmd/worker/worker_config.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMetricsFlushInterval = 500 * time.Millisecond
|
||||
datasetCacheDefaultTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// Config holds worker configuration
|
||||
type Config struct {
|
||||
Host string `yaml:"host"`
|
||||
User string `yaml:"user"`
|
||||
SSHKey string `yaml:"ssh_key"`
|
||||
Port int `yaml:"port"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
TrainScript string `yaml:"train_script"`
|
||||
RedisAddr string `yaml:"redis_addr"`
|
||||
RedisPassword string `yaml:"redis_password"`
|
||||
RedisDB int `yaml:"redis_db"`
|
||||
KnownHosts string `yaml:"known_hosts"`
|
||||
WorkerID string `yaml:"worker_id"`
|
||||
MaxWorkers int `yaml:"max_workers"`
|
||||
PollInterval int `yaml:"poll_interval_seconds"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.AuthConfig `yaml:"auth"`
|
||||
|
||||
// Metrics exporter
|
||||
Metrics MetricsConfig `yaml:"metrics"`
|
||||
// Metrics buffering
|
||||
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
|
||||
|
||||
// Data management
|
||||
DataManagerPath string `yaml:"data_manager_path"`
|
||||
AutoFetchData bool `yaml:"auto_fetch_data"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
|
||||
|
||||
// Podman execution
|
||||
PodmanImage string `yaml:"podman_image"`
|
||||
ContainerWorkspace string `yaml:"container_workspace"`
|
||||
ContainerResults string `yaml:"container_results"`
|
||||
GPUAccess bool `yaml:"gpu_access"`
|
||||
|
||||
// Task lease and retry settings
|
||||
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min)
|
||||
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min)
|
||||
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
||||
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min)
|
||||
}
|
||||
|
||||
// MetricsConfig controls the Prometheus exporter.
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get smart defaults for current environment
|
||||
smart := config.GetSmartDefaults()
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
}
|
||||
if cfg.WorkerID == "" {
|
||||
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
||||
}
|
||||
if cfg.MaxWorkers == 0 {
|
||||
cfg.MaxWorkers = smart.MaxWorkers()
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = smart.PollInterval()
|
||||
}
|
||||
if cfg.DataManagerPath == "" {
|
||||
cfg.DataManagerPath = "./data_manager"
|
||||
}
|
||||
if cfg.DataDir == "" {
|
||||
if cfg.Host == "" || !cfg.AutoFetchData {
|
||||
cfg.DataDir = config.DefaultLocalDataDir
|
||||
} else {
|
||||
cfg.DataDir = smart.DataDir()
|
||||
}
|
||||
}
|
||||
if cfg.Metrics.ListenAddr == "" {
|
||||
cfg.Metrics.ListenAddr = ":9100"
|
||||
}
|
||||
if cfg.MetricsFlushInterval == 0 {
|
||||
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
|
||||
}
|
||||
if cfg.DatasetCacheTTL == 0 {
|
||||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||||
}
|
||||
|
||||
// Set lease and retry defaults
|
||||
if cfg.TaskLeaseDuration == 0 {
|
||||
cfg.TaskLeaseDuration = 30 * time.Minute
|
||||
}
|
||||
if cfg.HeartbeatInterval == 0 {
|
||||
cfg.HeartbeatInterval = 1 * time.Minute
|
||||
}
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 3
|
||||
}
|
||||
if cfg.GracefulTimeout == 0 {
|
||||
cfg.GracefulTimeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements config.Validator interface
|
||||
func (c *Config) Validate() error {
|
||||
if c.Port != 0 {
|
||||
if err := config.ValidatePort(c.Port); err != nil {
|
||||
return fmt.Errorf("invalid SSH port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = config.ExpandPath(c.BasePath)
|
||||
if !filepath.IsAbs(c.BasePath) {
|
||||
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
|
||||
}
|
||||
}
|
||||
|
||||
if c.RedisAddr != "" {
|
||||
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
|
||||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.MaxWorkers < 1 {
|
||||
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Task struct and Redis constants moved to internal/queue
|
||||
883
cmd/worker/worker_server.go
Normal file
883
cmd/worker/worker_server.go
Normal file
|
|
@ -0,0 +1,883 @@
|
|||
// Package main implements the ML task worker
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/errors"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// MLServer wraps network.SSHClient for backward compatibility
|
||||
type MLServer struct {
|
||||
*network.SSHClient
|
||||
}
|
||||
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MLServer{SSHClient: client}, nil
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
id string
|
||||
config *Config
|
||||
server *MLServer
|
||||
queue *queue.TaskQueue
|
||||
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
|
||||
runningMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logging.Logger
|
||||
metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
|
||||
datasetCache map[string]time.Time
|
||||
datasetCacheMu sync.RWMutex
|
||||
datasetCacheTTL time.Duration
|
||||
|
||||
// Graceful shutdown fields
|
||||
shutdownCh chan struct{}
|
||||
activeTasks sync.Map // map[string]*queue.Task - track active tasks
|
||||
gracefulWait sync.WaitGroup
|
||||
}
|
||||
|
||||
func (w *Worker) setupMetricsExporter() error {
|
||||
if !w.config.Metrics.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
reg.MustRegister(
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
collectors.NewGoCollector(),
|
||||
)
|
||||
|
||||
labels := prometheus.Labels{"worker_id": w.id}
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_processed_total",
|
||||
Help: "Total tasks processed successfully by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksProcessed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_failed_total",
|
||||
Help: "Total tasks failed by this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.TasksFailed.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_active",
|
||||
Help: "Number of tasks currently running on this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.runningCount())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_tasks_queued",
|
||||
Help: "Latest observed queue depth from Redis.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.QueuedTasks.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_transferred_bytes_total",
|
||||
Help: "Total bytes transferred while fetching datasets.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataTransferred.Load())
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_data_fetch_time_seconds_total",
|
||||
Help: "Total time spent fetching datasets (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_execution_time_seconds_total",
|
||||
Help: "Total execution time for completed tasks (seconds).",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second)
|
||||
}))
|
||||
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Name: "fetchml_worker_max_concurrency",
|
||||
Help: "Configured maximum concurrent tasks for this worker.",
|
||||
ConstLabels: labels,
|
||||
}, func() float64 {
|
||||
return float64(w.config.MaxWorkers)
|
||||
}))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: w.config.Metrics.ListenAddr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
w.metricsSrv = srv
|
||||
go func() {
|
||||
w.logger.Info("metrics exporter listening",
|
||||
"addr", w.config.Metrics.ListenAddr,
|
||||
"enabled", w.config.Metrics.Enabled)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
w.logger.Warn("metrics exporter stopped",
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewWorker(cfg *Config, apiKey string) (*Worker, error) {
|
||||
srv, err := NewMLServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
RedisDB: cfg.RedisDB,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
}
|
||||
queue, err := queue.NewTaskQueue(queueCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create data_dir if it doesn't exist (for production without NAS)
|
||||
if cfg.DataDir != "" {
|
||||
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
|
||||
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = logging.EnsureTrace(ctx)
|
||||
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
|
||||
|
||||
baseLogger := logging.NewLogger(slog.LevelInfo, false)
|
||||
logger := baseLogger.Component(ctx, "worker")
|
||||
metrics := &metrics.Metrics{}
|
||||
|
||||
worker := &Worker{
|
||||
id: cfg.WorkerID,
|
||||
config: cfg,
|
||||
server: srv,
|
||||
queue: queue,
|
||||
running: make(map[string]context.CancelFunc),
|
||||
datasetCache: make(map[string]time.Time),
|
||||
datasetCacheTTL: cfg.DatasetCacheTTL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
metrics: metrics,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := worker.setupMetricsExporter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
func (w *Worker) Start() {
|
||||
w.logger.Info("worker started",
|
||||
"worker_id", w.id,
|
||||
"max_concurrent", w.config.MaxWorkers,
|
||||
"poll_interval", w.config.PollInterval)
|
||||
|
||||
go w.heartbeat()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
w.logger.Info("shutdown signal received, waiting for tasks")
|
||||
w.waitForTasks()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if w.runningCount() >= w.config.MaxWorkers {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
queueStart := time.Now()
|
||||
task, err := w.queue.GetNextTaskWithLease(w.config.WorkerID, w.config.TaskLeaseDuration)
|
||||
queueLatency := time.Since(queueStart)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
continue
|
||||
}
|
||||
w.logger.Error("error fetching task",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task == nil {
|
||||
if queueLatency > 200*time.Millisecond {
|
||||
w.logger.Debug("queue poll latency",
|
||||
"latency_ms", queueLatency.Milliseconds())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if depth, derr := w.queue.QueueDepth(); derr == nil {
|
||||
if queueLatency > 100*time.Millisecond || depth > 0 {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"remaining_depth", depth)
|
||||
}
|
||||
} else if queueLatency > 100*time.Millisecond {
|
||||
w.logger.Debug("queue fetch metrics",
|
||||
"latency_ms", queueLatency.Milliseconds(),
|
||||
"depth_error", derr)
|
||||
}
|
||||
|
||||
go w.executeTaskWithLease(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) heartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.Heartbeat(w.id); err != nil {
|
||||
w.logger.Warn("heartbeat failed",
|
||||
"worker_id", w.id,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NEW: Fetch datasets using data_manager
|
||||
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
|
||||
logger := w.logger.Job(ctx, task.JobName, task.ID)
|
||||
logger.Info("fetching datasets",
|
||||
"worker_id", w.id,
|
||||
"dataset_count", len(task.Datasets))
|
||||
|
||||
for _, dataset := range task.Datasets {
|
||||
if w.datasetIsFresh(dataset) {
|
||||
logger.Debug("skipping cached dataset",
|
||||
"dataset", dataset)
|
||||
continue
|
||||
}
|
||||
// Check for cancellation before each dataset fetch
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("fetching dataset",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
|
||||
// Create command with context for cancellation support
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
cmd := exec.CommandContext(cmdCtx,
|
||||
w.config.DataManagerPath,
|
||||
"fetch",
|
||||
task.JobName,
|
||||
dataset,
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
cancel() // Clean up context
|
||||
|
||||
if err != nil {
|
||||
return &errors.DataFetchError{
|
||||
Dataset: dataset,
|
||||
JobName: task.JobName,
|
||||
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("dataset ready",
|
||||
"worker_id", w.id,
|
||||
"dataset", dataset)
|
||||
w.markDatasetFetched(dataset)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) runJob(task *queue.Task) error {
|
||||
// Validate job name to prevent path traversal
|
||||
if err := container.ValidateJobName(task.JobName); err != nil {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "validation",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
jobPaths := config.NewJobPaths(w.config.BasePath)
|
||||
jobDir := filepath.Join(jobPaths.PendingPath(), task.JobName)
|
||||
outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName)
|
||||
logFile := filepath.Join(outputDir, "output.log")
|
||||
|
||||
// Sanitize paths
|
||||
jobDir, err := container.SanitizePath(jobDir)
|
||||
if err != nil {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "validation",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
outputDir, err = container.SanitizePath(outputDir)
|
||||
if err != nil {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "validation",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("mkdir failed: %w", err)
|
||||
}
|
||||
return "", nil
|
||||
}); err != nil {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "setup",
|
||||
Err: fmt.Errorf("failed to create output dir: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
// Move job from pending to running
|
||||
stagingStart := time.Now()
|
||||
if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) {
|
||||
if err := os.Rename(jobDir, outputDir); err != nil {
|
||||
return "", fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
return "", nil
|
||||
}); err != nil {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "setup",
|
||||
Err: fmt.Errorf("failed to move job: %w", err),
|
||||
}
|
||||
}
|
||||
stagingDuration := time.Since(stagingStart)
|
||||
|
||||
if w.config.PodmanImage == "" {
|
||||
return &errors.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
Phase: "validation",
|
||||
Err: fmt.Errorf("podman_image must be configured"),
|
||||
}
|
||||
}
|
||||
|
||||
containerWorkspace := w.config.ContainerWorkspace
|
||||
if containerWorkspace == "" {
|
||||
containerWorkspace = config.DefaultContainerWorkspace
|
||||
}
|
||||
containerResults := w.config.ContainerResults
|
||||
if containerResults == "" {
|
||||
containerResults = config.DefaultContainerResults
|
||||
}
|
||||
|
||||
podmanCfg := container.PodmanConfig{
|
||||
Image: w.config.PodmanImage,
|
||||
Workspace: filepath.Join(outputDir, "code"),
|
||||
Results: filepath.Join(outputDir, "results"),
|
||||
ContainerWorkspace: containerWorkspace,
|
||||
ContainerResults: containerResults,
|
||||
GPUAccess: w.config.GPUAccess,
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript)
|
||||
requirementsPath := filepath.Join(containerWorkspace, "requirements.txt")
|
||||
|
||||
var extraArgs []string
|
||||
if task.Args != "" {
|
||||
extraArgs = strings.Fields(task.Args)
|
||||
}
|
||||
|
||||
ioBefore, ioErr := telemetry.ReadProcessIO()
|
||||
podmanCmd := container.BuildPodmanCommand(podmanCfg, scriptPath, requirementsPath, extraArgs)
|
||||
logFileHandle, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err == nil {
|
||||
podmanCmd.Stdout = logFileHandle
|
||||
podmanCmd.Stderr = logFileHandle
|
||||
} else {
|
||||
w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err)
|
||||
}
|
||||
|
||||
w.logger.Info("executing podman job",
|
||||
"job", task.JobName,
|
||||
"image", w.config.PodmanImage,
|
||||
"workspace", podmanCfg.Workspace,
|
||||
"results", podmanCfg.Results)
|
||||
|
||||
containerStart := time.Now()
|
||||
if err := podmanCmd.Run(); err != nil {
|
||||
containerDuration := time.Since(containerStart)
|
||||
// Move job to failed directory
|
||||
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
|
||||
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) {
|
||||
if err := os.Rename(outputDir, failedDir); err != nil {
|
||||
return "", fmt.Errorf("rename to failed failed: %w", err)
|
||||
}
|
||||
return "", nil
|
||||
}); moveErr != nil {
|
||||
w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr)
|
||||
}
|
||||
|
||||
if ioErr == nil {
|
||||
if after, err := telemetry.ReadProcessIO(); err == nil {
|
||||
delta := telemetry.DiffIO(ioBefore, after)
|
||||
w.logger.Debug("worker io stats",
|
||||
"job", task.JobName,
|
||||
"read_bytes", delta.ReadBytes,
|
||||
"write_bytes", delta.WriteBytes)
|
||||
}
|
||||
}
|
||||
w.logger.Info("job timing (failure)",
|
||||
"job", task.JobName,
|
||||
"staging_ms", stagingDuration.Milliseconds(),
|
||||
"container_ms", containerDuration.Milliseconds(),
|
||||
"finalize_ms", 0,
|
||||
"total_ms", time.Since(stagingStart).Milliseconds(),
|
||||
)
|
||||
return fmt.Errorf("execution failed: %w", err)
|
||||
}
|
||||
containerDuration := time.Since(containerStart)
|
||||
|
||||
finalizeStart := time.Now()
|
||||
// Move job to finished directory
|
||||
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
|
||||
if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) {
|
||||
if err := os.Rename(outputDir, finishedDir); err != nil {
|
||||
return "", fmt.Errorf("rename to finished failed: %w", err)
|
||||
}
|
||||
return "", nil
|
||||
}); moveErr != nil {
|
||||
w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr)
|
||||
}
|
||||
finalizeDuration := time.Since(finalizeStart)
|
||||
totalDuration := time.Since(stagingStart)
|
||||
var ioDelta telemetry.IOStats
|
||||
if ioErr == nil {
|
||||
if after, err := telemetry.ReadProcessIO(); err == nil {
|
||||
ioDelta = telemetry.DiffIO(ioBefore, after)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Info("job timing",
|
||||
"job", task.JobName,
|
||||
"staging_ms", stagingDuration.Milliseconds(),
|
||||
"container_ms", containerDuration.Milliseconds(),
|
||||
"finalize_ms", finalizeDuration.Milliseconds(),
|
||||
"total_ms", totalDuration.Milliseconds(),
|
||||
"io_read_bytes", ioDelta.ReadBytes,
|
||||
"io_write_bytes", ioDelta.WriteBytes,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDatasets(args string) []string {
|
||||
if !strings.Contains(args, "--datasets") {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(args)
|
||||
for i, part := range parts {
|
||||
if part == "--datasets" && i+1 < len(parts) {
|
||||
return strings.Split(parts[i+1], ",")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) waitForTasks() {
|
||||
timeout := time.After(5 * time.Minute)
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
w.logger.Warn("shutdown timeout, force stopping",
|
||||
"running_tasks", len(w.running))
|
||||
return
|
||||
case <-ticker.C:
|
||||
count := w.runningCount()
|
||||
if count == 0 {
|
||||
w.logger.Info("all tasks completed, shutting down")
|
||||
return
|
||||
}
|
||||
w.logger.Debug("waiting for tasks to complete",
|
||||
"remaining", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) runningCount() int {
|
||||
w.runningMu.RLock()
|
||||
defer w.runningMu.RUnlock()
|
||||
return len(w.running)
|
||||
}
|
||||
|
||||
func (w *Worker) datasetIsFresh(dataset string) bool {
|
||||
w.datasetCacheMu.RLock()
|
||||
defer w.datasetCacheMu.RUnlock()
|
||||
expires, ok := w.datasetCache[dataset]
|
||||
return ok && time.Now().Before(expires)
|
||||
}
|
||||
|
||||
func (w *Worker) markDatasetFetched(dataset string) {
|
||||
expires := time.Now().Add(w.datasetCacheTTL)
|
||||
w.datasetCacheMu.Lock()
|
||||
w.datasetCache[dataset] = expires
|
||||
w.datasetCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Worker) GetMetrics() map[string]any {
|
||||
stats := w.metrics.GetStats()
|
||||
stats["worker_id"] = w.id
|
||||
stats["max_workers"] = w.config.MaxWorkers
|
||||
return stats
|
||||
}
|
||||
|
||||
func (w *Worker) Stop() {
|
||||
w.cancel()
|
||||
w.waitForTasks()
|
||||
|
||||
// FIXED: Check error return values
|
||||
if err := w.server.Close(); err != nil {
|
||||
w.logger.Warn("error closing server connection", "error", err)
|
||||
}
|
||||
if err := w.queue.Close(); err != nil {
|
||||
w.logger.Warn("error closing queue connection", "error", err)
|
||||
}
|
||||
if w.metricsSrv != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
||||
w.logger.Warn("metrics exporter shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
w.logger.Info("worker stopped", "worker_id", w.id)
|
||||
}
|
||||
|
||||
// Execute task with lease management and retry:
|
||||
func (w *Worker) executeTaskWithLease(task *queue.Task) {
|
||||
// Track task for graceful shutdown
|
||||
w.gracefulWait.Add(1)
|
||||
w.activeTasks.Store(task.ID, task)
|
||||
defer w.gracefulWait.Done()
|
||||
defer w.activeTasks.Delete(task.ID)
|
||||
|
||||
// Create task-specific context with timeout
|
||||
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
|
||||
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
|
||||
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
|
||||
|
||||
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
|
||||
defer taskCancel()
|
||||
|
||||
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
|
||||
logger.Info("starting task",
|
||||
"worker_id", w.id,
|
||||
"datasets", task.Datasets,
|
||||
"priority", task.Priority)
|
||||
|
||||
// Record task start
|
||||
w.metrics.RecordTaskStart()
|
||||
defer w.metrics.RecordTaskCompletion()
|
||||
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
logger.Info("task cancelled before execution")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Parse datasets from task arguments
|
||||
if task.Datasets == nil {
|
||||
task.Datasets = parseDatasets(task.Args)
|
||||
}
|
||||
|
||||
// Start heartbeat goroutine
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer cancelHeartbeat()
|
||||
|
||||
go w.heartbeatLoop(heartbeatCtx, task.ID)
|
||||
|
||||
// Update task status
|
||||
task.Status = "running"
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
task.WorkerID = w.id
|
||||
|
||||
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
|
||||
logger.Error("failed to update task status", "error", err)
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
|
||||
if w.config.AutoFetchData && len(task.Datasets) > 0 {
|
||||
if err := w.fetchDatasets(taskCtx, task); err != nil {
|
||||
logger.Error("data fetch failed", "error", err)
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
err := w.queue.UpdateTask(task)
|
||||
if err != nil {
|
||||
logger.Error("failed to update task status after data fetch failure", "error", err)
|
||||
}
|
||||
w.metrics.RecordTaskFailure()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Execute job with panic recovery
|
||||
var execErr error
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
execErr = fmt.Errorf("panic during execution: %v", r)
|
||||
}
|
||||
}()
|
||||
execErr = w.runJob(task)
|
||||
}()
|
||||
|
||||
// Finalize task
|
||||
endTime := time.Now()
|
||||
task.EndedAt = &endTime
|
||||
|
||||
if execErr != nil {
|
||||
task.Error = execErr.Error()
|
||||
|
||||
// Check if transient error (network, timeout, etc)
|
||||
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
|
||||
w.logger.Warn("task failed with transient error, will retry",
|
||||
"task_id", task.ID,
|
||||
"error", execErr,
|
||||
"retry_count", task.RetryCount)
|
||||
w.queue.RetryTask(task)
|
||||
} else {
|
||||
task.Status = "failed"
|
||||
w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
w.queue.UpdateTaskWithMetrics(task, "final")
|
||||
}
|
||||
|
||||
// Release lease
|
||||
w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
||||
}
|
||||
|
||||
// Heartbeat loop to renew lease:
|
||||
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
|
||||
ticker := time.NewTicker(w.config.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
|
||||
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
|
||||
return
|
||||
}
|
||||
// Also update worker heartbeat
|
||||
w.queue.Heartbeat(w.config.WorkerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Graceful shutdown:
|
||||
func (w *Worker) Shutdown() error {
|
||||
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
|
||||
|
||||
// Wait for active tasks with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.gracefulWait.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timeout := time.After(w.config.GracefulTimeout)
|
||||
select {
|
||||
case <-done:
|
||||
w.logger.Info("all tasks completed, shutdown successful")
|
||||
case <-timeout:
|
||||
w.logger.Warn("graceful shutdown timeout, releasing active leases")
|
||||
w.releaseAllLeases()
|
||||
}
|
||||
|
||||
return w.queue.Close()
|
||||
}
|
||||
|
||||
// Release all active leases:
|
||||
func (w *Worker) releaseAllLeases() {
|
||||
w.activeTasks.Range(func(key, value interface{}) bool {
|
||||
taskID := key.(string)
|
||||
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
|
||||
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions:
|
||||
func (w *Worker) countActiveTasks() int {
|
||||
count := 0
|
||||
w.activeTasks.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Check if error is transient (network, timeout, resource unavailable, etc)
|
||||
errStr := err.Error()
|
||||
transientIndicators := []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"resource temporarily unavailable",
|
||||
"no such host",
|
||||
"network unreachable",
|
||||
}
|
||||
for _, indicator := range transientIndicators {
|
||||
if strings.Contains(strings.ToLower(errStr), indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
// Parse authentication flags
|
||||
authFlags := auth.ParseAuthFlags()
|
||||
if err := auth.ValidateAuthFlags(authFlags); err != nil {
|
||||
log.Fatalf("Authentication flag error: %v", err)
|
||||
}
|
||||
|
||||
// Get API key from various sources
|
||||
apiKey := auth.GetAPIKeyFromSources(authFlags)
|
||||
|
||||
// Load configuration
|
||||
configPath := "config-local.yaml"
|
||||
if authFlags.ConfigFile != "" {
|
||||
configPath = authFlags.ConfigFile
|
||||
}
|
||||
|
||||
resolvedConfig, err := config.ResolveConfigPath(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("%v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(resolvedConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Validate authentication configuration
|
||||
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
||||
log.Fatalf("Invalid authentication configuration: %v", err)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("Invalid configuration: %v", err)
|
||||
}
|
||||
|
||||
// Test authentication if enabled
|
||||
if cfg.Auth.Enabled && apiKey != "" {
|
||||
user, err := cfg.Auth.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Authentication failed: %v", err)
|
||||
}
|
||||
log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin)
|
||||
} else if cfg.Auth.Enabled {
|
||||
log.Fatal("Authentication required but no API key provided")
|
||||
}
|
||||
|
||||
worker, err := NewWorker(cfg, apiKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create worker: %v", err)
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go worker.Start()
|
||||
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal: %v", sig)
|
||||
|
||||
// Use graceful shutdown
|
||||
if err := worker.Shutdown(); err != nil {
|
||||
log.Printf("Graceful shutdown error: %v", err)
|
||||
worker.Stop() // Fallback to force stop
|
||||
} else {
|
||||
log.Println("Worker shut down gracefully")
|
||||
}
|
||||
}
|
||||
78
examples/auth_integration_example.go
Normal file
78
examples/auth_integration_example.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Example: How to integrate auth into TUI startup
|
||||
func checkAuth(configFile string) error {
|
||||
// Load config
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Auth auth.AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w", err)
|
||||
}
|
||||
|
||||
// If auth disabled, proceed normally
|
||||
if !cfg.Auth.Enabled {
|
||||
fmt.Println("🔓 Authentication disabled - proceeding normally")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for API key
|
||||
apiKey := os.Getenv("FETCH_ML_API_KEY")
|
||||
if apiKey == "" {
|
||||
apiKey = getAPIKeyFromUser()
|
||||
}
|
||||
|
||||
// Validate API key
|
||||
user, err := cfg.Auth.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("🔐 Authenticated as: %s", user.Name)
|
||||
if user.Admin {
|
||||
fmt.Println(" (admin)")
|
||||
} else {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAPIKeyFromUser() string {
|
||||
fmt.Print("🔑 Enter API key: ")
|
||||
var key string
|
||||
fmt.Scanln(&key)
|
||||
return key
|
||||
}
|
||||
|
||||
// Example usage in main()
|
||||
func exampleMain() {
|
||||
configFile := "config_dev.yaml"
|
||||
|
||||
// Check authentication first
|
||||
if err := checkAuth(configFile); err != nil {
|
||||
log.Fatalf("Authentication failed: %v", err)
|
||||
}
|
||||
|
||||
// Proceed with normal TUI initialization
|
||||
fmt.Println("Starting TUI...")
|
||||
}
|
||||
|
||||
func main() {
|
||||
exampleMain()
|
||||
}
|
||||
117
internal/api/permissions_test.go
Normal file
117
internal/api/permissions_test.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func TestUserPermissions(t *testing.T) {
|
||||
authConfig := &auth.AuthConfig{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"admin": {
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey("admin_key")),
|
||||
Admin: true,
|
||||
},
|
||||
"scientist": {
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey("ds_key")),
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{
|
||||
"jobs:create": true,
|
||||
"jobs:read": true,
|
||||
"jobs:update": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
permission string
|
||||
want bool
|
||||
}{
|
||||
{"Admin can create", "admin_key", "jobs:create", true},
|
||||
{"Scientist can create", "ds_key", "jobs:create", true},
|
||||
{"Invalid key fails", "invalid_key", "jobs:create", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := authConfig.ValidateAPIKey(tt.apiKey)
|
||||
|
||||
if tt.apiKey == "invalid_key" {
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid API key")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
got := user.HasPermission(tt.permission)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskOwnership(t *testing.T) {
|
||||
tasks := []*queue.Task{
|
||||
{
|
||||
ID: "task1",
|
||||
JobName: "user1_job",
|
||||
UserID: "user1",
|
||||
CreatedBy: "user1",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "task2",
|
||||
JobName: "user2_job",
|
||||
UserID: "user2",
|
||||
CreatedBy: "user2",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
users := map[string]*auth.User{
|
||||
"user1": {Name: "user1", Admin: false},
|
||||
"user2": {Name: "user2", Admin: false},
|
||||
"admin": {Name: "admin", Admin: true},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userName string
|
||||
task *queue.Task
|
||||
want bool
|
||||
}{
|
||||
{"User can view own task", "user1", tasks[0], true},
|
||||
{"User cannot view other task", "user1", tasks[1], false},
|
||||
{"Admin can view any task", "admin", tasks[1], true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user := users[tt.userName]
|
||||
|
||||
canAccess := false
|
||||
if user.Admin {
|
||||
canAccess = true
|
||||
} else if tt.task.UserID == user.Name || tt.task.CreatedBy == user.Name {
|
||||
canAccess = true
|
||||
}
|
||||
|
||||
if canAccess != tt.want {
|
||||
t.Errorf("Access = %v, want %v", canAccess, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
305
internal/api/protocol.go
Normal file
305
internal/api/protocol.go
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Response packet types
|
||||
const (
|
||||
PacketTypeSuccess = 0x00
|
||||
PacketTypeError = 0x01
|
||||
PacketTypeProgress = 0x02
|
||||
PacketTypeStatus = 0x03
|
||||
PacketTypeData = 0x04
|
||||
PacketTypeLog = 0x05
|
||||
)
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrorCodeUnknownError = 0x00
|
||||
ErrorCodeInvalidRequest = 0x01
|
||||
ErrorCodeAuthenticationFailed = 0x02
|
||||
ErrorCodePermissionDenied = 0x03
|
||||
ErrorCodeResourceNotFound = 0x04
|
||||
ErrorCodeResourceAlreadyExists = 0x05
|
||||
|
||||
ErrorCodeServerOverloaded = 0x10
|
||||
ErrorCodeDatabaseError = 0x11
|
||||
ErrorCodeNetworkError = 0x12
|
||||
ErrorCodeStorageError = 0x13
|
||||
ErrorCodeTimeout = 0x14
|
||||
|
||||
ErrorCodeJobNotFound = 0x20
|
||||
ErrorCodeJobAlreadyRunning = 0x21
|
||||
ErrorCodeJobFailedToStart = 0x22
|
||||
ErrorCodeJobExecutionFailed = 0x23
|
||||
ErrorCodeJobCancelled = 0x24
|
||||
|
||||
ErrorCodeOutOfMemory = 0x30
|
||||
ErrorCodeDiskFull = 0x31
|
||||
ErrorCodeInvalidConfiguration = 0x32
|
||||
ErrorCodeServiceUnavailable = 0x33
|
||||
)
|
||||
|
||||
// Progress types
|
||||
const (
|
||||
ProgressTypePercentage = 0x00
|
||||
ProgressTypeStage = 0x01
|
||||
ProgressTypeMessage = 0x02
|
||||
ProgressTypeBytesTransferred = 0x03
|
||||
)
|
||||
|
||||
// Log levels
|
||||
const (
|
||||
LogLevelDebug = 0x00
|
||||
LogLevelInfo = 0x01
|
||||
LogLevelWarn = 0x02
|
||||
LogLevelError = 0x03
|
||||
)
|
||||
|
||||
// ResponsePacket represents a structured response packet
|
||||
type ResponsePacket struct {
|
||||
PacketType byte
|
||||
Timestamp uint64
|
||||
|
||||
// Success fields
|
||||
SuccessMessage string
|
||||
|
||||
// Error fields
|
||||
ErrorCode byte
|
||||
ErrorMessage string
|
||||
ErrorDetails string
|
||||
|
||||
// Progress fields
|
||||
ProgressType byte
|
||||
ProgressValue uint32
|
||||
ProgressTotal uint32
|
||||
ProgressMessage string
|
||||
|
||||
// Status fields
|
||||
StatusData string
|
||||
|
||||
// Data fields
|
||||
DataType string
|
||||
DataPayload []byte
|
||||
|
||||
// Log fields
|
||||
LogLevel byte
|
||||
LogMessage string
|
||||
}
|
||||
|
||||
// NewSuccessPacket creates a success response packet
|
||||
func NewSuccessPacket(message string) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeSuccess,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
SuccessMessage: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSuccessPacketWithPayload creates a success response packet with JSON payload
|
||||
func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponsePacket {
|
||||
// Convert payload to JSON for the DataPayload field
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeData,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
SuccessMessage: message,
|
||||
DataType: "status",
|
||||
DataPayload: payloadBytes,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorPacket creates an error response packet
|
||||
func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeError,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
ErrorCode: errorCode,
|
||||
ErrorMessage: message,
|
||||
ErrorDetails: details,
|
||||
}
|
||||
}
|
||||
|
||||
// NewProgressPacket creates a progress response packet
|
||||
func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeProgress,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
ProgressType: progressType,
|
||||
ProgressValue: value,
|
||||
ProgressTotal: total,
|
||||
ProgressMessage: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewStatusPacket creates a status response packet
|
||||
func NewStatusPacket(data string) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeStatus,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
StatusData: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDataPacket creates a data response packet
|
||||
func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeData,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
DataType: dataType,
|
||||
DataPayload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
// NewLogPacket creates a log response packet
|
||||
func NewLogPacket(level byte, message string) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeLog,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
LogLevel: level,
|
||||
LogMessage: message,
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize converts the packet to binary format
|
||||
func (p *ResponsePacket) Serialize() ([]byte, error) {
|
||||
var buf []byte
|
||||
|
||||
// Packet type
|
||||
buf = append(buf, p.PacketType)
|
||||
|
||||
// Timestamp (8 bytes, big-endian)
|
||||
timestampBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timestampBytes, p.Timestamp)
|
||||
buf = append(buf, timestampBytes...)
|
||||
|
||||
// Packet-specific data
|
||||
switch p.PacketType {
|
||||
case PacketTypeSuccess:
|
||||
buf = append(buf, serializeString(p.SuccessMessage)...)
|
||||
|
||||
case PacketTypeError:
|
||||
buf = append(buf, p.ErrorCode)
|
||||
buf = append(buf, serializeString(p.ErrorMessage)...)
|
||||
buf = append(buf, serializeString(p.ErrorDetails)...)
|
||||
|
||||
case PacketTypeProgress:
|
||||
buf = append(buf, p.ProgressType)
|
||||
valueBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(valueBytes, p.ProgressValue)
|
||||
buf = append(buf, valueBytes...)
|
||||
|
||||
totalBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal)
|
||||
buf = append(buf, totalBytes...)
|
||||
|
||||
buf = append(buf, serializeString(p.ProgressMessage)...)
|
||||
|
||||
case PacketTypeStatus:
|
||||
buf = append(buf, serializeString(p.StatusData)...)
|
||||
|
||||
case PacketTypeData:
|
||||
buf = append(buf, serializeString(p.DataType)...)
|
||||
buf = append(buf, serializeBytes(p.DataPayload)...)
|
||||
|
||||
case PacketTypeLog:
|
||||
buf = append(buf, p.LogLevel)
|
||||
buf = append(buf, serializeString(p.LogMessage)...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown packet type: %d", p.PacketType)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// serializeString writes a string with 2-byte length prefix
|
||||
func serializeString(s string) []byte {
|
||||
length := uint16(len(s))
|
||||
buf := make([]byte, 2+len(s))
|
||||
binary.BigEndian.PutUint16(buf[:2], length)
|
||||
copy(buf[2:], s)
|
||||
return buf
|
||||
}
|
||||
|
||||
// serializeBytes writes bytes with 4-byte length prefix
|
||||
func serializeBytes(b []byte) []byte {
|
||||
length := uint32(len(b))
|
||||
buf := make([]byte, 4+len(b))
|
||||
binary.BigEndian.PutUint32(buf[:4], length)
|
||||
copy(buf[4:], b)
|
||||
return buf
|
||||
}
|
||||
|
||||
// GetErrorMessage returns a human-readable error message for an error code
|
||||
func GetErrorMessage(code byte) string {
|
||||
switch code {
|
||||
case ErrorCodeUnknownError:
|
||||
return "Unknown error occurred"
|
||||
case ErrorCodeInvalidRequest:
|
||||
return "Invalid request format"
|
||||
case ErrorCodeAuthenticationFailed:
|
||||
return "Authentication failed"
|
||||
case ErrorCodePermissionDenied:
|
||||
return "Permission denied"
|
||||
case ErrorCodeResourceNotFound:
|
||||
return "Resource not found"
|
||||
case ErrorCodeResourceAlreadyExists:
|
||||
return "Resource already exists"
|
||||
|
||||
case ErrorCodeServerOverloaded:
|
||||
return "Server is overloaded"
|
||||
case ErrorCodeDatabaseError:
|
||||
return "Database error occurred"
|
||||
case ErrorCodeNetworkError:
|
||||
return "Network error occurred"
|
||||
case ErrorCodeStorageError:
|
||||
return "Storage error occurred"
|
||||
case ErrorCodeTimeout:
|
||||
return "Operation timed out"
|
||||
|
||||
case ErrorCodeJobNotFound:
|
||||
return "Job not found"
|
||||
case ErrorCodeJobAlreadyRunning:
|
||||
return "Job is already running"
|
||||
case ErrorCodeJobFailedToStart:
|
||||
return "Job failed to start"
|
||||
case ErrorCodeJobExecutionFailed:
|
||||
return "Job execution failed"
|
||||
case ErrorCodeJobCancelled:
|
||||
return "Job was cancelled"
|
||||
|
||||
case ErrorCodeOutOfMemory:
|
||||
return "Server out of memory"
|
||||
case ErrorCodeDiskFull:
|
||||
return "Server disk full"
|
||||
case ErrorCodeInvalidConfiguration:
|
||||
return "Invalid server configuration"
|
||||
case ErrorCodeServiceUnavailable:
|
||||
return "Service temporarily unavailable"
|
||||
|
||||
default:
|
||||
return "Unknown error code"
|
||||
}
|
||||
}
|
||||
|
||||
// GetLogLevelName returns the name for a log level
|
||||
func GetLogLevelName(level byte) string {
|
||||
switch level {
|
||||
case LogLevelDebug:
|
||||
return "DEBUG"
|
||||
case LogLevelInfo:
|
||||
return "INFO"
|
||||
case LogLevelWarn:
|
||||
return "WARN"
|
||||
case LogLevelError:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
606
internal/api/ws.go
Normal file
606
internal/api/ws.go
Normal file
|
|
@ -0,0 +1,606 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// Opcodes for binary WebSocket protocol
|
||||
const (
|
||||
OpcodeQueueJob = 0x01
|
||||
OpcodeStatusRequest = 0x02
|
||||
OpcodeCancelJob = 0x03
|
||||
OpcodePrune = 0x04
|
||||
OpcodeLogMetric = 0x0A
|
||||
OpcodeGetExperiment = 0x0B
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Allow localhost and homelab origins for development
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // Allow same-origin requests
|
||||
}
|
||||
|
||||
// Parse origin URL
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Allow localhost and local network origins
|
||||
host := parsedOrigin.Host
|
||||
return strings.HasSuffix(host, ":8080") ||
|
||||
strings.HasSuffix(host, ":8081") ||
|
||||
strings.HasPrefix(host, "localhost") ||
|
||||
strings.HasPrefix(host, "127.0.0.1") ||
|
||||
strings.HasPrefix(host, "192.168.") ||
|
||||
strings.HasPrefix(host, "10.") ||
|
||||
strings.HasPrefix(host, "172.")
|
||||
},
|
||||
}
|
||||
|
||||
type WSHandler struct {
|
||||
authConfig *auth.AuthConfig
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
queue *queue.TaskQueue
|
||||
}
|
||||
|
||||
func NewWSHandler(authConfig *auth.AuthConfig, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue) *WSHandler {
|
||||
return &WSHandler{
|
||||
authConfig: authConfig,
|
||||
logger: logger,
|
||||
expManager: expManager,
|
||||
queue: taskQueue,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check API key before upgrading WebSocket
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey == "" {
|
||||
// Also check Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate API key if authentication is enabled
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
|
||||
h.logger.Warn("websocket authentication failed", "error", err)
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
||||
|
||||
for {
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
h.logger.Warn("received non-binary message")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.handleMessage(conn, message); err != nil {
|
||||
h.logger.Error("message handling error", "error", err)
|
||||
// Send error response
|
||||
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
||||
if len(message) < 1 {
|
||||
return fmt.Errorf("message too short")
|
||||
}
|
||||
|
||||
opcode := message[0]
|
||||
payload := message[1:]
|
||||
|
||||
switch opcode {
|
||||
case OpcodeQueueJob:
|
||||
return h.handleQueueJob(conn, payload)
|
||||
case OpcodeStatusRequest:
|
||||
return h.handleStatusRequest(conn, payload)
|
||||
case OpcodeCancelJob:
|
||||
return h.handleCancelJob(conn, payload)
|
||||
case OpcodePrune:
|
||||
return h.handlePrune(conn, payload)
|
||||
case OpcodeLogMetric:
|
||||
return h.handleLogMetric(conn, payload)
|
||||
case OpcodeGetExperiment:
|
||||
return h.handleGetExperiment(conn, payload)
|
||||
default:
|
||||
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
|
||||
if len(payload) < 130 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
priority := int64(payload[128])
|
||||
jobNameLen := int(payload[129])
|
||||
|
||||
if len(payload) < 130+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
|
||||
jobName := string(payload[130 : 130+jobNameLen])
|
||||
|
||||
h.logger.Info("queue job request",
|
||||
"job", jobName,
|
||||
"priority", priority,
|
||||
"commit_id", commitID,
|
||||
)
|
||||
|
||||
// Validate API key and get user information
|
||||
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
|
||||
// Check user permissions
|
||||
if !h.authConfig.Enabled || user.HasPermission("jobs:create") {
|
||||
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
|
||||
}
|
||||
|
||||
// Create experiment directory and metadata
|
||||
if err := h.expManager.CreateExperiment(commitID); err != nil {
|
||||
h.logger.Error("failed to create experiment directory", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
|
||||
}
|
||||
|
||||
// Add user info to experiment metadata
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitID,
|
||||
JobName: jobName,
|
||||
User: user.Name,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if err := h.expManager.WriteMetadata(meta); err != nil {
|
||||
h.logger.Error("failed to save experiment metadata", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
||||
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
||||
|
||||
// Enqueue task if queue is available
|
||||
if h.queue != nil {
|
||||
taskID := uuid.New().String()
|
||||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: "", // TODO: Add args support
|
||||
Status: "queued",
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.Name,
|
||||
Username: user.Name,
|
||||
CreatedBy: user.Name,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitID,
|
||||
"user_id": user.Name,
|
||||
"username": user.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.queue.AddTask(task); err != nil {
|
||||
h.logger.Error("failed to enqueue task", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
|
||||
}
|
||||
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
|
||||
} else {
|
||||
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
||||
}
|
||||
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64]
|
||||
if len(payload) < 64 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[0:64])
|
||||
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
|
||||
|
||||
// Validate API key and get user information
|
||||
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
|
||||
// Check user permissions for viewing jobs
|
||||
if !h.authConfig.Enabled || user.HasPermission("jobs:read") {
|
||||
// Continue with status request
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
|
||||
}
|
||||
|
||||
// Get tasks with user filtering
|
||||
var tasks []*queue.Task
|
||||
if h.queue != nil {
|
||||
allTasks, err := h.queue.GetAllTasks()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get tasks", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
|
||||
}
|
||||
|
||||
// Filter tasks based on user permissions
|
||||
for _, task := range allTasks {
|
||||
// If auth is disabled or admin can see all tasks
|
||||
if !h.authConfig.Enabled || user.Admin {
|
||||
tasks = append(tasks, task)
|
||||
continue
|
||||
}
|
||||
|
||||
// Users can only see their own tasks
|
||||
if task.UserID == user.Name || task.CreatedBy == user.Name {
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build status response with user-specific data
|
||||
status := map[string]interface{}{
|
||||
"user": map[string]interface{}{
|
||||
"name": user.Name,
|
||||
"admin": user.Admin,
|
||||
"roles": user.Roles,
|
||||
},
|
||||
"tasks": map[string]interface{}{
|
||||
"total": len(tasks),
|
||||
"queued": countTasksByStatus(tasks, "queued"),
|
||||
"running": countTasksByStatus(tasks, "running"),
|
||||
"failed": countTasksByStatus(tasks, "failed"),
|
||||
"completed": countTasksByStatus(tasks, "completed"),
|
||||
},
|
||||
"queue": tasks, // Include filtered tasks
|
||||
}
|
||||
|
||||
packet := NewSuccessPacketWithPayload("Status retrieved", status)
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
// countTasksByStatus counts tasks by their status
|
||||
func countTasksByStatus(tasks []*queue.Task, status string) int {
|
||||
count := 0
|
||||
for _, task := range tasks {
|
||||
if task.Status == status {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
|
||||
if len(payload) < 65 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
|
||||
}
|
||||
|
||||
// Parse 64-byte hex API key hash
|
||||
apiKeyHash := string(payload[0:64])
|
||||
jobNameLen := int(payload[64])
|
||||
|
||||
if len(payload) < 65+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
|
||||
jobName := string(payload[65 : 65+jobNameLen])
|
||||
|
||||
h.logger.Info("cancel job request", "job", jobName)
|
||||
|
||||
// Validate API key and get user information
|
||||
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
|
||||
// Check user permissions for canceling jobs
|
||||
if !h.authConfig.Enabled || user.HasPermission("jobs:update") {
|
||||
// Continue with cancel request
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
|
||||
}
|
||||
|
||||
// Find the task and verify ownership
|
||||
if h.queue != nil {
|
||||
task, err := h.queue.GetTaskByName(jobName)
|
||||
if err != nil {
|
||||
h.logger.Error("task not found", "job", jobName, "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
|
||||
}
|
||||
|
||||
// Check if user can cancel this task (admin or owner)
|
||||
if !h.authConfig.Enabled || user.Admin || task.UserID == user.Name || task.CreatedBy == user.Name {
|
||||
// User can cancel the task
|
||||
} else {
|
||||
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
|
||||
}
|
||||
|
||||
// Cancel the task
|
||||
if err := h.queue.CancelTask(task.ID); err != nil {
|
||||
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
|
||||
} else {
|
||||
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
|
||||
}
|
||||
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][prune_type:1][value:4]
|
||||
if len(payload) < 69 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
|
||||
}
|
||||
|
||||
// Parse 64-byte hex API key hash
|
||||
apiKeyHash := string(payload[0:64])
|
||||
pruneType := payload[64]
|
||||
value := binary.BigEndian.Uint32(payload[65:69])
|
||||
|
||||
h.logger.Info("prune request", "type", pruneType, "value", value)
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
h.logger.Error("api key verification failed", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Convert prune parameters
|
||||
var keepCount int
|
||||
var olderThanDays int
|
||||
|
||||
switch pruneType {
|
||||
case 0:
|
||||
// keep N
|
||||
keepCount = int(value)
|
||||
olderThanDays = 0
|
||||
case 1:
|
||||
// older than days
|
||||
keepCount = 0
|
||||
olderThanDays = int(value)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
|
||||
}
|
||||
|
||||
// Perform pruning
|
||||
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
|
||||
if err != nil {
|
||||
h.logger.Error("prune failed", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
|
||||
|
||||
// Send structured success response
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
|
||||
if len(payload) < 141 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
step := int(binary.BigEndian.Uint32(payload[128:132]))
|
||||
valueBits := binary.BigEndian.Uint64(payload[132:140])
|
||||
value := math.Float64frombits(valueBits)
|
||||
nameLen := int(payload[140])
|
||||
|
||||
if len(payload) < 141+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
||||
}
|
||||
|
||||
name := string(payload[141 : 141+nameLen])
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
|
||||
h.logger.Error("failed to log metric", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64]
|
||||
if len(payload) < 128 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
meta, err := h.expManager.ReadMetadata(commitID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
|
||||
}
|
||||
|
||||
metrics, err := h.expManager.GetMetrics(commitID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"metadata": meta,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
|
||||
}
|
||||
|
||||
// Helper to hash API key for comparison
|
||||
func HashAPIKey(apiKey string) string {
|
||||
hash := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// SetupTLSConfig creates TLS configuration for WebSocket server
|
||||
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
|
||||
var server *http.Server
|
||||
|
||||
if certFile != "" && keyFile != "" {
|
||||
// Use provided certificates
|
||||
server = &http.Server{
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if host != "" {
|
||||
// Use Let's Encrypt with autocert
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(host),
|
||||
Cache: autocert.DirCache("/var/www/.cache"),
|
||||
}
|
||||
|
||||
server = &http.Server{
|
||||
TLSConfig: certManager.TLSConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
|
||||
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
|
||||
if h.authConfig == nil || !h.authConfig.Enabled {
|
||||
return nil // No auth required
|
||||
}
|
||||
|
||||
// For now, just check if it's a valid 64-char hex string
|
||||
if len(hexHash) != 64 {
|
||||
return fmt.Errorf("invalid api key hash length")
|
||||
}
|
||||
|
||||
// Check against stored API keys
|
||||
for username, entry := range h.authConfig.APIKeys {
|
||||
if string(entry.Hash) == hexHash {
|
||||
_ = username // Username found but not needed for verification
|
||||
return nil // Valid API key found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid api key")
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet
|
||||
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
|
||||
packet := NewErrorPacket(errorCode, message, details)
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
// sendResponsePacket sends a structured response packet
|
||||
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize response packet", "error", err)
|
||||
// Fallback to simple error response
|
||||
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
|
||||
}
|
||||
|
||||
return conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// sendErrorResponse removed (unused)
|
||||
335
internal/api/ws_test.go
Normal file
335
internal/api/ws_test.go
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupTestServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) {
|
||||
// Setup miniredis
|
||||
s, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup TaskQueue
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: s.Addr(),
|
||||
MetricsFlushInterval: 10 * time.Millisecond,
|
||||
}
|
||||
tq, err := queue.NewTaskQueue(queueCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup dependencies
|
||||
logger := logging.NewLogger(0, false)
|
||||
expManager := experiment.NewManager(t.TempDir())
|
||||
authCfg := &auth.AuthConfig{Enabled: false}
|
||||
|
||||
// Create handler
|
||||
handler := NewWSHandler(authCfg, logger, expManager, tq)
|
||||
|
||||
// Setup test server
|
||||
server := httptest.NewServer(handler)
|
||||
|
||||
return server, tq, expManager, s
|
||||
}
|
||||
|
||||
func connectWS(t *testing.T, serverURL string) *websocket.Conn {
|
||||
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
|
||||
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
require.NoError(t, err)
|
||||
return ws
|
||||
}
|
||||
|
||||
func TestWSHandler_QueueJob(t *testing.T) {
|
||||
server, tq, _, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare queue_job message
|
||||
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
|
||||
opcode := byte(OpcodeQueueJob)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
commitID := make([]byte, 64)
|
||||
copy(commitID, []byte(strings.Repeat("a", 64)))
|
||||
priority := byte(5)
|
||||
jobName := "test-job"
|
||||
jobNameLen := byte(len(jobName))
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
msg = append(msg, commitID...)
|
||||
msg = append(msg, priority)
|
||||
msg = append(msg, jobNameLen)
|
||||
msg = append(msg, []byte(jobName)...)
|
||||
|
||||
// Send message
|
||||
err := ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response (PacketTypeSuccess = 0x00)
|
||||
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
|
||||
|
||||
// Verify task in Redis
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
task, err := tq.GetNextTask()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, task)
|
||||
assert.Equal(t, jobName, task.JobName)
|
||||
}
|
||||
|
||||
func TestWSHandler_StatusRequest(t *testing.T) {
|
||||
server, tq, _, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
// Add a task to queue
|
||||
task := &queue.Task{
|
||||
ID: "task-1",
|
||||
JobName: "job-1",
|
||||
Status: "queued",
|
||||
Priority: 10,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user",
|
||||
CreatedBy: "user",
|
||||
}
|
||||
err := tq.AddTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare status_request message
|
||||
// Protocol: [opcode:1][api_key_hash:64]
|
||||
opcode := byte(OpcodeStatusRequest)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
|
||||
// Send message
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response (PacketTypeData = 0x04 for status with payload)
|
||||
assert.Equal(t, byte(PacketTypeData), resp[0])
|
||||
}
|
||||
|
||||
func TestWSHandler_CancelJob(t *testing.T) {
|
||||
server, tq, _, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
// Add a task to queue
|
||||
task := &queue.Task{
|
||||
ID: "task-1",
|
||||
JobName: "job-to-cancel",
|
||||
Status: "queued",
|
||||
Priority: 10,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user", // Auth disabled so this matches any user
|
||||
CreatedBy: "user",
|
||||
}
|
||||
err := tq.AddTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare cancel_job message
|
||||
// Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var]
|
||||
opcode := byte(OpcodeCancelJob)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
jobName := "job-to-cancel"
|
||||
jobNameLen := byte(len(jobName))
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
msg = append(msg, jobNameLen)
|
||||
msg = append(msg, []byte(jobName)...)
|
||||
|
||||
// Send message
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response
|
||||
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
|
||||
|
||||
// Verify task cancelled
|
||||
updatedTask, err := tq.GetTask("task-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cancelled", updatedTask.Status)
|
||||
}
|
||||
|
||||
func TestWSHandler_Prune(t *testing.T) {
|
||||
server, tq, expManager, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
// Create some experiments
|
||||
_ = expManager.CreateExperiment("commit-1")
|
||||
_ = expManager.CreateExperiment("commit-2")
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare prune message
|
||||
// Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4]
|
||||
opcode := byte(OpcodePrune)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
pruneType := byte(0) // Keep N
|
||||
value := uint32(1) // Keep 1
|
||||
valueBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(valueBytes, value)
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
msg = append(msg, pruneType)
|
||||
msg = append(msg, valueBytes...)
|
||||
|
||||
// Send message
|
||||
err := ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response
|
||||
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
|
||||
}
|
||||
|
||||
func TestWSHandler_LogMetric(t *testing.T) {
|
||||
server, tq, expManager, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
// Create experiment
|
||||
commitIDStr := strings.Repeat("a", 64)
|
||||
err := expManager.CreateExperiment(commitIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare log_metric message
|
||||
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
|
||||
opcode := byte(OpcodeLogMetric)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
commitID := []byte(commitIDStr)
|
||||
step := uint32(100)
|
||||
value := 0.95
|
||||
valueBits := math.Float64bits(value)
|
||||
metricName := "accuracy"
|
||||
nameLen := byte(len(metricName))
|
||||
|
||||
stepBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(stepBytes, step)
|
||||
valueBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(valueBytes, valueBits)
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
msg = append(msg, commitID...)
|
||||
msg = append(msg, stepBytes...)
|
||||
msg = append(msg, valueBytes...)
|
||||
msg = append(msg, nameLen)
|
||||
msg = append(msg, []byte(metricName)...)
|
||||
|
||||
// Send message
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response
|
||||
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
|
||||
}
|
||||
|
||||
func TestWSHandler_GetExperiment(t *testing.T) {
|
||||
server, tq, expManager, s := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer tq.Close()
|
||||
defer s.Close()
|
||||
|
||||
// Create experiment and metadata
|
||||
commitIDStr := strings.Repeat("a", 64)
|
||||
err := expManager.CreateExperiment(commitIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitIDStr,
|
||||
JobName: "test-job",
|
||||
}
|
||||
err = expManager.WriteMetadata(meta)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws := connectWS(t, server.URL)
|
||||
defer ws.Close()
|
||||
|
||||
// Prepare get_experiment message
|
||||
// Protocol: [opcode:1][api_key_hash:64][commit_id:64]
|
||||
opcode := byte(OpcodeGetExperiment)
|
||||
apiKeyHash := make([]byte, 64)
|
||||
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
||||
commitID := []byte(commitIDStr)
|
||||
|
||||
var msg []byte
|
||||
msg = append(msg, opcode)
|
||||
msg = append(msg, apiKeyHash...)
|
||||
msg = append(msg, commitID...)
|
||||
|
||||
// Send message
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
_, resp, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify success response (PacketTypeData)
|
||||
assert.Equal(t, byte(PacketTypeData), resp[0])
|
||||
}
|
||||
258
internal/auth/api_key.go
Normal file
258
internal/auth/api_key.go
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User represents an authenticated user
|
||||
type User struct {
|
||||
Name string `json:"name"`
|
||||
Admin bool `json:"admin"`
|
||||
Roles []string `json:"roles"`
|
||||
Permissions map[string]bool `json:"permissions"`
|
||||
}
|
||||
|
||||
// APIKeyHash represents a SHA256 hash of an API key
|
||||
type APIKeyHash string
|
||||
|
||||
// APIKeyEntry represents an API key configuration
|
||||
type APIKeyEntry struct {
|
||||
Hash APIKeyHash `json:"hash"`
|
||||
Admin bool `json:"admin"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Permissions map[string]bool `json:"permissions,omitempty"`
|
||||
}
|
||||
|
||||
// Username represents a user identifier
|
||||
type Username string
|
||||
|
||||
// AuthConfig represents the authentication configuration
|
||||
type AuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKeys map[Username]APIKeyEntry `json:"api_keys"`
|
||||
}
|
||||
|
||||
// AuthStore interface for different authentication backends
|
||||
type AuthStore interface {
|
||||
ValidateAPIKey(ctx context.Context, key string) (*User, error)
|
||||
CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error
|
||||
RevokeAPIKey(ctx context.Context, userID string) error
|
||||
ListUsers(ctx context.Context) ([]UserInfo, error)
|
||||
}
|
||||
|
||||
// contextKey is the type for context keys
|
||||
type contextKey string
|
||||
|
||||
const userContextKey = contextKey("user")
|
||||
|
||||
// ValidateAPIKey validates an API key and returns user information
|
||||
func (c *AuthConfig) ValidateAPIKey(key string) (*User, error) {
|
||||
if !c.Enabled {
|
||||
// Auth disabled - return default admin user for development
|
||||
return &User{Name: "default", Admin: true}, nil
|
||||
}
|
||||
|
||||
keyHash := HashAPIKey(key)
|
||||
|
||||
for username, entry := range c.APIKeys {
|
||||
if string(entry.Hash) == keyHash {
|
||||
// Build user with role and permission inheritance
|
||||
user := &User{
|
||||
Name: string(username),
|
||||
Admin: entry.Admin,
|
||||
Roles: entry.Roles,
|
||||
Permissions: make(map[string]bool),
|
||||
}
|
||||
|
||||
// Copy explicit permissions
|
||||
for perm, value := range entry.Permissions {
|
||||
user.Permissions[perm] = value
|
||||
}
|
||||
|
||||
// Add role-based permissions
|
||||
rolePerms := getRolePermissions(entry.Roles)
|
||||
for perm, value := range rolePerms {
|
||||
if _, exists := user.Permissions[perm]; !exists {
|
||||
user.Permissions[perm] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Admin gets all permissions
|
||||
if entry.Admin {
|
||||
user.Permissions["*"] = true
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid API key")
|
||||
}
|
||||
|
||||
// AuthMiddleware creates HTTP middleware for API key authentication
|
||||
func (c *AuthConfig) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !c.Enabled {
|
||||
if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") != "1" || os.Getenv("FETCH_ML_DEBUG") != "1" {
|
||||
http.Error(w, "Unauthorized: Authentication disabled", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
log.Println("WARNING: Insecure authentication bypass enabled: FETCH_ML_ALLOW_INSECURE_AUTH=1 and FETCH_ML_DEBUG=1; do NOT use this configuration in production.")
|
||||
ctx := context.WithValue(r.Context(), userContextKey, &User{Name: "default", Admin: true})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Only accept API key from header - no query parameters for security
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey == "" {
|
||||
http.Error(w, "Unauthorized: Missing API key in X-API-Key header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized: Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user to context
|
||||
ctx := context.WithValue(r.Context(), userContextKey, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserFromContext retrieves user from request context
|
||||
func GetUserFromContext(ctx context.Context) *User {
|
||||
if user, ok := ctx.Value(userContextKey).(*User); ok {
|
||||
return user
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequireAdmin creates middleware that requires admin privileges
|
||||
func RequireAdmin(next http.Handler) http.Handler {
|
||||
return RequirePermission("system:admin")(next)
|
||||
}
|
||||
|
||||
// RequirePermission creates middleware that requires specific permission
|
||||
func RequirePermission(permission string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasPermission(permission) {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HasPermission checks if user has a specific permission
|
||||
func (u *User) HasPermission(permission string) bool {
|
||||
// Wildcard permission grants all access
|
||||
if u.Permissions["*"] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Direct permission check
|
||||
if u.Permissions[permission] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Hierarchical permission check (e.g., "jobs:create" matches "jobs")
|
||||
parts := strings.Split(permission, ":")
|
||||
for i := 1; i < len(parts); i++ {
|
||||
partial := strings.Join(parts[:i], ":")
|
||||
if u.Permissions[partial] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasRole checks if user has a specific role
|
||||
func (u *User) HasRole(role string) bool {
|
||||
for _, userRole := range u.Roles {
|
||||
if userRole == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getRolePermissions returns permissions for given roles
|
||||
func getRolePermissions(roles []string) map[string]bool {
|
||||
permissions := make(map[string]bool)
|
||||
|
||||
// Use YAML permission manager if available
|
||||
if pm := GetGlobalPermissionManager(); pm != nil && pm.loaded {
|
||||
for _, role := range roles {
|
||||
rolePerms := pm.GetRolePermissions(role)
|
||||
for perm, value := range rolePerms {
|
||||
permissions[perm] = value
|
||||
}
|
||||
}
|
||||
return permissions
|
||||
}
|
||||
|
||||
// Fallback to inline permissions
|
||||
rolePermissions := map[string]map[string]bool{
|
||||
"admin": {"*": true},
|
||||
"data_scientist": {
|
||||
"jobs:create": true, "jobs:read": true, "jobs:update": true,
|
||||
"data:read": true, "models:read": true,
|
||||
},
|
||||
"data_engineer": {
|
||||
"data:create": true, "data:read": true, "data:update": true, "data:delete": true,
|
||||
},
|
||||
"viewer": {
|
||||
"jobs:read": true, "data:read": true, "models:read": true, "metrics:read": true,
|
||||
},
|
||||
"operator": {
|
||||
"jobs:read": true, "jobs:update": true, "metrics:read": true, "system:read": true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
if rolePerms, exists := rolePermissions[role]; exists {
|
||||
for perm, value := range rolePerms {
|
||||
permissions[perm] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// GenerateAPIKey generates a new random API key
|
||||
func GenerateAPIKey() string {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(time.Now().String())))
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
// HashAPIKey creates a SHA256 hash of an API key
|
||||
func HashAPIKey(key string) string {
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
229
internal/auth/api_key_test.go
Normal file
229
internal/auth/api_key_test.go
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashAPIKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "known hash",
|
||||
key: "password",
|
||||
expected: "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8",
|
||||
},
|
||||
{
|
||||
name: "another known hash",
|
||||
key: "test",
|
||||
expected: "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := HashAPIKey(tt.key)
|
||||
if got != tt.expected {
|
||||
t.Errorf("HashAPIKey() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashAPIKeyConsistency(t *testing.T) {
|
||||
key := "my-secret-key"
|
||||
hash1 := HashAPIKey(key)
|
||||
hash2 := HashAPIKey(key)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("HashAPIKey() not consistent: %v != %v", hash1, hash2)
|
||||
}
|
||||
|
||||
if len(hash1) != 64 {
|
||||
t.Errorf("HashAPIKey() wrong length: got %d, want 64", len(hash1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAPIKey(t *testing.T) {
|
||||
// Test that it generates keys
|
||||
key1 := GenerateAPIKey()
|
||||
|
||||
if len(key1) != 64 {
|
||||
t.Errorf("GenerateAPIKey() length = %d, want 64", len(key1))
|
||||
}
|
||||
|
||||
// Test uniqueness (timing-based, should be different)
|
||||
key2 := GenerateAPIKey()
|
||||
|
||||
if key1 == key2 {
|
||||
t.Errorf("GenerateAPIKey() not unique: both generated %s", key1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHasPermission(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *User
|
||||
permission string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard grants all",
|
||||
user: &User{
|
||||
Permissions: map[string]bool{"*": true},
|
||||
},
|
||||
permission: "anything",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "direct permission",
|
||||
user: &User{
|
||||
Permissions: map[string]bool{"jobs:create": true},
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical permission match",
|
||||
user: &User{
|
||||
Permissions: map[string]bool{"jobs": true},
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no permission",
|
||||
user: &User{
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
},
|
||||
permission: "jobs:create",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.user.HasPermission(tt.permission)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHasRole(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *User
|
||||
role string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "has role",
|
||||
user: &User{
|
||||
Roles: []string{"admin", "user"},
|
||||
},
|
||||
role: "admin",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "does not have role",
|
||||
user: &User{
|
||||
Roles: []string{"user"},
|
||||
},
|
||||
role: "admin",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty roles",
|
||||
user: &User{
|
||||
Roles: []string{},
|
||||
},
|
||||
role: "admin",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.user.HasRole(tt.role)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasRole() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfigValidateAPIKey(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
Enabled: true,
|
||||
APIKeys: map[Username]APIKeyEntry{
|
||||
"testuser": {
|
||||
Hash: APIKeyHash(HashAPIKey("test-key")),
|
||||
Admin: false,
|
||||
Roles: []string{"user"},
|
||||
Permissions: map[string]bool{
|
||||
"jobs:read": true,
|
||||
},
|
||||
},
|
||||
"admin": {
|
||||
Hash: APIKeyHash(HashAPIKey("admin-key")),
|
||||
Admin: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantErr bool
|
||||
wantAdmin bool
|
||||
}{
|
||||
{
|
||||
name: "valid user key",
|
||||
key: "test-key",
|
||||
wantErr: false,
|
||||
wantAdmin: false,
|
||||
},
|
||||
{
|
||||
name: "valid admin key",
|
||||
key: "admin-key",
|
||||
wantErr: false,
|
||||
wantAdmin: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key",
|
||||
key: "wrong-key",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := config.ValidateAPIKey(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAPIKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && user.Admin != tt.wantAdmin {
|
||||
t.Errorf("ValidateAPIKey() admin = %v, want %v", user.Admin, tt.wantAdmin)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfigDisabled(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
user, err := config.ValidateAPIKey("any-key")
|
||||
if err != nil {
|
||||
t.Errorf("ValidateAPIKey() with auth disabled should not error: %v", err)
|
||||
}
|
||||
if !user.Admin {
|
||||
t.Error("ValidateAPIKey() with auth disabled should return admin user")
|
||||
}
|
||||
}
|
||||
210
internal/auth/database.go
Normal file
210
internal/auth/database.go
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// DatabaseAuthStore implements authentication using SQLite database
|
||||
type DatabaseAuthStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// APIKeyRecord represents an API key in the database
|
||||
type APIKeyRecord struct {
|
||||
ID int `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
Admin bool `json:"admin"`
|
||||
Roles string `json:"roles"` // JSON array
|
||||
Permissions string `json:"permissions"` // JSON object
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||
}
|
||||
|
||||
// NewDatabaseAuthStore creates a new database-backed auth store
|
||||
func NewDatabaseAuthStore(dbPath string) (*DatabaseAuthStore, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
store := &DatabaseAuthStore{db: db}
|
||||
if err := store.init(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// init creates the necessary database tables
|
||||
func (s *DatabaseAuthStore) init() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
key_hash TEXT NOT NULL UNIQUE,
|
||||
admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
roles TEXT NOT NULL DEFAULT '[]',
|
||||
permissions TEXT NOT NULL DEFAULT '{}',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME,
|
||||
revoked_at DATETIME,
|
||||
CHECK (json_valid(roles)),
|
||||
CHECK (json_valid(permissions))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31'));
|
||||
`
|
||||
|
||||
_, err := s.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateAPIKey checks if an API key is valid and returns user info
|
||||
func (s *DatabaseAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) {
|
||||
keyHash := HashAPIKey(key)
|
||||
|
||||
query := `
|
||||
SELECT user_id, admin, roles, permissions, expires_at, revoked_at
|
||||
FROM api_keys
|
||||
WHERE key_hash = ?
|
||||
AND (revoked_at IS NULL OR revoked_at > ?)
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
`
|
||||
|
||||
var userID string
|
||||
var admin bool
|
||||
var rolesJSON, permissionsJSON string
|
||||
var expiresAt, revokedAt sql.NullTime
|
||||
now := time.Now()
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, keyHash, now, now).Scan(
|
||||
&userID, &admin, &rolesJSON, &permissionsJSON, &expiresAt, &revokedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid API key")
|
||||
}
|
||||
return nil, fmt.Errorf("database error: %w", err)
|
||||
}
|
||||
|
||||
// Parse roles
|
||||
var roles []string
|
||||
if err := json.Unmarshal([]byte(rolesJSON), &roles); err != nil {
|
||||
log.Printf("Failed to parse roles for user %s: %v", userID, err)
|
||||
roles = []string{}
|
||||
}
|
||||
|
||||
// Parse permissions
|
||||
var permissions map[string]bool
|
||||
if err := json.Unmarshal([]byte(permissionsJSON), &permissions); err != nil {
|
||||
log.Printf("Failed to parse permissions for user %s: %v", userID, err)
|
||||
permissions = make(map[string]bool)
|
||||
}
|
||||
|
||||
// Admin gets all permissions
|
||||
if admin {
|
||||
permissions["*"] = true
|
||||
}
|
||||
|
||||
return &User{
|
||||
Name: userID,
|
||||
Admin: admin,
|
||||
Roles: roles,
|
||||
Permissions: permissions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateAPIKey creates a new API key in the database
|
||||
func (s *DatabaseAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error {
|
||||
rolesJSON, err := json.Marshal(roles)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal roles: %w", err)
|
||||
}
|
||||
|
||||
permissionsJSON, err := json.Marshal(permissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal permissions: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO api_keys (user_id, key_hash, admin, roles, permissions, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
key_hash = excluded.key_hash,
|
||||
admin = excluded.admin,
|
||||
roles = excluded.roles,
|
||||
permissions = excluded.permissions,
|
||||
expires_at = excluded.expires_at,
|
||||
revoked_at = NULL
|
||||
`
|
||||
|
||||
_, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// RevokeAPIKey revokes an API key
|
||||
func (s *DatabaseAuthStore) RevokeAPIKey(ctx context.Context, userID string) error {
|
||||
query := `UPDATE api_keys SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`
|
||||
_, err := s.db.ExecContext(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListUsers returns all active users
|
||||
func (s *DatabaseAuthStore) ListUsers(ctx context.Context) ([]APIKeyRecord, error) {
|
||||
query := `
|
||||
SELECT id, user_id, key_hash, admin, roles, permissions, created_at, expires_at, revoked_at
|
||||
FROM api_keys
|
||||
WHERE revoked_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query users: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []APIKeyRecord
|
||||
for rows.Next() {
|
||||
var user APIKeyRecord
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.UserID, &user.KeyHash, &user.Admin,
|
||||
&user.Roles, &user.Permissions, &user.CreatedAt,
|
||||
&user.ExpiresAt, &user.RevokedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan user: %w", err)
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CleanupExpiredKeys removes expired and revoked keys
|
||||
func (s *DatabaseAuthStore) CleanupExpiredKeys(ctx context.Context) error {
|
||||
query := `
|
||||
DELETE FROM api_keys
|
||||
WHERE (revoked_at IS NOT NULL AND revoked_at < datetime('now', '-30 days'))
|
||||
OR (expires_at IS NOT NULL AND expires_at < datetime('now', '-7 days'))
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (s *DatabaseAuthStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
122
internal/auth/flags.go
Normal file
122
internal/auth/flags.go
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthFlags holds authentication-related command line flags
|
||||
type AuthFlags struct {
|
||||
APIKey string
|
||||
APIKeyFile string
|
||||
ConfigFile string
|
||||
EnableAuth bool
|
||||
ShowHelp bool
|
||||
}
|
||||
|
||||
// ParseAuthFlags parses authentication command line flags
|
||||
func ParseAuthFlags() *AuthFlags {
|
||||
flags := &AuthFlags{}
|
||||
|
||||
flag.StringVar(&flags.APIKey, "api-key", "", "API key for authentication")
|
||||
flag.StringVar(&flags.APIKeyFile, "api-key-file", "", "Path to file containing API key")
|
||||
flag.StringVar(&flags.ConfigFile, "config", "", "Configuration file path")
|
||||
flag.BoolVar(&flags.EnableAuth, "enable-auth", false, "Enable authentication")
|
||||
flag.BoolVar(&flags.ShowHelp, "auth-help", false, "Show authentication help")
|
||||
|
||||
// Custom help flag that doesn't exit
|
||||
flag.Usage = func() {}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
// GetAPIKeyFromSources gets API key from multiple sources in priority order
|
||||
func GetAPIKeyFromSources(flags *AuthFlags) string {
|
||||
// 1. Command line flag (highest priority)
|
||||
if flags.APIKey != "" {
|
||||
return flags.APIKey
|
||||
}
|
||||
|
||||
// 2. Explicit file flag
|
||||
if flags.APIKeyFile != "" {
|
||||
contents, readErr := os.ReadFile(flags.APIKeyFile)
|
||||
if readErr == nil {
|
||||
return strings.TrimSpace(string(contents))
|
||||
}
|
||||
log.Printf("Warning: Could not read API key file %s: %v", flags.APIKeyFile, readErr)
|
||||
}
|
||||
|
||||
// 3. Environment variable
|
||||
if envKey := os.Getenv("FETCH_ML_API_KEY"); envKey != "" {
|
||||
return envKey
|
||||
}
|
||||
|
||||
// 4. File-based key (for automated scripts)
|
||||
if fileKey := os.Getenv("FETCH_ML_API_KEY_FILE"); fileKey != "" {
|
||||
content, err := os.ReadFile(fileKey)
|
||||
if err == nil {
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
log.Printf("Warning: Could not read API key file %s: %v", fileKey, err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ValidateAuthFlags validates parsed authentication flags
|
||||
func ValidateAuthFlags(flags *AuthFlags) error {
|
||||
if flags.ShowHelp {
|
||||
PrintAuthHelp()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if flags.APIKeyFile != "" {
|
||||
if _, err := os.Stat(flags.APIKeyFile); err != nil {
|
||||
return fmt.Errorf("api key file not found: %s", flags.APIKeyFile)
|
||||
}
|
||||
if err := CheckConfigFilePermissions(flags.APIKeyFile); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If config file is specified, check if it exists
|
||||
if flags.ConfigFile != "" {
|
||||
if _, err := os.Stat(flags.ConfigFile); err != nil {
|
||||
return fmt.Errorf("config file not found: %s", flags.ConfigFile)
|
||||
}
|
||||
|
||||
// Check file permissions
|
||||
if err := CheckConfigFilePermissions(flags.ConfigFile); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintAuthHelp prints authentication-specific help
|
||||
func PrintAuthHelp() {
|
||||
fmt.Println("Authentication Options:")
|
||||
fmt.Println(" --api-key <key> API key for authentication")
|
||||
fmt.Println(" --api-key-file <path> Read API key from file")
|
||||
fmt.Println(" --config <file> Configuration file path")
|
||||
fmt.Println(" --enable-auth Enable authentication (if disabled)")
|
||||
fmt.Println(" --auth-help Show this help")
|
||||
fmt.Println()
|
||||
fmt.Println("Environment Variables:")
|
||||
fmt.Println(" FETCH_ML_API_KEY API key for authentication")
|
||||
fmt.Println(" FETCH_ML_API_KEY_FILE File containing API key")
|
||||
fmt.Println(" FETCH_ML_ENV Environment (development/production)")
|
||||
fmt.Println(" FETCH_ML_ALLOW_INSECURE_AUTH Allow insecure auth (dev only)")
|
||||
fmt.Println()
|
||||
fmt.Println("Security Notes:")
|
||||
fmt.Println(" - API keys in command line may be visible in process lists")
|
||||
fmt.Println(" - Environment variables are preferred for automated scripts")
|
||||
fmt.Println(" - File-based keys should have restricted permissions (600)")
|
||||
fmt.Println(" - Authentication is mandatory in production environments")
|
||||
}
|
||||
275
internal/auth/hybrid.go
Normal file
275
internal/auth/hybrid.go
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridAuthStore combines file-based and database authentication
|
||||
// Falls back to file config if database is not available
|
||||
type HybridAuthStore struct {
|
||||
fileStore *AuthConfig
|
||||
dbStore *DatabaseAuthStore
|
||||
useDB bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHybridAuthStore creates a hybrid auth store
|
||||
func NewHybridAuthStore(config *AuthConfig, dbPath string) (*HybridAuthStore, error) {
|
||||
hybrid := &HybridAuthStore{
|
||||
fileStore: config,
|
||||
useDB: false,
|
||||
}
|
||||
|
||||
// Try to initialize database store
|
||||
if dbPath != "" {
|
||||
dbStore, err := NewDatabaseAuthStore(dbPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to initialize database auth store, falling back to file: %v", err)
|
||||
} else {
|
||||
hybrid.dbStore = dbStore
|
||||
hybrid.useDB = true
|
||||
log.Printf("Using database authentication store")
|
||||
}
|
||||
}
|
||||
|
||||
// If database is available, migrate file-based keys to database
|
||||
if hybrid.useDB && config.Enabled && len(config.APIKeys) > 0 {
|
||||
if err := hybrid.migrateFileToDatabase(context.Background()); err != nil {
|
||||
log.Printf("Failed to migrate file keys to database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return hybrid, nil
|
||||
}
|
||||
|
||||
// ValidateAPIKey validates an API key using either database or file store
|
||||
func (h *HybridAuthStore) ValidateAPIKey(ctx context.Context, key string) (*User, error) {
|
||||
h.mu.RLock()
|
||||
useDB := h.useDB
|
||||
h.mu.RUnlock()
|
||||
|
||||
if useDB {
|
||||
user, err := h.dbStore.ValidateAPIKey(ctx, key)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// If database fails, fall back to file store
|
||||
log.Printf("Database auth failed, falling back to file store: %v", err)
|
||||
return h.fileStore.ValidateAPIKey(key)
|
||||
}
|
||||
|
||||
// Use file store
|
||||
return h.fileStore.ValidateAPIKey(key)
|
||||
}
|
||||
|
||||
// CreateAPIKey creates an API key using the preferred store
|
||||
func (h *HybridAuthStore) CreateAPIKey(ctx context.Context, userID string, keyHash string, admin bool, roles []string, permissions map[string]bool, expiresAt *time.Time) error {
|
||||
h.mu.RLock()
|
||||
useDB := h.useDB
|
||||
h.mu.RUnlock()
|
||||
|
||||
if useDB {
|
||||
err := h.dbStore.CreateAPIKey(ctx, userID, keyHash, admin, roles, permissions, expiresAt)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If database fails, fall back to file store
|
||||
log.Printf("Database key creation failed, using file store: %v", err)
|
||||
return h.createFileAPIKey(userID, keyHash, admin, roles, permissions)
|
||||
}
|
||||
|
||||
// Use file store
|
||||
return h.createFileAPIKey(userID, keyHash, admin, roles, permissions)
|
||||
}
|
||||
|
||||
// createFileAPIKey creates an API key in the file store
|
||||
func (h *HybridAuthStore) createFileAPIKey(userID string, keyHash string, admin bool, roles []string, permissions map[string]bool) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.fileStore.APIKeys == nil {
|
||||
h.fileStore.APIKeys = make(map[Username]APIKeyEntry)
|
||||
}
|
||||
|
||||
h.fileStore.APIKeys[Username(userID)] = APIKeyEntry{
|
||||
Hash: APIKeyHash(keyHash),
|
||||
Admin: admin,
|
||||
Roles: roles,
|
||||
Permissions: permissions,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAPIKey revokes an API key
|
||||
func (h *HybridAuthStore) RevokeAPIKey(ctx context.Context, userID string) error {
|
||||
h.mu.RLock()
|
||||
useDB := h.useDB
|
||||
h.mu.RUnlock()
|
||||
|
||||
if useDB {
|
||||
err := h.dbStore.RevokeAPIKey(ctx, userID)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("Database key revocation failed: %v", err)
|
||||
}
|
||||
|
||||
// Remove from file store
|
||||
h.mu.Lock()
|
||||
delete(h.fileStore.APIKeys, Username(userID))
|
||||
h.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUsers returns all users from the active store
|
||||
func (h *HybridAuthStore) ListUsers(ctx context.Context) ([]UserInfo, error) {
|
||||
h.mu.RLock()
|
||||
useDB := h.useDB
|
||||
h.mu.RUnlock()
|
||||
|
||||
if useDB {
|
||||
records, err := h.dbStore.ListUsers(ctx)
|
||||
if err == nil {
|
||||
users := make([]UserInfo, len(records))
|
||||
for i, record := range records {
|
||||
users[i] = UserInfo{
|
||||
UserID: record.UserID,
|
||||
Admin: record.Admin,
|
||||
KeyHash: record.KeyHash,
|
||||
Created: record.CreatedAt,
|
||||
Expires: record.ExpiresAt,
|
||||
Revoked: record.RevokedAt,
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
log.Printf("Database user listing failed: %v", err)
|
||||
}
|
||||
|
||||
// Use file store
|
||||
return h.listFileUsers()
|
||||
}
|
||||
|
||||
// UserInfo represents user information for listing
|
||||
type UserInfo struct {
|
||||
UserID string `json:"user_id"`
|
||||
Admin bool `json:"admin"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
Created time.Time `json:"created"`
|
||||
Expires *time.Time `json:"expires,omitempty"`
|
||||
Revoked *time.Time `json:"revoked,omitempty"`
|
||||
}
|
||||
|
||||
// listFileUsers returns users from file store
|
||||
func (h *HybridAuthStore) listFileUsers() ([]UserInfo, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
var users []UserInfo
|
||||
for username, entry := range h.fileStore.APIKeys {
|
||||
users = append(users, UserInfo{
|
||||
UserID: string(username),
|
||||
Admin: entry.Admin,
|
||||
KeyHash: string(entry.Hash),
|
||||
Created: time.Now(), // File store doesn't track creation time
|
||||
})
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// migrateFileToDatabase migrates file-based keys to database
|
||||
func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error {
|
||||
if h.fileStore == nil || len(h.fileStore.APIKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("Migrating %d API keys from file to database...", len(h.fileStore.APIKeys))
|
||||
|
||||
for username, entry := range h.fileStore.APIKeys {
|
||||
userID := string(username)
|
||||
err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to migrate key for user %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
log.Printf("Migrated key for user: %s", userID)
|
||||
}
|
||||
|
||||
log.Printf("Migration completed. Consider removing keys from config file.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchToDatabase attempts to switch to database authentication
|
||||
func (h *HybridAuthStore) SwitchToDatabase(dbPath string) error {
|
||||
dbStore, err := NewDatabaseAuthStore(dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database store: %w", err)
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Close existing database if any
|
||||
if h.dbStore != nil {
|
||||
h.dbStore.Close()
|
||||
}
|
||||
|
||||
h.dbStore = dbStore
|
||||
h.useDB = true
|
||||
|
||||
// Migrate existing keys
|
||||
if h.fileStore.Enabled && len(h.fileStore.APIKeys) > 0 {
|
||||
if err := h.migrateFileToDatabase(context.Background()); err != nil {
|
||||
log.Printf("Migration warning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (h *HybridAuthStore) Close() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.dbStore != nil {
|
||||
return h.dbStore.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDatabaseStats returns database statistics
|
||||
func (h *HybridAuthStore) GetDatabaseStats(ctx context.Context) (map[string]interface{}, error) {
|
||||
h.mu.RLock()
|
||||
useDB := h.useDB
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !useDB {
|
||||
return map[string]interface{}{
|
||||
"store_type": "file",
|
||||
"users": len(h.fileStore.APIKeys),
|
||||
}, nil
|
||||
}
|
||||
|
||||
users, err := h.dbStore.ListUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"store_type": "database",
|
||||
"users": len(users),
|
||||
"path": "db/fetch_ml.db",
|
||||
}, nil
|
||||
}
|
||||
167
internal/auth/keychain.go
Normal file
167
internal/auth/keychain.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zalando/go-keyring"
|
||||
)
|
||||
|
||||
// KeychainManager provides secure storage for API keys.
|
||||
type KeychainManager struct {
|
||||
primary systemKeyring
|
||||
fallback *fileKeyStore
|
||||
}
|
||||
|
||||
// systemKeyring abstracts go-keyring for easier testing.
|
||||
type systemKeyring interface {
|
||||
Set(service, account, secret string) error
|
||||
Get(service, account string) (string, error)
|
||||
Delete(service, account string) error
|
||||
}
|
||||
|
||||
type goKeyring struct{}
|
||||
|
||||
func (goKeyring) Set(service, account, secret string) error {
|
||||
return keyring.Set(service, account, secret)
|
||||
}
|
||||
|
||||
func (goKeyring) Get(service, account string) (string, error) {
|
||||
return keyring.Get(service, account)
|
||||
}
|
||||
|
||||
func (goKeyring) Delete(service, account string) error {
|
||||
return keyring.Delete(service, account)
|
||||
}
|
||||
|
||||
// NewKeychainManager returns a manager backed by the OS keyring with a secure file fallback.
|
||||
func NewKeychainManager() *KeychainManager {
|
||||
return newKeychainManagerWithKeyring(goKeyring{}, defaultFallbackDir())
|
||||
}
|
||||
|
||||
func newKeychainManagerWithKeyring(kr systemKeyring, fallbackDir string) *KeychainManager {
|
||||
if fallbackDir == "" {
|
||||
fallbackDir = defaultFallbackDir()
|
||||
}
|
||||
return &KeychainManager{
|
||||
primary: kr,
|
||||
fallback: newFileKeyStore(fallbackDir),
|
||||
}
|
||||
}
|
||||
|
||||
func defaultFallbackDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || home == "" {
|
||||
return filepath.Join(os.TempDir(), "fetch_ml", "keys")
|
||||
}
|
||||
return filepath.Join(home, ".fetch_ml", "keys")
|
||||
}
|
||||
|
||||
// StoreAPIKey stores the key in the OS keyring, falling back to a protected file when needed.
|
||||
func (km *KeychainManager) StoreAPIKey(service, account, key string) error {
|
||||
if err := km.primary.Set(service, account, key); err == nil {
|
||||
return nil
|
||||
} else if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) {
|
||||
return km.fallback.store(service, account, key)
|
||||
} else if fallbackErr := km.fallback.store(service, account, key); fallbackErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to store API key")
|
||||
}
|
||||
|
||||
// GetAPIKey retrieves a key from the OS keyring or fallback store.
|
||||
func (km *KeychainManager) GetAPIKey(service, account string) (string, error) {
|
||||
secret, err := km.primary.Get(service, account)
|
||||
if err == nil {
|
||||
return secret, nil
|
||||
}
|
||||
if errors.Is(err, keyring.ErrUnsupportedPlatform) || errors.Is(err, keyring.ErrNotFound) {
|
||||
return km.fallback.get(service, account)
|
||||
}
|
||||
// Unknown error - try fallback before surfacing
|
||||
if fallbackSecret, ferr := km.fallback.get(service, account); ferr == nil {
|
||||
return fallbackSecret, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to retrieve API key")
|
||||
}
|
||||
|
||||
// DeleteAPIKey removes a key from both stores.
|
||||
func (km *KeychainManager) DeleteAPIKey(service, account string) error {
|
||||
if err := km.primary.Delete(service, account); err != nil && !errors.Is(err, keyring.ErrNotFound) && !errors.Is(err, keyring.ErrUnsupportedPlatform) {
|
||||
return fmt.Errorf("failed to delete API key: %w", err)
|
||||
}
|
||||
if err := km.fallback.delete(service, account); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAvailable reports whether the OS keyring backend is usable.
|
||||
func (km *KeychainManager) IsAvailable() bool {
|
||||
_, err := km.primary.Get("fetch_ml_probe", fmt.Sprintf("probe_%d", time.Now().UnixNano()))
|
||||
return err == nil || !errors.Is(err, keyring.ErrUnsupportedPlatform)
|
||||
}
|
||||
|
||||
// ListAvailableMethods returns backends the manager can use.
|
||||
func (km *KeychainManager) ListAvailableMethods() []string {
|
||||
methods := []string{}
|
||||
if km.IsAvailable() {
|
||||
methods = append(methods, "OS keyring")
|
||||
}
|
||||
methods = append(methods, fmt.Sprintf("Encrypted file (%s)", km.fallback.baseDir))
|
||||
return methods
|
||||
}
|
||||
|
||||
// fileKeyStore stores secrets with 0600 permissions as a fallback.
|
||||
type fileKeyStore struct {
|
||||
baseDir string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newFileKeyStore(baseDir string) *fileKeyStore {
|
||||
return &fileKeyStore{baseDir: baseDir}
|
||||
}
|
||||
|
||||
func (f *fileKeyStore) store(service, account, secret string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if err := os.MkdirAll(f.baseDir, 0o700); err != nil {
|
||||
return fmt.Errorf("failed to prepare key store: %w", err)
|
||||
}
|
||||
path := f.path(service, account)
|
||||
return os.WriteFile(path, []byte(secret), 0o600)
|
||||
}
|
||||
|
||||
func (f *fileKeyStore) get(service, account string) (string, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
data, err := os.ReadFile(f.path(service, account))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func (f *fileKeyStore) delete(service, account string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
path := f.path(service, account)
|
||||
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileKeyStore) path(service, account string) string {
|
||||
return filepath.Join(f.baseDir, fmt.Sprintf("%s_%s.key", sanitize(service), sanitize(account)))
|
||||
}
|
||||
|
||||
func sanitize(value string) string {
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", "..", "_", " ", "_", "\t", "_")
|
||||
return replacer.Replace(value)
|
||||
}
|
||||
129
internal/auth/keychain_test.go
Normal file
129
internal/auth/keychain_test.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/zalando/go-keyring"
|
||||
)
|
||||
|
||||
type fakeKeyring struct {
|
||||
secrets map[string]string
|
||||
setErr error
|
||||
getErr error
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
func newFakeKeyring() *fakeKeyring {
|
||||
return &fakeKeyring{secrets: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (f *fakeKeyring) Set(service, account, secret string) error {
|
||||
if f.setErr != nil {
|
||||
return f.setErr
|
||||
}
|
||||
f.secrets[key(service, account)] = secret
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeKeyring) Get(service, account string) (string, error) {
|
||||
if f.getErr != nil {
|
||||
return "", f.getErr
|
||||
}
|
||||
if secret, ok := f.secrets[key(service, account)]; ok {
|
||||
return secret, nil
|
||||
}
|
||||
return "", keyring.ErrNotFound
|
||||
}
|
||||
|
||||
func (f *fakeKeyring) Delete(service, account string) error {
|
||||
if f.deleteErr != nil {
|
||||
return f.deleteErr
|
||||
}
|
||||
delete(f.secrets, key(service, account))
|
||||
return nil
|
||||
}
|
||||
|
||||
func key(service, account string) string {
|
||||
return service + ":" + account
|
||||
}
|
||||
|
||||
func newTestManager(t *testing.T, kr systemKeyring) (*KeychainManager, string) {
|
||||
t.Helper()
|
||||
baseDir := t.TempDir()
|
||||
return newKeychainManagerWithKeyring(kr, baseDir), baseDir
|
||||
}
|
||||
|
||||
func TestKeychainStoreAndGetPrimary(t *testing.T) {
|
||||
kr := newFakeKeyring()
|
||||
km, baseDir := newTestManager(t, kr)
|
||||
|
||||
if err := km.StoreAPIKey("fetch-ml", "alice", "super-secret"); err != nil {
|
||||
t.Fatalf("StoreAPIKey failed: %v", err)
|
||||
}
|
||||
|
||||
got, err := km.GetAPIKey("fetch-ml", "alice")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAPIKey failed: %v", err)
|
||||
}
|
||||
if got != "super-secret" {
|
||||
t.Fatalf("expected secret to be stored in primary keyring")
|
||||
}
|
||||
|
||||
// Ensure fallback file was not created when primary succeeds
|
||||
path := filepath.Join(baseDir, filepath.Base(km.fallback.path("fetch-ml", "alice")))
|
||||
if _, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("expected no fallback file, got err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeychainFallbackWhenUnsupported(t *testing.T) {
|
||||
kr := newFakeKeyring()
|
||||
kr.setErr = keyring.ErrUnsupportedPlatform
|
||||
kr.getErr = keyring.ErrUnsupportedPlatform
|
||||
kr.deleteErr = keyring.ErrUnsupportedPlatform
|
||||
km, _ := newTestManager(t, kr)
|
||||
|
||||
if err := km.StoreAPIKey("fetch-ml", "bob", "fallback-secret"); err != nil {
|
||||
t.Fatalf("StoreAPIKey should fallback: %v", err)
|
||||
}
|
||||
|
||||
got, err := km.GetAPIKey("fetch-ml", "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAPIKey should use fallback: %v", err)
|
||||
}
|
||||
if got != "fallback-secret" {
|
||||
t.Fatalf("expected fallback secret, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeychainDeleteRemovesFallback(t *testing.T) {
|
||||
kr := newFakeKeyring()
|
||||
kr.deleteErr = keyring.ErrNotFound
|
||||
km, _ := newTestManager(t, kr)
|
||||
|
||||
if err := km.fallback.store("fetch-ml", "carol", "temp"); err != nil {
|
||||
t.Fatalf("failed to seed fallback store: %v", err)
|
||||
}
|
||||
|
||||
if err := km.DeleteAPIKey("fetch-ml", "carol"); err != nil {
|
||||
t.Fatalf("DeleteAPIKey failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := km.fallback.get("fetch-ml", "carol"); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("expected fallback secret removed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAvailableMethodsIncludesFallback(t *testing.T) {
|
||||
kr := newFakeKeyring()
|
||||
kr.getErr = keyring.ErrUnsupportedPlatform
|
||||
km, _ := newTestManager(t, kr)
|
||||
|
||||
methods := km.ListAvailableMethods()
|
||||
if len(methods) != 1 || methods[0] == "OS keyring" {
|
||||
t.Fatalf("expected only fallback method, got %v", methods)
|
||||
}
|
||||
}
|
||||
192
internal/auth/permissions.go
Normal file
192
internal/auth/permissions.go
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Permission constants for type safety
|
||||
const (
|
||||
// Job permissions
|
||||
PermissionJobsCreate = "jobs:create"
|
||||
PermissionJobsRead = "jobs:read"
|
||||
PermissionJobsUpdate = "jobs:update"
|
||||
PermissionJobsDelete = "jobs:delete"
|
||||
|
||||
// Data permissions
|
||||
PermissionDataCreate = "data:create"
|
||||
PermissionDataRead = "data:read"
|
||||
PermissionDataUpdate = "data:update"
|
||||
PermissionDataDelete = "data:delete"
|
||||
|
||||
// Model permissions
|
||||
PermissionModelsCreate = "models:create"
|
||||
PermissionModelsRead = "models:read"
|
||||
PermissionModelsUpdate = "models:update"
|
||||
PermissionModelsDelete = "models:delete"
|
||||
|
||||
// System permissions
|
||||
PermissionSystemConfig = "system:config"
|
||||
PermissionSystemMetrics = "system:metrics"
|
||||
PermissionSystemLogs = "system:logs"
|
||||
PermissionSystemUsers = "system:users"
|
||||
|
||||
// Wildcard permission
|
||||
PermissionAll = "*"
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleDataScientist = "data_scientist"
|
||||
RoleDataEngineer = "data_engineer"
|
||||
RoleViewer = "viewer"
|
||||
RoleOperator = "operator"
|
||||
)
|
||||
|
||||
// PermissionGroup represents a group of related permissions
|
||||
type PermissionGroup struct {
|
||||
Name string
|
||||
Permissions []string
|
||||
Description string
|
||||
}
|
||||
|
||||
// Built-in permission groups
|
||||
var PermissionGroups = map[string]PermissionGroup{
|
||||
"full_access": {
|
||||
Name: "Full Access",
|
||||
Permissions: []string{PermissionAll},
|
||||
Description: "Complete system access",
|
||||
},
|
||||
"job_management": {
|
||||
Name: "Job Management",
|
||||
Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete},
|
||||
Description: "Create, read, update, and delete ML jobs",
|
||||
},
|
||||
"data_access": {
|
||||
Name: "Data Access",
|
||||
Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete},
|
||||
Description: "Access and manage datasets",
|
||||
},
|
||||
"readonly": {
|
||||
Name: "Read Only",
|
||||
Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics},
|
||||
Description: "View-only access to system resources",
|
||||
},
|
||||
"system_admin": {
|
||||
Name: "System Administration",
|
||||
Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics},
|
||||
Description: "System configuration and user management",
|
||||
},
|
||||
}
|
||||
|
||||
// GetPermissionGroup returns a permission group by name
|
||||
func GetPermissionGroup(name string) (PermissionGroup, bool) {
|
||||
group, exists := PermissionGroups[name]
|
||||
return group, exists
|
||||
}
|
||||
|
||||
// ValidatePermission checks if a permission string is valid
|
||||
func ValidatePermission(permission string) error {
|
||||
if permission == PermissionAll {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if permission matches known patterns
|
||||
validPrefixes := []string{"jobs:", "data:", "models:", "system:"}
|
||||
for _, prefix := range validPrefixes {
|
||||
if strings.HasPrefix(permission, prefix) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid permission format: %s", permission)
|
||||
}
|
||||
|
||||
// ValidateRole checks if a role is valid
|
||||
func ValidateRole(role string) error {
|
||||
validRoles := []string{RoleAdmin, RoleDataScientist, RoleDataEngineer, RoleViewer, RoleOperator}
|
||||
for _, validRole := range validRoles {
|
||||
if role == validRole {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid role: %s", role)
|
||||
}
|
||||
|
||||
// ExpandPermissionGroups converts permission group names to actual permissions
|
||||
func ExpandPermissionGroups(groups []string) ([]string, error) {
|
||||
var permissions []string
|
||||
|
||||
for _, groupName := range groups {
|
||||
if groupName == PermissionAll {
|
||||
return []string{PermissionAll}, nil
|
||||
}
|
||||
|
||||
group, exists := GetPermissionGroup(groupName)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown permission group: %s", groupName)
|
||||
}
|
||||
|
||||
permissions = append(permissions, group.Permissions...)
|
||||
}
|
||||
|
||||
// Remove duplicates
|
||||
unique := make(map[string]bool)
|
||||
for _, perm := range permissions {
|
||||
unique[perm] = true
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(unique))
|
||||
for perm := range unique {
|
||||
result = append(result, perm)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// PermissionCheckResult represents the result of a permission check
|
||||
type PermissionCheckResult struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
Permission string `json:"permission"`
|
||||
User string `json:"user"`
|
||||
Roles []string `json:"roles"`
|
||||
Missing []string `json:"missing,omitempty"`
|
||||
}
|
||||
|
||||
// CheckMultiplePermissions checks multiple permissions at once
|
||||
func (u *User) CheckMultiplePermissions(permissions []string) []PermissionCheckResult {
|
||||
results := make([]PermissionCheckResult, len(permissions))
|
||||
|
||||
for i, permission := range permissions {
|
||||
allowed := u.HasPermission(permission)
|
||||
missing := []string{}
|
||||
if !allowed {
|
||||
missing = []string{permission}
|
||||
}
|
||||
|
||||
results[i] = PermissionCheckResult{
|
||||
Allowed: allowed,
|
||||
Permission: permission,
|
||||
User: u.Name,
|
||||
Roles: u.Roles,
|
||||
Missing: missing,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// GetEffectivePermissions returns all effective permissions for a user
|
||||
func (u *User) GetEffectivePermissions() []string {
|
||||
if u.Permissions[PermissionAll] {
|
||||
return []string{PermissionAll}
|
||||
}
|
||||
|
||||
permissions := make([]string, 0, len(u.Permissions))
|
||||
for perm := range u.Permissions {
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
295
internal/auth/permissions_loader.go
Normal file
295
internal/auth/permissions_loader.go
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// PermissionConfig represents the permissions configuration
|
||||
type PermissionConfig struct {
|
||||
Roles map[string]RoleConfig `yaml:"roles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"`
|
||||
Hierarchy map[string]HierarchyConfig `yaml:"hierarchy"`
|
||||
Defaults DefaultsConfig `yaml:"defaults"`
|
||||
}
|
||||
|
||||
// RoleConfig defines a role and its permissions
|
||||
type RoleConfig struct {
|
||||
Description string `yaml:"description"`
|
||||
Permissions []string `yaml:"permissions"`
|
||||
}
|
||||
|
||||
// GroupConfig defines a permission group
|
||||
type GroupConfig struct {
|
||||
Description string `yaml:"description"`
|
||||
Inherits []string `yaml:"inherits"`
|
||||
Permissions []string `yaml:"permissions"`
|
||||
}
|
||||
|
||||
// HierarchyConfig defines resource hierarchy
|
||||
type HierarchyConfig struct {
|
||||
Children map[string]interface{} `yaml:"children"`
|
||||
Special map[string]string `yaml:"special"`
|
||||
}
|
||||
|
||||
// DefaultsConfig defines default settings
|
||||
type DefaultsConfig struct {
|
||||
NewUserRole string `yaml:"new_user_role"`
|
||||
AdminUsers []string `yaml:"admin_users"`
|
||||
}
|
||||
|
||||
// PermissionManager manages permissions from YAML file
|
||||
type PermissionManager struct {
|
||||
config *PermissionConfig
|
||||
rolePerms map[string]map[string]bool
|
||||
groupPerms map[string]map[string]bool
|
||||
mu sync.RWMutex
|
||||
loaded bool
|
||||
}
|
||||
|
||||
// NewPermissionManager creates a new permission manager
|
||||
func NewPermissionManager(configPath string) (*PermissionManager, error) {
|
||||
pm := &PermissionManager{}
|
||||
|
||||
if err := pm.loadConfig(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to load permissions: %w", err)
|
||||
}
|
||||
|
||||
return pm, nil
|
||||
}
|
||||
|
||||
// loadConfig loads permissions from YAML file
|
||||
func (pm *PermissionManager) loadConfig(configPath string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read permissions file: %w", err)
|
||||
}
|
||||
|
||||
var config PermissionConfig
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse permissions file: %w", err)
|
||||
}
|
||||
|
||||
pm.config = &config
|
||||
pm.rolePerms = make(map[string]map[string]bool)
|
||||
pm.groupPerms = make(map[string]map[string]bool)
|
||||
|
||||
// Process role permissions
|
||||
for roleName, role := range config.Roles {
|
||||
perms := make(map[string]bool)
|
||||
for _, perm := range role.Permissions {
|
||||
perms[perm] = true
|
||||
}
|
||||
pm.rolePerms[roleName] = perms
|
||||
}
|
||||
|
||||
// Process group permissions
|
||||
for groupName, group := range config.Groups {
|
||||
perms := make(map[string]bool)
|
||||
|
||||
// Add direct permissions
|
||||
for _, perm := range group.Permissions {
|
||||
perms[perm] = true
|
||||
}
|
||||
|
||||
// Inherit permissions from other roles/groups
|
||||
for _, inherit := range group.Inherits {
|
||||
if rolePerms, exists := pm.rolePerms[inherit]; exists {
|
||||
for perm, value := range rolePerms {
|
||||
perms[perm] = value
|
||||
}
|
||||
}
|
||||
if groupPerms, exists := pm.groupPerms[inherit]; exists {
|
||||
for perm, value := range groupPerms {
|
||||
perms[perm] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pm.groupPerms[groupName] = perms
|
||||
}
|
||||
|
||||
pm.loaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRolePermissions returns permissions for a role
|
||||
func (pm *PermissionManager) GetRolePermissions(role string) map[string]bool {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return make(map[string]bool)
|
||||
}
|
||||
|
||||
if perms, exists := pm.rolePerms[role]; exists {
|
||||
result := make(map[string]bool)
|
||||
for perm, value := range perms {
|
||||
result[perm] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return make(map[string]bool)
|
||||
}
|
||||
|
||||
// GetGroupPermissions returns permissions for a group
|
||||
func (pm *PermissionManager) GetGroupPermissions(group string) map[string]bool {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return make(map[string]bool)
|
||||
}
|
||||
|
||||
if perms, exists := pm.groupPerms[group]; exists {
|
||||
result := make(map[string]bool)
|
||||
for perm, value := range perms {
|
||||
result[perm] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return make(map[string]bool)
|
||||
}
|
||||
|
||||
// GetAllRoles returns all available roles
|
||||
func (pm *PermissionManager) GetAllRoles() map[string]RoleConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return make(map[string]RoleConfig)
|
||||
}
|
||||
|
||||
result := make(map[string]RoleConfig)
|
||||
for name, role := range pm.config.Roles {
|
||||
result[name] = role
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetAllGroups returns all available groups
|
||||
func (pm *PermissionManager) GetAllGroups() map[string]GroupConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
result := make(map[string]GroupConfig)
|
||||
for name, group := range pm.config.Groups {
|
||||
result[name] = group
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDefaultRole returns the default role for new users
|
||||
func (pm *PermissionManager) GetDefaultRole() string {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded || pm.config.Defaults.NewUserRole == "" {
|
||||
return "viewer"
|
||||
}
|
||||
|
||||
return pm.config.Defaults.NewUserRole
|
||||
}
|
||||
|
||||
// IsAdminUser checks if a username should have admin rights
|
||||
func (pm *PermissionManager) IsAdminUser(username string) bool {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, adminUser := range pm.config.Defaults.AdminUsers {
|
||||
if adminUser == username {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ReloadConfig reloads the permissions configuration
|
||||
func (pm *PermissionManager) ReloadConfig(configPath string) error {
|
||||
return pm.loadConfig(configPath)
|
||||
}
|
||||
|
||||
// ValidatePermission checks if a permission string is valid
|
||||
func (pm *PermissionManager) ValidatePermission(permission string) bool {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return false
|
||||
}
|
||||
|
||||
// Wildcard is always valid
|
||||
if permission == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if permission matches any defined role permissions
|
||||
for _, rolePerms := range pm.rolePerms {
|
||||
if _, exists := rolePerms[permission]; exists {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if permission matches any group permissions
|
||||
for _, groupPerms := range pm.groupPerms {
|
||||
if _, exists := groupPerms[permission]; exists {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPermissionHierarchy returns the hierarchy for a resource
|
||||
func (pm *PermissionManager) GetPermissionHierarchy(resource string) map[string]interface{} {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
if !pm.loaded {
|
||||
return make(map[string]interface{})
|
||||
}
|
||||
|
||||
if hierarchy, exists := pm.config.Hierarchy[resource]; exists {
|
||||
return hierarchy.Children
|
||||
}
|
||||
|
||||
return make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Global permission manager instance
|
||||
var globalPermissionManager *PermissionManager
|
||||
var permissionManagerOnce sync.Once
|
||||
|
||||
// GetGlobalPermissionManager returns the global permission manager
|
||||
func GetGlobalPermissionManager() *PermissionManager {
|
||||
permissionManagerOnce.Do(func() {
|
||||
// Try to load from default location
|
||||
if pm, err := NewPermissionManager("configs/schema/permissions.yaml"); err == nil {
|
||||
globalPermissionManager = pm
|
||||
} else {
|
||||
// Fallback to empty manager
|
||||
globalPermissionManager = &PermissionManager{
|
||||
rolePerms: make(map[string]map[string]bool),
|
||||
groupPerms: make(map[string]map[string]bool),
|
||||
loaded: false,
|
||||
}
|
||||
}
|
||||
})
|
||||
return globalPermissionManager
|
||||
}
|
||||
100
internal/auth/validator.go
Normal file
100
internal/auth/validator.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateAuthConfig enforces authentication requirements
|
||||
func (c *AuthConfig) ValidateAuthConfig() error {
|
||||
// Check if we're in production environment
|
||||
isProduction := os.Getenv("FETCH_ML_ENV") == "production"
|
||||
|
||||
if isProduction {
|
||||
if !c.Enabled {
|
||||
return fmt.Errorf("authentication must be enabled in production environment")
|
||||
}
|
||||
|
||||
if len(c.APIKeys) == 0 {
|
||||
return fmt.Errorf("at least one API key must be configured in production")
|
||||
}
|
||||
|
||||
// Ensure at least one admin user exists
|
||||
hasAdmin := false
|
||||
for _, entry := range c.APIKeys {
|
||||
if entry.Admin {
|
||||
hasAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAdmin {
|
||||
return fmt.Errorf("at least one admin user must be configured in production")
|
||||
}
|
||||
|
||||
// Check for insecure development override
|
||||
if os.Getenv("FETCH_ML_ALLOW_INSECURE_AUTH") == "1" {
|
||||
log.Printf("WARNING: FETCH_ML_ALLOW_INSECURE_AUTH is enabled in production - this is insecure")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate API key format
|
||||
for username, entry := range c.APIKeys {
|
||||
if string(username) == "" {
|
||||
return fmt.Errorf("empty username not allowed")
|
||||
}
|
||||
|
||||
if entry.Hash == "" {
|
||||
return fmt.Errorf("user %s has empty API key hash", username)
|
||||
}
|
||||
|
||||
// Validate hash format (should be 64 hex chars for SHA256)
|
||||
if len(entry.Hash) != 64 {
|
||||
return fmt.Errorf("user %s has invalid API key hash format", username)
|
||||
}
|
||||
|
||||
// Check hash contains only hex characters
|
||||
for _, char := range entry.Hash {
|
||||
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
|
||||
return fmt.Errorf("user %s has invalid API key hash characters", username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckConfigFilePermissions ensures config files have secure permissions
|
||||
func CheckConfigFilePermissions(configPath string) error {
|
||||
info, err := os.Stat(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat config file: %w", err)
|
||||
}
|
||||
|
||||
// Check file permissions (should be 600 or 640)
|
||||
perm := info.Mode().Perm()
|
||||
if perm&0077 != 0 {
|
||||
return fmt.Errorf("config file %s has insecure permissions: %o (should be 600 or 640)", configPath, perm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeConfig removes sensitive information for logging
|
||||
func (c *AuthConfig) SanitizeConfig() map[string]interface{} {
|
||||
sanitized := map[string]interface{}{
|
||||
"enabled": c.Enabled,
|
||||
"users": make(map[string]interface{}),
|
||||
}
|
||||
|
||||
for username := range c.APIKeys {
|
||||
sanitized["users"].(map[string]interface{})[string(username)] = map[string]interface{}{
|
||||
"admin": c.APIKeys[username].Admin,
|
||||
"hash": strings.Repeat("*", 8) + "...", // Show only prefix
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
54
internal/config/constants.go
Normal file
54
internal/config/constants.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package config
|
||||
|
||||
// Default configuration values (legacy - use SmartDefaults for new code)
|
||||
const (
|
||||
DefaultSSHPort = 22
|
||||
DefaultRedisPort = 6379
|
||||
DefaultRedisAddr = "localhost:6379"
|
||||
DefaultBasePath = "/mnt/nas/jobs"
|
||||
DefaultTrainScript = "train.py"
|
||||
DefaultDataDir = "/data/active"
|
||||
DefaultLocalDataDir = "./data/active"
|
||||
DefaultNASDataDir = "/mnt/datasets"
|
||||
DefaultMaxWorkers = 2
|
||||
DefaultPollInterval = 5
|
||||
DefaultMaxAgeHours = 24
|
||||
DefaultMaxSizeGB = 100
|
||||
DefaultCleanupInterval = 60
|
||||
)
|
||||
|
||||
// Redis key prefixes
|
||||
const (
|
||||
RedisTaskQueueKey = "ml:queue"
|
||||
RedisTaskPrefix = "ml:task:"
|
||||
RedisJobMetricsPrefix = "ml:metrics:"
|
||||
RedisTaskStatusPrefix = "ml:status:"
|
||||
RedisDatasetPrefix = "ml:dataset:"
|
||||
RedisWorkerHeartbeat = "ml:workers:heartbeat"
|
||||
)
|
||||
|
||||
// Task status constants
|
||||
const (
|
||||
TaskStatusQueued = "queued"
|
||||
TaskStatusRunning = "running"
|
||||
TaskStatusCompleted = "completed"
|
||||
TaskStatusFailed = "failed"
|
||||
TaskStatusCancelled = "cancelled"
|
||||
)
|
||||
|
||||
// Job status constants
|
||||
const (
|
||||
JobStatusPending = "pending"
|
||||
JobStatusQueued = "queued"
|
||||
JobStatusRunning = "running"
|
||||
JobStatusFinished = "finished"
|
||||
JobStatusFailed = "failed"
|
||||
)
|
||||
|
||||
// Podman defaults
|
||||
const (
|
||||
DefaultPodmanMemory = "8g"
|
||||
DefaultPodmanCPUs = "2"
|
||||
DefaultContainerWorkspace = "/workspace"
|
||||
DefaultContainerResults = "/workspace/results"
|
||||
)
|
||||
73
internal/config/paths.go
Normal file
73
internal/config/paths.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// Package config provides shared utilities for the fetch_ml project.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExpandPath expands environment variables and tilde in a path
|
||||
func ExpandPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
expanded := os.ExpandEnv(path)
|
||||
if strings.HasPrefix(expanded, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
expanded = filepath.Join(home, expanded[1:])
|
||||
}
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
// ResolveConfigPath resolves a config file path, checking multiple locations
|
||||
func ResolveConfigPath(path string) (string, error) {
|
||||
candidates := []string{path}
|
||||
if !filepath.IsAbs(path) {
|
||||
candidates = append(candidates, filepath.Join("configs", path))
|
||||
}
|
||||
|
||||
var checked []string
|
||||
for _, candidate := range candidates {
|
||||
resolved := ExpandPath(candidate)
|
||||
checked = append(checked, resolved)
|
||||
if _, err := os.Stat(resolved); err == nil {
|
||||
return resolved, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("config file not found (looked in %s)", strings.Join(checked, ", "))
|
||||
}
|
||||
|
||||
// JobPaths provides helper methods for job directory paths
|
||||
type JobPaths struct {
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewJobPaths creates a new JobPaths instance
|
||||
func NewJobPaths(basePath string) *JobPaths {
|
||||
return &JobPaths{BasePath: basePath}
|
||||
}
|
||||
|
||||
// PendingPath returns the path to pending jobs directory
|
||||
func (j *JobPaths) PendingPath() string {
|
||||
return filepath.Join(j.BasePath, "pending")
|
||||
}
|
||||
|
||||
// RunningPath returns the path to running jobs directory
|
||||
func (j *JobPaths) RunningPath() string {
|
||||
return filepath.Join(j.BasePath, "running")
|
||||
}
|
||||
|
||||
// FinishedPath returns the path to finished jobs directory
|
||||
func (j *JobPaths) FinishedPath() string {
|
||||
return filepath.Join(j.BasePath, "finished")
|
||||
}
|
||||
|
||||
// FailedPath returns the path to failed jobs directory
|
||||
func (j *JobPaths) FailedPath() string {
|
||||
return filepath.Join(j.BasePath, "failed")
|
||||
}
|
||||
222
internal/config/smart_defaults.go
Normal file
222
internal/config/smart_defaults.go
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EnvironmentProfile represents the deployment environment
|
||||
type EnvironmentProfile int
|
||||
|
||||
const (
|
||||
ProfileLocal EnvironmentProfile = iota
|
||||
ProfileContainer
|
||||
ProfileCI
|
||||
ProfileProduction
|
||||
)
|
||||
|
||||
// DetectEnvironment determines the current environment profile
|
||||
func DetectEnvironment() EnvironmentProfile {
|
||||
// CI detection
|
||||
if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("GITLAB_CI") != "" {
|
||||
return ProfileCI
|
||||
}
|
||||
|
||||
// Container detection
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return ProfileContainer
|
||||
}
|
||||
if os.Getenv("KUBERNETES_SERVICE_HOST") != "" {
|
||||
return ProfileContainer
|
||||
}
|
||||
if os.Getenv("CONTAINER") != "" {
|
||||
return ProfileContainer
|
||||
}
|
||||
|
||||
// Production detection (customizable)
|
||||
if os.Getenv("FETCH_ML_ENV") == "production" || os.Getenv("ENV") == "production" {
|
||||
return ProfileProduction
|
||||
}
|
||||
|
||||
// Default to local development
|
||||
return ProfileLocal
|
||||
}
|
||||
|
||||
// SmartDefaults provides environment-aware default values
|
||||
type SmartDefaults struct {
|
||||
Profile EnvironmentProfile
|
||||
}
|
||||
|
||||
// GetSmartDefaults returns defaults for the current environment
|
||||
func GetSmartDefaults() *SmartDefaults {
|
||||
return &SmartDefaults{
|
||||
Profile: DetectEnvironment(),
|
||||
}
|
||||
}
|
||||
|
||||
// Host returns the appropriate default host
|
||||
func (s *SmartDefaults) Host() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "host.docker.internal" // Docker Desktop/Colima
|
||||
case ProfileProduction:
|
||||
return "0.0.0.0"
|
||||
default: // ProfileLocal
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
// BasePath returns the appropriate default base path
|
||||
func (s *SmartDefaults) BasePath() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/ml-experiments"
|
||||
case ProfileProduction:
|
||||
return "/var/lib/fetch_ml/experiments"
|
||||
default: // ProfileLocal
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, "ml-experiments")
|
||||
}
|
||||
return "./ml-experiments"
|
||||
}
|
||||
}
|
||||
|
||||
// DataDir returns the appropriate default data directory
|
||||
func (s *SmartDefaults) DataDir() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/data"
|
||||
case ProfileProduction:
|
||||
return "/var/lib/fetch_ml/data"
|
||||
default: // ProfileLocal
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, "ml-data")
|
||||
}
|
||||
return "./data"
|
||||
}
|
||||
}
|
||||
|
||||
// RedisAddr returns the appropriate default Redis address
|
||||
func (s *SmartDefaults) RedisAddr() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "redis:6379" // Service name in containers
|
||||
case ProfileProduction:
|
||||
return "redis:6379"
|
||||
default: // ProfileLocal
|
||||
return "localhost:6379"
|
||||
}
|
||||
}
|
||||
|
||||
// SSHKeyPath returns the appropriate default SSH key path
|
||||
func (s *SmartDefaults) SSHKeyPath() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/.ssh/id_rsa"
|
||||
case ProfileProduction:
|
||||
return "/etc/fetch_ml/ssh/id_rsa"
|
||||
default: // ProfileLocal
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, ".ssh", "id_rsa")
|
||||
}
|
||||
return "~/.ssh/id_rsa"
|
||||
}
|
||||
}
|
||||
|
||||
// KnownHostsPath returns the appropriate default known_hosts path
|
||||
func (s *SmartDefaults) KnownHostsPath() string {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/.ssh/known_hosts"
|
||||
case ProfileProduction:
|
||||
return "/etc/fetch_ml/ssh/known_hosts"
|
||||
default: // ProfileLocal
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, ".ssh", "known_hosts")
|
||||
}
|
||||
return "~/.ssh/known_hosts"
|
||||
}
|
||||
}
|
||||
|
||||
// LogLevel returns the appropriate default log level
|
||||
func (s *SmartDefaults) LogLevel() string {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return "debug" // More verbose for CI debugging
|
||||
case ProfileProduction:
|
||||
return "info"
|
||||
default: // ProfileLocal, ProfileContainer
|
||||
return "info"
|
||||
}
|
||||
}
|
||||
|
||||
// MaxWorkers returns the appropriate default worker count
|
||||
func (s *SmartDefaults) MaxWorkers() int {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return 1 // Conservative for CI
|
||||
case ProfileProduction:
|
||||
return runtime.NumCPU() // Scale with CPU cores
|
||||
default: // ProfileLocal, ProfileContainer
|
||||
return 2 // Reasonable default for local dev
|
||||
}
|
||||
}
|
||||
|
||||
// PollInterval returns the appropriate default poll interval in seconds
|
||||
func (s *SmartDefaults) PollInterval() int {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return 1 // Fast polling for quick tests
|
||||
case ProfileProduction:
|
||||
return 10 // Conservative for production
|
||||
default: // ProfileLocal, ProfileContainer
|
||||
return 5 // Balanced default
|
||||
}
|
||||
}
|
||||
|
||||
// IsInContainer returns true if running in a container environment
|
||||
func (s *SmartDefaults) IsInContainer() bool {
|
||||
return s.Profile == ProfileContainer || s.Profile == ProfileCI
|
||||
}
|
||||
|
||||
// IsProduction returns true if this is a production environment
|
||||
func (s *SmartDefaults) IsProduction() bool {
|
||||
return s.Profile == ProfileProduction
|
||||
}
|
||||
|
||||
// IsCI returns true if this is a CI environment
|
||||
func (s *SmartDefaults) IsCI() bool {
|
||||
return s.Profile == ProfileCI
|
||||
}
|
||||
|
||||
// ExpandPath expands ~ and environment variables in paths
|
||||
func (s *SmartDefaults) ExpandPath(path string) string {
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
path = filepath.Join(home, path[2:])
|
||||
}
|
||||
}
|
||||
|
||||
// Expand environment variables
|
||||
path = os.ExpandEnv(path)
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// GetEnvironmentDescription returns a human-readable description
|
||||
func (s *SmartDefaults) GetEnvironmentDescription() string {
|
||||
switch s.Profile {
|
||||
case ProfileLocal:
|
||||
return "Local Development"
|
||||
case ProfileContainer:
|
||||
return "Container Environment"
|
||||
case ProfileCI:
|
||||
return "CI/CD Environment"
|
||||
case ProfileProduction:
|
||||
return "Production Environment"
|
||||
default:
|
||||
return "Unknown Environment"
|
||||
}
|
||||
}
|
||||
69
internal/config/validation.go
Normal file
69
internal/config/validation.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project,
|
||||
// including SSH clients, configuration helpers, logging, metrics,
|
||||
// and validation functions.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Validator is an interface for types that can validate themselves.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// ValidateConfig validates a configuration struct that implements the Validator interface.
|
||||
func ValidateConfig(v Validator) error {
|
||||
return v.Validate()
|
||||
}
|
||||
|
||||
// ValidatePort checks if a port number is within the valid range (1-65535).
|
||||
func ValidatePort(port int) error {
|
||||
if port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port: %d (must be 1-65535)", port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDirectory checks if a path exists and is a directory.
|
||||
func ValidateDirectory(path string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
|
||||
expanded := ExpandPath(path)
|
||||
info, err := os.Stat(expanded)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("directory does not exist: %s", expanded)
|
||||
}
|
||||
return fmt.Errorf("cannot access directory %s: %w", expanded, err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path is not a directory: %s", expanded)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedisAddr validates a Redis address in the format "host:port".
|
||||
func ValidateRedisAddr(addr string) error {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid redis address format: %w", err)
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return fmt.Errorf("redis host cannot be empty")
|
||||
}
|
||||
|
||||
var portInt int
|
||||
if _, err := fmt.Sscanf(port, "%d", &portInt); err != nil {
|
||||
return fmt.Errorf("invalid port number %q: %w", port, err)
|
||||
}
|
||||
|
||||
return ValidatePort(portInt)
|
||||
}
|
||||
105
internal/container/podman.go
Normal file
105
internal/container/podman.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package container
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
)
|
||||
|
||||
// PodmanConfig holds configuration for Podman container execution
|
||||
type PodmanConfig struct {
|
||||
Image string
|
||||
Workspace string
|
||||
Results string
|
||||
ContainerWorkspace string
|
||||
ContainerResults string
|
||||
GPUAccess bool
|
||||
Memory string
|
||||
CPUs string
|
||||
}
|
||||
|
||||
// BuildPodmanCommand builds a Podman command for executing ML experiments
|
||||
func BuildPodmanCommand(cfg PodmanConfig, scriptPath, requirementsPath string, extraArgs []string) *exec.Cmd {
|
||||
args := []string{
|
||||
"run", "--rm",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--cap-drop", "ALL",
|
||||
}
|
||||
|
||||
if cfg.Memory != "" {
|
||||
args = append(args, "--memory", cfg.Memory)
|
||||
} else {
|
||||
args = append(args, "--memory", config.DefaultPodmanMemory)
|
||||
}
|
||||
|
||||
if cfg.CPUs != "" {
|
||||
args = append(args, "--cpus", cfg.CPUs)
|
||||
} else {
|
||||
args = append(args, "--cpus", config.DefaultPodmanCPUs)
|
||||
}
|
||||
|
||||
args = append(args, "--userns", "keep-id")
|
||||
|
||||
// Mount workspace
|
||||
workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace)
|
||||
args = append(args, "-v", workspaceMount)
|
||||
|
||||
// Mount results
|
||||
resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults)
|
||||
args = append(args, "-v", resultsMount)
|
||||
|
||||
if cfg.GPUAccess {
|
||||
args = append(args, "--device", "/dev/dri")
|
||||
}
|
||||
|
||||
// Image and command
|
||||
args = append(args, cfg.Image,
|
||||
"--workspace", cfg.ContainerWorkspace,
|
||||
"--requirements", requirementsPath,
|
||||
"--script", scriptPath,
|
||||
)
|
||||
|
||||
// Add extra arguments via --args flag
|
||||
if len(extraArgs) > 0 {
|
||||
args = append(args, "--args")
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
return exec.Command("podman", args...)
|
||||
}
|
||||
|
||||
// SanitizePath ensures a path is safe to use (prevents path traversal)
|
||||
func SanitizePath(path string) (string, error) {
|
||||
// Clean the path to remove any .. or . components
|
||||
cleaned := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleaned, "..") {
|
||||
return "", fmt.Errorf("path traversal detected: %s", path)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// ValidateJobName validates a job name is safe
|
||||
func ValidateJobName(jobName string) error {
|
||||
if jobName == "" {
|
||||
return fmt.Errorf("job name cannot be empty")
|
||||
}
|
||||
|
||||
// Check for dangerous characters
|
||||
if strings.ContainsAny(jobName, "/\\<>:\"|?*") {
|
||||
return fmt.Errorf("job name contains invalid characters: %s", jobName)
|
||||
}
|
||||
|
||||
// Check for path traversal
|
||||
if strings.Contains(jobName, "..") {
|
||||
return fmt.Errorf("job name contains path traversal: %s", jobName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
39
internal/errors/errors.go
Normal file
39
internal/errors/errors.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DataFetchError represents an error that occurred while fetching a dataset
|
||||
// from the NAS to the ML server.
|
||||
type DataFetchError struct {
|
||||
Dataset string
|
||||
JobName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *DataFetchError) Error() string {
|
||||
return fmt.Sprintf("failed to fetch dataset %s for job %s: %v",
|
||||
e.Dataset, e.JobName, e.Err)
|
||||
}
|
||||
|
||||
func (e *DataFetchError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TaskExecutionError struct {
|
||||
TaskID string
|
||||
JobName string
|
||||
Phase string // "data_fetch", "execution", "cleanup"
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *TaskExecutionError) Error() string {
|
||||
return fmt.Sprintf("task %s (%s) failed during %s: %v",
|
||||
e.TaskID[:8], e.JobName, e.Phase, e.Err)
|
||||
}
|
||||
|
||||
func (e *TaskExecutionError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
343
internal/experiment/manager.go
Normal file
343
internal/experiment/manager.go
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
package experiment
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metadata represents experiment metadata stored in meta.bin
|
||||
type Metadata struct {
|
||||
CommitID string
|
||||
Timestamp int64
|
||||
JobName string
|
||||
User string
|
||||
}
|
||||
|
||||
// Manager handles experiment storage and metadata
|
||||
type Manager struct {
|
||||
basePath string
|
||||
}
|
||||
|
||||
func NewManager(basePath string) *Manager {
|
||||
return &Manager{
|
||||
basePath: basePath,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize ensures the experiment directory exists
|
||||
func (m *Manager) Initialize() error {
|
||||
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create experiment base directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExperimentPath returns the path for a given commit ID
|
||||
func (m *Manager) GetExperimentPath(commitID string) string {
|
||||
return filepath.Join(m.basePath, commitID)
|
||||
}
|
||||
|
||||
// GetFilesPath returns the path to the files directory for an experiment
|
||||
func (m *Manager) GetFilesPath(commitID string) string {
|
||||
return filepath.Join(m.GetExperimentPath(commitID), "files")
|
||||
}
|
||||
|
||||
// GetMetadataPath returns the path to meta.bin for an experiment
|
||||
func (m *Manager) GetMetadataPath(commitID string) string {
|
||||
return filepath.Join(m.GetExperimentPath(commitID), "meta.bin")
|
||||
}
|
||||
|
||||
// ExperimentExists checks if an experiment with the given commit ID exists
|
||||
func (m *Manager) ExperimentExists(commitID string) bool {
|
||||
path := m.GetExperimentPath(commitID)
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && info.IsDir()
|
||||
}
|
||||
|
||||
// CreateExperiment creates the directory structure for a new experiment
|
||||
func (m *Manager) CreateExperiment(commitID string) error {
|
||||
filesPath := m.GetFilesPath(commitID)
|
||||
|
||||
if err := os.MkdirAll(filesPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create experiment directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteMetadata writes experiment metadata to meta.bin
|
||||
func (m *Manager) WriteMetadata(meta *Metadata) error {
|
||||
path := m.GetMetadataPath(meta.CommitID)
|
||||
|
||||
// Binary format:
|
||||
// [version:1][timestamp:8][commit_id_len:1][commit_id:var][job_name_len:1][job_name:var][user_len:1][user:var]
|
||||
|
||||
buf := make([]byte, 0, 256)
|
||||
|
||||
// Version
|
||||
buf = append(buf, 0x01)
|
||||
|
||||
// Timestamp
|
||||
ts := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ts, uint64(meta.Timestamp))
|
||||
buf = append(buf, ts...)
|
||||
|
||||
// Commit ID
|
||||
buf = append(buf, byte(len(meta.CommitID)))
|
||||
buf = append(buf, []byte(meta.CommitID)...)
|
||||
|
||||
// Job Name
|
||||
buf = append(buf, byte(len(meta.JobName)))
|
||||
buf = append(buf, []byte(meta.JobName)...)
|
||||
|
||||
// User
|
||||
buf = append(buf, byte(len(meta.User)))
|
||||
buf = append(buf, []byte(meta.User)...)
|
||||
|
||||
return os.WriteFile(path, buf, 0644)
|
||||
}
|
||||
|
||||
// ReadMetadata reads experiment metadata from meta.bin
|
||||
func (m *Manager) ReadMetadata(commitID string) (*Metadata, error) {
|
||||
path := m.GetMetadataPath(commitID)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read metadata: %w", err)
|
||||
}
|
||||
|
||||
if len(data) < 10 {
|
||||
return nil, fmt.Errorf("metadata file too short")
|
||||
}
|
||||
|
||||
meta := &Metadata{}
|
||||
offset := 0
|
||||
|
||||
// Version
|
||||
version := data[offset]
|
||||
offset++
|
||||
if version != 0x01 {
|
||||
return nil, fmt.Errorf("unsupported metadata version: %d", version)
|
||||
}
|
||||
|
||||
// Timestamp
|
||||
meta.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8]))
|
||||
offset += 8
|
||||
|
||||
// Commit ID
|
||||
commitIDLen := int(data[offset])
|
||||
offset++
|
||||
meta.CommitID = string(data[offset : offset+commitIDLen])
|
||||
offset += commitIDLen
|
||||
|
||||
// Job Name
|
||||
if offset >= len(data) {
|
||||
return meta, nil
|
||||
}
|
||||
jobNameLen := int(data[offset])
|
||||
offset++
|
||||
meta.JobName = string(data[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
// User
|
||||
if offset >= len(data) {
|
||||
return meta, nil
|
||||
}
|
||||
userLen := int(data[offset])
|
||||
offset++
|
||||
meta.User = string(data[offset : offset+userLen])
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// ListExperiments returns all experiment commit IDs
|
||||
func (m *Manager) ListExperiments() ([]string, error) {
|
||||
entries, err := os.ReadDir(m.basePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read experiments directory: %w", err)
|
||||
}
|
||||
|
||||
var commitIDs []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
commitIDs = append(commitIDs, entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
return commitIDs, nil
|
||||
}
|
||||
|
||||
// PruneExperiments removes old experiments based on retention policy
|
||||
func (m *Manager) PruneExperiments(keepCount int, olderThanDays int) ([]string, error) {
|
||||
commitIDs, err := m.ListExperiments()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type experiment struct {
|
||||
commitID string
|
||||
timestamp int64
|
||||
}
|
||||
|
||||
var experiments []experiment
|
||||
for _, commitID := range commitIDs {
|
||||
meta, err := m.ReadMetadata(commitID)
|
||||
if err != nil {
|
||||
continue // Skip experiments with invalid metadata
|
||||
}
|
||||
experiments = append(experiments, experiment{
|
||||
commitID: commitID,
|
||||
timestamp: meta.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by timestamp (newest first)
|
||||
for i := 0; i < len(experiments); i++ {
|
||||
for j := i + 1; j < len(experiments); j++ {
|
||||
if experiments[j].timestamp > experiments[i].timestamp {
|
||||
experiments[i], experiments[j] = experiments[j], experiments[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pruned []string
|
||||
cutoffTime := time.Now().AddDate(0, 0, -olderThanDays).Unix()
|
||||
|
||||
for i, exp := range experiments {
|
||||
shouldPrune := false
|
||||
|
||||
// Keep the newest N experiments
|
||||
if i >= keepCount {
|
||||
shouldPrune = true
|
||||
}
|
||||
|
||||
// Also prune if older than threshold
|
||||
if olderThanDays > 0 && exp.timestamp < cutoffTime {
|
||||
shouldPrune = true
|
||||
}
|
||||
|
||||
if shouldPrune {
|
||||
expPath := m.GetExperimentPath(exp.commitID)
|
||||
if err := os.RemoveAll(expPath); err != nil {
|
||||
// Log but continue
|
||||
continue
|
||||
}
|
||||
pruned = append(pruned, exp.commitID)
|
||||
}
|
||||
}
|
||||
|
||||
return pruned, nil
|
||||
}
|
||||
|
||||
// Metric represents a single data point in an experiment
|
||||
type Metric struct {
|
||||
Name string `json:"name"`
|
||||
Value float64 `json:"value"`
|
||||
Step int `json:"step"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// GetMetricsPath returns the path to metrics.bin for an experiment
|
||||
func (m *Manager) GetMetricsPath(commitID string) string {
|
||||
return filepath.Join(m.GetExperimentPath(commitID), "metrics.bin")
|
||||
}
|
||||
|
||||
// LogMetric appends a metric to the experiment's metrics file
|
||||
func (m *Manager) LogMetric(commitID string, name string, value float64, step int) error {
|
||||
path := m.GetMetricsPath(commitID)
|
||||
|
||||
// Binary format for each metric:
|
||||
// [timestamp:8][step:4][value:8][name_len:1][name:var]
|
||||
|
||||
buf := make([]byte, 0, 64)
|
||||
|
||||
// Timestamp
|
||||
ts := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ts, uint64(time.Now().Unix()))
|
||||
buf = append(buf, ts...)
|
||||
|
||||
// Step
|
||||
st := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(st, uint32(step))
|
||||
buf = append(buf, st...)
|
||||
|
||||
// Value (float64)
|
||||
val := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(val, math.Float64bits(value))
|
||||
buf = append(buf, val...)
|
||||
|
||||
// Name
|
||||
if len(name) > 255 {
|
||||
name = name[:255]
|
||||
}
|
||||
buf = append(buf, byte(len(name)))
|
||||
buf = append(buf, []byte(name)...)
|
||||
|
||||
// Append to file
|
||||
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open metrics file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
return fmt.Errorf("failed to write metric: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetrics reads all metrics for an experiment
|
||||
func (m *Manager) GetMetrics(commitID string) ([]Metric, error) {
|
||||
path := m.GetMetricsPath(commitID)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return []Metric{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read metrics file: %w", err)
|
||||
}
|
||||
|
||||
var metrics []Metric
|
||||
offset := 0
|
||||
|
||||
for offset < len(data) {
|
||||
if offset+21 > len(data) { // Min size check
|
||||
break
|
||||
}
|
||||
|
||||
m := Metric{}
|
||||
|
||||
// Timestamp
|
||||
m.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8]))
|
||||
offset += 8
|
||||
|
||||
// Step
|
||||
m.Step = int(binary.BigEndian.Uint32(data[offset : offset+4]))
|
||||
offset += 4
|
||||
|
||||
// Value
|
||||
bits := binary.BigEndian.Uint64(data[offset : offset+8])
|
||||
m.Value = math.Float64frombits(bits)
|
||||
offset += 8
|
||||
|
||||
// Name
|
||||
nameLen := int(data[offset])
|
||||
offset++
|
||||
|
||||
if offset+nameLen > len(data) {
|
||||
break
|
||||
}
|
||||
m.Name = string(data[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
metrics = append(metrics, m)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
52
internal/logging/config.go
Normal file
52
internal/logging/config.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Config holds logging configuration
|
||||
type Config struct {
|
||||
Level string `yaml:"level"`
|
||||
File string `yaml:"file"`
|
||||
AuditLog string `yaml:"audit_log"`
|
||||
}
|
||||
|
||||
// LevelFromEnv reads LOG_LEVEL (if set) and returns the matching slog level.
|
||||
// Accepted values: debug, info, warn, error. Defaults to info.
|
||||
func LevelFromEnv() slog.Level {
|
||||
return parseLevel(os.Getenv("LOG_LEVEL"), slog.LevelInfo)
|
||||
}
|
||||
|
||||
func parseLevel(value string, defaultLevel slog.Level) slog.Level {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "warn", "warning":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
case "info", "":
|
||||
return slog.LevelInfo
|
||||
default:
|
||||
return defaultLevel
|
||||
}
|
||||
}
|
||||
|
||||
// NewConfiguredLogger creates a logger using the level configured via LOG_LEVEL.
|
||||
// JSON/text output is still controlled by LOG_FORMAT in NewLogger.
|
||||
func NewConfiguredLogger() *Logger {
|
||||
return NewLogger(LevelFromEnv(), false)
|
||||
}
|
||||
|
||||
// NewLoggerFromConfig creates a logger from configuration
|
||||
func NewLoggerFromConfig(cfg Config) *Logger {
|
||||
level := parseLevel(cfg.Level, slog.LevelInfo)
|
||||
|
||||
if cfg.File != "" {
|
||||
return NewFileLogger(level, false, cfg.File)
|
||||
}
|
||||
|
||||
return NewLogger(level, false)
|
||||
}
|
||||
172
internal/logging/logging.go
Normal file
172
internal/logging/logging.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
CtxTraceID ctxKey = "trace_id"
|
||||
CtxSpanID ctxKey = "span_id"
|
||||
CtxWorker ctxKey = "worker_id"
|
||||
CtxJob ctxKey = "job_name"
|
||||
CtxTask ctxKey = "task_id"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*slog.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a logger that writes to stderr (development mode)
|
||||
func NewLogger(level slog.Level, jsonOutput bool) *Logger {
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: os.Getenv("LOG_ADD_SOURCE") == "1",
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
if jsonOutput || os.Getenv("LOG_FORMAT") == "json" {
|
||||
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||
} else {
|
||||
handler = NewColorTextHandler(os.Stderr, opts)
|
||||
}
|
||||
|
||||
return &Logger{slog.New(handler)}
|
||||
}
|
||||
|
||||
// NewFileLogger creates a logger that writes to a file only (production mode)
|
||||
func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger {
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: os.Getenv("LOG_ADD_SOURCE") == "1",
|
||||
}
|
||||
|
||||
// Create log directory if it doesn't exist
|
||||
if logFile != "" {
|
||||
logDir := filepath.Dir(logFile)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
// Fallback to stderr only if directory creation fails
|
||||
return NewLogger(level, jsonOutput)
|
||||
}
|
||||
}
|
||||
|
||||
// Open log file
|
||||
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
// Fallback to stderr only if file creation fails
|
||||
return NewLogger(level, jsonOutput)
|
||||
}
|
||||
|
||||
// Write to file only (production)
|
||||
var handler slog.Handler
|
||||
if jsonOutput || os.Getenv("LOG_FORMAT") == "json" {
|
||||
handler = slog.NewJSONHandler(file, opts)
|
||||
} else {
|
||||
handler = slog.NewTextHandler(file, opts)
|
||||
}
|
||||
|
||||
return &Logger{slog.New(handler)}
|
||||
}
|
||||
|
||||
// Inject trace + span if missing
|
||||
func EnsureTrace(ctx context.Context) context.Context {
|
||||
if ctx.Value(CtxTraceID) == nil {
|
||||
ctx = context.WithValue(ctx, CtxTraceID, uuid.NewString())
|
||||
}
|
||||
if ctx.Value(CtxSpanID) == nil {
|
||||
ctx = context.WithValue(ctx, CtxSpanID, uuid.NewString())
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (l *Logger) WithContext(ctx context.Context, args ...any) *Logger {
|
||||
if trace := ctx.Value(CtxTraceID); trace != nil {
|
||||
args = append(args, "trace_id", trace)
|
||||
}
|
||||
if span := ctx.Value(CtxSpanID); span != nil {
|
||||
args = append(args, "span_id", span)
|
||||
}
|
||||
if worker := ctx.Value(CtxWorker); worker != nil {
|
||||
args = append(args, "worker_id", worker)
|
||||
}
|
||||
if job := ctx.Value(CtxJob); job != nil {
|
||||
args = append(args, "job_name", job)
|
||||
}
|
||||
if task := ctx.Value(CtxTask); task != nil {
|
||||
args = append(args, "task_id", task)
|
||||
}
|
||||
return &Logger{Logger: l.With(args...)}
|
||||
}
|
||||
|
||||
func CtxWithWorker(ctx context.Context, worker string) context.Context {
|
||||
return context.WithValue(ctx, CtxWorker, worker)
|
||||
}
|
||||
|
||||
func CtxWithJob(ctx context.Context, job string) context.Context {
|
||||
return context.WithValue(ctx, CtxJob, job)
|
||||
}
|
||||
|
||||
func CtxWithTask(ctx context.Context, task string) context.Context {
|
||||
return context.WithValue(ctx, CtxTask, task)
|
||||
}
|
||||
|
||||
func (l *Logger) Component(ctx context.Context, name string) *Logger {
|
||||
return l.WithContext(ctx, "component", name)
|
||||
}
|
||||
|
||||
func (l *Logger) Worker(ctx context.Context, workerID string) *Logger {
|
||||
return l.WithContext(ctx, "worker_id", workerID)
|
||||
}
|
||||
|
||||
func (l *Logger) Job(ctx context.Context, job string, task string) *Logger {
|
||||
return l.WithContext(ctx, "job_name", job, "task_id", task)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(msg string, args ...any) {
|
||||
l.Error(msg, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func (l *Logger) Panic(msg string, args ...any) {
|
||||
l.Error(msg, args...)
|
||||
panic(msg)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------
|
||||
// Colorized human-friendly console logs
|
||||
// -----------------------------------------------------
|
||||
|
||||
type ColorTextHandler struct {
|
||||
slog.Handler
|
||||
}
|
||||
|
||||
func NewColorTextHandler(w io.Writer, opts *slog.HandlerOptions) slog.Handler {
|
||||
base := slog.NewTextHandler(w, opts)
|
||||
return &ColorTextHandler{Handler: base}
|
||||
}
|
||||
|
||||
func (h *ColorTextHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
// Add uniform timestamp (override default)
|
||||
r.Time = time.Now()
|
||||
|
||||
switch r.Level {
|
||||
case slog.LevelDebug:
|
||||
r.Add("lvl_color", "\033[34mDBG\033[0m")
|
||||
case slog.LevelInfo:
|
||||
r.Add("lvl_color", "\033[32mINF\033[0m")
|
||||
case slog.LevelWarn:
|
||||
r.Add("lvl_color", "\033[33mWRN\033[0m")
|
||||
case slog.LevelError:
|
||||
r.Add("lvl_color", "\033[31mERR\033[0m")
|
||||
}
|
||||
|
||||
return h.Handler.Handle(ctx, r)
|
||||
}
|
||||
80
internal/logging/sanitize.go
Normal file
80
internal/logging/sanitize.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Patterns for sensitive data
|
||||
var (
|
||||
// API keys: 32+ hex characters
|
||||
apiKeyPattern = regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`)
|
||||
|
||||
// JWT tokens
|
||||
jwtPattern = regexp.MustCompile(`eyJ[a-zA-Z0-9_-]{10,}\.eyJ[a-zA-Z0-9_-]{10,}\.[a-zA-Z0-9_-]{10,}`)
|
||||
|
||||
// Password-like fields in logs
|
||||
passwordPattern = regexp.MustCompile(`(?i)(password|passwd|pwd|secret|token|key)["']?\s*[:=]\s*["']?([^"'\s,}]+)`)
|
||||
|
||||
// Redis URLs with passwords
|
||||
redisPasswordPattern = regexp.MustCompile(`redis://:[^@]+@`)
|
||||
)
|
||||
|
||||
// SanitizeLogMessage removes sensitive data from log messages
|
||||
func SanitizeLogMessage(message string) string {
|
||||
// Redact API keys
|
||||
message = apiKeyPattern.ReplaceAllString(message, "[REDACTED-API-KEY]")
|
||||
|
||||
// Redact JWT tokens
|
||||
message = jwtPattern.ReplaceAllString(message, "[REDACTED-JWT]")
|
||||
|
||||
// Redact password-like fields
|
||||
message = passwordPattern.ReplaceAllStringFunc(message, func(match string) string {
|
||||
parts := passwordPattern.FindStringSubmatch(match)
|
||||
if len(parts) >= 2 {
|
||||
return parts[1] + "=[REDACTED]"
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
// Redact Redis passwords from URLs
|
||||
message = redisPasswordPattern.ReplaceAllString(message, "redis://:[REDACTED]@")
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// SanitizeArgs removes sensitive data from structured log arguments
|
||||
func SanitizeArgs(args []any) []any {
|
||||
sanitized := make([]any, len(args))
|
||||
copy(sanitized, args)
|
||||
|
||||
for i := 0; i < len(sanitized)-1; i += 2 {
|
||||
// Check if this is a key-value pair
|
||||
key, okKey := sanitized[i].(string)
|
||||
value, okValue := sanitized[i+1].(string)
|
||||
|
||||
if okKey && okValue {
|
||||
lowerKey := strings.ToLower(key)
|
||||
// Redact sensitive fields
|
||||
if strings.Contains(lowerKey, "password") ||
|
||||
strings.Contains(lowerKey, "secret") ||
|
||||
strings.Contains(lowerKey, "token") ||
|
||||
strings.Contains(lowerKey, "key") ||
|
||||
strings.Contains(lowerKey, "api") {
|
||||
sanitized[i+1] = "[REDACTED]"
|
||||
} else if strings.HasPrefix(value, "redis://") {
|
||||
sanitized[i+1] = SanitizeLogMessage(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// RedactAPIKey masks an API key for logging (shows first/last 4 chars)
|
||||
func RedactAPIKey(key string) string {
|
||||
if len(key) <= 8 {
|
||||
return "[REDACTED]"
|
||||
}
|
||||
return key[:4] + "..." + key[len(key)-4:]
|
||||
}
|
||||
71
internal/metrics/metrics.go
Normal file
71
internal/metrics/metrics.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TasksProcessed atomic.Int64
|
||||
TasksFailed atomic.Int64
|
||||
DataFetchTime atomic.Int64 // Total nanoseconds
|
||||
ExecutionTime atomic.Int64
|
||||
DataTransferred atomic.Int64 // Total bytes
|
||||
ActiveTasks atomic.Int64
|
||||
QueuedTasks atomic.Int64
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordTaskSuccess(duration time.Duration) {
|
||||
m.TasksProcessed.Add(1)
|
||||
m.ExecutionTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordTaskFailure() {
|
||||
m.TasksFailed.Add(1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordTaskStart() {
|
||||
m.ActiveTasks.Add(1)
|
||||
}
|
||||
|
||||
// RecordTaskCompletion decrements the number of active tasks. It is safe to call
|
||||
// even if no tasks are currently recorded; the caller should ensure calls are
|
||||
// balanced with RecordTaskStart.
|
||||
func (m *Metrics) RecordTaskCompletion() {
|
||||
m.ActiveTasks.Add(-1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) {
|
||||
m.DataTransferred.Add(bytes)
|
||||
m.DataFetchTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) SetQueuedTasks(count int64) {
|
||||
m.QueuedTasks.Store(count)
|
||||
}
|
||||
|
||||
func (m *Metrics) GetStats() map[string]any {
|
||||
processed := m.TasksProcessed.Load()
|
||||
failed := m.TasksFailed.Load()
|
||||
dataTransferred := m.DataTransferred.Load()
|
||||
dataFetchTime := m.DataFetchTime.Load()
|
||||
|
||||
return map[string]any{
|
||||
"tasks_processed": processed,
|
||||
"tasks_failed": failed,
|
||||
"active_tasks": m.ActiveTasks.Load(),
|
||||
"queued_tasks": m.QueuedTasks.Load(),
|
||||
"success_rate": float64(processed-failed) / float64(max(processed, 1)),
|
||||
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / max(processed, 1)),
|
||||
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
|
||||
"avg_fetch_time": time.Duration(dataFetchTime / max(processed, 1)),
|
||||
}
|
||||
}
|
||||
259
internal/middleware/security.go
Normal file
259
internal/middleware/security.go
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// SecurityMiddleware provides comprehensive security features
|
||||
type SecurityMiddleware struct {
|
||||
rateLimiter *rate.Limiter
|
||||
apiKeys map[string]bool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func NewSecurityMiddleware(apiKeys []string, jwtSecret string) *SecurityMiddleware {
|
||||
keyMap := make(map[string]bool)
|
||||
for _, key := range apiKeys {
|
||||
keyMap[key] = true
|
||||
}
|
||||
|
||||
return &SecurityMiddleware{
|
||||
rateLimiter: rate.NewLimiter(rate.Limit(60), 10), // 60 requests per minute, burst of 10
|
||||
apiKeys: keyMap,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting middleware
|
||||
func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !sm.rateLimiter.Allow() {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// API key authentication
|
||||
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey == "" {
|
||||
// Also check Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
if !sm.apiKeys[apiKey] {
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Security headers middleware
|
||||
func SecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent clickjacking
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
// Prevent MIME type sniffing
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// Enable XSS protection
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
// Content Security Policy
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
// Referrer policy
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
// HSTS (HTTPS only)
|
||||
if r.TLS != nil {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// IP whitelist middleware
|
||||
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
|
||||
// Check if client IP is in whitelist
|
||||
allowed := false
|
||||
for _, ip := range allowedIPs {
|
||||
if strings.Contains(ip, "/") {
|
||||
// CIDR notation - would need proper IP net parsing
|
||||
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if clientIP == ip {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CORS middleware with restrictive defaults
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Only allow specific origins in production
|
||||
allowedOrigins := []string{
|
||||
"https://ml-experiments.example.com",
|
||||
"https://app.example.com",
|
||||
}
|
||||
|
||||
isAllowed := false
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed {
|
||||
isAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isAllowed {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Request timeout middleware
|
||||
func RequestTimeout(timeout time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Request size limiter
|
||||
func RequestSizeLimit(maxSize int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > maxSize {
|
||||
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Security audit logging
|
||||
func AuditLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
path := r.URL.Path
|
||||
raw := r.URL.RawQuery
|
||||
|
||||
// Wrap response writer to capture status code
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
// Process request
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
// Log after processing
|
||||
latency := time.Since(start)
|
||||
clientIP := getClientIP(r)
|
||||
method := r.Method
|
||||
statusCode := wrapped.statusCode
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
// Log security-relevant events
|
||||
if statusCode >= 400 || method == "DELETE" || strings.Contains(path, "/admin") {
|
||||
// Log to security audit system
|
||||
logSecurityEvent(map[string]interface{}{
|
||||
"timestamp": start.Unix(),
|
||||
"client_ip": clientIP,
|
||||
"method": method,
|
||||
"path": path,
|
||||
"status": statusCode,
|
||||
"latency": latency,
|
||||
"user_agent": r.UserAgent(),
|
||||
"referer": r.Referer(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to get client IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// Response writer wrapper to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func logSecurityEvent(event map[string]interface{}) {
|
||||
// Implementation would send to security monitoring system
|
||||
// For now, just log (in production, use proper logging)
|
||||
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
|
||||
}
|
||||
73
internal/network/retry.go
Normal file
73
internal/network/retry.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RetryConfig struct {
|
||||
MaxAttempts int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
}
|
||||
|
||||
func DefaultRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
func Retry(ctx context.Context, cfg RetryConfig, fn func() error) error {
|
||||
var lastErr error
|
||||
delay := cfg.InitialDelay
|
||||
|
||||
for attempt := 0; attempt < cfg.MaxAttempts; attempt++ {
|
||||
if err := fn(); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if attempt < cfg.MaxAttempts-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(math.Min(
|
||||
float64(delay)*cfg.Multiplier,
|
||||
float64(cfg.MaxDelay),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// RetryWithBackoff provides a convenient wrapper for common retry scenarios
|
||||
func RetryWithBackoff(ctx context.Context, maxAttempts int, operation func() error) error {
|
||||
cfg := RetryConfig{
|
||||
MaxAttempts: maxAttempts,
|
||||
InitialDelay: 200 * time.Millisecond,
|
||||
MaxDelay: 2 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
}
|
||||
return Retry(ctx, cfg, operation)
|
||||
}
|
||||
|
||||
// RetryForNetworkOperations is optimized for network-related operations
|
||||
func RetryForNetworkOperations(ctx context.Context, operation func() error) error {
|
||||
cfg := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 200 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
Multiplier: 1.5,
|
||||
}
|
||||
return Retry(ctx, cfg, operation)
|
||||
}
|
||||
304
internal/network/ssh.go
Normal file
304
internal/network/ssh.go
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
)
|
||||
|
||||
// SSHClient provides SSH connection and command execution
|
||||
type SSHClient struct {
|
||||
client *ssh.Client
|
||||
host string
|
||||
}
|
||||
|
||||
// NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client.
|
||||
// knownHostsPath is optional - if provided, will use known_hosts verification
|
||||
func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) {
|
||||
if host == "" || keyPath == "" {
|
||||
// Local mode - no SSH connection needed
|
||||
return &SSHClient{client: nil, host: ""}, nil
|
||||
}
|
||||
|
||||
keyPath = config.ExpandPath(keyPath)
|
||||
if strings.HasPrefix(keyPath, "~") {
|
||||
home, _ := os.UserHomeDir()
|
||||
keyPath = filepath.Join(home, keyPath[1:])
|
||||
}
|
||||
|
||||
key, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read SSH key: %w", err)
|
||||
}
|
||||
|
||||
var signer ssh.Signer
|
||||
if signer, err = ssh.ParsePrivateKey(key); err != nil {
|
||||
if _, ok := err.(*ssh.PassphraseMissingError); ok {
|
||||
// Try to use ssh-agent for passphrase-protected keys
|
||||
if agentSigner, agentErr := sshAgentSigner(); agentErr == nil {
|
||||
signer = agentSigner
|
||||
} else {
|
||||
return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse SSH key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
hostKeyCallback := ssh.InsecureIgnoreHostKey()
|
||||
if knownHostsPath != "" {
|
||||
knownHostsPath = config.ExpandPath(knownHostsPath)
|
||||
if _, err := os.Stat(knownHostsPath); err == nil {
|
||||
callback, err := knownhosts.New(knownHostsPath)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to parse known_hosts: %v; using insecure host key verification", err)
|
||||
} else {
|
||||
hostKeyCallback = callback
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
log.Printf("Warning: known_hosts not found at %s; using insecure host key verification", knownHostsPath)
|
||||
}
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: 10 * time.Second,
|
||||
HostKeyAlgorithms: []string{
|
||||
ssh.KeyAlgoRSA,
|
||||
ssh.KeyAlgoRSASHA256,
|
||||
ssh.KeyAlgoRSASHA512,
|
||||
ssh.KeyAlgoED25519,
|
||||
ssh.KeyAlgoECDSA256,
|
||||
ssh.KeyAlgoECDSA384,
|
||||
ssh.KeyAlgoECDSA521,
|
||||
},
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
client, err := ssh.Dial("tcp", addr, sshConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH connection failed: %w", err)
|
||||
}
|
||||
|
||||
return &SSHClient{client: client, host: host}, nil
|
||||
}
|
||||
|
||||
// Exec executes a command remotely via SSH or locally if in local mode
|
||||
func (c *SSHClient) Exec(cmd string) (string, error) {
|
||||
return c.ExecContext(context.Background(), cmd)
|
||||
}
|
||||
|
||||
// ExecContext executes a command with context support for cancellation and timeouts
|
||||
func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) {
|
||||
if c.client == nil {
|
||||
// Local mode - execute command locally with context
|
||||
execCmd := exec.CommandContext(ctx, "sh", "-c", cmd)
|
||||
output, err := execCmd.CombinedOutput()
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := session.Close(); closeErr != nil {
|
||||
// Session may already be closed, so we just log at debug level
|
||||
log.Printf("session close error (may be expected): %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run command with context cancellation
|
||||
type result struct {
|
||||
output string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
output, err := session.CombinedOutput(cmd)
|
||||
resultCh <- result{string(output), err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// FIXED: Check error return value
|
||||
if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil {
|
||||
log.Printf("failed to send SIGTERM: %v", sigErr)
|
||||
}
|
||||
|
||||
// Give process time to cleanup gracefully
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case res := <-resultCh:
|
||||
// Command finished during graceful shutdown
|
||||
return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output)
|
||||
case <-timer.C:
|
||||
if closeErr := session.Close(); closeErr != nil {
|
||||
log.Printf("failed to force close session: %v", closeErr)
|
||||
}
|
||||
|
||||
// Wait a bit more for final result
|
||||
select {
|
||||
case res := <-resultCh:
|
||||
return res.output, fmt.Errorf("command cancelled and force closed: %w (output: %s)", ctx.Err(), res.output)
|
||||
case <-time.After(5 * time.Second):
|
||||
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
case res := <-resultCh:
|
||||
return res.output, res.err
|
||||
}
|
||||
}
|
||||
|
||||
// FileExists checks if a file exists remotely or locally
|
||||
func (c *SSHClient) FileExists(path string) bool {
|
||||
if c.client == nil {
|
||||
// Local mode - check file locally
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.TrimSpace(out), "exists")
|
||||
}
|
||||
|
||||
// GetFileSize gets the size of a file or directory remotely or locally
|
||||
func (c *SSHClient) GetFileSize(path string) (int64, error) {
|
||||
if c.client == nil {
|
||||
// Local mode - get size locally
|
||||
var size int64
|
||||
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size += info.Size()
|
||||
return nil
|
||||
})
|
||||
return size, err
|
||||
}
|
||||
|
||||
out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var size int64
|
||||
if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil {
|
||||
return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err)
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// RemoteExists checks if a remote path exists (alias for FileExists for compatibility)
|
||||
func (c *SSHClient) RemoteExists(path string) bool {
|
||||
return c.FileExists(path)
|
||||
}
|
||||
|
||||
// ListDir lists directory contents remotely or locally
|
||||
func (c *SSHClient) ListDir(path string) []string {
|
||||
if c.client == nil {
|
||||
// Local mode
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var items []string
|
||||
for _, entry := range entries {
|
||||
items = append(items, entry.Name())
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []string
|
||||
for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
|
||||
if line != "" {
|
||||
items = append(items, line)
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// TailFile gets the last N lines of a file remotely or locally
|
||||
func (c *SSHClient) TailFile(path string, lines int) string {
|
||||
if c.client == nil {
|
||||
// Local mode - read file and return last N lines
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fileLines := strings.Split(string(data), "\n")
|
||||
if len(fileLines) > lines {
|
||||
fileLines = fileLines[len(fileLines)-lines:]
|
||||
}
|
||||
return strings.Join(fileLines, "\n")
|
||||
}
|
||||
|
||||
out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Close closes the SSH connection
|
||||
func (c *SSHClient) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshAgentSigner attempts to get a signer from ssh-agent
|
||||
func sshAgentSigner() (ssh.Signer, error) {
|
||||
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
|
||||
if sshAuthSock == "" {
|
||||
return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", sshAuthSock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("warning: failed to close ssh-agent connection: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
agentClient := agent.NewClient(conn)
|
||||
signers, err := agentClient.Signers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
|
||||
}
|
||||
|
||||
if len(signers) == 0 {
|
||||
return nil, fmt.Errorf("no signers available in ssh-agent")
|
||||
}
|
||||
|
||||
return signers[0], nil
|
||||
}
|
||||
84
internal/network/ssh_pool.go
Executable file
84
internal/network/ssh_pool.go
Executable file
|
|
@ -0,0 +1,84 @@
|
|||
// Package utils provides shared utilities for the fetch_ml project.
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
type SSHPool struct {
|
||||
factory func() (*SSHClient, error)
|
||||
pool chan *SSHClient
|
||||
active int
|
||||
maxConns int
|
||||
mu sync.Mutex
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
func NewSSHPool(maxConns int, factory func() (*SSHClient, error), logger *logging.Logger) *SSHPool {
|
||||
return &SSHPool{
|
||||
factory: factory,
|
||||
pool: make(chan *SSHClient, maxConns),
|
||||
maxConns: maxConns,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHPool) Get(ctx context.Context) (*SSHClient, error) {
|
||||
select {
|
||||
case conn := <-p.pool:
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
p.mu.Lock()
|
||||
if p.active < p.maxConns {
|
||||
p.active++
|
||||
p.mu.Unlock()
|
||||
return p.factory()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
// Wait for available connection
|
||||
select {
|
||||
case conn := <-p.pool:
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHPool) Put(conn *SSHClient) {
|
||||
select {
|
||||
case p.pool <- conn:
|
||||
default:
|
||||
// Pool is full, close connection
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
p.logger.Warn("failed to close SSH connection", "error", err)
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.active--
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHPool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Close all connections in the pool
|
||||
close(p.pool)
|
||||
for conn := range p.pool {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
p.logger.Warn("failed to close SSH connection", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset active count
|
||||
p.active = 0
|
||||
}
|
||||
215
internal/queue/errors.go
Normal file
215
internal/queue/errors.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrorCategory represents the type of error encountered
|
||||
type ErrorCategory string
|
||||
|
||||
const (
|
||||
ErrorNetwork ErrorCategory = "network" // Network connectivity issues
|
||||
ErrorResource ErrorCategory = "resource" // Resource exhaustion (OOM, disk full)
|
||||
ErrorRateLimit ErrorCategory = "rate_limit" // Rate limiting or throttling
|
||||
ErrorAuth ErrorCategory = "auth" // Authentication/authorization failures
|
||||
ErrorValidation ErrorCategory = "validation" // Input validation errors
|
||||
ErrorTimeout ErrorCategory = "timeout" // Operation timeout
|
||||
ErrorPermanent ErrorCategory = "permanent" // Non-retryable errors
|
||||
ErrorUnknown ErrorCategory = "unknown" // Unclassified errors
|
||||
)
|
||||
|
||||
// TaskError wraps an error with category and context
|
||||
type TaskError struct {
|
||||
Category ErrorCategory
|
||||
Message string
|
||||
Cause error
|
||||
Context map[string]string
|
||||
}
|
||||
|
||||
func (e *TaskError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("[%s] %s: %v", e.Category, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", e.Category, e.Message)
|
||||
}
|
||||
|
||||
func (e *TaskError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewTaskError creates a new categorized error
|
||||
func NewTaskError(category ErrorCategory, message string, cause error) *TaskError {
|
||||
return &TaskError{
|
||||
Category: category,
|
||||
Message: message,
|
||||
Cause: cause,
|
||||
Context: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// ClassifyError categorizes an error for retry logic
|
||||
func ClassifyError(err error) ErrorCategory {
|
||||
if err == nil {
|
||||
return ErrorUnknown
|
||||
}
|
||||
|
||||
// Check if already classified
|
||||
var taskErr *TaskError
|
||||
if errors.As(err, &taskErr) {
|
||||
return taskErr.Category
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
|
||||
// Network errors (retryable)
|
||||
networkIndicators := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"connection timeout",
|
||||
"no route to host",
|
||||
"network unreachable",
|
||||
"temporary failure",
|
||||
"dns",
|
||||
"dial tcp",
|
||||
"i/o timeout",
|
||||
}
|
||||
for _, indicator := range networkIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorNetwork
|
||||
}
|
||||
}
|
||||
|
||||
// Resource errors (retryable after delay)
|
||||
resourceIndicators := []string{
|
||||
"out of memory",
|
||||
"oom",
|
||||
"no space left",
|
||||
"disk full",
|
||||
"resource temporarily unavailable",
|
||||
"too many open files",
|
||||
"cannot allocate memory",
|
||||
}
|
||||
for _, indicator := range resourceIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorResource
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting (retryable with backoff)
|
||||
rateLimitIndicators := []string{
|
||||
"rate limit",
|
||||
"too many requests",
|
||||
"throttle",
|
||||
"quota exceeded",
|
||||
"429",
|
||||
}
|
||||
for _, indicator := range rateLimitIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorRateLimit
|
||||
}
|
||||
}
|
||||
|
||||
// Timeout errors (retryable)
|
||||
timeoutIndicators := []string{
|
||||
"timeout",
|
||||
"deadline exceeded",
|
||||
"context deadline",
|
||||
}
|
||||
for _, indicator := range timeoutIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// Authentication errors (not retryable)
|
||||
authIndicators := []string{
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"authentication failed",
|
||||
"invalid credentials",
|
||||
"access denied",
|
||||
"401",
|
||||
"403",
|
||||
}
|
||||
for _, indicator := range authIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorAuth
|
||||
}
|
||||
}
|
||||
|
||||
// Validation errors (not retryable)
|
||||
validationIndicators := []string{
|
||||
"invalid input",
|
||||
"validation failed",
|
||||
"bad request",
|
||||
"malformed",
|
||||
"400",
|
||||
}
|
||||
for _, indicator := range validationIndicators {
|
||||
if strings.Contains(errStr, indicator) {
|
||||
return ErrorValidation
|
||||
}
|
||||
}
|
||||
|
||||
// Default to unknown
|
||||
return ErrorUnknown
|
||||
}
|
||||
|
||||
// IsRetryable determines if an error category should be retried
|
||||
func IsRetryable(category ErrorCategory) bool {
|
||||
switch category {
|
||||
case ErrorNetwork, ErrorResource, ErrorRateLimit, ErrorTimeout, ErrorUnknown:
|
||||
return true
|
||||
case ErrorAuth, ErrorValidation, ErrorPermanent:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserMessage returns a user-friendly error message with suggestions
|
||||
func GetUserMessage(category ErrorCategory, err error) string {
|
||||
messages := map[ErrorCategory]string{
|
||||
ErrorNetwork: "Network connectivity issue. Please check your network connection and try again.",
|
||||
ErrorResource: "System resource exhausted. The system may be under heavy load. Try again later or contact support.",
|
||||
ErrorRateLimit: "Rate limit exceeded. Please wait a moment before retrying.",
|
||||
ErrorAuth: "Authentication failed. Please check your API key or credentials.",
|
||||
ErrorValidation: "Invalid input. Please review your request and correct any errors.",
|
||||
ErrorTimeout: "Operation timed out. The task may be too complex or the system is slow. Try again or simplify the request.",
|
||||
ErrorPermanent: "A permanent error occurred. This task cannot be retried automatically.",
|
||||
ErrorUnknown: "An unexpected error occurred. If this persists, please contact support.",
|
||||
}
|
||||
|
||||
baseMsg := messages[category]
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s (Details: %v)", baseMsg, err)
|
||||
}
|
||||
return baseMsg
|
||||
}
|
||||
|
||||
// RetryDelay calculates the retry delay based on error category and retry count
|
||||
func RetryDelay(category ErrorCategory, retryCount int) int {
|
||||
switch category {
|
||||
case ErrorRateLimit:
|
||||
// Longer backoff for rate limits
|
||||
return min(300, 10*(1<<retryCount)) // 10s, 20s, 40s, 80s, up to 300s
|
||||
case ErrorResource:
|
||||
// Medium backoff for resource issues
|
||||
return min(120, 5*(1<<retryCount)) // 5s, 10s, 20s, 40s, up to 120s
|
||||
case ErrorNetwork, ErrorTimeout:
|
||||
// Standard exponential backoff
|
||||
return 1 << retryCount // 1s, 2s, 4s, 8s, etc
|
||||
default:
|
||||
// Default exponential backoff
|
||||
return 1 << retryCount
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
118
internal/queue/metrics.go
Normal file
118
internal/queue/metrics.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
var (
|
||||
// Queue metrics
|
||||
QueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "fetch_ml_queue_depth",
|
||||
Help: "Number of tasks in the queue",
|
||||
})
|
||||
|
||||
TasksQueued = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_tasks_queued_total",
|
||||
Help: "Total number of tasks queued",
|
||||
})
|
||||
|
||||
// Task execution metrics
|
||||
TaskDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "fetch_ml_task_duration_seconds",
|
||||
Help: "Task execution duration in seconds",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, // 1s to 1h
|
||||
}, []string{"job_name", "status"})
|
||||
|
||||
TasksCompleted = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_tasks_completed_total",
|
||||
Help: "Total number of completed tasks",
|
||||
}, []string{"job_name", "status"})
|
||||
|
||||
// Error metrics
|
||||
TaskFailures = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_task_failures_total",
|
||||
Help: "Total number of failed tasks by error category",
|
||||
}, []string{"job_name", "error_category"})
|
||||
|
||||
TaskRetries = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_task_retries_total",
|
||||
Help: "Total number of task retries",
|
||||
}, []string{"job_name", "error_category"})
|
||||
|
||||
// Lease metrics
|
||||
LeaseExpirations = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_lease_expirations_total",
|
||||
Help: "Total number of expired leases reclaimed",
|
||||
})
|
||||
|
||||
LeaseRenewals = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_lease_renewals_total",
|
||||
Help: "Total number of successful lease renewals",
|
||||
}, []string{"worker_id"})
|
||||
|
||||
// Dead letter queue metrics
|
||||
DLQSize = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "fetch_ml_dlq_size",
|
||||
Help: "Number of tasks in dead letter queue",
|
||||
})
|
||||
|
||||
DLQAdditions = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_dlq_additions_total",
|
||||
Help: "Total number of tasks moved to DLQ",
|
||||
}, []string{"reason"})
|
||||
|
||||
// Worker metrics
|
||||
ActiveTasks = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "fetch_ml_active_tasks",
|
||||
Help: "Number of currently executing tasks",
|
||||
}, []string{"worker_id"})
|
||||
|
||||
WorkerHeartbeats = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "fetch_ml_worker_heartbeats_total",
|
||||
Help: "Total number of worker heartbeats",
|
||||
}, []string{"worker_id"})
|
||||
)
|
||||
|
||||
// RecordTaskStart records when a task starts
|
||||
func RecordTaskStart(jobName, workerID string) {
|
||||
ActiveTasks.WithLabelValues(workerID).Inc()
|
||||
}
|
||||
|
||||
// RecordTaskEnd records when a task completes
|
||||
func RecordTaskEnd(jobName, workerID, status string, durationSeconds float64) {
|
||||
ActiveTasks.WithLabelValues(workerID).Dec()
|
||||
TaskDuration.WithLabelValues(jobName, status).Observe(durationSeconds)
|
||||
TasksCompleted.WithLabelValues(jobName, status).Inc()
|
||||
}
|
||||
|
||||
// RecordTaskFailure records a task failure with error category
|
||||
func RecordTaskFailure(jobName string, errorCategory ErrorCategory) {
|
||||
TaskFailures.WithLabelValues(jobName, string(errorCategory)).Inc()
|
||||
}
|
||||
|
||||
// RecordTaskRetry records a task retry
|
||||
func RecordTaskRetry(jobName string, errorCategory ErrorCategory) {
|
||||
TaskRetries.WithLabelValues(jobName, string(errorCategory)).Inc()
|
||||
}
|
||||
|
||||
// RecordLeaseExpiration records a lease expiration
|
||||
func RecordLeaseExpiration() {
|
||||
LeaseExpirations.Inc()
|
||||
}
|
||||
|
||||
// RecordLeaseRenewal records a successful lease renewal
|
||||
func RecordLeaseRenewal(workerID string) {
|
||||
LeaseRenewals.WithLabelValues(workerID).Inc()
|
||||
}
|
||||
|
||||
// RecordDLQAddition records a task being moved to DLQ
|
||||
func RecordDLQAddition(reason string) {
|
||||
DLQAdditions.WithLabelValues(reason).Inc()
|
||||
DLQSize.Inc()
|
||||
}
|
||||
|
||||
// UpdateQueueDepth updates the current queue depth gauge
|
||||
func UpdateQueueDepth(depth int64) {
|
||||
QueueDepth.Set(float64(depth))
|
||||
}
|
||||
558
internal/queue/queue.go
Normal file
558
internal/queue/queue.go
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMetricsFlushInterval = 500 * time.Millisecond
|
||||
defaultLeaseDuration = 30 * time.Minute
|
||||
defaultMaxRetries = 3
|
||||
)
|
||||
|
||||
// TaskQueue manages ML experiment tasks via Redis
|
||||
type TaskQueue struct {
|
||||
client *redis.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
metricsCh chan metricEvent
|
||||
metricsDone chan struct{}
|
||||
flushEvery time.Duration
|
||||
}
|
||||
|
||||
type metricEvent struct {
|
||||
JobName string
|
||||
Metric string
|
||||
Value float64
|
||||
}
|
||||
|
||||
// Config holds configuration for TaskQueue
|
||||
type Config struct {
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
MetricsFlushInterval time.Duration
|
||||
}
|
||||
|
||||
func NewTaskQueue(cfg Config) (*TaskQueue, error) {
|
||||
var opts *redis.Options
|
||||
var err error
|
||||
|
||||
if len(cfg.RedisAddr) > 8 && cfg.RedisAddr[:8] == "redis://" {
|
||||
opts, err = redis.ParseURL(cfg.RedisAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid redis url: %w", err)
|
||||
}
|
||||
} else {
|
||||
opts = &redis.Options{
|
||||
Addr: cfg.RedisAddr,
|
||||
Password: cfg.RedisPassword,
|
||||
DB: cfg.RedisDB,
|
||||
}
|
||||
}
|
||||
|
||||
rdb := redis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
flushEvery := cfg.MetricsFlushInterval
|
||||
if flushEvery == 0 {
|
||||
flushEvery = defaultMetricsFlushInterval
|
||||
}
|
||||
|
||||
tq := &TaskQueue{
|
||||
client: rdb,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
metricsCh: make(chan metricEvent, 256),
|
||||
metricsDone: make(chan struct{}),
|
||||
flushEvery: flushEvery,
|
||||
}
|
||||
|
||||
go tq.runMetricsBuffer()
|
||||
go tq.runLeaseReclamation() // Start lease reclamation background job
|
||||
|
||||
return tq, nil
|
||||
}
|
||||
|
||||
// AddTask adds a new task to the queue with default retry settings
|
||||
func (tq *TaskQueue) AddTask(task *Task) error {
|
||||
// Set default retry settings if not specified
|
||||
if task.MaxRetries == 0 {
|
||||
task.MaxRetries = defaultMaxRetries
|
||||
}
|
||||
|
||||
taskData, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
pipe := tq.client.Pipeline()
|
||||
|
||||
// Store task data
|
||||
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
|
||||
|
||||
// Add to priority queue (ZSET)
|
||||
// Use priority as score (higher priority = higher score)
|
||||
pipe.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
|
||||
Score: float64(task.Priority),
|
||||
Member: task.ID,
|
||||
})
|
||||
|
||||
// Initialize status
|
||||
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
|
||||
"status", task.Status,
|
||||
"task_id", task.ID,
|
||||
"updated_at", time.Now().Format(time.RFC3339))
|
||||
|
||||
_, err = pipe.Exec(tq.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to enqueue task: %w", err)
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
TasksQueued.Inc()
|
||||
|
||||
// Update queue depth
|
||||
depth, _ := tq.QueueDepth()
|
||||
UpdateQueueDepth(depth)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextTask gets the next task without lease (backward compatible)
|
||||
func (tq *TaskQueue) GetNextTask() (*Task, error) {
|
||||
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
taskID := result[0].Member.(string)
|
||||
return tq.GetTask(taskID)
|
||||
}
|
||||
|
||||
// GetNextTaskWithLease gets the next task and acquires a lease
|
||||
func (tq *TaskQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) {
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
|
||||
// Pop highest priority task
|
||||
result, err := tq.client.ZPopMax(tq.ctx, TaskQueueKey, 1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
taskID := result[0].Member.(string)
|
||||
task, err := tq.GetTask(taskID)
|
||||
if err != nil {
|
||||
// Re-queue the task if we can't fetch it
|
||||
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
|
||||
Score: result[0].Score,
|
||||
Member: taskID,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Acquire lease
|
||||
now := time.Now()
|
||||
leaseExpiry := now.Add(leaseDuration)
|
||||
task.LeaseExpiry = &leaseExpiry
|
||||
task.LeasedBy = workerID
|
||||
|
||||
// Update task with lease
|
||||
if err := tq.UpdateTask(task); err != nil {
|
||||
// Re-queue if update fails
|
||||
tq.client.ZAdd(tq.ctx, TaskQueueKey, redis.Z{
|
||||
Score: result[0].Score,
|
||||
Member: taskID,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// RenewLease renews the lease on a task (heartbeat)
|
||||
func (tq *TaskQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error {
|
||||
if leaseDuration == 0 {
|
||||
leaseDuration = defaultLeaseDuration
|
||||
}
|
||||
|
||||
task, err := tq.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the worker owns the lease
|
||||
if task.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
|
||||
}
|
||||
|
||||
// Renew lease
|
||||
leaseExpiry := time.Now().Add(leaseDuration)
|
||||
task.LeaseExpiry = &leaseExpiry
|
||||
|
||||
// Record renewal metric
|
||||
RecordLeaseRenewal(workerID)
|
||||
|
||||
return tq.UpdateTask(task)
|
||||
}
|
||||
|
||||
// ReleaseLease releases the lease on a task
|
||||
func (tq *TaskQueue) ReleaseLease(taskID string, workerID string) error {
|
||||
task, err := tq.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the worker owns the lease
|
||||
if task.LeasedBy != workerID {
|
||||
return fmt.Errorf("task leased by different worker: %s", task.LeasedBy)
|
||||
}
|
||||
|
||||
// Clear lease
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
return tq.UpdateTask(task)
|
||||
}
|
||||
|
||||
// RetryTask re-queues a failed task with smart backoff based on error category
|
||||
func (tq *TaskQueue) RetryTask(task *Task) error {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Move to dead letter queue
|
||||
RecordDLQAddition("max_retries")
|
||||
return tq.MoveToDeadLetterQueue(task, "max retries exceeded")
|
||||
}
|
||||
|
||||
// Classify the error if it exists
|
||||
var errorCategory ErrorCategory = ErrorUnknown
|
||||
if task.Error != "" {
|
||||
errorCategory = ClassifyError(fmt.Errorf("%s", task.Error))
|
||||
}
|
||||
|
||||
// Check if error is retryable
|
||||
if !IsRetryable(errorCategory) {
|
||||
RecordDLQAddition(string(errorCategory))
|
||||
return tq.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory))
|
||||
}
|
||||
|
||||
task.RetryCount++
|
||||
task.Status = "queued"
|
||||
task.LastError = task.Error // Preserve last error
|
||||
task.Error = "" // Clear current error
|
||||
|
||||
// Calculate smart backoff based on error category
|
||||
backoffSeconds := RetryDelay(errorCategory, task.RetryCount)
|
||||
nextRetry := time.Now().Add(time.Duration(backoffSeconds) * time.Second)
|
||||
task.NextRetry = &nextRetry
|
||||
|
||||
// Clear lease
|
||||
task.LeaseExpiry = nil
|
||||
task.LeasedBy = ""
|
||||
|
||||
// Record retry metrics
|
||||
RecordTaskRetry(task.JobName, errorCategory)
|
||||
|
||||
// Re-queue with same priority
|
||||
return tq.AddTask(task)
|
||||
}
|
||||
|
||||
// MoveToDeadLetterQueue moves a task to the dead letter queue
|
||||
func (tq *TaskQueue) MoveToDeadLetterQueue(task *Task, reason string) error {
|
||||
task.Status = "failed"
|
||||
task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError)
|
||||
|
||||
taskData, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store in dead letter queue with timestamp
|
||||
key := "task:dlq:" + task.ID
|
||||
|
||||
// Record metrics
|
||||
RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError)))
|
||||
|
||||
pipe := tq.client.Pipeline()
|
||||
pipe.Set(tq.ctx, key, taskData, 30*24*time.Hour)
|
||||
pipe.ZRem(tq.ctx, TaskQueueKey, task.ID)
|
||||
_, err = pipe.Exec(tq.ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// runLeaseReclamation reclaims expired leases every 1 minute
|
||||
func (tq *TaskQueue) runLeaseReclamation() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tq.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := tq.reclaimExpiredLeases(); err != nil {
|
||||
// Log error but continue
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reclaimExpiredLeases finds and re-queues tasks with expired leases
|
||||
func (tq *TaskQueue) reclaimExpiredLeases() error {
|
||||
// Scan for all task keys
|
||||
iter := tq.client.Scan(tq.ctx, 0, TaskPrefix+"*", 100).Iterator()
|
||||
now := time.Now()
|
||||
|
||||
for iter.Next(tq.ctx) {
|
||||
taskKey := iter.Val()
|
||||
taskID := taskKey[len(TaskPrefix):]
|
||||
|
||||
task, err := tq.GetTask(taskID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if lease expired and task is still running
|
||||
if task.LeaseExpiry != nil && task.LeaseExpiry.Before(now) && task.Status == "running" {
|
||||
// Lease expired - retry or fail the task
|
||||
task.Error = fmt.Sprintf("worker %s lease expired", task.LeasedBy)
|
||||
|
||||
// Record lease expiration
|
||||
RecordLeaseExpiration()
|
||||
|
||||
if task.RetryCount < task.MaxRetries {
|
||||
// Retry the task
|
||||
if err := tq.RetryTask(task); err != nil {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Max retries exceeded - move to DLQ
|
||||
if err := tq.MoveToDeadLetterQueue(task, "lease expiry after max retries"); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iter.Err()
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID
|
||||
func (tq *TaskQueue) GetTask(taskID string) (*Task, error) {
|
||||
data, err := tq.client.Get(tq.ctx, TaskPrefix+taskID).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var task Task
|
||||
if err := json.Unmarshal([]byte(data), &task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// GetAllTasks retrieves all tasks from the queue
|
||||
func (tq *TaskQueue) GetAllTasks() ([]*Task, error) {
|
||||
// Get all task keys
|
||||
keys, err := tq.client.Keys(tq.ctx, TaskPrefix+"*").Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tasks []*Task
|
||||
for _, key := range keys {
|
||||
data, err := tq.client.Get(tq.ctx, key).Result()
|
||||
if err != nil {
|
||||
continue // Skip tasks that can't be retrieved
|
||||
}
|
||||
|
||||
var task Task
|
||||
if err := json.Unmarshal([]byte(data), &task); err != nil {
|
||||
continue // Skip malformed tasks
|
||||
}
|
||||
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// GetTaskByName retrieves a task by its job name
|
||||
func (tq *TaskQueue) GetTaskByName(jobName string) (*Task, error) {
|
||||
tasks, err := tq.GetAllTasks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
if task.JobName == jobName {
|
||||
return task, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("task with job name '%s' not found", jobName)
|
||||
}
|
||||
|
||||
// CancelTask marks a task as cancelled
|
||||
func (tq *TaskQueue) CancelTask(taskID string) error {
|
||||
task, err := tq.GetTask(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update task status to cancelled
|
||||
task.Status = "cancelled"
|
||||
now := time.Now()
|
||||
task.EndedAt = &now
|
||||
|
||||
return tq.UpdateTask(task)
|
||||
}
|
||||
|
||||
// UpdateTask updates a task in Redis
|
||||
func (tq *TaskQueue) UpdateTask(task *Task) error {
|
||||
taskData, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := tq.client.Pipeline()
|
||||
pipe.Set(tq.ctx, TaskPrefix+task.ID, taskData, 7*24*time.Hour)
|
||||
pipe.HSet(tq.ctx, TaskStatusPrefix+task.JobName,
|
||||
"status", task.Status,
|
||||
"task_id", task.ID,
|
||||
"updated_at", time.Now().Format(time.RFC3339))
|
||||
|
||||
_, err = pipe.Exec(tq.ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTaskWithMetrics updates task and records metrics
|
||||
func (tq *TaskQueue) UpdateTaskWithMetrics(task *Task, action string) error {
|
||||
if err := tq.UpdateTask(task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metricName := "tasks_" + action
|
||||
return tq.RecordMetric(task.JobName, metricName, 1)
|
||||
}
|
||||
|
||||
// RecordMetric records a metric value
|
||||
func (tq *TaskQueue) RecordMetric(jobName, metric string, value float64) error {
|
||||
evt := metricEvent{JobName: jobName, Metric: metric, Value: value}
|
||||
select {
|
||||
case tq.metricsCh <- evt:
|
||||
return nil
|
||||
default:
|
||||
return tq.writeMetrics(jobName, map[string]float64{metric: value})
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat records worker heartbeat
|
||||
func (tq *TaskQueue) Heartbeat(workerID string) error {
|
||||
return tq.client.HSet(tq.ctx, WorkerHeartbeat,
|
||||
workerID, time.Now().Unix()).Err()
|
||||
}
|
||||
|
||||
// QueueDepth returns the number of pending tasks
|
||||
func (tq *TaskQueue) QueueDepth() (int64, error) {
|
||||
return tq.client.ZCard(tq.ctx, TaskQueueKey).Result()
|
||||
}
|
||||
|
||||
// Close closes the task queue and cleans up resources
|
||||
func (tq *TaskQueue) Close() error {
|
||||
tq.cancel()
|
||||
<-tq.metricsDone // Wait for metrics buffer to finish
|
||||
return tq.client.Close()
|
||||
}
|
||||
|
||||
// GetRedisClient returns the underlying Redis client for direct access
|
||||
func (tq *TaskQueue) GetRedisClient() *redis.Client {
|
||||
return tq.client
|
||||
}
|
||||
|
||||
// WaitForNextTask waits for next task with timeout
|
||||
func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) (*Task, error) {
|
||||
if ctx == nil {
|
||||
ctx = tq.ctx
|
||||
}
|
||||
result, err := tq.client.BZPopMax(ctx, timeout, TaskQueueKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
member, ok := result.Member.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected task id type %T", result.Member)
|
||||
}
|
||||
return tq.GetTask(member)
|
||||
}
|
||||
|
||||
// runMetricsBuffer buffers and flushes metrics
|
||||
func (tq *TaskQueue) runMetricsBuffer() {
|
||||
defer close(tq.metricsDone)
|
||||
ticker := time.NewTicker(tq.flushEvery)
|
||||
defer ticker.Stop()
|
||||
pending := make(map[string]map[string]float64)
|
||||
flush := func() {
|
||||
for job, metrics := range pending {
|
||||
if err := tq.writeMetrics(job, metrics); err != nil {
|
||||
continue
|
||||
}
|
||||
delete(pending, job)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tq.ctx.Done():
|
||||
flush()
|
||||
return
|
||||
case evt, ok := <-tq.metricsCh:
|
||||
if !ok {
|
||||
flush()
|
||||
return
|
||||
}
|
||||
if _, exists := pending[evt.JobName]; !exists {
|
||||
pending[evt.JobName] = make(map[string]float64)
|
||||
}
|
||||
pending[evt.JobName][evt.Metric] = evt.Value
|
||||
case <-ticker.C:
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeMetrics writes metrics to Redis
|
||||
func (tq *TaskQueue) writeMetrics(jobName string, metrics map[string]float64) error {
|
||||
if len(metrics) == 0 {
|
||||
return nil
|
||||
}
|
||||
key := JobMetricsPrefix + jobName
|
||||
args := make([]any, 0, len(metrics)*2+2)
|
||||
args = append(args, "timestamp", time.Now().Unix())
|
||||
for metric, value := range metrics {
|
||||
args = append(args, metric, value)
|
||||
}
|
||||
return tq.client.HSet(context.Background(), key, args...).Err()
|
||||
}
|
||||
152
internal/queue/queue_permissions_test.go
Normal file
152
internal/queue/queue_permissions_test.go
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTask_UserFields(t *testing.T) {
|
||||
task := &Task{
|
||||
UserID: "testuser",
|
||||
Username: "testuser",
|
||||
CreatedBy: "testuser",
|
||||
}
|
||||
|
||||
if task.UserID != "testuser" {
|
||||
t.Errorf("Expected UserID to be 'testuser', got '%s'", task.UserID)
|
||||
}
|
||||
|
||||
if task.Username != "testuser" {
|
||||
t.Errorf("Expected Username to be 'testuser', got '%s'", task.Username)
|
||||
}
|
||||
|
||||
if task.CreatedBy != "testuser" {
|
||||
t.Errorf("Expected CreatedBy to be 'testuser', got '%s'", task.CreatedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskQueue_UserFiltering(t *testing.T) {
|
||||
// Setup test Redis configuration
|
||||
queueCfg := Config{
|
||||
RedisAddr: "localhost:6379",
|
||||
RedisDB: 15, // Use dedicated test DB
|
||||
}
|
||||
|
||||
// Create task queue
|
||||
taskQueue, err := NewTaskQueue(queueCfg)
|
||||
if err != nil {
|
||||
t.Skip("Redis not available for integration testing")
|
||||
return
|
||||
}
|
||||
defer taskQueue.Close()
|
||||
|
||||
// Clear test database
|
||||
taskQueue.client.FlushDB(taskQueue.ctx)
|
||||
|
||||
// Create test tasks with different users
|
||||
tasks := []*Task{
|
||||
{
|
||||
ID: "task1",
|
||||
JobName: "user1_job1",
|
||||
Status: "queued",
|
||||
UserID: "user1",
|
||||
CreatedBy: "user1",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "task2",
|
||||
JobName: "user1_job2",
|
||||
Status: "running",
|
||||
UserID: "user1",
|
||||
CreatedBy: "user1",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "task3",
|
||||
JobName: "user2_job1",
|
||||
Status: "queued",
|
||||
UserID: "user2",
|
||||
CreatedBy: "user2",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "task4",
|
||||
JobName: "admin_job",
|
||||
Status: "completed",
|
||||
UserID: "admin",
|
||||
CreatedBy: "admin",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
// Add tasks to queue
|
||||
for _, task := range tasks {
|
||||
err := taskQueue.AddTask(task)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add task %s: %v", task.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetAllTasks
|
||||
allTasks, err := taskQueue.GetAllTasks()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get all tasks: %v", err)
|
||||
}
|
||||
|
||||
if len(allTasks) != len(tasks) {
|
||||
t.Errorf("Expected %d tasks, got %d", len(tasks), len(allTasks))
|
||||
}
|
||||
|
||||
// Test user filtering logic
|
||||
filterTasksForUser := func(tasks []*Task, userID string) []*Task {
|
||||
var filtered []*Task
|
||||
for _, task := range tasks {
|
||||
if task.UserID == userID || task.CreatedBy == userID {
|
||||
filtered = append(filtered, task)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// Test filtering for user1 (should get 2 tasks)
|
||||
user1Tasks := filterTasksForUser(allTasks, "user1")
|
||||
if len(user1Tasks) != 2 {
|
||||
t.Errorf("Expected 2 tasks for user1, got %d", len(user1Tasks))
|
||||
}
|
||||
|
||||
// Test filtering for user2 (should get 1 task)
|
||||
user2Tasks := filterTasksForUser(allTasks, "user2")
|
||||
if len(user2Tasks) != 1 {
|
||||
t.Errorf("Expected 1 task for user2, got %d", len(user2Tasks))
|
||||
}
|
||||
|
||||
// Test filtering for admin (should get 1 task)
|
||||
adminTasks := filterTasksForUser(allTasks, "admin")
|
||||
if len(adminTasks) != 1 {
|
||||
t.Errorf("Expected 1 task for admin, got %d", len(adminTasks))
|
||||
}
|
||||
|
||||
// Test GetTaskByName
|
||||
task, err := taskQueue.GetTaskByName("user1_job1")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get task by name: %v", err)
|
||||
}
|
||||
if task == nil || task.UserID != "user1" {
|
||||
t.Error("Got wrong task or nil task")
|
||||
}
|
||||
|
||||
// Test CancelTask
|
||||
err = taskQueue.CancelTask("task1")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to cancel task: %v", err)
|
||||
}
|
||||
|
||||
// Verify task was cancelled
|
||||
cancelledTask, err := taskQueue.GetTask("task1")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get cancelled task: %v", err)
|
||||
}
|
||||
if cancelledTask.Status != "cancelled" {
|
||||
t.Errorf("Expected status 'cancelled', got '%s'", cancelledTask.Status)
|
||||
}
|
||||
}
|
||||
193
internal/queue/queue_test.go
Normal file
193
internal/queue/queue_test.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTaskQueue(t *testing.T) {
|
||||
// Start miniredis
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start miniredis: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// Create TaskQueue
|
||||
cfg := Config{
|
||||
RedisAddr: s.Addr(),
|
||||
MetricsFlushInterval: 10 * time.Millisecond, // Fast flush for testing
|
||||
}
|
||||
tq, err := NewTaskQueue(cfg)
|
||||
assert.NoError(t, err)
|
||||
defer tq.Close()
|
||||
|
||||
t.Run("AddTask", func(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "task-1",
|
||||
JobName: "job-1",
|
||||
Status: "queued",
|
||||
Priority: 10,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err = tq.AddTask(task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task is in Redis
|
||||
// Check ZSET
|
||||
score, err := s.ZScore(TaskQueueKey, "task-1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float64(10), score)
|
||||
})
|
||||
|
||||
t.Run("GetNextTask", func(t *testing.T) {
|
||||
// Add another task
|
||||
task := &Task{
|
||||
ID: "task-2",
|
||||
JobName: "job-2",
|
||||
Status: "queued",
|
||||
Priority: 20, // Higher priority
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err = tq.AddTask(task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should get task-2 first due to higher priority
|
||||
nextTask, err := tq.GetNextTask()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task-2", nextTask.ID)
|
||||
|
||||
// Verify task is removed from ZSET
|
||||
_, err = tq.client.ZScore(tq.ctx, TaskQueueKey, "task-2").Result()
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
})
|
||||
|
||||
t.Run("GetNextTaskWithLease", func(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "task-lease",
|
||||
JobName: "job-lease",
|
||||
Status: "queued",
|
||||
Priority: 15,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err := tq.AddTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
workerID := "worker-1"
|
||||
leaseDuration := 1 * time.Minute
|
||||
|
||||
leasedTask, err := tq.GetNextTaskWithLease(workerID, leaseDuration)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, leasedTask)
|
||||
assert.Equal(t, "task-lease", leasedTask.ID)
|
||||
assert.Equal(t, workerID, leasedTask.LeasedBy)
|
||||
assert.NotNil(t, leasedTask.LeaseExpiry)
|
||||
assert.True(t, leasedTask.LeaseExpiry.After(time.Now()))
|
||||
})
|
||||
|
||||
t.Run("RenewLease", func(t *testing.T) {
|
||||
taskID := "task-lease"
|
||||
workerID := "worker-1"
|
||||
|
||||
// Get initial expiry
|
||||
task, err := tq.GetTask(taskID)
|
||||
require.NoError(t, err)
|
||||
initialExpiry := task.LeaseExpiry
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Renew lease
|
||||
err = tq.RenewLease(taskID, workerID, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify expiry updated
|
||||
task, err = tq.GetTask(taskID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, task.LeaseExpiry.After(*initialExpiry))
|
||||
})
|
||||
|
||||
t.Run("ReleaseLease", func(t *testing.T) {
|
||||
taskID := "task-lease"
|
||||
workerID := "worker-1"
|
||||
|
||||
err := tq.ReleaseLease(taskID, workerID)
|
||||
require.NoError(t, err)
|
||||
|
||||
task, err := tq.GetTask(taskID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, task.LeaseExpiry)
|
||||
assert.Empty(t, task.LeasedBy)
|
||||
})
|
||||
|
||||
t.Run("RetryTask", func(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "task-retry",
|
||||
JobName: "job-retry",
|
||||
Status: "failed",
|
||||
Priority: 10,
|
||||
CreatedAt: time.Now(),
|
||||
MaxRetries: 3,
|
||||
RetryCount: 0,
|
||||
Error: "some transient error",
|
||||
}
|
||||
|
||||
// Add task directly to verify retry logic
|
||||
err := tq.AddTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate failure and retry
|
||||
task.Error = "connection timeout"
|
||||
err = tq.RetryTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify task updated
|
||||
updatedTask, err := tq.GetTask(task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, updatedTask.RetryCount)
|
||||
assert.Equal(t, "queued", updatedTask.Status)
|
||||
assert.Empty(t, updatedTask.Error)
|
||||
assert.Equal(t, "connection timeout", updatedTask.LastError)
|
||||
assert.NotNil(t, updatedTask.NextRetry)
|
||||
})
|
||||
|
||||
t.Run("DLQ", func(t *testing.T) {
|
||||
task := &Task{
|
||||
ID: "task-dlq",
|
||||
JobName: "job-dlq",
|
||||
Status: "failed",
|
||||
Priority: 10,
|
||||
CreatedAt: time.Now(),
|
||||
MaxRetries: 1,
|
||||
RetryCount: 1, // Already at max retries
|
||||
Error: "fatal error",
|
||||
}
|
||||
|
||||
err := tq.AddTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retry should move to DLQ
|
||||
err = tq.RetryTask(task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removed from main queue
|
||||
_, err = tq.client.ZScore(tq.ctx, TaskQueueKey, task.ID).Result()
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
|
||||
// Verify in DLQ
|
||||
dlqKey := "task:dlq:" + task.ID
|
||||
exists := s.Exists(dlqKey)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Verify DLQ content
|
||||
val, err := s.Get(dlqKey)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, val, "max retries exceeded")
|
||||
})
|
||||
}
|
||||
47
internal/queue/task.go
Normal file
47
internal/queue/task.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
)
|
||||
|
||||
// Task represents an ML experiment task
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"` // queued, running, completed, failed
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// User ownership and permissions
|
||||
UserID string `json:"user_id"` // User who owns this task
|
||||
Username string `json:"username"` // Username for display
|
||||
CreatedBy string `json:"created_by"` // User who submitted the task
|
||||
|
||||
// Lease management for task resilience
|
||||
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // When task lease expires
|
||||
LeasedBy string `json:"leased_by,omitempty"` // Worker ID holding lease
|
||||
|
||||
// Retry management
|
||||
RetryCount int `json:"retry_count"` // Number of retry attempts made
|
||||
MaxRetries int `json:"max_retries"` // Maximum retry limit (default 3)
|
||||
LastError string `json:"last_error,omitempty"` // Last error encountered
|
||||
NextRetry *time.Time `json:"next_retry,omitempty"` // When to retry next (exponential backoff)
|
||||
}
|
||||
|
||||
// Redis key constants
|
||||
var (
|
||||
TaskQueueKey = config.RedisTaskQueueKey
|
||||
TaskPrefix = config.RedisTaskPrefix
|
||||
TaskStatusPrefix = config.RedisTaskStatusPrefix
|
||||
WorkerHeartbeat = config.RedisWorkerHeartbeat
|
||||
JobMetricsPrefix = config.RedisJobMetricsPrefix
|
||||
)
|
||||
433
internal/storage/db.go
Normal file
433
internal/storage/db.go
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type DBConfig struct {
|
||||
Type string
|
||||
Connection string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func NewDB(config DBConfig) (*DB, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
|
||||
switch strings.ToLower(config.Type) {
|
||||
case "sqlite":
|
||||
conn, err = sql.Open("sqlite3", config.Connection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
// Enable foreign keys
|
||||
if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
// Enable WAL mode for better concurrency
|
||||
if _, err := conn.Exec("PRAGMA journal_mode = WAL"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
|
||||
}
|
||||
case "postgres":
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
case "postgresql":
|
||||
// Handle "postgresql" as alias for "postgres"
|
||||
connStr := buildPostgresConnectionString(config)
|
||||
conn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", config.Type)
|
||||
}
|
||||
|
||||
return &DB{conn: conn, dbType: strings.ToLower(config.Type)}, nil
|
||||
}
|
||||
|
||||
func buildPostgresConnectionString(config DBConfig) string {
|
||||
if config.Connection != "" {
|
||||
return config.Connection
|
||||
}
|
||||
|
||||
var connStr strings.Builder
|
||||
connStr.WriteString("host=")
|
||||
if config.Host != "" {
|
||||
connStr.WriteString(config.Host)
|
||||
} else {
|
||||
connStr.WriteString("localhost")
|
||||
}
|
||||
|
||||
if config.Port > 0 {
|
||||
connStr.WriteString(fmt.Sprintf(" port=%d", config.Port))
|
||||
} else {
|
||||
connStr.WriteString(" port=5432")
|
||||
}
|
||||
|
||||
if config.Username != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" user=%s", config.Username))
|
||||
}
|
||||
|
||||
if config.Password != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" password=%s", config.Password))
|
||||
}
|
||||
|
||||
if config.Database != "" {
|
||||
connStr.WriteString(fmt.Sprintf(" dbname=%s", config.Database))
|
||||
} else {
|
||||
connStr.WriteString(" dbname=fetch_ml")
|
||||
}
|
||||
|
||||
connStr.WriteString(" sslmode=disable")
|
||||
return connStr.String()
|
||||
}
|
||||
|
||||
// Legacy constructor for backward compatibility
|
||||
func NewDBFromPath(dbPath string) (*DB, error) {
|
||||
return NewDB(DBConfig{
|
||||
Type: "sqlite",
|
||||
Connection: dbPath,
|
||||
})
|
||||
}
|
||||
|
||||
type Job struct {
|
||||
ID string `json:"id"`
|
||||
JobName string `json:"job_name"`
|
||||
Args string `json:"args"`
|
||||
Status string `json:"status"`
|
||||
Priority int64 `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Datasets []string `json:"datasets,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
ID string `json:"id"`
|
||||
Hostname string `json:"hostname"`
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
Status string `json:"status"`
|
||||
CurrentJobs int `json:"current_jobs"`
|
||||
MaxJobs int `json:"max_jobs"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (db *DB) Initialize(schema string) error {
|
||||
if _, err := db.conn.Exec(schema); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
return db.conn.Close()
|
||||
}
|
||||
|
||||
// Job operations
|
||||
func (db *DB) CreateJob(job *Job) error {
|
||||
datasetsJSON, _ := json.Marshal(job.Datasets)
|
||||
metadataJSON, _ := json.Marshal(job.Metadata)
|
||||
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO jobs (id, job_name, args, status, priority, datasets, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, job.ID, job.JobName, job.Args, job.Status,
|
||||
job.Priority, string(datasetsJSON), string(metadataJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create job: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetJob(id string) (*Job, error) {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
|
||||
ended_at, worker_id, error, datasets, metadata, updated_at
|
||||
FROM jobs WHERE id = ?`
|
||||
} else {
|
||||
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
|
||||
ended_at, worker_id, error, datasets, metadata, updated_at
|
||||
FROM jobs WHERE id = $1`
|
||||
}
|
||||
|
||||
var job Job
|
||||
var datasetsJSON, metadataJSON string
|
||||
var workerID sql.NullString
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := db.conn.QueryRow(query, id).Scan(
|
||||
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
|
||||
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
if workerID.Valid {
|
||||
job.WorkerID = workerID.String
|
||||
}
|
||||
|
||||
if errorMsg.Valid {
|
||||
job.Error = errorMsg.String
|
||||
}
|
||||
|
||||
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
|
||||
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
|
||||
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `UPDATE jobs SET status = ?, worker_id = ?, error = ?,
|
||||
started_at = CASE WHEN ? = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN ? IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = ?`
|
||||
} else {
|
||||
query = `UPDATE jobs SET status = $1, worker_id = $2, error = $3,
|
||||
started_at = CASE WHEN $4 = 'running' AND started_at IS NULL THEN CURRENT_TIMESTAMP ELSE started_at END,
|
||||
ended_at = CASE WHEN $5 IN ('completed', 'failed') AND ended_at IS NULL THEN CURRENT_TIMESTAMP ELSE ended_at END
|
||||
WHERE id = $6`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, status, workerID, errorMsg, status, status, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
|
||||
ended_at, worker_id, error, datasets, metadata, updated_at
|
||||
FROM jobs`
|
||||
} else {
|
||||
query = `SELECT id, job_name, args, status, priority, created_at, started_at,
|
||||
ended_at, worker_id, error, datasets, metadata, updated_at
|
||||
FROM jobs`
|
||||
}
|
||||
|
||||
var args []interface{}
|
||||
if status != "" {
|
||||
if db.dbType == "sqlite" {
|
||||
query += " WHERE status = ?"
|
||||
} else {
|
||||
query += " WHERE status = $1"
|
||||
}
|
||||
args = append(args, status)
|
||||
}
|
||||
query += " ORDER BY created_at DESC"
|
||||
if limit > 0 {
|
||||
if db.dbType == "sqlite" {
|
||||
query += " LIMIT ?"
|
||||
} else {
|
||||
query += fmt.Sprintf(" LIMIT $%d", len(args)+1)
|
||||
}
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list jobs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*Job
|
||||
for rows.Next() {
|
||||
var job Job
|
||||
var datasetsJSON, metadataJSON string
|
||||
var workerID sql.NullString
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := rows.Scan(&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&job.CreatedAt, &job.StartedAt, &job.EndedAt, &workerID,
|
||||
&errorMsg, &datasetsJSON, &metadataJSON, &job.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
if workerID.Valid {
|
||||
job.WorkerID = workerID.String
|
||||
}
|
||||
|
||||
if errorMsg.Valid {
|
||||
job.Error = errorMsg.String
|
||||
}
|
||||
|
||||
json.Unmarshal([]byte(datasetsJSON), &job.Datasets)
|
||||
json.Unmarshal([]byte(metadataJSON), &job.Metadata)
|
||||
|
||||
jobs = append(jobs, &job)
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// Worker operations
|
||||
func (db *DB) RegisterWorker(worker *Worker) error {
|
||||
metadataJSON, _ := json.Marshal(worker.Metadata)
|
||||
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `INSERT OR REPLACE INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO workers (id, hostname, status, current_jobs, max_jobs, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
hostname = EXCLUDED.hostname,
|
||||
status = EXCLUDED.status,
|
||||
current_jobs = EXCLUDED.current_jobs,
|
||||
max_jobs = EXCLUDED.max_jobs,
|
||||
metadata = EXCLUDED.metadata`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, worker.ID, worker.Hostname, worker.Status,
|
||||
worker.CurrentJobs, worker.MaxJobs, string(metadataJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateWorkerHeartbeat(workerID string) error {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = ?`
|
||||
} else {
|
||||
query = `UPDATE workers SET last_heartbeat = CURRENT_TIMESTAMP WHERE id = $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, workerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update worker heartbeat: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetActiveWorkers() ([]*Worker, error) {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
|
||||
FROM workers WHERE status = 'active' AND last_heartbeat > datetime('now', '-30 seconds')`
|
||||
} else {
|
||||
query = `SELECT id, hostname, last_heartbeat, status, current_jobs, max_jobs, metadata
|
||||
FROM workers WHERE status = 'active' AND last_heartbeat > NOW() - INTERVAL '30 seconds'`
|
||||
}
|
||||
|
||||
rows, err := db.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active workers: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var workers []*Worker
|
||||
for rows.Next() {
|
||||
var worker Worker
|
||||
var metadataJSON string
|
||||
|
||||
err := rows.Scan(&worker.ID, &worker.Hostname, &worker.LastHeartbeat,
|
||||
&worker.Status, &worker.CurrentJobs, &worker.MaxJobs, &metadataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan worker: %w", err)
|
||||
}
|
||||
|
||||
json.Unmarshal([]byte(metadataJSON), &worker.Metadata)
|
||||
workers = append(workers, &worker)
|
||||
}
|
||||
|
||||
return workers, nil
|
||||
}
|
||||
|
||||
// Metrics operations
|
||||
func (db *DB) RecordJobMetric(jobID, metricName, metricValue string) error {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES (?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO job_metrics (job_id, metric_name, metric_value) VALUES ($1, $2, $3)`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, jobID, metricName, metricValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record job metric: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) RecordSystemMetric(metricName, metricValue string) error {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES (?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO system_metrics (metric_name, metric_value) VALUES ($1, $2)`
|
||||
}
|
||||
|
||||
_, err := db.conn.Exec(query, metricName, metricValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record system metric: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetJobMetrics(jobID string) (map[string]string, error) {
|
||||
var query string
|
||||
if db.dbType == "sqlite" {
|
||||
query = `SELECT metric_name, metric_value FROM job_metrics
|
||||
WHERE job_id = ? ORDER BY timestamp DESC`
|
||||
} else {
|
||||
query = `SELECT metric_name, metric_value FROM job_metrics
|
||||
WHERE job_id = $1 ORDER BY timestamp DESC`
|
||||
}
|
||||
|
||||
rows, err := db.conn.Query(query, jobID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get job metrics: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
metrics := make(map[string]string)
|
||||
for rows.Next() {
|
||||
var name, value string
|
||||
if err := rows.Scan(&name, &value); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan metric: %w", err)
|
||||
}
|
||||
metrics[name] = value
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
212
internal/storage/db_test.go
Normal file
212
internal/storage/db_test.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
// Use a temporary database
|
||||
dbPath := t.TempDir() + "/test.db"
|
||||
|
||||
// Initialize database
|
||||
db, err := NewDBFromPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize schema
|
||||
schema, err := os.ReadFile("schema.sql")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read schema: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Initialize(string(schema)); err != nil {
|
||||
t.Fatalf("Failed to initialize schema: %v", err)
|
||||
}
|
||||
|
||||
// Test job creation
|
||||
job := &Job{
|
||||
ID: "test-job-1",
|
||||
JobName: "test_experiment",
|
||||
Args: "--epochs 10 --lr 0.001",
|
||||
Status: "pending",
|
||||
Priority: 1,
|
||||
Datasets: []string{"dataset1", "dataset2"},
|
||||
Metadata: map[string]string{"gpu": "true", "memory": "8GB"},
|
||||
}
|
||||
|
||||
if err := db.CreateJob(job); err != nil {
|
||||
t.Fatalf("Failed to create job: %v", err)
|
||||
}
|
||||
|
||||
// Verify job exists in database
|
||||
var count int
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE id = ?", "test-job-1").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to verify job creation: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("Expected 1 job in database, got %d", count)
|
||||
}
|
||||
|
||||
// Test job retrieval
|
||||
retrievedJob, err := db.GetJob("test-job-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get job: %v", err)
|
||||
}
|
||||
|
||||
if retrievedJob.ID != job.ID {
|
||||
t.Errorf("Expected job ID %s, got %s", job.ID, retrievedJob.ID)
|
||||
}
|
||||
|
||||
if retrievedJob.JobName != job.JobName {
|
||||
t.Errorf("Expected job name %s, got %s", job.JobName, retrievedJob.JobName)
|
||||
}
|
||||
|
||||
if len(retrievedJob.Datasets) != 2 {
|
||||
t.Errorf("Expected 2 datasets, got %d", len(retrievedJob.Datasets))
|
||||
}
|
||||
|
||||
if retrievedJob.Metadata["gpu"] != "true" {
|
||||
t.Errorf("Expected gpu=true, got %s", retrievedJob.Metadata["gpu"])
|
||||
}
|
||||
|
||||
// Test job status update
|
||||
if err := db.UpdateJobStatus("test-job-1", "running", "worker-1", ""); err != nil {
|
||||
t.Fatalf("Failed to update job status: %v", err)
|
||||
}
|
||||
|
||||
// Verify status update
|
||||
updatedJob, err := db.GetJob("test-job-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated job: %v", err)
|
||||
}
|
||||
|
||||
if updatedJob.Status != "running" {
|
||||
t.Errorf("Expected status running, got %s", updatedJob.Status)
|
||||
}
|
||||
|
||||
if updatedJob.WorkerID != "worker-1" {
|
||||
t.Errorf("Expected worker ID worker-1, got %s", updatedJob.WorkerID)
|
||||
}
|
||||
|
||||
if updatedJob.StartedAt == nil {
|
||||
t.Error("Expected StartedAt to be set")
|
||||
}
|
||||
|
||||
// Test worker registration
|
||||
worker := &Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "test-host",
|
||||
Status: "active",
|
||||
CurrentJobs: 0,
|
||||
MaxJobs: 2,
|
||||
Metadata: map[string]string{"cpu": "8", "memory": "16GB"},
|
||||
}
|
||||
|
||||
if err := db.RegisterWorker(worker); err != nil {
|
||||
t.Fatalf("Failed to register worker: %v", err)
|
||||
}
|
||||
|
||||
// Test worker heartbeat
|
||||
if err := db.UpdateWorkerHeartbeat("worker-1"); err != nil {
|
||||
t.Fatalf("Failed to update worker heartbeat: %v", err)
|
||||
}
|
||||
|
||||
// Test metrics recording
|
||||
if err := db.RecordJobMetric("test-job-1", "accuracy", "0.95"); err != nil {
|
||||
t.Fatalf("Failed to record job metric: %v", err)
|
||||
}
|
||||
|
||||
if err := db.RecordSystemMetric("cpu_usage", "75"); err != nil {
|
||||
t.Fatalf("Failed to record system metric: %v", err)
|
||||
}
|
||||
|
||||
// Test metrics retrieval
|
||||
metrics, err := db.GetJobMetrics("test-job-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get job metrics: %v", err)
|
||||
}
|
||||
|
||||
if metrics["accuracy"] != "0.95" {
|
||||
t.Errorf("Expected accuracy 0.95, got %s", metrics["accuracy"])
|
||||
}
|
||||
|
||||
// Test job listing
|
||||
jobs, err := db.ListJobs("", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list jobs: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Found %d jobs", len(jobs))
|
||||
for i, job := range jobs {
|
||||
t.Logf("Job %d: ID=%s, Status=%s", i, job.ID, job.Status)
|
||||
}
|
||||
|
||||
if len(jobs) != 1 {
|
||||
t.Errorf("Expected 1 job, got %d", len(jobs))
|
||||
return
|
||||
}
|
||||
|
||||
if jobs[0].ID != "test-job-1" {
|
||||
t.Errorf("Expected job ID test-job-1, got %s", jobs[0].ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test active workers
|
||||
workers, err := db.GetActiveWorkers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get active workers: %v", err)
|
||||
}
|
||||
|
||||
if len(workers) != 1 {
|
||||
t.Errorf("Expected 1 active worker, got %d", len(workers))
|
||||
}
|
||||
|
||||
if workers[0].ID != "worker-1" {
|
||||
t.Errorf("Expected worker ID worker-1, got %s", workers[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBConstraints(t *testing.T) {
|
||||
dbPath := t.TempDir() + "/test_constraints.db"
|
||||
|
||||
db, err := NewDBFromPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
schema, err := os.ReadFile("schema.sql")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read schema: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Initialize(string(schema)); err != nil {
|
||||
t.Fatalf("Failed to initialize schema: %v", err)
|
||||
}
|
||||
|
||||
// Test duplicate job ID
|
||||
job := &Job{
|
||||
ID: "duplicate-test",
|
||||
JobName: "test",
|
||||
Status: "pending",
|
||||
}
|
||||
|
||||
if err := db.CreateJob(job); err != nil {
|
||||
t.Fatalf("Failed to create first job: %v", err)
|
||||
}
|
||||
|
||||
// Should fail on duplicate
|
||||
if err := db.CreateJob(job); err == nil {
|
||||
t.Error("Expected error when creating duplicate job")
|
||||
}
|
||||
|
||||
// Test getting non-existent job
|
||||
_, err = db.GetJob("non-existent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting non-existent job")
|
||||
}
|
||||
}
|
||||
257
internal/storage/migrate.go
Normal file
257
internal/storage/migrate.go
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// Migrator handles migration from Redis to SQLite
|
||||
type Migrator struct {
|
||||
redisClient *redis.Client
|
||||
sqliteDB *DB
|
||||
}
|
||||
|
||||
func NewMigrator(redisAddr, sqlitePath string) (*Migrator, error) {
|
||||
// Connect to Redis
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: redisAddr,
|
||||
})
|
||||
|
||||
// Connect to SQLite
|
||||
db, err := NewDBFromPath(sqlitePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to SQLite: %w", err)
|
||||
}
|
||||
|
||||
return &Migrator{
|
||||
redisClient: rdb,
|
||||
sqliteDB: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Migrator) Close() error {
|
||||
if err := m.sqliteDB.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.redisClient.Close()
|
||||
}
|
||||
|
||||
// MigrateJobs migrates job data from Redis to SQLite
|
||||
func (m *Migrator) MigrateJobs(ctx context.Context) error {
|
||||
log.Println("Starting job migration from Redis to SQLite...")
|
||||
|
||||
// Get all job keys from Redis
|
||||
jobKeys, err := m.redisClient.Keys(ctx, "job:*").Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job keys from Redis: %w", err)
|
||||
}
|
||||
|
||||
for _, jobKey := range jobKeys {
|
||||
jobData, err := m.redisClient.HGetAll(ctx, jobKey).Result()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get job data for %s: %v", jobKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse job data
|
||||
job := &Job{
|
||||
ID: jobKey[4:], // Remove "job:" prefix
|
||||
JobName: jobData["job_name"],
|
||||
Args: jobData["args"],
|
||||
Status: jobData["status"],
|
||||
Priority: parsePriority(jobData["priority"]),
|
||||
WorkerID: jobData["worker_id"],
|
||||
Error: jobData["error"],
|
||||
}
|
||||
|
||||
// Parse timestamps
|
||||
if createdAtStr := jobData["created_at"]; createdAtStr != "" {
|
||||
if ts, err := time.Parse(time.RFC3339, createdAtStr); err == nil {
|
||||
job.CreatedAt = ts
|
||||
}
|
||||
}
|
||||
|
||||
if startedAtStr := jobData["started_at"]; startedAtStr != "" {
|
||||
if ts, err := time.Parse(time.RFC3339, startedAtStr); err == nil {
|
||||
job.StartedAt = &ts
|
||||
}
|
||||
}
|
||||
|
||||
if endedAtStr := jobData["ended_at"]; endedAtStr != "" {
|
||||
if ts, err := time.Parse(time.RFC3339, endedAtStr); err == nil {
|
||||
job.EndedAt = &ts
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON fields
|
||||
if datasetsStr := jobData["datasets"]; datasetsStr != "" {
|
||||
json.Unmarshal([]byte(datasetsStr), &job.Datasets)
|
||||
}
|
||||
|
||||
if metadataStr := jobData["metadata"]; metadataStr != "" {
|
||||
json.Unmarshal([]byte(metadataStr), &job.Metadata)
|
||||
}
|
||||
|
||||
// Insert into SQLite
|
||||
if err := m.sqliteDB.CreateJob(job); err != nil {
|
||||
log.Printf("Failed to create job %s in SQLite: %v", job.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Migrated job: %s", job.ID)
|
||||
}
|
||||
|
||||
log.Printf("Migrated %d jobs from Redis to SQLite", len(jobKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateMetrics migrates metrics from Redis to SQLite
|
||||
func (m *Migrator) MigrateMetrics(ctx context.Context) error {
|
||||
log.Println("Starting metrics migration from Redis to SQLite...")
|
||||
|
||||
// Get all metric keys from Redis
|
||||
metricKeys, err := m.redisClient.Keys(ctx, "metrics:*").Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get metric keys from Redis: %w", err)
|
||||
}
|
||||
|
||||
for _, metricKey := range metricKeys {
|
||||
metricData, err := m.redisClient.HGetAll(ctx, metricKey).Result()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get metric data for %s: %v", metricKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse metric key format: metrics:job:job_id or metrics:system
|
||||
parts := parseMetricKey(metricKey)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
metricType := parts[1] // "job" or "system"
|
||||
|
||||
for name, value := range metricData {
|
||||
if metricType == "job" && len(parts) == 3 {
|
||||
// Job metric
|
||||
jobID := parts[2]
|
||||
if err := m.sqliteDB.RecordJobMetric(jobID, name, value); err != nil {
|
||||
log.Printf("Failed to record job metric %s for job %s: %v", name, jobID, err)
|
||||
}
|
||||
} else if metricType == "system" {
|
||||
// System metric
|
||||
if err := m.sqliteDB.RecordSystemMetric(name, value); err != nil {
|
||||
log.Printf("Failed to record system metric %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Migrated %d metric keys from Redis to SQLite", len(metricKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateWorkers migrates worker data from Redis to SQLite
|
||||
func (m *Migrator) MigrateWorkers(ctx context.Context) error {
|
||||
log.Println("Starting worker migration from Redis to SQLite...")
|
||||
|
||||
// Get all worker keys from Redis
|
||||
workerKeys, err := m.redisClient.Keys(ctx, "worker:*").Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get worker keys from Redis: %w", err)
|
||||
}
|
||||
|
||||
for _, workerKey := range workerKeys {
|
||||
workerData, err := m.redisClient.HGetAll(ctx, workerKey).Result()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get worker data for %s: %v", workerKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
ID: workerKey[8:], // Remove "worker:" prefix
|
||||
Hostname: workerData["hostname"],
|
||||
Status: workerData["status"],
|
||||
CurrentJobs: parseInt(workerData["current_jobs"]),
|
||||
MaxJobs: parseInt(workerData["max_jobs"]),
|
||||
}
|
||||
|
||||
// Parse heartbeat
|
||||
if heartbeatStr := workerData["last_heartbeat"]; heartbeatStr != "" {
|
||||
if ts, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
|
||||
worker.LastHeartbeat = ts
|
||||
}
|
||||
}
|
||||
|
||||
// Parse metadata
|
||||
if metadataStr := workerData["metadata"]; metadataStr != "" {
|
||||
json.Unmarshal([]byte(metadataStr), &worker.Metadata)
|
||||
}
|
||||
|
||||
// Insert into SQLite
|
||||
if err := m.sqliteDB.RegisterWorker(worker); err != nil {
|
||||
log.Printf("Failed to register worker %s in SQLite: %v", worker.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Migrated worker: %s", worker.ID)
|
||||
}
|
||||
|
||||
log.Printf("Migrated %d workers from Redis to SQLite", len(workerKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateAll performs complete migration from Redis to SQLite
|
||||
func (m *Migrator) MigrateAll(ctx context.Context) error {
|
||||
log.Println("Starting complete migration from Redis to SQLite...")
|
||||
|
||||
// Test connections
|
||||
if err := m.redisClient.Ping(ctx).Err(); err != nil {
|
||||
return fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
// Run migrations in order
|
||||
if err := m.MigrateJobs(ctx); err != nil {
|
||||
return fmt.Errorf("job migration failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.MigrateWorkers(ctx); err != nil {
|
||||
return fmt.Errorf("worker migration failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.MigrateMetrics(ctx); err != nil {
|
||||
return fmt.Errorf("metrics migration failed: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Migration completed successfully!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func parsePriority(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
// Implementation depends on your priority format
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseInt(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
// Implementation depends on your int format
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseMetricKey(key string) []string {
|
||||
// Simple split - adjust based on your Redis key format
|
||||
parts := strings.Split(key, ":")
|
||||
return parts
|
||||
}
|
||||
61
internal/storage/schema.sql
Normal file
61
internal/storage/schema.sql
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
-- SQLite schema for Fetch ML job persistence
|
||||
-- Complements Redis for task queuing
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
job_name TEXT NOT NULL,
|
||||
args TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
priority INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
ended_at DATETIME,
|
||||
worker_id TEXT,
|
||||
error TEXT,
|
||||
datasets TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_metrics (
|
||||
job_id TEXT,
|
||||
metric_name TEXT,
|
||||
metric_value TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (job_id, metric_name, timestamp),
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS workers (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
last_heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
status TEXT DEFAULT 'active',
|
||||
current_jobs INTEGER DEFAULT 0,
|
||||
max_jobs INTEGER DEFAULT 1,
|
||||
metadata TEXT -- JSON object
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS system_metrics (
|
||||
metric_name TEXT,
|
||||
metric_value TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (metric_name, timestamp)
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
|
||||
|
||||
-- Triggers to update timestamps
|
||||
CREATE TRIGGER IF NOT EXISTS update_jobs_timestamp
|
||||
AFTER UPDATE ON jobs
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE jobs SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END;
|
||||
68
internal/storage/schema_postgres.sql
Normal file
68
internal/storage/schema_postgres.sql
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
-- PostgreSQL schema for Fetch ML job persistence
|
||||
-- Complements Redis for task queuing
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
job_name TEXT NOT NULL,
|
||||
args TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
priority INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
ended_at TIMESTAMP WITH TIME ZONE,
|
||||
worker_id TEXT,
|
||||
error TEXT,
|
||||
datasets TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_metrics (
|
||||
job_id TEXT,
|
||||
metric_name TEXT,
|
||||
metric_value TEXT,
|
||||
timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (job_id, metric_name, timestamp),
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS workers (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
last_heartbeat TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
status TEXT DEFAULT 'active',
|
||||
current_jobs INTEGER DEFAULT 0,
|
||||
max_jobs INTEGER DEFAULT 1,
|
||||
metadata TEXT -- JSON object
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS system_metrics (
|
||||
metric_name TEXT,
|
||||
metric_value TEXT,
|
||||
timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (metric_name, timestamp)
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_worker_id ON jobs(worker_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_metrics_job_id ON job_metrics(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_metrics_timestamp ON job_metrics(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat ON workers(last_heartbeat);
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics(timestamp);
|
||||
|
||||
-- Function to update updated_at timestamp
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql';
|
||||
|
||||
-- Trigger to update timestamps
|
||||
CREATE TRIGGER update_jobs_timestamp
|
||||
BEFORE UPDATE ON jobs
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
77
internal/telemetry/telemetry.go
Normal file
77
internal/telemetry/telemetry.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package telemetry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
type IOStats struct {
|
||||
ReadBytes uint64
|
||||
WriteBytes uint64
|
||||
}
|
||||
|
||||
func ReadProcessIO() (IOStats, error) {
|
||||
f, err := os.Open("/proc/self/io")
|
||||
if err != nil {
|
||||
return IOStats{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var stats IOStats
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "read_bytes:") {
|
||||
stats.ReadBytes = parseUintField(line)
|
||||
}
|
||||
if strings.HasPrefix(line, "write_bytes:") {
|
||||
stats.WriteBytes = parseUintField(line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return IOStats{}, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func DiffIO(before, after IOStats) IOStats {
|
||||
var delta IOStats
|
||||
if after.ReadBytes >= before.ReadBytes {
|
||||
delta.ReadBytes = after.ReadBytes - before.ReadBytes
|
||||
}
|
||||
if after.WriteBytes >= before.WriteBytes {
|
||||
delta.WriteBytes = after.WriteBytes - before.WriteBytes
|
||||
}
|
||||
return delta
|
||||
}
|
||||
|
||||
func parseUintField(line string) uint64 {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) != 2 {
|
||||
return 0
|
||||
}
|
||||
value, err := strconv.ParseUint(strings.TrimSpace(parts[1]), 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func ExecWithMetrics(logger *logging.Logger, description string, threshold time.Duration, fn func() (string, error)) (string, error) {
|
||||
start := time.Now()
|
||||
out, err := fn()
|
||||
duration := time.Since(start)
|
||||
if duration > threshold {
|
||||
fields := []any{"latency_ms", duration.Milliseconds(), "command", description}
|
||||
if err != nil {
|
||||
fields = append(fields, "error", err)
|
||||
}
|
||||
logger.Debug("ssh exec", fields...)
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
Loading…
Reference in a new issue