diff --git a/cmd/tui/internal/config/config.go b/cmd/tui/internal/config/config.go index 11b274d..0edfd98 100644 --- a/cmd/tui/internal/config/config.go +++ b/cmd/tui/internal/config/config.go @@ -22,7 +22,7 @@ type Config struct { BasePath string `toml:"base_path"` Mode string `toml:"mode"` WrapperScript string `toml:"wrapper_script"` - TrainScript string `toml:"train_script"` + Entrypoint string `toml:"train_script"` RedisAddr string `toml:"redis_addr"` RedisPassword string `toml:"redis_password"` ContainerWorkspace string `toml:"container_workspace"` @@ -73,8 +73,8 @@ func LoadConfig(path string) (*Config, error) { cfg.BasePath = basePath } // wrapper_script is deprecated - using secure_runner.py directly via Podman - if cfg.TrainScript == "" { - cfg.TrainScript = utils.DefaultTrainScript + if cfg.Entrypoint == "" { + cfg.Entrypoint = utils.DefaultEntrypoint } if cfg.RedisAddr == "" { redisAddr, err := smart.RedisAddr() @@ -110,7 +110,7 @@ func LoadConfig(path string) (*Config, error) { cfg.BasePath = basePath } if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" { - cfg.TrainScript = trainScript + cfg.Entrypoint = trainScript } if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" { cfg.RedisAddr = redisAddr diff --git a/internal/audit/audit.go b/internal/audit/audit.go index aca4b4b..705df29 100644 --- a/internal/audit/audit.go +++ b/internal/audit/audit.go @@ -36,6 +36,14 @@ const ( EventFileWrite EventType = "file_write" EventFileDelete EventType = "file_delete" EventDatasetAccess EventType = "dataset_access" + + // KMS encryption events per ADR-012 through ADR-015 + EventKMSEncrypt EventType = "kms_encrypt" + EventKMSDecrypt EventType = "kms_decrypt" + EventKMSKeyCreate EventType = "kms_key_create" + EventKMSKeyRotate EventType = "kms_key_rotate" + EventKMSKeyDisable EventType = "kms_key_disable" + EventKMSKeyDelete EventType = "kms_key_delete" ) // Event represents an audit log event with integrity chain. @@ -352,6 +360,33 @@ func (al *Logger) LogJupyterOperation( }) } +// LogKMSOperation logs a KMS encryption/decryption or key management operation. +// Per ADR-012 through ADR-015: All key operations must be logged with tenant ID. +func (al *Logger) LogKMSOperation( + eventType EventType, + tenantID, artifactID, kmsKeyID string, + success bool, + errMsg string, +) { + metadata := map[string]any{ + "tenant_id": tenantID, + "kms_key_id": kmsKeyID, + } + if artifactID != "" { + metadata["artifact_id"] = artifactID + } + + al.Log(Event{ + EventType: eventType, + UserID: tenantID, // Tenant is the entity performing the operation + Resource: kmsKeyID, + Action: string(eventType), + Success: success, + ErrorMsg: errMsg, + Metadata: metadata, + }) +} + // Close closes the audit logger func (al *Logger) Close() error { al.mu.Lock() diff --git a/internal/audit/checkpoint.go b/internal/audit/checkpoint.go index 3f4027c..c0bbb9c 100644 --- a/internal/audit/checkpoint.go +++ b/internal/audit/checkpoint.go @@ -2,13 +2,9 @@ package audit import ( - "bufio" "context" - "crypto/sha256" "database/sql" - "encoding/hex" "fmt" - "os" "path/filepath" "time" ) @@ -182,26 +178,3 @@ func (dcm *DBCheckpointManager) ContinuousVerification( } } } - -// sha256File computes the SHA256 hash of a file (reused from rotation.go) -func sha256FileCheckpoint(path string) (string, error) { - f, err := os.Open(path) - if err != nil { - return "", err - } - defer f.Close() - - h := sha256.New() - scanner := bufio.NewScanner(f) - for scanner.Scan() { - // Hash the raw line including newline - h.Write(scanner.Bytes()) - h.Write([]byte{'\n'}) - } - - if err := scanner.Err(); err != nil { - return "", err - } - - return hex.EncodeToString(h.Sum(nil)), nil -} diff --git a/internal/config/constants.go b/internal/config/constants.go index e6663f2..6eaf921 100644 --- a/internal/config/constants.go +++ b/internal/config/constants.go @@ -6,7 +6,7 @@ const ( DefaultRedisPort = 6379 DefaultRedisAddr = "localhost:6379" DefaultBasePath = "/mnt/nas/jobs" - DefaultTrainScript = "train.py" + DefaultEntrypoint = "train.py" DefaultDataDir = "/data/active" DefaultLocalDataDir = "./data/active" DefaultNASDataDir = "/mnt/datasets" diff --git a/internal/crypto/kms/provider.go b/internal/crypto/kms/provider.go index f8f3afa..a5b3b13 100644 --- a/internal/crypto/kms/provider.go +++ b/internal/crypto/kms/provider.go @@ -16,20 +16,40 @@ import ( ) // KMSProvider defines the interface for external KMS operations. +// Root keys are stored in the KMS; DEKs are generated locally and wrapped +// by the KMS root key. type KMSProvider interface { + // Encrypt encrypts plaintext (typically a DEK) using the specified key ID. + // The key ID is a tenant-scoped KMS key identifier. Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) + + // Decrypt decrypts ciphertext (typically a wrapped DEK) using the specified key ID. Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) + + // CreateKey creates a new KMS key for a tenant. Returns the key ID. CreateKey(ctx context.Context, tenantID string) (string, error) + + // DisableKey disables a KMS key immediately (used in offboarding per ADR-015). DisableKey(ctx context.Context, keyID string) error + + // ScheduleKeyDeletion schedules hard deletion after the retention window (per ADR-015). + // Returns the deletion date. ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) + + // EnableKey re-enables a disabled key (requires approval workflow per ADR-015). EnableKey(ctx context.Context, keyID string) error + + // HealthCheck verifies KMS connectivity and returns any error. HealthCheck(ctx context.Context) error + + // Close closes the KMS provider connection and releases resources. Close() error } -// Provider type aliases to avoid duplication. +// ProviderType identifies the KMS provider implementation. type ProviderType = config.ProviderType +// Provider type constants from config package. const ( ProviderTypeVault = config.ProviderTypeVault ProviderTypeAWS = config.ProviderTypeAWS @@ -41,7 +61,7 @@ type ProviderFactory struct { config Config } -// Config aliases. +// Config aliases from config package. type Config = config.Config type VaultConfig = config.VaultConfig type AWSConfig = config.AWSConfig @@ -52,6 +72,11 @@ func DefaultCacheConfig() CacheConfig { return config.DefaultCacheConfig() } +// NewProviderFactory creates a new provider factory with the given config. +func NewProviderFactory(cfg Config) *ProviderFactory { + return &ProviderFactory{config: cfg} +} + // CreateProvider instantiates a KMS provider based on the configuration. func (f *ProviderFactory) CreateProvider() (KMSProvider, error) { switch f.config.Provider { diff --git a/internal/crypto/tenant_keys.go b/internal/crypto/tenant_keys.go index e2426be..474bedc 100644 --- a/internal/crypto/tenant_keys.go +++ b/internal/crypto/tenant_keys.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/crypto/kms" ) @@ -35,15 +36,17 @@ type TenantKeyManager struct { cache *kms.DEKCache // In-process DEK cache per ADR-012 config kms.Config // KMS configuration ctx context.Context // Background context for operations + audit *audit.Logger // Audit logger for key operations per ADR-012 } // NewTenantKeyManager creates a new tenant key manager with KMS integration. -func NewTenantKeyManager(provider kms.KMSProvider, cache *kms.DEKCache, config kms.Config) *TenantKeyManager { +func NewTenantKeyManager(provider kms.KMSProvider, cache *kms.DEKCache, config kms.Config, auditLogger *audit.Logger) *TenantKeyManager { return &TenantKeyManager{ kms: provider, cache: cache, config: config, ctx: context.Background(), + audit: auditLogger, } } @@ -63,6 +66,11 @@ func (km *TenantKeyManager) ProvisionTenant(tenantID string) (*KeyHierarchy, err h := sha256.Sum256([]byte(tenantID + time.Now().String())) rootKeyID := hex.EncodeToString(h[:8]) // First 8 bytes as ID + // Log key creation per ADR-012 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSKeyCreate, tenantID, "", kmsKeyID, true, "") + } + return &KeyHierarchy{ TenantID: tenantID, RootKeyID: rootKeyID, @@ -85,7 +93,17 @@ func (km *TenantKeyManager) RotateTenantKey(tenantID string, hierarchy *KeyHiera km.cache.Flush(tenantID) // Provision new key - return km.ProvisionTenant(tenantID) + newHierarchy, err := km.ProvisionTenant(tenantID) + if err != nil { + return nil, err + } + + // Log key rotation per ADR-012 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSKeyRotate, tenantID, "", newHierarchy.KMSKeyID, true, "") + } + + return newHierarchy, nil } // RevokeTenant disables and schedules deletion of all keys for a tenant. @@ -96,12 +114,22 @@ func (km *TenantKeyManager) RevokeTenant(hierarchy *KeyHierarchy) error { return fmt.Errorf("failed to disable key: %w", err) } + // Log key disable per ADR-015 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSKeyDisable, hierarchy.TenantID, "", hierarchy.KMSKeyID, true, "") + } + // Schedule hard deletion after 90 days per ADR-015 _, err := km.kms.ScheduleKeyDeletion(km.ctx, hierarchy.KMSKeyID, 90) if err != nil { return fmt.Errorf("failed to schedule key deletion: %w", err) } + // Log key deletion scheduled per ADR-015 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSKeyDelete, hierarchy.TenantID, "", hierarchy.KMSKeyID, true, "") + } + // Flush DEK cache for this tenant km.cache.Flush(hierarchy.TenantID) @@ -204,11 +232,11 @@ type WrappedDEK struct { // NewTestTenantKeyManager creates a tenant key manager with memory provider for testing. // This provides backward compatibility for existing tests. -func NewTestTenantKeyManager() *TenantKeyManager { +func NewTestTenantKeyManager(auditLogger *audit.Logger) *TenantKeyManager { provider := kms.NewMemoryProvider() cache := kms.NewDEKCache(kms.DefaultCacheConfig()) config := kms.Config{Provider: kms.ProviderTypeMemory} - return NewTenantKeyManager(provider, cache, config) + return NewTenantKeyManager(provider, cache, config, auditLogger) } // EncryptArtifact encrypts artifact data using a tenant-specific DEK. @@ -216,12 +244,18 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin // Generate a new DEK for this artifact wrappedDEK, err := km.GenerateDataEncryptionKey(tenantID, artifactID, kmsKeyID) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, err } // Get the DEK (from cache or unwrap) dek, err := km.UnwrapDataEncryptionKey(wrappedDEK, kmsKeyID) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, err } defer func() { @@ -234,21 +268,35 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin // Encrypt the data with the DEK block, err := aes.NewCipher(dek) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to create cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to create GCM: %w", err) } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to generate nonce: %w", err) } ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + // Log successful encryption per ADR-012 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSEncrypt, tenantID, artifactID, kmsKeyID, true, "") + } + return &EncryptedArtifact{ Ciphertext: base64.StdEncoding.EncodeToString(ciphertext), DEK: wrappedDEK, @@ -259,9 +307,15 @@ func (km *TenantKeyManager) EncryptArtifact(tenantID, artifactID, kmsKeyID strin // DecryptArtifact decrypts artifact data using its wrapped DEK. func (km *TenantKeyManager) DecryptArtifact(encrypted *EncryptedArtifact, kmsKeyID string) ([]byte, error) { + tenantID := encrypted.DEK.TenantID + artifactID := encrypted.DEK.ArtifactID + // Unwrap the DEK dek, err := km.UnwrapDataEncryptionKey(encrypted.DEK, kmsKeyID) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to unwrap DEK: %w", err) } defer func() { @@ -273,26 +327,52 @@ func (km *TenantKeyManager) DecryptArtifact(encrypted *EncryptedArtifact, kmsKey // Decrypt the data ciphertext, err := base64.StdEncoding.DecodeString(encrypted.Ciphertext) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to decode ciphertext: %w", err) } block, err := aes.NewCipher(dek) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to create cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } return nil, fmt.Errorf("failed to create GCM: %w", err) } nonceSize := gcm.NonceSize() if len(ciphertext) < nonceSize { - return nil, fmt.Errorf("ciphertext too short") + err := fmt.Errorf("ciphertext too short") + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } + return nil, err } nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] - return gcm.Open(nil, nonce, ciphertext, nil) + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, false, err.Error()) + } + return nil, err + } + + // Log successful decryption per ADR-012 + if km.audit != nil { + km.audit.LogKMSOperation(audit.EventKMSDecrypt, tenantID, artifactID, kmsKeyID, true, "") + } + + return plaintext, nil } // EncryptedArtifact represents an encrypted artifact with its wrapped DEK