build: update Makefile test paths and clean up old test references
Update test paths to reflect new test locations: - Change tests/unit/... to internal/... in test, test-unit, and verify-audit targets - Update test-coverage target to use correct coverpkg paths - Add coverage summary output to test-coverage target Clean up deleted test files from old locations: - Remove tests/unit/crypto/kms/cache_test.go (now in internal/auth/kms/) - Remove tests/unit/crypto/kms/protocol_test.go (now in internal/auth/kms/) - Remove tests/unit/resources/manager_test.go (now in internal/resources/)
This commit is contained in:
parent
d0266c4a90
commit
8a30acf661
4 changed files with 5 additions and 760 deletions
9
Makefile
9
Makefile
|
|
@ -177,14 +177,14 @@ verify-build:
|
|||
test: test-infra-up
|
||||
@arch=$$(uname -m | sed 's/x86_64/amd64/'); os=$$(uname -s | tr '[:upper:]' '[:lower:]'); \
|
||||
[ -f bin/cli/ml-$$os-$$arch ] || $(MAKE) build-cli
|
||||
@go test -v ./tests/unit/... ./tests/integration/... ./tests/e2e/... 2>&1 | grep -v "redis: connection pool" | tee /tmp/test-all.txt || true
|
||||
@go test -v ./internal/... ./tests/integration/... ./tests/e2e/... 2>&1 | grep -v "redis: connection pool" | tee /tmp/test-all.txt || true
|
||||
@echo "\n=== Test Summary ==="
|
||||
$(call test_summary,/tmp/test-all.txt)
|
||||
@$(MAKE) test-infra-down
|
||||
@cd cli && zig build test
|
||||
|
||||
test-unit:
|
||||
@go test -v -short ./tests/unit/... 2>&1 | tee /tmp/test-unit.txt || true
|
||||
@go test -v -short ./internal/... 2>&1 | tee /tmp/test-unit.txt || true
|
||||
@echo "\n=== Unit Test Summary ==="
|
||||
$(call test_summary,/tmp/test-unit.txt)
|
||||
@cd cli && zig build test
|
||||
|
|
@ -205,9 +205,10 @@ test-e2e: test-infra-up
|
|||
|
||||
test-coverage:
|
||||
@mkdir -p coverage
|
||||
go test -coverprofile=coverage/coverage.out ./...
|
||||
go test -coverprofile=coverage/coverage.out -coverpkg=./internal/...,./cmd/... ./internal/... ./tests/integration/... ./tests/e2e/...
|
||||
go tool cover -html=coverage/coverage.out -o coverage/coverage.html
|
||||
@echo "$(OK) Coverage report: coverage/coverage.html"
|
||||
@go tool cover -func=coverage/coverage.out | tail -1
|
||||
|
||||
consistency-test: build native-build
|
||||
@echo "Running cross-implementation consistency tests..."
|
||||
|
|
@ -353,7 +354,7 @@ lint-custom:
|
|||
@echo "$(OK) Custom lint complete"
|
||||
|
||||
verify-audit:
|
||||
@go test ./tests/unit/audit/... -run TestChainVerifier -v
|
||||
@go test ./internal/audit/... -run TestChainVerifier -v
|
||||
@echo "$(OK) Audit chain verification passed"
|
||||
|
||||
verify-audit-chain:
|
||||
|
|
|
|||
|
|
@ -1,255 +0,0 @@
|
|||
package kms_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
|
||||
kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
|
||||
)
|
||||
|
||||
// TestDEKCache_PutAndGet tests basic cache put and get operations.
|
||||
func TestDEKCache_PutAndGet(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
tenantID := "tenant-1"
|
||||
artifactID := "artifact-1"
|
||||
dek := []byte("test-dek-data-12345678901234567890123456789012")
|
||||
|
||||
// Put DEK in cache
|
||||
if err := cache.Put(tenantID, artifactID, "kms-key-1", dek); err != nil {
|
||||
t.Fatalf("Put failed: %v", err)
|
||||
}
|
||||
|
||||
// Get DEK from cache (KMS available)
|
||||
retrieved, ok := cache.Get(tenantID, artifactID, "kms-key-1", false)
|
||||
if !ok {
|
||||
t.Fatal("Get returned false, expected true")
|
||||
}
|
||||
|
||||
if !bytes.Equal(retrieved, dek) {
|
||||
t.Error("Retrieved DEK doesn't match original")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_GetNonexistent tests getting a non-existent entry.
|
||||
func TestDEKCache_GetNonexistent(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
_, ok := cache.Get("nonexistent", "nonexistent", "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("Get for non-existent key should return false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_TTLExpiry tests that entries expire after TTL.
|
||||
func TestDEKCache_TTLExpiry(t *testing.T) {
|
||||
// Use very short TTL for testing
|
||||
config := kmsconfig.CacheConfig{
|
||||
TTL: 50 * time.Millisecond,
|
||||
MaxEntries: 100,
|
||||
GraceWindow: 100 * time.Millisecond,
|
||||
}
|
||||
cache := kms.NewDEKCache(config)
|
||||
defer cache.Clear()
|
||||
|
||||
tenantID := "tenant-1"
|
||||
artifactID := "artifact-1"
|
||||
dek := []byte("test-dek-data-12345678901234567890123456789012")
|
||||
|
||||
// Put DEK
|
||||
cache.Put(tenantID, artifactID, "kms-key-1", dek)
|
||||
|
||||
// Should be available immediately
|
||||
_, ok := cache.Get(tenantID, artifactID, "kms-key-1", false)
|
||||
if !ok {
|
||||
t.Error("DEK should be available immediately after put")
|
||||
}
|
||||
|
||||
// Wait for TTL to expire
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Should not be available after TTL (KMS available)
|
||||
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("DEK should not be available after TTL expires")
|
||||
}
|
||||
|
||||
// Should be available in grace window (KMS unavailable)
|
||||
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", true)
|
||||
if !ok {
|
||||
t.Error("DEK should be available in grace window when KMS is unavailable")
|
||||
}
|
||||
|
||||
// Wait for grace window to expire
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should not be available after grace window
|
||||
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", true)
|
||||
if ok {
|
||||
t.Error("DEK should not be available after grace window expires")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_LRUeviction tests LRU eviction when cache is full.
|
||||
func TestDEKCache_LRUeviction(t *testing.T) {
|
||||
config := kmsconfig.CacheConfig{
|
||||
TTL: 1 * time.Hour, // Long TTL so eviction is due to size
|
||||
MaxEntries: 3,
|
||||
GraceWindow: 1 * time.Hour,
|
||||
}
|
||||
cache := kms.NewDEKCache(config)
|
||||
defer cache.Clear()
|
||||
|
||||
// Add 3 entries (at capacity)
|
||||
for i := 0; i < 3; i++ {
|
||||
dek := []byte("dek-data-12345678901234567890123456789012-" + string(rune('0'+i)))
|
||||
cache.Put("tenant-1", string(rune('a'+i)), "kms-key-1", dek)
|
||||
}
|
||||
|
||||
// Access first entry to make it recently used
|
||||
cache.Get("tenant-1", "a", "kms-key-1", false)
|
||||
|
||||
// Add 4th entry (should evict 'b' as it's the oldest unaccessed)
|
||||
dek4 := []byte("dek-data-12345678901234567890123456789012-4")
|
||||
cache.Put("tenant-1", "d", "kms-key-1", dek4)
|
||||
|
||||
// 'a' should still exist (was accessed)
|
||||
_, ok := cache.Get("tenant-1", "a", "kms-key-1", false)
|
||||
if !ok {
|
||||
t.Error("Entry 'a' should still exist after eviction")
|
||||
}
|
||||
|
||||
// 'b' should be evicted
|
||||
_, ok = cache.Get("tenant-1", "b", "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("Entry 'b' should have been evicted")
|
||||
}
|
||||
|
||||
// 'c' and 'd' should exist
|
||||
_, ok = cache.Get("tenant-1", "c", "kms-key-1", false)
|
||||
if !ok {
|
||||
t.Error("Entry 'c' should still exist")
|
||||
}
|
||||
_, ok = cache.Get("tenant-1", "d", "kms-key-1", false)
|
||||
if !ok {
|
||||
t.Error("Entry 'd' should exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_Flush tests flushing entries for a specific tenant.
|
||||
func TestDEKCache_Flush(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
// Add entries for two tenants
|
||||
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
|
||||
cache.Put("tenant-1", "artifact-2", "kms-key-1", []byte("dek-2"))
|
||||
cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-3"))
|
||||
|
||||
// Flush tenant-1
|
||||
cache.Flush("tenant-1")
|
||||
|
||||
// tenant-1 entries should be gone
|
||||
_, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("tenant-1 artifact-1 should be flushed")
|
||||
}
|
||||
_, ok = cache.Get("tenant-1", "artifact-2", "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("tenant-1 artifact-2 should be flushed")
|
||||
}
|
||||
|
||||
// tenant-2 entry should still exist
|
||||
_, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false)
|
||||
if !ok {
|
||||
t.Error("tenant-2 artifact-1 should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_Clear tests clearing all entries.
|
||||
func TestDEKCache_Clear(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
|
||||
// Add entries
|
||||
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
|
||||
cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-2"))
|
||||
|
||||
// Clear
|
||||
cache.Clear()
|
||||
|
||||
// All entries should be gone
|
||||
_, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false)
|
||||
if ok {
|
||||
t.Error("All entries should be cleared")
|
||||
}
|
||||
_, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false)
|
||||
if ok {
|
||||
t.Error("All entries should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_Stats tests cache statistics.
|
||||
func TestDEKCache_Stats(t *testing.T) {
|
||||
config := kmsconfig.DefaultCacheConfig()
|
||||
cache := kms.NewDEKCache(config)
|
||||
defer cache.Clear()
|
||||
|
||||
stats := cache.Stats()
|
||||
|
||||
if stats.Size != 0 {
|
||||
t.Errorf("Initial size should be 0, got %d", stats.Size)
|
||||
}
|
||||
if stats.MaxSize != config.MaxEntries {
|
||||
t.Errorf("MaxSize should be %d, got %d", config.MaxEntries, stats.MaxSize)
|
||||
}
|
||||
if stats.TTL != config.TTL {
|
||||
t.Errorf("TTL should be %v, got %v", config.TTL, stats.TTL)
|
||||
}
|
||||
if stats.GraceWindow != config.GraceWindow {
|
||||
t.Errorf("GraceWindow should be %v, got %v", config.GraceWindow, stats.GraceWindow)
|
||||
}
|
||||
|
||||
// Add entry
|
||||
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
|
||||
|
||||
stats = cache.Stats()
|
||||
if stats.Size != 1 {
|
||||
t.Errorf("Size should be 1 after put, got %d", stats.Size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_EmptyDEK tests that empty DEK is rejected.
|
||||
func TestDEKCache_EmptyDEK(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
err := cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte{})
|
||||
if err == nil {
|
||||
t.Error("Should reject empty DEK")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDEKCache_Isolation tests that DEKs are isolated between tenants.
|
||||
func TestDEKCache_Isolation(t *testing.T) {
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
// Same artifact ID, different tenants
|
||||
cache.Put("tenant-1", "shared-artifact", "kms-key-1", []byte("dek-for-tenant-1"))
|
||||
cache.Put("tenant-2", "shared-artifact", "kms-key-2", []byte("dek-for-tenant-2"))
|
||||
|
||||
// Each tenant should get their own DEK
|
||||
d1, ok := cache.Get("tenant-1", "shared-artifact", "kms-key-1", false)
|
||||
if !ok || string(d1) != "dek-for-tenant-1" {
|
||||
t.Error("tenant-1 should get their own DEK")
|
||||
}
|
||||
|
||||
d2, ok := cache.Get("tenant-2", "shared-artifact", "kms-key-2", false)
|
||||
if !ok || string(d2) != "dek-for-tenant-2" {
|
||||
t.Error("tenant-2 should get their own DEK")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,335 +0,0 @@
|
|||
package kms_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api"
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto"
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
|
||||
kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
|
||||
)
|
||||
|
||||
func TestProtocolSerialization(t *testing.T) {
|
||||
// Test success packet
|
||||
successPacket := api.NewSuccessPacket("Operation completed successfully")
|
||||
data, err := successPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize success packet: %v", err)
|
||||
}
|
||||
|
||||
// Verify packet type
|
||||
if len(data) < 1 || data[0] != api.PacketTypeSuccess {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0])
|
||||
}
|
||||
|
||||
// Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp)
|
||||
if len(data) < 9 {
|
||||
t.Errorf("Expected at least 9 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
// Test error packet - uses string error code from errors package
|
||||
errorPacket := api.NewErrorPacket("AUTHENTICATION_FAILED", "Auth failed", "Invalid API key")
|
||||
data, err = errorPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize error packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeError {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0])
|
||||
}
|
||||
|
||||
// Test progress packet
|
||||
progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...")
|
||||
data, err = progressPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize progress packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeProgress {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0])
|
||||
}
|
||||
|
||||
// Test status packet
|
||||
statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`)
|
||||
data, err = statusPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize status packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeStatus {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCodeFromErrorCode(t *testing.T) {
|
||||
tests := map[string]byte{
|
||||
"UNKNOWN_ERROR": api.ErrorCodeUnknownError,
|
||||
"AUTHENTICATION_FAILED": api.ErrorCodeAuthenticationFailed,
|
||||
"JOB_NOT_FOUND": api.ErrorCodeJobNotFound,
|
||||
"SERVER_OVERLOADED": api.ErrorCodeServerOverloaded,
|
||||
"INVALID_REQUEST": api.ErrorCodeInvalidRequest,
|
||||
"BAD_REQUEST": api.ErrorCodeInvalidRequest,
|
||||
"PERMISSION_DENIED": api.ErrorCodePermissionDenied,
|
||||
"FORBIDDEN": api.ErrorCodePermissionDenied,
|
||||
}
|
||||
|
||||
for code, expectedByte := range tests {
|
||||
actual := api.ByteCodeFromErrorCode(code)
|
||||
if actual != expectedByte {
|
||||
t.Errorf("Expected byte %d for code '%s', got %d", expectedByte, code, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevelMapping(t *testing.T) {
|
||||
tests := map[byte]string{
|
||||
api.LogLevelDebug: "DEBUG",
|
||||
api.LogLevelInfo: "INFO",
|
||||
api.LogLevelWarn: "WARN",
|
||||
api.LogLevelError: "ERROR",
|
||||
}
|
||||
|
||||
for level, expected := range tests {
|
||||
actual := api.GetLogLevelName(level)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampConsistency(t *testing.T) {
|
||||
before := time.Now().Unix()
|
||||
|
||||
packet := api.NewSuccessPacket("Test message")
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
after := time.Now().Unix()
|
||||
|
||||
// Extract timestamp (bytes 1-8, big-endian)
|
||||
if len(data) < 9 {
|
||||
t.Fatalf("Packet too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
timestamp := binary.BigEndian.Uint64(data[1:9])
|
||||
|
||||
if timestamp < uint64(before) || timestamp > uint64(after) {
|
||||
t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol.
|
||||
func TestKMSProtocol_EncryptDecrypt(t *testing.T) {
|
||||
// Create memory provider for testing
|
||||
provider := kms.NewMemoryProvider()
|
||||
defer provider.Close()
|
||||
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
config := kmsconfig.Config{
|
||||
Provider: kms.ProviderTypeMemory,
|
||||
Cache: kmsconfig.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant")
|
||||
if err != nil {
|
||||
t.Fatalf("ProvisionTenant failed: %v", err)
|
||||
}
|
||||
|
||||
// Test data - simulate artifact data
|
||||
plaintext := []byte("sensitive model weights and training data")
|
||||
|
||||
// Encrypt
|
||||
encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptArtifact failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify encrypted structure
|
||||
if encrypted.Ciphertext == "" {
|
||||
t.Error("Ciphertext should not be empty")
|
||||
}
|
||||
if encrypted.DEK == nil {
|
||||
t.Error("DEK should not be nil")
|
||||
}
|
||||
if encrypted.KMSKeyID != hierarchy.KMSKeyID {
|
||||
t.Error("KMSKeyID should match")
|
||||
}
|
||||
if encrypted.Algorithm != "AES-256-GCM" {
|
||||
t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptArtifact failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify round-trip
|
||||
if !bytes.Equal(decrypted, plaintext) {
|
||||
t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data.
|
||||
func TestKMSProtocol_MultiTenantIsolation(t *testing.T) {
|
||||
provider := kms.NewMemoryProvider()
|
||||
defer provider.Close()
|
||||
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
config := kmsconfig.Config{
|
||||
Provider: kms.ProviderTypeMemory,
|
||||
Cache: kmsconfig.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision two tenants
|
||||
tenant1, err := tkm.ProvisionTenant("tenant-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to provision tenant-1: %v", err)
|
||||
}
|
||||
|
||||
tenant2, err := tkm.ProvisionTenant("tenant-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to provision tenant-2: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt data for tenant-1
|
||||
plaintext := []byte("tenant-1 secret data")
|
||||
encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to decrypt with tenant-2's key - should fail
|
||||
_, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID)
|
||||
if err == nil {
|
||||
t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)")
|
||||
}
|
||||
|
||||
// Tenant-1 should still be able to decrypt
|
||||
decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Tenant-1 decrypt failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(decrypted, plaintext) {
|
||||
t.Error("Tenant-1 should decrypt their own data correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKMSProtocol_CacheHit verifies cached DEKs work correctly.
|
||||
func TestKMSProtocol_CacheHit(t *testing.T) {
|
||||
provider := kms.NewMemoryProvider()
|
||||
defer provider.Close()
|
||||
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
config := kmsconfig.Config{
|
||||
Provider: kms.ProviderTypeMemory,
|
||||
Cache: kmsconfig.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
hierarchy, _ := tkm.ProvisionTenant("cache-test")
|
||||
|
||||
plaintext := []byte("test data for caching")
|
||||
|
||||
// First encrypt
|
||||
encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext)
|
||||
|
||||
// Decrypt multiple times - should hit cache
|
||||
for i := 0; i < 3; i++ {
|
||||
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt %d failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(decrypted, plaintext) {
|
||||
t.Errorf("Decrypt %d: data mismatch", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cache has entries
|
||||
stats := cache.Stats()
|
||||
if stats.Size == 0 {
|
||||
t.Error("Cache should have entries after operations")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKMSProtocol_KeyRotation tests key rotation protocol.
|
||||
func TestKMSProtocol_KeyRotation(t *testing.T) {
|
||||
provider := kms.NewMemoryProvider()
|
||||
defer provider.Close()
|
||||
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
defer cache.Clear()
|
||||
|
||||
config := kmsconfig.Config{
|
||||
Provider: kms.ProviderTypeMemory,
|
||||
Cache: kmsconfig.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, _ := tkm.ProvisionTenant("rotation-test")
|
||||
oldKeyID := hierarchy.KMSKeyID
|
||||
|
||||
// Rotate key
|
||||
newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy)
|
||||
if err != nil {
|
||||
t.Fatalf("Key rotation failed: %v", err)
|
||||
}
|
||||
|
||||
if newHierarchy.KMSKeyID == oldKeyID {
|
||||
t.Error("New key should have different ID after rotation")
|
||||
}
|
||||
|
||||
// Cache should be flushed after rotation
|
||||
stats := cache.Stats()
|
||||
if stats.Size != 0 {
|
||||
t.Error("Cache should be flushed after key rotation")
|
||||
}
|
||||
|
||||
// Encrypt with new key
|
||||
plaintext2 := []byte("data encrypted with new key")
|
||||
encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2)
|
||||
|
||||
// Decrypt with new key
|
||||
decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt with new key failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(decrypted2, plaintext2) {
|
||||
t.Error("Data encrypted with new key should decrypt correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKMSProvider_HealthCheck tests health check protocol.
|
||||
func TestKMSProvider_HealthCheck(t *testing.T) {
|
||||
provider := kms.NewMemoryProvider()
|
||||
defer provider.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Memory provider should always be healthy
|
||||
if err := provider.HealthCheck(ctx); err != nil {
|
||||
t.Errorf("Memory provider health check failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
package resources_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestManager_CPUAcquireBlocksUntilRelease(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 4, GPUCount: 0, SlotsPerGPU: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
task1 := &queue.Task{CPU: 3}
|
||||
lease1, err := m.Acquire(context.Background(), task1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lease1)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = m.Acquire(ctx, &queue.Task{CPU: 2})
|
||||
require.Error(t, err)
|
||||
|
||||
lease1.Release()
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel2()
|
||||
lease2, err := m.Acquire(ctx2, &queue.Task{CPU: 2})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lease2)
|
||||
lease2.Release()
|
||||
}
|
||||
|
||||
func TestManager_GPUSlotsAllowSharing(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 1, SlotsPerGPU: 4})
|
||||
require.NoError(t, err)
|
||||
|
||||
leases := make([]*resources.Lease, 0, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
l, err := m.Acquire(context.Background(), &queue.Task{GPU: 1, GPUMemory: "0.25"})
|
||||
require.NoError(t, err)
|
||||
leases = append(leases, l)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = m.Acquire(ctx, &queue.Task{GPU: 1, GPUMemory: "0.25"})
|
||||
require.Error(t, err)
|
||||
|
||||
for _, l := range leases {
|
||||
l.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MultiGPUExclusiveAllocation(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 2, SlotsPerGPU: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
lease, err := m.Acquire(context.Background(), &queue.Task{GPU: 2})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, lease.GPUs(), 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = m.Acquire(ctx, &queue.Task{GPU: 1})
|
||||
require.Error(t, err)
|
||||
|
||||
lease.Release()
|
||||
}
|
||||
|
||||
func TestFormatCUDAVisibleDevices_NoLeaseDisablesGPU(t *testing.T) {
|
||||
require.Equal(t, "-1", resources.FormatCUDAVisibleDevices(nil))
|
||||
}
|
||||
|
||||
func TestManager_GPUSlotsAllowSharing_Concurrent(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 1, SlotsPerGPU: 4})
|
||||
require.NoError(t, err)
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
leases := make(chan *resources.Lease, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
go func() {
|
||||
<-started
|
||||
l, err := m.Acquire(context.Background(), &queue.Task{GPU: 1, GPUMemory: "0.25"})
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
leases <- l
|
||||
<-release
|
||||
l.Release()
|
||||
errCh <- nil
|
||||
}()
|
||||
}
|
||||
close(started)
|
||||
|
||||
deadline := time.After(500 * time.Millisecond)
|
||||
acquired := make([]*resources.Lease, 0, 4)
|
||||
for len(acquired) < 4 {
|
||||
select {
|
||||
case l := <-leases:
|
||||
acquired = append(acquired, l)
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for leases; got %d", len(acquired))
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = m.Acquire(ctx, &queue.Task{GPU: 1, GPUMemory: "0.25"})
|
||||
require.Error(t, err)
|
||||
|
||||
close(release)
|
||||
for i := 0; i < 4; i++ {
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_CPUOnlyNotBlockedWhenGPUSaturated(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 4, GPUCount: 1, SlotsPerGPU: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
gpuLease, err := m.Acquire(context.Background(), &queue.Task{GPU: 1})
|
||||
require.NoError(t, err)
|
||||
defer gpuLease.Release()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
lease, err := m.Acquire(context.Background(), &queue.Task{CPU: 1})
|
||||
if err == nil {
|
||||
lease.Release()
|
||||
}
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("cpu-only acquire unexpectedly blocked by gpu saturation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_AcquireMetrics_RecordWaitAndTimeout(t *testing.T) {
|
||||
m, err := resources.NewManager(resources.Options{TotalCPU: 1, GPUCount: 0, SlotsPerGPU: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
lease, err := m.Acquire(context.Background(), &queue.Task{CPU: 1})
|
||||
require.NoError(t, err)
|
||||
defer lease.Release()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = m.Acquire(ctx, &queue.Task{CPU: 1})
|
||||
require.Error(t, err)
|
||||
|
||||
s := m.Snapshot()
|
||||
require.GreaterOrEqual(t, s.AcquireTotal, int64(2))
|
||||
require.GreaterOrEqual(t, s.AcquireTimeoutTotal, int64(1))
|
||||
}
|
||||
Loading…
Reference in a new issue