diff --git a/cli/src/commands/cancel.zig b/cli/src/commands/cancel.zig index e24aeff..1217a89 100644 --- a/cli/src/commands/cancel.zig +++ b/cli/src/commands/cancel.zig @@ -1,126 +1,212 @@ const std = @import("std"); -const Config = @import("../config.zig").Config; +const config = @import("../config.zig"); +const db = @import("../db.zig"); const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); -const logging = @import("../utils/logging.zig"); const colors = @import("../utils/colors.zig"); -const auth = @import("../utils/auth.zig"); - -pub const CancelOptions = struct { - force: bool = false, - json: bool = false, -}; +const core = @import("../core.zig"); +const mode = @import("../mode.zig"); +const manifest_lib = @import("../manifest.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { - var options = CancelOptions{}; - var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { - colors.printError("Failed to allocate job list: {}\n", .{err}); - return err; - }; - defer job_names.deinit(allocator); + var flags = core.flags.CommonFlags{}; + var force = false; + var targets = std.ArrayList([]const u8).init(allocator); + defer targets.deinit(); - // Parse arguments for flags and job names + // Parse arguments var i: usize = 0; while (i < args.len) : (i += 1) { const arg = args[i]; - if (std.mem.eql(u8, arg, "--force")) { - options.force = true; + force = true; } else if (std.mem.eql(u8, arg, "--json")) { - options.json = true; - } else if (std.mem.startsWith(u8, arg, "--help")) { - try printUsage(); - return; + flags.json = true; + } else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { + return printUsage(); } else if (std.mem.startsWith(u8, arg, "--")) { - colors.printError("Unknown option: {s}\n", .{arg}); - try printUsage(); + core.output.errorMsg("cancel", "Unknown option"); return error.InvalidArgs; } else { - // This is a job name - try job_names.append(allocator, arg); + try targets.append(arg); } } - if (job_names.items.len == 0) { - colors.printError("No job names specified\n", .{}); - try printUsage(); + core.output.init(if (flags.json) .json else .text); + + if (targets.items.len == 0) { + core.output.errorMsg("cancel", "No run_id specified"); return error.InvalidArgs; } - const config = try Config.load(allocator); + const cfg = try config.Config.load(allocator); defer { - var mut_config = config; - mut_config.deinit(allocator); + var mut_cfg = cfg; + mut_cfg.deinit(allocator); } - // Authenticate with server to get user context - var user_context = try auth.authenticateUser(allocator, config); - defer user_context.deinit(); + // Detect mode + const mode_result = try mode.detect(allocator, cfg); + if (mode_result.warning) |w| { + std.log.warn("{s}", .{w}); + } - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); - - // Connect to WebSocket and send cancel messages - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); - - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); - - // Process each job var success_count: usize = 0; - var failed_jobs = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { - colors.printError("Failed to allocate failed jobs list: {}\n", .{err}); - return err; - }; - defer failed_jobs.deinit(allocator); + var failed_count: usize = 0; - for (job_names.items, 0..) |job_name, index| { - if (!options.json) { - colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name }); - } - - cancelSingleJob(allocator, &client, user_context, job_name, options, api_key_hash) catch |err| { - colors.printError("Failed to cancel job '{s}': {}\n", .{ job_name, err }); - failed_jobs.append(allocator, job_name) catch |append_err| { - colors.printError("Failed to track failed job: {}\n", .{append_err}); + for (targets.items) |target| { + if (mode.isOffline(mode_result.mode)) { + // Local mode: kill by PID + cancelLocal(allocator, target, force, flags.json) catch |err| { + if (!flags.json) { + colors.printError("Failed to cancel '{s}': {}\n", .{ target, err }); + } + failed_count += 1; + continue; }; - continue; - }; - + } else { + // Online mode: cancel on server + cancelServer(allocator, target, force, flags.json, cfg) catch |err| { + if (!flags.json) { + colors.printError("Failed to cancel '{s}': {}\n", .{ target, err }); + } + failed_count += 1; + continue; + }; + } success_count += 1; } - // Show summary - if (!options.json) { - colors.printInfo("\nCancel Summary:\n", .{}); - colors.printSuccess("Successfully canceled {d} job(s)\n", .{success_count}); - if (failed_jobs.items.len > 0) { - colors.printError("Failed to cancel {d} job(s):\n", .{failed_jobs.items.len}); - for (failed_jobs.items) |failed_job| { - colors.printError(" - {s}\n", .{failed_job}); - } + if (flags.json) { + std.debug.print("{{\"success\":true,\"canceled\":{d},\"failed\":{d}}}\n", .{ success_count, failed_count }); + } else { + colors.printSuccess("Canceled {d} run(s)\n", .{success_count}); + if (failed_count > 0) { + colors.printError("Failed to cancel {d} run(s)\n", .{failed_count}); } } } -fn cancelSingleJob(allocator: std.mem.Allocator, client: *ws.Client, user_context: auth.UserContext, job_name: []const u8, options: CancelOptions, api_key_hash: []const u8) !void { +/// Cancel local run by PID +fn cancelLocal(allocator: std.mem.Allocator, run_id: []const u8, force: bool, json: bool) !void { + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Get DB path + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + // Look up PID + const sql = "SELECT pid FROM ml_runs WHERE run_id = ? AND status = 'RUNNING';"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + try db.DB.bindText(stmt, 1, run_id); + + const has_row = try db.DB.step(stmt); + if (!has_row) { + return error.RunNotFoundOrNotRunning; + } + + const pid = db.DB.columnInt64(stmt, 0); + if (pid == 0) { + return error.NoPIDAvailable; + } + + // Send SIGTERM first + std.posix.kill(@intCast(pid), std.posix.SIG.TERM) catch |err| { + if (err == error.ProcessNotFound) { + // Process already gone + } else { + return err; + } + }; + + if (!force) { + // Wait 5 seconds for graceful termination + std.time.sleep(5 * std.time.ns_per_s); + } + + // Check if still running, send SIGKILL if needed + if (force or isProcessRunning(@intCast(pid))) { + std.posix.kill(@intCast(pid), std.posix.SIG.KILL) catch |err| { + if (err != error.ProcessNotFound) { + return err; + } + }; + } + + // Update run status + const update_sql = "UPDATE ml_runs SET status = 'CANCELLED', pid = NULL WHERE run_id = ?;"; + const update_stmt = try database.prepare(update_sql); + defer db.DB.finalize(update_stmt); + try db.DB.bindText(update_stmt, 1, run_id); + _ = try db.DB.step(update_stmt); + + // Update manifest + const artifact_path = try std.fs.path.join(allocator, &[_][]const u8{ + cfg.artifact_path, + if (cfg.experiment) |exp| exp.name else "default", + run_id, + "run_manifest.json", + }); + defer allocator.free(artifact_path); + + manifest_lib.updateManifestStatus(artifact_path, "CANCELLED", null, allocator) catch {}; + + // Checkpoint + database.checkpointOnExit(); + + if (!json) { + colors.printSuccess("✓ Canceled run {s}\n", .{run_id[0..8]}); + } +} + +/// Check if process is still running +fn isProcessRunning(pid: i32) bool { + const result = std.posix.kill(pid, 0); + return result == error.PermissionDenied or result == {}; +} + +/// Cancel server job +fn cancelServer(allocator: std.mem.Allocator, job_name: []const u8, force: bool, json: bool, cfg: config.Config) !void { + _ = force; + _ = json; + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try cfg.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + try client.sendCancelJob(job_name, api_key_hash); - // Receive structured response with user context - try client.receiveAndHandleCancelResponse(allocator, user_context, job_name, options); + // Wait for acknowledgment + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + // Parse response (simplified) + if (std.mem.indexOf(u8, message, "error") != null) { + return error.ServerCancelFailed; + } } fn printUsage() !void { - colors.printInfo("Usage: ml cancel [options] [ ...]\n", .{}); - colors.printInfo("\nOptions:\n", .{}); - colors.printInfo(" --force Force cancel even if job is running\n", .{}); + colors.printInfo("Usage: ml cancel [options] [ ...]\n", .{}); + colors.printInfo("\nCancel a local run (kill process) or server job.\n\n", .{}); + colors.printInfo("Options:\n", .{}); + colors.printInfo(" --force Force cancel (SIGKILL immediately)\n", .{}); colors.printInfo(" --json Output structured JSON\n", .{}); - colors.printInfo(" --help Show this help message\n", .{}); + colors.printInfo(" --help, -h Show this help message\n", .{}); colors.printInfo("\nExamples:\n", .{}); - colors.printInfo(" ml cancel job1 # Cancel single job\n", .{}); - colors.printInfo(" ml cancel job1 job2 job3 # Cancel multiple jobs\n", .{}); - colors.printInfo(" ml cancel --force job1 # Force cancel running job\n", .{}); - colors.printInfo(" ml cancel --json job1 # Cancel job with JSON output\n", .{}); - colors.printInfo(" ml cancel --force --json job1 job2 # Force cancel with JSON output\n", .{}); + colors.printInfo(" ml cancel abc123 # Cancel local run by run_id\n", .{}); + colors.printInfo(" ml cancel --force abc123 # Force cancel\n", .{}); } diff --git a/cli/src/commands/log.zig b/cli/src/commands/log.zig new file mode 100644 index 0000000..4b7a0c1 --- /dev/null +++ b/cli/src/commands/log.zig @@ -0,0 +1,192 @@ +const std = @import("std"); +const config = @import("../config.zig"); +const core = @import("../core.zig"); +const colors = @import("../utils/colors.zig"); +const manifest_lib = @import("../manifest.zig"); +const mode = @import("../mode.zig"); +const ws = @import("../net/ws/client.zig"); +const protocol = @import("../net/protocol.zig"); +const crypto = @import("../utils/crypto.zig"); + +/// Logs command - fetch or stream run logs +/// Usage: +/// ml logs # Fetch logs from local file or server +/// ml logs --follow # Stream logs from server +pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void { + var flags = core.flags.CommonFlags{}; + var command_args = try core.flags.parseCommon(allocator, args, &flags); + defer command_args.deinit(allocator); + + core.output.init(if (flags.json) .json else .text); + + if (flags.help) { + return printUsage(); + } + + if (command_args.items.len < 1) { + std.log.err("Usage: ml logs [--follow]", .{}); + return error.MissingArgument; + } + + const target = command_args.items[0]; + const follow = core.flags.parseBoolFlag(command_args.items, "follow"); + + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Detect mode + const mode_result = try mode.detect(allocator, cfg); + if (mode_result.warning) |w| { + std.log.warn("{s}", .{w}); + } + + if (mode.isOffline(mode_result.mode)) { + // Local mode: read from output.log file + return try fetchLocalLogs(allocator, target, &cfg, flags.json); + } else { + // Online mode: fetch or stream from server + if (follow) { + return try streamServerLogs(allocator, target, cfg); + } else { + return try fetchServerLogs(allocator, target, cfg); + } + } +} + +fn fetchLocalLogs(allocator: std.mem.Allocator, target: []const u8, cfg: *const config.Config, json: bool) !void { + // Resolve manifest path + const manifest_path = manifest_lib.resolveManifestPath(target, cfg.artifact_path, allocator) catch |err| { + if (err == error.ManifestNotFound) { + std.log.err("Run not found: {s}", .{target}); + return error.RunNotFound; + } + return err; + }; + defer allocator.free(manifest_path); + + // Read manifest to get artifact path + const manifest = try manifest_lib.readManifest(manifest_path, allocator); + defer manifest.deinit(allocator); + + // Build output.log path + const output_path = try std.fs.path.join(allocator, &[_][]const u8{ + manifest.artifact_path, + "output.log", + }); + defer allocator.free(output_path); + + // Read output.log + const content = std.fs.cwd().readFileAlloc(allocator, output_path, 10 * 1024 * 1024) catch |err| { + if (err == error.FileNotFound) { + std.log.err("No logs found for run: {s}", .{target}); + return error.LogsNotFound; + } + return err; + }; + defer allocator.free(content); + + if (json) { + // Escape content for JSON + var escaped = std.ArrayList(u8).init(allocator); + defer escaped.deinit(); + const writer = escaped.writer(allocator); + + for (content) |c| { + switch (c) { + '\\' => try writer.writeAll("\\\\"), + '"' => try writer.writeAll("\\\""), + '\n' => try writer.writeAll("\\n"), + '\r' => try writer.writeAll("\\r"), + '\t' => try writer.writeAll("\\t"), + else => { + if (c >= 0x20 and c < 0x7f) { + try writer.writeByte(c); + } else { + try writer.print("\\u{x:0>4}", .{c}); + } + }, + } + } + + std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"logs\":\"{s}\"}}\n", .{ + manifest.run_id, + escaped.items, + }); + } else { + std.debug.print("{s}\n", .{content}); + } +} + +fn fetchServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void { + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try cfg.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + try client.sendFetchLogs(target, api_key_hash); + + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + std.debug.print("{s}\n", .{message}); +} + +fn streamServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void { + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try cfg.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + colors.printInfo("Streaming logs for: {s}\n", .{target}); + + try client.sendStreamLogs(target, api_key_hash); + + // Stream loop + while (true) { + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + std.debug.print("{s}\n", .{message}); + continue; + }; + defer packet.deinit(allocator); + + switch (packet.packet_type) { + .data => { + if (packet.data_payload) |payload| { + std.debug.print("{s}\n", .{payload}); + } + }, + .error_packet => { + const err_msg = packet.error_message orelse "Stream error"; + colors.printError("Error: {s}\n", .{err_msg}); + return error.ServerError; + }, + else => {}, + } + } +} + +fn printUsage() !void { + std.debug.print("Usage: ml logs [options]\n\n", .{}); + std.debug.print("Fetch or stream run logs.\n\n", .{}); + std.debug.print("Options:\n", .{}); + std.debug.print(" --follow, -f Stream logs from server (online mode)\n", .{}); + std.debug.print(" --help, -h Show this help message\n", .{}); + std.debug.print(" --json Output structured JSON\n\n", .{}); + std.debug.print("Examples:\n", .{}); + std.debug.print(" ml logs abc123 # Fetch logs (local or server)\n", .{}); + std.debug.print(" ml logs abc123 --follow # Stream logs from server\n", .{}); +} diff --git a/cli/src/commands/note.zig b/cli/src/commands/note.zig new file mode 100644 index 0000000..a52e7cc --- /dev/null +++ b/cli/src/commands/note.zig @@ -0,0 +1,143 @@ +const std = @import("std"); +const config = @import("../config.zig"); +const db = @import("../db.zig"); +const core = @import("../core.zig"); +const colors = @import("../utils/colors.zig"); +const manifest_lib = @import("../manifest.zig"); + +/// Note command - unified metadata annotation +/// Usage: +/// ml note --text "Try lr=3e-4 next" +/// ml note --hypothesis "LR scaling helps" +/// ml note --outcome validates --confidence 0.9 +/// ml note --privacy private +pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void { + var flags = core.flags.CommonFlags{}; + var command_args = try core.flags.parseCommon(allocator, args, &flags); + defer command_args.deinit(allocator); + + core.output.init(if (flags.json) .json else .text); + + if (flags.help) { + return printUsage(); + } + + if (command_args.items.len < 1) { + std.log.err("Usage: ml note [options]", .{}); + return error.MissingArgument; + } + + const run_id = command_args.items[0]; + + // Parse metadata options + const text = core.flags.parseKVFlag(command_args.items, "text"); + const hypothesis = core.flags.parseKVFlag(command_args.items, "hypothesis"); + const outcome = core.flags.parseKVFlag(command_args.items, "outcome"); + const confidence = core.flags.parseKVFlag(command_args.items, "confidence"); + const privacy = core.flags.parseKVFlag(command_args.items, "privacy"); + const author = core.flags.parseKVFlag(command_args.items, "author"); + + // Check that at least one option is provided + if (text == null and hypothesis == null and outcome == null and privacy == null) { + std.log.err("No metadata provided. Use --text, --hypothesis, --outcome, or --privacy", .{}); + return error.MissingMetadata; + } + + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Get DB path + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + // Verify run exists + const check_sql = "SELECT 1 FROM ml_runs WHERE run_id = ?;"; + const check_stmt = try database.prepare(check_sql); + defer db.DB.finalize(check_stmt); + try db.DB.bindText(check_stmt, 1, run_id); + const has_row = try db.DB.step(check_stmt); + if (!has_row) { + std.log.err("Run not found: {s}", .{run_id}); + return error.RunNotFound; + } + + // Add text note as a tag + if (text) |t| { + try addTag(allocator, &database, run_id, "note", t, author); + } + + // Add hypothesis + if (hypothesis) |h| { + try addTag(allocator, &database, run_id, "hypothesis", h, author); + } + + // Add outcome + if (outcome) |o| { + try addTag(allocator, &database, run_id, "outcome", o, author); + if (confidence) |c| { + try addTag(allocator, &database, run_id, "confidence", c, author); + } + } + + // Add privacy level + if (privacy) |p| { + try addTag(allocator, &database, run_id, "privacy", p, author); + } + + // Checkpoint WAL + database.checkpointOnExit(); + + if (flags.json) { + std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"action\":\"note_added\"}}\n", .{run_id}); + } else { + colors.printSuccess("✓ Added note to run {s}\n", .{run_id[0..8]}); + } +} + +fn addTag( + allocator: std.mem.Allocator, + database: *db.DB, + run_id: []const u8, + key: []const u8, + value: []const u8, + author: ?[]const u8, +) !void { + const full_value = if (author) |a| + try std.fmt.allocPrint(allocator, "{s} (by {s})", .{ value, a }) + else + try allocator.dupe(u8, value); + defer allocator.free(full_value); + + const sql = "INSERT INTO ml_tags (run_id, key, value) VALUES (?, ?, ?);"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + + try db.DB.bindText(stmt, 1, run_id); + try db.DB.bindText(stmt, 2, key); + try db.DB.bindText(stmt, 3, full_value); + _ = try db.DB.step(stmt); +} + +fn printUsage() !void { + std.debug.print("Usage: ml note [options]\n\n", .{}); + std.debug.print("Add metadata notes to a run.\n\n", .{}); + std.debug.print("Options:\n", .{}); + std.debug.print(" --text Free-form annotation\n", .{}); + std.debug.print(" --hypothesis Research hypothesis\n", .{}); + std.debug.print(" --outcome Outcome: validates/refutes/inconclusive\n", .{}); + std.debug.print(" --confidence <0-1> Confidence in outcome\n", .{}); + std.debug.print(" --privacy Privacy: private/team/public\n", .{}); + std.debug.print(" --author Author of the note\n", .{}); + std.debug.print(" --help, -h Show this help\n", .{}); + std.debug.print(" --json Output structured JSON\n\n", .{}); + std.debug.print("Examples:\n", .{}); + std.debug.print(" ml note abc123 --text \"Try lr=3e-4 next\"\n", .{}); + std.debug.print(" ml note abc123 --hypothesis \"LR scaling helps\"\n", .{}); + std.debug.print(" ml note abc123 --outcome validates --confidence 0.9\n", .{}); +}