package api import ( "context" "net/http" "os" "os/signal" "strings" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/middleware" "github.com/jfraeys/fetch_ml/internal/prommetrics" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" ) // Server represents the API server type Server struct { config *ServerConfig httpServer *http.Server logger *logging.Logger expManager *experiment.Manager taskQueue queue.Backend db *storage.DB handlers *Handlers sec *middleware.SecurityMiddleware cleanupFuncs []func() jupyterServiceMgr *jupyter.ServiceManager auditLogger *audit.Logger promMetrics *prommetrics.Metrics // Prometheus metrics } // NewServer creates a new API server func NewServer(configPath string) (*Server, error) { // Load configuration cfg, err := LoadServerConfig(configPath) if err != nil { return nil, err } if err := cfg.Validate(); err != nil { return nil, err } server := &Server{ config: cfg, } // Initialize components if err := server.initializeComponents(); err != nil { return nil, err } // Setup HTTP server server.setupHTTPServer() return server, nil } // initializeComponents initializes all server components func (s *Server) initializeComponents() error { // Setup logging if err := s.config.EnsureLogDirectory(); err != nil { return err } s.logger = s.setupLogger() // Initialize experiment manager if err := s.initExperimentManager(); err != nil { return err } // Initialize task queue if err := s.initTaskQueue(); err != nil { return err } // Initialize database if err := s.initDatabase(); err != nil { return err } // Initialize database schema (if DB enabled) if err := s.initDatabaseSchema(); err != nil { return err } // Initialize security s.initSecurity() // Initialize Jupyter service manager s.initJupyterServiceManager() // Initialize handlers s.handlers = NewHandlers(s.expManager, nil, s.logger) return nil } func (s *Server) initDatabaseSchema() error { if s.db == nil { return nil } schema, err := storage.SchemaForDBType(s.config.Database.Type) if err != nil { return err } if err := s.db.Initialize(schema); err != nil { return err } s.logger.Info("database schema initialized", "type", s.config.Database.Type) return nil } // setupLogger creates and configures the logger func (s *Server) setupLogger() *logging.Logger { logger := logging.NewLoggerFromConfig(s.config.Logging) ctx := logging.EnsureTrace(context.Background()) return logger.Component(ctx, "api-server") } // initExperimentManager initializes the experiment manager func (s *Server) initExperimentManager() error { s.expManager = experiment.NewManager(s.config.BasePath) if err := s.expManager.Initialize(); err != nil { return err } s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath) return nil } // initTaskQueue initializes the task queue func (s *Server) initTaskQueue() error { backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend)) if backend == "" { backend = "redis" } redisAddr := strings.TrimSpace(s.config.Redis.Addr) if redisAddr == "" { redisAddr = "localhost:6379" } if strings.TrimSpace(s.config.Redis.URL) != "" { redisAddr = strings.TrimSpace(s.config.Redis.URL) } backendCfg := queue.BackendConfig{ Backend: queue.QueueBackend(backend), RedisAddr: redisAddr, RedisPassword: s.config.Redis.Password, RedisDB: s.config.Redis.DB, SQLitePath: s.config.Queue.SQLitePath, FilesystemPath: s.config.Queue.FilesystemPath, FallbackToFilesystem: s.config.Queue.FallbackToFilesystem, MetricsFlushInterval: 0, } taskQueue, err := queue.NewBackend(backendCfg) if err != nil { return err } s.taskQueue = taskQueue if backend == "sqlite" { s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath) } else { s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr) } // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { s.logger.Info("stopping task queue...") if err := s.taskQueue.Close(); err != nil { s.logger.Error("failed to stop task queue", "error", err) } else { s.logger.Info("task queue stopped") } }) return nil } // initDatabase initializes the database connection func (s *Server) initDatabase() error { if s.config.Database.Type == "" { return nil } dbConfig := storage.DBConfig{ Type: s.config.Database.Type, Connection: s.config.Database.Connection, Host: s.config.Database.Host, Port: s.config.Database.Port, Username: s.config.Database.Username, Password: s.config.Database.Password, Database: s.config.Database.Database, } db, err := storage.NewDB(dbConfig) if err != nil { return err } s.db = db s.logger.Info("database initialized", "type", s.config.Database.Type) // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { s.logger.Info("closing database connection...") if err := s.db.Close(); err != nil { s.logger.Error("failed to close database", "error", err) } else { s.logger.Info("database connection closed") } }) return nil } // initSecurity initializes security middleware func (s *Server) initSecurity() { authConfig := s.config.BuildAuthConfig() rlOpts := s.buildRateLimitOptions() s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts) } // buildRateLimitOptions builds rate limit options from configuration func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions { if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 { return nil } return &middleware.RateLimitOptions{ RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute, BurstSize: s.config.Security.RateLimit.BurstSize, } } // initJupyterServiceManager initializes the Jupyter service manager func (s *Server) initJupyterServiceManager() { serviceConfig := jupyter.GetDefaultServiceConfig() sm, err := jupyter.NewServiceManager(s.logger, serviceConfig) if err != nil { s.logger.Error("failed to initialize Jupyter service manager", "error", err) return } s.jupyterServiceMgr = sm s.logger.Info("jupyter service manager initialized") } // setupHTTPServer sets up the HTTP server and routes func (s *Server) setupHTTPServer() { mux := http.NewServeMux() // Initialize Prometheus metrics (if enabled) if s.config.Monitoring.Prometheus.Enabled { s.promMetrics = prommetrics.New() s.logger.Info("prometheus metrics initialized") // Register metrics endpoint metricsPath := s.config.Monitoring.Prometheus.Path if metricsPath == "" { metricsPath = "/metrics" } mux.Handle(metricsPath, s.promMetrics.Handler()) s.logger.Info("metrics endpoint registered", "path", metricsPath) } // Initialize health check handler if s.config.Monitoring.HealthChecks.Enabled { healthHandler := NewHealthHandler(s) healthHandler.RegisterRoutes(mux) mux.HandleFunc("/health/ok", s.handlers.handleHealth) s.logger.Info("health check endpoints registered") } // Initialize audit logger var auditLogger *audit.Logger if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" { al, err := audit.NewLogger( s.config.Security.AuditLogging.Enabled, s.config.Security.AuditLogging.LogPath, s.logger, ) if err != nil { s.logger.Warn("failed to initialize audit logger", "error", err) } else { auditLogger = al s.auditLogger = al // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { s.logger.Info("closing audit logger...") if err := auditLogger.Close(); err != nil { s.logger.Error("failed to close audit logger", "error", err) } }) } } // Register WebSocket handler with security config and audit logger securityCfg := getSecurityConfig(s.config) wsHandler := NewWSHandler( s.config.BuildAuthConfig(), s.logger, s.expManager, s.config.DataDir, s.taskQueue, s.db, s.jupyterServiceMgr, securityCfg, auditLogger, ) // Wrap WebSocket handler with metrics mux.Handle("/ws", wsHandler) // Register HTTP handlers s.handlers.RegisterHandlers(mux) // Wrap with middleware finalHandler := s.wrapWithMiddleware(mux) finalHandler = s.wrapWithMetrics(finalHandler) s.httpServer = &http.Server{ Addr: s.config.Server.Address, Handler: finalHandler, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, } } // getSecurityConfig extracts security config from server config func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig { return &config.SecurityConfig{ AllowedOrigins: cfg.Security.AllowedOrigins, ProductionMode: cfg.Security.ProductionMode, APIKeyRotationDays: cfg.Security.APIKeyRotationDays, AuditLogging: config.AuditLoggingConfig{ Enabled: cfg.Security.AuditLogging.Enabled, LogPath: cfg.Security.AuditLogging.LogPath, }, IPWhitelist: cfg.Security.IPWhitelist, } } // wrapWithMiddleware wraps the handler with security middleware func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for WebSocket and health endpoints if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") { mux.ServeHTTP(w, r) return } handler := s.sec.APIKeyAuth(mux) handler = s.sec.RateLimit(handler) handler = middleware.SecurityHeaders(handler) handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) handler = middleware.RequestTimeout(30 * time.Second)(handler) handler = middleware.AuditLogger(handler) if len(s.config.Security.IPWhitelist) > 0 { handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler) } handler.ServeHTTP(w, r) }) } // Start starts the server func (s *Server) Start() error { if !s.config.Server.TLS.Enabled { s.logger.Warn( "TLS disabled for API server; do not use this configuration in production", "address", s.config.Server.Address, ) } go func() { var err error if s.config.Server.TLS.Enabled { s.logger.Info("starting HTTPS server", "address", s.config.Server.Address) err = s.httpServer.ListenAndServeTLS( s.config.Server.TLS.CertFile, s.config.Server.TLS.KeyFile, ) } else { s.logger.Info("starting HTTP server", "address", s.config.Server.Address) err = s.httpServer.ListenAndServe() } if err != nil && err != http.ErrServerClosed { s.logger.Error("server failed", "error", err) } os.Exit(1) }() return nil } // WaitForShutdown waits for shutdown signals and gracefully shuts down the server func (s *Server) WaitForShutdown() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan s.logger.Info("received shutdown signal", "signal", sig) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() s.logger.Info("shutting down http server...") if err := s.httpServer.Shutdown(ctx); err != nil { s.logger.Error("server shutdown error", "error", err) } else { s.logger.Info("http server shutdown complete") } // Run cleanup functions for _, cleanup := range s.cleanupFuncs { cleanup() } s.logger.Info("api server stopped") } // Close cleans up server resources func (s *Server) Close() error { // Run all cleanup functions for _, cleanup := range s.cleanupFuncs { cleanup() } return nil }