// test_sha256_arm_kat.cpp - Known-answer test for ARMv8 SHA256 // Verifies ARMv8 output matches generic implementation for NIST test vectors #include #include #include // Only compile on ARM64 #if defined(__aarch64__) || defined(_M_ARM64) // SHA256 test vectors from NIST SP 800-22 // Input: "abc" (3 bytes) // Expected: ba7816bf 8f01cfea 414140de 5dae2223 b00361a3 96177a9c b410ff61 f20015ad static const uint8_t test_input_abc[] = {'a', 'b', 'c'}; static const uint8_t expected_abc[32] = { 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40, 0xde, 0x5d, 0xae, 0x22, 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad }; // Input: empty string (0 bytes) // Expected: e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855 static const uint8_t expected_empty[32] = { 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55 }; // Include the implementations #include "../dataset_hash/crypto/sha256_base.h" // Forward declaration - uses C++ linkage using TransformFunc = void (*)(uint32_t* state, const uint8_t* data); TransformFunc detect_armv8_transform(void); // Wrapper for generic transform to match signature static void generic_wrapper(uint32_t* state, const uint8_t* data) { transform_generic(state, data); } static void sha256_full(uint8_t* out, const uint8_t* in, size_t len, TransformFunc transform) { uint32_t state[8]; uint8_t buffer[64]; uint64_t bitlen = len * 8; // Init memcpy(state, H0, 32); // Process full blocks size_t i = 0; while (len >= 64) { transform(state, in + i); i += 64; len -= 64; } // Final block memcpy(buffer, in + i, len); buffer[len++] = 0x80; if (len > 56) { memset(buffer + len, 0, 64 - len); transform(state, buffer); len = 0; } memset(buffer + len, 0, 56 - len); // Append length (big-endian) for (int j = 0; j < 8; j++) { buffer[63 - j] = bitlen >> (j * 8); } transform(state, buffer); // Store result (big-endian) for (int j = 0; j < 8; j++) { out[j*4 + 0] = state[j] >> 24; out[j*4 + 1] = state[j] >> 16; out[j*4 + 2] = state[j] >> 8; out[j*4 + 3] = state[j]; } } int main() { TransformFunc armv8 = detect_armv8_transform(); if (!armv8) { printf("SKIP: ARMv8 not available on this platform\n"); return 0; } uint8_t result_armv8[32]; uint8_t result_generic[32]; int passed = 0; int failed = 0; // Test 1: Empty string sha256_full(result_armv8, nullptr, 0, armv8); sha256_full(result_generic, nullptr, 0, generic_wrapper); if (memcmp(result_armv8, expected_empty, 32) == 0 && memcmp(result_armv8, result_generic, 32) == 0) { printf("PASS: Empty string hash\n"); passed++; } else { printf("FAIL: Empty string hash\n"); failed++; } // Test 2: "abc" sha256_full(result_armv8, test_input_abc, 3, armv8); sha256_full(result_generic, test_input_abc, 3, generic_wrapper); if (memcmp(result_armv8, expected_abc, 32) == 0 && memcmp(result_armv8, result_generic, 32) == 0) { printf("PASS: \"abc\" hash\n"); passed++; } else { printf("FAIL: \"abc\" hash\n"); failed++; } printf("\nResults: %d passed, %d failed\n", passed, failed); return failed > 0 ? 1 : 0; } #else // Not ARM64 int main() { printf("SKIP: ARMv8 tests only run on aarch64\n"); return 0; } #endif