refactor(cli): improve hash utilities and command structure
## Changes - Refactor hash.zig utilities for better performance and maintainability - Clean up command structure in run.zig for clarity - Simplify main.zig entry point organization
This commit is contained in:
parent
eb88d403a1
commit
8ee98eaf7f
3 changed files with 69 additions and 52 deletions
|
|
@ -3,6 +3,7 @@ const core = @import("../core.zig");
|
|||
const config = @import("../config.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const common = @import("common.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
|
||||
const remote = @import("executor/remote.zig");
|
||||
const local = @import("executor/local.zig");
|
||||
|
|
@ -292,23 +293,23 @@ fn handleRerun(
|
|||
// Inherit config if requested (resources, priority, dataset config)
|
||||
if (options.inherit_config or options.inherit_all) {
|
||||
if (options.cpu == 1) { // Only if not explicitly set
|
||||
if (json.getInt(u8, obj, "cpu")) |cpu| {
|
||||
options.cpu = cpu;
|
||||
if (io.jsonGetInt(obj, "cpu")) |cpu| {
|
||||
options.cpu = @intCast(cpu);
|
||||
}
|
||||
}
|
||||
if (options.memory == 4) { // Only if not explicitly set
|
||||
if (json.getInt(u8, obj, "memory")) |mem| {
|
||||
options.memory = mem;
|
||||
if (io.jsonGetInt(obj, "memory")) |mem| {
|
||||
options.memory = @intCast(mem);
|
||||
}
|
||||
}
|
||||
if (options.gpu == 0) { // Only if not explicitly set
|
||||
if (json.getInt(u8, obj, "gpu")) |gpu| {
|
||||
options.gpu = gpu;
|
||||
if (io.jsonGetInt(obj, "gpu")) |gpu| {
|
||||
options.gpu = @intCast(gpu);
|
||||
}
|
||||
}
|
||||
if (options.priority == 5) { // Only if not explicitly set
|
||||
if (json.getInt(u8, obj, "priority")) |prio| {
|
||||
options.priority = prio;
|
||||
if (io.jsonGetInt(obj, "priority")) |prio| {
|
||||
options.priority = @intCast(prio);
|
||||
}
|
||||
}
|
||||
if (options.gpu_memory == null) {
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@ pub fn main() !void {
|
|||
return;
|
||||
};
|
||||
defer std.process.argsFree(allocator, args);
|
||||
const command = args[1];
|
||||
|
||||
if (args.len < 2) {
|
||||
// Handle help flags as valid commands
|
||||
if (std.mem.eql(u8, command, "--help") or std.mem.eql(u8, command, "-h") or args.len < 2) {
|
||||
printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const command = args[1];
|
||||
|
||||
// Fast dispatch using switch on first character
|
||||
switch (command[0]) {
|
||||
'j' => if (std.mem.eql(u8, command, "jupyter")) {
|
||||
|
|
|
|||
|
|
@ -129,48 +129,55 @@ pub const DatasetHash = struct {
|
|||
|
||||
/// Hash entire directory (deterministic, reproducible)
|
||||
/// Algorithm:
|
||||
/// 1. Collect all regular files recursively
|
||||
/// 2. Sort paths lexicographically
|
||||
/// 3. Hash each file
|
||||
/// 1. Collect all regular files recursively (stores full path + relative path)
|
||||
/// 2. Sort by relative paths lexicographically
|
||||
/// 3. Hash each file using full path
|
||||
/// 4. Combine hashes: SHA256(hash1 + hash2 + ...)
|
||||
pub fn hashDirectory(self: *Self, dir_path: []const u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError || std.fs.File.ReadError || error{ InvalidCharacter, InvalidLength })!([64]u8) {
|
||||
// Security: validate directory path
|
||||
try validatePath(dir_path);
|
||||
|
||||
// Collect all files
|
||||
var paths = try std.ArrayList([]const u8).initCapacity(self.allocator, 256);
|
||||
// Collect all files (stores both full and relative paths)
|
||||
var file_entries = try std.ArrayList(FileEntry).initCapacity(self.allocator, 256);
|
||||
defer {
|
||||
for (paths.items) |p| self.allocator.free(p);
|
||||
paths.deinit(self.allocator);
|
||||
for (file_entries.items) |entry| {
|
||||
self.allocator.free(entry.full_path);
|
||||
self.allocator.free(entry.rel_path);
|
||||
}
|
||||
file_entries.deinit(self.allocator);
|
||||
}
|
||||
|
||||
try self.collectFiles(dir_path, &paths, 0);
|
||||
try self.collectFiles(dir_path, dir_path, &file_entries, 0);
|
||||
|
||||
if (paths.items.len == 0) {
|
||||
return error.EmptyDirectory;
|
||||
if (file_entries.items.len == 0) {
|
||||
// Empty directory: return SHA256 of empty string (same as Go/Native)
|
||||
var empty_hasher = crypto.hash.sha2.Sha256.init(.{});
|
||||
var empty_hash_bytes: [32]u8 = undefined;
|
||||
empty_hasher.final(&empty_hash_bytes);
|
||||
var empty_result: [64]u8 = undefined;
|
||||
bytesToHex(&empty_hash_bytes, &empty_result);
|
||||
return empty_result;
|
||||
}
|
||||
|
||||
// Sort lexicographically for reproducibility
|
||||
std.mem.sort([]const u8, paths.items, {}, struct {
|
||||
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
|
||||
return std.mem.lessThan(u8, a, b);
|
||||
// Sort by relative path lexicographically for reproducibility
|
||||
std.mem.sort(FileEntry, file_entries.items, {}, struct {
|
||||
fn lessThan(_: void, a: FileEntry, b: FileEntry) bool {
|
||||
return std.mem.lessThan(u8, a.rel_path, b.rel_path);
|
||||
}
|
||||
}.lessThan);
|
||||
|
||||
// Hash all files and combine
|
||||
// Hash all files and combine (using full_path for actual file access)
|
||||
var combined_hasher = crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
if (self.thread_pool) |pool| {
|
||||
// Parallel hashing with thread pool
|
||||
try self.hashFilesParallel(pool, paths.items, &combined_hasher);
|
||||
try self.hashFilesParallelEntries(pool, file_entries.items, &combined_hasher);
|
||||
} else {
|
||||
// Sequential hashing
|
||||
for (paths.items) |path| {
|
||||
const file_hash = try self.hashFile(path);
|
||||
// Convert hex string to bytes and update combined hasher
|
||||
var hash_bytes: [32]u8 = undefined;
|
||||
_ = try std.fmt.hexToBytes(&hash_bytes, &file_hash);
|
||||
combined_hasher.update(&hash_bytes);
|
||||
// Sequential hashing - use hex string directly (not raw bytes) to match Go/Native
|
||||
for (file_entries.items) |entry| {
|
||||
const file_hash = try self.hashFile(entry.full_path);
|
||||
// Hash the hex string representation (64 chars) to match Go/Native
|
||||
combined_hasher.update(&file_hash);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -182,8 +189,15 @@ pub const DatasetHash = struct {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Entry for a file with both full and relative paths
|
||||
const FileEntry = struct {
|
||||
full_path: []const u8,
|
||||
rel_path: []const u8,
|
||||
};
|
||||
|
||||
/// Collect all regular files recursively
|
||||
fn collectFiles(self: *Self, dir_path: []const u8, paths: *std.ArrayList([]const u8), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void {
|
||||
/// Stores both full path (for hashing) and relative path (for sorting)
|
||||
fn collectFiles(self: *Self, dir_path: []const u8, base_dir: []const u8, entries: *std.ArrayList(FileEntry), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void {
|
||||
// Security: max depth to prevent infinite recursion on cycles
|
||||
if (depth > 32) return error.MaxDepthExceeded;
|
||||
|
||||
|
|
@ -204,7 +218,6 @@ pub const DatasetHash = struct {
|
|||
switch (entry.kind) {
|
||||
.file => {
|
||||
// Security: validate it's a regular file (not symlink)
|
||||
// Try to open without following symlinks
|
||||
const file = std.fs.cwd().openFile(full_path, .{ .mode = .read_only }) catch |err| {
|
||||
self.allocator.free(full_path);
|
||||
if (err == error.AccessDenied) continue;
|
||||
|
|
@ -220,13 +233,15 @@ pub const DatasetHash = struct {
|
|||
|
||||
// Security: only regular files (S_ISREG)
|
||||
if (stat.kind == .file) {
|
||||
try paths.append(self.allocator, full_path);
|
||||
// Compute relative path from base_dir for sorting
|
||||
const rel_path = try std.fs.path.relative(self.allocator, base_dir, full_path);
|
||||
try entries.append(self.allocator, .{ .full_path = full_path, .rel_path = rel_path });
|
||||
} else {
|
||||
self.allocator.free(full_path);
|
||||
}
|
||||
},
|
||||
.directory => {
|
||||
try self.collectFiles(full_path, paths, depth + 1);
|
||||
try self.collectFiles(full_path, base_dir, entries, depth + 1);
|
||||
self.allocator.free(full_path);
|
||||
},
|
||||
else => {
|
||||
|
|
@ -236,37 +251,37 @@ pub const DatasetHash = struct {
|
|||
}
|
||||
}
|
||||
|
||||
/// Hash files in parallel using thread pool
|
||||
fn hashFilesParallel(self: *Self, pool: *std.Thread.Pool, paths: [][]const u8, combined_hasher: *crypto.hash.sha2.Sha256) !void {
|
||||
const num_files = paths.len;
|
||||
/// Hash files in parallel using thread pool (using FileEntry with full_path)
|
||||
fn hashFilesParallelEntries(self: *Self, pool: *std.Thread.Pool, entries: []FileEntry, combined_hasher: *crypto.hash.sha2.Sha256) !void {
|
||||
const num_files = entries.len;
|
||||
|
||||
// Allocate space for all hash results
|
||||
const hashes = try self.allocator.alloc([32]u8, num_files);
|
||||
defer self.allocator.free(hashes);
|
||||
// Allocate space for all hex hash results (64 chars each)
|
||||
const hex_hashes = try self.allocator.alloc([64]u8, num_files);
|
||||
defer self.allocator.free(hex_hashes);
|
||||
|
||||
// Create a WaitGroup for synchronization
|
||||
var wg = std.Thread.WaitGroup{};
|
||||
wg.reset();
|
||||
|
||||
// Submit jobs to thread pool
|
||||
for (paths, 0..) |path, i| {
|
||||
// Submit jobs to thread pool (use full_path for hashing)
|
||||
for (entries, 0..) |entry, i| {
|
||||
pool.spawnWg(&wg, struct {
|
||||
fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[32]u8) void {
|
||||
fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[64]u8) void {
|
||||
const hex_hash = self_ptr.hashFile(file_path) catch |err| {
|
||||
std.log.warn("Failed to hash {s}: {}", .{ file_path, err });
|
||||
return;
|
||||
};
|
||||
_ = std.fmt.hexToBytes(hash_out, &hex_hash) catch return;
|
||||
@memcpy(hash_out, &hex_hash);
|
||||
}
|
||||
}.run, .{ self, path, &hashes[i] });
|
||||
}.run, .{ self, entry.full_path, &hex_hashes[i] });
|
||||
}
|
||||
|
||||
// Wait for all jobs to complete
|
||||
pool.waitAndWork(&wg);
|
||||
|
||||
// Combine all hashes in order
|
||||
for (hashes) |hash| {
|
||||
combined_hasher.update(&hash);
|
||||
// Combine all hashes in order (use hex string directly)
|
||||
for (hex_hashes) |hex_hash| {
|
||||
combined_hasher.update(&hex_hash);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -315,6 +330,7 @@ pub const HashError = error{
|
|||
AccessDenied,
|
||||
OutOfMemory,
|
||||
FileTooLarge,
|
||||
CurrentWorkingDirectoryUnlinked,
|
||||
};
|
||||
|
||||
/// Convenience: hash directory to hex string
|
||||
|
|
|
|||
Loading…
Reference in a new issue