package main import ( "context" "encoding/json" "flag" "fmt" "log" "net/http" "os" "os/signal" "path/filepath" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/middleware" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "gopkg.in/yaml.v3" ) // Config structure matching worker config type Config struct { BasePath string `yaml:"base_path"` Auth auth.AuthConfig `yaml:"auth"` Server ServerConfig `yaml:"server"` Security SecurityConfig `yaml:"security"` Redis RedisConfig `yaml:"redis"` Database DatabaseConfig `yaml:"database"` Logging logging.Config `yaml:"logging"` } type RedisConfig struct { Addr string `yaml:"addr"` Password string `yaml:"password"` DB int `yaml:"db"` URL string `yaml:"url"` } type DatabaseConfig struct { Type string `yaml:"type"` Connection string `yaml:"connection"` Host string `yaml:"host"` Port int `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` Database string `yaml:"database"` } type SecurityConfig struct { RateLimit RateLimitConfig `yaml:"rate_limit"` IPWhitelist []string `yaml:"ip_whitelist"` FailedLockout LockoutConfig `yaml:"failed_login_lockout"` } type RateLimitConfig struct { Enabled bool `yaml:"enabled"` RequestsPerMinute int `yaml:"requests_per_minute"` BurstSize int `yaml:"burst_size"` } type LockoutConfig struct { Enabled bool `yaml:"enabled"` MaxAttempts int `yaml:"max_attempts"` LockoutDuration string `yaml:"lockout_duration"` } type ServerConfig struct { Address string `yaml:"address"` TLS TLSConfig `yaml:"tls"` } type TLSConfig struct { Enabled bool `yaml:"enabled"` CertFile string `yaml:"cert_file"` KeyFile string `yaml:"key_file"` } func LoadConfig(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } return &cfg, nil } func main() { // Parse flags configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path") apiKey := flag.String("api-key", "", "API key for authentication") flag.Parse() // Load config resolvedConfig, err := config.ResolveConfigPath(*configFile) if err != nil { log.Fatalf("Failed to resolve config: %v", err) } cfg, err := LoadConfig(resolvedConfig) if err != nil { log.Fatalf("Failed to load config: %v", err) } // Ensure log directory exists if cfg.Logging.File != "" { logDir := filepath.Dir(cfg.Logging.File) log.Printf("Creating log directory: %s", logDir) if err := os.MkdirAll(logDir, 0755); err != nil { log.Fatalf("Failed to create log directory: %v", err) } } // Setup logging logger := logging.NewLoggerFromConfig(cfg.Logging) ctx := logging.EnsureTrace(context.Background()) logger = logger.Component(ctx, "api-server") // Setup experiment manager basePath := cfg.BasePath if basePath == "" { basePath = "/tmp/ml-experiments" } expManager := experiment.NewManager(basePath) log.Printf("Initializing experiment manager with base_path: %s", basePath) if err := expManager.Initialize(); err != nil { logger.Fatal("failed to initialize experiment manager", "error", err) } logger.Info("experiment manager initialized", "base_path", basePath) // Setup auth var authCfg *auth.AuthConfig if cfg.Auth.Enabled { authCfg = &cfg.Auth logger.Info("authentication enabled") } // Setup HTTP server with security middleware mux := http.NewServeMux() // Convert API keys from map to slice for security middleware apiKeys := make([]string, 0, len(cfg.Auth.APIKeys)) for username := range cfg.Auth.APIKeys { // For now, use username as the key (in production, this should be the actual API key) apiKeys = append(apiKeys, string(username)) } // Create security middleware sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET")) // Setup TaskQueue queueCfg := queue.Config{ RedisAddr: cfg.Redis.Addr, RedisPassword: cfg.Redis.Password, RedisDB: cfg.Redis.DB, } if queueCfg.RedisAddr == "" { queueCfg.RedisAddr = config.DefaultRedisAddr } // Support URL format for Redis if cfg.Redis.URL != "" { queueCfg.RedisAddr = cfg.Redis.URL } taskQueue, err := queue.NewTaskQueue(queueCfg) if err != nil { logger.Error("failed to initialize task queue", "error", err) // We continue without queue, but queue operations will fail } else { logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr) defer func() { logger.Info("stopping task queue...") if err := taskQueue.Close(); err != nil { logger.Error("failed to stop task queue", "error", err) } else { logger.Info("task queue stopped") } }() } // Setup database if configured var db *storage.DB if cfg.Database.Type != "" { dbConfig := storage.DBConfig{ Type: cfg.Database.Type, Connection: cfg.Database.Connection, Host: cfg.Database.Host, Port: cfg.Database.Port, Username: cfg.Database.Username, Password: cfg.Database.Password, Database: cfg.Database.Database, } db, err = storage.NewDB(dbConfig) if err != nil { logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err) } else { // Load appropriate database schema var schemaPath string if cfg.Database.Type == "sqlite" { schemaPath = "internal/storage/schema.sql" } else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" { schemaPath = "internal/storage/schema_postgres.sql" } else { logger.Error("unsupported database type", "type", cfg.Database.Type) db.Close() db = nil } if db != nil && schemaPath != "" { schema, err := os.ReadFile(schemaPath) if err != nil { logger.Error("failed to read database schema file", "path", schemaPath, "error", err) db.Close() db = nil } else { if err := db.Initialize(string(schema)); err != nil { logger.Error("failed to initialize database schema", "error", err) db.Close() db = nil } else { logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection) defer func() { logger.Info("closing database connection...") if err := db.Close(); err != nil { logger.Error("failed to close database", "error", err) } else { logger.Info("database connection closed") } }() } } } } } // Setup WebSocket handler with authentication wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue) // WebSocket endpoint - no middleware to avoid hijacking issues mux.Handle("/ws", wsHandler) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "OK\n") }) // Database status endpoint mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if db != nil { // Test database connection with a simple query var result struct { Status string `json:"status"` Type string `json:"type"` Path string `json:"path"` Message string `json:"message"` } result.Status = "connected" result.Type = "sqlite" result.Path = cfg.Database.Connection result.Message = "SQLite database is operational" // Test a simple query to verify connectivity if err := db.RecordSystemMetric("db_test", "ok"); err != nil { result.Status = "error" result.Message = fmt.Sprintf("Database query failed: %v", err) } jsonBytes, _ := json.Marshal(result) w.Write(jsonBytes) } else { w.WriteHeader(http.StatusServiceUnavailable) fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`) } }) // Apply security middleware to all routes except WebSocket // Create separate handlers for WebSocket vs other routes var finalHandler http.Handler = mux // Wrap non-websocket routes with security middleware finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/ws" { mux.ServeHTTP(w, r) } else { // Apply middleware chain for non-WebSocket routes handler := sec.RateLimit(mux) handler = middleware.SecurityHeaders(handler) handler = middleware.CORS(handler) handler = middleware.RequestTimeout(30 * time.Second)(handler) // Apply audit logger and IP whitelist only to non-WebSocket routes handler = middleware.AuditLogger(handler) if len(cfg.Security.IPWhitelist) > 0 { handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler) } handler.ServeHTTP(w, r) } }) var handler http.Handler = finalHandler server := &http.Server{ Addr: cfg.Server.Address, Handler: handler, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, } if !cfg.Server.TLS.Enabled { logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address) } // Start server in goroutine go func() { // Setup TLS if configured if cfg.Server.TLS.Enabled { logger.Info("starting HTTPS server", "address", cfg.Server.Address) if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed { logger.Error("HTTPS server failed", "error", err) } } else { logger.Info("starting HTTP server", "address", cfg.Server.Address) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Error("HTTP server failed", "error", err) } } os.Exit(1) }() // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan logger.Info("received shutdown signal", "signal", sig) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() logger.Info("shutting down http server...") if err := server.Shutdown(ctx); err != nil { logger.Error("server shutdown error", "error", err) } else { logger.Info("http server shutdown complete") } logger.Info("api server stopped") _ = expManager // Use expManager to avoid unused warning _ = apiKey // Will be used for auth later }