feat: add manifest signing and native hashing support
- Integrate RunManifest.Validate with existing Validator - Add manifest Sign() and Verify() methods - Add native C++ hashing libraries (dataset_hash, queue_index) - Add native bridge for Go/C++ integration - Add deduplication support in queue
This commit is contained in:
parent
a3f9bf8731
commit
37aad7ae87
13 changed files with 775 additions and 74 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
)
|
||||
|
||||
|
|
@ -121,6 +122,11 @@ type RunManifest struct {
|
|||
GPUDevices []string `json:"gpu_devices,omitempty"`
|
||||
WorkerHost string `json:"worker_host,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
|
||||
// Signature fields for tamper detection
|
||||
Signature string `json:"signature,omitempty"`
|
||||
SignerKeyID string `json:"signer_key_id,omitempty"`
|
||||
SigAlg string `json:"sig_alg,omitempty"`
|
||||
}
|
||||
|
||||
func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest {
|
||||
|
|
@ -234,3 +240,60 @@ func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) {
|
|||
m.Narrative.Tags = clean
|
||||
}
|
||||
}
|
||||
|
||||
// Sign signs the manifest using the provided signer
|
||||
func (m *RunManifest) Sign(signer *crypto.ManifestSigner) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("cannot sign nil manifest")
|
||||
}
|
||||
|
||||
result, err := signer.SignManifest(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign manifest: %w", err)
|
||||
}
|
||||
|
||||
m.Signature = result.Signature
|
||||
m.SignerKeyID = result.KeyID
|
||||
m.SigAlg = result.Algorithm
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify verifies the manifest signature using the provided public key
|
||||
func (m *RunManifest) Verify(publicKey []byte) (bool, error) {
|
||||
if m == nil {
|
||||
return false, fmt.Errorf("cannot verify nil manifest")
|
||||
}
|
||||
|
||||
if m.Signature == "" {
|
||||
return false, fmt.Errorf("manifest has no signature")
|
||||
}
|
||||
|
||||
// Build signing result from manifest fields
|
||||
result := &crypto.SigningResult{
|
||||
Signature: m.Signature,
|
||||
KeyID: m.SignerKeyID,
|
||||
Algorithm: m.SigAlg,
|
||||
}
|
||||
|
||||
// Call crypto package to verify
|
||||
return crypto.VerifyManifest(m, result, publicKey)
|
||||
}
|
||||
|
||||
// IsSigned returns true if the manifest has a signature
|
||||
func (m *RunManifest) IsSigned() bool {
|
||||
return m != nil && m.Signature != ""
|
||||
}
|
||||
|
||||
// Validate checks manifest completeness using the standard Validator.
|
||||
// This delegates to Validator.Validate() for consistency.
|
||||
func (m *RunManifest) Validate() error {
|
||||
v := NewValidator()
|
||||
return v.Validate(m)
|
||||
}
|
||||
|
||||
// ValidateStrict performs strict validation including optional provenance fields.
|
||||
// This delegates to Validator.ValidateStrict() for consistency.
|
||||
func (m *RunManifest) ValidateStrict() error {
|
||||
v := NewValidator()
|
||||
return v.ValidateStrict(m)
|
||||
}
|
||||
|
|
|
|||
78
internal/queue/dedup.go
Normal file
78
internal/queue/dedup.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrAlreadyQueued is returned when a job with the same commit was recently queued
|
||||
var ErrAlreadyQueued = fmt.Errorf("job already queued with this commit")
|
||||
|
||||
// CommitDedup tracks recently queued commits to prevent duplicate submissions
|
||||
type CommitDedup struct {
|
||||
mu sync.RWMutex
|
||||
commits map[string]time.Time // key: "job_name:commit_id" -> queued_at
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewCommitDedup creates a new commit deduplication tracker
|
||||
func NewCommitDedup(ttl time.Duration) *CommitDedup {
|
||||
if ttl <= 0 {
|
||||
ttl = 1 * time.Hour // Default 1 hour TTL
|
||||
}
|
||||
return &CommitDedup{
|
||||
commits: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDuplicate checks if a job+commit combination was recently queued
|
||||
func (d *CommitDedup) IsDuplicate(jobName, commitID string) bool {
|
||||
key := d.key(jobName, commitID)
|
||||
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
if t, ok := d.commits[key]; ok {
|
||||
if time.Since(t) < d.ttl {
|
||||
return true // Still within TTL, consider duplicate
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarkQueued records that a job+commit combination was just queued
|
||||
func (d *CommitDedup) MarkQueued(jobName, commitID string) {
|
||||
key := d.key(jobName, commitID)
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.commits[key] = time.Now()
|
||||
}
|
||||
|
||||
// Cleanup removes expired entries (call periodically, e.g., every 5 minutes)
|
||||
func (d *CommitDedup) Cleanup() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, t := range d.commits {
|
||||
if now.Sub(t) > d.ttl {
|
||||
delete(d.commits, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// key generates a unique key for job+commit combination
|
||||
func (d *CommitDedup) key(jobName, commitID string) string {
|
||||
return jobName + ":" + commitID
|
||||
}
|
||||
|
||||
// Size returns the number of tracked commits (for metrics/debugging)
|
||||
func (d *CommitDedup) Size() int {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return len(d.commits)
|
||||
}
|
||||
|
|
@ -26,6 +26,7 @@ type TaskQueue struct {
|
|||
metricsCh chan metricEvent
|
||||
metricsDone chan struct{}
|
||||
flushEvery time.Duration
|
||||
dedup *CommitDedup // Tracks recently queued commits
|
||||
}
|
||||
|
||||
type metricEvent struct {
|
||||
|
|
@ -103,10 +104,12 @@ func NewTaskQueue(cfg Config) (*TaskQueue, error) {
|
|||
metricsCh: make(chan metricEvent, 256),
|
||||
metricsDone: make(chan struct{}),
|
||||
flushEvery: flushEvery,
|
||||
dedup: NewCommitDedup(1 * time.Hour), // 1 hour default TTL for commit dedup
|
||||
}
|
||||
|
||||
go tq.runMetricsBuffer()
|
||||
go tq.runLeaseReclamation() // Start lease reclamation background job
|
||||
go tq.runDedupCleanup() // Start dedup cleanup background job
|
||||
|
||||
return tq, nil
|
||||
}
|
||||
|
|
@ -732,7 +735,38 @@ func (tq *TaskQueue) WaitForNextTask(ctx context.Context, timeout time.Duration)
|
|||
return tq.GetTask(member)
|
||||
}
|
||||
|
||||
// runMetricsBuffer buffers and flushes metrics
|
||||
// runDedupCleanup periodically cleans up expired dedup entries every 5 minutes
|
||||
func (tq *TaskQueue) runDedupCleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tq.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
tq.dedup.Cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddTaskDedup adds a task with commit deduplication check
|
||||
// Returns ErrAlreadyQueued if the same job+commit was recently queued
|
||||
func (tq *TaskQueue) AddTaskDedup(task *Task, commitID string) error {
|
||||
if commitID != "" && tq.dedup.IsDuplicate(task.JobName, commitID) {
|
||||
return ErrAlreadyQueued
|
||||
}
|
||||
|
||||
if err := tq.AddTask(task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark as queued on success
|
||||
if commitID != "" {
|
||||
tq.dedup.MarkQueued(task.JobName, commitID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (tq *TaskQueue) runMetricsBuffer() {
|
||||
defer close(tq.metricsDone)
|
||||
ticker := time.NewTicker(tq.flushEvery)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,15 @@
|
|||
package worker
|
||||
|
||||
import "github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
import (
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
// UseNativeLibs controls whether to use C++ implementations.
|
||||
// Set FETCHML_NATIVE_LIBS=1 to enable native libraries.
|
||||
// This is defined here so it's available regardless of build tags.
|
||||
var UseNativeLibs = false
|
||||
|
||||
// dirOverallSHA256Hex selects implementation based on toggle.
|
||||
// This file has no CGo imports so it compiles even when CGO is disabled.
|
||||
// The actual implementations are in native_bridge.go (native) and integrity package (Go).
|
||||
// dirOverallSHA256Hex uses native implementation when compiled with -tags native_libs.
|
||||
// The build tag selects between native_bridge.go (stub) and native_bridge_libs.go (real).
|
||||
// No runtime configuration needed - build determines behavior.
|
||||
func dirOverallSHA256Hex(root string) (string, error) {
|
||||
if !UseNativeLibs {
|
||||
return integrity.DirOverallSHA256Hex(root)
|
||||
}
|
||||
// native_bridge_libs.go provides this when built with -tags native_libs
|
||||
// native_bridge.go provides stub that falls back to Go
|
||||
return dirOverallSHA256HexNative(root)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,8 +104,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
|
|||
return "", fmt.Errorf("not a directory")
|
||||
}
|
||||
|
||||
// Collect all files
|
||||
// Collect all files with size info
|
||||
var files []string
|
||||
var totalSize int64
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
|
|
@ -118,6 +119,11 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
|
|||
return err
|
||||
}
|
||||
files = append(files, rel)
|
||||
|
||||
// Track total size for optimization decisions
|
||||
if info, err := d.Info(); err == nil {
|
||||
totalSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -183,3 +189,30 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
|
|||
}
|
||||
return fmt.Sprintf("%x", overall.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// EstimateDirSize returns total size of directory contents in bytes
|
||||
func EstimateDirSize(root string) (int64, error) {
|
||||
root = filepath.Clean(root)
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return info.Size(), nil
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if info, err := d.Info(); err == nil {
|
||||
totalSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return totalSize, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
// Even with CGO, native libs require explicit build tag
|
||||
UseNativeLibs = false
|
||||
log.Printf("[native] Native libraries disabled (build with -tags native_libs to enable)")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,12 +21,28 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
|
|||
set(CMAKE_C_FLAGS_RELEASE "-O3 -march=native -DNDEBUG -fomit-frame-pointer")
|
||||
set(CMAKE_C_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer")
|
||||
|
||||
# Security hardening flags (always enabled)
|
||||
set(SECURITY_FLAGS
|
||||
-D_FORTIFY_SOURCE=2 # Buffer overflow protection
|
||||
-fstack-protector-strong # Stack canaries
|
||||
-Wformat-security # Format string warnings
|
||||
-Werror=format-security # Format string errors
|
||||
-fPIE # Position-independent code
|
||||
)
|
||||
|
||||
# Add security flags to all build types
|
||||
add_compile_options(${SECURITY_FLAGS})
|
||||
|
||||
# Linker security flags
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -pie")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now")
|
||||
|
||||
# Warnings
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
|
||||
if(ENABLE_ASAN)
|
||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||
add_link_options(-fsanitize=address)
|
||||
add_compile_options(-fsanitize=address,undefined -fno-omit-frame-pointer)
|
||||
add_link_options(-fsanitize=address,undefined)
|
||||
endif()
|
||||
|
||||
if(ENABLE_TSAN)
|
||||
|
|
@ -59,6 +75,10 @@ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/tests)
|
|||
add_executable(test_storage tests/test_storage.cpp)
|
||||
target_link_libraries(test_storage queue_index)
|
||||
add_test(NAME storage_smoke COMMAND test_storage)
|
||||
|
||||
add_executable(test_dataset_hash tests/test_dataset_hash.cpp)
|
||||
target_link_libraries(test_dataset_hash dataset_hash)
|
||||
add_test(NAME dataset_hash_smoke COMMAND test_dataset_hash)
|
||||
endif()
|
||||
|
||||
# Combined target for building all libraries
|
||||
|
|
|
|||
|
|
@ -5,19 +5,13 @@
|
|||
#include <arm_neon.h>
|
||||
|
||||
static void transform_armv8(uint32_t* state, const uint8_t* data) {
|
||||
// Load the 512-bit message block into 4 128-bit vectors
|
||||
uint32x4_t w0 = vld1q_u32((const uint32_t*)data);
|
||||
uint32x4_t w1 = vld1q_u32((const uint32_t*)(data + 16));
|
||||
uint32x4_t w2 = vld1q_u32((const uint32_t*)(data + 32));
|
||||
uint32x4_t w3 = vld1q_u32((const uint32_t*)(data + 48));
|
||||
// Load message and reverse bytes within each 32-bit word (big-endian -> native)
|
||||
uint32x4_t w0 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data)));
|
||||
uint32x4_t w1 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 16)));
|
||||
uint32x4_t w2 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 32)));
|
||||
uint32x4_t w3 = vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(data + 48)));
|
||||
|
||||
// Reverse byte order (SHA256 uses big-endian words)
|
||||
w0 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w0)));
|
||||
w1 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w1)));
|
||||
w2 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w2)));
|
||||
w3 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(w3)));
|
||||
|
||||
// Load current hash state
|
||||
// Load current hash state (native endianness)
|
||||
uint32x4_t abcd = vld1q_u32(state);
|
||||
uint32x4_t efgh = vld1q_u32(state + 4);
|
||||
uint32x4_t abcd_orig = abcd;
|
||||
|
|
@ -30,20 +24,24 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) {
|
|||
uint32x4_t k3 = vld1q_u32(&K[12]);
|
||||
|
||||
uint32x4_t tmp = vaddq_u32(w0, k0);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
uint32x4_t abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
tmp = vaddq_u32(w1, k1);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
tmp = vaddq_u32(w2, k2);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
tmp = vaddq_u32(w3, k3);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
// Rounds 16-63: Message schedule expansion + rounds
|
||||
for (int i = 16; i < 64; i += 16) {
|
||||
|
|
@ -52,32 +50,36 @@ static void transform_armv8(uint32_t* state, const uint8_t* data) {
|
|||
w4 = vsha256su1q_u32(w4, w2, w3);
|
||||
k0 = vld1q_u32(&K[i]);
|
||||
tmp = vaddq_u32(w4, k0);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
// Schedule expansion for rounds i+4..i+7
|
||||
uint32x4_t w5 = vsha256su0q_u32(w1, w2);
|
||||
w5 = vsha256su1q_u32(w5, w3, w4);
|
||||
k1 = vld1q_u32(&K[i + 4]);
|
||||
tmp = vaddq_u32(w5, k1);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
// Schedule expansion for rounds i+8..i+11
|
||||
uint32x4_t w6 = vsha256su0q_u32(w2, w3);
|
||||
w6 = vsha256su1q_u32(w6, w4, w5);
|
||||
k2 = vld1q_u32(&K[i + 8]);
|
||||
tmp = vaddq_u32(w6, k2);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
// Schedule expansion for rounds i+12..i+15
|
||||
uint32x4_t w7 = vsha256su0q_u32(w3, w4);
|
||||
w7 = vsha256su1q_u32(w7, w5, w6);
|
||||
k3 = vld1q_u32(&K[i + 12]);
|
||||
tmp = vaddq_u32(w7, k3);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp);
|
||||
abcd = vsha256hq_u32(abcd, efgh, tmp);
|
||||
abcd_new = vsha256hq_u32(abcd, efgh, tmp);
|
||||
efgh = vsha256h2q_u32(efgh, abcd, tmp); // Use ORIGINAL abcd
|
||||
abcd = abcd_new;
|
||||
|
||||
// Rotate working variables
|
||||
w0 = w4; w1 = w5; w2 = w6; w3 = w7;
|
||||
|
|
|
|||
|
|
@ -11,19 +11,8 @@ void sha256_init(Sha256State* hasher) {
|
|||
memcpy(hasher->state, H0, sizeof(H0));
|
||||
|
||||
// Detect best transform implementation
|
||||
hasher->transform_fn = detect_best_transform();
|
||||
if (hasher->transform_fn == transform_generic) {
|
||||
// Try platform-specific implementations
|
||||
TransformFunc f = detect_armv8_transform();
|
||||
if (f) {
|
||||
hasher->transform_fn = f;
|
||||
} else {
|
||||
f = detect_x86_transform();
|
||||
if (f) {
|
||||
hasher->transform_fn = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
TransformFunc f = detect_best_transform();
|
||||
hasher->transform_fn = f ? f : transform_generic;
|
||||
}
|
||||
|
||||
void sha256_update(Sha256State* hasher, const uint8_t* data, size_t len) {
|
||||
|
|
|
|||
|
|
@ -73,6 +73,18 @@ int hash_file(const char* path, size_t buffer_size, char* out_hash) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// Hash a single file, allocating result buffer
|
||||
char* hash_file_alloc(const char* path, size_t buffer_size) {
|
||||
char* out_hash = (char*)malloc(65); // 64 hex + null
|
||||
if (!out_hash) return nullptr;
|
||||
|
||||
if (hash_file(path, buffer_size, out_hash) != 0) {
|
||||
free(out_hash);
|
||||
return nullptr;
|
||||
}
|
||||
return out_hash;
|
||||
}
|
||||
|
||||
int hash_files_batch(
|
||||
const char* const* paths,
|
||||
uint32_t count,
|
||||
|
|
@ -84,8 +96,8 @@ int hash_files_batch(
|
|||
int all_success = 1;
|
||||
|
||||
for (uint32_t i = 0; i < count; ++i) {
|
||||
if (hash_file(paths[i], buffer_size, out_hashes[i]) != 0) {
|
||||
out_hashes[i][0] = '\0';
|
||||
out_hashes[i] = hash_file_alloc(paths[i], buffer_size);
|
||||
if (out_hashes[i] == nullptr) {
|
||||
all_success = 0;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
#include "parallel_hash.h"
|
||||
#include "../io/file_hash.h"
|
||||
#include "../crypto/sha256_hasher.h"
|
||||
#include "../../common/include/thread_pool.h"
|
||||
#include <dirent.h>
|
||||
#include <sys/stat.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Simple file collector - just flat directory for now
|
||||
static int collect_files(const char* dir_path, char** out_paths, int max_files) {
|
||||
|
|
@ -55,6 +60,25 @@ void parallel_hasher_cleanup(ParallelHasher* hasher) {
|
|||
hasher->pool = nullptr;
|
||||
}
|
||||
|
||||
// Batch hash task - processes a range of files
|
||||
struct BatchHashTask {
|
||||
const char** paths;
|
||||
char** out_hashes;
|
||||
size_t buffer_size;
|
||||
int start_idx;
|
||||
int end_idx;
|
||||
std::atomic<bool>* success;
|
||||
};
|
||||
|
||||
// Worker function for batch processing
|
||||
static void batch_hash_worker(BatchHashTask* task) {
|
||||
for (int i = task->start_idx; i < task->end_idx; i++) {
|
||||
if (hash_file(task->paths[i], task->buffer_size, task->out_hashes[i]) != 0) {
|
||||
task->success->store(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_hash) {
|
||||
if (!hasher || !path || !out_hash) return -1;
|
||||
|
||||
|
|
@ -70,7 +94,6 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
|
|||
sha256_init(&st);
|
||||
uint8_t result[32];
|
||||
sha256_finalize(&st, result);
|
||||
// Convert to hex
|
||||
static const char hex[] = "0123456789abcdef";
|
||||
for (int i = 0; i < 32; i++) {
|
||||
out_hash[i*2] = hex[(result[i] >> 4) & 0xf];
|
||||
|
|
@ -80,15 +103,63 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
|
|||
return 0;
|
||||
}
|
||||
|
||||
// Hash all files
|
||||
char hashes[256][65];
|
||||
// Convert path_ptrs to const char** for batch task
|
||||
const char* path_array[256];
|
||||
for (int i = 0; i < count; i++) {
|
||||
if (hash_file(paths[i], hasher->buffer_size, hashes[i]) != 0) {
|
||||
return -1;
|
||||
}
|
||||
path_array[i] = path_ptrs[i];
|
||||
}
|
||||
|
||||
// Combine hashes
|
||||
// Parallel hash all files using ThreadPool with batched tasks
|
||||
char hashes[256][65];
|
||||
std::atomic<bool> all_success{true};
|
||||
std::atomic<int> completed_batches{0};
|
||||
|
||||
// Determine batch size - divide files among threads
|
||||
uint32_t num_threads = ThreadPool::default_thread_count();
|
||||
int batch_size = (count + num_threads - 1) / num_threads;
|
||||
if (batch_size < 1) batch_size = 1;
|
||||
int num_batches = (count + batch_size - 1) / batch_size;
|
||||
|
||||
// Allocate batch tasks
|
||||
BatchHashTask* batch_tasks = new BatchHashTask[num_batches];
|
||||
char* hash_ptrs[256];
|
||||
for (int i = 0; i < count; i++) {
|
||||
hash_ptrs[i] = hashes[i];
|
||||
}
|
||||
|
||||
for (int b = 0; b < num_batches; b++) {
|
||||
int start = b * batch_size;
|
||||
int end = start + batch_size;
|
||||
if (end > count) end = count;
|
||||
|
||||
batch_tasks[b].paths = path_array;
|
||||
batch_tasks[b].out_hashes = hash_ptrs;
|
||||
batch_tasks[b].buffer_size = hasher->buffer_size;
|
||||
batch_tasks[b].start_idx = start;
|
||||
batch_tasks[b].end_idx = end;
|
||||
batch_tasks[b].success = &all_success;
|
||||
}
|
||||
|
||||
// Enqueue batch tasks (one per thread, not one per file)
|
||||
for (int b = 0; b < num_batches; b++) {
|
||||
hasher->pool->enqueue([batch_tasks, b, &completed_batches]() {
|
||||
batch_hash_worker(&batch_tasks[b]);
|
||||
completed_batches.fetch_add(1);
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all batches to complete
|
||||
while (completed_batches.load() < num_batches) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
if (!all_success.load()) {
|
||||
delete[] batch_tasks;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Combine hashes deterministically (same order as paths)
|
||||
Sha256State st;
|
||||
sha256_init(&st);
|
||||
for (int i = 0; i < count; i++) {
|
||||
|
|
@ -105,6 +176,7 @@ int parallel_hash_directory(ParallelHasher* hasher, const char* path, char* out_
|
|||
}
|
||||
out_hash[64] = '\0';
|
||||
|
||||
delete[] batch_tasks;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -122,12 +194,55 @@ int parallel_hash_directory_batch(
|
|||
int count = collect_files(path, out_paths, (int)max_results);
|
||||
if (out_count) *out_count = (uint32_t)count;
|
||||
|
||||
// Hash each file
|
||||
for (int i = 0; i < count; i++) {
|
||||
if (hash_file(out_paths ? out_paths[i] : nullptr, hasher->buffer_size, out_hashes[i]) != 0) {
|
||||
out_hashes[i][0] = '\0';
|
||||
}
|
||||
if (count == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 0;
|
||||
// Convert out_paths to const char** for batch task
|
||||
const char* path_array[256];
|
||||
for (int i = 0; i < count; i++) {
|
||||
path_array[i] = out_paths ? out_paths[i] : nullptr;
|
||||
}
|
||||
|
||||
// Parallel hash all files using ThreadPool with batched tasks
|
||||
std::atomic<bool> all_success{true};
|
||||
std::atomic<int> completed_batches{0};
|
||||
|
||||
// Determine batch size
|
||||
uint32_t num_threads = ThreadPool::default_thread_count();
|
||||
int batch_size = (count + num_threads - 1) / num_threads;
|
||||
if (batch_size < 1) batch_size = 1;
|
||||
int num_batches = (count + batch_size - 1) / batch_size;
|
||||
|
||||
// Allocate batch tasks
|
||||
BatchHashTask* batch_tasks = new BatchHashTask[num_batches];
|
||||
|
||||
for (int b = 0; b < num_batches; b++) {
|
||||
int start = b * batch_size;
|
||||
int end = start + batch_size;
|
||||
if (end > count) end = count;
|
||||
|
||||
batch_tasks[b].paths = path_array;
|
||||
batch_tasks[b].out_hashes = out_hashes;
|
||||
batch_tasks[b].buffer_size = hasher->buffer_size;
|
||||
batch_tasks[b].start_idx = start;
|
||||
batch_tasks[b].end_idx = end;
|
||||
batch_tasks[b].success = &all_success;
|
||||
}
|
||||
|
||||
// Enqueue batch tasks
|
||||
for (int b = 0; b < num_batches; b++) {
|
||||
hasher->pool->enqueue([batch_tasks, b, &completed_batches]() {
|
||||
batch_hash_worker(&batch_tasks[b]);
|
||||
completed_batches.fetch_add(1);
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all batches to complete
|
||||
while (completed_batches.load() < num_batches) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
|
||||
delete[] batch_tasks;
|
||||
return all_success.load() ? 0 : -1;
|
||||
}
|
||||
|
|
|
|||
286
native/tests/test_dataset_hash.cpp
Normal file
286
native/tests/test_dataset_hash.cpp
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
// Simple test suite for dataset_hash library (no external dependencies)
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "dataset_hash/dataset_hash.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
using namespace std::chrono;
|
||||
|
||||
// Simple test macros
|
||||
#define TEST_ASSERT(cond) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
fprintf(stderr, "ASSERTION FAILED: %s at line %d\n", #cond, __LINE__); \
|
||||
return 1; \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define TEST_ASSERT_EQ(a, b) TEST_ASSERT((a) == (b))
|
||||
#define TEST_ASSERT_NE(a, b) TEST_ASSERT((a) != (b))
|
||||
#define TEST_ASSERT_STR_EQ(a, b) TEST_ASSERT(strcmp((a), (b)) == 0)
|
||||
|
||||
// Helper functions
|
||||
fs::path create_temp_dir() {
|
||||
fs::path temp = fs::temp_directory_path() / "dataset_hash_test_XXXXXX";
|
||||
fs::create_directories(temp);
|
||||
return temp;
|
||||
}
|
||||
|
||||
void cleanup_temp_dir(const fs::path& dir) {
|
||||
fs::remove_all(dir);
|
||||
}
|
||||
|
||||
void create_test_file(const fs::path& dir, const std::string& name, const std::string& content) {
|
||||
std::ofstream file(dir / name);
|
||||
file << content;
|
||||
file.close();
|
||||
}
|
||||
|
||||
// Test 1: Context creation
|
||||
int test_context_creation() {
|
||||
printf("Testing context creation...\n");
|
||||
|
||||
// Auto-detect threads
|
||||
fh_context_t* ctx = fh_init(0);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
fh_cleanup(ctx);
|
||||
|
||||
// Specific thread count
|
||||
ctx = fh_init(4);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
fh_cleanup(ctx);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 2: SIMD detection
|
||||
int test_simd_detection() {
|
||||
printf("Testing SIMD detection...\n");
|
||||
|
||||
int has_simd = fh_has_simd_sha256();
|
||||
const char* impl_name = fh_get_simd_impl_name();
|
||||
|
||||
printf(" SIMD available: %s\n", has_simd ? "yes" : "no");
|
||||
printf(" Implementation: %s\n", impl_name);
|
||||
|
||||
TEST_ASSERT_NE(impl_name, nullptr);
|
||||
TEST_ASSERT(strlen(impl_name) > 0);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 3: Hash single file
|
||||
int test_hash_single_file() {
|
||||
printf("Testing single file hash...\n");
|
||||
|
||||
fs::path temp = create_temp_dir();
|
||||
|
||||
fh_context_t* ctx = fh_init(1);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
|
||||
// Create test file
|
||||
create_test_file(temp, "test.txt", "Hello, World!");
|
||||
|
||||
// Hash it
|
||||
char* hash = fh_hash_file(ctx, (temp / "test.txt").string().c_str());
|
||||
TEST_ASSERT_NE(hash, nullptr);
|
||||
|
||||
// Verify hash format (64 hex characters + null)
|
||||
TEST_ASSERT_EQ(strlen(hash), 64);
|
||||
|
||||
// Hash should be deterministic
|
||||
char* hash2 = fh_hash_file(ctx, (temp / "test.txt").string().c_str());
|
||||
TEST_ASSERT_NE(hash2, nullptr);
|
||||
TEST_ASSERT_STR_EQ(hash, hash2);
|
||||
|
||||
fh_free_string(hash);
|
||||
fh_free_string(hash2);
|
||||
fh_cleanup(ctx);
|
||||
cleanup_temp_dir(temp);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 4: Hash empty file (known hash)
|
||||
int test_hash_empty_file() {
|
||||
printf("Testing empty file hash...\n");
|
||||
|
||||
fs::path temp = create_temp_dir();
|
||||
|
||||
fh_context_t* ctx = fh_init(1);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
|
||||
// Create empty file
|
||||
create_test_file(temp, "empty.txt", "");
|
||||
|
||||
char* hash = fh_hash_file(ctx, (temp / "empty.txt").string().c_str());
|
||||
TEST_ASSERT_NE(hash, nullptr);
|
||||
TEST_ASSERT_EQ(strlen(hash), 64);
|
||||
|
||||
// Debug: print actual hash
|
||||
printf(" Empty file hash: %s\n", hash);
|
||||
|
||||
// Known SHA-256 of empty string
|
||||
const char* expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
|
||||
printf(" Expected hash: %s\n", expected);
|
||||
TEST_ASSERT_STR_EQ(hash, expected);
|
||||
|
||||
fh_free_string(hash);
|
||||
fh_cleanup(ctx);
|
||||
cleanup_temp_dir(temp);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 5: Hash directory
|
||||
int test_hash_directory() {
|
||||
printf("Testing directory hash...\n");
|
||||
|
||||
fs::path temp = create_temp_dir();
|
||||
|
||||
fh_context_t* ctx = fh_init(4);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
|
||||
// Create directory structure
|
||||
create_test_file(temp, "root.txt", "root");
|
||||
fs::create_directories(temp / "subdir");
|
||||
create_test_file(temp, "subdir/file1.txt", "file1");
|
||||
create_test_file(temp, "subdir/file2.txt", "file2");
|
||||
|
||||
// Hash directory
|
||||
char* hash = fh_hash_directory(ctx, temp.string().c_str());
|
||||
TEST_ASSERT_NE(hash, nullptr);
|
||||
TEST_ASSERT_EQ(strlen(hash), 64);
|
||||
|
||||
// Hash should be deterministic
|
||||
char* hash2 = fh_hash_directory(ctx, temp.string().c_str());
|
||||
TEST_ASSERT_NE(hash2, nullptr);
|
||||
TEST_ASSERT_STR_EQ(hash, hash2);
|
||||
|
||||
fh_free_string(hash);
|
||||
fh_free_string(hash2);
|
||||
fh_cleanup(ctx);
|
||||
cleanup_temp_dir(temp);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 6: Batch hash
|
||||
int test_batch_hash() {
|
||||
printf("Testing batch hash...\n");
|
||||
|
||||
fs::path temp = create_temp_dir();
|
||||
|
||||
fh_context_t* ctx = fh_init(4);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
|
||||
// Create test files
|
||||
const int num_files = 10;
|
||||
std::vector<std::string> paths;
|
||||
std::vector<const char*> c_paths;
|
||||
|
||||
for (int i = 0; i < num_files; i++) {
|
||||
std::string name = "file_" + std::to_string(i) + ".txt";
|
||||
create_test_file(temp, name, "Content " + std::to_string(i));
|
||||
paths.push_back((temp / name).string());
|
||||
c_paths.push_back(paths.back().c_str());
|
||||
}
|
||||
|
||||
// Hash batch
|
||||
std::vector<char*> results(num_files, nullptr);
|
||||
int ret = fh_hash_batch(ctx, c_paths.data(), num_files, results.data());
|
||||
TEST_ASSERT_EQ(ret, 0);
|
||||
|
||||
// Verify all hashes
|
||||
for (int i = 0; i < num_files; i++) {
|
||||
TEST_ASSERT_NE(results[i], nullptr);
|
||||
TEST_ASSERT_EQ(strlen(results[i]), 64);
|
||||
fh_free_string(results[i]);
|
||||
}
|
||||
|
||||
fh_cleanup(ctx);
|
||||
cleanup_temp_dir(temp);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Test 7: Performance test
|
||||
int test_performance() {
|
||||
printf("Testing performance...\n");
|
||||
|
||||
fs::path temp = create_temp_dir();
|
||||
|
||||
fh_context_t* ctx = fh_init(4);
|
||||
TEST_ASSERT_NE(ctx, nullptr);
|
||||
|
||||
// Create 1000 small files
|
||||
const int num_files = 1000;
|
||||
auto start = high_resolution_clock::now();
|
||||
|
||||
for (int i = 0; i < num_files; i++) {
|
||||
create_test_file(temp, "perf_" + std::to_string(i) + ".txt", "content");
|
||||
}
|
||||
|
||||
auto create_end = high_resolution_clock::now();
|
||||
|
||||
// Hash all files
|
||||
char* hash = fh_hash_directory(ctx, temp.string().c_str());
|
||||
TEST_ASSERT_NE(hash, nullptr);
|
||||
|
||||
auto hash_end = high_resolution_clock::now();
|
||||
|
||||
auto create_time = duration_cast<milliseconds>(create_end - start);
|
||||
auto hash_time = duration_cast<milliseconds>(hash_end - create_end);
|
||||
|
||||
printf(" Created %d files in %lld ms\n", num_files, create_time.count());
|
||||
printf(" Hashed %d files in %lld ms\n", num_files, hash_time.count());
|
||||
printf(" Throughput: %.1f files/sec\n", num_files * 1000.0 / hash_time.count());
|
||||
|
||||
fh_free_string(hash);
|
||||
fh_cleanup(ctx);
|
||||
cleanup_temp_dir(temp);
|
||||
|
||||
printf(" PASSED\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Main test runner
|
||||
int main() {
|
||||
printf("\n=== Dataset Hash Library Test Suite ===\n\n");
|
||||
|
||||
int failed = 0;
|
||||
|
||||
failed += test_context_creation();
|
||||
failed += test_simd_detection();
|
||||
failed += test_hash_single_file();
|
||||
failed += test_hash_empty_file();
|
||||
failed += test_hash_directory();
|
||||
failed += test_batch_hash();
|
||||
failed += test_performance();
|
||||
|
||||
printf("\n=== Test Results ===\n");
|
||||
if (failed == 0) {
|
||||
printf("All tests PASSED!\n");
|
||||
return 0;
|
||||
} else {
|
||||
printf("%d test(s) FAILED\n", failed);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
75
tests/benchmarks/dataset_size_comparison_test.go
Normal file
75
tests/benchmarks/dataset_size_comparison_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
// BenchmarkDatasetSizeComparison finds the crossover point where native wins
|
||||
// Run with: FETCHML_NATIVE_LIBS=1 go test -tags native_libs -bench=BenchmarkDatasetSize ./tests/benchmarks/
|
||||
func BenchmarkDatasetSizeComparison(b *testing.B) {
|
||||
sizes := []struct {
|
||||
name string
|
||||
fileSize int
|
||||
numFiles int
|
||||
totalMB int
|
||||
}{
|
||||
{"100MB", 10 * 1024 * 1024, 10, 100}, // 10 x 10MB = 100MB
|
||||
{"500MB", 50 * 1024 * 1024, 10, 500}, // 10 x 50MB = 500MB
|
||||
{"1GB", 100 * 1024 * 1024, 10, 1000}, // 10 x 100MB = 1GB
|
||||
{"2GB", 100 * 1024 * 1024, 20, 2000}, // 20 x 100MB = 2GB
|
||||
{"5GB", 100 * 1024 * 1024, 50, 5000}, // 50 x 100MB = 5GB
|
||||
}
|
||||
|
||||
for _, tc := range sizes {
|
||||
tc := tc // capture range variable
|
||||
b.Run(tc.name+"/GoParallel", func(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := integrity.DirOverallSHA256HexParallel(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(tc.name+"/Native", func(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
createTestFiles(b, tmpDir, tc.numFiles, tc.fileSize)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := worker.DirOverallSHA256HexParallel(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestFiles(b *testing.B, dir string, numFiles int, fileSize int) {
|
||||
data := make([]byte, fileSize)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
path := filepath.Join(dir, "data", string(rune('a'+i%26)), "chunk.bin")
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0640); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue