fetch_ml/internal/crypto/kms/provider_test.go
Jeremie Fraeys f827ee522a
test(tracking/plugins): add PodmanInterface and comprehensive plugin tests for 91% coverage
Refactor plugins to use interface for testability:
- Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer)
- Update MLflow plugin to use container.PodmanInterface
- Update TensorBoard plugin to use container.PodmanInterface
- Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard)
- Coverage increased from 18% to 91.4%
2026-03-14 16:59:16 -04:00

269 lines
6.9 KiB
Go

package kms_test
import (
"context"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNewProviderFactory tests factory creation
func TestNewProviderFactory(t *testing.T) {
t.Parallel()
cfg := config.Config{
Provider: config.ProviderTypeMemory,
}
factory := kms.NewProviderFactory(cfg)
require.NotNil(t, factory)
}
// TestCreateProviderMemory tests creating memory provider
func TestCreateProviderMemory(t *testing.T) {
t.Parallel()
cfg := config.Config{
Provider: config.ProviderTypeMemory,
}
factory := kms.NewProviderFactory(cfg)
provider, err := factory.CreateProvider()
require.NoError(t, err)
require.NotNil(t, provider)
// Verify it's a memory provider
err = provider.HealthCheck(context.Background())
require.NoError(t, err)
err = provider.Close()
require.NoError(t, err)
}
// TestCreateProviderUnsupported tests unsupported provider type
func TestCreateProviderUnsupported(t *testing.T) {
t.Parallel()
cfg := config.Config{
Provider: "unsupported",
}
factory := kms.NewProviderFactory(cfg)
_, err := factory.CreateProvider()
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported KMS provider")
}
// TestMemoryProviderCreateKey tests key creation
func TestMemoryProviderCreateKey(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
keyID, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
require.NotEmpty(t, keyID)
assert.Contains(t, keyID, "memory-tenant-1")
}
// TestMemoryProviderEncryptDecrypt tests encryption and decryption
func TestMemoryProviderEncryptDecrypt(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Create a key
keyID, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
// Encrypt data
plaintext := []byte("secret data to encrypt")
ciphertext, err := provider.Encrypt(ctx, keyID, plaintext)
require.NoError(t, err)
require.NotNil(t, ciphertext)
// Ciphertext should be different from plaintext
assert.NotEqual(t, plaintext, ciphertext)
// Decrypt data
decrypted, err := provider.Decrypt(ctx, keyID, ciphertext)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
// TestMemoryProviderEncryptKeyNotFound tests encryption with nonexistent key
func TestMemoryProviderEncryptKeyNotFound(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
_, err := provider.Encrypt(ctx, "nonexistent-key", []byte("data"))
require.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
}
// TestMemoryProviderDecryptKeyNotFound tests decryption with nonexistent key
func TestMemoryProviderDecryptKeyNotFound(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
_, err := provider.Decrypt(ctx, "nonexistent-key", []byte("data"))
require.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
}
// TestMemoryProviderDecryptCiphertextTooShort tests decryption with short ciphertext
func TestMemoryProviderDecryptCiphertextTooShort(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Create a key
keyID, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
// Try to decrypt data that's too short
_, err = provider.Decrypt(ctx, keyID, []byte("short"))
require.Error(t, err)
assert.Contains(t, err.Error(), "ciphertext too short")
}
// TestMemoryProviderDecryptMACVerificationFailed tests MAC verification failure
func TestMemoryProviderDecryptMACVerificationFailed(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Create two different keys
keyID1, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
keyID2, err := provider.CreateKey(ctx, "tenant-2")
require.NoError(t, err)
// Encrypt with key1
plaintext := []byte("secret data")
ciphertext, err := provider.Encrypt(ctx, keyID1, plaintext)
require.NoError(t, err)
// Try to decrypt with key2 (should fail MAC verification)
_, err = provider.Decrypt(ctx, keyID2, ciphertext)
require.Error(t, err)
assert.Contains(t, err.Error(), "MAC verification failed")
}
// TestMemoryProviderDisableKey tests disabling a key
func TestMemoryProviderDisableKey(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Create a key
keyID, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
// Disable the key (no-op in memory provider)
err = provider.DisableKey(ctx, keyID)
require.NoError(t, err)
// Key should still work for memory provider
plaintext := []byte("data")
ciphertext, err := provider.Encrypt(ctx, keyID, plaintext)
require.NoError(t, err)
decrypted, err := provider.Decrypt(ctx, keyID, ciphertext)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
// TestMemoryProviderEnableKey tests enabling a key
func TestMemoryProviderEnableKey(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Enable a key (no-op in memory provider)
err := provider.EnableKey(ctx, "any-key")
require.NoError(t, err)
}
// TestMemoryProviderScheduleKeyDeletion tests key deletion scheduling
func TestMemoryProviderScheduleKeyDeletion(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Create a key
keyID, err := provider.CreateKey(ctx, "tenant-1")
require.NoError(t, err)
// Schedule deletion
deletionDate, err := provider.ScheduleKeyDeletion(ctx, keyID, 7)
require.NoError(t, err)
assert.WithinDuration(t, time.Now().Add(7*24*time.Hour), deletionDate, time.Second)
// Key should be deleted
_, err = provider.Encrypt(ctx, keyID, []byte("data"))
require.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
}
// TestMemoryProviderHealthCheck tests health check
func TestMemoryProviderHealthCheck(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
err := provider.HealthCheck(ctx)
require.NoError(t, err)
}
// TestMemoryProviderClose tests closing provider
func TestMemoryProviderClose(t *testing.T) {
t.Parallel()
provider := kms.NewMemoryProvider()
err := provider.Close()
require.NoError(t, err)
}
// TestProviderTypeConstants tests provider type constants
func TestProviderTypeConstants(t *testing.T) {
t.Parallel()
assert.Equal(t, config.ProviderType("vault"), kms.ProviderTypeVault)
assert.Equal(t, config.ProviderType("aws"), kms.ProviderTypeAWS)
assert.Equal(t, config.ProviderType("memory"), kms.ProviderTypeMemory)
}