refactor: Phase 5 - split API package into focused files

Reorganized internal/api/ package to follow single-concern principle:

- api/factory.go (new file, 257 lines)
  - Extracted component initialization from server.go
  - initializeComponents(), setupLogger(), initExperimentManager()
  - initTaskQueue(), initDatabase(), initDatabaseSchema()
  - initSecurity(), initJupyterServiceManager(), initAuditLogger()

- api/middleware.go (new file, 31 lines)
  - Extracted wrapWithMiddleware() - security middleware chain
  - Centralized auth, rate limiting, CORS, security headers

- api/server.go (reduced from 446 to 212 lines)
  - Now focused on Server lifecycle: NewServer, Start, WaitForShutdown, Close
  - Removed initialization logic (moved to factory.go)
  - Removed middleware wrapper (moved to middleware.go)

- api/metrics_middleware.go (existing, 64 lines)
  - Already had wrapWithMetrics(), left in place

Lines redistributed: ~180 lines from monolithic server.go
Build status: Compiles successfully
This commit is contained in:
Jeremie Fraeys 2026-02-17 13:11:02 -05:00
parent a5c1a9fc0b
commit db7fbbd8d5
No known key found for this signature in database
3 changed files with 288 additions and 256 deletions

256
internal/api/factory.go Normal file
View file

@ -0,0 +1,256 @@
package api
import (
"context"
"os"
"strings"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// initializeComponents initializes all server components
func (s *Server) initializeComponents() error {
// Setup logging
if err := s.config.EnsureLogDirectory(); err != nil {
return err
}
s.logger = s.setupLogger()
// Initialize experiment manager
if err := s.initExperimentManager(); err != nil {
return err
}
// Initialize task queue
if err := s.initTaskQueue(); err != nil {
return err
}
// Initialize database
if err := s.initDatabase(); err != nil {
return err
}
// Initialize database schema (if DB enabled)
if err := s.initDatabaseSchema(); err != nil {
return err
}
// Initialize security
s.initSecurity()
// Initialize Jupyter service manager
s.initJupyterServiceManager()
// Initialize handlers
s.handlers = NewHandlers(s.expManager, nil, s.logger)
return nil
}
// setupLogger creates and configures the logger
func (s *Server) setupLogger() *logging.Logger {
logger := logging.NewLoggerFromConfig(s.config.Logging)
ctx := logging.EnsureTrace(context.Background())
return logger.Component(ctx, "api-server")
}
// initExperimentManager initializes the experiment manager
func (s *Server) initExperimentManager() error {
s.expManager = experiment.NewManager(s.config.BasePath)
if err := s.expManager.Initialize(); err != nil {
return err
}
s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath)
return nil
}
// initTaskQueue initializes the task queue
func (s *Server) initTaskQueue() error {
backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend))
if backend == "" {
backend = "redis"
}
redisAddr := strings.TrimSpace(s.config.Redis.Addr)
if redisAddr == "" {
redisAddr = "localhost:6379"
}
if strings.TrimSpace(s.config.Redis.URL) != "" {
redisAddr = strings.TrimSpace(s.config.Redis.URL)
}
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(backend),
RedisAddr: redisAddr,
RedisPassword: s.config.Redis.Password,
RedisDB: s.config.Redis.DB,
SQLitePath: s.config.Queue.SQLitePath,
FilesystemPath: s.config.Queue.FilesystemPath,
FallbackToFilesystem: s.config.Queue.FallbackToFilesystem,
MetricsFlushInterval: 0,
}
taskQueue, err := queue.NewBackend(backendCfg)
if err != nil {
return err
}
s.taskQueue = taskQueue
if backend == "sqlite" {
s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath)
} else {
s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr)
}
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("stopping task queue...")
if err := s.taskQueue.Close(); err != nil {
s.logger.Error("failed to stop task queue", "error", err)
} else {
s.logger.Info("task queue stopped")
}
})
return nil
}
// initDatabase initializes the database connection
func (s *Server) initDatabase() error {
if s.config.Database.Type == "" {
return nil
}
dbConfig := storage.DBConfig{
Type: s.config.Database.Type,
Connection: s.config.Database.Connection,
Host: s.config.Database.Host,
Port: s.config.Database.Port,
Username: s.config.Database.Username,
Password: s.config.Database.Password,
Database: s.config.Database.Database,
}
db, err := storage.NewDB(dbConfig)
if err != nil {
return err
}
s.db = db
s.logger.Info("database initialized", "type", s.config.Database.Type)
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("closing database connection...")
if err := s.db.Close(); err != nil {
s.logger.Error("failed to close database", "error", err)
} else {
s.logger.Info("database connection closed")
}
})
return nil
}
// initDatabaseSchema initializes the database schema
func (s *Server) initDatabaseSchema() error {
if s.db == nil {
return nil
}
schema, err := storage.SchemaForDBType(s.config.Database.Type)
if err != nil {
return err
}
if err := s.db.Initialize(schema); err != nil {
return err
}
s.logger.Info("database schema initialized", "type", s.config.Database.Type)
return nil
}
// initSecurity initializes security middleware
func (s *Server) initSecurity() {
authConfig := s.config.BuildAuthConfig()
rlOpts := s.buildRateLimitOptions()
s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts)
}
// buildRateLimitOptions builds rate limit options from configuration
func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions {
if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 {
return nil
}
return &middleware.RateLimitOptions{
RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute,
BurstSize: s.config.Security.RateLimit.BurstSize,
}
}
// initJupyterServiceManager initializes the Jupyter service manager
func (s *Server) initJupyterServiceManager() {
serviceConfig := jupyter.GetDefaultServiceConfig()
sm, err := jupyter.NewServiceManager(s.logger, serviceConfig)
if err != nil {
s.logger.Error("failed to initialize Jupyter service manager", "error", err)
return
}
s.jupyterServiceMgr = sm
s.logger.Info("jupyter service manager initialized")
}
// initAuditLogger initializes the audit logger
func (s *Server) initAuditLogger() *audit.Logger {
if !s.config.Security.AuditLogging.Enabled || s.config.Security.AuditLogging.LogPath == "" {
return nil
}
al, err := audit.NewLogger(
s.config.Security.AuditLogging.Enabled,
s.config.Security.AuditLogging.LogPath,
s.logger,
)
if err != nil {
s.logger.Warn("failed to initialize audit logger", "error", err)
return nil
}
s.auditLogger = al
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("closing audit logger...")
if err := al.Close(); err != nil {
s.logger.Error("failed to close audit logger", "error", err)
}
})
return al
}
// getSecurityConfig extracts security config from server config
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
return &config.SecurityConfig{
AllowedOrigins: cfg.Security.AllowedOrigins,
ProductionMode: cfg.Security.ProductionMode,
APIKeyRotationDays: cfg.Security.APIKeyRotationDays,
AuditLogging: config.AuditLoggingConfig{
Enabled: cfg.Security.AuditLogging.Enabled,
LogPath: cfg.Security.AuditLogging.LogPath,
},
IPWhitelist: cfg.Security.IPWhitelist,
}
}

