fetch_ml/internal/crypto/kms/provider.go
Jeremie Fraeys 37c4d4e9c7
feat(crypto,auth): harden KMS and improve permission handling
KMS improvements:
- cache.go: add LRU eviction with memory-bounded caches
- provider.go: refactor provider initialization and key rotation
- tenant_keys.go: per-tenant key isolation with envelope encryption

Auth layer updates:
- hybrid.go: refine hybrid auth flow for API key + JWT
- permissions_loader.go: faster permission caching with hot-reload
- validator.go: stricter validation with detailed error messages

Security middleware:
- security.go: add rate limiting headers and CORS refinement

Testing and benchmarks:
- Add KMS cache and protocol unit tests
- Add KMS benchmark tests for encryption throughput
- Update KMS integration tests for tenant isolation
2026-03-12 12:04:32 -04:00

179 lines
5.9 KiB
Go

// Package kms provides Key Management System (KMS) integrations for external
// key management providers (HashiCorp Vault, AWS KMS, etc.).
// This implements the KMS integration per ADR-012 through ADR-015.
package kms
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/providers"
)
// KMSProvider defines the interface for external KMS operations.
// Root keys are stored in the KMS; DEKs are generated locally and wrapped
// by the KMS root key.
type KMSProvider interface {
// Encrypt encrypts plaintext (typically a DEK) using the specified key ID.
// The key ID is a tenant-scoped KMS key identifier.
Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error)
// Decrypt decrypts ciphertext (typically a wrapped DEK) using the specified key ID.
Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error)
// CreateKey creates a new KMS key for a tenant. Returns the key ID.
CreateKey(ctx context.Context, tenantID string) (string, error)
// DisableKey disables a KMS key immediately (used in offboarding per ADR-015).
DisableKey(ctx context.Context, keyID string) error
// ScheduleKeyDeletion schedules hard deletion after the retention window (per ADR-015).
// Returns the deletion date.
ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error)
// EnableKey re-enables a disabled key (requires approval workflow per ADR-015).
EnableKey(ctx context.Context, keyID string) error
// HealthCheck verifies KMS connectivity and returns any error.
HealthCheck(ctx context.Context) error
// Close closes the KMS provider connection and releases resources.
Close() error
}
// ProviderType identifies the KMS provider implementation.
type ProviderType = config.ProviderType
// Provider type constants from config package.
const (
ProviderTypeVault = config.ProviderTypeVault
ProviderTypeAWS = config.ProviderTypeAWS
ProviderTypeMemory = config.ProviderTypeMemory
)
// ProviderFactory creates KMS providers from configuration.
type ProviderFactory struct {
config config.Config
}
// NewProviderFactory creates a new provider factory with the given config.
func NewProviderFactory(cfg config.Config) *ProviderFactory {
return &ProviderFactory{config: cfg}
}
// CreateProvider instantiates a KMS provider based on the configuration.
func (f *ProviderFactory) CreateProvider() (KMSProvider, error) {
switch f.config.Provider {
case config.ProviderTypeVault:
return providers.NewVaultProvider(f.config.Vault)
case config.ProviderTypeAWS:
return providers.NewAWSProvider(f.config.AWS)
case config.ProviderTypeMemory:
return NewMemoryProvider(), nil
default:
return nil, fmt.Errorf("unsupported KMS provider: %s", f.config.Provider)
}
}
// MemoryProvider implements KMSProvider for development/testing.
// Root keys are stored in-memory. NOT for production use.
type MemoryProvider struct {
keys map[string][]byte // keyID -> root key
}
// NewMemoryProvider creates a new in-memory KMS provider for development.
func NewMemoryProvider() *MemoryProvider {
return &MemoryProvider{
keys: make(map[string][]byte),
}
}
// Encrypt encrypts plaintext using the specified key ID with MAC authentication.
func (m *MemoryProvider) Encrypt(_ context.Context, keyID string, plaintext []byte) ([]byte, error) {
key, exists := m.keys[keyID]
if !exists {
return nil, fmt.Errorf("key not found: %s", keyID)
}
// XOR encrypt
ciphertext := make([]byte, len(plaintext))
for i := range plaintext {
ciphertext[i] = plaintext[i] ^ key[i%len(key)]
}
// Append MAC for integrity
mac := hmac.New(sha256.New, key)
mac.Write(ciphertext)
macSum := mac.Sum(nil)
return append(ciphertext, macSum...), nil
}
// Decrypt decrypts ciphertext using the specified key ID with MAC verification.
func (m *MemoryProvider) Decrypt(_ context.Context, keyID string, ciphertext []byte) ([]byte, error) {
key, exists := m.keys[keyID]
if !exists {
return nil, fmt.Errorf("key not found: %s", keyID)
}
// Need at least 32 bytes for MAC
if len(ciphertext) < 32 {
return nil, fmt.Errorf("ciphertext too short")
}
// Split ciphertext and MAC
data := ciphertext[:len(ciphertext)-32]
macSum := ciphertext[len(ciphertext)-32:]
// Verify MAC
mac := hmac.New(sha256.New, key)
mac.Write(data)
expectedMAC := mac.Sum(nil)
if !hmac.Equal(macSum, expectedMAC) {
return nil, fmt.Errorf("MAC verification failed: wrong key or corrupted data")
}
// XOR decrypt
plaintext := make([]byte, len(data))
for i := range data {
plaintext[i] = data[i] ^ key[i%len(key)]
}
return plaintext, nil
}
// CreateKey creates a new in-memory key.
func (m *MemoryProvider) CreateKey(_ context.Context, tenantID string) (string, error) {
keyID := fmt.Sprintf("memory-%s-%d", tenantID, time.Now().UnixNano())
// Generate a 32-byte random key for AES-256
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate random key: %w", err)
}
m.keys[keyID] = key
return keyID, nil
}
// DisableKey disables a key (no-op in memory provider).
func (m *MemoryProvider) DisableKey(_ context.Context, _ string) error {
return nil
}
// ScheduleKeyDeletion schedules key deletion (removes from map in memory provider).
func (m *MemoryProvider) ScheduleKeyDeletion(_ context.Context, keyID string, windowDays int) (time.Time, error) {
delete(m.keys, keyID)
// Return deletion date (windowDays from now)
return time.Now().Add(time.Duration(windowDays) * 24 * time.Hour), nil
}
// EnableKey re-enables a disabled key (no-op in memory provider).
func (m *MemoryProvider) EnableKey(_ context.Context, _ string) error {
return nil
}
// HealthCheck always returns healthy for memory provider.
func (m *MemoryProvider) HealthCheck(_ context.Context) error {
return nil
}
// Close releases resources (no-op for memory provider).
func (m *MemoryProvider) Close() error {
return nil
}