diff --git a/cli/src/commands/run.zig b/cli/src/commands/run.zig index f4897de..7dd9457 100644 --- a/cli/src/commands/run.zig +++ b/cli/src/commands/run.zig @@ -3,6 +3,7 @@ const core = @import("../core.zig"); const config = @import("../config.zig"); const mode = @import("../mode.zig"); const common = @import("common.zig"); +const io = @import("../utils/io.zig"); const remote = @import("executor/remote.zig"); const local = @import("executor/local.zig"); @@ -292,23 +293,23 @@ fn handleRerun( // Inherit config if requested (resources, priority, dataset config) if (options.inherit_config or options.inherit_all) { if (options.cpu == 1) { // Only if not explicitly set - if (json.getInt(u8, obj, "cpu")) |cpu| { - options.cpu = cpu; + if (io.jsonGetInt(obj, "cpu")) |cpu| { + options.cpu = @intCast(cpu); } } if (options.memory == 4) { // Only if not explicitly set - if (json.getInt(u8, obj, "memory")) |mem| { - options.memory = mem; + if (io.jsonGetInt(obj, "memory")) |mem| { + options.memory = @intCast(mem); } } if (options.gpu == 0) { // Only if not explicitly set - if (json.getInt(u8, obj, "gpu")) |gpu| { - options.gpu = gpu; + if (io.jsonGetInt(obj, "gpu")) |gpu| { + options.gpu = @intCast(gpu); } } if (options.priority == 5) { // Only if not explicitly set - if (json.getInt(u8, obj, "priority")) |prio| { - options.priority = prio; + if (io.jsonGetInt(obj, "priority")) |prio| { + options.priority = @intCast(prio); } } if (options.gpu_memory == null) { diff --git a/cli/src/main.zig b/cli/src/main.zig index be80af5..3a4628c 100644 --- a/cli/src/main.zig +++ b/cli/src/main.zig @@ -16,14 +16,14 @@ pub fn main() !void { return; }; defer std.process.argsFree(allocator, args); + const command = args[1]; - if (args.len < 2) { + // Handle help flags as valid commands + if (std.mem.eql(u8, command, "--help") or std.mem.eql(u8, command, "-h") or args.len < 2) { printUsage(); return; } - const command = args[1]; - // Fast dispatch using switch on first character switch (command[0]) { 'j' => if (std.mem.eql(u8, command, "jupyter")) { diff --git a/cli/src/utils/hash.zig b/cli/src/utils/hash.zig index 6d12e27..3bcbfe1 100644 --- a/cli/src/utils/hash.zig +++ b/cli/src/utils/hash.zig @@ -129,48 +129,55 @@ pub const DatasetHash = struct { /// Hash entire directory (deterministic, reproducible) /// Algorithm: - /// 1. Collect all regular files recursively - /// 2. Sort paths lexicographically - /// 3. Hash each file + /// 1. Collect all regular files recursively (stores full path + relative path) + /// 2. Sort by relative paths lexicographically + /// 3. Hash each file using full path /// 4. Combine hashes: SHA256(hash1 + hash2 + ...) pub fn hashDirectory(self: *Self, dir_path: []const u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError || std.fs.File.ReadError || error{ InvalidCharacter, InvalidLength })!([64]u8) { // Security: validate directory path try validatePath(dir_path); - // Collect all files - var paths = try std.ArrayList([]const u8).initCapacity(self.allocator, 256); + // Collect all files (stores both full and relative paths) + var file_entries = try std.ArrayList(FileEntry).initCapacity(self.allocator, 256); defer { - for (paths.items) |p| self.allocator.free(p); - paths.deinit(self.allocator); + for (file_entries.items) |entry| { + self.allocator.free(entry.full_path); + self.allocator.free(entry.rel_path); + } + file_entries.deinit(self.allocator); } - try self.collectFiles(dir_path, &paths, 0); + try self.collectFiles(dir_path, dir_path, &file_entries, 0); - if (paths.items.len == 0) { - return error.EmptyDirectory; + if (file_entries.items.len == 0) { + // Empty directory: return SHA256 of empty string (same as Go/Native) + var empty_hasher = crypto.hash.sha2.Sha256.init(.{}); + var empty_hash_bytes: [32]u8 = undefined; + empty_hasher.final(&empty_hash_bytes); + var empty_result: [64]u8 = undefined; + bytesToHex(&empty_hash_bytes, &empty_result); + return empty_result; } - // Sort lexicographically for reproducibility - std.mem.sort([]const u8, paths.items, {}, struct { - fn lessThan(_: void, a: []const u8, b: []const u8) bool { - return std.mem.lessThan(u8, a, b); + // Sort by relative path lexicographically for reproducibility + std.mem.sort(FileEntry, file_entries.items, {}, struct { + fn lessThan(_: void, a: FileEntry, b: FileEntry) bool { + return std.mem.lessThan(u8, a.rel_path, b.rel_path); } }.lessThan); - // Hash all files and combine + // Hash all files and combine (using full_path for actual file access) var combined_hasher = crypto.hash.sha2.Sha256.init(.{}); if (self.thread_pool) |pool| { // Parallel hashing with thread pool - try self.hashFilesParallel(pool, paths.items, &combined_hasher); + try self.hashFilesParallelEntries(pool, file_entries.items, &combined_hasher); } else { - // Sequential hashing - for (paths.items) |path| { - const file_hash = try self.hashFile(path); - // Convert hex string to bytes and update combined hasher - var hash_bytes: [32]u8 = undefined; - _ = try std.fmt.hexToBytes(&hash_bytes, &file_hash); - combined_hasher.update(&hash_bytes); + // Sequential hashing - use hex string directly (not raw bytes) to match Go/Native + for (file_entries.items) |entry| { + const file_hash = try self.hashFile(entry.full_path); + // Hash the hex string representation (64 chars) to match Go/Native + combined_hasher.update(&file_hash); } } @@ -182,8 +189,15 @@ pub const DatasetHash = struct { return result; } + /// Entry for a file with both full and relative paths + const FileEntry = struct { + full_path: []const u8, + rel_path: []const u8, + }; + /// Collect all regular files recursively - fn collectFiles(self: *Self, dir_path: []const u8, paths: *std.ArrayList([]const u8), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void { + /// Stores both full path (for hashing) and relative path (for sorting) + fn collectFiles(self: *Self, dir_path: []const u8, base_dir: []const u8, entries: *std.ArrayList(FileEntry), depth: u8) (HashError || std.fs.Dir.OpenError || std.fs.File.OpenError || std.fs.File.StatError)!void { // Security: max depth to prevent infinite recursion on cycles if (depth > 32) return error.MaxDepthExceeded; @@ -204,7 +218,6 @@ pub const DatasetHash = struct { switch (entry.kind) { .file => { // Security: validate it's a regular file (not symlink) - // Try to open without following symlinks const file = std.fs.cwd().openFile(full_path, .{ .mode = .read_only }) catch |err| { self.allocator.free(full_path); if (err == error.AccessDenied) continue; @@ -220,13 +233,15 @@ pub const DatasetHash = struct { // Security: only regular files (S_ISREG) if (stat.kind == .file) { - try paths.append(self.allocator, full_path); + // Compute relative path from base_dir for sorting + const rel_path = try std.fs.path.relative(self.allocator, base_dir, full_path); + try entries.append(self.allocator, .{ .full_path = full_path, .rel_path = rel_path }); } else { self.allocator.free(full_path); } }, .directory => { - try self.collectFiles(full_path, paths, depth + 1); + try self.collectFiles(full_path, base_dir, entries, depth + 1); self.allocator.free(full_path); }, else => { @@ -236,37 +251,37 @@ pub const DatasetHash = struct { } } - /// Hash files in parallel using thread pool - fn hashFilesParallel(self: *Self, pool: *std.Thread.Pool, paths: [][]const u8, combined_hasher: *crypto.hash.sha2.Sha256) !void { - const num_files = paths.len; + /// Hash files in parallel using thread pool (using FileEntry with full_path) + fn hashFilesParallelEntries(self: *Self, pool: *std.Thread.Pool, entries: []FileEntry, combined_hasher: *crypto.hash.sha2.Sha256) !void { + const num_files = entries.len; - // Allocate space for all hash results - const hashes = try self.allocator.alloc([32]u8, num_files); - defer self.allocator.free(hashes); + // Allocate space for all hex hash results (64 chars each) + const hex_hashes = try self.allocator.alloc([64]u8, num_files); + defer self.allocator.free(hex_hashes); // Create a WaitGroup for synchronization var wg = std.Thread.WaitGroup{}; wg.reset(); - // Submit jobs to thread pool - for (paths, 0..) |path, i| { + // Submit jobs to thread pool (use full_path for hashing) + for (entries, 0..) |entry, i| { pool.spawnWg(&wg, struct { - fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[32]u8) void { + fn run(self_ptr: *Self, file_path: []const u8, hash_out: *[64]u8) void { const hex_hash = self_ptr.hashFile(file_path) catch |err| { std.log.warn("Failed to hash {s}: {}", .{ file_path, err }); return; }; - _ = std.fmt.hexToBytes(hash_out, &hex_hash) catch return; + @memcpy(hash_out, &hex_hash); } - }.run, .{ self, path, &hashes[i] }); + }.run, .{ self, entry.full_path, &hex_hashes[i] }); } // Wait for all jobs to complete pool.waitAndWork(&wg); - // Combine all hashes in order - for (hashes) |hash| { - combined_hasher.update(&hash); + // Combine all hashes in order (use hex string directly) + for (hex_hashes) |hex_hash| { + combined_hasher.update(&hex_hash); } } @@ -315,6 +330,7 @@ pub const HashError = error{ AccessDenied, OutOfMemory, FileTooLarge, + CurrentWorkingDirectoryUnlinked, }; /// Convenience: hash directory to hex string