Add tests for: - GetPublicKey: returns correct public key - GetKeyID: returns correct key ID - Property-based round-trip: Sign -> Verify for various message types (empty, single char, unicode, large messages) Coverage: GetPublicKey 100%, GetKeyID 100%
340 lines
9.8 KiB
Go
340 lines
9.8 KiB
Go
package crypto_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/crypto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestGenerateSigningKeys tests Ed25519 key generation
|
|
func TestGenerateSigningKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pub, priv, err := crypto.GenerateSigningKeys()
|
|
require.NoError(t, err, "Key generation should succeed")
|
|
require.NotNil(t, pub, "Public key should not be nil")
|
|
require.NotNil(t, priv, "Private key should not be nil")
|
|
assert.Len(t, pub, 32, "Ed25519 public key should be 32 bytes")
|
|
assert.Len(t, priv, 64, "Ed25519 private key should be 64 bytes")
|
|
|
|
// Keys should be different each time
|
|
pub2, priv2, _ := crypto.GenerateSigningKeys()
|
|
assert.NotEqual(t, pub, pub2, "Keys should be randomly generated")
|
|
assert.NotEqual(t, priv, priv2, "Keys should be randomly generated")
|
|
}
|
|
|
|
// TestNewManifestSigner tests signer creation
|
|
func TestNewManifestSigner(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
keySize int
|
|
keyID string
|
|
wantErr bool
|
|
}{
|
|
{"valid key", 64, "test-key", false},
|
|
{"wrong size - too short", 32, "short-key", true},
|
|
{"wrong size - too long", 96, "long-key", true},
|
|
{"empty key", 0, "empty-key", true},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var privKey []byte
|
|
if tc.keySize > 0 {
|
|
privKey = make([]byte, tc.keySize)
|
|
// Fill with non-zero data for realistic test
|
|
for i := range privKey {
|
|
privKey[i] = byte(i % 256)
|
|
}
|
|
}
|
|
|
|
signer, err := crypto.NewManifestSigner(privKey, tc.keyID)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
assert.Nil(t, signer)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, signer)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSignVerifyRoundTrip tests sign → verify cycle
|
|
func TestSignVerifyRoundTrip(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Generate keys
|
|
pub, priv, err := crypto.GenerateSigningKeys()
|
|
require.NoError(t, err)
|
|
|
|
// Create signer
|
|
signer, err := crypto.NewManifestSigner(priv, "test-key")
|
|
require.NoError(t, err)
|
|
|
|
// Sign manifest
|
|
manifest := map[string]interface{}{
|
|
"job_id": "test-123",
|
|
"command": "python train.py",
|
|
"config": map[string]interface{}{
|
|
"epochs": 10,
|
|
"batch": 32,
|
|
},
|
|
}
|
|
|
|
result, err := signer.SignManifest(manifest)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.NotEmpty(t, result.Signature, "Signature should not be empty")
|
|
assert.Equal(t, "test-key", result.KeyID)
|
|
assert.Equal(t, "Ed25519", result.Algorithm)
|
|
|
|
// Verify signature
|
|
valid, err := crypto.VerifyManifest(manifest, result, pub)
|
|
require.NoError(t, err)
|
|
assert.True(t, valid, "Signature should be valid")
|
|
}
|
|
|
|
// TestVerifyTamperedManifestFails tests that modified data fails verification
|
|
func TestVerifyTamperedManifestFails(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Generate keys
|
|
pub, priv, _ := crypto.GenerateSigningKeys()
|
|
|
|
// Create signer and sign original
|
|
signer, _ := crypto.NewManifestSigner(priv, "test")
|
|
manifest := map[string]string{"job_id": "test-123"}
|
|
result, _ := signer.SignManifest(manifest)
|
|
|
|
// Try to verify with modified data
|
|
tampered := map[string]string{"job_id": "tampered-999"}
|
|
valid, err := crypto.VerifyManifest(tampered, result, pub)
|
|
require.NoError(t, err)
|
|
assert.False(t, valid, "Tampered data should not verify")
|
|
}
|
|
|
|
// TestSignManifestBytesAndVerify tests raw byte signing and verification
|
|
func TestSignManifestBytesAndVerify(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pub, priv, _ := crypto.GenerateSigningKeys()
|
|
signer, _ := crypto.NewManifestSigner(priv, "test")
|
|
|
|
data := []byte(`{"raw": "json bytes"}`)
|
|
result, err := signer.SignManifestBytes(data)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.NotEmpty(t, result.Signature)
|
|
|
|
// Verify raw bytes
|
|
valid, err := crypto.VerifyManifestBytes(data, result, pub)
|
|
require.NoError(t, err)
|
|
assert.True(t, valid, "Raw bytes signature should verify")
|
|
}
|
|
|
|
// TestSavePrivateKeyToFilePermissions tests key file is created with 0600 permissions
|
|
func TestSavePrivateKeyToFilePermissions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
keyPath := filepath.Join(tmpDir, "test.key")
|
|
|
|
// Generate keys
|
|
_, priv, _ := crypto.GenerateSigningKeys()
|
|
|
|
// Save private key
|
|
err := crypto.SavePrivateKeyToFile(priv, keyPath)
|
|
require.NoError(t, err, "SavePrivateKeyToFile should succeed")
|
|
|
|
// Verify file exists
|
|
_, err = os.Stat(keyPath)
|
|
require.NoError(t, err, "Key file should exist")
|
|
|
|
// Check permissions
|
|
info, err := os.Stat(keyPath)
|
|
require.NoError(t, err)
|
|
|
|
mode := info.Mode().Perm()
|
|
assert.Equal(t, os.FileMode(0600), mode,
|
|
"Private key file must have 0600 permissions, got %o", mode)
|
|
}
|
|
|
|
// TestSavePublicKeyToFile tests public key save (0644 permissions)
|
|
func TestSavePublicKeyToFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
keyPath := filepath.Join(tmpDir, "test.pub")
|
|
|
|
// Generate keys
|
|
pub, _, _ := crypto.GenerateSigningKeys()
|
|
|
|
// Save public key
|
|
err := crypto.SavePublicKeyToFile(pub, keyPath)
|
|
require.NoError(t, err, "SavePublicKeyToFile should succeed")
|
|
|
|
// Check permissions - public keys are 0644
|
|
info, err := os.Stat(keyPath)
|
|
require.NoError(t, err)
|
|
mode := info.Mode().Perm()
|
|
assert.Equal(t, os.FileMode(0644), mode,
|
|
"Public key file should have 0644 permissions, got %o", mode)
|
|
}
|
|
|
|
// TestLoadPrivateKeyFromFile tests private key loading
|
|
func TestLoadPrivateKeyFromFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
keyPath := filepath.Join(tmpDir, "test.key")
|
|
|
|
// Generate and save
|
|
_, priv, _ := crypto.GenerateSigningKeys()
|
|
err := crypto.SavePrivateKeyToFile(priv, keyPath)
|
|
require.NoError(t, err)
|
|
|
|
// Load
|
|
loadedPriv, err := crypto.LoadPrivateKeyFromFile(keyPath)
|
|
require.NoError(t, err, "LoadPrivateKeyFromFile should succeed")
|
|
assert.Equal(t, priv, loadedPriv, "Private keys should match")
|
|
}
|
|
|
|
// TestLoadPublicKeyFromFile tests public key loading
|
|
func TestLoadPublicKeyFromFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
keyPath := filepath.Join(tmpDir, "test.pub")
|
|
|
|
// Generate and save
|
|
pub, _, _ := crypto.GenerateSigningKeys()
|
|
err := crypto.SavePublicKeyToFile(pub, keyPath)
|
|
require.NoError(t, err)
|
|
|
|
// Load
|
|
loadedPub, err := crypto.LoadPublicKeyFromFile(keyPath)
|
|
require.NoError(t, err, "LoadPublicKeyFromFile should succeed")
|
|
assert.Equal(t, pub, loadedPub, "Public keys should match")
|
|
}
|
|
|
|
// TestLoadPrivateKeyFromFileNotFound tests error for missing private key
|
|
func TestLoadPrivateKeyFromFileNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := crypto.LoadPrivateKeyFromFile("/nonexistent/path/to/key")
|
|
require.Error(t, err, "Loading nonexistent key should error")
|
|
}
|
|
|
|
// TestLoadPublicKeyFromFileNotFound tests error for missing public key
|
|
func TestLoadPublicKeyFromFileNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := crypto.LoadPublicKeyFromFile("/nonexistent/path/to/key.pub")
|
|
require.Error(t, err, "Loading nonexistent public key should error")
|
|
}
|
|
|
|
// TestEncodeDecodeKeyToBase64 tests base64 encoding/decoding
|
|
func TestEncodeDecodeKeyToBase64(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Generate key
|
|
pub, _, _ := crypto.GenerateSigningKeys()
|
|
|
|
// Encode
|
|
encoded := crypto.EncodeKeyToBase64(pub)
|
|
assert.NotEmpty(t, encoded)
|
|
|
|
// Decode
|
|
decoded, err := crypto.DecodeKeyFromBase64(encoded)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pub, decoded, "Decoded key should match original")
|
|
|
|
// Invalid base64 should error
|
|
_, err = crypto.DecodeKeyFromBase64("!!!invalid!!!")
|
|
require.Error(t, err, "Invalid base64 should error")
|
|
}
|
|
|
|
// TestManifestSignerGetters tests GetPublicKey and GetKeyID
|
|
func TestManifestSignerGetters(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, priv, err := crypto.GenerateSigningKeys()
|
|
require.NoError(t, err)
|
|
signer, err := crypto.NewManifestSigner(priv, "test-key-id")
|
|
require.NoError(t, err)
|
|
|
|
// Test GetPublicKey
|
|
gotPub := signer.GetPublicKey()
|
|
assert.NotNil(t, gotPub, "GetPublicKey should return public key")
|
|
|
|
// Test GetKeyID
|
|
gotKeyID := signer.GetKeyID()
|
|
assert.Equal(t, "test-key-id", gotKeyID, "GetKeyID should return correct key ID")
|
|
}
|
|
|
|
// TestSignVerifyRoundTripProperty is a property-based test
|
|
// Verifies that for any message, Sign -> Verify round-trip works
|
|
func TestSignVerifyRoundTripProperty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pub, priv, err := crypto.GenerateSigningKeys()
|
|
require.NoError(t, err)
|
|
signer, err := crypto.NewManifestSigner(priv, "property-test-key")
|
|
require.NoError(t, err)
|
|
|
|
// Property: Sign -> Verify should return original message for various inputs
|
|
testMessages := []string{
|
|
"", // Empty message
|
|
"a", // Single char
|
|
"hello world", // Simple message
|
|
"Message with special chars: !@#$%", // Special characters
|
|
"Unicode: ", // Unicode
|
|
strings.Repeat("x", 10000), // Large message
|
|
}
|
|
|
|
for _, msg := range testMessages {
|
|
sig, err := signer.SignManifest(msg)
|
|
require.NoError(t, err, "Should sign message: %q", msg)
|
|
|
|
verified, err := crypto.VerifyManifest(msg, sig, pub)
|
|
require.NoError(t, err, "Should verify message: %q", msg)
|
|
assert.True(t, verified, "Verified should be true for: %q", msg)
|
|
}
|
|
}
|
|
|
|
// TestKeyFilePermissions verifies saved keys have correct permissions
|
|
func TestKeyFilePermissions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
privPath := filepath.Join(tmpDir, "test.key")
|
|
pubPath := filepath.Join(tmpDir, "test.pub")
|
|
|
|
pub, priv, _ := crypto.GenerateSigningKeys()
|
|
|
|
// Save keys
|
|
err := crypto.SavePrivateKeyToFile(priv, privPath)
|
|
require.NoError(t, err)
|
|
err = crypto.SavePublicKeyToFile(pub, pubPath)
|
|
require.NoError(t, err)
|
|
|
|
// Check private key permissions (should be restrictive)
|
|
info, err := os.Stat(privPath)
|
|
require.NoError(t, err)
|
|
mode := info.Mode().Perm()
|
|
assert.Equal(t, os.FileMode(0o600), mode, "Private key should have 0600 permissions")
|
|
|
|
// Check public key permissions (can be more permissive)
|
|
info, err = os.Stat(pubPath)
|
|
require.NoError(t, err)
|
|
mode = info.Mode().Perm()
|
|
assert.Equal(t, os.FileMode(0o644), mode, "Public key should have 0644 permissions")
|
|
}
|