package api import ( "context" "net/http" "os" "os/signal" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/middleware" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" ) // Server represents the API server type Server struct { config *ServerConfig httpServer *http.Server logger *logging.Logger expManager *experiment.Manager taskQueue *queue.TaskQueue db *storage.DB handlers *Handlers sec *middleware.SecurityMiddleware cleanupFuncs []func() jupyterServiceMgr *jupyter.ServiceManager } // NewServer creates a new API server func NewServer(configPath string) (*Server, error) { // Load configuration cfg, err := LoadServerConfig(configPath) if err != nil { return nil, err } if err := cfg.Validate(); err != nil { return nil, err } server := &Server{ config: cfg, } // Initialize components if err := server.initializeComponents(); err != nil { return nil, err } // Setup HTTP server server.setupHTTPServer() return server, nil } // initializeComponents initializes all server components func (s *Server) initializeComponents() error { // Setup logging if err := s.config.EnsureLogDirectory(); err != nil { return err } s.logger = s.setupLogger() // Initialize experiment manager if err := s.initExperimentManager(); err != nil { return err } // Initialize task queue if err := s.initTaskQueue(); err != nil { return err } // Initialize database if err := s.initDatabase(); err != nil { return err } // Initialize security s.initSecurity() // Initialize Jupyter service manager s.initJupyterServiceManager() // Initialize handlers s.handlers = NewHandlers(s.expManager, s.jupyterServiceMgr, s.logger) return nil } // setupLogger creates and configures the logger func (s *Server) setupLogger() *logging.Logger { logger := logging.NewLoggerFromConfig(s.config.Logging) ctx := logging.EnsureTrace(context.Background()) return logger.Component(ctx, "api-server") } // initExperimentManager initializes the experiment manager func (s *Server) initExperimentManager() error { s.expManager = experiment.NewManager(s.config.BasePath) if err := s.expManager.Initialize(); err != nil { return err } s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath) return nil } // initTaskQueue initializes the task queue func (s *Server) initTaskQueue() error { queueCfg := queue.Config{ RedisAddr: s.config.Redis.Addr, RedisPassword: s.config.Redis.Password, RedisDB: s.config.Redis.DB, } if queueCfg.RedisAddr == "" { queueCfg.RedisAddr = "localhost:6379" } if s.config.Redis.URL != "" { queueCfg.RedisAddr = s.config.Redis.URL } taskQueue, err := queue.NewTaskQueue(queueCfg) if err != nil { return err } s.taskQueue = taskQueue s.logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr) // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { s.logger.Info("stopping task queue...") if err := s.taskQueue.Close(); err != nil { s.logger.Error("failed to stop task queue", "error", err) } else { s.logger.Info("task queue stopped") } }) return nil } // initDatabase initializes the database connection func (s *Server) initDatabase() error { if s.config.Database.Type == "" { return nil } dbConfig := storage.DBConfig{ Type: s.config.Database.Type, Connection: s.config.Database.Connection, Host: s.config.Database.Host, Port: s.config.Database.Port, Username: s.config.Database.Username, Password: s.config.Database.Password, Database: s.config.Database.Database, } db, err := storage.NewDB(dbConfig) if err != nil { return err } s.db = db s.logger.Info("database initialized", "type", s.config.Database.Type) // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { s.logger.Info("closing database connection...") if err := s.db.Close(); err != nil { s.logger.Error("failed to close database", "error", err) } else { s.logger.Info("database connection closed") } }) return nil } // initSecurity initializes security middleware func (s *Server) initSecurity() { authConfig := s.config.BuildAuthConfig() rlOpts := s.buildRateLimitOptions() s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts) } // buildRateLimitOptions builds rate limit options from configuration func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions { if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 { return nil } return &middleware.RateLimitOptions{ RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute, BurstSize: s.config.Security.RateLimit.BurstSize, } } // initJupyterServiceManager initializes the Jupyter service manager func (s *Server) initJupyterServiceManager() { serviceConfig := jupyter.GetDefaultServiceConfig() sm, err := jupyter.NewServiceManager(s.logger, serviceConfig) if err != nil { s.logger.Error("failed to initialize Jupyter service manager", "error", err) return } s.jupyterServiceMgr = sm s.logger.Info("jupyter service manager initialized") } // setupHTTPServer sets up the HTTP server and routes func (s *Server) setupHTTPServer() { mux := http.NewServeMux() // Register WebSocket handler wsHandler := NewWSHandler(s.config.BuildAuthConfig(), s.logger, s.expManager, s.taskQueue) mux.Handle("/ws", wsHandler) // Register HTTP handlers s.handlers.RegisterHandlers(mux) // Wrap with middleware finalHandler := s.wrapWithMiddleware(mux) s.httpServer = &http.Server{ Addr: s.config.Server.Address, Handler: finalHandler, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, } } // wrapWithMiddleware wraps the handler with security middleware func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/ws" { mux.ServeHTTP(w, r) return } handler := s.sec.APIKeyAuth(mux) handler = s.sec.RateLimit(handler) handler = middleware.SecurityHeaders(handler) handler = middleware.CORS(handler) handler = middleware.RequestTimeout(30 * time.Second)(handler) handler = middleware.AuditLogger(handler) if len(s.config.Security.IPWhitelist) > 0 { handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler) } handler.ServeHTTP(w, r) }) } // Start starts the server func (s *Server) Start() error { if !s.config.Server.TLS.Enabled { s.logger.Warn( "TLS disabled for API server; do not use this configuration in production", "address", s.config.Server.Address, ) } go func() { var err error if s.config.Server.TLS.Enabled { s.logger.Info("starting HTTPS server", "address", s.config.Server.Address) err = s.httpServer.ListenAndServeTLS( s.config.Server.TLS.CertFile, s.config.Server.TLS.KeyFile, ) } else { s.logger.Info("starting HTTP server", "address", s.config.Server.Address) err = s.httpServer.ListenAndServe() } if err != nil && err != http.ErrServerClosed { s.logger.Error("server failed", "error", err) } os.Exit(1) }() return nil } // WaitForShutdown waits for shutdown signals and gracefully shuts down the server func (s *Server) WaitForShutdown() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan s.logger.Info("received shutdown signal", "signal", sig) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() s.logger.Info("shutting down http server...") if err := s.httpServer.Shutdown(ctx); err != nil { s.logger.Error("server shutdown error", "error", err) } else { s.logger.Info("http server shutdown complete") } // Run cleanup functions for _, cleanup := range s.cleanupFuncs { cleanup() } s.logger.Info("api server stopped") } // Close cleans up server resources func (s *Server) Close() error { // Run all cleanup functions for _, cleanup := range s.cleanupFuncs { cleanup() } return nil }