Reorganized internal/api/ package to follow single-concern principle: - api/factory.go (new file, 257 lines) - Extracted component initialization from server.go - initializeComponents(), setupLogger(), initExperimentManager() - initTaskQueue(), initDatabase(), initDatabaseSchema() - initSecurity(), initJupyterServiceManager(), initAuditLogger() - api/middleware.go (new file, 31 lines) - Extracted wrapWithMiddleware() - security middleware chain - Centralized auth, rate limiting, CORS, security headers - api/server.go (reduced from 446 to 212 lines) - Now focused on Server lifecycle: NewServer, Start, WaitForShutdown, Close - Removed initialization logic (moved to factory.go) - Removed middleware wrapper (moved to middleware.go) - api/metrics_middleware.go (existing, 64 lines) - Already had wrapWithMetrics(), left in place Lines redistributed: ~180 lines from monolithic server.go Build status: Compiles successfully
256 lines
6.8 KiB
Go
256 lines
6.8 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/middleware"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
)
|
|
|
|
// initializeComponents initializes all server components
|
|
func (s *Server) initializeComponents() error {
|
|
// Setup logging
|
|
if err := s.config.EnsureLogDirectory(); err != nil {
|
|
return err
|
|
}
|
|
s.logger = s.setupLogger()
|
|
|
|
// Initialize experiment manager
|
|
if err := s.initExperimentManager(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize task queue
|
|
if err := s.initTaskQueue(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize database
|
|
if err := s.initDatabase(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize database schema (if DB enabled)
|
|
if err := s.initDatabaseSchema(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize security
|
|
s.initSecurity()
|
|
|
|
// Initialize Jupyter service manager
|
|
s.initJupyterServiceManager()
|
|
|
|
// Initialize handlers
|
|
s.handlers = NewHandlers(s.expManager, nil, s.logger)
|
|
|
|
return nil
|
|
}
|
|
|
|
// setupLogger creates and configures the logger
|
|
func (s *Server) setupLogger() *logging.Logger {
|
|
logger := logging.NewLoggerFromConfig(s.config.Logging)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
return logger.Component(ctx, "api-server")
|
|
}
|
|
|
|
// initExperimentManager initializes the experiment manager
|
|
func (s *Server) initExperimentManager() error {
|
|
s.expManager = experiment.NewManager(s.config.BasePath)
|
|
if err := s.expManager.Initialize(); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath)
|
|
return nil
|
|
}
|
|
|
|
// initTaskQueue initializes the task queue
|
|
func (s *Server) initTaskQueue() error {
|
|
backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend))
|
|
if backend == "" {
|
|
backend = "redis"
|
|
}
|
|
redisAddr := strings.TrimSpace(s.config.Redis.Addr)
|
|
if redisAddr == "" {
|
|
redisAddr = "localhost:6379"
|
|
}
|
|
if strings.TrimSpace(s.config.Redis.URL) != "" {
|
|
redisAddr = strings.TrimSpace(s.config.Redis.URL)
|
|
}
|
|
|
|
backendCfg := queue.BackendConfig{
|
|
Backend: queue.QueueBackend(backend),
|
|
RedisAddr: redisAddr,
|
|
RedisPassword: s.config.Redis.Password,
|
|
RedisDB: s.config.Redis.DB,
|
|
SQLitePath: s.config.Queue.SQLitePath,
|
|
FilesystemPath: s.config.Queue.FilesystemPath,
|
|
FallbackToFilesystem: s.config.Queue.FallbackToFilesystem,
|
|
MetricsFlushInterval: 0,
|
|
}
|
|
|
|
taskQueue, err := queue.NewBackend(backendCfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.taskQueue = taskQueue
|
|
if backend == "sqlite" {
|
|
s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath)
|
|
} else {
|
|
s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr)
|
|
}
|
|
|
|
// Add cleanup function
|
|
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
|
s.logger.Info("stopping task queue...")
|
|
if err := s.taskQueue.Close(); err != nil {
|
|
s.logger.Error("failed to stop task queue", "error", err)
|
|
} else {
|
|
s.logger.Info("task queue stopped")
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// initDatabase initializes the database connection
|
|
func (s *Server) initDatabase() error {
|
|
if s.config.Database.Type == "" {
|
|
return nil
|
|
}
|
|
|
|
dbConfig := storage.DBConfig{
|
|
Type: s.config.Database.Type,
|
|
Connection: s.config.Database.Connection,
|
|
Host: s.config.Database.Host,
|
|
Port: s.config.Database.Port,
|
|
Username: s.config.Database.Username,
|
|
Password: s.config.Database.Password,
|
|
Database: s.config.Database.Database,
|
|
}
|
|
|
|
db, err := storage.NewDB(dbConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.db = db
|
|
s.logger.Info("database initialized", "type", s.config.Database.Type)
|
|
|
|
// Add cleanup function
|
|
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
|
s.logger.Info("closing database connection...")
|
|
if err := s.db.Close(); err != nil {
|
|
s.logger.Error("failed to close database", "error", err)
|
|
} else {
|
|
s.logger.Info("database connection closed")
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// initDatabaseSchema initializes the database schema
|
|
func (s *Server) initDatabaseSchema() error {
|
|
if s.db == nil {
|
|
return nil
|
|
}
|
|
|
|
schema, err := storage.SchemaForDBType(s.config.Database.Type)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.db.Initialize(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Info("database schema initialized", "type", s.config.Database.Type)
|
|
return nil
|
|
}
|
|
|
|
// initSecurity initializes security middleware
|
|
func (s *Server) initSecurity() {
|
|
authConfig := s.config.BuildAuthConfig()
|
|
rlOpts := s.buildRateLimitOptions()
|
|
s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts)
|
|
}
|
|
|
|
// buildRateLimitOptions builds rate limit options from configuration
|
|
func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions {
|
|
if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 {
|
|
return nil
|
|
}
|
|
|
|
return &middleware.RateLimitOptions{
|
|
RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute,
|
|
BurstSize: s.config.Security.RateLimit.BurstSize,
|
|
}
|
|
}
|
|
|
|
// initJupyterServiceManager initializes the Jupyter service manager
|
|
func (s *Server) initJupyterServiceManager() {
|
|
serviceConfig := jupyter.GetDefaultServiceConfig()
|
|
|
|
sm, err := jupyter.NewServiceManager(s.logger, serviceConfig)
|
|
if err != nil {
|
|
s.logger.Error("failed to initialize Jupyter service manager", "error", err)
|
|
return
|
|
}
|
|
|
|
s.jupyterServiceMgr = sm
|
|
s.logger.Info("jupyter service manager initialized")
|
|
}
|
|
|
|
// initAuditLogger initializes the audit logger
|
|
func (s *Server) initAuditLogger() *audit.Logger {
|
|
if !s.config.Security.AuditLogging.Enabled || s.config.Security.AuditLogging.LogPath == "" {
|
|
return nil
|
|
}
|
|
|
|
al, err := audit.NewLogger(
|
|
s.config.Security.AuditLogging.Enabled,
|
|
s.config.Security.AuditLogging.LogPath,
|
|
s.logger,
|
|
)
|
|
if err != nil {
|
|
s.logger.Warn("failed to initialize audit logger", "error", err)
|
|
return nil
|
|
}
|
|
|
|
s.auditLogger = al
|
|
|
|
// Add cleanup function
|
|
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
|
s.logger.Info("closing audit logger...")
|
|
if err := al.Close(); err != nil {
|
|
s.logger.Error("failed to close audit logger", "error", err)
|
|
}
|
|
})
|
|
|
|
return al
|
|
}
|
|
|
|
// getSecurityConfig extracts security config from server config
|
|
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
|
|
return &config.SecurityConfig{
|
|
AllowedOrigins: cfg.Security.AllowedOrigins,
|
|
ProductionMode: cfg.Security.ProductionMode,
|
|
APIKeyRotationDays: cfg.Security.APIKeyRotationDays,
|
|
AuditLogging: config.AuditLoggingConfig{
|
|
Enabled: cfg.Security.AuditLogging.Enabled,
|
|
LogPath: cfg.Security.AuditLogging.LogPath,
|
|
},
|
|
IPWhitelist: cfg.Security.IPWhitelist,
|
|
}
|
|
}
|