refactor(cli): improve hash utilities and command structure

## Changes
- Refactor hash.zig utilities for better performance and maintainability
- Clean up command structure in run.zig for clarity
- Simplify main.zig entry point organization
This commit is contained in:
Jeremie Fraeys 2026-03-05 14:38:12 -05:00
parent eb88d403a1
commit 8ee98eaf7f
No known key found for this signature in database
3 changed files with 69 additions and 52 deletions

View file

@ -3,6 +3,7 @@ const core = @import("../core.zig");
const config = @import("../config.zig");
const mode = @import("../mode.zig");
const common = @import("common.zig");
const io = @import("../utils/io.zig");
const remote = @import("executor/remote.zig");
const local = @import("executor/local.zig");
@ -292,23 +293,23 @@ fn handleRerun(
// Inherit config if requested (resources, priority, dataset config)
if (options.inherit_config or options.inherit_all) {
if (options.cpu == 1) { // Only if not explicitly set
if (json.getInt(u8, obj, "cpu")) |cpu| {
options.cpu = cpu;
if (io.jsonGetInt(obj, "cpu")) |cpu| {
options.cpu = @intCast(cpu);
}
}
if (options.memory == 4) { // Only if not explicitly set
if (json.getInt(u8, obj, "memory")) |mem| {
options.memory = mem;
if (io.jsonGetInt(obj, "memory")) |mem| {
options.memory = @intCast(mem);
}
}
if (options.gpu == 0) { // Only if not explicitly set
if (json.getInt(u8, obj, "gpu")) |gpu| {
options.gpu = gpu;
if (io.jsonGetInt(obj, "gpu")) |gpu| {
options.gpu = @intCast(gpu);
}
}
if (options.priority == 5) { // Only if not explicitly set
if (json.getInt(u8, obj, "priority")) |prio| {
options.priority = prio;
if (io.jsonGetInt(obj, "priority")) |prio| {
options.priority = @intCast(prio);
}
}
if (options.gpu_memory == null) {

View file

@ -16,14 +16,14 @@ pub fn main() !void {
return;
};
defer std.process.argsFree(allocator, args);
const command = args[1];
if (args.len < 2) {
// Handle help flags as valid commands
if (std.mem.eql(u8, command, "--help") or std.mem.eql(u8, command, "-h") or args.len < 2) {
printUsage();
return;
}
const command = args[1];
// Fast dispatch using switch on first character
switch (command[0]) {
'j' => if (std.mem.eql(u8, command, "jupyter")) {

View file

@ -129,48 +129,55 @@ pub const DatasetHash = struct {
/// Hash entire directory (deterministic, reproducible)
/// Algorithm:
/// 1. Collect all regular files recursively
/// 2. Sort paths lexicographically
/// 3. Hash each file
/// 1. Collect all regular files recursively (stores full path + relative path)
/// 2. Sort by relative paths lexicographically
/// 3. Hash each file using full path
/// 4. Combine hashes: SHA256(hash1 + hash2 + ...)
pub fn hashDirectory(self: *Self, dir_path: []const u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError || std.fs.File.ReadError || error{ InvalidCharacter, InvalidLength })!([64]u8) {
// Security: validate directory path
try validatePath(dir_path);
// Collect all files
var paths = try std.ArrayList([]const u8).initCapacity(self.allocator, 256);
// Collect all files (stores both full and relative paths)
var file_entries = try std.ArrayList(FileEntry).initCapacity(self.allocator, 256);
defer {
for (paths.items) |p| self.allocator.free(p);
paths.deinit(self.allocator);
for (file_entries.items) |entry| {
self.allocator.free(entry.full_path);
self.allocator.free(entry.rel_path);
}
file_entries.deinit(self.allocator);
}
try self.collectFiles(dir_path, &paths, 0);
try self.collectFiles(dir_path, dir_path, &file_entries, 0);
if (paths.items.len == 0) {
return error.EmptyDirectory;
if (file_entries.items.len == 0) {
// Empty directory: return SHA256 of empty string (same as Go/Native)
var empty_hasher = crypto.hash.sha2.Sha256.init(.{});
var empty_hash_bytes: [32]u8 = undefined;
empty_hasher.final(&empty_hash_bytes);
var empty_result: [64]u8 = undefined;
bytesToHex(&empty_hash_bytes, &empty_result);
return empty_result;
}
// Sort lexicographically for reproducibility
std.mem.sort([]const u8, paths.items, {}, struct {
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
return std.mem.lessThan(u8, a, b);
// Sort by relative path lexicographically for reproducibility
std.mem.sort(FileEntry, file_entries.items, {}, struct {
fn lessThan(_: void, a: FileEntry, b: FileEntry) bool {
return std.mem.lessThan(u8, a.rel_path, b.rel_path);
}
}.lessThan);
// Hash all files and combine
// Hash all files and combine (using full_path for actual file access)
var combined_hasher = crypto.hash.sha2.Sha256.init(.{});
if (self.thread_pool) |pool| {
// Parallel hashing with thread pool
try self.hashFilesParallel(pool, paths.items, &combined_hasher);
try self.hashFilesParallelEntries(pool, file_entries.items, &combined_hasher);
} else {
// Sequential hashing
for (paths.items) |path| {
const file_hash = try self.hashFile(path);
// Convert hex string to bytes and update combined hasher
var hash_bytes: [32]u8 = undefined;
_ = try std.fmt.hexToBytes(&hash_bytes, &file_hash);
combined_hasher.update(&hash_bytes);
// Sequential hashing - use hex string directly (not raw bytes) to match Go/Native
for (file_entries.items) |entry| {
const file_hash = try self.hashFile(entry.full_path);
// Hash the hex string representation (64 chars) to match Go/Native
combined_hasher.update(&file_hash);
}
}
@ -182,8 +189,15 @@ pub const DatasetHash = struct {
return result;
}
/// Entry for a file with both full and relative paths
const FileEntry = struct {
full_path: []const u8,
rel_path: []const u8,
};
/// Collect all regular files recursively
fn collectFiles(self: *Self, dir_path: []const u8, paths: *std.ArrayList([]const u8), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void {
/// Stores both full path (for hashing) and relative path (for sorting)
fn collectFiles(self: *Self, dir_path: []const u8, base_dir: []const u8, entries: *std.ArrayList(FileEntry), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void {
// Security: max depth to prevent infinite recursion on cycles
if (depth > 32) return error.MaxDepthExceeded;
@ -204,7 +218,6 @@ pub const DatasetHash = struct {
switch (entry.kind) {
.file => {
// Security: validate it's a regular file (not symlink)
// Try to open without following symlinks
const file = std.fs.cwd().openFile(full_path, .{ .mode = .read_only }) catch |err| {
self.allocator.free(full_path);
if (err == error.AccessDenied) continue;
@ -220,13 +233,15 @@ pub const DatasetHash = struct {
// Security: only regular files (S_ISREG)
if (stat.kind == .file) {
try paths.append(self.allocator, full_path);
// Compute relative path from base_dir for sorting
const rel_path = try std.fs.path.relative(self.allocator, base_dir, full_path);
try entries.append(self.allocator, .{ .full_path = full_path, .rel_path = rel_path });
} else {
self.allocator.free(full_path);
}
},
.directory => {
try self.collectFiles(full_path, paths, depth + 1);
try self.collectFiles(full_path, base_dir, entries, depth + 1);
self.allocator.free(full_path);
},
else => {
@ -236,37 +251,37 @@ pub const DatasetHash = struct {
}
}
/// Hash files in parallel using thread pool
fn hashFilesParallel(self: *Self, pool: *std.Thread.Pool, paths: [][]const u8, combined_hasher: *crypto.hash.sha2.Sha256) !void {
const num_files = paths.len;
/// Hash files in parallel using thread pool (using FileEntry with full_path)
fn hashFilesParallelEntries(self: *Self, pool: *std.Thread.Pool, entries: []FileEntry, combined_hasher: *crypto.hash.sha2.Sha256) !void {
const num_files = entries.len;
// Allocate space for all hash results
const hashes = try self.allocator.alloc([32]u8, num_files);
defer self.allocator.free(hashes);
// Allocate space for all hex hash results (64 chars each)
const hex_hashes = try self.allocator.alloc([64]u8, num_files);
defer self.allocator.free(hex_hashes);
// Create a WaitGroup for synchronization
var wg = std.Thread.WaitGroup{};
wg.reset();
// Submit jobs to thread pool
for (paths, 0..) |path, i| {
// Submit jobs to thread pool (use full_path for hashing)
for (entries, 0..) |entry, i| {
pool.spawnWg(&wg, struct {
fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[32]u8) void {
fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[64]u8) void {
const hex_hash = self_ptr.hashFile(file_path) catch |err| {
std.log.warn("Failed to hash {s}: {}", .{ file_path, err });
return;
};
_ = std.fmt.hexToBytes(hash_out, &hex_hash) catch return;
@memcpy(hash_out, &hex_hash);
}
}.run, .{ self, path, &hashes[i] });
}.run, .{ self, entry.full_path, &hex_hashes[i] });
}
// Wait for all jobs to complete
pool.waitAndWork(&wg);
// Combine all hashes in order
for (hashes) |hash| {
combined_hasher.update(&hash);
// Combine all hashes in order (use hex string directly)
for (hex_hashes) |hex_hash| {
combined_hasher.update(&hex_hash);
}
}
@ -315,6 +330,7 @@ pub const HashError = error{
AccessDenied,
OutOfMemory,
FileTooLarge,
CurrentWorkingDirectoryUnlinked,
};
/// Convenience: hash directory to hex string