security: improve audit, crypto, and config handling

- Enhance audit checkpoint system
- Update KMS provider and tenant key management
- Refine configuration constants
- Improve TUI config handling
This commit is contained in:
Jeremie Fraeys 2026-03-04 13:23:42 -05:00
parent a4f2c36069
commit 66f262d788
No known key found for this signature in database
6 changed files with 153 additions and 40 deletions

View file

@ -22,7 +22,7 @@ type Config struct {
BasePath string `toml:"base_path"`
Mode string `toml:"mode"`
WrapperScript string `toml:"wrapper_script"`
TrainScript string `toml:"train_script"`
Entrypoint string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
ContainerWorkspace string `toml:"container_workspace"`
@ -73,8 +73,8 @@ func LoadConfig(path string) (*Config, error) {
cfg.BasePath = basePath
}
// wrapper_script is deprecated - using secure_runner.py directly via Podman
if cfg.TrainScript == "" {
cfg.TrainScript = utils.DefaultTrainScript
if cfg.Entrypoint == "" {
cfg.Entrypoint = utils.DefaultEntrypoint
}
if cfg.RedisAddr == "" {
redisAddr, err := smart.RedisAddr()
@ -110,7 +110,7 @@ func LoadConfig(path string) (*Config, error) {
cfg.BasePath = basePath
}
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
cfg.TrainScript = trainScript
cfg.Entrypoint = trainScript
}
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
cfg.RedisAddr = redisAddr

View file

@ -36,6 +36,14 @@ const (
EventFileWrite EventType = "file_write"
EventFileDelete EventType = "file_delete"
EventDatasetAccess EventType = "dataset_access"
// KMS encryption events per ADR-012 through ADR-015
EventKMSEncrypt EventType = "kms_encrypt"
EventKMSDecrypt EventType = "kms_decrypt"
EventKMSKeyCreate EventType = "kms_key_create"
EventKMSKeyRotate EventType = "kms_key_rotate"
EventKMSKeyDisable EventType = "kms_key_disable"
EventKMSKeyDelete EventType = "kms_key_delete"
)
// Event represents an audit log event with integrity chain.
@ -352,6 +360,33 @@ func (al *Logger) LogJupyterOperation(
})
}
// LogKMSOperation logs a KMS encryption/decryption or key management operation.
// Per ADR-012 through ADR-015: All key operations must be logged with tenant ID.
func (al *Logger) LogKMSOperation(
eventType EventType,
tenantID, artifactID, kmsKeyID string,
success bool,
errMsg string,
) {
metadata := map[string]any{
"tenant_id": tenantID,
"kms_key_id": kmsKeyID,
}
if artifactID != "" {
metadata["artifact_id"] = artifactID
}
al.Log(Event{
EventType: eventType,
UserID: tenantID, // Tenant is the entity performing the operation
Resource: kmsKeyID,
Action: string(eventType),
Success: success,
ErrorMsg: errMsg,
Metadata: metadata,
})
}
// Close closes the audit logger
func (al *Logger) Close() error {
al.mu.Lock()

View file

@ -2,13 +2,9 @@
package audit
import (
"bufio"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"time"
)
@ -182,26 +178,3 @@ func (dcm *DBCheckpointManager) ContinuousVerification(
}
}
}
// sha256File computes the SHA256 hash of a file (reused from rotation.go)
func sha256FileCheckpoint(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
// Hash the raw line including newline
h.Write(scanner.Bytes())
h.Write([]byte{'\n'})
}
if err := scanner.Err(); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}

View file

