From 37aad7ae87121d1d8ffbe6e7588b667e05a96db9 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 19 Feb 2026 15:34:39 -0500 Subject: [PATCH] feat: add manifest signing and native hashing support - Integrate RunManifest.Validate with existing Validator - Add manifest Sign() and Verify() methods - Add native C++ hashing libraries (dataset_hash, queue_index) - Add native bridge for Go/C++ integration - Add deduplication support in queue --- internal/manifest/run_manifest.go | 63 ++++ internal/queue/dedup.go | 78 +++++ internal/queue/queue.go | 36 ++- internal/worker/hash_selector.go | 20 +- internal/worker/integrity/hash.go | 35 ++- internal/worker/native_bridge.go | 2 - native/CMakeLists.txt | 24 +- native/dataset_hash/crypto/sha256_armv8.cpp | 58 ++-- native/dataset_hash/crypto/sha256_hasher.cpp | 15 +- native/dataset_hash/io/file_hash.cpp | 16 +- .../dataset_hash/threading/parallel_hash.cpp | 141 ++++++++- native/tests/test_dataset_hash.cpp | 286 ++++++++++++++++++ .../dataset_size_comparison_test.go | 75 +++++ 13 files changed, 775 insertions(+), 74 deletions(-) create mode 100644 internal/queue/dedup.go create mode 100644 native/tests/test_dataset_hash.cpp create mode 100644 tests/benchmarks/dataset_size_comparison_test.go diff --git a/internal/manifest/run_manifest.go b/internal/manifest/run_manifest.go index 2ad284e..1aff02b 100644 --- a/internal/manifest/run_manifest.go +++ b/internal/manifest/run_manifest.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/jfraeys/fetch_ml/internal/crypto" "github.com/jfraeys/fetch_ml/internal/fileutil" ) @@ -121,6 +122,11 @@ type RunManifest struct { GPUDevices []string `json:"gpu_devices,omitempty"` WorkerHost string `json:"worker_host,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` + + // Signature fields for tamper detection + Signature string `json:"signature,omitempty"` + SignerKeyID string `json:"signer_key_id,omitempty"` + SigAlg string `json:"sig_alg,omitempty"` } func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest { @@ -234,3 +240,60 @@ func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) { m.Narrative.Tags = clean } } + +// Sign signs the manifest using the provided signer +func (m *RunManifest) Sign(signer *crypto.ManifestSigner) error { + if m == nil { + return fmt.Errorf("cannot sign nil manifest") + } + + result, err := signer.SignManifest(m) + if err != nil { + return fmt.Errorf("failed to sign manifest: %w", err) + } + + m.Signature = result.Signature + m.SignerKeyID = result.KeyID + m.SigAlg = result.Algorithm + return nil +} + +// Verify verifies the manifest signature using the provided public key +func (m *RunManifest) Verify(publicKey []byte) (bool, error) { + if m == nil { + return false, fmt.Errorf("cannot verify nil manifest") + } + + if m.Signature == "" { + return false, fmt.Errorf("manifest has no signature") + } + + // Build signing result from manifest fields + result := &crypto.SigningResult{ + Signature: m.Signature, + KeyID: m.SignerKeyID, + Algorithm: m.SigAlg, + } + + // Call crypto package to verify + return crypto.VerifyManifest(m, result, publicKey) +} + +// IsSigned returns true if the manifest has a signature +func (m *RunManifest) IsSigned() bool { + return m != nil && m.Signature != "" +} + +// Validate checks manifest completeness using the standard Validator. +// This delegates to Validator.Validate() for consistency. +func (m *RunManifest) Validate() error { + v := NewValidator() + return v.Validate(m) +} + +// ValidateStrict performs strict validation including optional provenance fields. +// This delegates to Validator.ValidateStrict() for consistency. +func (m *RunManifest) ValidateStrict() error { + v := NewValidator() + return v.ValidateStrict(m) +} diff --git a/internal/queue/dedup.go b/internal/queue/dedup.go new file mode 100644 index 0000000..97d3566 --- /dev/null +++ b/internal/queue/dedup.go @@ -0,0 +1,78 @@ +package queue + +import ( + "fmt" + "sync" + "time" +) + +// ErrAlreadyQueued is returned when a job with the same commit was recently queued +var ErrAlreadyQueued = fmt.Errorf("job already queued with this commit") + +// CommitDedup tracks recently queued commits to prevent duplicate submissions +type CommitDedup struct { + mu sync.RWMutex + commits map[string]time.Time // key: "job_name:commit_id" -> queued_at + ttl time.Duration +} + +// NewCommitDedup creates a new commit deduplication tracker +func NewCommitDedup(ttl time.Duration) *CommitDedup { + if ttl <= 0 { + ttl = 1 * time.Hour // Default 1 hour TTL + } + return &CommitDedup{ + commits: make(map[string]time.Time), + ttl: ttl, + } +} + +// IsDuplicate checks if a job+commit combination was recently queued +func (d *CommitDedup) IsDuplicate(jobName, commitID string) bool { + key := d.key(jobName, commitID) + + d.mu.RLock() + defer d.mu.RUnlock() + + if t, ok := d.commits[key]; ok { + if time.Since(t) < d.ttl { + return true // Still within TTL, consider duplicate + } + } + return false +} + +// MarkQueued records that a job+commit combination was just queued +func (d *CommitDedup) MarkQueued(jobName, commitID string) { + key := d.key(jobName, commitID) + + d.mu.Lock() + defer d.mu.Unlock() + + d.commits[key] = time.Now() +} + +// Cleanup removes expired entries (call periodically, e.g., every 5 minutes) +func (d *CommitDedup) Cleanup() { + d.mu.Lock() + defer d.mu.Unlock() + + now := time.Now() + for id, t := range d.commits { + if now.Sub(t) > d.ttl { + delete(d.commits, id) + } + } +} + +// key generates a unique key for job+commit combination +func (d *CommitDedup) key(jobName, commitID string) string { + return jobName + ":" + commitID +} + +// Size returns the number of tracked commits (for metrics/debugging) +func (d *CommitDedup) Size() int { + d.mu.RLock() + defer d.mu.RUnlock() + return len(d.commits) +} diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 9b10c09..ad2b5e5 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -26,6 +26,7 @@ type TaskQueue struct { metricsCh chan metricEvent metricsDone chan struct{} flushEvery time.Duration + dedup *CommitDedup // Tracks recently queued commits } type metricEvent struct { @@ -103,10 +104,12 @@ func NewTaskQueue(cfg Config) (*TaskQueue, error) { metricsCh: make(chan metricEvent, 256), metricsDone: make(chan struct{}), flushEvery: flushEvery, + dedup: NewCommitDedup(1 * time.Hour), // 1 hour default TTL for commit dedup } go tq.runMetricsBuffer() go tq.runLeaseReclamation() // Start lease reclamation background job + go tq.runDedupCleanup() // Start dedup cleanup background job return tq, nil } @@ -732,7 +735,38 @@ func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration) return tq.GetTask(member) } -// runMetricsBuffer buffers and flushes metrics +// runDedupCleanup periodically cleans up expired dedup entries every 5 minutes +func (tq *TaskQueue) runDedupCleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-tq.ctx.Done(): + return + case <-ticker.C: + tq.dedup.Cleanup() + } + } +} + +// AddTaskDedup adds a task with commit deduplication check +// Returns ErrAlreadyQueued if the same job+commit was recently queued +func (tq *TaskQueue) AddTaskDedup(task *Task, commitID string) error { + if commitID != "" && tq.dedup.IsDuplicate(task.JobName, commitID) { + return ErrAlreadyQueued + } + + if err := tq.AddTask(task); err != nil { + return err + } + + // Mark as queued on success + if commitID != "" { + tq.dedup.MarkQueued(task.JobName, commitID) + } + return nil +} func (tq *TaskQueue) runMetricsBuffer() { defer close(tq.metricsDone) ticker := time.NewTicker(tq.flushEvery) diff --git a/internal/worker/hash_selector.go b/internal/worker/hash_selector.go index 0231f79..cb9f0cf 100644 --- a/internal/worker/hash_selector.go +++ b/internal/worker/hash_selector.go @@ -1,19 +1,15 @@ package worker -import "github.com/jfraeys/fetch_ml/internal/worker/integrity" +import ( + "github.com/jfraeys/fetch_ml/internal/worker/integrity" +) -// UseNativeLibs controls whether to use C++ implementations. -// Set FETCHML_NATIVE_LIBS=1 to enable native libraries. -// This is defined here so it's available regardless of build tags. -var UseNativeLibs = false - -// dirOverallSHA256Hex selects implementation based on toggle. -// This file has no CGo imports so it compiles even when CGO is disabled. -// The actual implementations are in native_bridge.go (native) and integrity package (Go). +// dirOverallSHA256Hex uses native implementation when compiled with -tags native_libs. +// The build tag selects between native_bridge.go (stub) and native_bridge_libs.go (real). +// No runtime configuration needed - build determines behavior. func dirOverallSHA256Hex(root string) (string, error) { - if !UseNativeLibs { - return integrity.DirOverallSHA256Hex(root) - } + // native_bridge_libs.go provides this when built with -tags native_libs + // native_bridge.go provides stub that falls back to Go return dirOverallSHA256HexNative(root) } diff --git a/internal/worker/integrity/hash.go b/internal/worker/integrity/hash.go index 06cf810..02d6da0 100644 --- a/internal/worker/integrity/hash.go +++ b/internal/worker/integrity/hash.go @@ -104,8 +104,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) { return "", fmt.Errorf("not a directory") } - // Collect all files + // Collect all files with size info var files []string + var totalSize int64 err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { if walkErr != nil { return walkErr @@ -118,6 +119,11 @@ func DirOverallSHA256HexParallel(root string) (string, error) { return err } files = append(files, rel) + + // Track total size for optimization decisions + if info, err := d.Info(); err == nil { + totalSize += info.Size() + } return nil }) if err != nil { @@ -183,3 +189,30 @@ func DirOverallSHA256HexParallel(root string) (string, error) { } return fmt.Sprintf("%x", overall.Sum(nil)), nil } + +// EstimateDirSize returns total size of directory contents in bytes +func EstimateDirSize(root string) (int64, error) { + root = filepath.Clean(root) + info, err := os.Stat(root) + if err != nil { + return 0, err + } + if !info.IsDir() { + return info.Size(), nil + } + + var totalSize int64 + err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + if info, err := d.Info(); err == nil { + totalSize += info.Size() + } + return nil + }) + return totalSize, err +} diff --git a/internal/worker/native_bridge.go b/internal/worker/native_bridge.go index 0ee666c..29597ab 100644 --- a/internal/worker/native_bridge.go +++ b/internal/worker/native_bridge.go @@ -12,8 +12,6 @@ import ( ) func init() { - // Even with CGO, native libs require explicit build tag - UseNativeLibs = false log.Printf("[native] Native libraries disabled (build with -tags native_libs to enable)") } diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index ba73ced..ca9edfa 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -21,12 +21,28 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") set(CMAKE_C_FLAGS_RELEASE "-O3 -march=native -DNDEBUG -fomit-frame-pointer") set(CMAKE_C_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer") + # Security hardening flags (always enabled) + set(SECURITY_FLAGS + -D_FORTIFY_SOURCE=2 # Buffer overflow protection + -fstack-protector-strong # Stack canaries + -Wformat-security # Format string warnings + -Werror=format-security # Format string errors + -fPIE # Position-independent code + ) + + # Add security flags to all build types + add_compile_options(${SECURITY_FLAGS}) + + # Linker security flags + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -pie") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now") + # Warnings add_compile_options(-Wall -Wextra -Wpedantic) if(ENABLE_ASAN) - add_compile_options(-fsanitize=address -fno-omit-frame-pointer) - add_link_options(-fsanitize=address) + add_compile_options(-fsanitize=address,undefined -fno-omit-frame-pointer) + add_link_options(-fsanitize=address,undefined) endif() if(ENABLE_TSAN) @@ -59,6 +75,10 @@ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/tests) add_executable(test_storage tests/test_storage.cpp) target_link_libraries(test_storage queue_index) add_test(NAME storage_smoke COMMAND test_storage) + + add_executable(test_dataset_hash tests/test_dataset_hash.cpp) + target_link_libraries(test_dataset_hash dataset_hash) + add_test(NAME dataset_hash_smoke COMMAND test_dataset_hash) endif() # Combined target for building all libraries diff --git a/native/dataset_hash/crypto/sha256_armv8.cpp b/native/dataset_hash/crypto/sha256_armv8.cpp index ef7b83a..d62091b 100644 --- a/native/dataset_hash/crypto/sha256_armv8.cpp +++ b/native/dataset_hash/crypto/sha256_armv8.cpp @@ -5,19 +5,13 @@ #include static void transform_armv8(uint32_t* state, const uint8_t* data) { - // Load the 512-bit message block into 4 128-bit vectors - uint32x4_t w0 = vld1q_u32((const uint32_t*)data); - uint32x4_t w1 = vld1q_u32((const uint32_t*)(data + 16)); - uint32x4_t w2 = vld1q_u32((const uint32_t*)(data + 32)); - uint32x4_t w3 = vld1q_u32((const uint32_t*)(data + 48)); + // Load message and reverse bytes within each 32-bit word (big-endian -> native) + uint32x4_t w0 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data))); + uint32x4_t w1 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 16))); + uint32x4_t w2 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 32))); + uint32x4_t w3 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 48))); - // Reverse byte order (SHA256 uses big-endian words) - w0 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w0))); - w1 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w1))); - w2 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w2))); - w3 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w3))); - - // Load current hash state + // Load current hash state (native endianness) uint32x4_t abcd = vld1q_u32(state); uint32x4_t efgh = vld1q_u32(state + 4); uint32x4_t abcd_orig = abcd; @@ -30,20 +24,24 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) { uint32x4_t k3 = vld1q_u32(&K[12]); uint32x4_t tmp = vaddq_u32(w0, k0); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + uint32x4_t abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; tmp = vaddq_u32(w1, k1); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; tmp = vaddq_u32(w2, k2); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; tmp = vaddq_u32(w3, k3); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; // Rounds 16-63: Message schedule expansion + rounds for (int i = 16; i < 64; i += 16) { @@ -52,32 +50,36 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) { w4 = vsha256su1q_u32(w4, w2, w3); k0 = vld1q_u32(&K[i]); tmp = vaddq_u32(w4, k0); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; // Schedule expansion for rounds i+4..i+7 uint32x4_t w5 = vsha256su0q_u32(w1, w2); w5 = vsha256su1q_u32(w5, w3, w4); k1 = vld1q_u32(&K[i + 4]); tmp = vaddq_u32(w5, k1); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; // Schedule expansion for rounds i+8..i+11 uint32x4_t w6 = vsha256su0q_u32(w2, w3); w6 = vsha256su1q_u32(w6, w4, w5); k2 = vld1q_u32(&K[i + 8]); tmp = vaddq_u32(w6, k2); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; // Schedule expansion for rounds i+12..i+15 uint32x4_t w7 = vsha256su0q_u32(w3, w4); w7 = vsha256su1q_u32(w7, w5, w6); k3 = vld1q_u32(&K[i + 12]); tmp = vaddq_u32(w7, k3); - efgh = vsha256h2q_u32(efgh, abcd, tmp); - abcd = vsha256hq_u32(abcd, efgh, tmp); + abcd_new = vsha256hq_u32(abcd, efgh, tmp); + efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd + abcd = abcd_new; // Rotate working variables w0 = w4; w1 = w5; w2 = w6; w3 = w7; diff --git a/native/dataset_hash/crypto/sha256_hasher.cpp b/native/dataset_hash/crypto/sha256_hasher.cpp index 835abac..46a1927 100644 --- a/native/dataset_hash/crypto/sha256_hasher.cpp +++ b/native/dataset_hash/crypto/sha256_hasher.cpp @@ -11,19 +11,8 @@ void sha256_init(Sha256State* hasher) { memcpy(hasher->state, H0, sizeof(H0)); // Detect best transform implementation - hasher->transform_fn = detect_best_transform(); - if (hasher->transform_fn == transform_generic) { - // Try platform-specific implementations - TransformFunc f = detect_armv8_transform(); - if (f) { - hasher->transform_fn = f; - } else { - f = detect_x86_transform(); - if (f) { - hasher->transform_fn = f; - } - } - } + TransformFunc f = detect_best_transform(); + hasher->transform_fn = f ? f : transform_generic; } void sha256_update(Sha256State* hasher, const uint8_t* data, size_t len) { diff --git a/native/dataset_hash/io/file_hash.cpp b/native/dataset_hash/io/file_hash.cpp index 62d9eaf..4a19dc4 100644 --- a/native/dataset_hash/io/file_hash.cpp +++ b/native/dataset_hash/io/file_hash.cpp @@ -73,6 +73,18 @@ int hash_file(const char* path, size_t buffer_size, char* out_hash) { return 0; } +// Hash a single file, allocating result buffer +char* hash_file_alloc(const char* path, size_t buffer_size) { + char* out_hash = (char*)malloc(65); // 64 hex + null + if (!out_hash) return nullptr; + + if (hash_file(path, buffer_size, out_hash) != 0) { + free(out_hash); + return nullptr; + } + return out_hash; +} + int hash_files_batch( const char* const* paths, uint32_t count, @@ -84,8 +96,8 @@ int hash_files_batch( int all_success = 1; for (uint32_t i = 0; i < count; ++i) { - if (hash_file(paths[i], buffer_size, out_hashes[i]) != 0) { - out_hashes[i][0] = '\0'; + out_hashes[i] = hash_file_alloc(paths[i], buffer_size); + if (out_hashes[i] == nullptr) { all_success = 0; } } diff --git a/native/dataset_hash/threading/parallel_hash.cpp b/native/dataset_hash/threading/parallel_hash.cpp index 74fef42..98ba3a2 100644 --- a/native/dataset_hash/threading/parallel_hash.cpp +++ b/native/dataset_hash/threading/parallel_hash.cpp @@ -1,10 +1,15 @@ #include "parallel_hash.h" #include "../io/file_hash.h" #include "../crypto/sha256_hasher.h" +#include "../../common/include/thread_pool.h" #include #include #include #include +#include +#include +#include +#include // Simple file collector - just flat directory for now static int collect_files(const char* dir_path, char** out_paths, int max_files) { @@ -55,6 +60,25 @@ void parallel_hasher_cleanup(ParallelHasher* hasher) { hasher->pool = nullptr; } +// Batch hash task - processes a range of files +struct BatchHashTask { + const char** paths; + char** out_hashes; + size_t buffer_size; + int start_idx; + int end_idx; + std::atomic* success; +}; + +// Worker function for batch processing +static void batch_hash_worker(BatchHashTask* task) { + for (int i = task->start_idx; i < task->end_idx; i++) { + if (hash_file(task->paths[i], task->buffer_size, task->out_hashes[i]) != 0) { + task->success->store(false); + } + } +} + int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_hash) { if (!hasher || !path || !out_hash) return -1; @@ -70,7 +94,6 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_ sha256_init(&st); uint8_t result[32]; sha256_finalize(&st, result); - // Convert to hex static const char hex[] = "0123456789abcdef"; for (int i = 0; i < 32; i++) { out_hash[i*2] = hex[(result[i] >> 4) & 0xf]; @@ -80,15 +103,63 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_ return 0; } - // Hash all files - char hashes[256][65]; + // Convert path_ptrs to const char** for batch task + const char* path_array[256]; for (int i = 0; i < count; i++) { - if (hash_file(paths[i], hasher->buffer_size, hashes[i]) != 0) { - return -1; - } + path_array[i] = path_ptrs[i]; } - // Combine hashes + // Parallel hash all files using ThreadPool with batched tasks + char hashes[256][65]; + std::atomic all_success{true}; + std::atomic completed_batches{0}; + + // Determine batch size - divide files among threads + uint32_t num_threads = ThreadPool::default_thread_count(); + int batch_size = (count + num_threads - 1) / num_threads; + if (batch_size < 1) batch_size = 1; + int num_batches = (count + batch_size - 1) / batch_size; + + // Allocate batch tasks + BatchHashTask* batch_tasks = new BatchHashTask[num_batches]; + char* hash_ptrs[256]; + for (int i = 0; i < count; i++) { + hash_ptrs[i] = hashes[i]; + } + + for (int b = 0; b < num_batches; b++) { + int start = b * batch_size; + int end = start + batch_size; + if (end > count) end = count; + + batch_tasks[b].paths = path_array; + batch_tasks[b].out_hashes = hash_ptrs; + batch_tasks[b].buffer_size = hasher->buffer_size; + batch_tasks[b].start_idx = start; + batch_tasks[b].end_idx = end; + batch_tasks[b].success = &all_success; + } + + // Enqueue batch tasks (one per thread, not one per file) + for (int b = 0; b < num_batches; b++) { + hasher->pool->enqueue([batch_tasks, b, &completed_batches]() { + batch_hash_worker(&batch_tasks[b]); + completed_batches.fetch_add(1); + }); + } + + // Wait for all batches to complete + while (completed_batches.load() < num_batches) { + std::this_thread::yield(); + } + + // Check for errors + if (!all_success.load()) { + delete[] batch_tasks; + return -1; + } + + // Combine hashes deterministically (same order as paths) Sha256State st; sha256_init(&st); for (int i = 0; i < count; i++) { @@ -105,6 +176,7 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_ } out_hash[64] = '\0'; + delete[] batch_tasks; return 0; } @@ -122,12 +194,55 @@ int parallel_hash_directory_batch( int count = collect_files(path, out_paths, (int)max_results); if (out_count) *out_count = (uint32_t)count; - // Hash each file - for (int i = 0; i < count; i++) { - if (hash_file(out_paths ? out_paths[i] : nullptr, hasher->buffer_size, out_hashes[i]) != 0) { - out_hashes[i][0] = '\0'; - } + if (count == 0) { + return 0; } - return 0; + // Convert out_paths to const char** for batch task + const char* path_array[256]; + for (int i = 0; i < count; i++) { + path_array[i] = out_paths ? out_paths[i] : nullptr; + } + + // Parallel hash all files using ThreadPool with batched tasks + std::atomic all_success{true}; + std::atomic completed_batches{0}; + + // Determine batch size + uint32_t num_threads = ThreadPool::default_thread_count(); + int batch_size = (count + num_threads - 1) / num_threads; + if (batch_size < 1) batch_size = 1; + int num_batches = (count + batch_size - 1) / batch_size; + + // Allocate batch tasks + BatchHashTask* batch_tasks = new BatchHashTask[num_batches]; + + for (int b = 0; b < num_batches; b++) { + int start = b * batch_size; + int end = start + batch_size; + if (end > count) end = count; + + batch_tasks[b].paths = path_array; + batch_tasks[b].out_hashes = out_hashes; + batch_tasks[b].buffer_size = hasher->buffer_size; + batch_tasks[b].start_idx = start; + batch_tasks[b].end_idx = end; + batch_tasks[b].success = &all_success; + } + + // Enqueue batch tasks + for (int b = 0; b < num_batches; b++) { + hasher->pool->enqueue([batch_tasks, b, &completed_batches]() { + batch_hash_worker(&batch_tasks[b]); + completed_batches.fetch_add(1); + }); + } + + // Wait for all batches to complete + while (completed_batches.load() < num_batches) { + std::this_thread::yield(); + } + + delete[] batch_tasks; + return all_success.load() ? 0 : -1; } diff --git a/native/tests/test_dataset_hash.cpp b/native/tests/test_dataset_hash.cpp new file mode 100644 index 0000000..0a16506 --- /dev/null +++ b/native/tests/test_dataset_hash.cpp @@ -0,0 +1,286 @@ +// Simple test suite for dataset_hash library (no external dependencies) + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dataset_hash/dataset_hash.h" + +namespace fs = std::filesystem; +using namespace std::chrono; + +// Simple test macros +#define TEST_ASSERT(cond) \ + do { \ + if (!(cond)) { \ + fprintf(stderr, "ASSERTION FAILED: %s at line %d\n", #cond, __LINE__); \ + return 1; \ + } \ + } while(0) + +#define TEST_ASSERT_EQ(a, b) TEST_ASSERT((a) == (b)) +#define TEST_ASSERT_NE(a, b) TEST_ASSERT((a) != (b)) +#define TEST_ASSERT_STR_EQ(a, b) TEST_ASSERT(strcmp((a), (b)) == 0) + +// Helper functions +fs::path create_temp_dir() { + fs::path temp = fs::temp_directory_path() / "dataset_hash_test_XXXXXX"; + fs::create_directories(temp); + return temp; +} + +void cleanup_temp_dir(const fs::path& dir) { + fs::remove_all(dir); +} + +void create_test_file(const fs::path& dir, const std::string& name, const std::string& content) { + std::ofstream file(dir / name); + file << content; + file.close(); +} + +// Test 1: Context creation +int test_context_creation() { + printf("Testing context creation...\n"); + + // Auto-detect threads + fh_context_t* ctx = fh_init(0); + TEST_ASSERT_NE(ctx, nullptr); + fh_cleanup(ctx); + + // Specific thread count + ctx = fh_init(4); + TEST_ASSERT_NE(ctx, nullptr); + fh_cleanup(ctx); + + printf(" PASSED\n"); + return 0; +} + +// Test 2: SIMD detection +int test_simd_detection() { + printf("Testing SIMD detection...\n"); + + int has_simd = fh_has_simd_sha256(); + const char* impl_name = fh_get_simd_impl_name(); + + printf(" SIMD available: %s\n", has_simd ? "yes" : "no"); + printf(" Implementation: %s\n", impl_name); + + TEST_ASSERT_NE(impl_name, nullptr); + TEST_ASSERT(strlen(impl_name) > 0); + + printf(" PASSED\n"); + return 0; +} + +// Test 3: Hash single file +int test_hash_single_file() { + printf("Testing single file hash...\n"); + + fs::path temp = create_temp_dir(); + + fh_context_t* ctx = fh_init(1); + TEST_ASSERT_NE(ctx, nullptr); + + // Create test file + create_test_file(temp, "test.txt", "Hello, World!"); + + // Hash it + char* hash = fh_hash_file(ctx, (temp / "test.txt").string().c_str()); + TEST_ASSERT_NE(hash, nullptr); + + // Verify hash format (64 hex characters + null) + TEST_ASSERT_EQ(strlen(hash), 64); + + // Hash should be deterministic + char* hash2 = fh_hash_file(ctx, (temp / "test.txt").string().c_str()); + TEST_ASSERT_NE(hash2, nullptr); + TEST_ASSERT_STR_EQ(hash, hash2); + + fh_free_string(hash); + fh_free_string(hash2); + fh_cleanup(ctx); + cleanup_temp_dir(temp); + + printf(" PASSED\n"); + return 0; +} + +// Test 4: Hash empty file (known hash) +int test_hash_empty_file() { + printf("Testing empty file hash...\n"); + + fs::path temp = create_temp_dir(); + + fh_context_t* ctx = fh_init(1); + TEST_ASSERT_NE(ctx, nullptr); + + // Create empty file + create_test_file(temp, "empty.txt", ""); + + char* hash = fh_hash_file(ctx, (temp / "empty.txt").string().c_str()); + TEST_ASSERT_NE(hash, nullptr); + TEST_ASSERT_EQ(strlen(hash), 64); + + // Debug: print actual hash + printf(" Empty file hash: %s\n", hash); + + // Known SHA-256 of empty string + const char* expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + printf(" Expected hash: %s\n", expected); + TEST_ASSERT_STR_EQ(hash, expected); + + fh_free_string(hash); + fh_cleanup(ctx); + cleanup_temp_dir(temp); + + printf(" PASSED\n"); + return 0; +} + +// Test 5: Hash directory +int test_hash_directory() { + printf("Testing directory hash...\n"); + + fs::path temp = create_temp_dir(); + + fh_context_t* ctx = fh_init(4); + TEST_ASSERT_NE(ctx, nullptr); + + // Create directory structure + create_test_file(temp, "root.txt", "root"); + fs::create_directories(temp / "subdir"); + create_test_file(temp, "subdir/file1.txt", "file1"); + create_test_file(temp, "subdir/file2.txt", "file2"); + + // Hash directory + char* hash = fh_hash_directory(ctx, temp.string().c_str()); + TEST_ASSERT_NE(hash, nullptr); + TEST_ASSERT_EQ(strlen(hash), 64); + + // Hash should be deterministic + char* hash2 = fh_hash_directory(ctx, temp.string().c_str()); + TEST_ASSERT_NE(hash2, nullptr); + TEST_ASSERT_STR_EQ(hash, hash2); + + fh_free_string(hash); + fh_free_string(hash2); + fh_cleanup(ctx); + cleanup_temp_dir(temp); + + printf(" PASSED\n"); + return 0; +} + +// Test 6: Batch hash +int test_batch_hash() { + printf("Testing batch hash...\n"); + + fs::path temp = create_temp_dir(); + + fh_context_t* ctx = fh_init(4); + TEST_ASSERT_NE(ctx, nullptr); + + // Create test files + const int num_files = 10; + std::vector paths; + std::vector c_paths; + + for (int i = 0; i < num_files; i++) { + std::string name = "file_" + std::to_string(i) + ".txt"; + create_test_file(temp, name, "Content " + std::to_string(i)); + paths.push_back((temp / name).string()); + c_paths.push_back(paths.back().c_str()); + } + + // Hash batch + std::vector results(num_files, nullptr); + int ret = fh_hash_batch(ctx, c_paths.data(), num_files, results.data()); + TEST_ASSERT_EQ(ret, 0); + + // Verify all hashes + for (int i = 0; i < num_files; i++) { + TEST_ASSERT_NE(results[i], nullptr); + TEST_ASSERT_EQ(strlen(results[i]), 64); + fh_free_string(results[i]); + } + + fh_cleanup(ctx); + cleanup_temp_dir(temp); + + printf(" PASSED\n"); + return 0; +} + +// Test 7: Performance test +int test_performance() { + printf("Testing performance...\n"); + + fs::path temp = create_temp_dir(); + + fh_context_t* ctx = fh_init(4); + TEST_ASSERT_NE(ctx, nullptr); + + // Create 1000 small files + const int num_files = 1000; + auto start = high_resolution_clock::now(); + + for (int i = 0; i < num_files; i++) { + create_test_file(temp, "perf_" + std::to_string(i) + ".txt", "content"); + } + + auto create_end = high_resolution_clock::now(); + + // Hash all files + char* hash = fh_hash_directory(ctx, temp.string().c_str()); + TEST_ASSERT_NE(hash, nullptr); + + auto hash_end = high_resolution_clock::now(); + + auto create_time = duration_cast(create_end - start); + auto hash_time = duration_cast(hash_end - create_end); + + printf(" Created %d files in %lld ms\n", num_files, create_time.count()); + printf(" Hashed %d files in %lld ms\n", num_files, hash_time.count()); + printf(" Throughput: %.1f files/sec\n", num_files * 1000.0 / hash_time.count()); + + fh_free_string(hash); + fh_cleanup(ctx); + cleanup_temp_dir(temp); + + printf(" PASSED\n"); + return 0; +} + +// Main test runner +int main() { + printf("\n=== Dataset Hash Library Test Suite ===\n\n"); + + int failed = 0; + + failed += test_context_creation(); + failed += test_simd_detection(); + failed += test_hash_single_file(); + failed += test_hash_empty_file(); + failed += test_hash_directory(); + failed += test_batch_hash(); + failed += test_performance(); + + printf("\n=== Test Results ===\n"); + if (failed == 0) { + printf("All tests PASSED!\n"); + return 0; + } else { + printf("%d test(s) FAILED\n", failed); + return 1; + } +} + diff --git a/tests/benchmarks/dataset_size_comparison_test.go b/tests/benchmarks/dataset_size_comparison_test.go new file mode 100644 index 0000000..f1b8c18 --- /dev/null +++ b/tests/benchmarks/dataset_size_comparison_test.go @@ -0,0 +1,75 @@ +package benchmarks + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker" + "github.com/jfraeys/fetch_ml/internal/worker/integrity" +) + +// BenchmarkDatasetSizeComparison finds the crossover point where native wins +// Run with: FETCHML_NATIVE_LIBS=1 go test -tags native_libs -bench=BenchmarkDatasetSize ./tests/benchmarks/ +func BenchmarkDatasetSizeComparison(b *testing.B) { + sizes := []struct { + name string + fileSize int + numFiles int + totalMB int + }{ + {"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB + {"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB + {"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB + {"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB + {"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB + } + + for _, tc := range sizes { + tc := tc // capture range variable + b.Run(tc.name+"/GoParallel", func(b *testing.B) { + tmpDir := b.TempDir() + createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := integrity.DirOverallSHA256HexParallel(tmpDir) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run(tc.name+"/Native", func(b *testing.B) { + tmpDir := b.TempDir() + createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := worker.DirOverallSHA256HexParallel(tmpDir) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func createTestFiles(b *testing.B, dir string, numFiles int, fileSize int) { + data := make([]byte, fileSize) + for i := range data { + data[i] = byte(i % 256) + } + + for i := 0; i < numFiles; i++ { + path := filepath.Join(dir, "data", string(rune('a'+i%26)), "chunk.bin") + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { + b.Fatal(err) + } + if err := os.WriteFile(path, data, 0640); err != nil { + b.Fatal(err) + } + } +}