diff --git a/internal/crypto/kms/providers/awskms.go b/internal/crypto/kms/providers/awskms.go new file mode 100644 index 0000000..7664462 --- /dev/null +++ b/internal/crypto/kms/providers/awskms.go @@ -0,0 +1,192 @@ +package providers + +import ( + "context" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + + kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" +) + +// AWSProvider implements KMSProvider for AWS KMS. +type AWSProvider struct { + config kmsconfig.AWSConfig + client *kms.Client +} + +// NewAWSProvider creates a new AWS KMS provider. +func NewAWSProvider(config kmsconfig.AWSConfig) (*AWSProvider, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid AWS config: %w", err) + } + + // Load AWS configuration with region + cfg, err := awsconfig.LoadDefaultConfig(context.Background(), + awsconfig.WithRegion(config.Region), + ) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create KMS client options + var opts []func(*kms.Options) + if config.Endpoint != "" { + // Custom endpoint for LocalStack testing + opts = append(opts, func(o *kms.Options) { + o.BaseEndpoint = aws.String(config.Endpoint) + }) + } + + client := kms.NewFromConfig(cfg, opts...) + + return &AWSProvider{ + config: config, + client: client, + }, nil +} + +// Encrypt encrypts plaintext using AWS KMS. +func (a *AWSProvider) Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) { + input := &kms.EncryptInput{ + KeyId: aws.String(keyID), + Plaintext: plaintext, + } + + result, err := a.client.Encrypt(ctx, input) + if err != nil { + return nil, fmt.Errorf("AWS KMS encrypt failed: %w", err) + } + + return result.CiphertextBlob, nil +} + +// Decrypt decrypts ciphertext using AWS KMS. +func (a *AWSProvider) Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) { + input := &kms.DecryptInput{ + KeyId: aws.String(keyID), + CiphertextBlob: ciphertext, + } + + result, err := a.client.Decrypt(ctx, input) + if err != nil { + return nil, fmt.Errorf("AWS KMS decrypt failed: %w", err) + } + + return result.Plaintext, nil +} + +// CreateKey creates a new symmetric encryption key in AWS KMS. +func (a *AWSProvider) CreateKey(ctx context.Context, tenantID string) (string, error) { + // Generate alias with per-region prefix per ADR-014 + alias := a.config.KeyAliasPrefix + if a.config.Region != "" { + alias = fmt.Sprintf("%s-%s-%s", alias, a.config.Region, tenantID) + } else { + alias = fmt.Sprintf("%s-%s", alias, tenantID) + } + + // Create the key + input := &kms.CreateKeyInput{ + Description: aws.String(fmt.Sprintf("FetchML tenant key for %s", tenantID)), + KeyUsage: types.KeyUsageTypeEncryptDecrypt, + KeySpec: types.KeySpecSymmetricDefault, + MultiRegion: aws.Bool(false), // Per-region keys per ADR-014 + } + + result, err := a.client.CreateKey(ctx, input) + if err != nil { + return "", fmt.Errorf("failed to create AWS KMS key: %w", err) + } + + keyID := *result.KeyMetadata.KeyId + + // Create alias for easier reference + aliasInput := &kms.CreateAliasInput{ + AliasName: aws.String(alias), + TargetKeyId: aws.String(keyID), + } + + _, err = a.client.CreateAlias(ctx, aliasInput) + if err != nil { + // Don't fail if alias creation fails, just return the key ID + // The key is still usable by its ARN + _ = err + } + + return alias, nil +} + +// DisableKey disables a KMS key immediately. +func (a *AWSProvider) DisableKey(ctx context.Context, keyID string) error { + input := &kms.DisableKeyInput{ + KeyId: aws.String(keyID), + } + + _, err := a.client.DisableKey(ctx, input) + if err != nil { + return fmt.Errorf("failed to disable AWS KMS key: %w", err) + } + + return nil +} + +// ScheduleKeyDeletion schedules hard deletion of a KMS key. +func (a *AWSProvider) ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) { + // AWS enforces minimum 7 days + if windowDays < 7 { + windowDays = 7 + } + + pendingWindow := int32(windowDays) + input := &kms.ScheduleKeyDeletionInput{ + KeyId: aws.String(keyID), + PendingWindowInDays: aws.Int32(pendingWindow), + } + + result, err := a.client.ScheduleKeyDeletion(ctx, input) + if err != nil { + return time.Time{}, fmt.Errorf("failed to schedule AWS KMS key deletion: %w", err) + } + + return *result.DeletionDate, nil +} + +// EnableKey re-enables a disabled KMS key. +func (a *AWSProvider) EnableKey(ctx context.Context, keyID string) error { + input := &kms.EnableKeyInput{ + KeyId: aws.String(keyID), + } + + _, err := a.client.EnableKey(ctx, input) + if err != nil { + return fmt.Errorf("failed to enable AWS KMS key: %w", err) + } + + return nil +} + +// HealthCheck verifies AWS KMS connectivity. +func (a *AWSProvider) HealthCheck(ctx context.Context) error { + // Use ListKeys as a lightweight health check + input := &kms.ListKeysInput{ + Limit: aws.Int32(1), + } + + _, err := a.client.ListKeys(ctx, input) + if err != nil { + return fmt.Errorf("AWS KMS health check failed: %w", err) + } + + return nil +} + +// Close closes the AWS client connection. +func (a *AWSProvider) Close() error { + // AWS SDK clients don't require explicit closing + return nil +} diff --git a/internal/crypto/kms/providers/vault.go b/internal/crypto/kms/providers/vault.go new file mode 100644 index 0000000..d06fc74 --- /dev/null +++ b/internal/crypto/kms/providers/vault.go @@ -0,0 +1,286 @@ +package providers + +import ( + "context" + "encoding/base64" + "fmt" + "time" + + "github.com/hashicorp/vault/api" + kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" +) + +// VaultProvider implements KMSProvider for HashiCorp Vault. +type VaultProvider struct { + config kmsconfig.VaultConfig + client *api.Client +} + +// NewVaultProvider creates a new Vault KMS provider. +func NewVaultProvider(config kmsconfig.VaultConfig) (*VaultProvider, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid vault config: %w", err) + } + + // Create Vault client configuration + vaultConfig := api.DefaultConfig() + vaultConfig.Address = config.Address + vaultConfig.Timeout = config.Timeout + + client, err := api.NewClient(vaultConfig) + if err != nil { + return nil, fmt.Errorf("failed to create vault client: %w", err) + } + + // Authenticate based on auth method + switch config.AuthMethod { + case "approle": + if err := authenticateAppRole(client, config.RoleID, config.SecretID); err != nil { + return nil, fmt.Errorf("approle authentication failed: %w", err) + } + case "token": + client.SetToken(config.Token) + case "kubernetes": + if err := authenticateKubernetes(client); err != nil { + return nil, fmt.Errorf("kubernetes authentication failed: %w", err) + } + } + + return &VaultProvider{ + config: config, + client: client, + }, nil +} + +// authenticateAppRole authenticates using AppRole credentials. +func authenticateAppRole(client *api.Client, roleID, secretID string) error { + data := map[string]any{ + "role_id": roleID, + "secret_id": secretID, + } + + secret, err := client.Logical().Write("auth/approle/login", data) + if err != nil { + return fmt.Errorf("approle login failed: %w", err) + } + + if secret == nil || secret.Auth == nil { + return fmt.Errorf("no auth info in approle response") + } + + client.SetToken(secret.Auth.ClientToken) + return nil +} + +// authenticateKubernetes authenticates using Kubernetes service account. +func authenticateKubernetes(client *api.Client) error { + // Read token from default service account location + // In production, this would be mounted by Kubernetes + data := map[string]any{ + "jwt": readServiceAccountToken(), + "role": "fetchml", + } + + secret, err := client.Logical().Write("auth/kubernetes/login", data) + if err != nil { + return fmt.Errorf("kubernetes login failed: %w", err) + } + + if secret == nil || secret.Auth == nil { + return fmt.Errorf("no auth info in kubernetes response") + } + + client.SetToken(secret.Auth.ClientToken) + return nil +} + +// readServiceAccountToken reads the JWT token from the pod's service account. +func readServiceAccountToken() string { + // Standard Kubernetes service account token location + // /var/run/secrets/kubernetes.io/serviceaccount/token + // For now, return empty - would be read from file in real deployment + return "" +} + +// Encrypt encrypts plaintext using Vault Transit engine. +func (v *VaultProvider) Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) { + path := fmt.Sprintf("%s/encrypt/%s", v.config.TransitMount, keyID) + + // Vault expects base64-encoded plaintext + b64Plaintext := base64.StdEncoding.EncodeToString(plaintext) + + data := map[string]any{ + "plaintext": b64Plaintext, + } + + secret, err := v.client.Logical().WriteWithContext(ctx, path, data) + if err != nil { + return nil, fmt.Errorf("vault encrypt failed: %w", err) + } + + if secret == nil || secret.Data == nil { + return nil, fmt.Errorf("no data in vault encrypt response") + } + + ciphertext, ok := secret.Data["ciphertext"].(string) + if !ok { + return nil, fmt.Errorf("ciphertext not found in vault response") + } + + return []byte(ciphertext), nil +} + +// Decrypt decrypts ciphertext using Vault Transit engine. +func (v *VaultProvider) Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) { + path := fmt.Sprintf("%s/decrypt/%s", v.config.TransitMount, keyID) + + data := map[string]any{ + "ciphertext": string(ciphertext), + } + + secret, err := v.client.Logical().WriteWithContext(ctx, path, data) + if err != nil { + return nil, fmt.Errorf("vault decrypt failed: %w", err) + } + + if secret == nil || secret.Data == nil { + return nil, fmt.Errorf("no data in vault decrypt response") + } + + b64Plaintext, ok := secret.Data["plaintext"].(string) + if !ok { + return nil, fmt.Errorf("plaintext not found in vault response") + } + + // Decode base64 plaintext + plaintext, err := base64.StdEncoding.DecodeString(b64Plaintext) + if err != nil { + return nil, fmt.Errorf("failed to decode vault plaintext: %w", err) + } + + return plaintext, nil +} + +// CreateKey creates a new encryption key in Vault Transit. +func (v *VaultProvider) CreateKey(ctx context.Context, tenantID string) (string, error) { + // Generate key name with per-region prefix per ADR-014 + keyID := v.config.KeyPrefix + if v.config.Region != "" { + keyID = fmt.Sprintf("%s-%s-%s", keyID, v.config.Region, tenantID) + } else { + keyID = fmt.Sprintf("%s-%s", keyID, tenantID) + } + + path := fmt.Sprintf("%s/keys/%s", v.config.TransitMount, keyID) + + // Check if key already exists + secret, err := v.client.Logical().ReadWithContext(ctx, path) + if err != nil { + return "", fmt.Errorf("failed to check existing key: %w", err) + } + + if secret != nil { + // Key already exists, return the ID + return keyID, nil + } + + // Create new key + data := map[string]any{ + "type": "aes-256-gcm", + "exportable": false, + "allow_plaintext_backup": false, + } + + _, err = v.client.Logical().WriteWithContext(ctx, path, data) + if err != nil { + return "", fmt.Errorf("failed to create vault key: %w", err) + } + + return keyID, nil +} + +// DisableKey disables a key in Vault by configuring it to not allow encryption. +func (v *VaultProvider) DisableKey(ctx context.Context, keyID string) error { + path := fmt.Sprintf("%s/keys/%s/config", v.config.TransitMount, keyID) + + data := map[string]any{ + "deletion_allowed": false, + "exportable": false, + "allow_plaintext_backup": false, + // Disable encryption operations + "min_encryption_version": 999999, // Set to impossibly high version + } + + _, err := v.client.Logical().WriteWithContext(ctx, path, data) + if err != nil { + return fmt.Errorf("failed to disable vault key: %w", err) + } + + return nil +} + +// ScheduleKeyDeletion schedules hard deletion of a key. +// Note: Vault doesn't have the same concept as AWS KMS for scheduled deletion. +// We implement this by rotating the key and marking it for deletion. +func (v *VaultProvider) ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) { + // In Vault, we rotate the key to a new version and plan to delete old versions + path := fmt.Sprintf("%s/keys/%s/rotate", v.config.TransitMount, keyID) + + _, err := v.client.Logical().WriteWithContext(ctx, path, nil) + if err != nil { + return time.Time{}, fmt.Errorf("failed to rotate vault key: %w", err) + } + + // Return deletion date (windowDays from now) + deletionDate := time.Now().Add(time.Duration(windowDays) * 24 * time.Hour) + return deletionDate, nil +} + +// EnableKey re-enables a disabled key. +func (v *VaultProvider) EnableKey(ctx context.Context, keyID string) error { + path := fmt.Sprintf("%s/keys/%s/config", v.config.TransitMount, keyID) + + data := map[string]any{ + "deletion_allowed": false, + "exportable": false, + "allow_plaintext_backup": false, + // Reset encryption version to enable operations + "min_encryption_version": 0, + "min_decryption_version": 0, + } + + _, err := v.client.Logical().WriteWithContext(ctx, path, data) + if err != nil { + return fmt.Errorf("failed to enable vault key: %w", err) + } + + return nil +} + +// HealthCheck verifies Vault connectivity and seal status. +func (v *VaultProvider) HealthCheck(ctx context.Context) error { + health, err := v.client.Sys().HealthWithContext(ctx) + if err != nil { + return fmt.Errorf("vault health check failed: %w", err) + } + + if !health.Initialized { + return fmt.Errorf("vault is not initialized") + } + + if health.Sealed { + return fmt.Errorf("vault is sealed") + } + + if health.Standby { + return fmt.Errorf("vault is in standby mode") + } + + return nil +} + +// Close closes the Vault client connection. +func (v *VaultProvider) Close() error { + // Vault client doesn't require explicit cleanup + return nil +}