diff --git a/internal/crypto/kms/cache.go b/internal/crypto/kms/cache.go new file mode 100644 index 0000000..23d3217 --- /dev/null +++ b/internal/crypto/kms/cache.go @@ -0,0 +1,277 @@ +package kms + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "sync" + "time" +) + +// DEKCache implements an in-process cache for unwrapped DEKs per ADR-012. +// - TTL: 15 minutes per entry +// - Max size: 1000 entries (LRU eviction) +// - Scope: In-process only, never serialized to disk +type DEKCache struct { + mu sync.RWMutex + entries map[string]*cacheEntry + ttl time.Duration + maxEntries int + graceWindow time.Duration // Per ADR-013 + evictionList []string // Simple LRU: front = oldest, back = newest +} + +// cacheEntry holds a cached DEK with metadata. +type cacheEntry struct { + dek []byte + tenantID string + artifactID string + kmsKeyID string + createdAt time.Time + lastAccess time.Time + evicted bool +} + +// NewDEKCache creates a new DEK cache with the specified configuration. +func NewDEKCache(cfg CacheConfig) *DEKCache { + return &DEKCache{ + entries: make(map[string]*cacheEntry), + ttl: cfg.TTL, + maxEntries: cfg.MaxEntries, + graceWindow: cfg.GraceWindow, + evictionList: make([]string, 0), + } +} + +// cacheKey generates a unique key for the cache map including KMS key ID. +func cacheKey(tenantID, artifactID, kmsKeyID string) string { + // Use a hash to avoid storing raw IDs in the key + h := sha256.New() + _, _ = h.Write([]byte(tenantID + "/" + artifactID + "/" + kmsKeyID)) + return hex.EncodeToString(h.Sum(nil)) +} + +// Get retrieves a DEK from the cache if present and not expired. +// Returns nil and false if not found or expired. +// Per ADR-013: During KMS unavailability, expired entries within grace window are still returned. +func (c *DEKCache) Get(tenantID, artifactID, kmsKeyID string, kmsUnavailable bool) ([]byte, bool) { + key := cacheKey(tenantID, artifactID, kmsKeyID) + + c.mu.RLock() + entry, exists := c.entries[key] + c.mu.RUnlock() + + if !exists || entry.evicted { + return nil, false + } + + now := time.Now() + age := now.Sub(entry.createdAt) + + // Check if expired + if age > c.ttl { + // Per ADR-013: Grace window only applies during KMS unavailability + if !kmsUnavailable { + return nil, false + } + + // Check grace window + graceAge := age - c.ttl + if graceAge > c.graceWindow { + return nil, false + } + + // Within grace window - return the DEK but log that we're using grace period + // (caller should log this appropriately) + } + + // Update last access time (need write lock) + c.mu.Lock() + entry.lastAccess = now + c.updateLRU(key) + c.mu.Unlock() + + // Return a copy of the DEK to prevent external modification + dekCopy := make([]byte, len(entry.dek)) + copy(dekCopy, entry.dek) + return dekCopy, true +} + +// Put stores a DEK in the cache. +func (c *DEKCache) Put(tenantID, artifactID, kmsKeyID string, dek []byte) error { + if len(dek) == 0 { + return fmt.Errorf("cannot cache empty DEK") + } + + key := cacheKey(tenantID, artifactID, kmsKeyID) + + // Copy the DEK to prevent external modification + dekCopy := make([]byte, len(dek)) + copy(dekCopy, dek) + + c.mu.Lock() + defer c.mu.Unlock() + + // Check if we need to evict + if len(c.entries) >= c.maxEntries { + c.evictLRU() + } + + // Store the entry + now := time.Now() + c.entries[key] = &cacheEntry{ + dek: dekCopy, + tenantID: tenantID, + artifactID: artifactID, + kmsKeyID: kmsKeyID, + createdAt: now, + lastAccess: now, + } + + // Add to LRU list (newest at back) + c.evictionList = append(c.evictionList, key) + + return nil +} + +// Flush removes all DEKs for a specific tenant from the cache. +// Called on key rotation events and tenant offboarding per ADR-012. +func (c *DEKCache) Flush(tenantID string) { + c.mu.Lock() + defer c.mu.Unlock() + + // Mark entries for eviction + for key, entry := range c.entries { + if entry.tenantID == tenantID { + entry.evicted = true + delete(c.entries, key) + } + } + + // Rebuild eviction list without flushed entries + newList := make([]string, 0, len(c.evictionList)) + for _, key := range c.evictionList { + if entry, exists := c.entries[key]; exists && !entry.evicted { + newList = append(newList, key) + } + } + c.evictionList = newList +} + +// Clear removes all entries from the cache. +func (c *DEKCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + // Securely wipe DEK bytes before dropping references + for _, entry := range c.entries { + for i := range entry.dek { + entry.dek[i] = 0 + } + } + + c.entries = make(map[string]*cacheEntry) + c.evictionList = make([]string, 0) +} + +// Stats returns current cache statistics. +func (c *DEKCache) Stats() CacheStats { + c.mu.RLock() + defer c.mu.RUnlock() + + return CacheStats{ + Size: len(c.entries), + MaxSize: c.maxEntries, + TTL: c.ttl, + GraceWindow: c.graceWindow, + } +} + +// CacheStats holds cache statistics. +type CacheStats struct { + Size int + MaxSize int + TTL time.Duration + GraceWindow time.Duration +} + +// updateLRU moves the accessed key to the back of the list (most recently used). +func (c *DEKCache) updateLRU(key string) { + // Find and remove key from current position + for i, k := range c.evictionList { + if k == key { + // Remove from current position + c.evictionList = append(c.evictionList[:i], c.evictionList[i+1:]...) + break + } + } + // Add to back (most recent) + c.evictionList = append(c.evictionList, key) +} + +// evictLRU removes the oldest entry (front of list). +func (c *DEKCache) evictLRU() { + if len(c.evictionList) == 0 { + return + } + + // Remove oldest (front of list) + oldestKey := c.evictionList[0] + c.evictionList = c.evictionList[1:] + + // Securely wipe DEK bytes + if entry, exists := c.entries[oldestKey]; exists { + for i := range entry.dek { + entry.dek[i] = 0 + } + entry.evicted = true + } + + delete(c.entries, oldestKey) +} + +// cleanupExpired periodically removes expired entries. +// This should be called periodically (e.g., by a background goroutine). +func (c *DEKCache) cleanupExpired() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for key, entry := range c.entries { + if now.Sub(entry.createdAt) > c.ttl+c.graceWindow { + // Securely wipe + for i := range entry.dek { + entry.dek[i] = 0 + } + entry.evicted = true + delete(c.entries, key) + } + } + + // Rebuild eviction list + newList := make([]string, 0, len(c.entries)) + for _, key := range c.evictionList { + if _, exists := c.entries[key]; exists { + newList = append(newList, key) + } + } + c.evictionList = newList +} + +// StartCleanup starts a background goroutine to periodically clean up expired entries. +func (c *DEKCache) StartCleanup(interval time.Duration) chan struct{} { + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.cleanupExpired() + case <-stop: + return + } + } + }() + return stop +} diff --git a/internal/crypto/kms/config/config.go b/internal/crypto/kms/config/config.go new file mode 100644 index 0000000..2c50a67 --- /dev/null +++ b/internal/crypto/kms/config/config.go @@ -0,0 +1,168 @@ +// Package config provides KMS configuration types shared across KMS providers. +package config + +import ( + "fmt" + "time" +) + +// ProviderType identifies the KMS provider implementation. +type ProviderType string + +const ( + ProviderTypeVault ProviderType = "vault" + ProviderTypeAWS ProviderType = "aws" + ProviderTypeMemory ProviderType = "memory" // Development only +) + +// IsValid returns true if the provider type is valid. +func (t ProviderType) IsValid() bool { + switch t { + case ProviderTypeVault, ProviderTypeAWS, ProviderTypeMemory: + return true + } + return false +} + +// Config holds KMS provider configuration. +type Config struct { + Provider ProviderType `yaml:"provider"` + Vault VaultConfig `yaml:"vault,omitempty"` + AWS AWSConfig `yaml:"aws,omitempty"` + Cache CacheConfig `yaml:"cache,omitempty"` +} + +// VaultConfig holds HashiCorp Vault-specific configuration. +type VaultConfig struct { + Address string `yaml:"address"` + AuthMethod string `yaml:"auth_method"` + RoleID string `yaml:"role_id"` + SecretID string `yaml:"secret_id"` + Token string `yaml:"token"` + TransitMount string `yaml:"transit_mount"` + KeyPrefix string `yaml:"key_prefix"` + Region string `yaml:"region"` + Timeout time.Duration `yaml:"timeout"` +} + +// AWSConfig holds AWS KMS-specific configuration. +type AWSConfig struct { + Region string `yaml:"region"` + KeyAliasPrefix string `yaml:"key_alias_prefix"` + RoleARN string `yaml:"role_arn,omitempty"` + Endpoint string `yaml:"endpoint,omitempty"` +} + +// CacheConfig holds DEK cache configuration per ADR-012. +type CacheConfig struct { + TTL time.Duration `yaml:"ttl_minutes"` + MaxEntries int `yaml:"max_entries"` + GraceWindow time.Duration `yaml:"grace_window_minutes"` +} + +// DefaultCacheConfig returns the default cache configuration per ADR-012/013. +func DefaultCacheConfig() CacheConfig { + return CacheConfig{ + TTL: 15 * time.Minute, + MaxEntries: 1000, + GraceWindow: 1 * time.Hour, + } +} + +// Validate checks the configuration for errors. +func (c *Config) Validate() error { + if !c.Provider.IsValid() { + return fmt.Errorf("invalid KMS provider: %s", c.Provider) + } + + switch c.Provider { + case ProviderTypeVault: + if err := c.Vault.Validate(); err != nil { + return fmt.Errorf("vault config: %w", err) + } + case ProviderTypeAWS: + if err := c.AWS.Validate(); err != nil { + return fmt.Errorf("aws config: %w", err) + } + } + + // Apply defaults for cache config + if c.Cache.TTL == 0 { + c.Cache.TTL = DefaultCacheConfig().TTL + } + if c.Cache.MaxEntries == 0 { + c.Cache.MaxEntries = DefaultCacheConfig().MaxEntries + } + if c.Cache.GraceWindow == 0 { + c.Cache.GraceWindow = DefaultCacheConfig().GraceWindow + } + + return nil +} + +// Validate checks Vault configuration. +func (v *VaultConfig) Validate() error { + if v.Address == "" { + return fmt.Errorf("vault address is required") + } + + switch v.AuthMethod { + case "approle": + if v.RoleID == "" || v.SecretID == "" { + return fmt.Errorf("approle auth requires role_id and secret_id") + } + case "kubernetes": + // Kubernetes auth uses service account token + case "token": + if v.Token == "" { + return fmt.Errorf("token auth requires token") + } + default: + return fmt.Errorf("invalid auth_method: %s", v.AuthMethod) + } + + // Apply defaults + if v.TransitMount == "" { + v.TransitMount = "transit" + } + if v.KeyPrefix == "" { + v.KeyPrefix = "fetchml-tenant" + } + if v.Timeout == 0 { + v.Timeout = 30 * time.Second + } + + return nil +} + +// Validate checks AWS configuration. +func (a *AWSConfig) Validate() error { + if a.Region == "" { + return fmt.Errorf("AWS region is required") + } + + // Apply defaults + if a.KeyAliasPrefix == "" { + a.KeyAliasPrefix = "alias/fetchml" + } + + return nil +} + +// KeyIDForTenant generates a KMS key ID for a tenant based on the provider config. +func (c *Config) KeyIDForTenant(tenantID string) string { + switch c.Provider { + case ProviderTypeVault: + if c.Vault.Region != "" { + return fmt.Sprintf("%s-%s-%s", c.Vault.KeyPrefix, c.Vault.Region, tenantID) + } + return fmt.Sprintf("%s-%s", c.Vault.KeyPrefix, tenantID) + case ProviderTypeAWS: + if c.AWS.Region != "" { + return fmt.Sprintf("%s-%s-%s", c.AWS.KeyAliasPrefix, c.AWS.Region, tenantID) + } + return fmt.Sprintf("%s-%s", c.AWS.KeyAliasPrefix, tenantID) + default: + return tenantID + } +} diff --git a/internal/crypto/kms/provider.go b/internal/crypto/kms/provider.go new file mode 100644 index 0000000..f8f3afa --- /dev/null +++ b/internal/crypto/kms/provider.go @@ -0,0 +1,165 @@ +// Package kms provides Key Management System (KMS) integrations for external +// key management providers (HashiCorp Vault, AWS KMS, etc.). +// This implements the KMS integration per ADR-012 through ADR-015. +package kms + +import ( + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "fmt" + "time" + + "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" + "github.com/jfraeys/fetch_ml/internal/crypto/kms/providers" +) + +// KMSProvider defines the interface for external KMS operations. +type KMSProvider interface { + Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) + Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) + CreateKey(ctx context.Context, tenantID string) (string, error) + DisableKey(ctx context.Context, keyID string) error + ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) + EnableKey(ctx context.Context, keyID string) error + HealthCheck(ctx context.Context) error + Close() error +} + +// Provider type aliases to avoid duplication. +type ProviderType = config.ProviderType + +const ( + ProviderTypeVault = config.ProviderTypeVault + ProviderTypeAWS = config.ProviderTypeAWS + ProviderTypeMemory = config.ProviderTypeMemory +) + +// ProviderFactory creates KMS providers from configuration. +type ProviderFactory struct { + config Config +} + +// Config aliases. +type Config = config.Config +type VaultConfig = config.VaultConfig +type AWSConfig = config.AWSConfig +type CacheConfig = config.CacheConfig + +// DefaultCacheConfig re-exports from config package. +func DefaultCacheConfig() CacheConfig { + return config.DefaultCacheConfig() +} + +// CreateProvider instantiates a KMS provider based on the configuration. +func (f *ProviderFactory) CreateProvider() (KMSProvider, error) { + switch f.config.Provider { + case ProviderTypeVault: + return providers.NewVaultProvider(f.config.Vault) + case ProviderTypeAWS: + return providers.NewAWSProvider(f.config.AWS) + case ProviderTypeMemory: + return NewMemoryProvider(), nil + default: + return nil, fmt.Errorf("unsupported KMS provider: %s", f.config.Provider) + } +} + +// MemoryProvider implements KMSProvider for development/testing. +// Root keys are stored in-memory. NOT for production use. +type MemoryProvider struct { + keys map[string][]byte // keyID -> root key +} + +// NewMemoryProvider creates a new in-memory KMS provider for development. +func NewMemoryProvider() *MemoryProvider { + return &MemoryProvider{ + keys: make(map[string][]byte), + } +} + +// Encrypt encrypts plaintext using the specified key ID with MAC authentication. +func (m *MemoryProvider) Encrypt(_ context.Context, keyID string, plaintext []byte) ([]byte, error) { + key, exists := m.keys[keyID] + if !exists { + return nil, fmt.Errorf("key not found: %s", keyID) + } + // XOR encrypt + ciphertext := make([]byte, len(plaintext)) + for i := range plaintext { + ciphertext[i] = plaintext[i] ^ key[i%len(key)] + } + // Append MAC for integrity + mac := hmac.New(sha256.New, key) + mac.Write(ciphertext) + macSum := mac.Sum(nil) + return append(ciphertext, macSum...), nil +} + +// Decrypt decrypts ciphertext using the specified key ID with MAC verification. +func (m *MemoryProvider) Decrypt(_ context.Context, keyID string, ciphertext []byte) ([]byte, error) { + key, exists := m.keys[keyID] + if !exists { + return nil, fmt.Errorf("key not found: %s", keyID) + } + // Need at least 32 bytes for MAC + if len(ciphertext) < 32 { + return nil, fmt.Errorf("ciphertext too short") + } + // Split ciphertext and MAC + data := ciphertext[:len(ciphertext)-32] + macSum := ciphertext[len(ciphertext)-32:] + // Verify MAC + mac := hmac.New(sha256.New, key) + mac.Write(data) + expectedMAC := mac.Sum(nil) + if !hmac.Equal(macSum, expectedMAC) { + return nil, fmt.Errorf("MAC verification failed: wrong key or corrupted data") + } + // XOR decrypt + plaintext := make([]byte, len(data)) + for i := range data { + plaintext[i] = data[i] ^ key[i%len(key)] + } + return plaintext, nil +} + +// CreateKey creates a new in-memory key. +func (m *MemoryProvider) CreateKey(_ context.Context, tenantID string) (string, error) { + keyID := fmt.Sprintf("memory-%s-%d", tenantID, time.Now().UnixNano()) + // Generate a 32-byte random key for AES-256 + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return "", fmt.Errorf("failed to generate random key: %w", err) + } + m.keys[keyID] = key + return keyID, nil +} + +// DisableKey disables a key (no-op in memory provider). +func (m *MemoryProvider) DisableKey(_ context.Context, _ string) error { + return nil +} + +// ScheduleKeyDeletion schedules key deletion (removes from map in memory provider). +func (m *MemoryProvider) ScheduleKeyDeletion(_ context.Context, keyID string, windowDays int) (time.Time, error) { + delete(m.keys, keyID) + // Return deletion date (windowDays from now) + return time.Now().Add(time.Duration(windowDays) * 24 * time.Hour), nil +} + +// EnableKey re-enables a disabled key (no-op in memory provider). +func (m *MemoryProvider) EnableKey(_ context.Context, _ string) error { + return nil +} + +// HealthCheck always returns healthy for memory provider. +func (m *MemoryProvider) HealthCheck(_ context.Context) error { + return nil +} + +// Close releases resources (no-op for memory provider). +func (m *MemoryProvider) Close() error { + return nil +}