feat: add manifest signing and native hashing support

- Integrate RunManifest.Validate with existing Validator
- Add manifest Sign() and Verify() methods
- Add native C++ hashing libraries (dataset_hash, queue_index)
- Add native bridge for Go/C++ integration
- Add deduplication support in queue
This commit is contained in:
Jeremie Fraeys 2026-02-19 15:34:39 -05:00
parent a3f9bf8731
commit 37aad7ae87
No known key found for this signature in database
13 changed files with 775 additions and 74 deletions

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
@ -121,6 +122,11 @@ type RunManifest struct {
GPUDevices []string `json:"gpu_devices,omitempty"`
WorkerHost string `json:"worker_host,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
// Signature fields for tamper detection
Signature string `json:"signature,omitempty"`
SignerKeyID string `json:"signer_key_id,omitempty"`
SigAlg string `json:"sig_alg,omitempty"`
}
func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest {
@ -234,3 +240,60 @@ func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) {
m.Narrative.Tags = clean
}
}
// Sign signs the manifest using the provided signer
func (m *RunManifest) Sign(signer *crypto.ManifestSigner) error {
if m == nil {
return fmt.Errorf("cannot sign nil manifest")
}
result, err := signer.SignManifest(m)
if err != nil {
return fmt.Errorf("failed to sign manifest: %w", err)
}
m.Signature = result.Signature
m.SignerKeyID = result.KeyID
m.SigAlg = result.Algorithm
return nil
}
// Verify verifies the manifest signature using the provided public key
func (m *RunManifest) Verify(publicKey []byte) (bool, error) {
if m == nil {
return false, fmt.Errorf("cannot verify nil manifest")
}
if m.Signature == "" {
return false, fmt.Errorf("manifest has no signature")
}
// Build signing result from manifest fields
result := &crypto.SigningResult{
Signature: m.Signature,
KeyID: m.SignerKeyID,
Algorithm: m.SigAlg,
}
// Call crypto package to verify
return crypto.VerifyManifest(m, result, publicKey)
}
// IsSigned returns true if the manifest has a signature
func (m *RunManifest) IsSigned() bool {
return m != nil && m.Signature != ""
}
// Validate checks manifest completeness using the standard Validator.
// This delegates to Validator.Validate() for consistency.
func (m *RunManifest) Validate() error {
v := NewValidator()
return v.Validate(m)
}
// ValidateStrict performs strict validation including optional provenance fields.
// This delegates to Validator.ValidateStrict() for consistency.
func (m *RunManifest) ValidateStrict() error {
v := NewValidator()
return v.ValidateStrict(m)
}

78
internal/queue/dedup.go Normal file
View file

@ -0,0 +1,78 @@
package queue
import (
"fmt"
"sync"
"time"
)
// ErrAlreadyQueued is returned when a job with the same commit was recently queued
var ErrAlreadyQueued = fmt.Errorf("job already queued with this commit")
// CommitDedup tracks recently queued commits to prevent duplicate submissions
type CommitDedup struct {
mu sync.RWMutex
commits map[string]time.Time // key: "job_name:commit_id" -> queued_at
ttl time.Duration
}
// NewCommitDedup creates a new commit deduplication tracker
func NewCommitDedup(ttl time.Duration) *CommitDedup {
if ttl <= 0 {
ttl = 1 * time.Hour // Default 1 hour TTL
}
return &CommitDedup{
commits: make(map[string]time.Time),
ttl: ttl,
}
}
// IsDuplicate checks if a job+commit combination was recently queued
func (d *CommitDedup) IsDuplicate(jobName, commitID string) bool {
key := d.key(jobName, commitID)
d.mu.RLock()
defer d.mu.RUnlock()
if t, ok := d.commits[key]; ok {
if time.Since(t) < d.ttl {
return true // Still within TTL, consider duplicate
}
}
return false
}
// MarkQueued records that a job+commit combination was just queued
func (d *CommitDedup) MarkQueued(jobName, commitID string) {
key := d.key(jobName, commitID)
d.mu.Lock()
defer d.mu.Unlock()
d.commits[key] = time.Now()
}
// Cleanup removes expired entries (call periodically, e.g., every 5 minutes)
func (d *CommitDedup) Cleanup() {
d.mu.Lock()
defer d.mu.Unlock()
now := time.Now()
for id, t := range d.commits {
if now.Sub(t) > d.ttl {
delete(d.commits, id)
}
}
}
// key generates a unique key for job+commit combination
func (d *CommitDedup) key(jobName, commitID string) string {
return jobName + ":" + commitID
}
// Size returns the number of tracked commits (for metrics/debugging)
func (d *CommitDedup) Size() int {
d.mu.RLock()
defer d.mu.RUnlock()
return len(d.commits)
}

View file

@ -26,6 +26,7 @@ type TaskQueue struct {
metricsCh chan metricEvent
metricsDone chan struct{}
flushEvery time.Duration
dedup *CommitDedup // Tracks recently queued commits
}
type metricEvent struct {
@ -103,10 +104,12 @@ func NewTaskQueue(cfg Config) (*TaskQueue, error) {
metricsCh: make(chan metricEvent, 256),
metricsDone: make(chan struct{}),
flushEvery: flushEvery,
dedup: NewCommitDedup(1 * time.Hour), // 1 hour default TTL for commit dedup
}
go tq.runMetricsBuffer()
go tq.runLeaseReclamation() // Start lease reclamation background job
go tq.runDedupCleanup() // Start dedup cleanup background job
return tq, nil
}
@ -732,7 +735,38 @@ func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration)
return tq.GetTask(member)
}
// runMetricsBuffer buffers and flushes metrics
// runDedupCleanup periodically cleans up expired dedup entries every 5 minutes
func (tq *TaskQueue) runDedupCleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-tq.ctx.Done():
return
case <-ticker.C:
tq.dedup.Cleanup()
}
}
}
// AddTaskDedup adds a task with commit deduplication check
// Returns ErrAlreadyQueued if the same job+commit was recently queued
func (tq *TaskQueue) AddTaskDedup(task *Task, commitID string) error {
if commitID != "" && tq.dedup.IsDuplicate(task.JobName, commitID) {
return ErrAlreadyQueued
}
if err := tq.AddTask(task); err != nil {
return err
}
// Mark as queued on success
if commitID != "" {
tq.dedup.MarkQueued(task.JobName, commitID)
}
return nil
}
func (tq *TaskQueue) runMetricsBuffer() {
defer close(tq.metricsDone)
ticker := time.NewTicker(tq.flushEvery)

View file

@ -1,19 +1,15 @@
package worker
import "github.com/jfraeys/fetch_ml/internal/worker/integrity"
import (
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
// UseNativeLibs controls whether to use C++ implementations.
// Set FETCHML_NATIVE_LIBS=1 to enable native libraries.
// This is defined here so it's available regardless of build tags.
var UseNativeLibs = false
// dirOverallSHA256Hex selects implementation based on toggle.
// This file has no CGo imports so it compiles even when CGO is disabled.
// The actual implementations are in native_bridge.go (native) and integrity package (Go).
// dirOverallSHA256Hex uses native implementation when compiled with -tags native_libs.
// The build tag selects between native_bridge.go (stub) and native_bridge_libs.go (real).
// No runtime configuration needed - build determines behavior.
func dirOverallSHA256Hex(root string) (string, error) {
if !UseNativeLibs {
return integrity.DirOverallSHA256Hex(root)
}
// native_bridge_libs.go provides this when built with -tags native_libs
// native_bridge.go provides stub that falls back to Go
return dirOverallSHA256HexNative(root)
}

View file

@ -104,8 +104,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
return "", fmt.Errorf("not a directory")
}
// Collect all files
// Collect all files with size info
var files []string
var totalSize int64
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
@ -118,6 +119,11 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
return err
}
files = append(files, rel)
// Track total size for optimization decisions
if info, err := d.Info(); err == nil {
totalSize += info.Size()
}
return nil
})
if err != nil {
@ -183,3 +189,30 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
// EstimateDirSize returns total size of directory contents in bytes
func EstimateDirSize(root string) (int64, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return 0, err
}
if !info.IsDir() {
return info.Size(), nil
}
var totalSize int64
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
if info, err := d.Info(); err == nil {
totalSize += info.Size()
}
return nil
})
return totalSize, err
}

View file

@ -12,8 +12,6 @@ import (
)
func init() {
// Even with CGO, native libs require explicit build tag
UseNativeLibs = false
log.Printf("[native] Native libraries disabled (build with -tags native_libs to enable)")
}

View file

@ -21,12 +21,28 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
set(CMAKE_C_FLAGS_RELEASE "-O3 -march=native -DNDEBUG -fomit-frame-pointer")
set(CMAKE_C_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer")
# Security hardening flags (always enabled)
set(SECURITY_FLAGS
-D_FORTIFY_SOURCE=2 # Buffer overflow protection
-fstack-protector-strong # Stack canaries
-Wformat-security # Format string warnings
-Werror=format-security # Format string errors
-fPIE # Position-independent code
)
# Add security flags to all build types
add_compile_options(${SECURITY_FLAGS})
# Linker security flags
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -pie")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now")
# Warnings
add_compile_options(-Wall -Wextra -Wpedantic)
if(ENABLE_ASAN)
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
add_link_options(-fsanitize=address)
add_compile_options(-fsanitize=address,undefined -fno-omit-frame-pointer)
add_link_options(-fsanitize=address,undefined)
endif()
if(ENABLE_TSAN)
@ -59,6 +75,10 @@ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/tests)
add_executable(test_storage tests/test_storage.cpp)
target_link_libraries(test_storage queue_index)
add_test(NAME storage_smoke COMMAND test_storage)
add_executable(test_dataset_hash tests/test_dataset_hash.cpp)
target_link_libraries(test_dataset_hash dataset_hash)
add_test(NAME dataset_hash_smoke COMMAND test_dataset_hash)
endif()
# Combined target for building all libraries

View file

@ -5,19 +5,13 @@
#include <arm_neon.h>
static void transform_armv8(uint32_t* state, const uint8_t* data) {
// Load the 512-bit message block into 4 128-bit vectors
uint32x4_t w0 = vld1q_u32((const uint32_t*)data);
uint32x4_t w1 = vld1q_u32((const uint32_t*)(data + 16));
uint32x4_t w2 = vld1q_u32((const uint32_t*)(data + 32));
uint32x4_t w3 = vld1q_u32((const uint32_t*)(data + 48));
// Load message and reverse bytes within each 32-bit word (big-endian -> native)
uint32x4_t w0 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data)));
uint32x4_t w1 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 16)));
uint32x4_t w2 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 32)));
uint32x4_t w3 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 48)));
// Reverse byte order (SHA256 uses big-endian words)
w0 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w0)));
w1 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w1)));
w2 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w2)));
w3 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w3)));
// Load current hash state
// Load current hash state (native endianness)
uint32x4_t abcd = vld1q_u32(state);
uint32x4_t efgh = vld1q_u32(state + 4);
uint32x4_t abcd_orig = abcd;
@ -30,20 +24,24 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) {
uint32x4_t k3 = vld1q_u32(&K[12]);
uint32x4_t tmp = vaddq_u32(w0, k0);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
uint32x4_t abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
tmp = vaddq_u32(w1, k1);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
tmp = vaddq_u32(w2, k2);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
tmp = vaddq_u32(w3, k3);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
// Rounds 16-63: Message schedule expansion + rounds
for (int i = 16; i < 64; i += 16) {
@ -52,32 +50,36 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) {
w4 = vsha256su1q_u32(w4, w2, w3);
k0 = vld1q_u32(&K[i]);
tmp = vaddq_u32(w4, k0);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
// Schedule expansion for rounds i+4..i+7
uint32x4_t w5 = vsha256su0q_u32(w1, w2);
w5 = vsha256su1q_u32(w5, w3, w4);
k1 = vld1q_u32(&K[i + 4]);
tmp = vaddq_u32(w5, k1);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
// Schedule expansion for rounds i+8..i+11
uint32x4_t w6 = vsha256su0q_u32(w2, w3);
w6 = vsha256su1q_u32(w6, w4, w5);
k2 = vld1q_u32(&K[i + 8]);
tmp = vaddq_u32(w6, k2);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
// Schedule expansion for rounds i+12..i+15
uint32x4_t w7 = vsha256su0q_u32(w3, w4);
w7 = vsha256su1q_u32(w7, w5, w6);
k3 = vld1q_u32(&K[i + 12]);
tmp = vaddq_u32(w7, k3);
efgh = vsha256h2q_u32(efgh, abcd, tmp);
abcd = vsha256hq_u32(abcd, efgh, tmp);
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
abcd = abcd_new;
// Rotate working variables
w0 = w4; w1 = w5; w2 = w6; w3 = w7;

View file

@ -11,19 +11,8 @@ void sha256_init(Sha256State* hasher) {
memcpy(hasher->state, H0, sizeof(H0));
// Detect best transform implementation
hasher->transform_fn = detect_best_transform();
if (hasher->transform_fn == transform_generic) {
// Try platform-specific implementations
TransformFunc f = detect_armv8_transform();
if (f) {
hasher->transform_fn = f;
} else {
f = detect_x86_transform();
if (f) {
hasher->transform_fn = f;
}
}
}
TransformFunc f = detect_best_transform();
hasher->transform_fn = f ? f : transform_generic;
}
void sha256_update(Sha256State* hasher, const uint8_t* data, size_t len) {

View file

@ -73,6 +73,18 @@ int hash_file(const char* path, size_t buffer_size, char* out_hash) {
return 0;
}
// Hash a single file, allocating result buffer
char* hash_file_alloc(const char* path, size_t buffer_size) {
char* out_hash = (char*)malloc(65); // 64 hex + null
if (!out_hash) return nullptr;
if (hash_file(path, buffer_size, out_hash) != 0) {
free(out_hash);
return nullptr;
}
return out_hash;
}
int hash_files_batch(
const char* const* paths,
uint32_t count,
@ -84,8 +96,8 @@ int hash_files_batch(
int all_success = 1;
for (uint32_t i = 0; i < count; ++i) {
if (hash_file(paths[i], buffer_size, out_hashes[i]) != 0) {
out_hashes[i][0] = '\0';
out_hashes[i] = hash_file_alloc(paths[i], buffer_size);
if (out_hashes[i] == nullptr) {
all_success = 0;
}
}

View file

@ -1,10 +1,15 @@
#include "parallel_hash.h"
#include "../io/file_hash.h"
#include "../crypto/sha256_hasher.h"
#include "../../common/include/thread_pool.h"
#include <dirent.h>
#include <sys/stat.h>
#include <string.h>
#include <stdlib.h>
#include <atomic>
#include <functional>
#include <thread>
#include <vector>
// Simple file collector - just flat directory for now
static int collect_files(const char* dir_path, char** out_paths, int max_files) {
@ -55,6 +60,25 @@ void parallel_hasher_cleanup(ParallelHasher* hasher) {
hasher->pool = nullptr;
}
// Batch hash task - processes a range of files
struct BatchHashTask {
const char** paths;
char** out_hashes;
size_t buffer_size;
int start_idx;
int end_idx;
std::atomic<bool>* success;
};
// Worker function for batch processing
static void batch_hash_worker(BatchHashTask* task) {
for (int i = task->start_idx; i < task->end_idx; i++) {
if (hash_file(task->paths[i], task->buffer_size, task->out_hashes[i]) != 0) {
task->success->store(false);
}
}
}
int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_hash) {
if (!hasher || !path || !out_hash) return -1;
@ -70,7 +94,6 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
sha256_init(&st);
uint8_t result[32];
sha256_finalize(&st, result);
// Convert to hex
static const char hex[] = "0123456789abcdef";
for (int i = 0; i < 32; i++) {
out_hash[i*2] = hex[(result[i] >> 4) & 0xf];
@ -80,15 +103,63 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
return 0;
}
// Hash all files
char hashes[256][65];
// Convert path_ptrs to const char** for batch task
const char* path_array[256];
for (int i = 0; i < count; i++) {
if (hash_file(paths[i], hasher->buffer_size, hashes[i]) != 0) {
return -1;
}
path_array[i] = path_ptrs[i];
}
// Combine hashes
// Parallel hash all files using ThreadPool with batched tasks
char hashes[256][65];
std::atomic<bool> all_success{true};
std::atomic<int> completed_batches{0};
// Determine batch size - divide files among threads
uint32_t num_threads = ThreadPool::default_thread_count();
int batch_size = (count + num_threads - 1) / num_threads;
if (batch_size < 1) batch_size = 1;
int num_batches = (count + batch_size - 1) / batch_size;
// Allocate batch tasks
BatchHashTask* batch_tasks = new BatchHashTask[num_batches];
char* hash_ptrs[256];
for (int i = 0; i < count; i++) {
hash_ptrs[i] = hashes[i];
}
for (int b = 0; b < num_batches; b++) {
int start = b * batch_size;
int end = start + batch_size;
if (end > count) end = count;
batch_tasks[b].paths = path_array;
batch_tasks[b].out_hashes = hash_ptrs;
batch_tasks[b].buffer_size = hasher->buffer_size;
batch_tasks[b].start_idx = start;
batch_tasks[b].end_idx = end;
batch_tasks[b].success = &all_success;
}
// Enqueue batch tasks (one per thread, not one per file)
for (int b = 0; b < num_batches; b++) {
hasher->pool->enqueue([batch_tasks, b, &completed_batches]() {
batch_hash_worker(&batch_tasks[b]);
completed_batches.fetch_add(1);
});
}
// Wait for all batches to complete
while (completed_batches.load() < num_batches) {
std::this_thread::yield();
}
// Check for errors
if (!all_success.load()) {
delete[] batch_tasks;
return -1;
}
// Combine hashes deterministically (same order as paths)
Sha256State st;
sha256_init(&st);
for (int i = 0; i < count; i++) {
@ -105,6 +176,7 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
}
out_hash[64] = '\0';
delete[] batch_tasks;
return 0;
}
@ -122,12 +194,55 @@ int parallel_hash_directory_batch(
int count = collect_files(path, out_paths, (int)max_results);
if (out_count) *out_count = (uint32_t)count;
// Hash each file
for (int i = 0; i < count; i++) {
if (hash_file(out_paths ? out_paths[i] : nullptr, hasher->buffer_size, out_hashes[i]) != 0) {
out_hashes[i][0] = '\0';
}
if (count == 0) {
return 0;
}
return 0;
// Convert out_paths to const char** for batch task
const char* path_array[256];
for (int i = 0; i < count; i++) {
path_array[i] = out_paths ? out_paths[i] : nullptr;
}
// Parallel hash all files using ThreadPool with batched tasks
std::atomic<bool> all_success{true};
std::atomic<int> completed_batches{0};
// Determine batch size
uint32_t num_threads = ThreadPool::default_thread_count();
int batch_size = (count + num_threads - 1) / num_threads;
if (batch_size < 1) batch_size = 1;
int num_batches = (count + batch_size - 1) / batch_size;
// Allocate batch tasks
BatchHashTask* batch_tasks = new BatchHashTask[num_batches];
for (int b = 0; b < num_batches; b++) {
int start = b * batch_size;
int end = start + batch_size;
if (end > count) end = count;
batch_tasks[b].paths = path_array;
batch_tasks[b].out_hashes = out_hashes;
batch_tasks[b].buffer_size = hasher->buffer_size;
batch_tasks[b].start_idx = start;
batch_tasks[b].end_idx = end;
batch_tasks[b].success = &all_success;
}
// Enqueue batch tasks
for (int b = 0; b < num_batches; b++) {
hasher->pool->enqueue([batch_tasks, b, &completed_batches]() {
batch_hash_worker(&batch_tasks[b]);
completed_batches.fetch_add(1);
});
}
// Wait for all batches to complete
while (completed_batches.load() < num_batches) {
std::this_thread::yield();
}
delete[] batch_tasks;
return all_success.load() ? 0 : -1;
}

View file

@ -0,0 +1,286 @@
// Simple test suite for dataset_hash library (no external dependencies)
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cassert>
#include <string>
#include <vector>
#include <chrono>
#include <filesystem>
#include <fstream>
#include <iostream>
#include "dataset_hash/dataset_hash.h"
namespace fs = std::filesystem;
using namespace std::chrono;
// Simple test macros
#define TEST_ASSERT(cond) \
do { \
if (!(cond)) { \
fprintf(stderr, "ASSERTION FAILED: %s at line %d\n", #cond, __LINE__); \
return 1; \
} \
} while(0)
#define TEST_ASSERT_EQ(a, b) TEST_ASSERT((a) == (b))
#define TEST_ASSERT_NE(a, b) TEST_ASSERT((a) != (b))
#define TEST_ASSERT_STR_EQ(a, b) TEST_ASSERT(strcmp((a), (b)) == 0)
// Helper functions
fs::path create_temp_dir() {
fs::path temp = fs::temp_directory_path() / "dataset_hash_test_XXXXXX";
fs::create_directories(temp);
return temp;
}
void cleanup_temp_dir(const fs::path& dir) {
fs::remove_all(dir);
}
void create_test_file(const fs::path& dir, const std::string& name, const std::string& content) {
std::ofstream file(dir / name);
file << content;
file.close();
}
// Test 1: Context creation
int test_context_creation() {
printf("Testing context creation...\n");
// Auto-detect threads
fh_context_t* ctx = fh_init(0);
TEST_ASSERT_NE(ctx, nullptr);
fh_cleanup(ctx);
// Specific thread count
ctx = fh_init(4);
TEST_ASSERT_NE(ctx, nullptr);
fh_cleanup(ctx);
printf(" PASSED\n");
return 0;
}
// Test 2: SIMD detection
int test_simd_detection() {
printf("Testing SIMD detection...\n");
int has_simd = fh_has_simd_sha256();
const char* impl_name = fh_get_simd_impl_name();
printf(" SIMD available: %s\n", has_simd ? "yes" : "no");
printf(" Implementation: %s\n", impl_name);
TEST_ASSERT_NE(impl_name, nullptr);
TEST_ASSERT(strlen(impl_name) > 0);
printf(" PASSED\n");
return 0;
}
// Test 3: Hash single file
int test_hash_single_file() {
printf("Testing single file hash...\n");
fs::path temp = create_temp_dir();
fh_context_t* ctx = fh_init(1);
TEST_ASSERT_NE(ctx, nullptr);
// Create test file
create_test_file(temp, "test.txt", "Hello, World!");
// Hash it
char* hash = fh_hash_file(ctx, (temp / "test.txt").string().c_str());
TEST_ASSERT_NE(hash, nullptr);
// Verify hash format (64 hex characters + null)
TEST_ASSERT_EQ(strlen(hash), 64);
// Hash should be deterministic
char* hash2 = fh_hash_file(ctx, (temp / "test.txt").string().c_str());
TEST_ASSERT_NE(hash2, nullptr);
TEST_ASSERT_STR_EQ(hash, hash2);
fh_free_string(hash);
fh_free_string(hash2);
fh_cleanup(ctx);
cleanup_temp_dir(temp);
printf(" PASSED\n");
return 0;
}
// Test 4: Hash empty file (known hash)
int test_hash_empty_file() {
printf("Testing empty file hash...\n");
fs::path temp = create_temp_dir();
fh_context_t* ctx = fh_init(1);
TEST_ASSERT_NE(ctx, nullptr);
// Create empty file
create_test_file(temp, "empty.txt", "");
char* hash = fh_hash_file(ctx, (temp / "empty.txt").string().c_str());
TEST_ASSERT_NE(hash, nullptr);
TEST_ASSERT_EQ(strlen(hash), 64);
// Debug: print actual hash
printf(" Empty file hash: %s\n", hash);
// Known SHA-256 of empty string
const char* expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
printf(" Expected hash: %s\n", expected);
TEST_ASSERT_STR_EQ(hash, expected);
fh_free_string(hash);
fh_cleanup(ctx);
cleanup_temp_dir(temp);
printf(" PASSED\n");
return 0;
}
// Test 5: Hash directory
int test_hash_directory() {
printf("Testing directory hash...\n");
fs::path temp = create_temp_dir();
fh_context_t* ctx = fh_init(4);
TEST_ASSERT_NE(ctx, nullptr);
// Create directory structure
create_test_file(temp, "root.txt", "root");
fs::create_directories(temp / "subdir");
create_test_file(temp, "subdir/file1.txt", "file1");
create_test_file(temp, "subdir/file2.txt", "file2");
// Hash directory
char* hash = fh_hash_directory(ctx, temp.string().c_str());
TEST_ASSERT_NE(hash, nullptr);
TEST_ASSERT_EQ(strlen(hash), 64);
// Hash should be deterministic
char* hash2 = fh_hash_directory(ctx, temp.string().c_str());
TEST_ASSERT_NE(hash2, nullptr);
TEST_ASSERT_STR_EQ(hash, hash2);
fh_free_string(hash);
fh_free_string(hash2);
fh_cleanup(ctx);
cleanup_temp_dir(temp);
printf(" PASSED\n");
return 0;
}
// Test 6: Batch hash
int test_batch_hash() {
printf("Testing batch hash...\n");
fs::path temp = create_temp_dir();
fh_context_t* ctx = fh_init(4);
TEST_ASSERT_NE(ctx, nullptr);
// Create test files
const int num_files = 10;
std::vector<std::string> paths;
std::vector<const char*> c_paths;
for (int i = 0; i < num_files; i++) {
std::string name = "file_" + std::to_string(i) + ".txt";
create_test_file(temp, name, "Content " + std::to_string(i));
paths.push_back((temp / name).string());
c_paths.push_back(paths.back().c_str());
}
// Hash batch
std::vector<char*> results(num_files, nullptr);
int ret = fh_hash_batch(ctx, c_paths.data(), num_files, results.data());
TEST_ASSERT_EQ(ret, 0);
// Verify all hashes
for (int i = 0; i < num_files; i++) {
TEST_ASSERT_NE(results[i], nullptr);
TEST_ASSERT_EQ(strlen(results[i]), 64);
fh_free_string(results[i]);
}
fh_cleanup(ctx);
cleanup_temp_dir(temp);
printf(" PASSED\n");
return 0;
}
// Test 7: Performance test
int test_performance() {
printf("Testing performance...\n");
fs::path temp = create_temp_dir();
fh_context_t* ctx = fh_init(4);
TEST_ASSERT_NE(ctx, nullptr);
// Create 1000 small files
const int num_files = 1000;
auto start = high_resolution_clock::now();
for (int i = 0; i < num_files; i++) {
create_test_file(temp, "perf_" + std::to_string(i) + ".txt", "content");
}
auto create_end = high_resolution_clock::now();
// Hash all files
char* hash = fh_hash_directory(ctx, temp.string().c_str());
TEST_ASSERT_NE(hash, nullptr);
auto hash_end = high_resolution_clock::now();
auto create_time = duration_cast<milliseconds>(create_end - start);
auto hash_time = duration_cast<milliseconds>(hash_end - create_end);
printf(" Created %d files in %lld ms\n", num_files, create_time.count());
printf(" Hashed %d files in %lld ms\n", num_files, hash_time.count());
printf(" Throughput: %.1f files/sec\n", num_files * 1000.0 / hash_time.count());
fh_free_string(hash);
fh_cleanup(ctx);
cleanup_temp_dir(temp);
printf(" PASSED\n");
return 0;
}
// Main test runner
int main() {
printf("\n=== Dataset Hash Library Test Suite ===\n\n");
int failed = 0;
failed += test_context_creation();
failed += test_simd_detection();
failed += test_hash_single_file();
failed += test_hash_empty_file();
failed += test_hash_directory();
failed += test_batch_hash();
failed += test_performance();
printf("\n=== Test Results ===\n");
if (failed == 0) {
printf("All tests PASSED!\n");
return 0;
} else {
printf("%d test(s) FAILED\n", failed);
return 1;
}
}

View file

@ -0,0 +1,75 @@
package benchmarks
import (
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
)
// BenchmarkDatasetSizeComparison finds the crossover point where native wins
// Run with: FETCHML_NATIVE_LIBS=1 go test -tags native_libs -bench=BenchmarkDatasetSize ./tests/benchmarks/
func BenchmarkDatasetSizeComparison(b *testing.B) {
sizes := []struct {
name string
fileSize int
numFiles int
totalMB int
}{
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
}
for _, tc := range sizes {
tc := tc // capture range variable
b.Run(tc.name+"/GoParallel", func(b *testing.B) {
tmpDir := b.TempDir()
createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := integrity.DirOverallSHA256HexParallel(tmpDir)
if err != nil {
b.Fatal(err)
}
}
})
b.Run(tc.name+"/Native", func(b *testing.B) {
tmpDir := b.TempDir()
createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := worker.DirOverallSHA256HexParallel(tmpDir)
if err != nil {
b.Fatal(err)
}
}
})
}
}
func createTestFiles(b *testing.B, dir string, numFiles int, fileSize int) {
data := make([]byte, fileSize)
for i := range data {
data[i] = byte(i % 256)
}
for i := 0; i < numFiles; i++ {
path := filepath.Join(dir, "data", string(rune('a'+i%26)), "chunk.bin")
if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil {
b.Fatal(err)
}
if err := os.WriteFile(path, data, 0640); err != nil {
b.Fatal(err)
}
}
}