- Add API server with WebSocket support and REST endpoints - Implement authentication system with API keys and permissions - Add task queue system with Redis backend and error handling - Include storage layer with database migrations and schemas - Add comprehensive logging, metrics, and telemetry - Implement security middleware and network utilities - Add experiment management and container orchestration - Include configuration management with smart defaults
363 lines
10 KiB
Go
363 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/middleware"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Config structure matching worker config
|
|
type Config struct {
|
|
BasePath string `yaml:"base_path"`
|
|
Auth auth.AuthConfig `yaml:"auth"`
|
|
Server ServerConfig `yaml:"server"`
|
|
Security SecurityConfig `yaml:"security"`
|
|
Redis RedisConfig `yaml:"redis"`
|
|
Database DatabaseConfig `yaml:"database"`
|
|
Logging logging.Config `yaml:"logging"`
|
|
}
|
|
|
|
type RedisConfig struct {
|
|
Addr string `yaml:"addr"`
|
|
Password string `yaml:"password"`
|
|
DB int `yaml:"db"`
|
|
URL string `yaml:"url"`
|
|
}
|
|
|
|
type DatabaseConfig struct {
|
|
Type string `yaml:"type"`
|
|
Connection string `yaml:"connection"`
|
|
Host string `yaml:"host"`
|
|
Port int `yaml:"port"`
|
|
Username string `yaml:"username"`
|
|
Password string `yaml:"password"`
|
|
Database string `yaml:"database"`
|
|
}
|
|
|
|
type SecurityConfig struct {
|
|
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
|
IPWhitelist []string `yaml:"ip_whitelist"`
|
|
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
|
}
|
|
|
|
type RateLimitConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
RequestsPerMinute int `yaml:"requests_per_minute"`
|
|
BurstSize int `yaml:"burst_size"`
|
|
}
|
|
|
|
type LockoutConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
MaxAttempts int `yaml:"max_attempts"`
|
|
LockoutDuration string `yaml:"lockout_duration"`
|
|
}
|
|
|
|
type ServerConfig struct {
|
|
Address string `yaml:"address"`
|
|
TLS TLSConfig `yaml:"tls"`
|
|
}
|
|
|
|
type TLSConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
CertFile string `yaml:"cert_file"`
|
|
KeyFile string `yaml:"key_file"`
|
|
}
|
|
|
|
func LoadConfig(path string) (*Config, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var cfg Config
|
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func main() {
|
|
// Parse flags
|
|
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
|
|
apiKey := flag.String("api-key", "", "API key for authentication")
|
|
flag.Parse()
|
|
|
|
// Load config
|
|
resolvedConfig, err := config.ResolveConfigPath(*configFile)
|
|
if err != nil {
|
|
log.Fatalf("Failed to resolve config: %v", err)
|
|
}
|
|
|
|
cfg, err := LoadConfig(resolvedConfig)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// Ensure log directory exists
|
|
if cfg.Logging.File != "" {
|
|
logDir := filepath.Dir(cfg.Logging.File)
|
|
log.Printf("Creating log directory: %s", logDir)
|
|
if err := os.MkdirAll(logDir, 0755); err != nil {
|
|
log.Fatalf("Failed to create log directory: %v", err)
|
|
}
|
|
}
|
|
|
|
// Setup logging
|
|
logger := logging.NewLoggerFromConfig(cfg.Logging)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
logger = logger.Component(ctx, "api-server")
|
|
|
|
// Setup experiment manager
|
|
basePath := cfg.BasePath
|
|
if basePath == "" {
|
|
basePath = "/tmp/ml-experiments"
|
|
}
|
|
expManager := experiment.NewManager(basePath)
|
|
log.Printf("Initializing experiment manager with base_path: %s", basePath)
|
|
if err := expManager.Initialize(); err != nil {
|
|
logger.Fatal("failed to initialize experiment manager", "error", err)
|
|
}
|
|
logger.Info("experiment manager initialized", "base_path", basePath)
|
|
|
|
// Setup auth
|
|
var authCfg *auth.AuthConfig
|
|
if cfg.Auth.Enabled {
|
|
authCfg = &cfg.Auth
|
|
logger.Info("authentication enabled")
|
|
}
|
|
|
|
// Setup HTTP server with security middleware
|
|
mux := http.NewServeMux()
|
|
|
|
// Convert API keys from map to slice for security middleware
|
|
apiKeys := make([]string, 0, len(cfg.Auth.APIKeys))
|
|
for username := range cfg.Auth.APIKeys {
|
|
// For now, use username as the key (in production, this should be the actual API key)
|
|
apiKeys = append(apiKeys, string(username))
|
|
}
|
|
|
|
// Create security middleware
|
|
sec := middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"))
|
|
|
|
// Setup TaskQueue
|
|
queueCfg := queue.Config{
|
|
RedisAddr: cfg.Redis.Addr,
|
|
RedisPassword: cfg.Redis.Password,
|
|
RedisDB: cfg.Redis.DB,
|
|
}
|
|
if queueCfg.RedisAddr == "" {
|
|
queueCfg.RedisAddr = config.DefaultRedisAddr
|
|
}
|
|
// Support URL format for Redis
|
|
if cfg.Redis.URL != "" {
|
|
queueCfg.RedisAddr = cfg.Redis.URL
|
|
}
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queueCfg)
|
|
if err != nil {
|
|
logger.Error("failed to initialize task queue", "error", err)
|
|
// We continue without queue, but queue operations will fail
|
|
} else {
|
|
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
|
|
defer func() {
|
|
logger.Info("stopping task queue...")
|
|
if err := taskQueue.Close(); err != nil {
|
|
logger.Error("failed to stop task queue", "error", err)
|
|
} else {
|
|
logger.Info("task queue stopped")
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Setup database if configured
|
|
var db *storage.DB
|
|
if cfg.Database.Type != "" {
|
|
dbConfig := storage.DBConfig{
|
|
Type: cfg.Database.Type,
|
|
Connection: cfg.Database.Connection,
|
|
Host: cfg.Database.Host,
|
|
Port: cfg.Database.Port,
|
|
Username: cfg.Database.Username,
|
|
Password: cfg.Database.Password,
|
|
Database: cfg.Database.Database,
|
|
}
|
|
|
|
db, err = storage.NewDB(dbConfig)
|
|
if err != nil {
|
|
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
|
|
} else {
|
|
// Load appropriate database schema
|
|
var schemaPath string
|
|
if cfg.Database.Type == "sqlite" {
|
|
schemaPath = "internal/storage/schema.sql"
|
|
} else if cfg.Database.Type == "postgres" || cfg.Database.Type == "postgresql" {
|
|
schemaPath = "internal/storage/schema_postgres.sql"
|
|
} else {
|
|
logger.Error("unsupported database type", "type", cfg.Database.Type)
|
|
db.Close()
|
|
db = nil
|
|
}
|
|
|
|
if db != nil && schemaPath != "" {
|
|
schema, err := os.ReadFile(schemaPath)
|
|
if err != nil {
|
|
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
|
|
db.Close()
|
|
db = nil
|
|
} else {
|
|
if err := db.Initialize(string(schema)); err != nil {
|
|
logger.Error("failed to initialize database schema", "error", err)
|
|
db.Close()
|
|
db = nil
|
|
} else {
|
|
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
|
|
defer func() {
|
|
logger.Info("closing database connection...")
|
|
if err := db.Close(); err != nil {
|
|
logger.Error("failed to close database", "error", err)
|
|
} else {
|
|
logger.Info("database connection closed")
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Setup WebSocket handler with authentication
|
|
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
|
|
|
|
// WebSocket endpoint - no middleware to avoid hijacking issues
|
|
mux.Handle("/ws", wsHandler)
|
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprintf(w, "OK\n")
|
|
})
|
|
|
|
// Database status endpoint
|
|
mux.HandleFunc("/db-status", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if db != nil {
|
|
// Test database connection with a simple query
|
|
var result struct {
|
|
Status string `json:"status"`
|
|
Type string `json:"type"`
|
|
Path string `json:"path"`
|
|
Message string `json:"message"`
|
|
}
|
|
result.Status = "connected"
|
|
result.Type = "sqlite"
|
|
result.Path = cfg.Database.Connection
|
|
result.Message = "SQLite database is operational"
|
|
|
|
// Test a simple query to verify connectivity
|
|
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
|
|
result.Status = "error"
|
|
result.Message = fmt.Sprintf("Database query failed: %v", err)
|
|
}
|
|
|
|
jsonBytes, _ := json.Marshal(result)
|
|
w.Write(jsonBytes)
|
|
} else {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
|
|
}
|
|
})
|
|
|
|
// Apply security middleware to all routes except WebSocket
|
|
// Create separate handlers for WebSocket vs other routes
|
|
var finalHandler http.Handler = mux
|
|
|
|
// Wrap non-websocket routes with security middleware
|
|
finalHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/ws" {
|
|
mux.ServeHTTP(w, r)
|
|
} else {
|
|
// Apply middleware chain for non-WebSocket routes
|
|
handler := sec.RateLimit(mux)
|
|
handler = middleware.SecurityHeaders(handler)
|
|
handler = middleware.CORS(handler)
|
|
handler = middleware.RequestTimeout(30 * time.Second)(handler)
|
|
|
|
// Apply audit logger and IP whitelist only to non-WebSocket routes
|
|
handler = middleware.AuditLogger(handler)
|
|
if len(cfg.Security.IPWhitelist) > 0 {
|
|
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
|
|
}
|
|
|
|
handler.ServeHTTP(w, r)
|
|
}
|
|
})
|
|
|
|
var handler http.Handler = finalHandler
|
|
|
|
server := &http.Server{
|
|
Addr: cfg.Server.Address,
|
|
Handler: handler,
|
|
ReadTimeout: 15 * time.Second,
|
|
WriteTimeout: 15 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
}
|
|
|
|
if !cfg.Server.TLS.Enabled {
|
|
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
|
|
}
|
|
|
|
// Start server in goroutine
|
|
go func() {
|
|
// Setup TLS if configured
|
|
if cfg.Server.TLS.Enabled {
|
|
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
|
|
if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("HTTPS server failed", "error", err)
|
|
}
|
|
} else {
|
|
logger.Info("starting HTTP server", "address", cfg.Server.Address)
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("HTTP server failed", "error", err)
|
|
}
|
|
}
|
|
os.Exit(1)
|
|
}()
|
|
|
|
// Setup graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
sig := <-sigChan
|
|
logger.Info("received shutdown signal", "signal", sig)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
logger.Info("shutting down http server...")
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
logger.Error("server shutdown error", "error", err)
|
|
} else {
|
|
logger.Info("http server shutdown complete")
|
|
}
|
|
|
|
logger.Info("api server stopped")
|
|
|
|
_ = expManager // Use expManager to avoid unused warning
|
|
_ = apiKey // Will be used for auth later
|
|
}
|