fetch_ml/cmd/api-server/main.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

444 lines
12 KiB
Go

// Package main implements the fetch_ml API server
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"gopkg.in/yaml.v3"
)
// Config structure matching worker config.
type Config struct {
BasePath string `yaml:"base_path"`
Auth auth.Config `yaml:"auth"`
Server ServerConfig `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
Resources config.ResourceConfig `yaml:"resources"`
}
// RedisConfig holds Redis connection configuration.
type RedisConfig struct {
Addr string `yaml:"addr"`
Password string `yaml:"password"`
DB int `yaml:"db"`
URL string `yaml:"url"`
}
// DatabaseConfig holds database connection configuration.
type DatabaseConfig struct {
Type string `yaml:"type"`
Connection string `yaml:"connection"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
// SecurityConfig holds security-related configuration.
type SecurityConfig struct {
RateLimit RateLimitConfig `yaml:"rate_limit"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
}
// RateLimitConfig holds rate limiting configuration.
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
}
// LockoutConfig holds failed login lockout configuration.
type LockoutConfig struct {
Enabled bool `yaml:"enabled"`
MaxAttempts int `yaml:"max_attempts"`
LockoutDuration string `yaml:"lockout_duration"`
}
// ServerConfig holds server configuration.
type ServerConfig struct {
Address string `yaml:"address"`
TLS TLSConfig `yaml:"tls"`
}
// TLSConfig holds TLS configuration.
type TLSConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
// LoadConfig loads configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func main() {
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
apiKey := flag.String("api-key", "", "API key for authentication")
flag.Parse()
cfg, err := loadServerConfig(*configFile)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
if err := ensureLogDirectory(cfg.Logging); err != nil {
log.Fatalf("Failed to prepare log directory: %v", err)
}
logger := setupLogger(cfg.Logging)
expManager, err := initExperimentManager(cfg.BasePath, logger)
if err != nil {
logger.Fatal("failed to initialize experiment manager", "error", err)
}
taskQueue, queueCleanup := initTaskQueue(cfg, logger)
if queueCleanup != nil {
defer queueCleanup()
}
db, dbCleanup := initDatabase(cfg, logger)
if dbCleanup != nil {
defer dbCleanup()
}
authCfg := buildAuthConfig(cfg.Auth, logger)
sec := newSecurityMiddleware(cfg)
mux := buildHTTPMux(cfg, logger, expManager, taskQueue, authCfg, db)
finalHandler := wrapWithMiddleware(cfg, sec, mux)
server := newHTTPServer(cfg, finalHandler)
startServer(server, cfg, logger)
waitForShutdown(server, logger)
_ = apiKey // Reserved for future authentication enhancements
}
func loadServerConfig(path string) (*Config, error) {
resolvedConfig, err := config.ResolveConfigPath(path)
if err != nil {
return nil, err
}
cfg, err := LoadConfig(resolvedConfig)
if err != nil {
return nil, err
}
cfg.Resources.ApplyDefaults()
return cfg, nil
}
func ensureLogDirectory(cfg logging.Config) error {
if cfg.File == "" {
return nil
}
logDir := filepath.Dir(cfg.File)
log.Printf("Creating log directory: %s", logDir)
return os.MkdirAll(logDir, 0750)
}
func setupLogger(cfg logging.Config) *logging.Logger {
logger := logging.NewLoggerFromConfig(cfg)
ctx := logging.EnsureTrace(context.Background())
return logger.Component(ctx, "api-server")
}
func initExperimentManager(basePath string, logger *logging.Logger) (*experiment.Manager, error) {
if basePath == "" {
basePath = "/tmp/ml-experiments"
}
expManager := experiment.NewManager(basePath)
log.Printf("Initializing experiment manager with base_path: %s", basePath)
if err := expManager.Initialize(); err != nil {
return nil, err
}
logger.Info("experiment manager initialized", "base_path", basePath)
return expManager, nil
}
func buildAuthConfig(cfg auth.Config, logger *logging.Logger) *auth.Config {
if !cfg.Enabled {
return nil
}
logger.Info("authentication enabled")
return &cfg
}
func newSecurityMiddleware(cfg *Config) *middleware.SecurityMiddleware {
apiKeys := collectAPIKeys(cfg.Auth.APIKeys)
rlOpts := buildRateLimitOptions(cfg.Security.RateLimit)
return middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"), rlOpts)
}
func collectAPIKeys(keys map[auth.Username]auth.APIKeyEntry) []string {
apiKeys := make([]string, 0, len(keys))
for username := range keys {
apiKeys = append(apiKeys, string(username))
}
return apiKeys
}
func buildRateLimitOptions(cfg RateLimitConfig) *middleware.RateLimitOptions {
if !cfg.Enabled || cfg.RequestsPerMinute <= 0 {
return nil
}
return &middleware.RateLimitOptions{
RequestsPerMinute: cfg.RequestsPerMinute,
BurstSize: cfg.BurstSize,
}
}
func initTaskQueue(cfg *Config, logger *logging.Logger) (*queue.TaskQueue, func()) {
queueCfg := queue.Config{
RedisAddr: cfg.Redis.Addr,
RedisPassword: cfg.Redis.Password,
RedisDB: cfg.Redis.DB,
}
if queueCfg.RedisAddr == "" {
queueCfg.RedisAddr = config.DefaultRedisAddr
}
if cfg.Redis.URL != "" {
queueCfg.RedisAddr = cfg.Redis.URL
}
taskQueue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
logger.Error("failed to initialize task queue", "error", err)
return nil, nil
}
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
cleanup := func() {
logger.Info("stopping task queue...")
if err := taskQueue.Close(); err != nil {
logger.Error("failed to stop task queue", "error", err)
} else {
logger.Info("task queue stopped")
}
}
return taskQueue, cleanup
}
func initDatabase(cfg *Config, logger *logging.Logger) (*storage.DB, func()) {
if cfg.Database.Type == "" {
return nil, nil
}
dbConfig := storage.DBConfig{
Type: cfg.Database.Type,
Connection: cfg.Database.Connection,
Host: cfg.Database.Host,
Port: cfg.Database.Port,
Username: cfg.Database.Username,
Password: cfg.Database.Password,
Database: cfg.Database.Database,
}
db, err := storage.NewDB(dbConfig)
if err != nil {
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
return nil, nil
}
schemaPath := schemaPathForDB(cfg.Database.Type)
if schemaPath == "" {
logger.Error("unsupported database type", "type", cfg.Database.Type)
_ = db.Close()
return nil, nil
}
schema, err := fileutil.SecureFileRead(schemaPath)
if err != nil {
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
_ = db.Close()
return nil, nil
}
if err := db.Initialize(string(schema)); err != nil {
logger.Error("failed to initialize database schema", "error", err)
_ = db.Close()
return nil, nil
}
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
cleanup := func() {
logger.Info("closing database connection...")
if err := db.Close(); err != nil {
logger.Error("failed to close database", "error", err)
} else {
logger.Info("database connection closed")
}
}
return db, cleanup
}
func schemaPathForDB(dbType string) string {
switch dbType {
case "sqlite":
return "internal/storage/schema_sqlite.sql"
case "postgres", "postgresql":
return "internal/storage/schema_postgres.sql"
default:
return ""
}
}
func buildHTTPMux(
cfg *Config,
logger *logging.Logger,
expManager *experiment.Manager,
taskQueue *queue.TaskQueue,
authCfg *auth.Config,
db *storage.DB,
) *http.ServeMux {
mux := http.NewServeMux()
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
mux.Handle("/ws", wsHandler)
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprintf(w, "OK\n")
})
mux.HandleFunc("/db-status", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
if db == nil {
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
return
}
var result struct {
Status string `json:"status"`
Type string `json:"type"`
Path string `json:"path"`
Message string `json:"message"`
}
result.Status = "connected"
result.Type = cfg.Database.Type
result.Path = cfg.Database.Connection
result.Message = fmt.Sprintf("%s database is operational", cfg.Database.Type)
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
result.Status = "error"
result.Message = fmt.Sprintf("Database query failed: %v", err)
}
jsonBytes, _ := json.Marshal(result)
_, _ = w.Write(jsonBytes)
})
return mux
}
func wrapWithMiddleware(cfg *Config, sec *middleware.SecurityMiddleware, mux *http.ServeMux) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ws" {
mux.ServeHTTP(w, r)
return
}
handler := sec.RateLimit(mux)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
handler = middleware.AuditLogger(handler)
if len(cfg.Security.IPWhitelist) > 0 {
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
}
handler.ServeHTTP(w, r)
})
}
func newHTTPServer(cfg *Config, handler http.Handler) *http.Server {
return &http.Server{
Addr: cfg.Server.Address,
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
}
func startServer(server *http.Server, cfg *Config, logger *logging.Logger) {
if !cfg.Server.TLS.Enabled {
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
}
go func() {
if cfg.Server.TLS.Enabled {
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
if err := server.ListenAndServeTLS(
cfg.Server.TLS.CertFile,
cfg.Server.TLS.KeyFile,
); err != nil && err != http.ErrServerClosed {
logger.Error("HTTPS server failed", "error", err)
}
} else {
logger.Info("starting HTTP server", "address", cfg.Server.Address)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server failed", "error", err)
}
}
os.Exit(1)
}()
}
func waitForShutdown(server *http.Server, logger *logging.Logger) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
logger.Info("received shutdown signal", "signal", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
logger.Info("shutting down http server...")
if err := server.Shutdown(ctx); err != nil {
logger.Error("server shutdown error", "error", err)
} else {
logger.Info("http server shutdown complete")
}
logger.Info("api server stopped")
}