Replace space-padding with consistent tab (\t) alignment in all printUsage() functions. Add ligature-friendly ASCII symbols: - => for results/outcomes (renders as ⇒ with ligatures) - ~> for modifications/changes (renders as ~> with ligatures) - -> for state transitions (renders as → with ligatures) - [OK] / [FAIL] for status indicators All symbols use ASCII 32-126 for xargs-safe, copy-pasteable output.
470 lines
15 KiB
Zig
470 lines
15 KiB
Zig
const std = @import("std");
|
|
const Config = @import("../config.zig").Config;
|
|
const crypto = @import("../utils/crypto.zig");
|
|
const io = @import("../utils/io.zig");
|
|
const ws = @import("../net/ws/client.zig");
|
|
const protocol = @import("../net/protocol.zig");
|
|
const core = @import("../core.zig");
|
|
|
|
pub const FindOptions = struct {
|
|
json: bool = false,
|
|
csv: bool = false,
|
|
limit: usize = 20,
|
|
tag: ?[]const u8 = null,
|
|
outcome: ?[]const u8 = null,
|
|
dataset: ?[]const u8 = null,
|
|
experiment_group: ?[]const u8 = null,
|
|
author: ?[]const u8 = null,
|
|
after: ?[]const u8 = null,
|
|
before: ?[]const u8 = null,
|
|
query: ?[]const u8 = null,
|
|
};
|
|
|
|
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|
if (argv.len == 0) {
|
|
return printUsage();
|
|
}
|
|
|
|
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
|
return printUsage();
|
|
}
|
|
|
|
var flags = core.flags.CommonFlags{};
|
|
var limit: usize = 20;
|
|
var csv: bool = false;
|
|
var tag: ?[]const u8 = null;
|
|
var outcome: ?[]const u8 = null;
|
|
var dataset: ?[]const u8 = null;
|
|
var experiment_group: ?[]const u8 = null;
|
|
var author: ?[]const u8 = null;
|
|
var after: ?[]const u8 = null;
|
|
var before: ?[]const u8 = null;
|
|
var query_str: ?[]const u8 = null;
|
|
|
|
// First argument might be a query string or a flag
|
|
var arg_idx: usize = 0;
|
|
if (!std.mem.startsWith(u8, argv[0], "--")) {
|
|
query_str = argv[0];
|
|
arg_idx = 1;
|
|
}
|
|
|
|
var i: usize = arg_idx;
|
|
while (i < argv.len) : (i += 1) {
|
|
const arg = argv[i];
|
|
if (std.mem.eql(u8, arg, "--json")) {
|
|
flags.json = true;
|
|
} else if (std.mem.eql(u8, arg, "--csv")) {
|
|
csv = true;
|
|
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) {
|
|
limit = try std.fmt.parseInt(usize, argv[i + 1], 10);
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) {
|
|
tag = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) {
|
|
outcome = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) {
|
|
dataset = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--group") and i + 1 < argv.len) {
|
|
experiment_group = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) {
|
|
author = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) {
|
|
after = argv[i + 1];
|
|
i += 1;
|
|
} else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) {
|
|
before = argv[i + 1];
|
|
i += 1;
|
|
} else {
|
|
core.output.err("Unknown option");
|
|
return error.InvalidArgs;
|
|
}
|
|
}
|
|
|
|
const cfg = try Config.load(allocator);
|
|
defer {
|
|
var mut_cfg = cfg;
|
|
mut_cfg.deinit(allocator);
|
|
}
|
|
|
|
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
|
defer allocator.free(api_key_hash);
|
|
|
|
const ws_url = try cfg.getWebSocketUrl(allocator);
|
|
defer allocator.free(ws_url);
|
|
std.debug.print("Searching experiments...\n", .{});
|
|
|
|
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
|
defer client.close();
|
|
|
|
// Build search options struct for JSON builder
|
|
const search_options = FindOptions{
|
|
.json = flags.json,
|
|
.csv = csv,
|
|
.limit = limit,
|
|
.tag = tag,
|
|
.outcome = outcome,
|
|
.dataset = dataset,
|
|
.experiment_group = experiment_group,
|
|
.author = author,
|
|
.after = after,
|
|
.before = before,
|
|
.query = query_str,
|
|
};
|
|
|
|
const search_json = try buildSearchJson(allocator, &search_options);
|
|
defer allocator.free(search_json);
|
|
|
|
// Send search request - we'll use the dataset search opcode as a placeholder
|
|
// In production, this would have a dedicated search endpoint
|
|
try client.sendDatasetSearch(search_json, api_key_hash);
|
|
|
|
const msg = try client.receiveMessage(allocator);
|
|
defer allocator.free(msg);
|
|
|
|
// Parse response
|
|
const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch {
|
|
if (flags.json) {
|
|
var out = io.stdoutWriter();
|
|
try out.print("{{\"error\":\"invalid_response\"}}\n", .{});
|
|
} else {
|
|
std.debug.print("Failed to parse search results\n", .{});
|
|
}
|
|
return error.InvalidResponse;
|
|
};
|
|
defer parsed.deinit();
|
|
|
|
const root = parsed.value;
|
|
|
|
if (flags.json) {
|
|
try io.stdoutWriteJson(root);
|
|
} else if (csv) {
|
|
const options = FindOptions{ .json = flags.json, .csv = csv };
|
|
try outputCsvResults(allocator, root, &options);
|
|
} else {
|
|
const options = FindOptions{ .json = flags.json, .csv = csv };
|
|
try outputHumanResults(root, &options);
|
|
}
|
|
}
|
|
|
|
fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 {
|
|
var buf = std.ArrayList(u8).empty;
|
|
defer buf.deinit(allocator);
|
|
|
|
const writer = buf.writer(allocator);
|
|
try writer.writeAll("{");
|
|
|
|
var first = true;
|
|
|
|
if (options.query) |q| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"query\":");
|
|
try writeJSONString(writer, q);
|
|
}
|
|
|
|
if (options.tag) |t| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"tag\":");
|
|
try writeJSONString(writer, t);
|
|
}
|
|
|
|
if (options.outcome) |o| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"outcome\":");
|
|
try writeJSONString(writer, o);
|
|
}
|
|
|
|
if (options.dataset) |d| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"dataset\":");
|
|
try writeJSONString(writer, d);
|
|
}
|
|
|
|
if (options.experiment_group) |eg| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"experiment_group\":");
|
|
try writeJSONString(writer, eg);
|
|
}
|
|
|
|
if (options.author) |a| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"author\":");
|
|
try writeJSONString(writer, a);
|
|
}
|
|
|
|
if (options.after) |a| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"after\":");
|
|
try writeJSONString(writer, a);
|
|
}
|
|
|
|
if (options.before) |b| {
|
|
if (!first) try writer.writeAll(",");
|
|
first = false;
|
|
try writer.writeAll("\"before\":");
|
|
try writeJSONString(writer, b);
|
|
}
|
|
|
|
if (!first) try writer.writeAll(",");
|
|
try writer.print("\"limit\":{d}", .{options.limit});
|
|
|
|
try writer.writeAll("}");
|
|
|
|
return buf.toOwnedSlice(allocator);
|
|
}
|
|
|
|
fn outputHumanResults(root: std.json.Value, _options: *const FindOptions) !void {
|
|
_ = _options;
|
|
if (root != .object) {
|
|
std.debug.print("Invalid response format\n", .{});
|
|
return;
|
|
}
|
|
|
|
const obj = root.object;
|
|
|
|
// Check for error
|
|
if (obj.get("error")) |err| {
|
|
if (err == .string) {
|
|
std.debug.print("Search error: {s}\n", .{err.string});
|
|
}
|
|
return;
|
|
}
|
|
|
|
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
|
if (results == null) {
|
|
std.debug.print("No results found\n", .{});
|
|
return;
|
|
}
|
|
|
|
if (results.? != .array) {
|
|
std.debug.print("Invalid results format\n", .{});
|
|
return;
|
|
}
|
|
|
|
const items = results.?.array.items;
|
|
|
|
if (items.len == 0) {
|
|
std.debug.print("No experiments found matching your criteria\n", .{});
|
|
return;
|
|
}
|
|
|
|
for (items) |item| {
|
|
if (item != .object) continue;
|
|
const run_obj = item.object;
|
|
|
|
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
|
const short_id = if (id.len > 8) id[0..8] else id;
|
|
|
|
const job_name = jsonGetString(run_obj, "job_name") orelse "";
|
|
|
|
const outcome = jsonGetString(run_obj, "outcome") orelse "-";
|
|
const status = jsonGetString(run_obj, "status") orelse "unknown";
|
|
|
|
// Build group/tags field
|
|
var group_tags_buf: [100]u8 = undefined;
|
|
const group_tags = blk: {
|
|
const group = jsonGetString(run_obj, "experiment_group");
|
|
const tags = run_obj.get("tags");
|
|
|
|
if (group) |g| {
|
|
if (tags) |t| {
|
|
if (t == .string) {
|
|
break :blk std.fmt.bufPrint(&group_tags_buf, "{s}/{s}", .{ g, t.string }) catch g;
|
|
}
|
|
}
|
|
break :blk g;
|
|
}
|
|
if (tags) |t| {
|
|
if (t == .string) break :blk t.string;
|
|
}
|
|
break :blk "";
|
|
};
|
|
|
|
// TSV output: id => outcome | status | job_name | group_tags
|
|
std.debug.print("{s} => {s}\t{s}\t{s}\t{s}\n", .{
|
|
short_id, outcome, status, job_name, group_tags,
|
|
});
|
|
}
|
|
}
|
|
|
|
fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void {
|
|
_ = options;
|
|
|
|
if (root != .object) {
|
|
std.debug.print("Invalid response format\n", .{});
|
|
return;
|
|
}
|
|
|
|
const obj = root.object;
|
|
|
|
// Check for error
|
|
if (obj.get("error")) |err| {
|
|
if (err == .string) {
|
|
std.debug.print("Search error: {s}\n", .{err.string});
|
|
}
|
|
return;
|
|
}
|
|
|
|
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
|
if (results == null) {
|
|
return;
|
|
}
|
|
|
|
if (results.? != .array) {
|
|
return;
|
|
}
|
|
|
|
const items = results.?.array.items;
|
|
|
|
// CSV Header
|
|
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
|
try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n");
|
|
|
|
for (items) |item| {
|
|
if (item != .object) continue;
|
|
const run_obj = item.object;
|
|
|
|
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
|
const job_name = jsonGetString(run_obj, "job_name") orelse "";
|
|
const outcome = jsonGetString(run_obj, "outcome") orelse "";
|
|
const status = jsonGetString(run_obj, "status") orelse "";
|
|
const group = jsonGetString(run_obj, "experiment_group") orelse "";
|
|
|
|
// Get tags
|
|
var tags: []const u8 = "";
|
|
if (run_obj.get("tags")) |t| {
|
|
if (t == .string) tags = t.string;
|
|
}
|
|
|
|
// Get hypothesis
|
|
var hypothesis: []const u8 = "";
|
|
if (run_obj.get("narrative")) |narr| {
|
|
if (narr == .object) {
|
|
if (narr.object.get("hypothesis")) |h| {
|
|
if (h == .string) hypothesis = h.string;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Escape fields that might contain commas or quotes
|
|
const safe_job = try escapeCsv(allocator, job_name);
|
|
defer allocator.free(safe_job);
|
|
const safe_group = try escapeCsv(allocator, group);
|
|
defer allocator.free(safe_group);
|
|
const safe_tags = try escapeCsv(allocator, tags);
|
|
defer allocator.free(safe_tags);
|
|
const safe_hypo = try escapeCsv(allocator, hypothesis);
|
|
defer allocator.free(safe_hypo);
|
|
|
|
var buf: [1024]u8 = undefined;
|
|
const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{
|
|
id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo,
|
|
});
|
|
try stdout_file.writeAll(line);
|
|
}
|
|
}
|
|
|
|
fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 {
|
|
// Check if we need to escape (contains comma, quote, or newline)
|
|
var needs_escape = false;
|
|
for (s) |c| {
|
|
if (c == ',' or c == '"' or c == '\n' or c == '\r') {
|
|
needs_escape = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!needs_escape) {
|
|
return allocator.dupe(u8, s);
|
|
}
|
|
|
|
// Escape: wrap in quotes and double existing quotes
|
|
var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| {
|
|
return err;
|
|
};
|
|
defer buf.deinit(allocator);
|
|
|
|
try buf.append(allocator, '"');
|
|
for (s) |c| {
|
|
if (c == '"') {
|
|
try buf.appendSlice(allocator, "\"\"");
|
|
} else {
|
|
try buf.append(allocator, c);
|
|
}
|
|
}
|
|
try buf.append(allocator, '"');
|
|
|
|
return buf.toOwnedSlice(allocator);
|
|
}
|
|
|
|
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
|
const v_opt = obj.get(key);
|
|
if (v_opt == null) return null;
|
|
const v = v_opt.?;
|
|
if (v != .string) return null;
|
|
return v.string;
|
|
}
|
|
|
|
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
|
try writer.writeAll("\"");
|
|
for (s) |c| {
|
|
switch (c) {
|
|
'"' => try writer.writeAll("\\\""),
|
|
'\\' => try writer.writeAll("\\\\"),
|
|
'\n' => try writer.writeAll("\\n"),
|
|
'\r' => try writer.writeAll("\\r"),
|
|
'\t' => try writer.writeAll("\\t"),
|
|
else => {
|
|
if (c < 0x20) {
|
|
var buf: [6]u8 = undefined;
|
|
buf[0] = '\\';
|
|
buf[1] = 'u';
|
|
buf[2] = '0';
|
|
buf[3] = '0';
|
|
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
|
buf[5] = hexDigit(@intCast(c & 0x0F));
|
|
try writer.writeAll(&buf);
|
|
} else {
|
|
try writer.writeAll(&[_]u8{c});
|
|
}
|
|
},
|
|
}
|
|
}
|
|
try writer.writeAll("\"");
|
|
}
|
|
|
|
fn hexDigit(v: u8) u8 {
|
|
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
|
}
|
|
|
|
fn printUsage() !void {
|
|
std.debug.print("Usage: ml find [query] [options]\n", .{});
|
|
std.debug.print("\nSearch experiments by:\n", .{});
|
|
std.debug.print("\tQuery (free text):\tml find \"hypothesis: warmup\"\n", .{});
|
|
std.debug.print("\tTags:\t\t\tml find --tag ablation\n", .{});
|
|
std.debug.print("\tOutcome:\t\tml find --outcome validates\n", .{});
|
|
std.debug.print("\tDataset:\t\tml find --dataset imagenet\n", .{});
|
|
std.debug.print("\tExperiment group:\tml find --experiment-group lr-scaling\n", .{});
|
|
std.debug.print("\tAuthor:\t\t\tml find --author user@lab.edu\n", .{});
|
|
std.debug.print("\tTime range:\t\tml find --after 2024-01-01 --before 2024-03-01\n", .{});
|
|
std.debug.print("\nOptions:\n", .{});
|
|
std.debug.print("\t--limit <n>\tMax results (default: 20)\n", .{});
|
|
std.debug.print("\t--json\t\tOutput as JSON\n", .{});
|
|
std.debug.print("\t--csv\t\tOutput as CSV\n", .{});
|
|
std.debug.print("\t--help, -h\tShow this help\n", .{});
|
|
std.debug.print("\nExamples:\n", .{});
|
|
std.debug.print("\tml find --tag ablation --outcome validates\n", .{});
|
|
std.debug.print("\tml find --experiment-group batch-scaling --json\n", .{});
|
|
std.debug.print("\tml find \"learning rate\" --after 2024-01-01\n", .{});
|
|
}
|