feat(kms): implement core KMS infrastructure with DEK cache

Add KMSProvider interface for external key management systems:
- Encrypt/Decrypt operations for DEK wrapping
- Key lifecycle management (Create, Disable, ScheduleDeletion, Enable)
- HealthCheck and Close methods

Implement MemoryProvider for development/testing:
- XOR encryption with HMAC-SHA256 authentication
- Secure random key generation using crypto/rand
- MAC verification to detect wrong keys

Implement DEKCache per ADR-012:
- 15-minute TTL with configurable grace window (1 hour)
- LRU eviction with 1000 entry limit
- Cache key includes (tenantID, artifactID, kmsKeyID) for isolation
- Thread-safe operations with RWMutex
- Secure memory wiping on eviction/cleanup

Add config package with types:
- ProviderType enum (vault, aws, memory)
- VaultConfig with AppRole/Kubernetes/Token auth
- AWSConfig with region and alias prefix
- CacheConfig with TTL, MaxEntries, GraceWindow
- Validation methods for all config types
This commit is contained in:
Jeremie Fraeys 2026-03-03 19:13:55 -05:00
parent da104367d6
commit cb25677695
No known key found for this signature in database
3 changed files with 610 additions and 0 deletions

View file

@ -0,0 +1,277 @@
package kms
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sync"
"time"
)
// DEKCache implements an in-process cache for unwrapped DEKs per ADR-012.
// - TTL: 15 minutes per entry
// - Max size: 1000 entries (LRU eviction)
// - Scope: In-process only, never serialized to disk
type DEKCache struct {
mu sync.RWMutex
entries map[string]*cacheEntry
ttl time.Duration
maxEntries int
graceWindow time.Duration // Per ADR-013
evictionList []string // Simple LRU: front = oldest, back = newest
}
// cacheEntry holds a cached DEK with metadata.
type cacheEntry struct {
dek []byte
tenantID string
artifactID string
kmsKeyID string
createdAt time.Time
lastAccess time.Time
evicted bool
}
// NewDEKCache creates a new DEK cache with the specified configuration.
func NewDEKCache(cfg CacheConfig) *DEKCache {
return &DEKCache{
entries: make(map[string]*cacheEntry),
ttl: cfg.TTL,
maxEntries: cfg.MaxEntries,
graceWindow: cfg.GraceWindow,
evictionList: make([]string, 0),
}
}
// cacheKey generates a unique key for the cache map including KMS key ID.
func cacheKey(tenantID, artifactID, kmsKeyID string) string {
// Use a hash to avoid storing raw IDs in the key
h := sha256.New()
_, _ = h.Write([]byte(tenantID + "/" + artifactID + "/" + kmsKeyID))
return hex.EncodeToString(h.Sum(nil))
}
// Get retrieves a DEK from the cache if present and not expired.
// Returns nil and false if not found or expired.
// Per ADR-013: During KMS unavailability, expired entries within grace window are still returned.
func (c *DEKCache) Get(tenantID, artifactID, kmsKeyID string, kmsUnavailable bool) ([]byte, bool) {
key := cacheKey(tenantID, artifactID, kmsKeyID)
c.mu.RLock()
entry, exists := c.entries[key]
c.mu.RUnlock()
if !exists || entry.evicted {
return nil, false
}
now := time.Now()
age := now.Sub(entry.createdAt)
// Check if expired
if age > c.ttl {
// Per ADR-013: Grace window only applies during KMS unavailability
if !kmsUnavailable {
return nil, false
}
// Check grace window
graceAge := age - c.ttl
if graceAge > c.graceWindow {
return nil, false
}
// Within grace window - return the DEK but log that we're using grace period
// (caller should log this appropriately)
}
// Update last access time (need write lock)
c.mu.Lock()
entry.lastAccess = now
c.updateLRU(key)
c.mu.Unlock()
// Return a copy of the DEK to prevent external modification
dekCopy := make([]byte, len(entry.dek))
copy(dekCopy, entry.dek)
return dekCopy, true
}
// Put stores a DEK in the cache.
func (c *DEKCache) Put(tenantID, artifactID, kmsKeyID string, dek []byte) error {
if len(dek) == 0 {
return fmt.Errorf("cannot cache empty DEK")
}
key := cacheKey(tenantID, artifactID, kmsKeyID)
// Copy the DEK to prevent external modification
dekCopy := make([]byte, len(dek))
copy(dekCopy, dek)
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if len(c.entries) >= c.maxEntries {
c.evictLRU()
}
// Store the entry
now := time.Now()
c.entries[key] = &cacheEntry{
dek: dekCopy,
tenantID: tenantID,
artifactID: artifactID,
kmsKeyID: kmsKeyID,
createdAt: now,
lastAccess: now,
}
// Add to LRU list (newest at back)
c.evictionList = append(c.evictionList, key)
return nil
}
// Flush removes all DEKs for a specific tenant from the cache.
// Called on key rotation events and tenant offboarding per ADR-012.
func (c *DEKCache) Flush(tenantID string) {
c.mu.Lock()
defer c.mu.Unlock()
// Mark entries for eviction
for key, entry := range c.entries {
if entry.tenantID == tenantID {
entry.evicted = true
delete(c.entries, key)
}
}
// Rebuild eviction list without flushed entries
newList := make([]string, 0, len(c.evictionList))
for _, key := range c.evictionList {
if entry, exists := c.entries[key]; exists && !entry.evicted {
newList = append(newList, key)
}
}
c.evictionList = newList
}
// Clear removes all entries from the cache.
func (c *DEKCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
// Securely wipe DEK bytes before dropping references
for _, entry := range c.entries {
for i := range entry.dek {
entry.dek[i] = 0
}
}
c.entries = make(map[string]*cacheEntry)
c.evictionList = make([]string, 0)
}
// Stats returns current cache statistics.
func (c *DEKCache) Stats() CacheStats {
c.mu.RLock()
defer c.mu.RUnlock()
return CacheStats{
Size: len(c.entries),
MaxSize: c.maxEntries,
TTL: c.ttl,
GraceWindow: c.graceWindow,
}
}
// CacheStats holds cache statistics.
type CacheStats struct {
Size int
MaxSize int
TTL time.Duration
GraceWindow time.Duration
}
// updateLRU moves the accessed key to the back of the list (most recently used).
func (c *DEKCache) updateLRU(key string) {
// Find and remove key from current position
for i, k := range c.evictionList {
if k == key {
// Remove from current position
c.evictionList = append(c.evictionList[:i], c.evictionList[i+1:]...)
break
}
}
// Add to back (most recent)
c.evictionList = append(c.evictionList, key)
}
// evictLRU removes the oldest entry (front of list).
func (c *DEKCache) evictLRU() {
if len(c.evictionList) == 0 {
return
}
// Remove oldest (front of list)
oldestKey := c.evictionList[0]
c.evictionList = c.evictionList[1:]
// Securely wipe DEK bytes
if entry, exists := c.entries[oldestKey]; exists {
for i := range entry.dek {
entry.dek[i] = 0
}
entry.evicted = true
}
delete(c.entries, oldestKey)
}
// cleanupExpired periodically removes expired entries.
// This should be called periodically (e.g., by a background goroutine).
func (c *DEKCache) cleanupExpired() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
for key, entry := range c.entries {
if now.Sub(entry.createdAt) > c.ttl+c.graceWindow {
// Securely wipe
for i := range entry.dek {
entry.dek[i] = 0
}
entry.evicted = true
delete(c.entries, key)
}
}
// Rebuild eviction list
newList := make([]string, 0, len(c.entries))
for _, key := range c.evictionList {
if _, exists := c.entries[key]; exists {
newList = append(newList, key)
}
}
c.evictionList = newList
}
// StartCleanup starts a background goroutine to periodically clean up expired entries.
func (c *DEKCache) StartCleanup(interval time.Duration) chan struct{} {
stop := make(chan struct{})
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanupExpired()
case <-stop:
return
}
}
}()
return stop
}

