- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
444 lines
12 KiB
Go
444 lines
12 KiB
Go
// Package main implements the fetch_ml API server
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/middleware"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Config structure matching worker config.
|
|
type Config struct {
|
|
BasePath string `yaml:"base_path"`
|
|
Auth auth.Config `yaml:"auth"`
|
|
Server ServerConfig `yaml:"server"`
|
|
Security SecurityConfig `yaml:"security"`
|
|
Redis RedisConfig `yaml:"redis"`
|
|
Database DatabaseConfig `yaml:"database"`
|
|
Logging logging.Config `yaml:"logging"`
|
|
Resources config.ResourceConfig `yaml:"resources"`
|
|
}
|
|
|
|
// RedisConfig holds Redis connection configuration.
|
|
type RedisConfig struct {
|
|
Addr string `yaml:"addr"`
|
|
Password string `yaml:"password"`
|
|
DB int `yaml:"db"`
|
|
URL string `yaml:"url"`
|
|
}
|
|
|
|
// DatabaseConfig holds database connection configuration.
|
|
type DatabaseConfig struct {
|
|
Type string `yaml:"type"`
|
|
Connection string `yaml:"connection"`
|
|
Host string `yaml:"host"`
|
|
Port int `yaml:"port"`
|
|
Username string `yaml:"username"`
|
|
Password string `yaml:"password"`
|
|
Database string `yaml:"database"`
|
|
}
|
|
|
|
// SecurityConfig holds security-related configuration.
|
|
type SecurityConfig struct {
|
|
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
|
IPWhitelist []string `yaml:"ip_whitelist"`
|
|
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
|
}
|
|
|
|
// RateLimitConfig holds rate limiting configuration.
|
|
type RateLimitConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
RequestsPerMinute int `yaml:"requests_per_minute"`
|
|
BurstSize int `yaml:"burst_size"`
|
|
}
|
|
|
|
// LockoutConfig holds failed login lockout configuration.
|
|
type LockoutConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
MaxAttempts int `yaml:"max_attempts"`
|
|
LockoutDuration string `yaml:"lockout_duration"`
|
|
}
|
|
|
|
// ServerConfig holds server configuration.
|
|
type ServerConfig struct {
|
|
Address string `yaml:"address"`
|
|
TLS TLSConfig `yaml:"tls"`
|
|
}
|
|
|
|
// TLSConfig holds TLS configuration.
|
|
type TLSConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
CertFile string `yaml:"cert_file"`
|
|
KeyFile string `yaml:"key_file"`
|
|
}
|
|
|
|
// LoadConfig loads configuration from a YAML file.
|
|
func LoadConfig(path string) (*Config, error) {
|
|
data, err := fileutil.SecureFileRead(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var cfg Config
|
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func main() {
|
|
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
|
|
apiKey := flag.String("api-key", "", "API key for authentication")
|
|
flag.Parse()
|
|
|
|
cfg, err := loadServerConfig(*configFile)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
if err := ensureLogDirectory(cfg.Logging); err != nil {
|
|
log.Fatalf("Failed to prepare log directory: %v", err)
|
|
}
|
|
|
|
logger := setupLogger(cfg.Logging)
|
|
|
|
expManager, err := initExperimentManager(cfg.BasePath, logger)
|
|
if err != nil {
|
|
logger.Fatal("failed to initialize experiment manager", "error", err)
|
|
}
|
|
|
|
taskQueue, queueCleanup := initTaskQueue(cfg, logger)
|
|
if queueCleanup != nil {
|
|
defer queueCleanup()
|
|
}
|
|
|
|
db, dbCleanup := initDatabase(cfg, logger)
|
|
if dbCleanup != nil {
|
|
defer dbCleanup()
|
|
}
|
|
|
|
authCfg := buildAuthConfig(cfg.Auth, logger)
|
|
sec := newSecurityMiddleware(cfg)
|
|
|
|
mux := buildHTTPMux(cfg, logger, expManager, taskQueue, authCfg, db)
|
|
finalHandler := wrapWithMiddleware(cfg, sec, mux)
|
|
server := newHTTPServer(cfg, finalHandler)
|
|
|
|
startServer(server, cfg, logger)
|
|
waitForShutdown(server, logger)
|
|
|
|
_ = apiKey // Reserved for future authentication enhancements
|
|
}
|
|
|
|
func loadServerConfig(path string) (*Config, error) {
|
|
resolvedConfig, err := config.ResolveConfigPath(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg, err := LoadConfig(resolvedConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg.Resources.ApplyDefaults()
|
|
return cfg, nil
|
|
}
|
|
|
|
func ensureLogDirectory(cfg logging.Config) error {
|
|
if cfg.File == "" {
|
|
return nil
|
|
}
|
|
|
|
logDir := filepath.Dir(cfg.File)
|
|
log.Printf("Creating log directory: %s", logDir)
|
|
return os.MkdirAll(logDir, 0750)
|
|
}
|
|
|
|
func setupLogger(cfg logging.Config) *logging.Logger {
|
|
logger := logging.NewLoggerFromConfig(cfg)
|
|
ctx := logging.EnsureTrace(context.Background())
|
|
return logger.Component(ctx, "api-server")
|
|
}
|
|
|
|
func initExperimentManager(basePath string, logger *logging.Logger) (*experiment.Manager, error) {
|
|
if basePath == "" {
|
|
basePath = "/tmp/ml-experiments"
|
|
}
|
|
|
|
expManager := experiment.NewManager(basePath)
|
|
log.Printf("Initializing experiment manager with base_path: %s", basePath)
|
|
if err := expManager.Initialize(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info("experiment manager initialized", "base_path", basePath)
|
|
return expManager, nil
|
|
}
|
|
|
|
func buildAuthConfig(cfg auth.Config, logger *logging.Logger) *auth.Config {
|
|
if !cfg.Enabled {
|
|
return nil
|
|
}
|
|
|
|
logger.Info("authentication enabled")
|
|
return &cfg
|
|
}
|
|
|
|
func newSecurityMiddleware(cfg *Config) *middleware.SecurityMiddleware {
|
|
apiKeys := collectAPIKeys(cfg.Auth.APIKeys)
|
|
rlOpts := buildRateLimitOptions(cfg.Security.RateLimit)
|
|
return middleware.NewSecurityMiddleware(apiKeys, os.Getenv("JWT_SECRET"), rlOpts)
|
|
}
|
|
|
|
func collectAPIKeys(keys map[auth.Username]auth.APIKeyEntry) []string {
|
|
apiKeys := make([]string, 0, len(keys))
|
|
for username := range keys {
|
|
apiKeys = append(apiKeys, string(username))
|
|
}
|
|
return apiKeys
|
|
}
|
|
|
|
func buildRateLimitOptions(cfg RateLimitConfig) *middleware.RateLimitOptions {
|
|
if !cfg.Enabled || cfg.RequestsPerMinute <= 0 {
|
|
return nil
|
|
}
|
|
|
|
return &middleware.RateLimitOptions{
|
|
RequestsPerMinute: cfg.RequestsPerMinute,
|
|
BurstSize: cfg.BurstSize,
|
|
}
|
|
}
|
|
|
|
func initTaskQueue(cfg *Config, logger *logging.Logger) (*queue.TaskQueue, func()) {
|
|
queueCfg := queue.Config{
|
|
RedisAddr: cfg.Redis.Addr,
|
|
RedisPassword: cfg.Redis.Password,
|
|
RedisDB: cfg.Redis.DB,
|
|
}
|
|
if queueCfg.RedisAddr == "" {
|
|
queueCfg.RedisAddr = config.DefaultRedisAddr
|
|
}
|
|
if cfg.Redis.URL != "" {
|
|
queueCfg.RedisAddr = cfg.Redis.URL
|
|
}
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queueCfg)
|
|
if err != nil {
|
|
logger.Error("failed to initialize task queue", "error", err)
|
|
return nil, nil
|
|
}
|
|
|
|
logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
|
|
cleanup := func() {
|
|
logger.Info("stopping task queue...")
|
|
if err := taskQueue.Close(); err != nil {
|
|
logger.Error("failed to stop task queue", "error", err)
|
|
} else {
|
|
logger.Info("task queue stopped")
|
|
}
|
|
}
|
|
return taskQueue, cleanup
|
|
}
|
|
|
|
func initDatabase(cfg *Config, logger *logging.Logger) (*storage.DB, func()) {
|
|
if cfg.Database.Type == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
dbConfig := storage.DBConfig{
|
|
Type: cfg.Database.Type,
|
|
Connection: cfg.Database.Connection,
|
|
Host: cfg.Database.Host,
|
|
Port: cfg.Database.Port,
|
|
Username: cfg.Database.Username,
|
|
Password: cfg.Database.Password,
|
|
Database: cfg.Database.Database,
|
|
}
|
|
|
|
db, err := storage.NewDB(dbConfig)
|
|
if err != nil {
|
|
logger.Error("failed to initialize database", "type", cfg.Database.Type, "error", err)
|
|
return nil, nil
|
|
}
|
|
|
|
schemaPath := schemaPathForDB(cfg.Database.Type)
|
|
if schemaPath == "" {
|
|
logger.Error("unsupported database type", "type", cfg.Database.Type)
|
|
_ = db.Close()
|
|
return nil, nil
|
|
}
|
|
|
|
schema, err := fileutil.SecureFileRead(schemaPath)
|
|
if err != nil {
|
|
logger.Error("failed to read database schema file", "path", schemaPath, "error", err)
|
|
_ = db.Close()
|
|
return nil, nil
|
|
}
|
|
|
|
if err := db.Initialize(string(schema)); err != nil {
|
|
logger.Error("failed to initialize database schema", "error", err)
|
|
_ = db.Close()
|
|
return nil, nil
|
|
}
|
|
|
|
logger.Info("database initialized", "type", cfg.Database.Type, "connection", cfg.Database.Connection)
|
|
cleanup := func() {
|
|
logger.Info("closing database connection...")
|
|
if err := db.Close(); err != nil {
|
|
logger.Error("failed to close database", "error", err)
|
|
} else {
|
|
logger.Info("database connection closed")
|
|
}
|
|
}
|
|
return db, cleanup
|
|
}
|
|
|
|
func schemaPathForDB(dbType string) string {
|
|
switch dbType {
|
|
case "sqlite":
|
|
return "internal/storage/schema_sqlite.sql"
|
|
case "postgres", "postgresql":
|
|
return "internal/storage/schema_postgres.sql"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func buildHTTPMux(
|
|
cfg *Config,
|
|
logger *logging.Logger,
|
|
expManager *experiment.Manager,
|
|
taskQueue *queue.TaskQueue,
|
|
authCfg *auth.Config,
|
|
db *storage.DB,
|
|
) *http.ServeMux {
|
|
mux := http.NewServeMux()
|
|
wsHandler := api.NewWSHandler(authCfg, logger, expManager, taskQueue)
|
|
|
|
mux.Handle("/ws", wsHandler)
|
|
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = fmt.Fprintf(w, "OK\n")
|
|
})
|
|
|
|
mux.HandleFunc("/db-status", func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if db == nil {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
_, _ = fmt.Fprintf(w, `{"status":"disconnected","message":"Database not configured or failed to initialize"}`)
|
|
return
|
|
}
|
|
|
|
var result struct {
|
|
Status string `json:"status"`
|
|
Type string `json:"type"`
|
|
Path string `json:"path"`
|
|
Message string `json:"message"`
|
|
}
|
|
result.Status = "connected"
|
|
result.Type = cfg.Database.Type
|
|
result.Path = cfg.Database.Connection
|
|
result.Message = fmt.Sprintf("%s database is operational", cfg.Database.Type)
|
|
|
|
if err := db.RecordSystemMetric("db_test", "ok"); err != nil {
|
|
result.Status = "error"
|
|
result.Message = fmt.Sprintf("Database query failed: %v", err)
|
|
}
|
|
|
|
jsonBytes, _ := json.Marshal(result)
|
|
_, _ = w.Write(jsonBytes)
|
|
})
|
|
|
|
return mux
|
|
}
|
|
|
|
func wrapWithMiddleware(cfg *Config, sec *middleware.SecurityMiddleware, mux *http.ServeMux) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/ws" {
|
|
mux.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
handler := sec.RateLimit(mux)
|
|
handler = middleware.SecurityHeaders(handler)
|
|
handler = middleware.CORS(handler)
|
|
handler = middleware.RequestTimeout(30 * time.Second)(handler)
|
|
handler = middleware.AuditLogger(handler)
|
|
if len(cfg.Security.IPWhitelist) > 0 {
|
|
handler = sec.IPWhitelist(cfg.Security.IPWhitelist)(handler)
|
|
}
|
|
handler.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func newHTTPServer(cfg *Config, handler http.Handler) *http.Server {
|
|
return &http.Server{
|
|
Addr: cfg.Server.Address,
|
|
Handler: handler,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
}
|
|
|
|
func startServer(server *http.Server, cfg *Config, logger *logging.Logger) {
|
|
if !cfg.Server.TLS.Enabled {
|
|
logger.Warn("TLS disabled for API server; do not use this configuration in production", "address", cfg.Server.Address)
|
|
}
|
|
|
|
go func() {
|
|
if cfg.Server.TLS.Enabled {
|
|
logger.Info("starting HTTPS server", "address", cfg.Server.Address)
|
|
if err := server.ListenAndServeTLS(
|
|
cfg.Server.TLS.CertFile,
|
|
cfg.Server.TLS.KeyFile,
|
|
); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("HTTPS server failed", "error", err)
|
|
}
|
|
} else {
|
|
logger.Info("starting HTTP server", "address", cfg.Server.Address)
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("HTTP server failed", "error", err)
|
|
}
|
|
}
|
|
os.Exit(1)
|
|
}()
|
|
}
|
|
|
|
func waitForShutdown(server *http.Server, logger *logging.Logger) {
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
sig := <-sigChan
|
|
logger.Info("received shutdown signal", "signal", sig)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
logger.Info("shutting down http server...")
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
logger.Error("server shutdown error", "error", err)
|
|
} else {
|
|
logger.Info("http server shutdown complete")
|
|
}
|
|
|
|
logger.Info("api server stopped")
|
|
}
|