fetch_ml/internal/api/server.go
Jeremie Fraeys db7fbbd8d5
refactor: Phase 5 - split API package into focused files
Reorganized internal/api/ package to follow single-concern principle:

- api/factory.go (new file, 257 lines)
  - Extracted component initialization from server.go
  - initializeComponents(), setupLogger(), initExperimentManager()
  - initTaskQueue(), initDatabase(), initDatabaseSchema()
  - initSecurity(), initJupyterServiceManager(), initAuditLogger()

- api/middleware.go (new file, 31 lines)
  - Extracted wrapWithMiddleware() - security middleware chain
  - Centralized auth, rate limiting, CORS, security headers

- api/server.go (reduced from 446 to 212 lines)
  - Now focused on Server lifecycle: NewServer, Start, WaitForShutdown, Close
  - Removed initialization logic (moved to factory.go)
  - Removed middleware wrapper (moved to middleware.go)

- api/metrics_middleware.go (existing, 64 lines)
  - Already had wrapWithMetrics(), left in place

Lines redistributed: ~180 lines from monolithic server.go
Build status: Compiles successfully
2026-02-17 13:11:02 -05:00

190 lines
4.7 KiB
Go

package api
import (
"context"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/prommetrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Server represents the API server
type Server struct {
config *ServerConfig
httpServer *http.Server
logger *logging.Logger
expManager *experiment.Manager
taskQueue queue.Backend
db *storage.DB
handlers *Handlers
sec *middleware.SecurityMiddleware
cleanupFuncs []func()
jupyterServiceMgr *jupyter.ServiceManager
auditLogger *audit.Logger
promMetrics *prommetrics.Metrics // Prometheus metrics
}
// NewServer creates a new API server
func NewServer(configPath string) (*Server, error) {
// Load configuration
cfg, err := LoadServerConfig(configPath)
if err != nil {
return nil, err
}
if err := cfg.Validate(); err != nil {
return nil, err
}
server := &Server{
config: cfg,
}
// Initialize components
if err := server.initializeComponents(); err != nil {
return nil, err
}
// Setup HTTP server
server.setupHTTPServer()
return server, nil
}
// setupHTTPServer sets up the HTTP server and routes
func (s *Server) setupHTTPServer() {
mux := http.NewServeMux()
// Initialize Prometheus metrics (if enabled)
if s.config.Monitoring.Prometheus.Enabled {
s.promMetrics = prommetrics.New()
s.logger.Info("prometheus metrics initialized")
// Register metrics endpoint
metricsPath := s.config.Monitoring.Prometheus.Path
if metricsPath == "" {
metricsPath = "/metrics"
}
mux.Handle(metricsPath, s.promMetrics.Handler())
s.logger.Info("metrics endpoint registered", "path", metricsPath)
}
// Initialize health check handler
if s.config.Monitoring.HealthChecks.Enabled {
healthHandler := NewHealthHandler(s)
healthHandler.RegisterRoutes(mux)
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
s.logger.Info("health check endpoints registered")
}
// Initialize audit logger
auditLogger := s.initAuditLogger()
// Register WebSocket handler with security config and audit logger
securityCfg := getSecurityConfig(s.config)
wsHandler := NewWSHandler(
s.config.BuildAuthConfig(),
s.logger,
s.expManager,
s.config.DataDir,
s.taskQueue,
s.db,
s.jupyterServiceMgr,
securityCfg,
auditLogger,
)
// Wrap WebSocket handler with metrics
mux.Handle("/ws", wsHandler)
// Register HTTP handlers
s.handlers.RegisterHandlers(mux)
// Wrap with middleware
finalHandler := s.wrapWithMiddleware(mux)
finalHandler = s.wrapWithMetrics(finalHandler)
s.httpServer = &http.Server{
Addr: s.config.Server.Address,
Handler: finalHandler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
}
// Start starts the server
func (s *Server) Start() error {
if !s.config.Server.TLS.Enabled {
s.logger.Warn(
"TLS disabled for API server; do not use this configuration in production",
"address", s.config.Server.Address,
)
}
go func() {
var err error
if s.config.Server.TLS.Enabled {
s.logger.Info("starting HTTPS server", "address", s.config.Server.Address)
err = s.httpServer.ListenAndServeTLS(
s.config.Server.TLS.CertFile,
s.config.Server.TLS.KeyFile,
)
} else {
s.logger.Info("starting HTTP server", "address", s.config.Server.Address)
err = s.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
s.logger.Error("server failed", "error", err)
}
os.Exit(1)
}()
return nil
}
// WaitForShutdown waits for shutdown signals and gracefully shuts down the server
func (s *Server) WaitForShutdown() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
s.logger.Info("received shutdown signal", "signal", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.logger.Info("shutting down http server...")
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.Error("server shutdown error", "error", err)
} else {
s.logger.Info("http server shutdown complete")
}
// Run cleanup functions
for _, cleanup := range s.cleanupFuncs {
cleanup()
}
s.logger.Info("api server stopped")
}
// Close cleans up server resources
func (s *Server) Close() error {
// Run all cleanup functions
for _, cleanup := range s.cleanupFuncs {
cleanup()
}
return nil
}