const std = @import("std"); const Config = @import("../config.zig").Config; const crypto = @import("../utils/crypto.zig"); const io = @import("../utils/io.zig"); const ws = @import("../net/ws/client.zig"); const protocol = @import("../net/protocol.zig"); const core = @import("../core.zig"); pub const FindOptions = struct { json: bool = false, csv: bool = false, limit: usize = 20, tag: ?[]const u8 = null, outcome: ?[]const u8 = null, dataset: ?[]const u8 = null, experiment_group: ?[]const u8 = null, author: ?[]const u8 = null, after: ?[]const u8 = null, before: ?[]const u8 = null, query: ?[]const u8 = null, }; pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { if (argv.len == 0) { return printUsage(); } if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) { return printUsage(); } var flags = core.flags.CommonFlags{}; var limit: usize = 20; var csv: bool = false; var tag: ?[]const u8 = null; var outcome: ?[]const u8 = null; var dataset: ?[]const u8 = null; var experiment_group: ?[]const u8 = null; var author: ?[]const u8 = null; var after: ?[]const u8 = null; var before: ?[]const u8 = null; var query_str: ?[]const u8 = null; // First argument might be a query string or a flag var arg_idx: usize = 0; if (!std.mem.startsWith(u8, argv[0], "--")) { query_str = argv[0]; arg_idx = 1; } var i: usize = arg_idx; while (i < argv.len) : (i += 1) { const arg = argv[i]; if (std.mem.eql(u8, arg, "--json")) { flags.json = true; } else if (std.mem.eql(u8, arg, "--csv")) { csv = true; } else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) { limit = try std.fmt.parseInt(usize, argv[i + 1], 10); i += 1; } else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) { tag = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) { outcome = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) { dataset = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--group") and i + 1 < argv.len) { experiment_group = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) { author = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) { after = argv[i + 1]; i += 1; } else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) { before = argv[i + 1]; i += 1; } else { core.output.err("Unknown option"); return error.InvalidArgs; } } const cfg = try Config.load(allocator); defer { var mut_cfg = cfg; mut_cfg.deinit(allocator); } const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); defer allocator.free(api_key_hash); const ws_url = try cfg.getWebSocketUrl(allocator); defer allocator.free(ws_url); std.debug.print("Searching experiments...\n", .{}); var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); defer client.close(); // Build search options struct for JSON builder const search_options = FindOptions{ .json = flags.json, .csv = csv, .limit = limit, .tag = tag, .outcome = outcome, .dataset = dataset, .experiment_group = experiment_group, .author = author, .after = after, .before = before, .query = query_str, }; const search_json = try buildSearchJson(allocator, &search_options); defer allocator.free(search_json); // Send search request - we'll use the dataset search opcode as a placeholder // In production, this would have a dedicated search endpoint try client.sendDatasetSearch(search_json, api_key_hash); const msg = try client.receiveMessage(allocator); defer allocator.free(msg); // Parse response const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch { if (flags.json) { var out = io.stdoutWriter(); try out.print("{{\"error\":\"invalid_response\"}}\n", .{}); } else { std.debug.print("Failed to parse search results\n", .{}); } return error.InvalidResponse; }; defer parsed.deinit(); const root = parsed.value; if (flags.json) { try io.stdoutWriteJson(root); } else if (csv) { const options = FindOptions{ .json = flags.json, .csv = csv }; try outputCsvResults(allocator, root, &options); } else { const options = FindOptions{ .json = flags.json, .csv = csv }; try outputHumanResults(root, &options); } } fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 { var buf = std.ArrayList(u8).empty; defer buf.deinit(allocator); const writer = buf.writer(allocator); try writer.writeAll("{"); var first = true; if (options.query) |q| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"query\":"); try writeJSONString(writer, q); } if (options.tag) |t| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"tag\":"); try writeJSONString(writer, t); } if (options.outcome) |o| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"outcome\":"); try writeJSONString(writer, o); } if (options.dataset) |d| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"dataset\":"); try writeJSONString(writer, d); } if (options.experiment_group) |eg| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"experiment_group\":"); try writeJSONString(writer, eg); } if (options.author) |a| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"author\":"); try writeJSONString(writer, a); } if (options.after) |a| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"after\":"); try writeJSONString(writer, a); } if (options.before) |b| { if (!first) try writer.writeAll(","); first = false; try writer.writeAll("\"before\":"); try writeJSONString(writer, b); } if (!first) try writer.writeAll(","); try writer.print("\"limit\":{d}", .{options.limit}); try writer.writeAll("}"); return buf.toOwnedSlice(allocator); } fn outputHumanResults(root: std.json.Value, _options: *const FindOptions) !void { _ = _options; if (root != .object) { std.debug.print("Invalid response format\n", .{}); return; } const obj = root.object; // Check for error if (obj.get("error")) |err| { if (err == .string) { std.debug.print("Search error: {s}\n", .{err.string}); } return; } const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs"); if (results == null) { std.debug.print("No results found\n", .{}); return; } if (results.? != .array) { std.debug.print("Invalid results format\n", .{}); return; } const items = results.?.array.items; if (items.len == 0) { std.debug.print("No experiments found matching your criteria\n", .{}); return; } for (items) |item| { if (item != .object) continue; const run_obj = item.object; const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown"; const short_id = if (id.len > 8) id[0..8] else id; const job_name = jsonGetString(run_obj, "job_name") orelse ""; const outcome = jsonGetString(run_obj, "outcome") orelse "-"; const status = jsonGetString(run_obj, "status") orelse "unknown"; // Build group/tags field var group_tags_buf: [100]u8 = undefined; const group_tags = blk: { const group = jsonGetString(run_obj, "experiment_group"); const tags = run_obj.get("tags"); if (group) |g| { if (tags) |t| { if (t == .string) { break :blk std.fmt.bufPrint(&group_tags_buf, "{s}/{s}", .{ g, t.string }) catch g; } } break :blk g; } if (tags) |t| { if (t == .string) break :blk t.string; } break :blk ""; }; // TSV output: id => outcome | status | job_name | group_tags std.debug.print("{s} => {s}\t{s}\t{s}\t{s}\n", .{ short_id, outcome, status, job_name, group_tags, }); } } fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void { _ = options; if (root != .object) { std.debug.print("Invalid response format\n", .{}); return; } const obj = root.object; // Check for error if (obj.get("error")) |err| { if (err == .string) { std.debug.print("Search error: {s}\n", .{err.string}); } return; } const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs"); if (results == null) { return; } if (results.? != .array) { return; } const items = results.?.array.items; // CSV Header const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n"); for (items) |item| { if (item != .object) continue; const run_obj = item.object; const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown"; const job_name = jsonGetString(run_obj, "job_name") orelse ""; const outcome = jsonGetString(run_obj, "outcome") orelse ""; const status = jsonGetString(run_obj, "status") orelse ""; const group = jsonGetString(run_obj, "experiment_group") orelse ""; // Get tags var tags: []const u8 = ""; if (run_obj.get("tags")) |t| { if (t == .string) tags = t.string; } // Get hypothesis var hypothesis: []const u8 = ""; if (run_obj.get("narrative")) |narr| { if (narr == .object) { if (narr.object.get("hypothesis")) |h| { if (h == .string) hypothesis = h.string; } } } // Escape fields that might contain commas or quotes const safe_job = try escapeCsv(allocator, job_name); defer allocator.free(safe_job); const safe_group = try escapeCsv(allocator, group); defer allocator.free(safe_group); const safe_tags = try escapeCsv(allocator, tags); defer allocator.free(safe_tags); const safe_hypo = try escapeCsv(allocator, hypothesis); defer allocator.free(safe_hypo); var buf: [1024]u8 = undefined; const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{ id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo, }); try stdout_file.writeAll(line); } } fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 { // Check if we need to escape (contains comma, quote, or newline) var needs_escape = false; for (s) |c| { if (c == ',' or c == '"' or c == '\n' or c == '\r') { needs_escape = true; break; } } if (!needs_escape) { return allocator.dupe(u8, s); } // Escape: wrap in quotes and double existing quotes var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| { return err; }; defer buf.deinit(allocator); try buf.append(allocator, '"'); for (s) |c| { if (c == '"') { try buf.appendSlice(allocator, "\"\""); } else { try buf.append(allocator, c); } } try buf.append(allocator, '"'); return buf.toOwnedSlice(allocator); } fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { const v_opt = obj.get(key); if (v_opt == null) return null; const v = v_opt.?; if (v != .string) return null; return v.string; } fn writeJSONString(writer: anytype, s: []const u8) !void { try writer.writeAll("\""); for (s) |c| { switch (c) { '"' => try writer.writeAll("\\\""), '\\' => try writer.writeAll("\\\\"), '\n' => try writer.writeAll("\\n"), '\r' => try writer.writeAll("\\r"), '\t' => try writer.writeAll("\\t"), else => { if (c < 0x20) { var buf: [6]u8 = undefined; buf[0] = '\\'; buf[1] = 'u'; buf[2] = '0'; buf[3] = '0'; buf[4] = hexDigit(@intCast((c >> 4) & 0x0F)); buf[5] = hexDigit(@intCast(c & 0x0F)); try writer.writeAll(&buf); } else { try writer.writeAll(&[_]u8{c}); } }, } } try writer.writeAll("\""); } fn hexDigit(v: u8) u8 { return if (v < 10) ('0' + v) else ('a' + (v - 10)); } fn printUsage() !void { std.debug.print("Usage: ml find [query] [options]\n", .{}); std.debug.print("\nSearch experiments by:\n", .{}); std.debug.print("\tQuery (free text):\tml find \"hypothesis: warmup\"\n", .{}); std.debug.print("\tTags:\t\t\tml find --tag ablation\n", .{}); std.debug.print("\tOutcome:\t\tml find --outcome validates\n", .{}); std.debug.print("\tDataset:\t\tml find --dataset imagenet\n", .{}); std.debug.print("\tExperiment group:\tml find --experiment-group lr-scaling\n", .{}); std.debug.print("\tAuthor:\t\t\tml find --author user@lab.edu\n", .{}); std.debug.print("\tTime range:\t\tml find --after 2024-01-01 --before 2024-03-01\n", .{}); std.debug.print("\nOptions:\n", .{}); std.debug.print("\t--limit \tMax results (default: 20)\n", .{}); std.debug.print("\t--json\t\tOutput as JSON\n", .{}); std.debug.print("\t--csv\t\tOutput as CSV\n", .{}); std.debug.print("\t--help, -h\tShow this help\n", .{}); std.debug.print("\nExamples:\n", .{}); std.debug.print("\tml find --tag ablation --outcome validates\n", .{}); std.debug.print("\tml find --experiment-group batch-scaling --json\n", .{}); std.debug.print("\tml find \"learning rate\" --after 2024-01-01\n", .{}); }