Refactor plugins to use interface for testability: - Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer) - Update MLflow plugin to use container.PodmanInterface - Update TensorBoard plugin to use container.PodmanInterface - Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard) - Coverage increased from 18% to 91.4%
269 lines
6.9 KiB
Go
269 lines
6.9 KiB
Go
package kms_test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
|
|
"github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestNewProviderFactory tests factory creation
|
|
func TestNewProviderFactory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg := config.Config{
|
|
Provider: config.ProviderTypeMemory,
|
|
}
|
|
|
|
factory := kms.NewProviderFactory(cfg)
|
|
require.NotNil(t, factory)
|
|
}
|
|
|
|
// TestCreateProviderMemory tests creating memory provider
|
|
func TestCreateProviderMemory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg := config.Config{
|
|
Provider: config.ProviderTypeMemory,
|
|
}
|
|
|
|
factory := kms.NewProviderFactory(cfg)
|
|
provider, err := factory.CreateProvider()
|
|
require.NoError(t, err)
|
|
require.NotNil(t, provider)
|
|
|
|
// Verify it's a memory provider
|
|
err = provider.HealthCheck(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
err = provider.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestCreateProviderUnsupported tests unsupported provider type
|
|
func TestCreateProviderUnsupported(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg := config.Config{
|
|
Provider: "unsupported",
|
|
}
|
|
|
|
factory := kms.NewProviderFactory(cfg)
|
|
_, err := factory.CreateProvider()
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "unsupported KMS provider")
|
|
}
|
|
|
|
// TestMemoryProviderCreateKey tests key creation
|
|
func TestMemoryProviderCreateKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
keyID, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, keyID)
|
|
assert.Contains(t, keyID, "memory-tenant-1")
|
|
}
|
|
|
|
// TestMemoryProviderEncryptDecrypt tests encryption and decryption
|
|
func TestMemoryProviderEncryptDecrypt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Create a key
|
|
keyID, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
|
|
// Encrypt data
|
|
plaintext := []byte("secret data to encrypt")
|
|
ciphertext, err := provider.Encrypt(ctx, keyID, plaintext)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, ciphertext)
|
|
|
|
// Ciphertext should be different from plaintext
|
|
assert.NotEqual(t, plaintext, ciphertext)
|
|
|
|
// Decrypt data
|
|
decrypted, err := provider.Decrypt(ctx, keyID, ciphertext)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, plaintext, decrypted)
|
|
}
|
|
|
|
// TestMemoryProviderEncryptKeyNotFound tests encryption with nonexistent key
|
|
func TestMemoryProviderEncryptKeyNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
_, err := provider.Encrypt(ctx, "nonexistent-key", []byte("data"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "key not found")
|
|
}
|
|
|
|
// TestMemoryProviderDecryptKeyNotFound tests decryption with nonexistent key
|
|
func TestMemoryProviderDecryptKeyNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
_, err := provider.Decrypt(ctx, "nonexistent-key", []byte("data"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "key not found")
|
|
}
|
|
|
|
// TestMemoryProviderDecryptCiphertextTooShort tests decryption with short ciphertext
|
|
func TestMemoryProviderDecryptCiphertextTooShort(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Create a key
|
|
keyID, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
|
|
// Try to decrypt data that's too short
|
|
_, err = provider.Decrypt(ctx, keyID, []byte("short"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "ciphertext too short")
|
|
}
|
|
|
|
// TestMemoryProviderDecryptMACVerificationFailed tests MAC verification failure
|
|
func TestMemoryProviderDecryptMACVerificationFailed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Create two different keys
|
|
keyID1, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
keyID2, err := provider.CreateKey(ctx, "tenant-2")
|
|
require.NoError(t, err)
|
|
|
|
// Encrypt with key1
|
|
plaintext := []byte("secret data")
|
|
ciphertext, err := provider.Encrypt(ctx, keyID1, plaintext)
|
|
require.NoError(t, err)
|
|
|
|
// Try to decrypt with key2 (should fail MAC verification)
|
|
_, err = provider.Decrypt(ctx, keyID2, ciphertext)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "MAC verification failed")
|
|
}
|
|
|
|
// TestMemoryProviderDisableKey tests disabling a key
|
|
func TestMemoryProviderDisableKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Create a key
|
|
keyID, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
|
|
// Disable the key (no-op in memory provider)
|
|
err = provider.DisableKey(ctx, keyID)
|
|
require.NoError(t, err)
|
|
|
|
// Key should still work for memory provider
|
|
plaintext := []byte("data")
|
|
ciphertext, err := provider.Encrypt(ctx, keyID, plaintext)
|
|
require.NoError(t, err)
|
|
|
|
decrypted, err := provider.Decrypt(ctx, keyID, ciphertext)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, plaintext, decrypted)
|
|
}
|
|
|
|
// TestMemoryProviderEnableKey tests enabling a key
|
|
func TestMemoryProviderEnableKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Enable a key (no-op in memory provider)
|
|
err := provider.EnableKey(ctx, "any-key")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestMemoryProviderScheduleKeyDeletion tests key deletion scheduling
|
|
func TestMemoryProviderScheduleKeyDeletion(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Create a key
|
|
keyID, err := provider.CreateKey(ctx, "tenant-1")
|
|
require.NoError(t, err)
|
|
|
|
// Schedule deletion
|
|
deletionDate, err := provider.ScheduleKeyDeletion(ctx, keyID, 7)
|
|
require.NoError(t, err)
|
|
assert.WithinDuration(t, time.Now().Add(7*24*time.Hour), deletionDate, time.Second)
|
|
|
|
// Key should be deleted
|
|
_, err = provider.Encrypt(ctx, keyID, []byte("data"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "key not found")
|
|
}
|
|
|
|
// TestMemoryProviderHealthCheck tests health check
|
|
func TestMemoryProviderHealthCheck(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
err := provider.HealthCheck(ctx)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestMemoryProviderClose tests closing provider
|
|
func TestMemoryProviderClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := kms.NewMemoryProvider()
|
|
err := provider.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestProviderTypeConstants tests provider type constants
|
|
func TestProviderTypeConstants(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, config.ProviderType("vault"), kms.ProviderTypeVault)
|
|
assert.Equal(t, config.ProviderType("aws"), kms.ProviderTypeAWS)
|
|
assert.Equal(t, config.ProviderType("memory"), kms.ProviderTypeMemory)
|
|
}
|