package providers import ( "context" "encoding/base64" "fmt" "time" "github.com/hashicorp/vault/api" kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" ) // VaultProvider implements KMSProvider for HashiCorp Vault. type VaultProvider struct { config kmsconfig.VaultConfig client *api.Client } // NewVaultProvider creates a new Vault KMS provider. func NewVaultProvider(config kmsconfig.VaultConfig) (*VaultProvider, error) { if err := config.Validate(); err != nil { return nil, fmt.Errorf("invalid vault config: %w", err) } // Create Vault client configuration vaultConfig := api.DefaultConfig() vaultConfig.Address = config.Address vaultConfig.Timeout = config.Timeout client, err := api.NewClient(vaultConfig) if err != nil { return nil, fmt.Errorf("failed to create vault client: %w", err) } // Authenticate based on auth method switch config.AuthMethod { case "approle": if err := authenticateAppRole(client, config.RoleID, config.SecretID); err != nil { return nil, fmt.Errorf("approle authentication failed: %w", err) } case "token": client.SetToken(config.Token) case "kubernetes": if err := authenticateKubernetes(client); err != nil { return nil, fmt.Errorf("kubernetes authentication failed: %w", err) } } return &VaultProvider{ config: config, client: client, }, nil } // authenticateAppRole authenticates using AppRole credentials. func authenticateAppRole(client *api.Client, roleID, secretID string) error { data := map[string]any{ "role_id": roleID, "secret_id": secretID, } secret, err := client.Logical().Write("auth/approle/login", data) if err != nil { return fmt.Errorf("approle login failed: %w", err) } if secret == nil || secret.Auth == nil { return fmt.Errorf("no auth info in approle response") } client.SetToken(secret.Auth.ClientToken) return nil } // authenticateKubernetes authenticates using Kubernetes service account. func authenticateKubernetes(client *api.Client) error { // Read token from default service account location // In production, this would be mounted by Kubernetes data := map[string]any{ "jwt": readServiceAccountToken(), "role": "fetchml", } secret, err := client.Logical().Write("auth/kubernetes/login", data) if err != nil { return fmt.Errorf("kubernetes login failed: %w", err) } if secret == nil || secret.Auth == nil { return fmt.Errorf("no auth info in kubernetes response") } client.SetToken(secret.Auth.ClientToken) return nil } // readServiceAccountToken reads the JWT token from the pod's service account. func readServiceAccountToken() string { // Standard Kubernetes service account token location // /var/run/secrets/kubernetes.io/serviceaccount/token // For now, return empty - would be read from file in real deployment return "" } // Encrypt encrypts plaintext using Vault Transit engine. func (v *VaultProvider) Encrypt(ctx context.Context, keyID string, plaintext []byte) ([]byte, error) { path := fmt.Sprintf("%s/encrypt/%s", v.config.TransitMount, keyID) // Vault expects base64-encoded plaintext b64Plaintext := base64.StdEncoding.EncodeToString(plaintext) data := map[string]any{ "plaintext": b64Plaintext, } secret, err := v.client.Logical().WriteWithContext(ctx, path, data) if err != nil { return nil, fmt.Errorf("vault encrypt failed: %w", err) } if secret == nil || secret.Data == nil { return nil, fmt.Errorf("no data in vault encrypt response") } ciphertext, ok := secret.Data["ciphertext"].(string) if !ok { return nil, fmt.Errorf("ciphertext not found in vault response") } return []byte(ciphertext), nil } // Decrypt decrypts ciphertext using Vault Transit engine. func (v *VaultProvider) Decrypt(ctx context.Context, keyID string, ciphertext []byte) ([]byte, error) { path := fmt.Sprintf("%s/decrypt/%s", v.config.TransitMount, keyID) data := map[string]any{ "ciphertext": string(ciphertext), } secret, err := v.client.Logical().WriteWithContext(ctx, path, data) if err != nil { return nil, fmt.Errorf("vault decrypt failed: %w", err) } if secret == nil || secret.Data == nil { return nil, fmt.Errorf("no data in vault decrypt response") } b64Plaintext, ok := secret.Data["plaintext"].(string) if !ok { return nil, fmt.Errorf("plaintext not found in vault response") } // Decode base64 plaintext plaintext, err := base64.StdEncoding.DecodeString(b64Plaintext) if err != nil { return nil, fmt.Errorf("failed to decode vault plaintext: %w", err) } return plaintext, nil } // CreateKey creates a new encryption key in Vault Transit. func (v *VaultProvider) CreateKey(ctx context.Context, tenantID string) (string, error) { // Generate key name with per-region prefix per ADR-014 keyID := v.config.KeyPrefix if v.config.Region != "" { keyID = fmt.Sprintf("%s-%s-%s", keyID, v.config.Region, tenantID) } else { keyID = fmt.Sprintf("%s-%s", keyID, tenantID) } path := fmt.Sprintf("%s/keys/%s", v.config.TransitMount, keyID) // Check if key already exists secret, err := v.client.Logical().ReadWithContext(ctx, path) if err != nil { return "", fmt.Errorf("failed to check existing key: %w", err) } if secret != nil { // Key already exists, return the ID return keyID, nil } // Create new key data := map[string]any{ "type": "aes-256-gcm", "exportable": false, "allow_plaintext_backup": false, } _, err = v.client.Logical().WriteWithContext(ctx, path, data) if err != nil { return "", fmt.Errorf("failed to create vault key: %w", err) } return keyID, nil } // DisableKey disables a key in Vault by configuring it to not allow encryption. func (v *VaultProvider) DisableKey(ctx context.Context, keyID string) error { path := fmt.Sprintf("%s/keys/%s/config", v.config.TransitMount, keyID) data := map[string]any{ "deletion_allowed": false, "exportable": false, "allow_plaintext_backup": false, // Disable encryption operations "min_encryption_version": 999999, // Set to impossibly high version } _, err := v.client.Logical().WriteWithContext(ctx, path, data) if err != nil { return fmt.Errorf("failed to disable vault key: %w", err) } return nil } // ScheduleKeyDeletion schedules hard deletion of a key. // Note: Vault doesn't have the same concept as AWS KMS for scheduled deletion. // We implement this by rotating the key and marking it for deletion. func (v *VaultProvider) ScheduleKeyDeletion(ctx context.Context, keyID string, windowDays int) (time.Time, error) { // In Vault, we rotate the key to a new version and plan to delete old versions path := fmt.Sprintf("%s/keys/%s/rotate", v.config.TransitMount, keyID) _, err := v.client.Logical().WriteWithContext(ctx, path, nil) if err != nil { return time.Time{}, fmt.Errorf("failed to rotate vault key: %w", err) } // Return deletion date (windowDays from now) deletionDate := time.Now().Add(time.Duration(windowDays) * 24 * time.Hour) return deletionDate, nil } // EnableKey re-enables a disabled key. func (v *VaultProvider) EnableKey(ctx context.Context, keyID string) error { path := fmt.Sprintf("%s/keys/%s/config", v.config.TransitMount, keyID) data := map[string]any{ "deletion_allowed": false, "exportable": false, "allow_plaintext_backup": false, // Reset encryption version to enable operations "min_encryption_version": 0, "min_decryption_version": 0, } _, err := v.client.Logical().WriteWithContext(ctx, path, data) if err != nil { return fmt.Errorf("failed to enable vault key: %w", err) } return nil } // HealthCheck verifies Vault connectivity and seal status. func (v *VaultProvider) HealthCheck(ctx context.Context) error { health, err := v.client.Sys().HealthWithContext(ctx) if err != nil { return fmt.Errorf("vault health check failed: %w", err) } if !health.Initialized { return fmt.Errorf("vault is not initialized") } if health.Sealed { return fmt.Errorf("vault is sealed") } if health.Standby { return fmt.Errorf("vault is in standby mode") } return nil } // Close closes the Vault client connection. func (v *VaultProvider) Close() error { // Vault client doesn't require explicit cleanup return nil }