View file

@ -0,0 +1,168 @@
// Package config provides KMS configuration types shared across KMS providers.
package config
import (
"fmt"
"time"
)
// ProviderType identifies the KMS provider implementation.
type ProviderType string
const (
ProviderTypeVault ProviderType = "vault"
ProviderTypeAWS ProviderType = "aws"
ProviderTypeMemory ProviderType = "memory" // Development only
)
// IsValid returns true if the provider type is valid.
func (t ProviderType) IsValid() bool {
switch t {
case ProviderTypeVault, ProviderTypeAWS, ProviderTypeMemory:
return true
}
return false
}
// Config holds KMS provider configuration.
type Config struct {
Provider ProviderType `yaml:"provider"`
Vault VaultConfig `yaml:"vault,omitempty"`
AWS AWSConfig `yaml:"aws,omitempty"`
Cache CacheConfig `yaml:"cache,omitempty"`
}
// VaultConfig holds HashiCorp Vault-specific configuration.
type VaultConfig struct {
Address string `yaml:"address"`
AuthMethod string `yaml:"auth_method"`
RoleID string `yaml:"role_id"`
SecretID string `yaml:"secret_id"`
Token string `yaml:"token"`
TransitMount string `yaml:"transit_mount"`
KeyPrefix string `yaml:"key_prefix"`
Region string `yaml:"region"`
Timeout time.Duration `yaml:"timeout"`
}
// AWSConfig holds AWS KMS-specific configuration.
type AWSConfig struct {
Region string `yaml:"region"`
KeyAliasPrefix string `yaml:"key_alias_prefix"`
RoleARN string `yaml:"role_arn,omitempty"`
Endpoint string `yaml:"endpoint,omitempty"`
}
// CacheConfig holds DEK cache configuration per ADR-012.
type CacheConfig struct {
TTL time.Duration `yaml:"ttl_minutes"`
MaxEntries int `yaml:"max_entries"`
GraceWindow time.Duration `yaml:"grace_window_minutes"`
}
// DefaultCacheConfig returns the default cache configuration per ADR-012/013.
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
TTL: 15 * time.Minute,
MaxEntries: 1000,
GraceWindow: 1 * time.Hour,
}
}
// Validate checks the configuration for errors.
func (c *Config) Validate() error {
if !c.Provider.IsValid() {
return fmt.Errorf("invalid KMS provider: %s", c.Provider)
}
switch c.Provider {
case ProviderTypeVault:
if err := c.Vault.Validate(); err != nil {
return fmt.Errorf("vault config: %w", err)
}
case ProviderTypeAWS:
if err := c.AWS.Validate(); err != nil {
return fmt.Errorf("aws config: %w", err)
}
}
// Apply defaults for cache config
if c.Cache.TTL == 0 {
c.Cache.TTL = DefaultCacheConfig().TTL
}
if c.Cache.MaxEntries == 0 {
c.Cache.MaxEntries = DefaultCacheConfig().MaxEntries
}
if c.Cache.GraceWindow == 0 {
c.Cache.GraceWindow = DefaultCacheConfig().GraceWindow
}
return nil
}
// Validate checks Vault configuration.
func (v *VaultConfig) Validate() error {
if v.Address == "" {
return fmt.Errorf("vault address is required")
}
switch v.AuthMethod {
case "approle":
if v.RoleID == "" || v.SecretID == "" {
return fmt.Errorf("approle auth requires role_id and secret_id")
}
case "kubernetes":
// Kubernetes auth uses service account token
case "token":
if v.Token == "" {
return fmt.Errorf("token auth requires token")
}
default:
return fmt.Errorf("invalid auth_method: %s", v.AuthMethod)
}
// Apply defaults
if v.TransitMount == "" {
v.TransitMount = "transit"
}
if v.KeyPrefix == "" {
v.KeyPrefix = "fetchml-tenant"
}
if v.Timeout == 0 {
v.Timeout = 30 * time.Second
}
return nil
}
// Validate checks AWS configuration.
func (a *AWSConfig) Validate() error {
if a.Region == "" {
return fmt.Errorf("AWS region is required")
}
// Apply defaults
if a.KeyAliasPrefix == "" {
a.KeyAliasPrefix = "alias/fetchml"
}
return nil
}
// KeyIDForTenant generates a KMS key ID for a tenant based on the provider config.
func (c *Config) KeyIDForTenant(tenantID string) string {
switch c.Provider {
case ProviderTypeVault:
if c.Vault.Region != "" {
return fmt.Sprintf("%s-%s-%s", c.Vault.KeyPrefix, c.Vault.Region, tenantID)
}
return fmt.Sprintf("%s-%s", c.Vault.KeyPrefix, tenantID)
case ProviderTypeAWS:
if c.AWS.Region != "" {
return fmt.Sprintf("%s-%s-%s", c.AWS.KeyAliasPrefix, c.AWS.Region, tenantID)
}
return fmt.Sprintf("%s-%s", c.AWS.KeyAliasPrefix, tenantID)
default:
return tenantID
}
}

