fetch_ml/cli/src/commands/find.zig
Jeremie Fraeys a1988de8b1
style(cli): Standardize printUsage() formatting with tabs and ASCII symbols
Replace space-padding with consistent tab (\t) alignment in all printUsage() functions.

Add ligature-friendly ASCII symbols:

  - => for results/outcomes (renders as ⇒ with ligatures)

  - ~> for modifications/changes (renders as ~> with ligatures)

  - -> for state transitions (renders as → with ligatures)

  - [OK] / [FAIL] for status indicators

All symbols use ASCII 32-126 for xargs-safe, copy-pasteable output.
2026-02-23 14:09:49 -05:00

470 lines
15 KiB
Zig

const std = @import("std");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const core = @import("../core.zig");
pub const FindOptions = struct {
json: bool = false,
csv: bool = false,
limit: usize = 20,
tag: ?[]const u8 = null,
outcome: ?[]const u8 = null,
dataset: ?[]const u8 = null,
experiment_group: ?[]const u8 = null,
author: ?[]const u8 = null,
after: ?[]const u8 = null,
before: ?[]const u8 = null,
query: ?[]const u8 = null,
};
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
return printUsage();
}
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
return printUsage();
}
var flags = core.flags.CommonFlags{};
var limit: usize = 20;
var csv: bool = false;
var tag: ?[]const u8 = null;
var outcome: ?[]const u8 = null;
var dataset: ?[]const u8 = null;
var experiment_group: ?[]const u8 = null;
var author: ?[]const u8 = null;
var after: ?[]const u8 = null;
var before: ?[]const u8 = null;
var query_str: ?[]const u8 = null;
// First argument might be a query string or a flag
var arg_idx: usize = 0;
if (!std.mem.startsWith(u8, argv[0], "--")) {
query_str = argv[0];
arg_idx = 1;
}
var i: usize = arg_idx;
while (i < argv.len) : (i += 1) {
const arg = argv[i];
if (std.mem.eql(u8, arg, "--json")) {
flags.json = true;
} else if (std.mem.eql(u8, arg, "--csv")) {
csv = true;
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) {
limit = try std.fmt.parseInt(usize, argv[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) {
tag = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) {
outcome = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) {
dataset = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--group") and i + 1 < argv.len) {
experiment_group = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) {
author = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) {
after = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) {
before = argv[i + 1];
i += 1;
} else {
core.output.err("Unknown option");
return error.InvalidArgs;
}
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
std.debug.print("Searching experiments...\n", .{});
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
// Build search options struct for JSON builder
const search_options = FindOptions{
.json = flags.json,
.csv = csv,
.limit = limit,
.tag = tag,
.outcome = outcome,
.dataset = dataset,
.experiment_group = experiment_group,
.author = author,
.after = after,
.before = before,
.query = query_str,
};
const search_json = try buildSearchJson(allocator, &search_options);
defer allocator.free(search_json);
// Send search request - we'll use the dataset search opcode as a placeholder
// In production, this would have a dedicated search endpoint
try client.sendDatasetSearch(search_json, api_key_hash);
const msg = try client.receiveMessage(allocator);
defer allocator.free(msg);
// Parse response
const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch {
if (flags.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\":\"invalid_response\"}}\n", .{});
} else {
std.debug.print("Failed to parse search results\n", .{});
}
return error.InvalidResponse;
};
defer parsed.deinit();
const root = parsed.value;
if (flags.json) {
try io.stdoutWriteJson(root);
} else if (csv) {
const options = FindOptions{ .json = flags.json, .csv = csv };
try outputCsvResults(allocator, root, &options);
} else {
const options = FindOptions{ .json = flags.json, .csv = csv };
try outputHumanResults(root, &options);
}
}
fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
var first = true;
if (options.query) |q| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"query\":");
try writeJSONString(writer, q);
}
if (options.tag) |t| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"tag\":");
try writeJSONString(writer, t);
}
if (options.outcome) |o| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"outcome\":");
try writeJSONString(writer, o);
}
if (options.dataset) |d| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"dataset\":");
try writeJSONString(writer, d);
}
if (options.experiment_group) |eg| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"experiment_group\":");
try writeJSONString(writer, eg);
}
if (options.author) |a| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"author\":");
try writeJSONString(writer, a);
}
if (options.after) |a| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"after\":");
try writeJSONString(writer, a);
}
if (options.before) |b| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"before\":");
try writeJSONString(writer, b);
}
if (!first) try writer.writeAll(",");
try writer.print("\"limit\":{d}", .{options.limit});
try writer.writeAll("}");
return buf.toOwnedSlice(allocator);
}
fn outputHumanResults(root: std.json.Value, _options: *const FindOptions) !void {
_ = _options;
if (root != .object) {
std.debug.print("Invalid response format\n", .{});
return;
}
const obj = root.object;
// Check for error
if (obj.get("error")) |err| {
if (err == .string) {
std.debug.print("Search error: {s}\n", .{err.string});
}
return;
}
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
if (results == null) {
std.debug.print("No results found\n", .{});
return;
}
if (results.? != .array) {
std.debug.print("Invalid results format\n", .{});
return;
}
const items = results.?.array.items;
if (items.len == 0) {
std.debug.print("No experiments found matching your criteria\n", .{});
return;
}
for (items) |item| {
if (item != .object) continue;
const run_obj = item.object;
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
const short_id = if (id.len > 8) id[0..8] else id;
const job_name = jsonGetString(run_obj, "job_name") orelse "";
const outcome = jsonGetString(run_obj, "outcome") orelse "-";
const status = jsonGetString(run_obj, "status") orelse "unknown";
// Build group/tags field
var group_tags_buf: [100]u8 = undefined;
const group_tags = blk: {
const group = jsonGetString(run_obj, "experiment_group");
const tags = run_obj.get("tags");
if (group) |g| {
if (tags) |t| {
if (t == .string) {
break :blk std.fmt.bufPrint(&group_tags_buf, "{s}/{s}", .{ g, t.string }) catch g;
}
}
break :blk g;
}
if (tags) |t| {
if (t == .string) break :blk t.string;
}
break :blk "";
};
// TSV output: id => outcome | status | job_name | group_tags
std.debug.print("{s} => {s}\t{s}\t{s}\t{s}\n", .{
short_id, outcome, status, job_name, group_tags,
});
}
}
fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void {
_ = options;
if (root != .object) {
std.debug.print("Invalid response format\n", .{});
return;
}
const obj = root.object;
// Check for error
if (obj.get("error")) |err| {
if (err == .string) {
std.debug.print("Search error: {s}\n", .{err.string});
}
return;
}
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
if (results == null) {
return;
}
if (results.? != .array) {
return;
}
const items = results.?.array.items;
// CSV Header
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n");
for (items) |item| {
if (item != .object) continue;
const run_obj = item.object;
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
const job_name = jsonGetString(run_obj, "job_name") orelse "";
const outcome = jsonGetString(run_obj, "outcome") orelse "";
const status = jsonGetString(run_obj, "status") orelse "";
const group = jsonGetString(run_obj, "experiment_group") orelse "";
// Get tags
var tags: []const u8 = "";
if (run_obj.get("tags")) |t| {
if (t == .string) tags = t.string;
}
// Get hypothesis
var hypothesis: []const u8 = "";
if (run_obj.get("narrative")) |narr| {
if (narr == .object) {
if (narr.object.get("hypothesis")) |h| {
if (h == .string) hypothesis = h.string;
}
}
}
// Escape fields that might contain commas or quotes
const safe_job = try escapeCsv(allocator, job_name);
defer allocator.free(safe_job);
const safe_group = try escapeCsv(allocator, group);
defer allocator.free(safe_group);
const safe_tags = try escapeCsv(allocator, tags);
defer allocator.free(safe_tags);
const safe_hypo = try escapeCsv(allocator, hypothesis);
defer allocator.free(safe_hypo);
var buf: [1024]u8 = undefined;
const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{
id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo,
});
try stdout_file.writeAll(line);
}
}
fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 {
// Check if we need to escape (contains comma, quote, or newline)
var needs_escape = false;
for (s) |c| {
if (c == ',' or c == '"' or c == '\n' or c == '\r') {
needs_escape = true;
break;
}
}
if (!needs_escape) {
return allocator.dupe(u8, s);
}
// Escape: wrap in quotes and double existing quotes
var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| {
return err;
};
defer buf.deinit(allocator);
try buf.append(allocator, '"');
for (s) |c| {
if (c == '"') {
try buf.appendSlice(allocator, "\"\"");
} else {
try buf.append(allocator, c);
}
}
try buf.append(allocator, '"');
return buf.toOwnedSlice(allocator);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v_opt = obj.get(key);
if (v_opt == null) return null;
const v = v_opt.?;
if (v != .string) return null;
return v.string;
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}
fn printUsage() !void {
std.debug.print("Usage: ml find [query] [options]\n", .{});
std.debug.print("\nSearch experiments by:\n", .{});
std.debug.print("\tQuery (free text):\tml find \"hypothesis: warmup\"\n", .{});
std.debug.print("\tTags:\t\t\tml find --tag ablation\n", .{});
std.debug.print("\tOutcome:\t\tml find --outcome validates\n", .{});
std.debug.print("\tDataset:\t\tml find --dataset imagenet\n", .{});
std.debug.print("\tExperiment group:\tml find --experiment-group lr-scaling\n", .{});
std.debug.print("\tAuthor:\t\t\tml find --author user@lab.edu\n", .{});
std.debug.print("\tTime range:\t\tml find --after 2024-01-01 --before 2024-03-01\n", .{});
std.debug.print("\nOptions:\n", .{});
std.debug.print("\t--limit <n>\tMax results (default: 20)\n", .{});
std.debug.print("\t--json\t\tOutput as JSON\n", .{});
std.debug.print("\t--csv\t\tOutput as CSV\n", .{});
std.debug.print("\t--help, -h\tShow this help\n", .{});
std.debug.print("\nExamples:\n", .{});
std.debug.print("\tml find --tag ablation --outcome validates\n", .{});
std.debug.print("\tml find --experiment-group batch-scaling --json\n", .{});
std.debug.print("\tml find \"learning rate\" --after 2024-01-01\n", .{});
}