fetch_ml/internal/crypto/kms/providers/awskms.go
Jeremie Fraeys 7c03c8b5bd
feat(kms): add HashiCorp Vault and AWS KMS providers
Implement VaultProvider with Transit engine:
- AppRole, Kubernetes, and Token authentication
- Encrypt/Decrypt via /transit/encrypt and /transit/decrypt
- Key lifecycle via /transit/keys API
- Health check via /sys/health

Implement AWSProvider with SDK v2:
- Per-region key naming with alias prefix
- Encrypt/Decrypt via KMS SDK
- Key lifecycle (CreateKey, Disable, ScheduleDeletion, Enable)
- AWS endpoint support for LocalStack testing
2026-03-03 19:14:21 -05:00

192 lines
4.9 KiB
Go

package providers
import (
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
)
// AWSProvider implements KMSProvider for AWS KMS.
type AWSProvider struct {
config kmsconfig.AWSConfig
client *kms.Client
}
// NewAWSProvider creates a new AWS KMS provider.
func NewAWSProvider(config kmsconfig.AWSConfig) (*AWSProvider, error) {
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid AWS config: %w", err)
}
// Load AWS configuration with region
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
awsconfig.WithRegion(config.Region),
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
// Create KMS client options
var opts []func(*kms.Options)
if config.Endpoint != "" {
// Custom endpoint for LocalStack testing
opts = append(opts, func(o *kms.Options) {
o.BaseEndpoint = aws.String(config.Endpoint)
})
}
client := kms.NewFromConfig(cfg, opts...)
return &AWSProvider{
config: config,
client: client,
}, nil
}
// Encrypt encrypts plaintext using AWS KMS.
func (a *AWSProvider) Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) {
input := &kms.EncryptInput{
KeyId: aws.String(keyID),
Plaintext: plaintext,
}
result, err := a.client.Encrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("AWS KMS encrypt failed: %w", err)
}
return result.CiphertextBlob, nil
}
// Decrypt decrypts ciphertext using AWS KMS.
func (a *AWSProvider) Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) {
input := &kms.DecryptInput{
KeyId: aws.String(keyID),
CiphertextBlob: ciphertext,
}
result, err := a.client.Decrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("AWS KMS decrypt failed: %w", err)
}
return result.Plaintext, nil
}
// CreateKey creates a new symmetric encryption key in AWS KMS.
func (a *AWSProvider) CreateKey(ctx context.Context, tenantID string) (string, error) {
// Generate alias with per-region prefix per ADR-014
alias := a.config.KeyAliasPrefix
if a.config.Region != "" {
alias = fmt.Sprintf("%s-%s-%s", alias, a.config.Region, tenantID)
} else {
alias = fmt.Sprintf("%s-%s", alias, tenantID)
}
// Create the key
input := &kms.CreateKeyInput{
Description: aws.String(fmt.Sprintf("FetchML tenant key for %s", tenantID)),
KeyUsage: types.KeyUsageTypeEncryptDecrypt,
KeySpec: types.KeySpecSymmetricDefault,
MultiRegion: aws.Bool(false), // Per-region keys per ADR-014
}
result, err := a.client.CreateKey(ctx, input)
if err != nil {
return "", fmt.Errorf("failed to create AWS KMS key: %w", err)
}
keyID := *result.KeyMetadata.KeyId
// Create alias for easier reference
aliasInput := &kms.CreateAliasInput{
AliasName: aws.String(alias),
TargetKeyId: aws.String(keyID),
}
_, err = a.client.CreateAlias(ctx, aliasInput)
if err != nil {
// Don't fail if alias creation fails, just return the key ID
// The key is still usable by its ARN
_ = err
}
return alias, nil
}
// DisableKey disables a KMS key immediately.
func (a *AWSProvider) DisableKey(ctx context.Context, keyID string) error {
input := &kms.DisableKeyInput{
KeyId: aws.String(keyID),
}
_, err := a.client.DisableKey(ctx, input)
if err != nil {
return fmt.Errorf("failed to disable AWS KMS key: %w", err)
}
return nil
}
// ScheduleKeyDeletion schedules hard deletion of a KMS key.
func (a *AWSProvider) ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) {
// AWS enforces minimum 7 days
if windowDays < 7 {
windowDays = 7
}
pendingWindow := int32(windowDays)
input := &kms.ScheduleKeyDeletionInput{
KeyId: aws.String(keyID),
PendingWindowInDays: aws.Int32(pendingWindow),
}
result, err := a.client.ScheduleKeyDeletion(ctx, input)
if err != nil {
return time.Time{}, fmt.Errorf("failed to schedule AWS KMS key deletion: %w", err)
}
return *result.DeletionDate, nil
}
// EnableKey re-enables a disabled KMS key.
func (a *AWSProvider) EnableKey(ctx context.Context, keyID string) error {
input := &kms.EnableKeyInput{
KeyId: aws.String(keyID),
}
_, err := a.client.EnableKey(ctx, input)
if err != nil {
return fmt.Errorf("failed to enable AWS KMS key: %w", err)
}
return nil
}
// HealthCheck verifies AWS KMS connectivity.
func (a *AWSProvider) HealthCheck(ctx context.Context) error {
// Use ListKeys as a lightweight health check
input := &kms.ListKeysInput{
Limit: aws.Int32(1),
}
_, err := a.client.ListKeys(ctx, input)
if err != nil {
return fmt.Errorf("AWS KMS health check failed: %w", err)
}
return nil
}
// Close closes the AWS client connection.
func (a *AWSProvider) Close() error {
// AWS SDK clients don't require explicit closing
return nil
}