const std = @import("std"); const colors = @import("../utils/colors.zig"); const Config = @import("../config.zig").Config; const crypto = @import("../utils/crypto.zig"); const io = @import("../utils/io.zig"); const ws = @import("../net/ws/client.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len == 0) { try printUsage(); return error.InvalidArgs; } if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) { try printUsage(); return; } const target = args[0]; var author: []const u8 = ""; var note: ?[]const u8 = null; var base_override: ?[]const u8 = null; var json_mode: bool = false; var i: usize = 1; while (i < args.len) : (i += 1) { const a = args[i]; if (std.mem.eql(u8, a, "--author")) { if (i + 1 >= args.len) { colors.printError("Missing value for --author\n", .{}); return error.InvalidArgs; } author = args[i + 1]; i += 1; } else if (std.mem.eql(u8, a, "--note")) { if (i + 1 >= args.len) { colors.printError("Missing value for --note\n", .{}); return error.InvalidArgs; } note = args[i + 1]; i += 1; } else if (std.mem.eql(u8, a, "--base")) { if (i + 1 >= args.len) { colors.printError("Missing value for --base\n", .{}); return error.InvalidArgs; } base_override = args[i + 1]; i += 1; } else if (std.mem.eql(u8, a, "--json")) { json_mode = true; } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { try printUsage(); return; } else if (std.mem.startsWith(u8, a, "--")) { colors.printError("Unknown option: {s}\n", .{a}); return error.InvalidArgs; } else { colors.printError("Unexpected argument: {s}\n", .{a}); return error.InvalidArgs; } } if (note == null or std.mem.trim(u8, note.?, " \t\r\n").len == 0) { colors.printError("--note is required\n", .{}); try printUsage(); return error.InvalidArgs; } const cfg = try Config.load(allocator); defer { var mut_cfg = cfg; mut_cfg.deinit(allocator); } const resolved_base = base_override orelse cfg.worker_base; const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| { if (err == error.FileNotFound) { colors.printError( "Could not locate run_manifest.json for '{s}'. Provide a path, or use --base to scan finished/failed/running/pending.\n", .{target}, ); } return err; }; defer allocator.free(manifest_path); const job_name = try readJobNameFromManifest(allocator, manifest_path); defer allocator.free(job_name); const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); defer allocator.free(api_key_hash); const ws_url = try cfg.getWebSocketUrl(allocator); defer allocator.free(ws_url); var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); defer client.close(); try client.sendAnnotateRun(job_name, author, note.?, api_key_hash); if (json_mode) { const msg = try client.receiveMessage(allocator); defer allocator.free(msg); const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch { var out = io.stdoutWriter(); try out.print("{s}\n", .{msg}); return error.InvalidPacket; }; defer { if (packet.success_message) |m| allocator.free(m); if (packet.error_message) |m| allocator.free(m); if (packet.error_details) |m| allocator.free(m); } const Result = struct { ok: bool, job_name: []const u8, message: []const u8, error_code: ?u8 = null, error_message: ?[]const u8 = null, details: ?[]const u8 = null, }; var out = io.stdoutWriter(); if (packet.packet_type == .error_packet) { const res = Result{ .ok = false, .job_name = job_name, .message = "", .error_code = @intFromEnum(packet.error_code.?), .error_message = packet.error_message orelse "", .details = packet.error_details orelse "", }; try out.print("{f}\n", .{std.json.fmt(res, .{})}); return error.CommandFailed; } const res = Result{ .ok = true, .job_name = job_name, .message = packet.success_message orelse "", }; try out.print("{f}\n", .{std.json.fmt(res, .{})}); return; } try client.receiveAndHandleResponse(allocator, "Annotate"); colors.printSuccess("Annotation added\n", .{}); colors.printInfo("Job: {s}\n", .{job_name}); } fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 { const data = try readFileAlloc(allocator, manifest_path); defer allocator.free(data); const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{}); defer parsed.deinit(); if (parsed.value != .object) return error.InvalidManifest; const root = parsed.value.object; const job_name = jsonGetString(root, "job_name") orelse ""; if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) { return error.InvalidManifest; } return allocator.dupe(u8, job_name); } fn resolveManifestPathWithBase( allocator: std.mem.Allocator, input: []const u8, base: []const u8, ) ![]u8 { var cwd = std.fs.cwd(); if (std.fs.path.isAbsolute(input)) { if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| { var mutable_dir = dir; defer mutable_dir.close(); return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); } if (std.fs.openFileAbsolute(input, .{}) catch null) |file| { var mutable_file = file; defer mutable_file.close(); return allocator.dupe(u8, input); } return resolveManifestPathById(allocator, input, base); } const stat = cwd.statFile(input) catch |err| { if (err == error.FileNotFound) { return resolveManifestPathById(allocator, input, base); } return err; }; if (stat.kind == .directory) { return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); } return allocator.dupe(u8, input); } fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 { if (std.mem.trim(u8, id, " \t\r\n").len == 0) { return error.FileNotFound; } if (base_path.len == 0) { return error.FileNotFound; } const roots = [_][]const u8{ "finished", "failed", "running", "pending" }; for (roots) |root| { const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root }); defer allocator.free(root_path); var dir = if (std.fs.path.isAbsolute(root_path)) (std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue) else (std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue); defer dir.close(); var it = dir.iterate(); while (try it.next()) |entry| { if (entry.kind != .directory) continue; const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name }); defer allocator.free(run_dir); const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" }); defer allocator.free(manifest_path); const file = if (std.fs.path.isAbsolute(manifest_path)) (std.fs.openFileAbsolute(manifest_path, .{}) catch continue) else (std.fs.cwd().openFile(manifest_path, .{}) catch continue); defer file.close(); const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue; defer allocator.free(data); const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue; defer parsed.deinit(); if (parsed.value != .object) continue; const obj = parsed.value.object; const run_id = jsonGetString(obj, "run_id") orelse ""; const task_id = jsonGetString(obj, "task_id") orelse ""; if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) { return allocator.dupe(u8, manifest_path); } } } return error.FileNotFound; } fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 { var file = if (std.fs.path.isAbsolute(path)) try std.fs.openFileAbsolute(path, .{}) else try std.fs.cwd().openFile(path, .{}); defer file.close(); return file.readToEndAlloc(allocator, 1024 * 1024); } fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { const v = obj.get(key) orelse return null; if (v != .string) return null; return v.string; } fn printUsage() !void { colors.printInfo("Usage: ml annotate --note [--author ] [--base ] [--json]\n", .{}); colors.printInfo("\nExamples:\n", .{}); colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{}); colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{}); }