Implement VaultProvider with Transit engine: - AppRole, Kubernetes, and Token authentication - Encrypt/Decrypt via /transit/encrypt and /transit/decrypt - Key lifecycle via /transit/keys API - Health check via /sys/health Implement AWSProvider with SDK v2: - Per-region key naming with alias prefix - Encrypt/Decrypt via KMS SDK - Key lifecycle (CreateKey, Disable, ScheduleDeletion, Enable) - AWS endpoint support for LocalStack testing
192 lines
4.9 KiB
Go
192 lines
4.9 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/service/kms"
|
|
"github.com/aws/aws-sdk-go-v2/service/kms/types"
|
|
|
|
kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
|
|
)
|
|
|
|
// AWSProvider implements KMSProvider for AWS KMS.
|
|
type AWSProvider struct {
|
|
config kmsconfig.AWSConfig
|
|
client *kms.Client
|
|
}
|
|
|
|
// NewAWSProvider creates a new AWS KMS provider.
|
|
func NewAWSProvider(config kmsconfig.AWSConfig) (*AWSProvider, error) {
|
|
if err := config.Validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid AWS config: %w", err)
|
|
}
|
|
|
|
// Load AWS configuration with region
|
|
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
|
|
awsconfig.WithRegion(config.Region),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
|
}
|
|
|
|
// Create KMS client options
|
|
var opts []func(*kms.Options)
|
|
if config.Endpoint != "" {
|
|
// Custom endpoint for LocalStack testing
|
|
opts = append(opts, func(o *kms.Options) {
|
|
o.BaseEndpoint = aws.String(config.Endpoint)
|
|
})
|
|
}
|
|
|
|
client := kms.NewFromConfig(cfg, opts...)
|
|
|
|
return &AWSProvider{
|
|
config: config,
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
// Encrypt encrypts plaintext using AWS KMS.
|
|
func (a *AWSProvider) Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) {
|
|
input := &kms.EncryptInput{
|
|
KeyId: aws.String(keyID),
|
|
Plaintext: plaintext,
|
|
}
|
|
|
|
result, err := a.client.Encrypt(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AWS KMS encrypt failed: %w", err)
|
|
}
|
|
|
|
return result.CiphertextBlob, nil
|
|
}
|
|
|
|
// Decrypt decrypts ciphertext using AWS KMS.
|
|
func (a *AWSProvider) Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) {
|
|
input := &kms.DecryptInput{
|
|
KeyId: aws.String(keyID),
|
|
CiphertextBlob: ciphertext,
|
|
}
|
|
|
|
result, err := a.client.Decrypt(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AWS KMS decrypt failed: %w", err)
|
|
}
|
|
|
|
return result.Plaintext, nil
|
|
}
|
|
|
|
// CreateKey creates a new symmetric encryption key in AWS KMS.
|
|
func (a *AWSProvider) CreateKey(ctx context.Context, tenantID string) (string, error) {
|
|
// Generate alias with per-region prefix per ADR-014
|
|
alias := a.config.KeyAliasPrefix
|
|
if a.config.Region != "" {
|
|
alias = fmt.Sprintf("%s-%s-%s", alias, a.config.Region, tenantID)
|
|
} else {
|
|
alias = fmt.Sprintf("%s-%s", alias, tenantID)
|
|
}
|
|
|
|
// Create the key
|
|
input := &kms.CreateKeyInput{
|
|
Description: aws.String(fmt.Sprintf("FetchML tenant key for %s", tenantID)),
|
|
KeyUsage: types.KeyUsageTypeEncryptDecrypt,
|
|
KeySpec: types.KeySpecSymmetricDefault,
|
|
MultiRegion: aws.Bool(false), // Per-region keys per ADR-014
|
|
}
|
|
|
|
result, err := a.client.CreateKey(ctx, input)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create AWS KMS key: %w", err)
|
|
}
|
|
|
|
keyID := *result.KeyMetadata.KeyId
|
|
|
|
// Create alias for easier reference
|
|
aliasInput := &kms.CreateAliasInput{
|
|
AliasName: aws.String(alias),
|
|
TargetKeyId: aws.String(keyID),
|
|
}
|
|
|
|
_, err = a.client.CreateAlias(ctx, aliasInput)
|
|
if err != nil {
|
|
// Don't fail if alias creation fails, just return the key ID
|
|
// The key is still usable by its ARN
|
|
_ = err
|
|
}
|
|
|
|
return alias, nil
|
|
}
|
|
|
|
// DisableKey disables a KMS key immediately.
|
|
func (a *AWSProvider) DisableKey(ctx context.Context, keyID string) error {
|
|
input := &kms.DisableKeyInput{
|
|
KeyId: aws.String(keyID),
|
|
}
|
|
|
|
_, err := a.client.DisableKey(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to disable AWS KMS key: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ScheduleKeyDeletion schedules hard deletion of a KMS key.
|
|
func (a *AWSProvider) ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) {
|
|
// AWS enforces minimum 7 days
|
|
if windowDays < 7 {
|
|
windowDays = 7
|
|
}
|
|
|
|
pendingWindow := int32(windowDays)
|
|
input := &kms.ScheduleKeyDeletionInput{
|
|
KeyId: aws.String(keyID),
|
|
PendingWindowInDays: aws.Int32(pendingWindow),
|
|
}
|
|
|
|
result, err := a.client.ScheduleKeyDeletion(ctx, input)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("failed to schedule AWS KMS key deletion: %w", err)
|
|
}
|
|
|
|
return *result.DeletionDate, nil
|
|
}
|
|
|
|
// EnableKey re-enables a disabled KMS key.
|
|
func (a *AWSProvider) EnableKey(ctx context.Context, keyID string) error {
|
|
input := &kms.EnableKeyInput{
|
|
KeyId: aws.String(keyID),
|
|
}
|
|
|
|
_, err := a.client.EnableKey(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to enable AWS KMS key: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HealthCheck verifies AWS KMS connectivity.
|
|
func (a *AWSProvider) HealthCheck(ctx context.Context) error {
|
|
// Use ListKeys as a lightweight health check
|
|
input := &kms.ListKeysInput{
|
|
Limit: aws.Int32(1),
|
|
}
|
|
|
|
_, err := a.client.ListKeys(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("AWS KMS health check failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the AWS client connection.
|
|
func (a *AWSProvider) Close() error {
|
|
// AWS SDK clients don't require explicit closing
|
|
return nil
|
|
}
|