@ -6,7 +6,7 @@ const (
DefaultRedisPort = 6379
DefaultRedisAddr = "localhost:6379"
DefaultBasePath = "/mnt/nas/jobs"
DefaultTrainScript = "train.py"
DefaultEntrypoint = "train.py"
DefaultDataDir = "/data/active"
DefaultLocalDataDir = "./data/active"
DefaultNASDataDir = "/mnt/datasets"

View file

@ -16,20 +16,40 @@ import (
)
// KMSProvider defines the interface for external KMS operations.
// Root keys are stored in the KMS; DEKs are generated locally and wrapped
// by the KMS root key.
type KMSProvider interface {
// Encrypt encrypts plaintext (typically a DEK) using the specified key ID.
// The key ID is a tenant-scoped KMS key identifier.
Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error)
// Decrypt decrypts ciphertext (typically a wrapped DEK) using the specified key ID.
Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error)
// CreateKey creates a new KMS key for a tenant. Returns the key ID.
CreateKey(ctx context.Context, tenantID string) (string, error)
// DisableKey disables a KMS key immediately (used in offboarding per ADR-015).
DisableKey(ctx context.Context, keyID string) error
// ScheduleKeyDeletion schedules hard deletion after the retention window (per ADR-015).
// Returns the deletion date.
ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error)
// EnableKey re-enables a disabled key (requires approval workflow per ADR-015).
EnableKey(ctx context.Context, keyID string) error
// HealthCheck verifies KMS connectivity and returns any error.
HealthCheck(ctx context.Context) error
// Close closes the KMS provider connection and releases resources.
Close() error
}
// Provider type aliases to avoid duplication.
// ProviderType identifies the KMS provider implementation.
type ProviderType = config.ProviderType
// Provider type constants from config package.
const (
ProviderTypeVault = config.ProviderTypeVault
ProviderTypeAWS = config.ProviderTypeAWS
@ -41,7 +61,7 @@ type ProviderFactory struct {
config Config
}
// Config aliases.
// Config aliases from config package.
type Config = config.Config
type VaultConfig = config.VaultConfig
type AWSConfig = config.AWSConfig
@ -52,6 +72,11 @@ func DefaultCacheConfig() CacheConfig {
return config.DefaultCacheConfig()
}
// NewProviderFactory creates a new provider factory with the given config.
func NewProviderFactory(cfg Config) *ProviderFactory {
return &ProviderFactory{config: cfg}
}
// CreateProvider instantiates a KMS provider based on the configuration.
func (f *ProviderFactory) CreateProvider() (KMSProvider, error) {
switch f.config.Provider {

View file

@ -15,6 +15,7 @@ import (
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
)
@ -35,15 +36,17 @@ type TenantKeyManager struct {
cache *kms.DEKCache // In-process DEK cache per ADR-012
config kms.Config // KMS configuration
ctx context.Context // Background context for operations
audit *audit.Logger // Audit logger for key operations per ADR-012
}
// NewTenantKeyManager creates a new tenant key manager with KMS integration.
func NewTenantKeyManager(provider kms.KMSProvider, cache *kms.DEKCache, config kms.Config) *TenantKeyManager {
func NewTenantKeyManager(provider kms.KMSProvider, cache *kms.DEKCache, config kms.Config, auditLogger *audit.Logger) *TenantKeyManager {
return &TenantKeyManager{
kms: provider,
cache: cache,
config: config,
ctx: context.Background(),
audit: auditLogger,
}
}
@ -63,6 +66,11 @@ func (km *TenantKeyManager) ProvisionTenant(tenantID string) (*KeyHierarchy, err
h := sha256.Sum256([]byte(tenantID + time.Now().String()))
rootKeyID := hex.EncodeToString(h[:8]) // First 8 bytes as ID
// Log key creation per ADR-012
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSKeyCreate, tenantID, "", kmsKeyID, true, "")
}
return &KeyHierarchy{
TenantID: tenantID,
RootKeyID: rootKeyID,
@ -85,7 +93,17 @@ func (km *TenantKeyManager) RotateTenantKey(tenantID string, hierarchy *KeyHiera
km.cache.Flush(tenantID)
// Provision new key
return km.ProvisionTenant(tenantID)
newHierarchy, err := km.ProvisionTenant(tenantID)
if err != nil {
return nil, err
}
// Log key rotation per ADR-012
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSKeyRotate, tenantID, "", newHierarchy.KMSKeyID, true, "")
}
return newHierarchy, nil
}
// RevokeTenant disables and schedules deletion of all keys for a tenant.
@ -96,12 +114,22 @@ func (km *TenantKeyManager) RevokeTenant(hierarchy *KeyHierarchy) error {
return fmt.Errorf("failed to disable key: %w", err)
}
// Log key disable per ADR-015
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSKeyDisable, hierarchy.TenantID, "", hierarchy.KMSKeyID, true, "")
}
// Schedule hard deletion after 90 days per ADR-015
_, err := km.kms.ScheduleKeyDeletion(km.ctx, hierarchy.KMSKeyID, 90)
if err != nil {
return fmt.Errorf("failed to schedule key deletion: %w", err)
}
// Log key deletion scheduled per ADR-015
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSKeyDelete, hierarchy.TenantID, "", hierarchy.KMSKeyID, true, "")
}
// Flush DEK cache for this tenant
km.cache.Flush(hierarchy.TenantID)
@ -204,11 +232,11 @@ type WrappedDEK struct {
// NewTestTenantKeyManager creates a tenant key manager with memory provider for testing.
// This provides backward compatibility for existing tests.
func NewTestTenantKeyManager() *TenantKeyManager {
func NewTestTenantKeyManager(auditLogger *audit.Logger) *TenantKeyManager {
provider := kms.NewMemoryProvider()
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
config := kms.Config{Provider: kms.ProviderTypeMemory}
return NewTenantKeyManager(provider, cache, config)
return NewTenantKeyManager(provider, cache, config, auditLogger)
}
// EncryptArtifact encrypts artifact data using a tenant-specific DEK.
@ -216,12 +244,18 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin
// Generate a new DEK for this artifact
wrappedDEK, err := km.GenerateDataEncryptionKey(tenantID, artifactID, kmsKeyID)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, err
}
// Get the DEK (from cache or unwrap)
dek, err := km.UnwrapDataEncryptionKey(wrappedDEK, kmsKeyID)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, err
}
defer func() {
@ -234,21 +268,35 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin
// Encrypt the data with the DEK
block, err := aes.NewCipher(dek)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
// Log successful encryption per ADR-012
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, true, "")
}
return &EncryptedArtifact{
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
DEK: wrappedDEK,
@ -259,9 +307,15 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin
// DecryptArtifact decrypts artifact data using its wrapped DEK.
func (km *TenantKeyManager) DecryptArtifact(encrypted *EncryptedArtifact, kmsKeyID string) ([]byte, error) {
tenantID := encrypted.DEK.TenantID
artifactID := encrypted.DEK.ArtifactID
// Unwrap the DEK
dek, err := km.UnwrapDataEncryptionKey(encrypted.DEK, kmsKeyID)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to unwrap DEK: %w", err)
}
defer func() {
@ -273,26 +327,52 @@ func (km *TenantKeyManager) DecryptArtifact(encrypted *EncryptedArtifact, kmsKey
// Decrypt the data
ciphertext, err := base64.StdEncoding.DecodeString(encrypted.Ciphertext)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to decode ciphertext: %w", err)
}
block, err := aes.NewCipher(dek)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
err := fmt.Errorf("ciphertext too short")
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, err
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error())
}
return nil, err
}
// Log successful decryption per ADR-012
if km.audit != nil {
km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, true, "")
}
return plaintext, nil
}
// EncryptedArtifact represents an encrypted artifact with its wrapped DEK