View file

@ -0,0 +1,165 @@
// Package kms provides Key Management System (KMS) integrations for external
// key management providers (HashiCorp Vault, AWS KMS, etc.).
// This implements the KMS integration per ADR-012 through ADR-015.
package kms
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/providers"
)
// KMSProvider defines the interface for external KMS operations.
type KMSProvider interface {
Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error)
Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error)
CreateKey(ctx context.Context, tenantID string) (string, error)
DisableKey(ctx context.Context, keyID string) error
ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error)
EnableKey(ctx context.Context, keyID string) error
HealthCheck(ctx context.Context) error
Close() error
}
// Provider type aliases to avoid duplication.
type ProviderType = config.ProviderType
const (
ProviderTypeVault = config.ProviderTypeVault
ProviderTypeAWS = config.ProviderTypeAWS
ProviderTypeMemory = config.ProviderTypeMemory
)
// ProviderFactory creates KMS providers from configuration.
type ProviderFactory struct {
config Config
}
// Config aliases.
type Config = config.Config
type VaultConfig = config.VaultConfig
type AWSConfig = config.AWSConfig
type CacheConfig = config.CacheConfig
// DefaultCacheConfig re-exports from config package.
func DefaultCacheConfig() CacheConfig {
return config.DefaultCacheConfig()
}
// CreateProvider instantiates a KMS provider based on the configuration.
func (f *ProviderFactory) CreateProvider() (KMSProvider, error) {
switch f.config.Provider {
case ProviderTypeVault:
return providers.NewVaultProvider(f.config.Vault)
case ProviderTypeAWS:
return providers.NewAWSProvider(f.config.AWS)
case ProviderTypeMemory:
return NewMemoryProvider(), nil
default:
return nil, fmt.Errorf("unsupported KMS provider: %s", f.config.Provider)
}
}
// MemoryProvider implements KMSProvider for development/testing.
// Root keys are stored in-memory. NOT for production use.
type MemoryProvider struct {
keys map[string][]byte // keyID -> root key
}
// NewMemoryProvider creates a new in-memory KMS provider for development.
func NewMemoryProvider() *MemoryProvider {
return &MemoryProvider{
keys: make(map[string][]byte),
}
}
// Encrypt encrypts plaintext using the specified key ID with MAC authentication.
func (m *MemoryProvider) Encrypt(_ context.Context, keyID string, plaintext []byte) ([]byte, error) {
key, exists := m.keys[keyID]
if !exists {
return nil, fmt.Errorf("key not found: %s", keyID)
}
// XOR encrypt
ciphertext := make([]byte, len(plaintext))
for i := range plaintext {
ciphertext[i] = plaintext[i] ^ key[i%len(key)]
}
// Append MAC for integrity
mac := hmac.New(sha256.New, key)
mac.Write(ciphertext)
macSum := mac.Sum(nil)
return append(ciphertext, macSum...), nil
}
// Decrypt decrypts ciphertext using the specified key ID with MAC verification.
func (m *MemoryProvider) Decrypt(_ context.Context, keyID string, ciphertext []byte) ([]byte, error) {
key, exists := m.keys[keyID]
if !exists {
return nil, fmt.Errorf("key not found: %s", keyID)
}
// Need at least 32 bytes for MAC
if len(ciphertext) < 32 {
return nil, fmt.Errorf("ciphertext too short")
}
// Split ciphertext and MAC
data := ciphertext[:len(ciphertext)-32]
macSum := ciphertext[len(ciphertext)-32:]
// Verify MAC
mac := hmac.New(sha256.New, key)
mac.Write(data)
expectedMAC := mac.Sum(nil)
if !hmac.Equal(macSum, expectedMAC) {
return nil, fmt.Errorf("MAC verification failed: wrong key or corrupted data")
}
// XOR decrypt
plaintext := make([]byte, len(data))
for i := range data {
plaintext[i] = data[i] ^ key[i%len(key)]
}
return plaintext, nil
}
// CreateKey creates a new in-memory key.
func (m *MemoryProvider) CreateKey(_ context.Context, tenantID string) (string, error) {
keyID := fmt.Sprintf("memory-%s-%d", tenantID, time.Now().UnixNano())
// Generate a 32-byte random key for AES-256
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate random key: %w", err)
}
m.keys[keyID] = key
return keyID, nil
}
// DisableKey disables a key (no-op in memory provider).
func (m *MemoryProvider) DisableKey(_ context.Context, _ string) error {
return nil
}
// ScheduleKeyDeletion schedules key deletion (removes from map in memory provider).
func (m *MemoryProvider) ScheduleKeyDeletion(_ context.Context, keyID string, windowDays int) (time.Time, error) {
delete(m.keys, keyID)
// Return deletion date (windowDays from now)
return time.Now().Add(time.Duration(windowDays) * 24 * time.Hour), nil
}
// EnableKey re-enables a disabled key (no-op in memory provider).
func (m *MemoryProvider) EnableKey(_ context.Context, _ string) error {
return nil
}
// HealthCheck always returns healthy for memory provider.
func (m *MemoryProvider) HealthCheck(_ context.Context) error {
return nil
}
// Close releases resources (no-op for memory provider).
func (m *MemoryProvider) Close() error {
return nil
}