View file

@ -0,0 +1,31 @@
package api
import (
"net/http"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/middleware"
)
// wrapWithMiddleware wraps the handler with security middleware
func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for WebSocket and health endpoints
if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") {
mux.ServeHTTP(w, r)
return
}
handler := s.sec.APIKeyAuth(mux)
handler = s.sec.RateLimit(handler)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
handler = middleware.AuditLogger(handler)
if len(s.config.Security.IPWhitelist) > 0 {
handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler)
}
handler.ServeHTTP(w, r)
})
}

View file

@ -5,12 +5,10 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -63,202 +61,6 @@ func NewServer(configPath string) (*Server, error) {
return server, nil
}
// initializeComponents initializes all server components
func (s *Server) initializeComponents() error {
// Setup logging
if err := s.config.EnsureLogDirectory(); err != nil {
return err
}
s.logger = s.setupLogger()
// Initialize experiment manager
if err := s.initExperimentManager(); err != nil {
return err
}
// Initialize task queue
if err := s.initTaskQueue(); err != nil {
return err
}
// Initialize database
if err := s.initDatabase(); err != nil {
return err
}
// Initialize database schema (if DB enabled)
if err := s.initDatabaseSchema(); err != nil {
return err
}
// Initialize security
s.initSecurity()
// Initialize Jupyter service manager
s.initJupyterServiceManager()
// Initialize handlers
s.handlers = NewHandlers(s.expManager, nil, s.logger)
return nil
}
func (s *Server) initDatabaseSchema() error {
if s.db == nil {
return nil
}
schema, err := storage.SchemaForDBType(s.config.Database.Type)
if err != nil {
return err
}
if err := s.db.Initialize(schema); err != nil {
return err
}
s.logger.Info("database schema initialized", "type", s.config.Database.Type)
return nil
}
// setupLogger creates and configures the logger
func (s *Server) setupLogger() *logging.Logger {
logger := logging.NewLoggerFromConfig(s.config.Logging)
ctx := logging.EnsureTrace(context.Background())
return logger.Component(ctx, "api-server")
}
// initExperimentManager initializes the experiment manager
func (s *Server) initExperimentManager() error {
s.expManager = experiment.NewManager(s.config.BasePath)
if err := s.expManager.Initialize(); err != nil {
return err
}
s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath)
return nil
}
// initTaskQueue initializes the task queue
func (s *Server) initTaskQueue() error {
backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend))
if backend == "" {
backend = "redis"
}
redisAddr := strings.TrimSpace(s.config.Redis.Addr)
if redisAddr == "" {
redisAddr = "localhost:6379"
}
if strings.TrimSpace(s.config.Redis.URL) != "" {
redisAddr = strings.TrimSpace(s.config.Redis.URL)
}
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(backend),
RedisAddr: redisAddr,
RedisPassword: s.config.Redis.Password,
RedisDB: s.config.Redis.DB,
SQLitePath: s.config.Queue.SQLitePath,
FilesystemPath: s.config.Queue.FilesystemPath,
FallbackToFilesystem: s.config.Queue.FallbackToFilesystem,
MetricsFlushInterval: 0,
}
taskQueue, err := queue.NewBackend(backendCfg)
if err != nil {
return err
}
s.taskQueue = taskQueue
if backend == "sqlite" {
s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath)
} else {
s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr)
}
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("stopping task queue...")
if err := s.taskQueue.Close(); err != nil {
s.logger.Error("failed to stop task queue", "error", err)
} else {
s.logger.Info("task queue stopped")
}
})
return nil
}
// initDatabase initializes the database connection
func (s *Server) initDatabase() error {
if s.config.Database.Type == "" {
return nil
}
dbConfig := storage.DBConfig{
Type: s.config.Database.Type,
Connection: s.config.Database.Connection,
Host: s.config.Database.Host,
Port: s.config.Database.Port,
Username: s.config.Database.Username,
Password: s.config.Database.Password,
Database: s.config.Database.Database,
}
db, err := storage.NewDB(dbConfig)
if err != nil {
return err
}
s.db = db
s.logger.Info("database initialized", "type", s.config.Database.Type)
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("closing database connection...")
if err := s.db.Close(); err != nil {
s.logger.Error("failed to close database", "error", err)
} else {
s.logger.Info("database connection closed")
}
})
return nil
}
// initSecurity initializes security middleware
func (s *Server) initSecurity() {
authConfig := s.config.BuildAuthConfig()
rlOpts := s.buildRateLimitOptions()
s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts)
}
// buildRateLimitOptions builds rate limit options from configuration
func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions {
if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 {
return nil
}
return &middleware.RateLimitOptions{
RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute,
BurstSize: s.config.Security.RateLimit.BurstSize,
}
}
// initJupyterServiceManager initializes the Jupyter service manager
func (s *Server) initJupyterServiceManager() {
serviceConfig := jupyter.GetDefaultServiceConfig()
sm, err := jupyter.NewServiceManager(s.logger, serviceConfig)
if err != nil {
s.logger.Error("failed to initialize Jupyter service manager", "error", err)
return
}
s.jupyterServiceMgr = sm
s.logger.Info("jupyter service manager initialized")
}
// setupHTTPServer sets up the HTTP server and routes
func (s *Server) setupHTTPServer() {
mux := http.NewServeMux()
@ -286,28 +88,7 @@ func (s *Server) setupHTTPServer() {
}
// Initialize audit logger
var auditLogger *audit.Logger
if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" {
al, err := audit.NewLogger(
s.config.Security.AuditLogging.Enabled,
s.config.Security.AuditLogging.LogPath,
s.logger,
)
if err != nil {
s.logger.Warn("failed to initialize audit logger", "error", err)
} else {
auditLogger = al
s.auditLogger = al
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("closing audit logger...")
if err := auditLogger.Close(); err != nil {
s.logger.Error("failed to close audit logger", "error", err)
}
})
}
}
auditLogger := s.initAuditLogger()
// Register WebSocket handler with security config and audit logger
securityCfg := getSecurityConfig(s.config)
@ -342,42 +123,6 @@ func (s *Server) setupHTTPServer() {
}
}
// getSecurityConfig extracts security config from server config
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
return &config.SecurityConfig{
AllowedOrigins: cfg.Security.AllowedOrigins,
ProductionMode: cfg.Security.ProductionMode,
APIKeyRotationDays: cfg.Security.APIKeyRotationDays,
AuditLogging: config.AuditLoggingConfig{
Enabled: cfg.Security.AuditLogging.Enabled,
LogPath: cfg.Security.AuditLogging.LogPath,
},
IPWhitelist: cfg.Security.IPWhitelist,
}
}
// wrapWithMiddleware wraps the handler with security middleware
func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for WebSocket and health endpoints
if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") {
mux.ServeHTTP(w, r)
return
}
handler := s.sec.APIKeyAuth(mux)
handler = s.sec.RateLimit(handler)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
handler = middleware.AuditLogger(handler)
if len(s.config.Security.IPWhitelist) > 0 {
handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler)
}
handler.ServeHTTP(w, r)
})
}
// Start starts the server
func (s *Server) Start() error {
if !s.config.Server.TLS.Enabled {