From d3461cd07fbf4a883cba8616f03ed6b7908abec2 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Fri, 20 Feb 2026 21:28:34 -0500 Subject: [PATCH] feat(cli): Update server integration commands - queue.zig: Add --rerun flag to re-queue completed local runs - Requires server connection, rejects in offline mode with clear error - HandleRerun function sends rerun request via WebSocket - sync.zig: Rewrite for WebSocket experiment sync protocol - Queries unsynced runs from SQLite ml_runs table - Builds sync JSON with metrics and params - Sends sync_run message, waits for sync_ack response - MarkRunSynced updates synced flag in database - watch.zig: Add --sync flag for continuous experiment sync - Auto-sync runs to server every 30 seconds when online - Mode detection with offline error handling --- cli/src/commands/queue.zig | 85 +++++- cli/src/commands/queue/index.zig | 3 + cli/src/commands/queue/parse.zig | 177 +++++++++++ cli/src/commands/queue/submit.zig | 200 +++++++++++++ cli/src/commands/queue/validate.zig | 161 ++++++++++ cli/src/commands/sync.zig | 440 +++++++++++++++------------- cli/src/commands/watch.zig | 154 ++++------ 7 files changed, 913 insertions(+), 307 deletions(-) create mode 100644 cli/src/commands/queue/index.zig create mode 100644 cli/src/commands/queue/parse.zig create mode 100644 cli/src/commands/queue/submit.zig create mode 100644 cli/src/commands/queue/validate.zig diff --git a/cli/src/commands/queue.zig b/cli/src/commands/queue.zig index b19a6b7..ff03943 100644 --- a/cli/src/commands/queue.zig +++ b/cli/src/commands/queue.zig @@ -6,6 +6,9 @@ const history = @import("../utils/history.zig"); const crypto = @import("../utils/crypto.zig"); const protocol = @import("../net/protocol.zig"); const stdcrypto = std.crypto; +const mode = @import("../mode.zig"); +const db = @import("../db.zig"); +const manifest_lib = @import("../manifest.zig"); pub const TrackingConfig = struct { mlflow: ?MLflowConfig = null, @@ -103,6 +106,45 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { return; } + // Load config for mode detection + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Detect mode early to provide clear error for offline + const mode_result = try mode.detect(allocator, config); + + // Check for --rerun flag + var rerun_id: ?[]const u8 = null; + for (args, 0..) |arg, i| { + if (std.mem.eql(u8, arg, "--rerun") and i + 1 < args.len) { + rerun_id = args[i + 1]; + break; + } + } + + // If --rerun is specified, handle re-queueing + if (rerun_id) |id| { + if (mode.isOffline(mode_result.mode)) { + colors.printError("ml queue --rerun requires server connection\n", .{}); + return error.RequiresServer; + } + return try handleRerun(allocator, id, args, config); + } + + // Regular queue - requires server + if (mode.isOffline(mode_result.mode)) { + colors.printError("ml queue requires server connection (use 'ml run' for local execution)\n", .{}); + return error.RequiresServer; + } + + // Continue with regular queue logic... + try executeQueue(allocator, args, config); +} + +fn executeQueue(allocator: std.mem.Allocator, args: []const []const u8, config: Config) !void { // Support batch operations - multiple job names var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { colors.printError("Failed to allocate job list: {}\n", .{err}); @@ -117,13 +159,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var args_override: ?[]const u8 = null; var note_override: ?[]const u8 = null; - // Load configuration to get defaults - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); - } - // Initialize options with config defaults var options = QueueOptions{ .cpu = config.default_cpu, @@ -391,6 +426,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } } +/// Handle --rerun flag: re-queue a completed run +fn handleRerun(allocator: std.mem.Allocator, run_id: []const u8, args: []const []const u8, cfg: Config) !void { + _ = args; // Override args not implemented yet + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try cfg.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + // Send rerun request to server + try client.sendRerunRequest(run_id, api_key_hash); + + // Wait for response + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + // Parse response (simplified) + if (std.mem.indexOf(u8, message, "success") != null) { + colors.printSuccess("✓ Re-queued run {s}\n", .{run_id[0..8]}); + } else { + colors.printError("Failed to re-queue: {s}\n", .{message}); + return error.RerunFailed; + } +} + fn generateCommitID(allocator: std.mem.Allocator) ![]const u8 { var bytes: [20]u8 = undefined; stdcrypto.random.bytes(&bytes); @@ -621,6 +685,7 @@ fn queueSingleJob( fn printUsage() !void { colors.printInfo("Usage: ml queue [job-name ...] [options]\n", .{}); + colors.printInfo(" ml queue --rerun # Re-queue a completed run\n", .{}); colors.printInfo("\nBasic Options:\n", .{}); colors.printInfo(" --commit Specify commit ID\n", .{}); colors.printInfo(" --priority Set priority (0-255, default: 5)\n", .{}); @@ -640,6 +705,7 @@ fn printUsage() !void { colors.printInfo(" --experiment-group Group related experiments\n", .{}); colors.printInfo(" --tags Comma-separated tags (e.g., ablation,batch-size)\n", .{}); colors.printInfo("\nSpecial Modes:\n", .{}); + colors.printInfo(" --rerun Re-queue a completed local run to server\n", .{}); colors.printInfo(" --dry-run Show what would be queued\n", .{}); colors.printInfo(" --validate Validate experiment without queuing\n", .{}); colors.printInfo(" --explain Explain what will happen\n", .{}); @@ -662,10 +728,11 @@ fn printUsage() !void { colors.printInfo(" ml queue my_job # Queue a job\n", .{}); colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{}); colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{}); + colors.printInfo(" ml queue --rerun abc123 # Re-queue completed run\n", .{}); colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{}); colors.printInfo("\nResearch Examples:\n", .{}); - colors.printInfo(" ml queue train.py --hypothesis \"LR scaling improves convergence\" \\\n", .{}); - colors.printInfo(" --context \"Following paper XYZ\" --tags ablation,lr-scaling\n", .{}); + colors.printInfo(" ml queue train.py --hypothesis 'LR scaling improves convergence' \n", .{}); + colors.printInfo(" --context 'Following paper XYZ' --tags ablation,lr-scaling\n", .{}); } pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 { diff --git a/cli/src/commands/queue/index.zig b/cli/src/commands/queue/index.zig new file mode 100644 index 0000000..2903c81 --- /dev/null +++ b/cli/src/commands/queue/index.zig @@ -0,0 +1,3 @@ +pub const parse = @import("queue/parse.zig"); +pub const validate = @import("queue/validate.zig"); +pub const submit = @import("queue/submit.zig"); diff --git a/cli/src/commands/queue/parse.zig b/cli/src/commands/queue/parse.zig new file mode 100644 index 0000000..f8a0f1b --- /dev/null +++ b/cli/src/commands/queue/parse.zig @@ -0,0 +1,177 @@ +const std = @import("std"); + +/// Parse job template from command line arguments +pub const JobTemplate = struct { + job_names: std.ArrayList([]const u8), + commit_id_override: ?[]const u8, + priority: u8, + snapshot_id: ?[]const u8, + snapshot_sha256: ?[]const u8, + args_override: ?[]const u8, + note_override: ?[]const u8, + cpu: u8, + memory: u8, + gpu: u8, + gpu_memory: ?[]const u8, + dry_run: bool, + validate: bool, + explain: bool, + json: bool, + force: bool, + runner_args_start: ?usize, + + pub fn init(allocator: std.mem.Allocator) JobTemplate { + return .{ + .job_names = std.ArrayList([]const u8).init(allocator), + .commit_id_override = null, + .priority = 5, + .snapshot_id = null, + .snapshot_sha256 = null, + .args_override = null, + .note_override = null, + .cpu = 2, + .memory = 8, + .gpu = 0, + .gpu_memory = null, + .dry_run = false, + .validate = false, + .explain = false, + .json = false, + .force = false, + .runner_args_start = null, + }; + } + + pub fn deinit(self: *JobTemplate, allocator: std.mem.Allocator) void { + self.job_names.deinit(allocator); + } +}; + +/// Parse command arguments into a job template +pub fn parseArgs(allocator: std.mem.Allocator, args: []const []const u8) !JobTemplate { + var template = JobTemplate.init(allocator); + errdefer template.deinit(allocator); + + var i: usize = 0; + while (i < args.len) : (i += 1) { + const arg = args[i]; + + if (std.mem.eql(u8, arg, "--")) { + template.runner_args_start = i + 1; + break; + } else if (std.mem.eql(u8, arg, "--commit-id")) { + if (i + 1 < args.len) { + template.commit_id_override = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--priority")) { + if (i + 1 < args.len) { + template.priority = std.fmt.parseInt(u8, args[i + 1], 10) catch 5; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--snapshot")) { + if (i + 1 < args.len) { + template.snapshot_id = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--snapshot-sha256")) { + if (i + 1 < args.len) { + template.snapshot_sha256 = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--args")) { + if (i + 1 < args.len) { + template.args_override = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--note")) { + if (i + 1 < args.len) { + template.note_override = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--cpu")) { + if (i + 1 < args.len) { + template.cpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 2; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--memory")) { + if (i + 1 < args.len) { + template.memory = std.fmt.parseInt(u8, args[i + 1], 10) catch 8; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--gpu")) { + if (i + 1 < args.len) { + template.gpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 0; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--gpu-memory")) { + if (i + 1 < args.len) { + template.gpu_memory = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--dry-run")) { + template.dry_run = true; + } else if (std.mem.eql(u8, arg, "--validate")) { + template.validate = true; + } else if (std.mem.eql(u8, arg, "--explain")) { + template.explain = true; + } else if (std.mem.eql(u8, arg, "--json")) { + template.json = true; + } else if (std.mem.eql(u8, arg, "--force")) { + template.force = true; + } else if (!std.mem.startsWith(u8, arg, "-")) { + // Positional argument - job name + try template.job_names.append(arg); + } + } + + return template; +} + +/// Get runner args from the parsed template +pub fn getRunnerArgs(self: JobTemplate, all_args: []const []const u8) []const []const u8 { + if (self.runner_args_start) |start| { + if (start < all_args.len) { + return all_args[start..]; + } + } + return &[_][]const u8{}; +} + +/// Resolve commit ID from prefix or full hash +pub fn resolveCommitId(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 { + if (input.len < 7 or input.len > 40) return error.InvalidArgs; + for (input) |c| { + if (!std.ascii.isHex(c)) return error.InvalidArgs; + } + + if (input.len == 40) { + return allocator.dupe(u8, input); + } + + var dir = if (std.fs.path.isAbsolute(base_path)) + try std.fs.openDirAbsolute(base_path, .{ .iterate = true }) + else + try std.fs.cwd().openDir(base_path, .{ .iterate = true }); + defer dir.close(); + + var it = dir.iterate(); + var found: ?[]u8 = null; + errdefer if (found) |s| allocator.free(s); + + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + const name = entry.name; + if (name.len != 40) continue; + if (!std.mem.startsWith(u8, name, input)) continue; + for (name) |c| { + if (!std.ascii.isHex(c)) break; + } else { + if (found != null) return error.InvalidArgs; + found = try allocator.dupe(u8, name); + } + } + + if (found) |s| return s; + return error.FileNotFound; +} diff --git a/cli/src/commands/queue/submit.zig b/cli/src/commands/queue/submit.zig new file mode 100644 index 0000000..633e2bd --- /dev/null +++ b/cli/src/commands/queue/submit.zig @@ -0,0 +1,200 @@ +const std = @import("std"); +const ws = @import("../../net/ws/client.zig"); +const protocol = @import("../../net/protocol.zig"); +const crypto = @import("../../utils/crypto.zig"); +const Config = @import("../../config.zig").Config; +const core = @import("../../core.zig"); +const history = @import("../../utils/history.zig"); + +/// Job submission configuration +pub const SubmitConfig = struct { + job_names: []const []const u8, + commit_id: ?[]const u8, + priority: u8, + snapshot_id: ?[]const u8, + snapshot_sha256: ?[]const u8, + args_override: ?[]const u8, + note_override: ?[]const u8, + cpu: u8, + memory: u8, + gpu: u8, + gpu_memory: ?[]const u8, + dry_run: bool, + force: bool, + runner_args: []const []const u8, + + pub fn estimateTotalJobs(self: SubmitConfig) usize { + return self.job_names.len; + } +}; + +/// Submission result +pub const SubmitResult = struct { + success: bool, + job_count: usize, + errors: std.ArrayList([]const u8), + + pub fn init(allocator: std.mem.Allocator) SubmitResult { + return .{ + .success = true, + .job_count = 0, + .errors = std.ArrayList([]const u8).init(allocator), + }; + } + + pub fn deinit(self: *SubmitResult, allocator: std.mem.Allocator) void { + for (self.errors.items) |err| { + allocator.free(err); + } + self.errors.deinit(allocator); + } +}; + +/// Submit jobs to the server +pub fn submitJobs( + allocator: std.mem.Allocator, + config: Config, + submit_config: SubmitConfig, + json: bool, +) !SubmitResult { + var result = SubmitResult.init(allocator); + errdefer result.deinit(allocator); + + // Dry run mode - just print what would be submitted + if (submit_config.dry_run) { + if (json) { + std.debug.print("{{\"success\":true,\"command\":\"queue.submit\",\"dry_run\":true,\"jobs\":[", .{}); + for (submit_config.job_names, 0..) |name, i| { + if (i > 0) std.debug.print(",", .{}); + std.debug.print("\"{s}\"", .{name}); + } + std.debug.print("],\"total\":{d}}}}}\n", .{submit_config.job_names.len}); + } else { + std.debug.print("[DRY RUN] Would submit {d} jobs:\n", .{submit_config.job_names.len}); + for (submit_config.job_names) |name| { + std.debug.print(" - {s}\n", .{name}); + } + } + result.job_count = submit_config.job_names.len; + return result; + } + + // Get WebSocket URL + const ws_url = try config.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + // Hash API key + const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); + defer allocator.free(api_key_hash); + + // Connect to server + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + const msg = try std.fmt.allocPrint(allocator, "Failed to connect: {}", .{err}); + result.addError(msg); + result.success = false; + return result; + }; + defer client.close(); + + // Submit each job + for (submit_config.job_names) |job_name| { + submitSingleJob( + allocator, + &client, + api_key_hash, + job_name, + submit_config, + &result, + ) catch |err| { + const msg = try std.fmt.allocPrint(allocator, "Failed to submit {s}: {}", .{ job_name, err }); + result.addError(msg); + result.success = false; + }; + } + + // Save to history if successful + if (result.success and result.job_count > 0) { + if (submit_config.commit_id) |commit_id| { + for (submit_config.job_names) |job_name| { + history.saveEntry(allocator, job_name, commit_id) catch {}; + } + } + } + + return result; +} + +/// Submit a single job +fn submitSingleJob( + allocator: std.mem.Allocator, + client: *ws.Client, + _: []const u8, + job_name: []const u8, + submit_config: SubmitConfig, + result: *SubmitResult, +) !void { + // Build job submission payload + var payload = std.ArrayList(u8).init(allocator); + defer payload.deinit(); + + const writer = payload.writer(); + try writer.print( + "{{\"job_name\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory\":{d},\"gpu\":{d}", + .{ job_name, submit_config.priority, submit_config.cpu, submit_config.memory, submit_config.gpu }, + ); + + if (submit_config.gpu_memory) |gm| { + try writer.print(",\"gpu_memory\":\"{s}\"", .{gm}); + } + + try writer.print("}}", .{}); + + if (submit_config.commit_id) |cid| { + try writer.print(",\"commit_id\":\"{s}\"", .{cid}); + } + + if (submit_config.snapshot_id) |sid| { + try writer.print(",\"snapshot_id\":\"{s}\"", .{sid}); + } + + if (submit_config.note_override) |note| { + try writer.print(",\"note\":\"{s}\"", .{note}); + } + + try writer.print("}}", .{}); + + // Send job submission + client.sendMessage(payload.items) catch |err| { + return err; + }; + + result.job_count += 1; +} + +/// Print submission results +pub fn printResults(result: SubmitResult, json: bool) void { + if (json) { + const status = if (result.success) "true" else "false"; + std.debug.print("{{\"success\":{s},\"command\":\"queue.submit\",\"data\":{{\"submitted\":{d}", .{ status, result.job_count }); + + if (result.errors.items.len > 0) { + std.debug.print(",\"errors\":[", .{}); + for (result.errors.items, 0..) |err, i| { + if (i > 0) std.debug.print(",", .{}); + std.debug.print("\"{s}\"", .{err}); + } + std.debug.print("]", .{}); + } + + std.debug.print("}}}}\n", .{}); + } else { + if (result.success) { + std.debug.print("Successfully submitted {d} jobs\n", .{result.job_count}); + } else { + std.debug.print("Failed to submit jobs ({d} errors)\n", .{result.errors.items.len}); + for (result.errors.items) |err| { + std.debug.print(" Error: {s}\n", .{err}); + } + } + } +} diff --git a/cli/src/commands/queue/validate.zig b/cli/src/commands/queue/validate.zig new file mode 100644 index 0000000..707233c --- /dev/null +++ b/cli/src/commands/queue/validate.zig @@ -0,0 +1,161 @@ +const std = @import("std"); + +/// Validation errors for queue operations +pub const ValidationError = error{ + MissingJobName, + InvalidCommitId, + InvalidSnapshotId, + InvalidResourceLimits, + DuplicateJobName, + InvalidPriority, +}; + +/// Validation result +pub const ValidationResult = struct { + valid: bool, + errors: std.ArrayList([]const u8), + + pub fn init(allocator: std.mem.Allocator) ValidationResult { + return .{ + .valid = true, + .errors = std.ArrayList([]const u8).init(allocator), + }; + } + + pub fn deinit(self: *ValidationResult, allocator: std.mem.Allocator) void { + for (self.errors.items) |err| { + allocator.free(err); + } + self.errors.deinit(allocator); + } + + pub fn addError(self: *ValidationResult, allocator: std.mem.Allocator, msg: []const u8) void { + self.valid = false; + const copy = allocator.dupe(u8, msg) catch return; + self.errors.append(copy) catch { + allocator.free(copy); + }; + } +}; + +/// Validate job name format +pub fn validateJobName(name: []const u8) bool { + if (name.len == 0 or name.len > 128) return false; + + for (name) |c| { + if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') { + return false; + } + } + return true; +} + +/// Validate commit ID format (40 character hex) +pub fn validateCommitId(id: []const u8) bool { + if (id.len != 40) return false; + for (id) |c| { + if (!std.ascii.isHex(c)) return false; + } + return true; +} + +/// Validate snapshot ID format +pub fn validateSnapshotId(id: []const u8) bool { + if (id.len == 0 or id.len > 64) return false; + for (id) |c| { + if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') { + return false; + } + } + return true; +} + +/// Validate resource limits +pub fn validateResources(cpu: u8, memory: u8, gpu: u8) ValidationError!void { + if (cpu == 0 or cpu > 128) { + return error.InvalidResourceLimits; + } + if (memory == 0 or memory > 1024) { + return error.InvalidResourceLimits; + } + if (gpu > 16) { + return error.InvalidResourceLimits; + } +} + +/// Validate priority value (1-10) +pub fn validatePriority(priority: u8) ValidationError!void { + if (priority < 1 or priority > 10) { + return error.InvalidPriority; + } +} + +/// Full validation for job template +pub fn validateJobTemplate( + allocator: std.mem.Allocator, + job_names: []const []const u8, + commit_id: ?[]const u8, + cpu: u8, + memory: u8, + gpu: u8, +) !ValidationResult { + var result = ValidationResult.init(allocator); + errdefer result.deinit(allocator); + + // Check job names + if (job_names.len == 0) { + result.addError(allocator, "At least one job name is required"); + return result; + } + + // Check for duplicates + var seen = std.StringHashMap(void).init(allocator); + defer seen.deinit(); + + for (job_names) |name| { + if (!validateJobName(name)) { + const msg = try std.fmt.allocPrint(allocator, "Invalid job name: {s}", .{name}); + result.addError(allocator, msg); + allocator.free(msg); + } + + if (seen.contains(name)) { + const msg = try std.fmt.allocPrint(allocator, "Duplicate job name: {s}", .{name}); + result.addError(allocator, msg); + allocator.free(msg); + } else { + try seen.put(name, {}); + } + } + + // Validate commit ID if provided + if (commit_id) |id| { + if (!validateCommitId(id)) { + result.addError(allocator, "Invalid commit ID format (expected 40 character hex)"); + } + } + + // Validate resources + validateResources(cpu, memory, gpu) catch { + result.addError(allocator, "Invalid resource limits"); + }; + + return result; +} + +/// Print validation errors +pub fn printValidationErrors(result: ValidationResult, json: bool) void { + if (json) { + std.debug.print("{{\"success\":false,\"command\":\"queue.validate\",\"errors\":[", .{}); + for (result.errors.items, 0..) |err, i| { + if (i > 0) std.debug.print(",", .{}); + std.debug.print("\"{s}\"", .{err}); + } + std.debug.print("]}}\n", .{}); + } else { + std.debug.print("Validation failed:\n", .{}); + for (result.errors.items) |err| { + std.debug.print(" - {s}\n", .{err}); + } + } +} diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig index 0ad4d87..6ef9737 100644 --- a/cli/src/commands/sync.zig +++ b/cli/src/commands/sync.zig @@ -1,236 +1,262 @@ const std = @import("std"); const colors = @import("../utils/colors.zig"); -const Config = @import("../config.zig").Config; -const crypto = @import("../utils/crypto.zig"); -const rsync = @import("../utils/rsync_embedded.zig"); +const config = @import("../config.zig"); +const db = @import("../db.zig"); const ws = @import("../net/ws/client.zig"); -const logging = @import("../utils/logging.zig"); -const json = @import("../utils/json.zig"); -const native_hash = @import("../utils/native_hash.zig"); -const ProgressBar = @import("../ui/progress.zig").ProgressBar; +const crypto = @import("../utils/crypto.zig"); +const mode = @import("../mode.zig"); +const core = @import("../core.zig"); +const manifest_lib = @import("../manifest.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { - if (args.len == 0) { - printUsage(); - return error.InvalidArgs; - } + var flags = core.flags.CommonFlags{}; + var specific_run_id: ?[]const u8 = null; - // Global flags for (args) |arg| { if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { - printUsage(); + return printUsage(); + } else if (std.mem.eql(u8, arg, "--json")) { + flags.json = true; + } else if (!std.mem.startsWith(u8, arg, "--")) { + specific_run_id = arg; + } + } + + core.output.init(if (flags.json) .json else .text); + + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const mode_result = try mode.detect(allocator, cfg); + if (mode.isOffline(mode_result.mode)) { + colors.printError("ml sync requires server connection\n", .{}); + return error.RequiresServer; + } + + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + var runs_to_sync = std.ArrayList(RunInfo).init(allocator); + defer { + for (runs_to_sync.items) |*r| r.deinit(allocator); + runs_to_sync.deinit(); + } + + if (specific_run_id) |run_id| { + const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + try db.DB.bindText(stmt, 1, run_id); + if (try db.DB.step(stmt)) { + try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); + } else { + colors.printWarning("Run {s} already synced or not found\n", .{run_id}); return; } - } - - const path = args[0]; - var job_name: ?[]const u8 = null; - var should_queue = false; - var priority: u8 = 5; - var json_mode: bool = false; - var dev_mode: bool = false; - var use_timestamp_check = false; - var dry_run = false; - - // Parse flags - var i: usize = 1; - while (i < args.len) : (i += 1) { - if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) { - job_name = args[i + 1]; - i += 1; - } else if (std.mem.eql(u8, args[i], "--queue")) { - should_queue = true; - } else if (std.mem.eql(u8, args[i], "--json")) { - json_mode = true; - } else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) { - priority = try std.fmt.parseInt(u8, args[i + 1], 10); - i += 1; - } else if (std.mem.eql(u8, args[i], "--dev")) { - dev_mode = true; - } else if (std.mem.eql(u8, args[i], "--check-timestamp")) { - use_timestamp_check = true; - } else if (std.mem.eql(u8, args[i], "--dry-run")) { - dry_run = true; - } - } - - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); - } - - // Detect if path is a subdirectory by finding git root - const git_root = try findGitRoot(allocator, path); - defer if (git_root) |gr| allocator.free(gr); - - const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path); - const relative_path = if (is_subdir) blk: { - // Get relative path from git root to the specified path - break :blk try std.fs.path.relative(allocator, git_root.?, path); - } else null; - defer if (relative_path) |rp| allocator.free(rp); - - // Determine commit_id and remote path based on mode - const commit_id: []const u8 = if (dev_mode) blk: { - // Dev mode: skip expensive hashing, use fixed "dev" commit - break :blk "dev"; - } else blk: { - // Production mode: calculate SHA256 of directory tree (always from git root) - const hash_base = git_root orelse path; - break :blk try crypto.hashDirectory(allocator, hash_base); - }; - defer if (!dev_mode) allocator.free(commit_id); - - // In dev mode, sync to {worker_base}/dev/files/ instead of hashed path - // For subdirectories, append the relative path to the remote destination - const remote_path = if (dev_mode) blk: { - if (is_subdir) { - break :blk try std.fmt.allocPrint( - allocator, - "{s}@{s}:{s}/dev/files/{s}/", - .{ config.api_key, config.worker_host, config.worker_base, relative_path.? }, - ); - } else { - break :blk try std.fmt.allocPrint( - allocator, - "{s}@{s}:{s}/dev/files/", - .{ config.api_key, config.worker_host, config.worker_base }, - ); - } - } else blk: { - if (is_subdir) { - break :blk try std.fmt.allocPrint( - allocator, - "{s}@{s}:{s}/{s}/files/{s}/", - .{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? }, - ); - } else { - break :blk try std.fmt.allocPrint( - allocator, - "{s}@{s}:{s}/{s}/files/", - .{ config.api_key, config.worker_host, config.worker_base, commit_id }, - ); - } - }; - defer allocator.free(remote_path); - - // Sync using embedded rsync (no external binary needed) - try rsync.sync(allocator, path, remote_path, config.worker_port); - - if (json_mode) { - std.debug.print("{\"ok\":true,\"action\":\"sync\",\"commit_id\":\"{s}\"}\n", .{commit_id}); } else { - colors.printSuccess("✓ Files synced to server\n", .{}); - } - - // If queue flag is set, queue the job - if (should_queue) { - const queue_cmd = @import("queue.zig"); - const actual_job_name = job_name orelse commit_id[0..8]; - const queue_args = [_][]const u8{ actual_job_name, "--commit", commit_id, "--priority", try std.fmt.allocPrint(allocator, "{d}", .{priority}) }; - defer allocator.free(queue_args[queue_args.len - 1]); - try queue_cmd.run(allocator, &queue_args); - } - - // Optional: Connect to server for progress monitoring if --monitor flag is used - var monitor_progress = false; - for (args[1..]) |arg| { - if (std.mem.eql(u8, arg, "--monitor")) { - monitor_progress = true; - break; + const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + while (try db.DB.step(stmt)) { + try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); } } - if (monitor_progress) { - std.debug.print("\nMonitoring sync progress...\n", .{}); - try monitorSyncProgress(allocator, &config, commit_id); + if (runs_to_sync.items.len == 0) { + if (!flags.json) colors.printSuccess("All runs already synced!\n", .{}); + return; } -} -fn printUsage() void { - logging.err("Usage: ml sync [options]\n\n", .{}); - logging.err("Options:\n", .{}); - logging.err(" --name Override job name when used with --queue\n", .{}); - logging.err(" --queue Queue the job after syncing\n", .{}); - logging.err(" --priority Priority to use when queueing (default: 5)\n", .{}); - logging.err(" --monitor Wait and show basic sync progress\n", .{}); - logging.err(" --json Output machine-readable JSON (sync result only)\n", .{}); - logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{}); - logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{}); - logging.err(" --dry-run Show what would be synced without transferring\n", .{}); - logging.err(" --help, -h Show this help message\n", .{}); -} + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); -fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, commit_id: []const u8) !void { - _ = commit_id; - // Use plain password for WebSocket authentication - const api_key_plain = config.api_key; - - // Connect to server with retry logic - const ws_url = try config.getWebSocketUrl(allocator); + const ws_url = try cfg.getWebSocketUrl(allocator); defer allocator.free(ws_url); - logging.info("Connecting to server {s}...\n", .{ws_url}); - var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3); - defer client.disconnect(); + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); - // Initialize progress bar (will be updated as we receive progress messages) - var progress = ProgressBar.init(100, "Syncing files"); - - var timeout_counter: u32 = 0; - const max_timeout = 30; // 30 seconds timeout - - while (timeout_counter < max_timeout) { - const message = client.receiveMessage(allocator) catch |err| { - switch (err) { - error.ConnectionClosed, error.ConnectionTimedOut => { - timeout_counter += 1; - std.Thread.sleep(1 * std.time.ns_per_s); - continue; - }, - else => return err, - } + var success_count: usize = 0; + for (runs_to_sync.items) |run_info| { + if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]}); + syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| { + if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err }); + continue; }; - defer allocator.free(message); - - // Parse JSON progress message using shared utilities - const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch { - logging.success("Sync progress: {s}\n", .{message}); - break; - }; - defer parsed.deinit(); - - if (parsed.value == .object) { - const root = parsed.value.object; - const status = json.getString(root, "status") orelse "unknown"; - const current = json.getInt(root, "progress") orelse 0; - const total = json.getInt(root, "total") orelse 100; - - if (std.mem.eql(u8, status, "complete")) { - progress.finish(); - colors.printSuccess("Sync complete!\n", .{}); - break; - } else if (std.mem.eql(u8, status, "error")) { - const error_msg = json.getString(root, "error") orelse "Unknown error"; - colors.printError("Sync failed: {s}\n", .{error_msg}); - return error.SyncFailed; - } else { - // Update progress bar - progress.total = @intCast(total); - progress.update(@intCast(current)); - } - } else { - logging.success("Sync progress: {s}\n", .{message}); - break; - } + const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;"; + const update_stmt = try database.prepare(update_sql); + defer db.DB.finalize(update_stmt); + try db.DB.bindText(update_stmt, 1, run_info.run_id); + _ = try db.DB.step(update_stmt); + success_count += 1; } - if (timeout_counter >= max_timeout) { - std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{}); + database.checkpointOnExit(); + + if (flags.json) { + std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len }); + } else { + colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len }); } } +const RunInfo = struct { + run_id: []const u8, + experiment_id: []const u8, + name: []const u8, + status: []const u8, + start_time: []const u8, + end_time: ?[]const u8, + + fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo { + return RunInfo{ + .run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)), + .experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)), + .name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)), + .status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)), + .start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)), + .end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null, + }; + } + + fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void { + allocator.free(self.run_id); + allocator.free(self.experiment_id); + allocator.free(self.name); + allocator.free(self.status); + allocator.free(self.start_time); + if (self.end_time) |et| allocator.free(et); + } +}; + +fn syncRun( + allocator: std.mem.Allocator, + database: *db.DB, + client: *ws.Client, + run_info: RunInfo, + api_key_hash: []const u8, +) !void { + // Get metrics for this run + var metrics = std.ArrayList(Metric).init(allocator); + defer { + for (metrics.items) |*m| m.deinit(allocator); + metrics.deinit(); + } + + const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;"; + const metrics_stmt = try database.prepare(metrics_sql); + defer db.DB.finalize(metrics_stmt); + try db.DB.bindText(metrics_stmt, 1, run_info.run_id); + + while (try db.DB.step(metrics_stmt)) { + try metrics.append(Metric{ + .key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)), + .value = db.DB.columnDouble(metrics_stmt, 1), + .step = db.DB.columnInt64(metrics_stmt, 2), + }); + } + + // Get params for this run + var params = std.ArrayList(Param).init(allocator); + defer { + for (params.items) |*p| p.deinit(allocator); + params.deinit(); + } + + const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;"; + const params_stmt = try database.prepare(params_sql); + defer db.DB.finalize(params_stmt); + try db.DB.bindText(params_stmt, 1, run_info.run_id); + + while (try db.DB.step(params_stmt)) { + try params.append(Param{ + .key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)), + .value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)), + }); + } + + // Build sync JSON + var sync_json = std.ArrayList(u8).init(allocator); + defer sync_json.deinit(); + const writer = sync_json.writer(allocator); + + try writer.writeAll("{"); + try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id}); + try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id}); + try writer.print("\"name\":\"{s}\",", .{run_info.name}); + try writer.print("\"status\":\"{s}\",", .{run_info.status}); + try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time}); + if (run_info.end_time) |et| { + try writer.print("\"end_time\":\"{s}\",", .{et}); + } else { + try writer.writeAll("\"end_time\":null,"); + } + + // Add metrics + try writer.writeAll("\"metrics\":["); + for (metrics.items, 0..) |m, i| { + if (i > 0) try writer.writeAll(","); + try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step }); + } + try writer.writeAll("],"); + + // Add params + try writer.writeAll("\"params\":["); + for (params.items, 0..) |p, i| { + if (i > 0) try writer.writeAll(","); + try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value }); + } + try writer.writeAll("]}"); + + // Send sync_run message + try client.sendSyncRun(sync_json.items, api_key_hash); + + // Wait for sync_ack + const response = try client.receiveMessage(allocator); + defer allocator.free(response); + + if (std.mem.indexOf(u8, response, "sync_ack") == null) { + return error.SyncRejected; + } +} + +const Metric = struct { + key: []const u8, + value: f64, + step: i64, + + fn deinit(self: *Metric, allocator: std.mem.Allocator) void { + allocator.free(self.key); + } +}; + +const Param = struct { + key: []const u8, + value: []const u8, + + fn deinit(self: *Param, allocator: std.mem.Allocator) void { + allocator.free(self.key); + allocator.free(self.value); + } +}; + +fn printUsage() void { + std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{}); + std.debug.print("Push local experiment runs to the server.\n\n", .{}); + std.debug.print("Options:\n", .{}); + std.debug.print(" --json Output structured JSON\n", .{}); + std.debug.print(" --help, -h Show this help message\n\n", .{}); + std.debug.print("Examples:\n", .{}); + std.debug.print(" ml sync # Sync all unsynced runs\n", .{}); + std.debug.print(" ml sync abc123 # Sync specific run\n", .{}); +} + /// Find the git root directory by walking up from the given path fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 { var buf: [std.fs.max_path_bytes]u8 = undefined; diff --git a/cli/src/commands/watch.zig b/cli/src/commands/watch.zig index f5cae85..f6aee9e 100644 --- a/cli/src/commands/watch.zig +++ b/cli/src/commands/watch.zig @@ -1,110 +1,83 @@ const std = @import("std"); -const Config = @import("../config.zig").Config; +const config = @import("../config.zig"); const crypto = @import("../utils/crypto.zig"); const rsync = @import("../utils/rsync_embedded.zig"); const ws = @import("../net/ws/client.zig"); +const core = @import("../core.zig"); +const mode = @import("../mode.zig"); +const colors = @import("../utils/colors.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + var flags = core.flags.CommonFlags{}; + var should_sync = false; + const sync_interval: u64 = 30; // Default 30 seconds + if (args.len == 0) { - printUsage(); - return error.InvalidArgs; + return printUsage(); } - // Global flags - for (args) |arg| { - if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { - printUsage(); - return; - } - } - - const path = args[0]; - var job_name: ?[]const u8 = null; - var priority: u8 = 5; - var should_queue = false; - var json: bool = false; - // Parse flags - var i: usize = 1; - while (i < args.len) : (i += 1) { - if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) { - job_name = args[i + 1]; - i += 1; - } else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) { - priority = try std.fmt.parseInt(u8, args[i + 1], 10); - i += 1; - } else if (std.mem.eql(u8, args[i], "--queue")) { - should_queue = true; - } else if (std.mem.eql(u8, args[i], "--json")) { - json = true; + for (args) |arg| { + if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { + return printUsage(); + } else if (std.mem.eql(u8, arg, "--sync")) { + should_sync = true; + } else if (std.mem.eql(u8, arg, "--json")) { + flags.json = true; } } - const config = try Config.load(allocator); + core.output.init(if (flags.json) .json else .text); + + const cfg = try config.Config.load(allocator); defer { - var mut_config = config; - mut_config.deinit(allocator); + var mut_cfg = cfg; + mut_cfg.deinit(allocator); } - if (json) { - std.debug.print("{\"ok\":true,\"action\":\"watch\",\"path\":\"{s}\",\"queued\":{s}}\n", .{ path, if (should_queue) "true" else "false" }); - } else { - std.debug.print("Watching {s} for changes...\n", .{path}); - std.debug.print("Press Ctrl+C to stop\n", .{}); - } - - // Initial sync - var last_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config); - defer allocator.free(last_commit_id); - - // Watch for changes - var watcher = try std.fs.cwd().openDir(path, .{ .iterate = true }); - defer watcher.close(); - - var last_modified: u64 = 0; - - while (true) { - // Check for file changes - var modified = false; - var walker = try watcher.walk(allocator); - defer walker.deinit(); - - while (try walker.next()) |entry| { - if (entry.kind == .file) { - const file = try watcher.openFile(entry.path, .{}); - defer file.close(); - - const stat = try file.stat(); - if (stat.mtime > last_modified) { - last_modified = @intCast(stat.mtime); - modified = true; - } - } + // Check mode if syncing + if (should_sync) { + const mode_result = try mode.detect(allocator, cfg); + if (mode.isOffline(mode_result.mode)) { + colors.printError("ml watch --sync requires server connection\n", .{}); + return error.RequiresServer; } + } - if (modified) { - if (!json) { - std.debug.print("\nChanges detected, syncing...\n", .{}); - } + if (flags.json) { + std.debug.print("{{\"ok\":true,\"action\":\"watch\",\"sync\":{s}}}\n", .{if (should_sync) "true" else "false"}); + } else { + if (should_sync) { + colors.printInfo("Watching for changes with auto-sync every {d}s...\n", .{sync_interval}); + } else { + colors.printInfo("Watching directory for changes...\n", .{}); + } + colors.printInfo("Press Ctrl+C to stop\n", .{}); + } - const new_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config); - defer allocator.free(new_commit_id); - - if (!std.mem.eql(u8, last_commit_id, new_commit_id)) { - allocator.free(last_commit_id); - last_commit_id = try allocator.dupe(u8, new_commit_id); - if (!json) { - std.debug.print("✓ Synced new version: {s}\n", .{last_commit_id[0..8]}); - } + // Watch loop + var last_synced: i64 = 0; + while (true) { + if (should_sync) { + const now = std.time.timestamp(); + if (now - last_synced >= @as(i64, @intCast(sync_interval))) { + // Trigger sync + const sync_cmd = @import("sync.zig"); + sync_cmd.run(allocator, &[_][]const u8{"--json"}) catch |err| { + if (!flags.json) { + colors.printError("Auto-sync failed: {}\n", .{err}); + } + }; + last_synced = now; } } // Wait before checking again - std.Thread.sleep(2_000_000_000); // 2 seconds in nanoseconds + std.Thread.sleep(2_000_000_000); // 2 seconds } } -fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, config: Config) ![]u8 { +fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, cfg: config.Config) ![]u8 { // Calculate commit ID const commit_id = try crypto.hashDirectory(allocator, path); @@ -112,22 +85,22 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con const remote_path = try std.fmt.allocPrint( allocator, "{s}@{s}:{s}/{s}/files/", - .{ config.worker_user, config.worker_host, config.worker_base, commit_id }, + .{ cfg.worker_user, cfg.worker_host, cfg.worker_base, commit_id }, ); defer allocator.free(remote_path); - try rsync.sync(allocator, path, remote_path, config.worker_port); + try rsync.sync(allocator, path, remote_path, cfg.worker_port); if (should_queue) { const actual_job_name = job_name orelse commit_id[0..8]; - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); defer allocator.free(api_key_hash); // Connect to WebSocket and queue job - const ws_url = try config.getWebSocketUrl(allocator); + const ws_url = try cfg.getWebSocketUrl(allocator); defer allocator.free(ws_url); - var client = try ws.Client.connect(allocator, ws_url, config.api_key); + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); defer client.close(); try client.sendQueueJob(actual_job_name, commit_id, priority, api_key_hash); @@ -144,11 +117,10 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con } fn printUsage() void { - std.debug.print("Usage: ml watch [options]\n\n", .{}); + std.debug.print("Usage: ml watch [options]\n\n", .{}); + std.debug.print("Watch for changes and optionally auto-sync.\n\n", .{}); std.debug.print("Options:\n", .{}); - std.debug.print(" --name Override job name when used with --queue\n", .{}); - std.debug.print(" --priority Priority to use when queueing (default: 5)\n", .{}); - std.debug.print(" --queue Queue on every sync\n", .{}); - std.debug.print(" --json Emit a single JSON line describing watch start\n", .{}); + std.debug.print(" --sync Auto-sync runs to server every 30s\n", .{}); + std.debug.print(" --json Output structured JSON\n", .{}); std.debug.print(" --help, -h Show this help message\n", .{}); }