fetch_ml/cmd/api-server/main.go
Jeremie Fraeys 803677be57 feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
2025-12-04 16:53:53 -05:00

363 lines
10 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"gopkg.in/yaml.v3"
)
// Config structure matching worker config
type Config struct {
BasePath string `yaml:"base_path"`
Auth auth.AuthConfig `yaml:"auth"`
Server ServerConfig `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
}
type RedisConfig struct {
Addr string `yaml:"addr"`
Password string `yaml:"password"`
DB int `yaml:"db"`
URL string `yaml:"url"`
}
type DatabaseConfig struct {
Type string `yaml:"type"`
Connection string `yaml:"connection"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
type SecurityConfig struct {
RateLimit RateLimitConfig `yaml:"rate_limit"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
}
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
}
type LockoutConfig struct {
Enabled bool `yaml:"enabled"`
MaxAttempts int `yaml:"max_attempts"`
LockoutDuration string `yaml:"lockout_duration"`
}
type ServerConfig struct {
Address string `yaml:"address"`
TLS TLSConfig `yaml:"tls"`
}
type TLSConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func main() {
// Parse flags
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
apiKey := flag.String("api-key", "", "API key for authentication")
flag.Parse()
// Load config
resolvedConfig, err := config.ResolveConfigPath(*configFile)
if err != nil {
log.Fatalf("Failed to resolve config: %v", err)
}
cfg, err := LoadConfig(resolvedConfig)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Ensure log directory exists
if cfg.Logging.File != "" {
logDir := filepath.Dir(cfg.Logging.File)
log.Printf("Creating log directory: %s", logDir)
if err := os.MkdirAll(logDir, 0755); err != nil {
log.Fatalf("Failed to create log directory: %v", err)
}
}
// Setup logging
logger := logging.NewLoggerFromConfig(cfg.Logging)
ctx := logging.EnsureTrace(context.Background())
logger = logger.Component(ctx, "api-server")
// Setup experiment manager
basePath := cfg.BasePath
if basePath == "" {
basePath = "/tmp/ml-experiments"
}
expManager := experiment.NewManager(basePath)
log.Printf("Initializing experiment manager with base_path: %s", basePath)
if err := expManager.Initialize(); err != nil {
logger.Fatal("failed to initialize experiment manager", "error", err)
}
logger.Info("experiment manager initialized", "base_path", basePath)
// Setup auth
var authCfg *auth.AuthConfig
if cfg.Auth.Enabled {
authCfg = &cfg.Auth
logger.Info("authentication enabled")
}
// Setup HTTP server with security middleware
mux := http.NewServeMux()
// Convert API keys from map to slice for security middleware
apiKeys := make([]string, 0, len(cfg.Auth.APIKeys))
for username := range cfg.Auth.APIKeys {
// For now, use username as the key (in production, this should be the actual API key)
apiKeys = append(apiKeys, string(username))
}
// Create security middleware
sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"))
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: cfg.Redis.Addr,
RedisPassword: cfg.Redis.Password,
RedisDB: cfg.Redis.DB,
}
if queueCfg.RedisAddr == "" {
queueCfg.RedisAddr = config.DefaultRedisAddr
}
// Support URL format for Redis
if cfg.Redis.URL != "" {
queueCfg.RedisAddr = cfg.Redis.URL
}
taskQueue, err := queue.NewTaskQueue(queueCfg)
if err != nil {
logger.Error("failed to initialize task queue", "error", err)
// We continue without queue, but queue operations will fail
} else {
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
defer func() {
logger.Info("stopping task queue...")
if err := taskQueue.Close(); err != nil {
logger.Error("failed to stop task queue", "error", err)
} else {
logger.Info("task queue stopped")
}
}()
}
// Setup database if configured
var db *storage.DB
if cfg.Database.Type != "" {
dbConfig := storage.DBConfig{
Type: cfg.Database.Type,
Connection: cfg.Database.Connection,
Host: cfg.Database.Host,
Port: cfg.Database.Port,
Username: cfg.Database.Username,
Password: cfg.Database.Password,
Database: cfg.Database.Database,
}
db, err = storage.NewDB(dbConfig)
if err != nil {
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
} else {
// Load appropriate database schema
var schemaPath string
if cfg.Database.Type == "sqlite" {
schemaPath = "internal/storage/schema.sql"
} else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" {
schemaPath = "internal/storage/schema_postgres.sql"
} else {
logger.Error("unsupported database type", "type", cfg.Database.Type)
db.Close()
db = nil
}
if db != nil && schemaPath != "" {
schema, err := os.ReadFile(schemaPath)
if err != nil {
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
db.Close()
db = nil
} else {
if err := db.Initialize(string(schema)); err != nil {
logger.Error("failed to initialize database schema", "error", err)
db.Close()
db = nil
} else {
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
defer func() {
logger.Info("closing database connection...")
if err := db.Close(); err != nil {
logger.Error("failed to close database", "error", err)
} else {
logger.Info("database connection closed")
}
}()
}
}
}
}
}
// Setup WebSocket handler with authentication
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
// WebSocket endpoint - no middleware to avoid hijacking issues
mux.Handle("/ws", wsHandler)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "OK\n")
})
// Database status endpoint
mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if db != nil {
// Test database connection with a simple query
var result struct {
Status string `json:"status"`
Type string `json:"type"`
Path string `json:"path"`
Message string `json:"message"`
}
result.Status = "connected"
result.Type = "sqlite"
result.Path = cfg.Database.Connection
result.Message = "SQLite database is operational"
// Test a simple query to verify connectivity
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
result.Status = "error"
result.Message = fmt.Sprintf("Database query failed: %v", err)
}
jsonBytes, _ := json.Marshal(result)
w.Write(jsonBytes)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
}
})
// Apply security middleware to all routes except WebSocket
// Create separate handlers for WebSocket vs other routes
var finalHandler http.Handler = mux
// Wrap non-websocket routes with security middleware
finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ws" {
mux.ServeHTTP(w, r)
} else {
// Apply middleware chain for non-WebSocket routes
handler := sec.RateLimit(mux)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
// Apply audit logger and IP whitelist only to non-WebSocket routes
handler = middleware.AuditLogger(handler)
if len(cfg.Security.IPWhitelist) > 0 {
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
}
handler.ServeHTTP(w, r)
}
})
var handler http.Handler = finalHandler
server := &http.Server{
Addr: cfg.Server.Address,
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
if !cfg.Server.TLS.Enabled {
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
}
// Start server in goroutine
go func() {
// Setup TLS if configured
if cfg.Server.TLS.Enabled {
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed {
logger.Error("HTTPS server failed", "error", err)
}
} else {
logger.Info("starting HTTP server", "address", cfg.Server.Address)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server failed", "error", err)
}
}
os.Exit(1)
}()
// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
logger.Info("received shutdown signal", "signal", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
logger.Info("shutting down http server...")
if err := server.Shutdown(ctx); err != nil {
logger.Error("server shutdown error", "error", err)
} else {
logger.Info("http server shutdown complete")
}
logger.Info("api server stopped")
_ = expManager // Use expManager to avoid unused warning
_ = apiKey // Will be used for auth later
}