// Package main implements the fetch_ml API server package main import ( "context" "encoding/json" "flag" "fmt" "log" "net/http" "os" "os/signal" "path/filepath" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/middleware" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "gopkg.in/yaml.v3" ) // Config structure matching worker config. type Config struct { BasePath string `yaml:"base_path"` Auth auth.Config `yaml:"auth"` Server ServerConfig `yaml:"server"` Security SecurityConfig `yaml:"security"` Redis RedisConfig `yaml:"redis"` Database DatabaseConfig `yaml:"database"` Logging logging.Config `yaml:"logging"` Resources config.ResourceConfig `yaml:"resources"` } // RedisConfig holds Redis connection configuration. type RedisConfig struct { Addr string `yaml:"addr"` Password string `yaml:"password"` DB int `yaml:"db"` URL string `yaml:"url"` } // DatabaseConfig holds database connection configuration. type DatabaseConfig struct { Type string `yaml:"type"` Connection string `yaml:"connection"` Host string `yaml:"host"` Port int `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` Database string `yaml:"database"` } // SecurityConfig holds security-related configuration. type SecurityConfig struct { RateLimit RateLimitConfig `yaml:"rate_limit"` IPWhitelist []string `yaml:"ip_whitelist"` FailedLockout LockoutConfig `yaml:"failed_login_lockout"` } // RateLimitConfig holds rate limiting configuration. type RateLimitConfig struct { Enabled bool `yaml:"enabled"` RequestsPerMinute int `yaml:"requests_per_minute"` BurstSize int `yaml:"burst_size"` } // LockoutConfig holds failed login lockout configuration. type LockoutConfig struct { Enabled bool `yaml:"enabled"` MaxAttempts int `yaml:"max_attempts"` LockoutDuration string `yaml:"lockout_duration"` } // ServerConfig holds server configuration. type ServerConfig struct { Address string `yaml:"address"` TLS TLSConfig `yaml:"tls"` } // TLSConfig holds TLS configuration. type TLSConfig struct { Enabled bool `yaml:"enabled"` CertFile string `yaml:"cert_file"` KeyFile string `yaml:"key_file"` } // LoadConfig loads configuration from a YAML file. func LoadConfig(path string) (*Config, error) { data, err := fileutil.SecureFileRead(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } return &cfg, nil } func main() { configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path") apiKey := flag.String("api-key", "", "API key for authentication") flag.Parse() cfg, err := loadServerConfig(*configFile) if err != nil { log.Fatalf("Failed to load config: %v", err) } if err := ensureLogDirectory(cfg.Logging); err != nil { log.Fatalf("Failed to prepare log directory: %v", err) } logger := setupLogger(cfg.Logging) expManager, err := initExperimentManager(cfg.BasePath, logger) if err != nil { logger.Fatal("failed to initialize experiment manager", "error", err) } taskQueue, queueCleanup := initTaskQueue(cfg, logger) if queueCleanup != nil { defer queueCleanup() } db, dbCleanup := initDatabase(cfg, logger) if dbCleanup != nil { defer dbCleanup() } authCfg := buildAuthConfig(cfg.Auth, logger) sec := newSecurityMiddleware(cfg) mux := buildHTTPMux(cfg, logger, expManager, taskQueue, authCfg, db) finalHandler := wrapWithMiddleware(cfg, sec, mux) server := newHTTPServer(cfg, finalHandler) startServer(server, cfg, logger) waitForShutdown(server, logger) _ = apiKey // Reserved for future authentication enhancements } func loadServerConfig(path string) (*Config, error) { resolvedConfig, err := config.ResolveConfigPath(path) if err != nil { return nil, err } cfg, err := LoadConfig(resolvedConfig) if err != nil { return nil, err } cfg.Resources.ApplyDefaults() return cfg, nil } func ensureLogDirectory(cfg logging.Config) error { if cfg.File == "" { return nil } logDir := filepath.Dir(cfg.File) log.Printf("Creating log directory: %s", logDir) return os.MkdirAll(logDir, 0750) } func setupLogger(cfg logging.Config) *logging.Logger { logger := logging.NewLoggerFromConfig(cfg) ctx := logging.EnsureTrace(context.Background()) return logger.Component(ctx, "api-server") } func initExperimentManager(basePath string, logger *logging.Logger) (*experiment.Manager, error) { if basePath == "" { basePath = "/tmp/ml-experiments" } expManager := experiment.NewManager(basePath) log.Printf("Initializing experiment manager with base_path: %s", basePath) if err := expManager.Initialize(); err != nil { return nil, err } logger.Info("experiment manager initialized", "base_path", basePath) return expManager, nil } func buildAuthConfig(cfg auth.Config, logger *logging.Logger) *auth.Config { if !cfg.Enabled { return nil } logger.Info("authentication enabled") return &cfg } func newSecurityMiddleware(cfg *Config) *middleware.SecurityMiddleware { apiKeys := collectAPIKeys(cfg.Auth.APIKeys) rlOpts := buildRateLimitOptions(cfg.Security.RateLimit) return middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"), rlOpts) } func collectAPIKeys(keys map[auth.Username]auth.APIKeyEntry) []string { apiKeys := make([]string, 0, len(keys)) for username := range keys { apiKeys = append(apiKeys, string(username)) } return apiKeys } func buildRateLimitOptions(cfg RateLimitConfig) *middleware.RateLimitOptions { if !cfg.Enabled || cfg.RequestsPerMinute <= 0 { return nil } return &middleware.RateLimitOptions{ RequestsPerMinute: cfg.RequestsPerMinute, BurstSize: cfg.BurstSize, } } func initTaskQueue(cfg *Config, logger *logging.Logger) (*queue.TaskQueue, func()) { queueCfg := queue.Config{ RedisAddr: cfg.Redis.Addr, RedisPassword: cfg.Redis.Password, RedisDB: cfg.Redis.DB, } if queueCfg.RedisAddr == "" { queueCfg.RedisAddr = config.DefaultRedisAddr } if cfg.Redis.URL != "" { queueCfg.RedisAddr = cfg.Redis.URL } taskQueue, err := queue.NewTaskQueue(queueCfg) if err != nil { logger.Error("failed to initialize task queue", "error", err) return nil, nil } logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr) cleanup := func() { logger.Info("stopping task queue...") if err := taskQueue.Close(); err != nil { logger.Error("failed to stop task queue", "error", err) } else { logger.Info("task queue stopped") } } return taskQueue, cleanup } func initDatabase(cfg *Config, logger *logging.Logger) (*storage.DB, func()) { if cfg.Database.Type == "" { return nil, nil } dbConfig := storage.DBConfig{ Type: cfg.Database.Type, Connection: cfg.Database.Connection, Host: cfg.Database.Host, Port: cfg.Database.Port, Username: cfg.Database.Username, Password: cfg.Database.Password, Database: cfg.Database.Database, } db, err := storage.NewDB(dbConfig) if err != nil { logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err) return nil, nil } schemaPath := schemaPathForDB(cfg.Database.Type) if schemaPath == "" { logger.Error("unsupported database type", "type", cfg.Database.Type) _ = db.Close() return nil, nil } schema, err := fileutil.SecureFileRead(schemaPath) if err != nil { logger.Error("failed to read database schema file", "path", schemaPath, "error", err) _ = db.Close() return nil, nil } if err := db.Initialize(string(schema)); err != nil { logger.Error("failed to initialize database schema", "error", err) _ = db.Close() return nil, nil } logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection) cleanup := func() { logger.Info("closing database connection...") if err := db.Close(); err != nil { logger.Error("failed to close database", "error", err) } else { logger.Info("database connection closed") } } return db, cleanup } func schemaPathForDB(dbType string) string { switch dbType { case "sqlite": return "internal/storage/schema_sqlite.sql" case "postgres", "postgresql": return "internal/storage/schema_postgres.sql" default: return "" } } func buildHTTPMux( cfg *Config, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue, authCfg *auth.Config, db *storage.DB, ) *http.ServeMux { mux := http.NewServeMux() wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue) mux.Handle("/ws", wsHandler) mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = fmt.Fprintf(w, "OK\n") }) mux.HandleFunc("/db-status", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") if db == nil { w.WriteHeader(http.StatusServiceUnavailable) _, _ = fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`) return } var result struct { Status string `json:"status"` Type string `json:"type"` Path string `json:"path"` Message string `json:"message"` } result.Status = "connected" result.Type = cfg.Database.Type result.Path = cfg.Database.Connection result.Message = fmt.Sprintf("%s database is operational", cfg.Database.Type) if err := db.RecordSystemMetric("db_test", "ok"); err != nil { result.Status = "error" result.Message = fmt.Sprintf("Database query failed: %v", err) } jsonBytes, _ := json.Marshal(result) _, _ = w.Write(jsonBytes) }) return mux } func wrapWithMiddleware(cfg *Config, sec *middleware.SecurityMiddleware, mux *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/ws" { mux.ServeHTTP(w, r) return } handler := sec.RateLimit(mux) handler = middleware.SecurityHeaders(handler) handler = middleware.CORS(handler) handler = middleware.RequestTimeout(30 * time.Second)(handler) handler = middleware.AuditLogger(handler) if len(cfg.Security.IPWhitelist) > 0 { handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler) } handler.ServeHTTP(w, r) }) } func newHTTPServer(cfg *Config, handler http.Handler) *http.Server { return &http.Server{ Addr: cfg.Server.Address, Handler: handler, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, } } func startServer(server *http.Server, cfg *Config, logger *logging.Logger) { if !cfg.Server.TLS.Enabled { logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address) } go func() { if cfg.Server.TLS.Enabled { logger.Info("starting HTTPS server", "address", cfg.Server.Address) if err := server.ListenAndServeTLS( cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile, ); err != nil && err != http.ErrServerClosed { logger.Error("HTTPS server failed", "error", err) } } else { logger.Info("starting HTTP server", "address", cfg.Server.Address) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Error("HTTP server failed", "error", err) } } os.Exit(1) }() } func waitForShutdown(server *http.Server, logger *logging.Logger) { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan logger.Info("received shutdown signal", "signal", sig) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() logger.Info("shutting down http server...") if err := server.Shutdown(ctx); err != nil { logger.Error("server shutdown error", "error", err) } else { logger.Info("http server shutdown complete") } logger.Info("api server stopped") }