diff --git a/internal/crypto/signing_test.go b/internal/crypto/signing_test.go new file mode 100644 index 0000000..2f7e3db --- /dev/null +++ b/internal/crypto/signing_test.go @@ -0,0 +1,340 @@ +package crypto_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/jfraeys/fetch_ml/internal/crypto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGenerateSigningKeys tests Ed25519 key generation +func TestGenerateSigningKeys(t *testing.T) { + t.Parallel() + + pub, priv, err := crypto.GenerateSigningKeys() + require.NoError(t, err, "Key generation should succeed") + require.NotNil(t, pub, "Public key should not be nil") + require.NotNil(t, priv, "Private key should not be nil") + assert.Len(t, pub, 32, "Ed25519 public key should be 32 bytes") + assert.Len(t, priv, 64, "Ed25519 private key should be 64 bytes") + + // Keys should be different each time + pub2, priv2, _ := crypto.GenerateSigningKeys() + assert.NotEqual(t, pub, pub2, "Keys should be randomly generated") + assert.NotEqual(t, priv, priv2, "Keys should be randomly generated") +} + +// TestNewManifestSigner tests signer creation +func TestNewManifestSigner(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + keySize int + keyID string + wantErr bool + }{ + {"valid key", 64, "test-key", false}, + {"wrong size - too short", 32, "short-key", true}, + {"wrong size - too long", 96, "long-key", true}, + {"empty key", 0, "empty-key", true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var privKey []byte + if tc.keySize > 0 { + privKey = make([]byte, tc.keySize) + // Fill with non-zero data for realistic test + for i := range privKey { + privKey[i] = byte(i % 256) + } + } + + signer, err := crypto.NewManifestSigner(privKey, tc.keyID) + if tc.wantErr { + require.Error(t, err) + assert.Nil(t, signer) + return + } + require.NoError(t, err) + assert.NotNil(t, signer) + }) + } +} + +// TestSignVerifyRoundTrip tests sign → verify cycle +func TestSignVerifyRoundTrip(t *testing.T) { + t.Parallel() + + // Generate keys + pub, priv, err := crypto.GenerateSigningKeys() + require.NoError(t, err) + + // Create signer + signer, err := crypto.NewManifestSigner(priv, "test-key") + require.NoError(t, err) + + // Sign manifest + manifest := map[string]interface{}{ + "job_id": "test-123", + "command": "python train.py", + "config": map[string]interface{}{ + "epochs": 10, + "batch": 32, + }, + } + + result, err := signer.SignManifest(manifest) + require.NoError(t, err) + require.NotNil(t, result) + assert.NotEmpty(t, result.Signature, "Signature should not be empty") + assert.Equal(t, "test-key", result.KeyID) + assert.Equal(t, "Ed25519", result.Algorithm) + + // Verify signature + valid, err := crypto.VerifyManifest(manifest, result, pub) + require.NoError(t, err) + assert.True(t, valid, "Signature should be valid") +} + +// TestVerifyTamperedManifestFails tests that modified data fails verification +func TestVerifyTamperedManifestFails(t *testing.T) { + t.Parallel() + + // Generate keys + pub, priv, _ := crypto.GenerateSigningKeys() + + // Create signer and sign original + signer, _ := crypto.NewManifestSigner(priv, "test") + manifest := map[string]string{"job_id": "test-123"} + result, _ := signer.SignManifest(manifest) + + // Try to verify with modified data + tampered := map[string]string{"job_id": "tampered-999"} + valid, err := crypto.VerifyManifest(tampered, result, pub) + require.NoError(t, err) + assert.False(t, valid, "Tampered data should not verify") +} + +// TestSignManifestBytesAndVerify tests raw byte signing and verification +func TestSignManifestBytesAndVerify(t *testing.T) { + t.Parallel() + + pub, priv, _ := crypto.GenerateSigningKeys() + signer, _ := crypto.NewManifestSigner(priv, "test") + + data := []byte(`{"raw": "json bytes"}`) + result, err := signer.SignManifestBytes(data) + require.NoError(t, err) + require.NotNil(t, result) + assert.NotEmpty(t, result.Signature) + + // Verify raw bytes + valid, err := crypto.VerifyManifestBytes(data, result, pub) + require.NoError(t, err) + assert.True(t, valid, "Raw bytes signature should verify") +} + +// TestSavePrivateKeyToFilePermissions tests key file is created with 0600 permissions +func TestSavePrivateKeyToFilePermissions(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test.key") + + // Generate keys + _, priv, _ := crypto.GenerateSigningKeys() + + // Save private key + err := crypto.SavePrivateKeyToFile(priv, keyPath) + require.NoError(t, err, "SavePrivateKeyToFile should succeed") + + // Verify file exists + _, err = os.Stat(keyPath) + require.NoError(t, err, "Key file should exist") + + // Check permissions + info, err := os.Stat(keyPath) + require.NoError(t, err) + + mode := info.Mode().Perm() + assert.Equal(t, os.FileMode(0600), mode, + "Private key file must have 0600 permissions, got %o", mode) +} + +// TestSavePublicKeyToFile tests public key save (0644 permissions) +func TestSavePublicKeyToFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test.pub") + + // Generate keys + pub, _, _ := crypto.GenerateSigningKeys() + + // Save public key + err := crypto.SavePublicKeyToFile(pub, keyPath) + require.NoError(t, err, "SavePublicKeyToFile should succeed") + + // Check permissions - public keys are 0644 + info, err := os.Stat(keyPath) + require.NoError(t, err) + mode := info.Mode().Perm() + assert.Equal(t, os.FileMode(0644), mode, + "Public key file should have 0644 permissions, got %o", mode) +} + +// TestLoadPrivateKeyFromFile tests private key loading +func TestLoadPrivateKeyFromFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test.key") + + // Generate and save + _, priv, _ := crypto.GenerateSigningKeys() + err := crypto.SavePrivateKeyToFile(priv, keyPath) + require.NoError(t, err) + + // Load + loadedPriv, err := crypto.LoadPrivateKeyFromFile(keyPath) + require.NoError(t, err, "LoadPrivateKeyFromFile should succeed") + assert.Equal(t, priv, loadedPriv, "Private keys should match") +} + +// TestLoadPublicKeyFromFile tests public key loading +func TestLoadPublicKeyFromFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test.pub") + + // Generate and save + pub, _, _ := crypto.GenerateSigningKeys() + err := crypto.SavePublicKeyToFile(pub, keyPath) + require.NoError(t, err) + + // Load + loadedPub, err := crypto.LoadPublicKeyFromFile(keyPath) + require.NoError(t, err, "LoadPublicKeyFromFile should succeed") + assert.Equal(t, pub, loadedPub, "Public keys should match") +} + +// TestLoadPrivateKeyFromFileNotFound tests error for missing private key +func TestLoadPrivateKeyFromFileNotFound(t *testing.T) { + t.Parallel() + + _, err := crypto.LoadPrivateKeyFromFile("/nonexistent/path/to/key") + require.Error(t, err, "Loading nonexistent key should error") +} + +// TestLoadPublicKeyFromFileNotFound tests error for missing public key +func TestLoadPublicKeyFromFileNotFound(t *testing.T) { + t.Parallel() + + _, err := crypto.LoadPublicKeyFromFile("/nonexistent/path/to/key.pub") + require.Error(t, err, "Loading nonexistent public key should error") +} + +// TestEncodeDecodeKeyToBase64 tests base64 encoding/decoding +func TestEncodeDecodeKeyToBase64(t *testing.T) { + t.Parallel() + + // Generate key + pub, _, _ := crypto.GenerateSigningKeys() + + // Encode + encoded := crypto.EncodeKeyToBase64(pub) + assert.NotEmpty(t, encoded) + + // Decode + decoded, err := crypto.DecodeKeyFromBase64(encoded) + require.NoError(t, err) + assert.Equal(t, pub, decoded, "Decoded key should match original") + + // Invalid base64 should error + _, err = crypto.DecodeKeyFromBase64("!!!invalid!!!") + require.Error(t, err, "Invalid base64 should error") +} + +// TestManifestSignerGetters tests GetPublicKey and GetKeyID +func TestManifestSignerGetters(t *testing.T) { + t.Parallel() + + _, priv, err := crypto.GenerateSigningKeys() + require.NoError(t, err) + signer, err := crypto.NewManifestSigner(priv, "test-key-id") + require.NoError(t, err) + + // Test GetPublicKey + gotPub := signer.GetPublicKey() + assert.NotNil(t, gotPub, "GetPublicKey should return public key") + + // Test GetKeyID + gotKeyID := signer.GetKeyID() + assert.Equal(t, "test-key-id", gotKeyID, "GetKeyID should return correct key ID") +} + +// TestSignVerifyRoundTripProperty is a property-based test +// Verifies that for any message, Sign -> Verify round-trip works +func TestSignVerifyRoundTripProperty(t *testing.T) { + t.Parallel() + + pub, priv, err := crypto.GenerateSigningKeys() + require.NoError(t, err) + signer, err := crypto.NewManifestSigner(priv, "property-test-key") + require.NoError(t, err) + + // Property: Sign -> Verify should return original message for various inputs + testMessages := []string{ + "", // Empty message + "a", // Single char + "hello world", // Simple message + "Message with special chars: !@#$%", // Special characters + "Unicode: ", // Unicode + strings.Repeat("x", 10000), // Large message + } + + for _, msg := range testMessages { + sig, err := signer.SignManifest(msg) + require.NoError(t, err, "Should sign message: %q", msg) + + verified, err := crypto.VerifyManifest(msg, sig, pub) + require.NoError(t, err, "Should verify message: %q", msg) + assert.True(t, verified, "Verified should be true for: %q", msg) + } +} + +// TestKeyFilePermissions verifies saved keys have correct permissions +func TestKeyFilePermissions(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + privPath := filepath.Join(tmpDir, "test.key") + pubPath := filepath.Join(tmpDir, "test.pub") + + pub, priv, _ := crypto.GenerateSigningKeys() + + // Save keys + err := crypto.SavePrivateKeyToFile(priv, privPath) + require.NoError(t, err) + err = crypto.SavePublicKeyToFile(pub, pubPath) + require.NoError(t, err) + + // Check private key permissions (should be restrictive) + info, err := os.Stat(privPath) + require.NoError(t, err) + mode := info.Mode().Perm() + assert.Equal(t, os.FileMode(0o600), mode, "Private key should have 0600 permissions") + + // Check public key permissions (can be more permissive) + info, err = os.Stat(pubPath) + require.NoError(t, err) + mode = info.Mode().Perm() + assert.Equal(t, os.FileMode(0o644), mode, "Public key should have 0644 permissions") +}