fetch_ml/native/dataset_hash/dataset_hash.cpp
Jeremie Fraeys d408a60eb1
Some checks failed
Documentation / build-and-publish (push) Waiting to run
Test / test (push) Waiting to run
Checkout test / test (push) Successful in 5s
CI with Native Libraries / test-native (push) Has been cancelled
CI with Native Libraries / build-release (push) Has been cancelled
ci: push all workflow updates
2026-02-12 13:28:15 -05:00

517 lines
15 KiB
C++

#include "dataset_hash.h"
#include <algorithm>
#include <cerrno>
#include <cstring>
#include <fcntl.h>
#include <filesystem>
#include <mutex>
#include <shared_mutex>
#include <string>
#include <sys/mman.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
#include <vector>
// Platform-specific includes for SIMD
#if defined(__x86_64__) || defined(_M_X64)
#include <cpuid.h>
#include <immintrin.h>
#define HAS_X86_SIMD
#elif defined(__aarch64__) || defined(_M_ARM64)
#include <arm_neon.h>
#define HAS_ARM_SIMD
#endif
namespace fs = std::filesystem;
// SHA256 constants
static const uint32_t K[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};
// SIMD implementation detection
enum class Sha256Impl {
GENERIC,
SHA_NI,
ARMV8_CRYPTO
};
static Sha256Impl detect_best_impl() {
#if defined(HAS_ARM_SIMD)
return Sha256Impl::ARMV8_CRYPTO;
#elif defined(HAS_X86_SIMD)
unsigned int eax, ebx, ecx, edx;
if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) {
if (ebx & (1 << 29)) { // SHA bit
return Sha256Impl::SHA_NI;
}
}
return Sha256Impl::GENERIC;
#else
return Sha256Impl::GENERIC;
#endif
}
// Generic SHA256 implementation
class Sha256Generic {
public:
static void transform(uint32_t* state, const uint8_t* data) {
uint32_t W[64];
uint32_t a, b, c, d, e, f, g, h;
// Prepare message schedule
for (int i = 0; i < 16; ++i) {
W[i] = (data[i * 4] << 24) | (data[i * 4 + 1] << 16) |
(data[i * 4 + 2] << 8) | data[i * 4 + 3];
}
for (int i = 16; i < 64; ++i) {
uint32_t s0 = (W[i-15] >> 7 | W[i-15] << 25) ^
(W[i-15] >> 18 | W[i-15] << 14) ^
(W[i-15] >> 3);
uint32_t s1 = (W[i-2] >> 17 | W[i-2] << 15) ^
(W[i-2] >> 19 | W[i-2] << 13) ^
(W[i-2] >> 10);
W[i] = W[i-16] + s0 + W[i-7] + s1;
}
// Initialize working variables
a = state[0]; b = state[1]; c = state[2]; d = state[3];
e = state[4]; f = state[5]; g = state[6]; h = state[7];
// Main loop
for (int i = 0; i < 64; ++i) {
uint32_t S1 = (e >> 6 | e << 26) ^ (e >> 11 | e << 21) ^ (e >> 25 | e << 7);
uint32_t ch = (e & f) ^ ((~e) & g);
uint32_t temp1 = h + S1 + ch + K[i] + W[i];
uint32_t S0 = (a >> 2 | a << 30) ^ (a >> 13 | a << 19) ^ (a >> 22 | a << 10);
uint32_t maj = (a & b) ^ (a & c) ^ (b & c);
uint32_t temp2 = S0 + maj;
h = g; g = f; f = e; e = d + temp1;
d = c; c = b; b = a; a = temp1 + temp2;
}
// Update state
state[0] += a; state[1] += b; state[2] += c; state[3] += d;
state[4] += e; state[5] += f; state[6] += g; state[7] += h;
}
};
// Intel SHA-NI implementation (placeholder - actual implementation needs inline asm)
#if defined(HAS_X86_SIMD)
class Sha256SHA_NI {
public:
static void transform(uint32_t* state, const uint8_t* data) {
// For now, fall back to generic (full SHA-NI impl is complex)
// TODO: Implement with _mm_sha256msg1_epu32, _mm_sha256msg2_epu32, etc.
Sha256Generic::transform(state, data);
}
};
#endif
// ARMv8 crypto implementation (placeholder - actual implementation needs intrinsics)
#if defined(HAS_ARM_SIMD)
class Sha256ARMv8 {
public:
static void transform(uint32_t* state, const uint8_t* data) {
// For now, fall back to generic (full ARMv8 impl needs sha256su0, sha256su1, sha256h, sha256h2)
// TODO: Implement with vsha256su0q_u32, vsha256su1q_u32, vsha256hq_u32, vsha256h2q_u32
Sha256Generic::transform(state, data);
}
};
#endif
// SHA256 hasher class
class Sha256Hasher {
uint32_t state[8];
uint8_t buffer[64];
size_t buffer_len;
uint64_t total_len;
Sha256Impl impl;
public:
Sha256Hasher() : buffer_len(0), total_len(0) {
// Initial hash values
state[0] = 0x6a09e667;
state[1] = 0xbb67ae85;
state[2] = 0x3c6ef372;
state[3] = 0xa54ff53a;
state[4] = 0x510e527f;
state[5] = 0x9b05688c;
state[6] = 0x1f83d9ab;
state[7] = 0x5be0cd19;
impl = detect_best_impl();
}
void update(const uint8_t* data, size_t len) {
total_len += len;
// Fill buffer if there's pending data
if (buffer_len > 0) {
size_t to_copy = std::min(len, 64 - buffer_len);
std::memcpy(buffer + buffer_len, data, to_copy);
buffer_len += to_copy;
data += to_copy;
len -= to_copy;
if (buffer_len == 64) {
transform(buffer);
buffer_len = 0;
}
}
// Process full blocks
while (len >= 64) {
transform(data);
data += 64;
len -= 64;
}
// Store remaining data in buffer
if (len > 0) {
std::memcpy(buffer, data, len);
buffer_len = len;
}
}
void finalize(uint8_t* out) {
// Padding
uint64_t bit_len = total_len * 8;
buffer[buffer_len++] = 0x80;
if (buffer_len > 56) {
while (buffer_len < 64) buffer[buffer_len++] = 0;
transform(buffer);
buffer_len = 0;
}
while (buffer_len < 56) buffer[buffer_len++] = 0;
// Append length (big-endian)
for (int i = 7; i >= 0; --i) {
buffer[56 + (7 - i)] = (bit_len >> (i * 8)) & 0xff;
}
transform(buffer);
// Output (big-endian)
for (int i = 0; i < 8; ++i) {
out[i * 4] = (state[i] >> 24) & 0xff;
out[i * 4 + 1] = (state[i] >> 16) & 0xff;
out[i * 4 + 2] = (state[i] >> 8) & 0xff;
out[i * 4 + 3] = state[i] & 0xff;
}
}
static std::string bytes_to_hex(const uint8_t* data) {
static const char hex[] = "0123456789abcdef";
std::string result;
result.reserve(64);
for (int i = 0; i < 32; ++i) {
result += hex[(data[i] >> 4) & 0xf];
result += hex[data[i] & 0xf];
}
return result;
}
private:
void transform(const uint8_t* data) {
switch (impl) {
#if defined(HAS_X86_SIMD)
case Sha256Impl::SHA_NI:
Sha256SHA_NI::transform(state, data);
break;
#endif
#if defined(HAS_ARM_SIMD)
case Sha256Impl::ARMV8_CRYPTO:
Sha256ARMv8::transform(state, data);
break;
#endif
default:
Sha256Generic::transform(state, data);
}
}
};
// Thread pool for parallel hashing
class ThreadPool {
std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks;
std::mutex queue_mutex;
std::condition_variable condition;
bool stop = false;
public:
ThreadPool(size_t num_threads) {
for (size_t i = 0; i < num_threads; ++i) {
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(queue_mutex);
condition.wait(lock, [this] { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
}
});
}
}
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (auto& worker : workers) {
worker.join();
}
}
void enqueue(std::function<void()> task) {
{
std::unique_lock<std::mutex> lock(queue_mutex);
tasks.emplace(std::move(task));
}
condition.notify_one();
}
};
// Context structure
struct fh_context {
std::unique_ptr<ThreadPool> pool;
uint32_t num_threads;
std::string last_error;
size_t buffer_size = 64 * 1024; // 64KB default
};
// Hash a file using mmap
static std::string hash_file_mmap(const char* path, size_t buffer_size) {
int fd = open(path, O_RDONLY);
if (fd < 0) {
return "";
}
struct stat st;
if (fstat(fd, &st) < 0) {
close(fd);
return "";
}
Sha256Hasher hasher;
if (st.st_size == 0) {
// Empty file
uint8_t result[32];
hasher.finalize(result);
close(fd);
return Sha256Hasher::bytes_to_hex(result);
}
// Memory map the file
void* mapped = mmap(nullptr, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (mapped == MAP_FAILED) {
// Fall back to buffered read
std::vector<uint8_t> buffer(buffer_size);
ssize_t n;
while ((n = read(fd, buffer.data(), buffer.size())) > 0) {
hasher.update(buffer.data(), n);
}
} else {
// Process from mmap
hasher.update(static_cast<uint8_t*>(mapped), st.st_size);
munmap(mapped, st.st_size);
}
close(fd);
uint8_t result[32];
hasher.finalize(result);
return Sha256Hasher::bytes_to_hex(result);
}
// C API Implementation
fh_context_t* fh_init(uint32_t num_threads) {
auto* ctx = new fh_context_t;
if (num_threads == 0) {
num_threads = std::thread::hardware_concurrency();
if (num_threads == 0) num_threads = 4;
if (num_threads > 8) num_threads = 8; // Cap at 8
}
ctx->num_threads = num_threads;
ctx->pool = std::make_unique<ThreadPool>(num_threads);
return ctx;
}
void fh_cleanup(fh_context_t* ctx) {
delete ctx;
}
char* fh_hash_file(fh_context_t* ctx, const char* path) {
if (!ctx || !path) return nullptr;
std::string hash = hash_file_mmap(path, ctx->buffer_size);
if (hash.empty()) {
ctx->last_error = "Failed to hash file: " + std::string(path);
return nullptr;
}
char* result = new char[hash.size() + 1];
std::strcpy(result, hash.c_str());
return result;
}
char* fh_hash_directory(fh_context_t* ctx, const char* path) {
if (!ctx || !path) return nullptr;
// Collect all files
std::vector<std::string> files;
try {
for (const auto& entry : fs::recursive_directory_iterator(path)) {
if (entry.is_regular_file()) {
files.push_back(entry.path().string());
}
}
} catch (...) {
ctx->last_error = "Failed to scan directory";
return nullptr;
}
if (files.empty()) {
// Empty directory
Sha256Hasher hasher;
uint8_t result[32];
hasher.finalize(result);
std::string hash = Sha256Hasher::bytes_to_hex(result);
char* out = new char[hash.size() + 1];
std::strcpy(out, hash.c_str());
return out;
}
// Sort for deterministic order
std::sort(files.begin(), files.end());
// Parallel hash all files
std::vector<std::string> hashes(files.size());
std::mutex error_mutex;
bool has_error = false;
std::atomic<size_t> completed(0);
for (size_t i = 0; i < files.size(); ++i) {
ctx->pool->enqueue([&, i]() {
hashes[i] = hash_file_mmap(files[i].c_str(), ctx->buffer_size);
if (hashes[i].empty()) {
std::lock_guard<std::mutex> lock(error_mutex);
has_error = true;
}
completed++;
});
}
// Wait for completion (poll with sleep to avoid blocking)
while (completed.load() < files.size()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
if (has_error) {
ctx->last_error = "Failed to hash some files";
return nullptr;
}
// Combine hashes
Sha256Hasher final_hasher;
for (const auto& h : hashes) {
final_hasher.update(reinterpret_cast<const uint8_t*>(h.data()), h.size());
}
uint8_t result[32];
final_hasher.finalize(result);
std::string final_hash = Sha256Hasher::bytes_to_hex(result);
char* out = new char[final_hash.size() + 1];
std::strcpy(out, final_hash.c_str());
return out;
}
int fh_hash_batch(fh_context_t* ctx, const char** paths, uint32_t count, char** out_hashes) {
if (!ctx || !paths || !out_hashes || count == 0) return -1;
std::atomic<size_t> completed(0);
std::atomic<bool> has_error(false);
for (uint32_t i = 0; i < count; ++i) {
ctx->pool->enqueue([&, i]() {
std::string hash = hash_file_mmap(paths[i], ctx->buffer_size);
if (hash.empty()) {
has_error = true;
std::strcpy(out_hashes[i], "");
} else {
std::strncpy(out_hashes[i], hash.c_str(), 64);
out_hashes[i][64] = '\0';
}
completed++;
});
}
// Wait for completion
while (completed.load() < count) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
return has_error ? -1 : 0;
}
char* fh_hash_directory_combined(fh_context_t* ctx, const char* dir_path) {
// Same as fh_hash_directory
return fh_hash_directory(ctx, dir_path);
}
void fh_free_string(char* str) {
delete[] str;
}
const char* fh_last_error(fh_context_t* ctx) {
if (!ctx || ctx->last_error.empty()) return nullptr;
return ctx->last_error.c_str();
}
int fh_has_simd_sha256(void) {
Sha256Impl impl = detect_best_impl();
return (impl == Sha256Impl::SHA_NI || impl == Sha256Impl::ARMV8_CRYPTO) ? 1 : 0;
}
const char* fh_get_simd_impl_name(void) {
Sha256Impl impl = detect_best_impl();
switch (impl) {
#if defined(HAS_X86_SIMD)
case Sha256Impl::SHA_NI:
return "SHA-NI";
#endif
#if defined(HAS_ARM_SIMD)
case Sha256Impl::ARMV8_CRYPTO:
return "ARMv8";
#endif
default:
return "generic";
}
}