- Move ci-test.sh and setup.sh to scripts/ - Trim docs/src/zig-cli.md to current structure - Replace hardcoded secrets with placeholders in configs - Update .gitignore to block .env*, secrets/, keys, build artifacts - Slim README.md to reflect current CLI/TUI split - Add cleanup trap to ci-test.sh - Ensure no secrets are committed
327 lines
8.2 KiB
Go
327 lines
8.2 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/middleware"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
)
|
|
|
|
// Server represents the API server
|
|
type Server struct {
|
|
config *ServerConfig
|
|
httpServer *http.Server
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
taskQueue *queue.TaskQueue
|
|
db *storage.DB
|
|
handlers *Handlers
|
|
sec *middleware.SecurityMiddleware
|
|
cleanupFuncs []func()
|
|
jupyterServiceMgr *jupyter.ServiceManager
|
|
}
|
|
|
|
// NewServer creates a new API server
|
|
func NewServer(configPath string) (*Server, error) {
|
|
// Load configuration
|
|
cfg, err := LoadServerConfig(configPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := cfg.Validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
server := &Server{
|
|
config: cfg,
|
|
}
|
|
|
|
// Initialize components
|
|
if err := server.initializeComponents(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Setup HTTP server
|
|
server.setupHTTPServer()
|
|
|
|
return server, nil
|
|
}
|
|
|
|
// initializeComponents initializes all server components
|
|
func (s *Server) initializeComponents() error {
|
|
// Setup logging
|
|
if err := s.config.EnsureLogDirectory(); err != nil {
|
|
return err
|
|
}
|
|
s.logger = s.setupLogger()
|
|
|
|
// Initialize experiment manager
|
|
if err := s.initExperimentManager(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize task queue
|
|
if err := s.initTaskQueue(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize database
|
|
if err := s.initDatabase(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize security
|
|
s.initSecurity()
|
|
|
|
// Initialize Jupyter service manager
|
|
s.initJupyterServiceManager()
|
|
|
|
// Initialize handlers
|
|
s.handlers = NewHandlers(s.expManager, s.jupyterServiceMgr, s.logger)
|
|
|
|
return nil
|
|
}
|
|
|
|
// setupLogger creates and configures the logger
|
|
func (s *Server) setupLogger() *logging.Logger {
|
|
logger := logging.NewLoggerFromConfig(s.config.Logging)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
return logger.Component(ctx, "api-server")
|
|
}
|
|
|
|
// initExperimentManager initializes the experiment manager
|
|
func (s *Server) initExperimentManager() error {
|
|
s.expManager = experiment.NewManager(s.config.BasePath)
|
|
if err := s.expManager.Initialize(); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Info("experiment manager initialized", "base_path", s.config.BasePath)
|
|
return nil
|
|
}
|
|
|
|
// initTaskQueue initializes the task queue
|
|
func (s *Server) initTaskQueue() error {
|
|
queueCfg := queue.Config{
|
|
RedisAddr: s.config.Redis.Addr,
|
|
RedisPassword: s.config.Redis.Password,
|
|
RedisDB: s.config.Redis.DB,
|
|
}
|
|
|
|
if queueCfg.RedisAddr == "" {
|
|
queueCfg.RedisAddr = "localhost:6379"
|
|
}
|
|
if s.config.Redis.URL != "" {
|
|
queueCfg.RedisAddr = s.config.Redis.URL
|
|
}
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queueCfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.taskQueue = taskQueue
|
|
s.logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
|
|
|
|
// Add cleanup function
|
|
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
|
s.logger.Info("stopping task queue...")
|
|
if err := s.taskQueue.Close(); err != nil {
|
|
s.logger.Error("failed to stop task queue", "error", err)
|
|
} else {
|
|
s.logger.Info("task queue stopped")
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// initDatabase initializes the database connection
|
|
func (s *Server) initDatabase() error {
|
|
if s.config.Database.Type == "" {
|
|
return nil
|
|
}
|
|
|
|
dbConfig := storage.DBConfig{
|
|
Type: s.config.Database.Type,
|
|
Connection: s.config.Database.Connection,
|
|
Host: s.config.Database.Host,
|
|
Port: s.config.Database.Port,
|
|
Username: s.config.Database.Username,
|
|
Password: s.config.Database.Password,
|
|
Database: s.config.Database.Database,
|
|
}
|
|
|
|
db, err := storage.NewDB(dbConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.db = db
|
|
s.logger.Info("database initialized", "type", s.config.Database.Type)
|
|
|
|
// Add cleanup function
|
|
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
|
s.logger.Info("closing database connection...")
|
|
if err := s.db.Close(); err != nil {
|
|
s.logger.Error("failed to close database", "error", err)
|
|
} else {
|
|
s.logger.Info("database connection closed")
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// initSecurity initializes security middleware
|
|
func (s *Server) initSecurity() {
|
|
authConfig := s.config.BuildAuthConfig()
|
|
rlOpts := s.buildRateLimitOptions()
|
|
s.sec = middleware.NewSecurityMiddleware(authConfig, os.Getenv("JWT_SECRET"), rlOpts)
|
|
}
|
|
|
|
// buildRateLimitOptions builds rate limit options from configuration
|
|
func (s *Server) buildRateLimitOptions() *middleware.RateLimitOptions {
|
|
if !s.config.Security.RateLimit.Enabled || s.config.Security.RateLimit.RequestsPerMinute <= 0 {
|
|
return nil
|
|
}
|
|
|
|
return &middleware.RateLimitOptions{
|
|
RequestsPerMinute: s.config.Security.RateLimit.RequestsPerMinute,
|
|
BurstSize: s.config.Security.RateLimit.BurstSize,
|
|
}
|
|
}
|
|
|
|
// initJupyterServiceManager initializes the Jupyter service manager
|
|
func (s *Server) initJupyterServiceManager() {
|
|
serviceConfig := jupyter.GetDefaultServiceConfig()
|
|
|
|
sm, err := jupyter.NewServiceManager(s.logger, serviceConfig)
|
|
if err != nil {
|
|
s.logger.Error("failed to initialize Jupyter service manager", "error", err)
|
|
return
|
|
}
|
|
|
|
s.jupyterServiceMgr = sm
|
|
s.logger.Info("jupyter service manager initialized")
|
|
}
|
|
|
|
// setupHTTPServer sets up the HTTP server and routes
|
|
func (s *Server) setupHTTPServer() {
|
|
mux := http.NewServeMux()
|
|
|
|
// Register WebSocket handler
|
|
wsHandler := NewWSHandler(s.config.BuildAuthConfig(), s.logger, s.expManager, s.taskQueue)
|
|
mux.Handle("/ws", wsHandler)
|
|
|
|
// Register HTTP handlers
|
|
s.handlers.RegisterHandlers(mux)
|
|
|
|
// Wrap with middleware
|
|
finalHandler := s.wrapWithMiddleware(mux)
|
|
|
|
s.httpServer = &http.Server{
|
|
Addr: s.config.Server.Address,
|
|
Handler: finalHandler,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
}
|
|
|
|
// wrapWithMiddleware wraps the handler with security middleware
|
|
func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/ws" {
|
|
mux.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
handler := s.sec.APIKeyAuth(mux)
|
|
handler = s.sec.RateLimit(handler)
|
|
handler = middleware.SecurityHeaders(handler)
|
|
handler = middleware.CORS(handler)
|
|
handler = middleware.RequestTimeout(30 * time.Second)(handler)
|
|
handler = middleware.AuditLogger(handler)
|
|
if len(s.config.Security.IPWhitelist) > 0 {
|
|
handler = s.sec.IPWhitelist(s.config.Security.IPWhitelist)(handler)
|
|
}
|
|
handler.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Start starts the server
|
|
func (s *Server) Start() error {
|
|
if !s.config.Server.TLS.Enabled {
|
|
s.logger.Warn(
|
|
"TLS disabled for API server; do not use this configuration in production",
|
|
"address", s.config.Server.Address,
|
|
)
|
|
}
|
|
|
|
go func() {
|
|
var err error
|
|
if s.config.Server.TLS.Enabled {
|
|
s.logger.Info("starting HTTPS server", "address", s.config.Server.Address)
|
|
err = s.httpServer.ListenAndServeTLS(
|
|
s.config.Server.TLS.CertFile,
|
|
s.config.Server.TLS.KeyFile,
|
|
)
|
|
} else {
|
|
s.logger.Info("starting HTTP server", "address", s.config.Server.Address)
|
|
err = s.httpServer.ListenAndServe()
|
|
}
|
|
|
|
if err != nil && err != http.ErrServerClosed {
|
|
s.logger.Error("server failed", "error", err)
|
|
}
|
|
os.Exit(1)
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// WaitForShutdown waits for shutdown signals and gracefully shuts down the server
|
|
func (s *Server) WaitForShutdown() {
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
sig := <-sigChan
|
|
s.logger.Info("received shutdown signal", "signal", sig)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
s.logger.Info("shutting down http server...")
|
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
|
s.logger.Error("server shutdown error", "error", err)
|
|
} else {
|
|
s.logger.Info("http server shutdown complete")
|
|
}
|
|
|
|
// Run cleanup functions
|
|
for _, cleanup := range s.cleanupFuncs {
|
|
cleanup()
|
|
}
|
|
|
|
s.logger.Info("api server stopped")
|
|
}
|
|
|
|
// Close cleans up server resources
|
|
func (s *Server) Close() error {
|
|
// Run all cleanup functions
|
|
for _, cleanup := range s.cleanupFuncs {
|
|
cleanup()
|
|
}
|
|
return nil
|
|
}
|