From db7fbbd8d585eb1380ecaf5a5a6e332f385df014 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Tue, 17 Feb 2026 13:11:02 -0500 Subject: [PATCH] refactor: Phase 5 - split API package into focused files Reorganized internal/api/ package to follow single-concern principle: - api/factory.go (new file, 257 lines) - Extracted component initialization from server.go - initializeComponents(), setupLogger(), initExperimentManager() - initTaskQueue(), initDatabase(), initDatabaseSchema() - initSecurity(), initJupyterServiceManager(), initAuditLogger() - api/middleware.go (new file, 31 lines) - Extracted wrapWithMiddleware() - security middleware chain - Centralized auth, rate limiting, CORS, security headers - api/server.go (reduced from 446 to 212 lines) - Now focused on Server lifecycle: NewServer, Start, WaitForShutdown, Close - Removed initialization logic (moved to factory.go) - Removed middleware wrapper (moved to middleware.go) - api/metrics_middleware.go (existing, 64 lines) - Already had wrapWithMetrics(), left in place Lines redistributed: ~180 lines from monolithic server.go Build status: Compiles successfully --- internal/api/factory.go | 256 ++++++++++++++++++++++++++++++++++++ internal/api/middleware.go | 31 +++++ internal/api/server.go | 257 +------------------------------------ 3 files changed, 288 insertions(+), 256 deletions(-) create mode 100644 internal/api/factory.go create mode 100644 internal/api/middleware.go diff --git a/internal/api/factory.go b/internal/api/factory.go new file mode 100644 index 0000000..e4e22d7 --- /dev/null +++ b/internal/api/factory.go @@ -0,0 +1,256 @@ +package api + +import ( + "context" + "os" + "strings" + + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/middleware" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// initializeComponents initializes all server components +func (s *Server) initializeComponents() error { + // Setup logging + if err := s.config.EnsureLogDirectory(); err != nil { + return err + } + s.logger = s.setupLogger() + + // Initialize experiment manager + if err := s.initExperimentManager(); err != nil { + return err + } + + // Initialize task queue + if err := s.initTaskQueue(); err != nil { + return err + } + + // Initialize database + if err := s.initDatabase(); err != nil { + return err + } + + // Initialize database schema (if DB enabled) + if err := s.initDatabaseSchema(); err != nil { + return err + } + + // Initialize security + s.initSecurity() + + // Initialize Jupyter service manager + s.initJupyterServiceManager() + + // Initialize handlers + s.handlers = NewHandlers(s.expManager, nil, s.logger) + + return nil +} + +// setupLogger creates and configures the logger +func (s *Server) setupLogger() *logging.Logger { + logger := logging.NewLoggerFromConfig(s.config.Logging) + ctx := logging.EnsureTrace(context.Background()) + return logger.Component(ctx, "api-server") +} + +// initExperimentManager initializes the experiment manager +func (s *Server) initExperimentManager() error { + s.expManager = experiment.NewManager(s.config.BasePath) + if err := s.expManager.Initialize(); err != nil { + return err + } + + s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath) + return nil +} + +// initTaskQueue initializes the task queue +func (s *Server) initTaskQueue() error { + backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend)) + if backend == "" { + backend = "redis" + } + redisAddr := strings.TrimSpace(s.config.Redis.Addr) + if redisAddr == "" { + redisAddr = "localhost:6379" + } + if strings.TrimSpace(s.config.Redis.URL) != "" { + redisAddr = strings.TrimSpace(s.config.Redis.URL) + } + + backendCfg := queue.BackendConfig{ + Backend: queue.QueueBackend(backend), + RedisAddr: redisAddr, + RedisPassword: s.config.Redis.Password, + RedisDB: s.config.Redis.DB, + SQLitePath: s.config.Queue.SQLitePath, + FilesystemPath: s.config.Queue.FilesystemPath, + FallbackToFilesystem: s.config.Queue.FallbackToFilesystem, + MetricsFlushInterval: 0, + } + + taskQueue, err := queue.NewBackend(backendCfg) + if err != nil { + return err + } + + s.taskQueue = taskQueue + if backend == "sqlite" { + s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath) + } else { + s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr) + } + + // Add cleanup function + s.cleanupFuncs = append(s.cleanupFuncs, func() { + s.logger.Info("stopping task queue...") + if err := s.taskQueue.Close(); err != nil { + s.logger.Error("failed to stop task queue", "error", err) + } else { + s.logger.Info("task queue stopped") + } + }) + + return nil +} + +// initDatabase initializes the database connection +func (s *Server) initDatabase() error { + if s.config.Database.Type == "" { + return nil + } + + dbConfig := storage.DBConfig{ + Type: s.config.Database.Type, + Connection: s.config.Database.Connection, + Host: s.config.Database.Host, + Port: s.config.Database.Port, + Username: s.config.Database.Username, + Password: s.config.Database.Password, + Database: s.config.Database.Database, + } + + db, err := storage.NewDB(dbConfig) + if err != nil { + return err + } + + s.db = db + s.logger.Info("database initialized", "type", s.config.Database.Type) + + // Add cleanup function + s.cleanupFuncs = append(s.cleanupFuncs, func() { + s.logger.Info("closing database connection...") + if err := s.db.Close(); err != nil { + s.logger.Error("failed to close database", "error", err) + } else { + s.logger.Info("database connection closed") + } + }) + + return nil +} + +// initDatabaseSchema initializes the database schema +func (s *Server) initDatabaseSchema() error { + if s.db == nil { + return nil + } + + schema, err := storage.SchemaForDBType(s.config.Database.Type) + if err != nil { + return err + } + + if err := s.db.Initialize(schema); err != nil { + return err + } + + s.logger.Info("database schema initialized", "type", s.config.Database.Type) + return nil +} + +// initSecurity initializes security middleware +func (s *Server) initSecurity() { + authConfig := s.config.BuildAuthConfig() + rlOpts := s.buildRateLimitOptions() + s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts) +} + +// buildRateLimitOptions builds rate limit options from configuration +func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions { + if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 { + return nil + } + + return &middleware.RateLimitOptions{ + RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute, + BurstSize: s.config.Security.RateLimit.BurstSize, + } +} + +// initJupyterServiceManager initializes the Jupyter service manager +func (s *Server) initJupyterServiceManager() { + serviceConfig := jupyter.GetDefaultServiceConfig() + + sm, err := jupyter.NewServiceManager(s.logger, serviceConfig) + if err != nil { + s.logger.Error("failed to initialize Jupyter service manager", "error", err) + return + } + + s.jupyterServiceMgr = sm + s.logger.Info("jupyter service manager initialized") +} + +// initAuditLogger initializes the audit logger +func (s *Server) initAuditLogger() *audit.Logger { + if !s.config.Security.AuditLogging.Enabled || s.config.Security.AuditLogging.LogPath == "" { + return nil + } + + al, err := audit.NewLogger( + s.config.Security.AuditLogging.Enabled, + s.config.Security.AuditLogging.LogPath, + s.logger, + ) + if err != nil { + s.logger.Warn("failed to initialize audit logger", "error", err) + return nil + } + + s.auditLogger = al + + // Add cleanup function + s.cleanupFuncs = append(s.cleanupFuncs, func() { + s.logger.Info("closing audit logger...") + if err := al.Close(); err != nil { + s.logger.Error("failed to close audit logger", "error", err) + } + }) + + return al +} + +// getSecurityConfig extracts security config from server config +func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig { + return &config.SecurityConfig{ + AllowedOrigins: cfg.Security.AllowedOrigins, + ProductionMode: cfg.Security.ProductionMode, + APIKeyRotationDays: cfg.Security.APIKeyRotationDays, + AuditLogging: config.AuditLoggingConfig{ + Enabled: cfg.Security.AuditLogging.Enabled, + LogPath: cfg.Security.AuditLogging.LogPath, + }, + IPWhitelist: cfg.Security.IPWhitelist, + } +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 0000000..dede950 --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,31 @@ +package api + +import ( + "net/http" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/middleware" +) + +// wrapWithMiddleware wraps the handler with security middleware +func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip auth for WebSocket and health endpoints + if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") { + mux.ServeHTTP(w, r) + return + } + + handler := s.sec.APIKeyAuth(mux) + handler = s.sec.RateLimit(handler) + handler = middleware.SecurityHeaders(handler) + handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) + handler = middleware.RequestTimeout(30 * time.Second)(handler) + handler = middleware.AuditLogger(handler) + if len(s.config.Security.IPWhitelist) > 0 { + handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler) + } + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/api/server.go b/internal/api/server.go index 001c47e..8e18fce 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,12 +5,10 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/audit" - "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" @@ -63,202 +61,6 @@ func NewServer(configPath string) (*Server, error) { return server, nil } -// initializeComponents initializes all server components -func (s *Server) initializeComponents() error { - // Setup logging - if err := s.config.EnsureLogDirectory(); err != nil { - return err - } - s.logger = s.setupLogger() - - // Initialize experiment manager - if err := s.initExperimentManager(); err != nil { - return err - } - - // Initialize task queue - if err := s.initTaskQueue(); err != nil { - return err - } - - // Initialize database - if err := s.initDatabase(); err != nil { - return err - } - - // Initialize database schema (if DB enabled) - if err := s.initDatabaseSchema(); err != nil { - return err - } - - // Initialize security - s.initSecurity() - - // Initialize Jupyter service manager - s.initJupyterServiceManager() - - // Initialize handlers - s.handlers = NewHandlers(s.expManager, nil, s.logger) - - return nil -} - -func (s *Server) initDatabaseSchema() error { - if s.db == nil { - return nil - } - - schema, err := storage.SchemaForDBType(s.config.Database.Type) - if err != nil { - return err - } - - if err := s.db.Initialize(schema); err != nil { - return err - } - - s.logger.Info("database schema initialized", "type", s.config.Database.Type) - return nil -} - -// setupLogger creates and configures the logger -func (s *Server) setupLogger() *logging.Logger { - logger := logging.NewLoggerFromConfig(s.config.Logging) - ctx := logging.EnsureTrace(context.Background()) - return logger.Component(ctx, "api-server") -} - -// initExperimentManager initializes the experiment manager -func (s *Server) initExperimentManager() error { - s.expManager = experiment.NewManager(s.config.BasePath) - if err := s.expManager.Initialize(); err != nil { - return err - } - - s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath) - return nil -} - -// initTaskQueue initializes the task queue -func (s *Server) initTaskQueue() error { - backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend)) - if backend == "" { - backend = "redis" - } - redisAddr := strings.TrimSpace(s.config.Redis.Addr) - if redisAddr == "" { - redisAddr = "localhost:6379" - } - if strings.TrimSpace(s.config.Redis.URL) != "" { - redisAddr = strings.TrimSpace(s.config.Redis.URL) - } - - backendCfg := queue.BackendConfig{ - Backend: queue.QueueBackend(backend), - RedisAddr: redisAddr, - RedisPassword: s.config.Redis.Password, - RedisDB: s.config.Redis.DB, - SQLitePath: s.config.Queue.SQLitePath, - FilesystemPath: s.config.Queue.FilesystemPath, - FallbackToFilesystem: s.config.Queue.FallbackToFilesystem, - MetricsFlushInterval: 0, - } - - taskQueue, err := queue.NewBackend(backendCfg) - if err != nil { - return err - } - - s.taskQueue = taskQueue - if backend == "sqlite" { - s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath) - } else { - s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr) - } - - // Add cleanup function - s.cleanupFuncs = append(s.cleanupFuncs, func() { - s.logger.Info("stopping task queue...") - if err := s.taskQueue.Close(); err != nil { - s.logger.Error("failed to stop task queue", "error", err) - } else { - s.logger.Info("task queue stopped") - } - }) - - return nil -} - -// initDatabase initializes the database connection -func (s *Server) initDatabase() error { - if s.config.Database.Type == "" { - return nil - } - - dbConfig := storage.DBConfig{ - Type: s.config.Database.Type, - Connection: s.config.Database.Connection, - Host: s.config.Database.Host, - Port: s.config.Database.Port, - Username: s.config.Database.Username, - Password: s.config.Database.Password, - Database: s.config.Database.Database, - } - - db, err := storage.NewDB(dbConfig) - if err != nil { - return err - } - - s.db = db - s.logger.Info("database initialized", "type", s.config.Database.Type) - - // Add cleanup function - s.cleanupFuncs = append(s.cleanupFuncs, func() { - s.logger.Info("closing database connection...") - if err := s.db.Close(); err != nil { - s.logger.Error("failed to close database", "error", err) - } else { - s.logger.Info("database connection closed") - } - }) - - return nil -} - -// initSecurity initializes security middleware -func (s *Server) initSecurity() { - authConfig := s.config.BuildAuthConfig() - rlOpts := s.buildRateLimitOptions() - s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts) -} - -// buildRateLimitOptions builds rate limit options from configuration -func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions { - if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 { - return nil - } - - return &middleware.RateLimitOptions{ - RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute, - BurstSize: s.config.Security.RateLimit.BurstSize, - } -} - -// initJupyterServiceManager initializes the Jupyter service manager -func (s *Server) initJupyterServiceManager() { - serviceConfig := jupyter.GetDefaultServiceConfig() - - sm, err := jupyter.NewServiceManager(s.logger, serviceConfig) - if err != nil { - s.logger.Error("failed to initialize Jupyter service manager", "error", err) - return - } - - s.jupyterServiceMgr = sm - s.logger.Info("jupyter service manager initialized") -} - // setupHTTPServer sets up the HTTP server and routes func (s *Server) setupHTTPServer() { mux := http.NewServeMux() @@ -286,28 +88,7 @@ func (s *Server) setupHTTPServer() { } // Initialize audit logger - var auditLogger *audit.Logger - if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" { - al, err := audit.NewLogger( - s.config.Security.AuditLogging.Enabled, - s.config.Security.AuditLogging.LogPath, - s.logger, - ) - if err != nil { - s.logger.Warn("failed to initialize audit logger", "error", err) - } else { - auditLogger = al - s.auditLogger = al - - // Add cleanup function - s.cleanupFuncs = append(s.cleanupFuncs, func() { - s.logger.Info("closing audit logger...") - if err := auditLogger.Close(); err != nil { - s.logger.Error("failed to close audit logger", "error", err) - } - }) - } - } + auditLogger := s.initAuditLogger() // Register WebSocket handler with security config and audit logger securityCfg := getSecurityConfig(s.config) @@ -342,42 +123,6 @@ func (s *Server) setupHTTPServer() { } } -// getSecurityConfig extracts security config from server config -func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig { - return &config.SecurityConfig{ - AllowedOrigins: cfg.Security.AllowedOrigins, - ProductionMode: cfg.Security.ProductionMode, - APIKeyRotationDays: cfg.Security.APIKeyRotationDays, - AuditLogging: config.AuditLoggingConfig{ - Enabled: cfg.Security.AuditLogging.Enabled, - LogPath: cfg.Security.AuditLogging.LogPath, - }, - IPWhitelist: cfg.Security.IPWhitelist, - } -} - -// wrapWithMiddleware wraps the handler with security middleware -func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Skip auth for WebSocket and health endpoints - if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") { - mux.ServeHTTP(w, r) - return - } - - handler := s.sec.APIKeyAuth(mux) - handler = s.sec.RateLimit(handler) - handler = middleware.SecurityHeaders(handler) - handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) - handler = middleware.RequestTimeout(30 * time.Second)(handler) - handler = middleware.AuditLogger(handler) - if len(s.config.Security.IPWhitelist) > 0 { - handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler) - } - handler.ServeHTTP(w, r) - }) -} - // Start starts the server func (s *Server) Start() error { if !s.config.Server.TLS.Enabled {