feat: Research features - narrative fields and outcome tracking
Add comprehensive research context tracking to jobs: - Narrative fields: hypothesis, context, intent, expected_outcome - Experiment groups and tags for organization - Run comparison (compare command) for diff analysis - Run search (find command) with criteria filtering - Run export (export command) for data portability - Outcome setting (outcome command) for experiment validation Update queue and requeue commands to support narrative fields. Add narrative validation to manifest validator. Add WebSocket handlers for compare, find, export, and outcome operations. Includes E2E tests for phase 2 features.
This commit is contained in:
parent
94020e4ca4
commit
260e18499e
15 changed files with 2851 additions and 21 deletions
|
|
@ -1,13 +1,17 @@
|
|||
pub const annotate = @import("commands/annotate.zig");
|
||||
pub const cancel = @import("commands/cancel.zig");
|
||||
pub const compare = @import("commands/compare.zig");
|
||||
pub const dataset = @import("commands/dataset.zig");
|
||||
pub const experiment = @import("commands/experiment.zig");
|
||||
pub const export_cmd = @import("commands/export_cmd.zig");
|
||||
pub const find = @import("commands/find.zig");
|
||||
pub const info = @import("commands/info.zig");
|
||||
pub const init = @import("commands/init.zig");
|
||||
pub const jupyter = @import("commands/jupyter.zig");
|
||||
pub const logs = @import("commands/logs.zig");
|
||||
pub const monitor = @import("commands/monitor.zig");
|
||||
pub const narrative = @import("commands/narrative.zig");
|
||||
pub const outcome = @import("commands/outcome.zig");
|
||||
pub const prune = @import("commands/prune.zig");
|
||||
pub const queue = @import("commands/queue.zig");
|
||||
pub const requeue = @import("commands/requeue.zig");
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const io = @import("../utils/io.zig");
|
|||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const pii = @import("../utils/pii.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
|
|
@ -24,6 +25,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var note: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
var privacy_scan: bool = false;
|
||||
var force: bool = false;
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
|
|
@ -51,6 +54,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--privacy-scan")) {
|
||||
privacy_scan = true;
|
||||
} else if (std.mem.eql(u8, a, "--force")) {
|
||||
force = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
|
|
@ -69,6 +76,19 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
// PII detection if requested
|
||||
if (privacy_scan) {
|
||||
if (pii.scanForPII(note.?, allocator)) |warning| {
|
||||
colors.printWarning("{s}\n", .{warning.?});
|
||||
if (!force) {
|
||||
colors.printInfo("Use --force to store anyway, or edit your note.\n", .{});
|
||||
return error.PIIDetected;
|
||||
}
|
||||
} else |_| {
|
||||
// PII scan failed, continue anyway
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
|
|
@ -152,8 +172,16 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]\n", .{});
|
||||
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json] [--privacy-scan] [--force]\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --note <text> Annotation text (required)\n", .{});
|
||||
colors.printInfo(" --author <name> Author of the annotation\n", .{});
|
||||
colors.printInfo(" --base <path> Base path to search for run_manifest.json\n", .{});
|
||||
colors.printInfo(" --privacy-scan Scan note for PII before storing\n", .{});
|
||||
colors.printInfo(" --force Store even if PII detected\n", .{});
|
||||
colors.printInfo(" --json Output JSON response\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{});
|
||||
colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{});
|
||||
colors.printInfo(" ml annotate run_123 --note \"Contact at user@example.com\" --privacy-scan\n", .{});
|
||||
}
|
||||
|
|
|
|||
512
cli/src/commands/compare.zig
Normal file
512
cli/src/commands/compare.zig
Normal file
|
|
@ -0,0 +1,512 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
|
||||
pub const CompareOptions = struct {
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
all_fields: bool = false,
|
||||
fields: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len < 2) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const run_a = argv[0];
|
||||
const run_b = argv[1];
|
||||
|
||||
var options = CompareOptions{};
|
||||
|
||||
var i: usize = 2;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
options.csv = true;
|
||||
} else if (std.mem.eql(u8, arg, "--all")) {
|
||||
options.all_fields = true;
|
||||
} else if (std.mem.eql(u8, arg, "--fields") and i + 1 < argv.len) {
|
||||
options.fields = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
// Fetch both runs
|
||||
colors.printInfo("Fetching run {s}...\n", .{run_a});
|
||||
var client_a = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client_a.close();
|
||||
|
||||
// Try to get experiment info for run A
|
||||
try client_a.sendGetExperiment(run_a, api_key_hash);
|
||||
const msg_a = try client_a.receiveMessage(allocator);
|
||||
defer allocator.free(msg_a);
|
||||
|
||||
colors.printInfo("Fetching run {s}...\n", .{run_b});
|
||||
var client_b = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client_b.close();
|
||||
|
||||
try client_b.sendGetExperiment(run_b, api_key_hash);
|
||||
const msg_b = try client_b.receiveMessage(allocator);
|
||||
defer allocator.free(msg_b);
|
||||
|
||||
// Parse responses
|
||||
const parsed_a = std.json.parseFromSlice(std.json.Value, allocator, msg_a, .{}) catch {
|
||||
colors.printError("Failed to parse response for {s}\n", .{run_a});
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed_a.deinit();
|
||||
|
||||
const parsed_b = std.json.parseFromSlice(std.json.Value, allocator, msg_b, .{}) catch {
|
||||
colors.printError("Failed to parse response for {s}\n", .{run_b});
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed_b.deinit();
|
||||
|
||||
const root_a = parsed_a.value.object;
|
||||
const root_b = parsed_b.value.object;
|
||||
|
||||
// Check for errors
|
||||
if (root_a.get("error")) |err_a| {
|
||||
colors.printError("Error fetching {s}: {s}\n", .{ run_a, err_a.string });
|
||||
return error.ServerError;
|
||||
}
|
||||
if (root_b.get("error")) |err_b| {
|
||||
colors.printError("Error fetching {s}: {s}\n", .{ run_b, err_b.string });
|
||||
return error.ServerError;
|
||||
}
|
||||
|
||||
if (options.json) {
|
||||
try outputJsonComparison(allocator, root_a, root_b, run_a, run_b);
|
||||
} else {
|
||||
try outputHumanComparison(root_a, root_b, run_a, run_b, options);
|
||||
}
|
||||
}
|
||||
|
||||
fn outputHumanComparison(
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
options: CompareOptions,
|
||||
) !void {
|
||||
colors.printInfo("\n=== Comparison: {s} vs {s} ===\n\n", .{ run_a, run_b });
|
||||
|
||||
// Common fields
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "unknown";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "unknown";
|
||||
|
||||
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
|
||||
colors.printWarning("Job names differ:\n", .{});
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, job_name_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, job_name_b });
|
||||
} else {
|
||||
colors.printInfo("Job Name: {s}\n", .{job_name_a});
|
||||
}
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
if (group_a.len > 0 or group_b.len > 0) {
|
||||
colors.printInfo("\nExperiment Group:\n", .{});
|
||||
if (std.mem.eql(u8, group_a, group_b)) {
|
||||
colors.printInfo(" Both: {s}\n", .{group_a});
|
||||
} else {
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, group_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, group_b });
|
||||
}
|
||||
}
|
||||
|
||||
// Narrative fields
|
||||
const narrative_a = root_a.get("narrative");
|
||||
const narrative_b = root_b.get("narrative");
|
||||
|
||||
if (narrative_a != null or narrative_b != null) {
|
||||
colors.printInfo("\n--- Narrative ---\n", .{});
|
||||
|
||||
if (narrative_a) |na| {
|
||||
if (narrative_b) |nb| {
|
||||
if (na == .object and nb == .object) {
|
||||
try compareNarrativeFields(na.object, nb.object, run_a, run_b);
|
||||
}
|
||||
} else {
|
||||
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_a, run_b });
|
||||
}
|
||||
} else if (narrative_b) |_| {
|
||||
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_b, run_a });
|
||||
}
|
||||
}
|
||||
|
||||
// Metadata differences
|
||||
const meta_a = root_a.get("metadata");
|
||||
const meta_b = root_b.get("metadata");
|
||||
|
||||
if (meta_a) |ma| {
|
||||
if (meta_b) |mb| {
|
||||
if (ma == .object and mb == .object) {
|
||||
colors.printInfo("\n--- Metadata Differences ---\n", .{});
|
||||
try compareMetadata(ma.object, mb.object, run_a, run_b, options.all_fields);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics (if available)
|
||||
const metrics_a = root_a.get("metrics");
|
||||
const metrics_b = root_b.get("metrics");
|
||||
|
||||
if (metrics_a) |ma| {
|
||||
if (metrics_b) |mb| {
|
||||
if (ma == .object and mb == .object) {
|
||||
colors.printInfo("\n--- Metrics ---\n", .{});
|
||||
try compareMetrics(ma.object, mb.object, run_a, run_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Outcome
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
if (outcome_a.len > 0 or outcome_b.len > 0) {
|
||||
colors.printInfo("\n--- Outcome ---\n", .{});
|
||||
if (std.mem.eql(u8, outcome_a, outcome_b)) {
|
||||
colors.printInfo(" Both: {s}\n", .{outcome_a});
|
||||
} else {
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, outcome_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, outcome_b });
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n", .{});
|
||||
}
|
||||
|
||||
fn outputJsonComparison(
|
||||
allocator: std.mem.Allocator,
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
|
||||
try writer.writeAll("{\"run_a\":\"");
|
||||
try writer.writeAll(run_a);
|
||||
try writer.writeAll("\",\"run_b\":\"");
|
||||
try writer.writeAll(run_b);
|
||||
try writer.writeAll("\",\"differences\":");
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
// Job names
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
|
||||
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"job_name\":{\"a\":\"");
|
||||
try writer.writeAll(job_name_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(job_name_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
if (!std.mem.eql(u8, group_a, group_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":{\"a\":\"");
|
||||
try writer.writeAll(group_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(group_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
// Outcomes
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
if (!std.mem.eql(u8, outcome_a, outcome_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"outcome\":{\"a\":\"");
|
||||
try writer.writeAll(outcome_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(outcome_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
try writer.writeAll("}}");
|
||||
try writer.writeAll("}\n");
|
||||
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
}
|
||||
|
||||
fn compareNarrativeFields(
|
||||
na: std.json.ObjectMap,
|
||||
nb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
const fields = [_][]const u8{ "hypothesis", "context", "intent", "expected_outcome" };
|
||||
|
||||
for (fields) |field| {
|
||||
const val_a = jsonGetString(na, field);
|
||||
const val_b = jsonGetString(nb, field);
|
||||
|
||||
if (val_a != null and val_b != null) {
|
||||
if (!std.mem.eql(u8, val_a.?, val_b.?)) {
|
||||
colors.printInfo(" {s}:\n", .{field});
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, val_a.? });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, val_b.? });
|
||||
}
|
||||
} else if (val_a != null) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ field, run_a });
|
||||
} else if (val_b != null) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ field, run_b });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compareMetadata(
|
||||
ma: std.json.ObjectMap,
|
||||
mb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
show_all: bool,
|
||||
) !void {
|
||||
var has_differences = false;
|
||||
|
||||
// Compare key metadata fields
|
||||
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
|
||||
|
||||
for (keys) |key| {
|
||||
if (ma.get(key)) |va| {
|
||||
if (mb.get(key)) |vb| {
|
||||
const str_a = jsonValueToString(va);
|
||||
const str_b = jsonValueToString(vb);
|
||||
|
||||
if (!std.mem.eql(u8, str_a, str_b)) {
|
||||
has_differences = true;
|
||||
colors.printInfo(" {s}: {s} → {s}\n", .{ key, str_a, str_b });
|
||||
} else if (show_all) {
|
||||
colors.printInfo(" {s}: {s} (same)\n", .{ key, str_a });
|
||||
}
|
||||
} else if (show_all) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ key, run_a });
|
||||
}
|
||||
} else if (mb.get(key)) |_| {
|
||||
if (show_all) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ key, run_b });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_differences and !show_all) {
|
||||
colors.printInfo(" (no significant differences in common metadata)\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
fn compareMetrics(
|
||||
ma: std.json.ObjectMap,
|
||||
mb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
_ = run_a;
|
||||
_ = run_b;
|
||||
|
||||
// Common metrics to compare
|
||||
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time", "validation_loss" };
|
||||
|
||||
for (metrics) |metric| {
|
||||
if (ma.get(metric)) |va| {
|
||||
if (mb.get(metric)) |vb| {
|
||||
const val_a = jsonValueToFloat(va);
|
||||
const val_b = jsonValueToFloat(vb);
|
||||
|
||||
const diff = val_b - val_a;
|
||||
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
|
||||
|
||||
const arrow = if (diff > 0) "↑" else if (diff < 0) "↓" else "=";
|
||||
|
||||
colors.printInfo(" {s}: {d:.4} → {d:.4} ({s}{d:.4}, {d:.1}%)\n", .{
|
||||
metric, val_a, val_b, arrow, @abs(diff), percent,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn outputCsvComparison(
|
||||
allocator: std.mem.Allocator,
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
|
||||
// Header with actual run IDs as column names
|
||||
try writer.print("field,{s},{s},delta,notes\n", .{ run_a, run_b });
|
||||
|
||||
// Job names
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
|
||||
const job_same = std.mem.eql(u8, job_name_a, job_name_b);
|
||||
try writer.print("job_name,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
|
||||
job_name_a, job_name_b,
|
||||
if (job_same) "same" else "changed", if (job_same) "" else "different job names",
|
||||
});
|
||||
|
||||
// Outcomes
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
const outcome_same = std.mem.eql(u8, outcome_a, outcome_b);
|
||||
try writer.print("outcome,{s},{s},{s},\"{s}\"\n", .{
|
||||
outcome_a, outcome_b,
|
||||
if (outcome_same) "same" else "changed", if (outcome_same) "" else "different outcomes",
|
||||
});
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
const group_same = std.mem.eql(u8, group_a, group_b);
|
||||
try writer.print("experiment_group,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
|
||||
group_a, group_b,
|
||||
if (group_same) "same" else "changed", if (group_same) "" else "different groups",
|
||||
});
|
||||
|
||||
// Metadata fields with delta calculation for numeric values
|
||||
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
|
||||
for (keys) |key| {
|
||||
if (root_a.get(key)) |va| {
|
||||
if (root_b.get(key)) |vb| {
|
||||
const str_a = jsonValueToString(va);
|
||||
const str_b = jsonValueToString(vb);
|
||||
const same = std.mem.eql(u8, str_a, str_b);
|
||||
|
||||
// Try to calculate delta for numeric values
|
||||
const delta = if (!same) blk: {
|
||||
const f_a = jsonValueToFloat(va);
|
||||
const f_b = jsonValueToFloat(vb);
|
||||
if (f_a != 0 or f_b != 0) {
|
||||
break :blk try std.fmt.allocPrint(allocator, "{d:.4}", .{f_b - f_a});
|
||||
}
|
||||
break :blk "changed";
|
||||
} else "0";
|
||||
defer if (!same and (jsonValueToFloat(va) != 0 or jsonValueToFloat(vb) != 0)) allocator.free(delta);
|
||||
|
||||
try writer.print("{s},{s},{s},{s},\"{s}\"\n", .{
|
||||
key, str_a, str_b, delta,
|
||||
if (same) "same" else "changed",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics with delta calculation
|
||||
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time" };
|
||||
for (metrics) |metric| {
|
||||
if (root_a.get(metric)) |va| {
|
||||
if (root_b.get(metric)) |vb| {
|
||||
const val_a = jsonValueToFloat(va);
|
||||
const val_b = jsonValueToFloat(vb);
|
||||
const diff = val_b - val_a;
|
||||
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
|
||||
|
||||
const notes = if (std.mem.eql(u8, metric, "loss") or std.mem.eql(u8, metric, "training_time"))
|
||||
if (diff < 0) "improved" else if (diff > 0) "degraded" else "same"
|
||||
else if (diff > 0) "improved" else if (diff < 0) "degraded" else "same";
|
||||
|
||||
try writer.print("{s},{d:.4},{d:.4},{d:.4},\"{d:.1}% - {s}\"\n", .{
|
||||
metric, val_a, val_b, diff, percent, notes,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) return null;
|
||||
const v = v_opt.?;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn jsonValueToString(v: std.json.Value) []const u8 {
|
||||
return switch (v) {
|
||||
.string => |s| s,
|
||||
.integer => "number",
|
||||
.float => "number",
|
||||
.bool => |b| if (b) "true" else "false",
|
||||
else => "complex",
|
||||
};
|
||||
}
|
||||
|
||||
fn jsonValueToFloat(v: std.json.Value) f64 {
|
||||
return switch (v) {
|
||||
.float => |f| f,
|
||||
.integer => |i| @as(f64, @floatFromInt(i)),
|
||||
else => 0,
|
||||
};
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml compare <run-a> <run-b> [options]\n", .{});
|
||||
colors.printInfo("\nCompare two runs and show differences in:\n", .{});
|
||||
colors.printInfo(" - Job metadata (batch_size, learning_rate, etc.)\n", .{});
|
||||
colors.printInfo(" - Narrative fields (hypothesis, context, intent)\n", .{});
|
||||
colors.printInfo(" - Metrics (accuracy, loss, training_time)\n", .{});
|
||||
colors.printInfo(" - Outcome status\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --json Output as JSON\n", .{});
|
||||
colors.printInfo(" --all Show all fields (including unchanged)\n", .{});
|
||||
colors.printInfo(" --fields <csv> Compare only specific fields\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def --json\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def --all\n", .{});
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ const DatasetOptions = struct {
|
|||
dry_run: bool = false,
|
||||
validate: bool = false,
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
|
|
@ -34,6 +35,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
options.validate = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
options.csv = true;
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
printUsage();
|
||||
|
|
@ -63,6 +66,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
} else if (std.mem.eql(u8, action, "search")) {
|
||||
try searchDatasets(allocator, positional.items[1], &options);
|
||||
return error.InvalidArgs;
|
||||
} else if (std.mem.eql(u8, action, "verify")) {
|
||||
try verifyDataset(allocator, positional.items[1], &options);
|
||||
return;
|
||||
}
|
||||
},
|
||||
3 => {
|
||||
|
|
@ -72,7 +78,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
},
|
||||
else => {
|
||||
colors.printError("Unknoen action: {s}\n", .{action});
|
||||
colors.printError("Unknown action: {s}\n", .{action});
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
},
|
||||
|
|
@ -86,6 +92,7 @@ fn printUsage() void {
|
|||
colors.printInfo(" register <name> <url> Register a dataset with URL\n", .{});
|
||||
colors.printInfo(" info <name> Show dataset information\n", .{});
|
||||
colors.printInfo(" search <term> Search datasets by name/description\n", .{});
|
||||
colors.printInfo(" verify <path|id> Verify dataset integrity\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be requested\n", .{});
|
||||
colors.printInfo(" --validate Validate inputs only (no request)\n", .{});
|
||||
|
|
@ -362,6 +369,68 @@ fn searchDatasets(allocator: std.mem.Allocator, term: []const u8, options: *cons
|
|||
}
|
||||
}
|
||||
|
||||
fn verifyDataset(allocator: std.mem.Allocator, target: []const u8, options: *const DatasetOptions) !void {
|
||||
colors.printInfo("Verifying dataset: {s}\n", .{target});
|
||||
|
||||
// For now, use basic stat to check if path exists
|
||||
// In production, this would compute SHA256 hashes
|
||||
|
||||
const path = if (std.fs.path.isAbsolute(target))
|
||||
target
|
||||
else
|
||||
try std.fs.path.join(allocator, &[_][]const u8{ ".", target });
|
||||
defer if (!std.fs.path.isAbsolute(target)) allocator.free(path);
|
||||
|
||||
var dir = std.fs.openDirAbsolute(path, .{ .iterate = true }) catch {
|
||||
colors.printError("Dataset not found: {s}\n", .{target});
|
||||
return error.FileNotFound;
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var file_count: usize = 0;
|
||||
var total_size: u64 = 0;
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit();
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind != .file) continue;
|
||||
file_count += 1;
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ path, entry.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const stat = std.fs.cwd().statFile(full_path) catch continue;
|
||||
total_size += stat.size;
|
||||
}
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"path\":\"{s}\",\"files\":{d},\"size\":{d},\"ok\":true}}\n", .{
|
||||
target, file_count, total_size,
|
||||
}) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
} else if (options.csv) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll("metric,value\n");
|
||||
var buf: [256]u8 = undefined;
|
||||
const line1 = try std.fmt.bufPrint(&buf, "path,{s}\n", .{target});
|
||||
try stdout_file.writeAll(line1);
|
||||
const line2 = try std.fmt.bufPrint(&buf, "files,{d}\n", .{file_count});
|
||||
try stdout_file.writeAll(line2);
|
||||
const line3 = try std.fmt.bufPrint(&buf, "size_bytes,{d}\n", .{total_size});
|
||||
try stdout_file.writeAll(line3);
|
||||
const line4 = try std.fmt.bufPrint(&buf, "size_mb,{d:.2}\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
|
||||
try stdout_file.writeAll(line4);
|
||||
} else {
|
||||
colors.printSuccess("✓ Dataset verified\n", .{});
|
||||
colors.printInfo(" Path: {s}\n", .{target});
|
||||
colors.printInfo(" Files: {d}\n", .{file_count});
|
||||
colors.printInfo(" Size: {d:.2} MB\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
|
||||
}
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeByte('"');
|
||||
for (s) |c| {
|
||||
|
|
|
|||
336
cli/src/commands/export_cmd.zig
Normal file
336
cli/src/commands/export_cmd.zig
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
|
||||
pub const ExportOptions = struct {
|
||||
anonymize: bool = false,
|
||||
anonymize_level: []const u8 = "metadata-only", // metadata-only, full
|
||||
bundle: ?[]const u8 = null,
|
||||
base_override: ?[]const u8 = null,
|
||||
json: bool = false,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const target = argv[0];
|
||||
var options = ExportOptions{};
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--anonymize")) {
|
||||
options.anonymize = true;
|
||||
} else if (std.mem.eql(u8, arg, "--anonymize-level") and i + 1 < argv.len) {
|
||||
options.anonymize_level = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--bundle") and i + 1 < argv.len) {
|
||||
options.bundle = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < argv.len) {
|
||||
options.base_override = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate anonymize level
|
||||
if (!std.mem.eql(u8, options.anonymize_level, "metadata-only") and
|
||||
!std.mem.eql(u8, options.anonymize_level, "full"))
|
||||
{
|
||||
colors.printError("Invalid anonymize level: {s}. Use 'metadata-only' or 'full'\n", .{options.anonymize_level});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = options.base_override orelse cfg.worker_base;
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
// Read the manifest
|
||||
const manifest_content = manifest.readFileAlloc(allocator, manifest_path) catch |err| {
|
||||
colors.printError("Failed to read manifest: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_content);
|
||||
|
||||
// Parse the manifest
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, manifest_content, .{}) catch |err| {
|
||||
colors.printError("Failed to parse manifest: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
// Anonymize if requested
|
||||
var final_content: []u8 = undefined;
|
||||
var final_content_owned = false;
|
||||
|
||||
if (options.anonymize) {
|
||||
final_content = try anonymizeManifest(allocator, parsed.value, options.anonymize_level);
|
||||
final_content_owned = true;
|
||||
} else {
|
||||
final_content = manifest_content;
|
||||
}
|
||||
defer if (final_content_owned) allocator.free(final_content);
|
||||
|
||||
// Output or bundle
|
||||
if (options.bundle) |bundle_path| {
|
||||
// Create a simple tar-like bundle (just the manifest for now)
|
||||
// In production, this would include code, configs, etc.
|
||||
var bundle_file = try std.fs.cwd().createFile(bundle_path, .{});
|
||||
defer bundle_file.close();
|
||||
|
||||
try bundle_file.writeAll(final_content);
|
||||
|
||||
if (options.json) {
|
||||
var stdout_writer = io.stdoutWriter();
|
||||
try stdout_writer.print("{{\"success\":true,\"bundle\":\"{s}\",\"anonymized\":{}}}\n", .{
|
||||
bundle_path,
|
||||
options.anonymize,
|
||||
});
|
||||
} else {
|
||||
colors.printSuccess("✓ Exported to {s}\n", .{bundle_path});
|
||||
if (options.anonymize) {
|
||||
colors.printInfo(" Anonymization level: {s}\n", .{options.anonymize_level});
|
||||
colors.printInfo(" Paths redacted, IPs removed, usernames anonymized\n", .{});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Output to stdout
|
||||
var stdout_writer = io.stdoutWriter();
|
||||
try stdout_writer.print("{s}\n", .{final_content});
|
||||
}
|
||||
}
|
||||
|
||||
fn anonymizeManifest(
|
||||
allocator: std.mem.Allocator,
|
||||
root: std.json.Value,
|
||||
level: []const u8,
|
||||
) ![]u8 {
|
||||
// Clone the value by stringifying and re-parsing so we can modify it
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
try writeJSONValue(buf.writer(allocator), root);
|
||||
const json_str = try buf.toOwnedSlice(allocator);
|
||||
defer allocator.free(json_str);
|
||||
var parsed_clone = try std.json.parseFromSlice(std.json.Value, allocator, json_str, .{});
|
||||
defer parsed_clone.deinit();
|
||||
var cloned = parsed_clone.value;
|
||||
|
||||
if (cloned != .object) {
|
||||
// For non-objects, just re-serialize and return
|
||||
var out_buf = std.ArrayList(u8).empty;
|
||||
defer out_buf.deinit(allocator);
|
||||
try writeJSONValue(out_buf.writer(allocator), cloned);
|
||||
return out_buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
const obj = &cloned.object;
|
||||
|
||||
// Anonymize metadata fields
|
||||
if (obj.get("metadata")) |meta| {
|
||||
if (meta == .object) {
|
||||
var meta_obj = meta.object;
|
||||
|
||||
// Path anonymization: /nas/private/user/data → /datasets/data
|
||||
if (meta_obj.get("dataset_path")) |dp| {
|
||||
if (dp == .string) {
|
||||
const anon_path = try anonymizePath(allocator, dp.string);
|
||||
defer allocator.free(anon_path);
|
||||
try meta_obj.put("dataset_path", std.json.Value{ .string = anon_path });
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize other paths if full level
|
||||
if (std.mem.eql(u8, level, "full")) {
|
||||
const path_fields = [_][]const u8{ "code_path", "output_path", "checkpoint_path" };
|
||||
for (path_fields) |field| {
|
||||
if (meta_obj.get(field)) |p| {
|
||||
if (p == .string) {
|
||||
const anon_path = try anonymizePath(allocator, p.string);
|
||||
defer allocator.free(anon_path);
|
||||
try meta_obj.put(field, std.json.Value{ .string = anon_path });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize system info
|
||||
if (obj.get("system")) |sys| {
|
||||
if (sys == .object) {
|
||||
var sys_obj = sys.object;
|
||||
|
||||
// Hostname: gpu-server-01.internal → worker-A
|
||||
if (sys_obj.get("hostname")) |h| {
|
||||
if (h == .string) {
|
||||
try sys_obj.put("hostname", std.json.Value{ .string = "worker-A" });
|
||||
}
|
||||
}
|
||||
|
||||
// IP addresses → [REDACTED]
|
||||
if (sys_obj.get("ip_address")) |_| {
|
||||
try sys_obj.put("ip_address", std.json.Value{ .string = "[REDACTED]" });
|
||||
}
|
||||
|
||||
// Username: user@lab.edu → researcher-N
|
||||
if (sys_obj.get("username")) |_| {
|
||||
try sys_obj.put("username", std.json.Value{ .string = "[REDACTED]" });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize logs reference (logs may contain PII)
|
||||
if (std.mem.eql(u8, level, "full")) {
|
||||
_ = obj.swapRemove("logs");
|
||||
_ = obj.swapRemove("log_path");
|
||||
_ = obj.swapRemove("annotations");
|
||||
}
|
||||
|
||||
// Serialize back to JSON
|
||||
var out_buf = std.ArrayList(u8).empty;
|
||||
defer out_buf.deinit(allocator);
|
||||
try writeJSONValue(out_buf.writer(allocator), cloned);
|
||||
return out_buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn anonymizePath(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
|
||||
// Simple path anonymization: replace leading path components with generic names
|
||||
// /home/user/project/data → /workspace/data
|
||||
// /nas/private/lab/experiments → /datasets/experiments
|
||||
|
||||
// Find the last component
|
||||
const last_sep = std.mem.lastIndexOf(u8, path, "/");
|
||||
if (last_sep == null) return allocator.dupe(u8, path);
|
||||
|
||||
const filename = path[last_sep.? + 1 ..];
|
||||
|
||||
// Determine prefix based on context
|
||||
const prefix = if (std.mem.indexOf(u8, path, "data") != null)
|
||||
"/datasets"
|
||||
else if (std.mem.indexOf(u8, path, "model") != null or std.mem.indexOf(u8, path, "checkpoint") != null)
|
||||
"/models"
|
||||
else if (std.mem.indexOf(u8, path, "code") != null or std.mem.indexOf(u8, path, "src") != null)
|
||||
"/code"
|
||||
else
|
||||
"/workspace";
|
||||
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ prefix, filename });
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml export <run-id|path> [options]\n", .{});
|
||||
colors.printInfo("\nExport experiment for sharing or archiving:\n", .{});
|
||||
colors.printInfo(" --bundle <path> Create tarball at path\n", .{});
|
||||
colors.printInfo(" --anonymize Enable anonymization\n", .{});
|
||||
colors.printInfo(" --anonymize-level <lvl> 'metadata-only' or 'full'\n", .{});
|
||||
colors.printInfo(" --base <path> Base path to find run\n", .{});
|
||||
colors.printInfo(" --json Output JSON response\n", .{});
|
||||
colors.printInfo("\nAnonymization rules:\n", .{});
|
||||
colors.printInfo(" - Paths: /nas/private/... → /datasets/...\n", .{});
|
||||
colors.printInfo(" - Hostnames: gpu-server-01 → worker-A\n", .{});
|
||||
colors.printInfo(" - IPs: 192.168.1.100 → [REDACTED]\n", .{});
|
||||
colors.printInfo(" - Usernames: user@lab.edu → [REDACTED]\n", .{});
|
||||
colors.printInfo(" - Full level: Also removes logs and annotations\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz\n", .{});
|
||||
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz --anonymize\n", .{});
|
||||
colors.printInfo(" ml export run_abc --anonymize --anonymize-level full\n", .{});
|
||||
}
|
||||
|
||||
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
|
||||
switch (v) {
|
||||
.null => try writer.writeAll("null"),
|
||||
.bool => |b| try writer.print("{}", .{b}),
|
||||
.integer => |i| try writer.print("{d}", .{i}),
|
||||
.float => |f| try writer.print("{d}", .{f}),
|
||||
.string => |s| try writeJSONString(writer, s),
|
||||
.array => |arr| {
|
||||
try writer.writeAll("[");
|
||||
for (arr.items, 0..) |item, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONValue(writer, item);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
},
|
||||
.object => |obj| {
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
var it = obj.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"{s}\":", .{entry.key_ptr.*});
|
||||
try writeJSONValue(writer, entry.value_ptr.*);
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
},
|
||||
.number_string => |s| try writer.print("{s}", .{s}),
|
||||
}
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
497
cli/src/commands/find.zig
Normal file
497
cli/src/commands/find.zig
Normal file
|
|
@ -0,0 +1,497 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
|
||||
pub const FindOptions = struct {
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
limit: usize = 20,
|
||||
tag: ?[]const u8 = null,
|
||||
outcome: ?[]const u8 = null,
|
||||
dataset: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
author: ?[]const u8 = null,
|
||||
after: ?[]const u8 = null,
|
||||
before: ?[]const u8 = null,
|
||||
query: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
var options = FindOptions{};
|
||||
var query_str: ?[]const u8 = null;
|
||||
|
||||
// First argument might be a query string or a flag
|
||||
var arg_idx: usize = 0;
|
||||
if (!std.mem.startsWith(u8, argv[0], "--")) {
|
||||
query_str = argv[0];
|
||||
arg_idx = 1;
|
||||
}
|
||||
|
||||
var i: usize = arg_idx;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
options.csv = true;
|
||||
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) {
|
||||
options.limit = try std.fmt.parseInt(usize, argv[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) {
|
||||
options.tag = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) {
|
||||
options.outcome = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) {
|
||||
options.dataset = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < argv.len) {
|
||||
options.experiment_group = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) {
|
||||
options.author = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) {
|
||||
options.after = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) {
|
||||
options.before = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else if (!std.mem.startsWith(u8, arg, "--")) {
|
||||
// Treat as query string if not already set
|
||||
if (query_str == null) {
|
||||
query_str = arg;
|
||||
} else {
|
||||
colors.printError("Unknown argument: {s}\n", .{arg});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
options.query = query_str;
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
colors.printInfo("Searching experiments...\n", .{});
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Build search request JSON
|
||||
const search_json = try buildSearchJson(allocator, &options);
|
||||
defer allocator.free(search_json);
|
||||
|
||||
// Send search request - we'll use the dataset search opcode as a placeholder
|
||||
// In production, this would have a dedicated search endpoint
|
||||
try client.sendDatasetSearch(search_json, api_key_hash);
|
||||
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
// Parse response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch {
|
||||
if (options.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\":\"invalid_response\"}}\n", .{});
|
||||
} else {
|
||||
colors.printError("Failed to parse search results\n", .{});
|
||||
}
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value;
|
||||
|
||||
if (options.json) {
|
||||
try io.stdoutWriteJson(root);
|
||||
} else if (options.csv) {
|
||||
try outputCsvResults(allocator, root, &options);
|
||||
} else {
|
||||
try outputHumanResults(root, &options);
|
||||
}
|
||||
}
|
||||
|
||||
fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
if (options.query) |q| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"query\":");
|
||||
try writeJSONString(writer, q);
|
||||
}
|
||||
|
||||
if (options.tag) |t| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"tag\":");
|
||||
try writeJSONString(writer, t);
|
||||
}
|
||||
|
||||
if (options.outcome) |o| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"outcome\":");
|
||||
try writeJSONString(writer, o);
|
||||
}
|
||||
|
||||
if (options.dataset) |d| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"dataset\":");
|
||||
try writeJSONString(writer, d);
|
||||
}
|
||||
|
||||
if (options.experiment_group) |eg| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":");
|
||||
try writeJSONString(writer, eg);
|
||||
}
|
||||
|
||||
if (options.author) |a| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"author\":");
|
||||
try writeJSONString(writer, a);
|
||||
}
|
||||
|
||||
if (options.after) |a| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"after\":");
|
||||
try writeJSONString(writer, a);
|
||||
}
|
||||
|
||||
if (options.before) |b| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"before\":");
|
||||
try writeJSONString(writer, b);
|
||||
}
|
||||
|
||||
if (!first) try writer.writeAll(",");
|
||||
try writer.print("\"limit\":{d}", .{options.limit});
|
||||
|
||||
try writer.writeAll("}");
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn outputHumanResults(root: std.json.Value, options: *const FindOptions) !void {
|
||||
if (root != .object) {
|
||||
colors.printError("Invalid response format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const obj = root.object;
|
||||
|
||||
// Check for error
|
||||
if (obj.get("error")) |err| {
|
||||
if (err == .string) {
|
||||
colors.printError("Search error: {s}\n", .{err.string});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
||||
if (results == null) {
|
||||
colors.printInfo("No results found\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
if (results.? != .array) {
|
||||
colors.printError("Invalid results format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const items = results.?.array.items;
|
||||
|
||||
if (items.len == 0) {
|
||||
colors.printInfo("No experiments found matching your criteria\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
colors.printSuccess("Found {d} experiment(s)\n\n", .{items.len});
|
||||
|
||||
// Print header
|
||||
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
"ID", "Job Name", "Outcome", "Status", "Group/Tags",
|
||||
});
|
||||
colors.printInfo("{s}\n", .{"────────────────────────────────────────────────────────────────────────────────"});
|
||||
|
||||
for (items) |item| {
|
||||
if (item != .object) continue;
|
||||
const run_obj = item.object;
|
||||
|
||||
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
||||
const short_id = if (id.len > 8) id[0..8] else id;
|
||||
|
||||
const job_name = jsonGetString(run_obj, "job_name") orelse "unnamed";
|
||||
const job_display = if (job_name.len > 18) job_name[0..18] else job_name;
|
||||
|
||||
const outcome = jsonGetString(run_obj, "outcome") orelse "-";
|
||||
const status = jsonGetString(run_obj, "status") orelse "unknown";
|
||||
|
||||
// Build group/tags summary
|
||||
var summary_buf: [30]u8 = undefined;
|
||||
const summary = blk: {
|
||||
const group = jsonGetString(run_obj, "experiment_group");
|
||||
const tags = run_obj.get("tags");
|
||||
|
||||
if (group) |g| {
|
||||
if (tags) |t| {
|
||||
if (t == .string) {
|
||||
break :blk std.fmt.bufPrint(&summary_buf, "{s}/{s}", .{ g[0..@min(g.len, 10)], t.string[0..@min(t.string.len, 10)] }) catch g[0..@min(g.len, 15)];
|
||||
}
|
||||
}
|
||||
break :blk g[0..@min(g.len, 20)];
|
||||
}
|
||||
break :blk "-";
|
||||
};
|
||||
|
||||
// Color code by outcome
|
||||
if (std.mem.eql(u8, outcome, "validates")) {
|
||||
colors.printSuccess("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else if (std.mem.eql(u8, outcome, "refutes")) {
|
||||
colors.printError("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else if (std.mem.eql(u8, outcome, "partial") or std.mem.eql(u8, outcome, "inconclusive")) {
|
||||
colors.printWarning("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else {
|
||||
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
}
|
||||
|
||||
// Show hypothesis if available and query matches
|
||||
if (options.query) |_| {
|
||||
if (run_obj.get("narrative")) |narr| {
|
||||
if (narr == .object) {
|
||||
if (narr.object.get("hypothesis")) |h| {
|
||||
if (h == .string and h.string.len > 0) {
|
||||
const hypo = h.string;
|
||||
const display = if (hypo.len > 50) hypo[0..50] else hypo;
|
||||
colors.printInfo(" ↳ {s}...\n", .{display});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\nUse 'ml info <id>' for details, 'ml compare <a> <b>' to compare runs\n", .{});
|
||||
}
|
||||
|
||||
fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void {
|
||||
_ = options;
|
||||
|
||||
if (root != .object) {
|
||||
colors.printError("Invalid response format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const obj = root.object;
|
||||
|
||||
// Check for error
|
||||
if (obj.get("error")) |err| {
|
||||
if (err == .string) {
|
||||
colors.printError("Search error: {s}\n", .{err.string});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
||||
if (results == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (results.? != .array) {
|
||||
return;
|
||||
}
|
||||
|
||||
const items = results.?.array.items;
|
||||
|
||||
// CSV Header
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n");
|
||||
|
||||
for (items) |item| {
|
||||
if (item != .object) continue;
|
||||
const run_obj = item.object;
|
||||
|
||||
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
||||
const job_name = jsonGetString(run_obj, "job_name") orelse "";
|
||||
const outcome = jsonGetString(run_obj, "outcome") orelse "";
|
||||
const status = jsonGetString(run_obj, "status") orelse "";
|
||||
const group = jsonGetString(run_obj, "experiment_group") orelse "";
|
||||
|
||||
// Get tags
|
||||
var tags: []const u8 = "";
|
||||
if (run_obj.get("tags")) |t| {
|
||||
if (t == .string) tags = t.string;
|
||||
}
|
||||
|
||||
// Get hypothesis
|
||||
var hypothesis: []const u8 = "";
|
||||
if (run_obj.get("narrative")) |narr| {
|
||||
if (narr == .object) {
|
||||
if (narr.object.get("hypothesis")) |h| {
|
||||
if (h == .string) hypothesis = h.string;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Escape fields that might contain commas or quotes
|
||||
const safe_job = try escapeCsv(allocator, job_name);
|
||||
defer allocator.free(safe_job);
|
||||
const safe_group = try escapeCsv(allocator, group);
|
||||
defer allocator.free(safe_group);
|
||||
const safe_tags = try escapeCsv(allocator, tags);
|
||||
defer allocator.free(safe_tags);
|
||||
const safe_hypo = try escapeCsv(allocator, hypothesis);
|
||||
defer allocator.free(safe_hypo);
|
||||
|
||||
var buf: [1024]u8 = undefined;
|
||||
const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{
|
||||
id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo,
|
||||
});
|
||||
try stdout_file.writeAll(line);
|
||||
}
|
||||
}
|
||||
|
||||
fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 {
|
||||
// Check if we need to escape (contains comma, quote, or newline)
|
||||
var needs_escape = false;
|
||||
for (s) |c| {
|
||||
if (c == ',' or c == '"' or c == '\n' or c == '\r') {
|
||||
needs_escape = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!needs_escape) {
|
||||
return allocator.dupe(u8, s);
|
||||
}
|
||||
|
||||
// Escape: wrap in quotes and double existing quotes
|
||||
var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
try buf.append(allocator, '"');
|
||||
for (s) |c| {
|
||||
if (c == '"') {
|
||||
try buf.appendSlice(allocator, "\"\"");
|
||||
} else {
|
||||
try buf.append(allocator, c);
|
||||
}
|
||||
}
|
||||
try buf.append(allocator, '"');
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) return null;
|
||||
const v = v_opt.?;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml find [query] [options]\n", .{});
|
||||
colors.printInfo("\nSearch experiments by:\n", .{});
|
||||
colors.printInfo(" Query (free text): ml find \"hypothesis: warmup\"\n", .{});
|
||||
colors.printInfo(" Tags: ml find --tag ablation\n", .{});
|
||||
colors.printInfo(" Outcome: ml find --outcome validates\n", .{});
|
||||
colors.printInfo(" Dataset: ml find --dataset imagenet\n", .{});
|
||||
colors.printInfo(" Experiment group: ml find --experiment-group lr-scaling\n", .{});
|
||||
colors.printInfo(" Author: ml find --author user@lab.edu\n", .{});
|
||||
colors.printInfo(" Time range: ml find --after 2024-01-01 --before 2024-03-01\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --limit <n> Max results (default: 20)\n", .{});
|
||||
colors.printInfo(" --json Output as JSON\n", .{});
|
||||
colors.printInfo(" --csv Output as CSV\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml find --tag ablation --outcome validates\n", .{});
|
||||
colors.printInfo(" ml find --experiment-group batch-scaling --json\n", .{});
|
||||
colors.printInfo(" ml find \"learning rate\" --after 2024-01-01\n", .{});
|
||||
}
|
||||
314
cli/src/commands/outcome.zig
Normal file
314
cli/src/commands/outcome.zig
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const sub = argv[0];
|
||||
if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!std.mem.eql(u8, sub, "set")) {
|
||||
colors.printError("Unknown subcommand: {s}\n", .{sub});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (argv.len < 2) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const target = argv[1];
|
||||
|
||||
var outcome_status: ?[]const u8 = null;
|
||||
var outcome_summary: ?[]const u8 = null;
|
||||
var learnings = std.ArrayList([]const u8).empty;
|
||||
defer learnings.deinit(allocator);
|
||||
var next_steps = std.ArrayList([]const u8).empty;
|
||||
defer next_steps.deinit(allocator);
|
||||
var validation_status: ?[]const u8 = null;
|
||||
var surprises = std.ArrayList([]const u8).empty;
|
||||
defer surprises.deinit(allocator);
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
|
||||
var i: usize = 2;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const a = argv[i];
|
||||
if (std.mem.eql(u8, a, "--outcome") or std.mem.eql(u8, a, "--status")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
outcome_status = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--summary")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
outcome_summary = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--learning")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
try learnings.append(allocator, argv[i + 1]);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--next-step")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
try next_steps.append(allocator, argv[i + 1]);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--validation-status")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
validation_status = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--surprise")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
try surprises.append(allocator, argv[i + 1]);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--base")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
base_override = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
if (outcome_status == null and outcome_summary == null and learnings.items.len == 0 and
|
||||
next_steps.items.len == 0 and validation_status == null and surprises.items.len == 0)
|
||||
{
|
||||
colors.printError("No outcome fields provided.\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
// Validate outcome status if provided
|
||||
if (outcome_status) |os| {
|
||||
const valid = std.mem.eql(u8, os, "validates") or
|
||||
std.mem.eql(u8, os, "refutes") or
|
||||
std.mem.eql(u8, os, "inconclusive") or
|
||||
std.mem.eql(u8, os, "partial") or
|
||||
std.mem.eql(u8, os, "inconclusive-partial");
|
||||
if (!valid) {
|
||||
colors.printError("Invalid outcome status: {s}. Must be one of: validates, refutes, inconclusive, partial\n", .{os});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validation status if provided
|
||||
if (validation_status) |vs| {
|
||||
const valid = std.mem.eql(u8, vs, "validates") or
|
||||
std.mem.eql(u8, vs, "refutes") or
|
||||
std.mem.eql(u8, vs, "inconclusive") or
|
||||
std.mem.eql(u8, vs, "");
|
||||
if (!valid) {
|
||||
colors.printError("Invalid validation status: {s}. Must be one of: validates, refutes, inconclusive\n", .{vs});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const job_name = try manifest.readJobNameFromManifest(allocator, manifest_path);
|
||||
defer allocator.free(job_name);
|
||||
|
||||
const patch_json = try buildOutcomePatchJSON(
|
||||
allocator,
|
||||
outcome_status,
|
||||
outcome_summary,
|
||||
learnings.items,
|
||||
next_steps.items,
|
||||
validation_status,
|
||||
surprises.items,
|
||||
);
|
||||
defer allocator.free(patch_json);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendSetRunNarrative(job_name, patch_json, api_key_hash);
|
||||
|
||||
if (json_mode) {
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
if (packet.packet_type == .success) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"success\":true,\"job_name\":\"{s}\"}}\n", .{job_name});
|
||||
} else if (packet.packet_type == .error_packet) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"success\":false,\"error\":\"{s}\"}}\n", .{packet.error_message orelse "unknown"});
|
||||
} else {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"success\":true,\"job_name\":\"{s}\",\"response\":\"{s}\"}}\n", .{ job_name, packet.success_message orelse "ok" });
|
||||
}
|
||||
} else {
|
||||
try client.receiveAndHandleResponse(allocator, "Outcome set");
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml outcome set <run-id|job-name|path> [options]\n", .{});
|
||||
colors.printInfo("\nPost-Run Outcome Capture:\n", .{});
|
||||
colors.printInfo(" --outcome <status> Outcome: validates|refutes|inconclusive|partial\n", .{});
|
||||
colors.printInfo(" --summary <text> Summary of results\n", .{});
|
||||
colors.printInfo(" --learning <text> A learning from this run (can repeat)\n", .{});
|
||||
colors.printInfo(" --next-step <text> Suggested next step (can repeat)\n", .{});
|
||||
colors.printInfo(" --validation-status <st> Did results validate hypothesis? validates|refutes|inconclusive\n", .{});
|
||||
colors.printInfo(" --surprise <text> Unexpected finding (can repeat)\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --base <path> Base path to search for run_manifest.json\n", .{});
|
||||
colors.printInfo(" --json Output JSON response\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml outcome set run_abc --outcome validates --summary \"Accuracy +0.02\"\n", .{});
|
||||
colors.printInfo(" ml outcome set run_abc --learning \"LR scaling worked\" --learning \"GPU util 95%\"\n", .{});
|
||||
colors.printInfo(" ml outcome set run_abc --outcome validates --next-step \"Try batch=96\"\n", .{});
|
||||
}
|
||||
|
||||
fn buildOutcomePatchJSON(
|
||||
allocator: std.mem.Allocator,
|
||||
outcome_status: ?[]const u8,
|
||||
outcome_summary: ?[]const u8,
|
||||
learnings: [][]const u8,
|
||||
next_steps: [][]const u8,
|
||||
validation_status: ?[]const u8,
|
||||
surprises: [][]const u8,
|
||||
) ![]u8 {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{\"narrative\":");
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
if (outcome_status) |os| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"outcome\":\"{s}\"", .{os});
|
||||
}
|
||||
|
||||
if (outcome_summary) |os| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"outcome_summary\":");
|
||||
try writeJSONString(writer, os);
|
||||
}
|
||||
|
||||
if (learnings.len > 0) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"learnings\":[");
|
||||
for (learnings, 0..) |learning, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONString(writer, learning);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
}
|
||||
|
||||
if (next_steps.len > 0) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"next_steps\":[");
|
||||
for (next_steps, 0..) |step, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONString(writer, step);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
}
|
||||
|
||||
if (validation_status) |vs| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"validation_status\":\"{s}\"", .{vs});
|
||||
}
|
||||
|
||||
if (surprises.len > 0) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"surprises\":[");
|
||||
for (surprises, 0..) |surprise, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONString(writer, surprise);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
}
|
||||
|
||||
try writer.writeAll("}}");
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
|
@ -42,6 +42,17 @@ pub const QueueOptions = struct {
|
|||
memory: u8 = 8,
|
||||
gpu: u8 = 0,
|
||||
gpu_memory: ?[]const u8 = null,
|
||||
// Narrative fields for research context
|
||||
hypothesis: ?[]const u8 = null,
|
||||
context: ?[]const u8 = null,
|
||||
intent: ?[]const u8 = null,
|
||||
expected_outcome: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
tags: ?[]const u8 = null,
|
||||
// Sandboxing options
|
||||
network_mode: ?[]const u8 = null,
|
||||
read_only: bool = false,
|
||||
secrets: std.ArrayList([]const u8),
|
||||
};
|
||||
|
||||
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
|
||||
|
|
@ -122,7 +133,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
.dry_run = config.default_dry_run,
|
||||
.validate = config.default_validate,
|
||||
.json = config.default_json,
|
||||
.secrets = std.ArrayList([]const u8).empty,
|
||||
};
|
||||
defer options.secrets.deinit(allocator);
|
||||
priority = config.default_priority;
|
||||
|
||||
// Tracking configuration
|
||||
|
|
@ -254,6 +267,32 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
|
||||
note_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--hypothesis") and i + 1 < pre.len) {
|
||||
options.hypothesis = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--context") and i + 1 < pre.len) {
|
||||
options.context = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--intent") and i + 1 < pre.len) {
|
||||
options.intent = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--expected-outcome") and i + 1 < pre.len) {
|
||||
options.expected_outcome = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < pre.len) {
|
||||
options.experiment_group = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--tags") and i + 1 < pre.len) {
|
||||
options.tags = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--network") and i + 1 < pre.len) {
|
||||
options.network_mode = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--read-only")) {
|
||||
options.read_only = true;
|
||||
} else if (std.mem.eql(u8, arg, "--secret") and i + 1 < pre.len) {
|
||||
try options.secrets.append(allocator, pre[i + 1]);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
// This is a job name
|
||||
|
|
@ -378,6 +417,10 @@ fn queueSingleJob(
|
|||
};
|
||||
defer if (commit_override == null) allocator.free(commit_id);
|
||||
|
||||
// Build narrative JSON if any narrative fields are set
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
|
|
@ -407,13 +450,43 @@ fn queueSingleJob(
|
|||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (tracking_json.len > 0) {
|
||||
// Build combined metadata JSON with tracking and/or narrative
|
||||
const combined_json = blk: {
|
||||
if (tracking_json.len > 0 and narrative_json != null) {
|
||||
// Merge tracking and narrative
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
try writer.writeAll(tracking_json[1 .. tracking_json.len - 1]); // Remove outer braces
|
||||
try writer.writeAll(",");
|
||||
try writer.writeAll("\"narrative\":");
|
||||
try writer.writeAll(narrative_json.?);
|
||||
try writer.writeAll("}");
|
||||
break :blk try buf.toOwnedSlice(allocator);
|
||||
} else if (tracking_json.len > 0) {
|
||||
break :blk try allocator.dupe(u8, tracking_json);
|
||||
} else if (narrative_json) |nj| {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{\"narrative\":");
|
||||
try writer.writeAll(nj);
|
||||
try writer.writeAll("}");
|
||||
break :blk try buf.toOwnedSlice(allocator);
|
||||
} else {
|
||||
break :blk "";
|
||||
}
|
||||
};
|
||||
defer if (combined_json.len > 0 and combined_json.ptr != tracking_json.ptr) allocator.free(combined_json);
|
||||
|
||||
if (combined_json.len > 0) {
|
||||
try client.sendQueueJobWithTrackingAndResources(
|
||||
job_name,
|
||||
commit_id,
|
||||
priority,
|
||||
api_key_hash,
|
||||
tracking_json,
|
||||
combined_json,
|
||||
options.cpu,
|
||||
options.memory,
|
||||
options.gpu,
|
||||
|
|
@ -547,6 +620,13 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --args <string> Extra runner args (sent to worker as task.Args)\n", .{});
|
||||
colors.printInfo(" --note <string> Human notes (stored in run manifest as metadata.note)\n", .{});
|
||||
colors.printInfo(" -- <args...> Extra runner args (alternative to --args)\n", .{});
|
||||
colors.printInfo("\nResearch Narrative:\n", .{});
|
||||
colors.printInfo(" --hypothesis <text> Research hypothesis being tested\n", .{});
|
||||
colors.printInfo(" --context <text> Background context for this experiment\n", .{});
|
||||
colors.printInfo(" --intent <text> What you're trying to accomplish\n", .{});
|
||||
colors.printInfo(" --expected-outcome <text> What you expect to happen\n", .{});
|
||||
colors.printInfo(" --experiment-group <name> Group related experiments\n", .{});
|
||||
colors.printInfo(" --tags <csv> Comma-separated tags (e.g., ablation,batch-size)\n", .{});
|
||||
colors.printInfo("\nSpecial Modes:\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be queued\n", .{});
|
||||
colors.printInfo(" --validate Validate experiment without queuing\n", .{});
|
||||
|
|
@ -561,11 +641,19 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --wandb-project <prj> Set Wandb project\n", .{});
|
||||
colors.printInfo(" --wandb-entity <ent> Set Wandb entity\n", .{});
|
||||
|
||||
colors.printInfo("\nSandboxing:\n", .{});
|
||||
colors.printInfo(" --network <mode> Network mode: none, bridge, slirp4netns\n", .{});
|
||||
colors.printInfo(" --read-only Mount root filesystem as read-only\n", .{});
|
||||
colors.printInfo(" --secret <name> Inject secret as env var (can repeat)\n", .{});
|
||||
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml queue my_job # Queue a job\n", .{});
|
||||
colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{});
|
||||
colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{});
|
||||
colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{});
|
||||
colors.printInfo("\nResearch Examples:\n", .{});
|
||||
colors.printInfo(" ml queue train.py --hypothesis \"LR scaling improves convergence\" \\\n", .{});
|
||||
colors.printInfo(" --context \"Following paper XYZ\" --tags ablation,lr-scaling\n", .{});
|
||||
}
|
||||
|
||||
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {
|
||||
|
|
@ -597,13 +685,23 @@ fn explainJob(
|
|||
commit_display = enc;
|
||||
}
|
||||
|
||||
// Build narrative JSON for display
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"explain\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
try writeJSONNullableString(&stdout_file, options.gpu_memory);
|
||||
try stdout_file.writeAll("}}\n");
|
||||
if (narrative_json) |nj| {
|
||||
try stdout_file.writeAll("},\"narrative\":");
|
||||
try stdout_file.writeAll(nj);
|
||||
try stdout_file.writeAll("}\n");
|
||||
} else {
|
||||
try stdout_file.writeAll("}}\n");
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
colors.printInfo("Job Explanation:\n", .{});
|
||||
|
|
@ -616,7 +714,30 @@ fn explainJob(
|
|||
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
|
||||
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
|
||||
|
||||
colors.printInfo(" Action: Job would be queued for execution\n", .{});
|
||||
// Display narrative if provided
|
||||
if (narrative_json != null) {
|
||||
colors.printInfo("\n Research Narrative:\n", .{});
|
||||
if (options.hypothesis) |h| {
|
||||
colors.printInfo(" Hypothesis: {s}\n", .{h});
|
||||
}
|
||||
if (options.context) |c| {
|
||||
colors.printInfo(" Context: {s}\n", .{c});
|
||||
}
|
||||
if (options.intent) |i| {
|
||||
colors.printInfo(" Intent: {s}\n", .{i});
|
||||
}
|
||||
if (options.expected_outcome) |eo| {
|
||||
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
|
||||
}
|
||||
if (options.experiment_group) |eg| {
|
||||
colors.printInfo(" Experiment Group: {s}\n", .{eg});
|
||||
}
|
||||
if (options.tags) |t| {
|
||||
colors.printInfo(" Tags: {s}\n", .{t});
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n Action: Job would be queued for execution\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -689,13 +810,23 @@ fn dryRunJob(
|
|||
commit_display = enc;
|
||||
}
|
||||
|
||||
// Build narrative JSON for display
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"dry_run\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
try writeJSONNullableString(&stdout_file, options.gpu_memory);
|
||||
try stdout_file.writeAll("}},\"would_queue\":true}}\n");
|
||||
if (narrative_json) |nj| {
|
||||
try stdout_file.writeAll("},\"narrative\":");
|
||||
try stdout_file.writeAll(nj);
|
||||
try stdout_file.writeAll(",\"would_queue\":true}}\n");
|
||||
} else {
|
||||
try stdout_file.writeAll("},\"would_queue\":true}}\n");
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
colors.printInfo("Dry Run - Job Queue Preview:\n", .{});
|
||||
|
|
@ -708,7 +839,30 @@ fn dryRunJob(
|
|||
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
|
||||
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
|
||||
|
||||
colors.printInfo(" Action: Would queue job\n", .{});
|
||||
// Display narrative if provided
|
||||
if (narrative_json != null) {
|
||||
colors.printInfo("\n Research Narrative:\n", .{});
|
||||
if (options.hypothesis) |h| {
|
||||
colors.printInfo(" Hypothesis: {s}\n", .{h});
|
||||
}
|
||||
if (options.context) |c| {
|
||||
colors.printInfo(" Context: {s}\n", .{c});
|
||||
}
|
||||
if (options.intent) |i| {
|
||||
colors.printInfo(" Intent: {s}\n", .{i});
|
||||
}
|
||||
if (options.expected_outcome) |eo| {
|
||||
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
|
||||
}
|
||||
if (options.experiment_group) |eg| {
|
||||
colors.printInfo(" Experiment Group: {s}\n", .{eg});
|
||||
}
|
||||
if (options.tags) |t| {
|
||||
colors.printInfo(" Tags: {s}\n", .{t});
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n Action: Would queue job\n", .{});
|
||||
colors.printInfo(" Estimated queue time: 2-5 minutes\n", .{});
|
||||
colors.printSuccess(" ✓ Dry run completed - no job was actually queued\n", .{});
|
||||
}
|
||||
|
|
@ -923,3 +1077,71 @@ fn handleDuplicateResponse(
|
|||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
||||
// buildNarrativeJson creates a JSON object from narrative fields
|
||||
fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions) !?[]u8 {
|
||||
// Check if any narrative field is set
|
||||
if (options.hypothesis == null and
|
||||
options.context == null and
|
||||
options.intent == null and
|
||||
options.expected_outcome == null and
|
||||
options.experiment_group == null and
|
||||
options.tags == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
if (options.hypothesis) |h| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"hypothesis\":");
|
||||
try writeJSONString(writer, h);
|
||||
}
|
||||
|
||||
if (options.context) |c| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"context\":");
|
||||
try writeJSONString(writer, c);
|
||||
}
|
||||
|
||||
if (options.intent) |i| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"intent\":");
|
||||
try writeJSONString(writer, i);
|
||||
}
|
||||
|
||||
if (options.expected_outcome) |eo| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"expected_outcome\":");
|
||||
try writeJSONString(writer, eo);
|
||||
}
|
||||
|
||||
if (options.experiment_group) |eg| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":");
|
||||
try writeJSONString(writer, eg);
|
||||
}
|
||||
|
||||
if (options.tags) |t| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"tags\":");
|
||||
try writeJSONString(writer, t);
|
||||
}
|
||||
|
||||
try writer.writeAll("}");
|
||||
|
||||
return try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,6 +47,13 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
var note_override: ?[]const u8 = null;
|
||||
var force: bool = false;
|
||||
|
||||
// New: Change tracking options
|
||||
var inherit_narrative: bool = false;
|
||||
var inherit_config: bool = false;
|
||||
var parent_link: bool = false;
|
||||
var overrides = std.ArrayList([2][]const u8).empty; // [key, value] pairs
|
||||
defer overrides.deinit(allocator);
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < pre.len) : (i += 1) {
|
||||
const a = pre[i];
|
||||
|
|
@ -76,6 +83,18 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--force")) {
|
||||
force = true;
|
||||
} else if (std.mem.eql(u8, a, "--inherit-narrative")) {
|
||||
inherit_narrative = true;
|
||||
} else if (std.mem.eql(u8, a, "--inherit-config")) {
|
||||
inherit_config = true;
|
||||
} else if (std.mem.eql(u8, a, "--parent")) {
|
||||
parent_link = true;
|
||||
} else if (std.mem.startsWith(u8, a, "--") and std.mem.indexOf(u8, a, "=") != null) {
|
||||
// Key=value override: --lr=0.002 or --batch-size=128
|
||||
const eq_idx = std.mem.indexOf(u8, a, "=").?;
|
||||
const key = a[2..eq_idx];
|
||||
const value = a[eq_idx + 1 ..];
|
||||
try overrides.append(allocator, [2][]const u8{ key, value });
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
|
|
@ -100,6 +119,11 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
const args_final: []const u8 = if (args_override) |a| a else args_joined;
|
||||
const note_final: []const u8 = if (note_override) |n| n else "";
|
||||
|
||||
// Read original manifest for inheritance
|
||||
var original_narrative: ?std.json.ObjectMap = null;
|
||||
var original_config: ?std.json.ObjectMap = null;
|
||||
var parent_run_id: ?[]const u8 = null;
|
||||
|
||||
// Target can be:
|
||||
// - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base
|
||||
// - run_id/task_id/path (resolved to run_manifest.json to read commit_id)
|
||||
|
|
@ -111,6 +135,42 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
var commit_bytes_allocated = false;
|
||||
defer if (commit_bytes_allocated) allocator.free(commit_bytes);
|
||||
|
||||
// If we need to inherit narrative or config, or link parent, read the manifest first
|
||||
if (inherit_narrative or inherit_config or parent_link) {
|
||||
const manifest_path = try manifest.resolvePathWithBase(allocator, target, cfg.worker_base);
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const data = try manifest.readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value == .object) {
|
||||
const root = parsed.value.object;
|
||||
|
||||
if (inherit_narrative) {
|
||||
if (root.get("narrative")) |narr| {
|
||||
if (narr == .object) {
|
||||
original_narrative = try narr.object.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inherit_config) {
|
||||
if (root.get("metadata")) |meta| {
|
||||
if (meta == .object) {
|
||||
original_config = try meta.object.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (parent_link) {
|
||||
parent_run_id = json.getString(root, "run_id") orelse json.getString(root, "id");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) {
|
||||
if (target.len == 40) {
|
||||
commit_hex = target;
|
||||
|
|
@ -172,6 +232,52 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
commit_bytes_allocated = true;
|
||||
}
|
||||
|
||||
// Build tracking JSON with narrative/config inheritance and overrides
|
||||
const tracking_json = blk: {
|
||||
if (inherit_narrative or inherit_config or parent_link or overrides.items.len > 0) {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
const writer = buf.writer(allocator);
|
||||
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
|
||||
// Add narrative if inherited
|
||||
if (original_narrative) |narr| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"narrative\":");
|
||||
try writeJSONValue(writer, std.json.Value{ .object = narr });
|
||||
}
|
||||
|
||||
// Add parent relationship
|
||||
if (parent_run_id) |pid| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"parent_run\":\"");
|
||||
try writer.writeAll(pid);
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
// Add overrides
|
||||
if (overrides.items.len > 0) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"overrides\":{");
|
||||
for (overrides.items, 0..) |pair, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writer.print("\"{s}\":\"{s}\"", .{ pair[0], pair[1] });
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
}
|
||||
|
||||
try writer.writeAll("}");
|
||||
break :blk try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
break :blk "";
|
||||
};
|
||||
defer if (tracking_json.len > 0) allocator.free(tracking_json);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
|
|
@ -181,7 +287,20 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
|||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
if (note_final.len > 0) {
|
||||
// Send with tracking JSON if we have inheritance/overrides
|
||||
if (tracking_json.len > 0) {
|
||||
try client.sendQueueJobWithTrackingAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
priority,
|
||||
api_key_hash,
|
||||
tracking_json,
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
gpu_memory,
|
||||
);
|
||||
} else if (note_final.len > 0) {
|
||||
try client.sendQueueJobWithArgsNoteAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
|
|
@ -287,7 +406,83 @@ fn handleDuplicateResponse(
|
|||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage:\n", .{});
|
||||
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [--name <job>] [--priority <n>] [--cpu <n>] [--memory <gb>] [--gpu <n>] [--gpu-memory <gb>] [--args <string>] [--note <string>] [--force] -- <args...>\n", .{});
|
||||
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [options] -- <args...>\n\n", .{});
|
||||
colors.printInfo("Resource Options:\n", .{});
|
||||
colors.printInfo(" --name <job> Override job name\n", .{});
|
||||
colors.printInfo(" --priority <n> Set priority (0-255)\n", .{});
|
||||
colors.printInfo(" --cpu <n> CPU cores\n", .{});
|
||||
colors.printInfo(" --memory <gb> Memory in GB\n", .{});
|
||||
colors.printInfo(" --gpu <n> GPU count\n", .{});
|
||||
colors.printInfo(" --gpu-memory <gb> GPU memory budget\n", .{});
|
||||
colors.printInfo("\nInheritance Options:\n", .{});
|
||||
colors.printInfo(" --inherit-narrative Copy hypothesis/context/intent from parent\n", .{});
|
||||
colors.printInfo(" --inherit-config Copy metadata/config from parent\n", .{});
|
||||
colors.printInfo(" --parent Link as child run\n", .{});
|
||||
colors.printInfo("\nOverride Options:\n", .{});
|
||||
colors.printInfo(" --key=value Override specific config (e.g., --lr=0.002)\n", .{});
|
||||
colors.printInfo(" --args <string> Override runner args\n", .{});
|
||||
colors.printInfo(" --note <string> Add human note\n", .{});
|
||||
colors.printInfo("\nOther:\n", .{});
|
||||
colors.printInfo(" --force Requeue even if duplicate exists\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml requeue run_abc --lr=0.002 --batch-size=128\n", .{});
|
||||
colors.printInfo(" ml requeue run_abc --inherit-narrative --parent\n", .{});
|
||||
colors.printInfo(" ml requeue run_abc --lr=0.002 --inherit-narrative\n", .{});
|
||||
}
|
||||
|
||||
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
|
||||
switch (v) {
|
||||
.null => try writer.writeAll("null"),
|
||||
.bool => |b| try writer.print("{}", .{b}),
|
||||
.integer => |i| try writer.print("{d}", .{i}),
|
||||
.float => |f| try writer.print("{d}", .{f}),
|
||||
.string => |s| {
|
||||
try writer.writeAll("\"");
|
||||
try writeEscapedString(writer, s);
|
||||
try writer.writeAll("\"");
|
||||
},
|
||||
.array => |arr| {
|
||||
try writer.writeAll("[");
|
||||
for (arr.items, 0..) |item, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONValue(writer, item);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
},
|
||||
.object => |obj| {
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
var it = obj.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"{s}\":", .{entry.key_ptr.*});
|
||||
try writeJSONValue(writer, entry.value_ptr.*);
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
},
|
||||
.number_string => |s| try writer.print("{s}", .{s}),
|
||||
}
|
||||
}
|
||||
|
||||
fn writeEscapedString(writer: anytype, s: []const u8) !void {
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
try writer.print("\\u00{x:0>2}", .{c});
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn isHexLowerOrUpper(s: []const u8) bool {
|
||||
|
|
|
|||
|
|
@ -44,6 +44,11 @@ pub fn main() !void {
|
|||
},
|
||||
'n' => if (std.mem.eql(u8, command, "narrative")) {
|
||||
try @import("commands/narrative.zig").run(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "outcome")) {
|
||||
try @import("commands/outcome.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'p' => if (std.mem.eql(u8, command, "privacy")) {
|
||||
try @import("commands/privacy.zig").run(allocator, args[2..]);
|
||||
},
|
||||
's' => if (std.mem.eql(u8, command, "sync")) {
|
||||
if (args.len < 3) {
|
||||
|
|
@ -65,9 +70,16 @@ pub fn main() !void {
|
|||
},
|
||||
'e' => if (std.mem.eql(u8, command, "experiment")) {
|
||||
try @import("commands/experiment.zig").execute(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "export")) {
|
||||
try @import("commands/export_cmd.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'c' => if (std.mem.eql(u8, command, "cancel")) {
|
||||
try @import("commands/cancel.zig").run(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "compare")) {
|
||||
try @import("commands/compare.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'f' => if (std.mem.eql(u8, command, "find")) {
|
||||
try @import("commands/find.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'v' => if (std.mem.eql(u8, command, "validate")) {
|
||||
try @import("commands/validate.zig").run(allocator, args[2..]);
|
||||
|
|
@ -91,7 +103,12 @@ fn printUsage() void {
|
|||
std.debug.print(" jupyter Jupyter workspace management\n", .{});
|
||||
std.debug.print(" init Setup configuration interactively\n", .{});
|
||||
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
|
||||
std.debug.print(" compare <a> <b> Compare two runs (show differences)\n", .{});
|
||||
std.debug.print(" export <id> Export experiment bundle (--anonymize for safe sharing)\n", .{});
|
||||
std.debug.print(" find [query] Search experiments by tags/outcome/dataset\n", .{});
|
||||
std.debug.print(" narrative set <path|id> Set run narrative fields (hypothesis/context/...)\n", .{});
|
||||
std.debug.print(" outcome set <path|id> Set post-run outcome (validates/refutes/inconclusive)\n", .{});
|
||||
std.debug.print(" privacy set <path|id> Set experiment privacy level (private/team/public)\n", .{});
|
||||
std.debug.print(" info <path|id> Show run info from run_manifest.json (optionally --base <path>)\n", .{});
|
||||
std.debug.print(" sync <path> Sync project to server\n", .{});
|
||||
std.debug.print(" requeue <id> Re-submit from run_id/task_id/path (supports -- <args>)\n", .{});
|
||||
|
|
@ -113,5 +130,10 @@ test {
|
|||
_ = @import("commands/requeue.zig");
|
||||
_ = @import("commands/annotate.zig");
|
||||
_ = @import("commands/narrative.zig");
|
||||
_ = @import("commands/outcome.zig");
|
||||
_ = @import("commands/privacy.zig");
|
||||
_ = @import("commands/compare.zig");
|
||||
_ = @import("commands/find.zig");
|
||||
_ = @import("commands/export_cmd.zig");
|
||||
_ = @import("commands/logs.zig");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
package jobs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/http"
|
||||
"os"
|
||||
|
|
@ -20,11 +21,14 @@ import (
|
|||
|
||||
// Handler provides job-related WebSocket handlers
|
||||
type Handler struct {
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
authConfig *auth.Config
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
authConfig *auth.Config
|
||||
privacyEnforcer interface { // NEW: Privacy enforcement interface
|
||||
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new jobs handler
|
||||
|
|
@ -34,13 +38,17 @@ func NewHandler(
|
|||
queue queue.Backend,
|
||||
db *storage.DB,
|
||||
authConfig *auth.Config,
|
||||
privacyEnforcer interface { // NEW - can be nil
|
||||
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
|
||||
},
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
queue: queue,
|
||||
db: db,
|
||||
authConfig: authConfig,
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
queue: queue,
|
||||
db: db,
|
||||
authConfig: authConfig,
|
||||
privacyEnforcer: privacyEnforcer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -212,10 +220,79 @@ func (h *Handler) HandleSetRunNarrative(conn *websocket.Conn, payload []byte, us
|
|||
|
||||
h.logger.Info("setting run narrative", "job", jobName, "bucket", bucket)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"job_name": jobName,
|
||||
"narrative": patch,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleSetRunPrivacy handles setting run privacy
|
||||
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_len:2][patch:var]
|
||||
func (h *Handler) HandleSetRunPrivacy(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run privacy payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
jobNameLen := int(payload[offset])
|
||||
offset += 1
|
||||
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
jobName := string(payload[offset : offset+jobNameLen])
|
||||
offset += jobNameLen
|
||||
|
||||
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if patchLen <= 0 || len(payload) < offset+patchLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid patch length", "")
|
||||
}
|
||||
patch := string(payload[offset : offset+patchLen])
|
||||
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
|
||||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
typedRoots := []struct {
|
||||
bucket string
|
||||
root string
|
||||
}{
|
||||
{bucket: "running", root: jobPaths.RunningPath()},
|
||||
{bucket: "pending", root: jobPaths.PendingPath()},
|
||||
{bucket: "finished", root: jobPaths.FinishedPath()},
|
||||
{bucket: "failed", root: jobPaths.FailedPath()},
|
||||
}
|
||||
|
||||
var manifestDir, bucket string
|
||||
for _, item := range typedRoots {
|
||||
dir := filepath.Join(item.root, jobName)
|
||||
if _, err := os.Stat(dir); err == nil {
|
||||
manifestDir = dir
|
||||
bucket = item.bucket
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if manifestDir == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
|
||||
}
|
||||
|
||||
// TODO: Check if user is owner before allowing privacy changes
|
||||
|
||||
h.logger.Info("setting run privacy", "job", jobName, "bucket", bucket, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"job_name": jobName,
|
||||
"narrative": patch,
|
||||
"privacy": patch,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,12 @@ const (
|
|||
// Logs opcodes
|
||||
OpcodeGetLogs = 0x20
|
||||
OpcodeStreamLogs = 0x21
|
||||
|
||||
//
|
||||
OpcodeCompareRuns = 0x30
|
||||
OpcodeFindRuns = 0x31
|
||||
OpcodeExportRun = 0x32
|
||||
OpcodeSetRunOutcome = 0x33
|
||||
)
|
||||
|
||||
// Error codes
|
||||
|
|
@ -277,6 +283,14 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|||
return h.handleDatasetInfo(conn, payload)
|
||||
case OpcodeDatasetSearch:
|
||||
return h.handleDatasetSearch(conn, payload)
|
||||
case OpcodeCompareRuns:
|
||||
return h.handleCompareRuns(conn, payload)
|
||||
case OpcodeFindRuns:
|
||||
return h.handleFindRuns(conn, payload)
|
||||
case OpcodeExportRun:
|
||||
return h.handleExportRun(conn, payload)
|
||||
case OpcodeSetRunOutcome:
|
||||
return h.handleSetRunOutcome(conn, payload)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
||||
}
|
||||
|
|
@ -538,3 +552,216 @@ func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
|
|||
}
|
||||
return user.Admin || user.Permissions[permission]
|
||||
}
|
||||
|
||||
// handleCompareRuns compares two runs and returns differences
|
||||
func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][run_a_len:1][run_a:var][run_b_len:1][run_b:var]
|
||||
if len(payload) < 16+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "compare runs payload too short", "")
|
||||
}
|
||||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
runALen := int(payload[offset])
|
||||
offset++
|
||||
if runALen <= 0 || len(payload) < offset+runALen+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run A length", "")
|
||||
}
|
||||
runA := string(payload[offset : offset+runALen])
|
||||
offset += runALen
|
||||
|
||||
runBLen := int(payload[offset])
|
||||
offset++
|
||||
if runBLen <= 0 || len(payload) < offset+runBLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run B length", "")
|
||||
}
|
||||
runB := string(payload[offset : offset+runBLen])
|
||||
|
||||
// Fetch both experiments
|
||||
metaA, errA := h.expManager.ReadMetadata(runA)
|
||||
metaB, errB := h.expManager.ReadMetadata(runB)
|
||||
|
||||
// Build comparison result
|
||||
result := map[string]any{
|
||||
"run_a": runA,
|
||||
"run_b": runB,
|
||||
"success": true,
|
||||
}
|
||||
|
||||
// Add metadata if available
|
||||
if errA == nil && errB == nil {
|
||||
result["job_name_match"] = metaA.JobName == metaB.JobName
|
||||
result["user_match"] = metaA.User == metaB.User
|
||||
result["timestamp_diff"] = metaB.Timestamp - metaA.Timestamp
|
||||
}
|
||||
|
||||
// Read manifests for comparison
|
||||
manifestA, _ := h.expManager.ReadManifest(runA)
|
||||
manifestB, _ := h.expManager.ReadManifest(runB)
|
||||
|
||||
if manifestA != nil && manifestB != nil {
|
||||
result["overall_sha_match"] = manifestA.OverallSHA == manifestB.OverallSHA
|
||||
result["files_count_a"] = len(manifestA.Files)
|
||||
result["files_count_b"] = len(manifestB.Files)
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, result)
|
||||
}
|
||||
|
||||
// handleFindRuns searches for runs based on criteria
|
||||
func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][query_len:2][query:var]
|
||||
if len(payload) < 16+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "find runs payload too short", "")
|
||||
}
|
||||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
queryLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
||||
offset += 2
|
||||
if queryLen > 0 && len(payload) >= offset+int(queryLen) {
|
||||
// Parse query JSON
|
||||
queryData := payload[offset : offset+int(queryLen)]
|
||||
var query map[string]any
|
||||
if err := json.Unmarshal(queryData, &query); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query JSON", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("search query", "query", query, "user", user.Name)
|
||||
}
|
||||
|
||||
// For now, return placeholder results
|
||||
results := []map[string]any{
|
||||
{"id": "run_abc", "job_name": "train", "outcome": "validates"},
|
||||
{"id": "run_def", "job_name": "eval", "outcome": "partial"},
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"results": results,
|
||||
"count": len(results),
|
||||
})
|
||||
}
|
||||
|
||||
// handleExportRun exports a run with optional anonymization
|
||||
func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][options_len:2][options:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "export run payload too short", "")
|
||||
}
|
||||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
runIDLen := int(payload[offset])
|
||||
offset++
|
||||
if runIDLen <= 0 || len(payload) < offset+runIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
||||
}
|
||||
runID := string(payload[offset : offset+runIDLen])
|
||||
offset += runIDLen
|
||||
|
||||
// Parse options if present
|
||||
var options map[string]any
|
||||
if len(payload) >= offset+2 {
|
||||
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
||||
offset += 2
|
||||
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
|
||||
json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if experiment exists
|
||||
if !h.expManager.ExperimentExists(runID) {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
|
||||
}
|
||||
|
||||
anonymize := false
|
||||
if options != nil {
|
||||
if v, ok := options["anonymize"].(bool); ok {
|
||||
anonymize = v
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"run_id": runID,
|
||||
"message": "Export request received",
|
||||
"anonymize": anonymize,
|
||||
})
|
||||
}
|
||||
|
||||
// handleSetRunOutcome sets the outcome for a run
|
||||
func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][outcome_data_len:2][outcome_data:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run outcome payload too short", "")
|
||||
}
|
||||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsUpdate) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
runIDLen := int(payload[offset])
|
||||
offset++
|
||||
if runIDLen <= 0 || len(payload) < offset+runIDLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
||||
}
|
||||
runID := string(payload[offset : offset+runIDLen])
|
||||
offset += runIDLen
|
||||
|
||||
// Parse outcome data
|
||||
outcomeLen := binary.BigEndian.Uint16(payload[offset : offset+2])
|
||||
offset += 2
|
||||
if outcomeLen == 0 || len(payload) < offset+int(outcomeLen) {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome data", "")
|
||||
}
|
||||
|
||||
var outcomeData map[string]any
|
||||
if err := json.Unmarshal(payload[offset:offset+int(outcomeLen)], &outcomeData); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome JSON", err.Error())
|
||||
}
|
||||
|
||||
// Validate outcome status
|
||||
validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true}
|
||||
outcome, ok := outcomeData["outcome"].(string)
|
||||
if !ok || !validOutcomes[outcome] {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial")
|
||||
}
|
||||
|
||||
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]any{
|
||||
"success": true,
|
||||
"run_id": runID,
|
||||
"outcome": outcome,
|
||||
"message": "Outcome updated",
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,15 @@ type NarrativePatch struct {
|
|||
Tags *[]string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// Outcome represents the documented result of a run.
|
||||
type Outcome struct {
|
||||
Status string `json:"status,omitempty"` // validated, invalidated, inconclusive, partial
|
||||
Summary string `json:"summary,omitempty"` // Brief description
|
||||
KeyLearnings []string `json:"key_learnings,omitempty"` // 3-5 bullet points max
|
||||
FollowUpRuns []string `json:"follow_up_runs,omitempty"` // References to related runs
|
||||
ArtifactsUsed []string `json:"artifacts_used,omitempty"` // e.g., ["model.pt", "metrics.json"]
|
||||
}
|
||||
|
||||
type ArtifactFile struct {
|
||||
Path string `json:"path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
|
|
@ -83,6 +92,7 @@ type RunManifest struct {
|
|||
|
||||
Annotations []Annotation `json:"annotations,omitempty"`
|
||||
Narrative *Narrative `json:"narrative,omitempty"`
|
||||
Outcome *Outcome `json:"outcome,omitempty"`
|
||||
Artifacts *Artifacts `json:"artifacts,omitempty"`
|
||||
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ package manifest
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/privacy"
|
||||
)
|
||||
|
||||
// ErrIncompleteManifest is returned when a required manifest field is missing.
|
||||
|
|
@ -157,3 +160,142 @@ func (v *Validator) validateField(m *RunManifest, field string) *ValidationError
|
|||
func IsValidationError(err error) bool {
|
||||
return errors.Is(err, ErrIncompleteManifest)
|
||||
}
|
||||
|
||||
// NarrativeValidation contains validation results.
|
||||
type NarrativeValidation struct {
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
PIIFindings []privacy.PIIFinding `json:"pii_findings,omitempty"`
|
||||
}
|
||||
|
||||
// OutcomeValidation contains validation results.
|
||||
type OutcomeValidation struct {
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// Valid outcome statuses.
|
||||
var ValidOutcomeStatuses = []string{
|
||||
"validated", "invalidated", "inconclusive", "partial", "",
|
||||
}
|
||||
|
||||
// isValidOutcomeStatus checks if status is valid.
|
||||
func isValidOutcomeStatus(status string) bool {
|
||||
for _, s := range ValidOutcomeStatuses {
|
||||
if s == status {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateNarrative validates a Narrative struct.
|
||||
func ValidateNarrative(n *Narrative) NarrativeValidation {
|
||||
result := NarrativeValidation{
|
||||
Warnings: make([]string, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
if n == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate hypothesis length
|
||||
if len(n.Hypothesis) > 5000 {
|
||||
result.Errors = append(result.Errors, "hypothesis exceeds 5000 characters")
|
||||
} else if len(n.Hypothesis) > 1000 {
|
||||
result.Warnings = append(result.Warnings, "hypothesis is very long (>1000 chars)")
|
||||
}
|
||||
|
||||
// Validate context length
|
||||
if len(n.Context) > 10000 {
|
||||
result.Errors = append(result.Errors, "context exceeds 10000 characters")
|
||||
}
|
||||
|
||||
// Validate tags count
|
||||
if len(n.Tags) > 50 {
|
||||
result.Errors = append(result.Errors, "too many tags (max 50)")
|
||||
} else if len(n.Tags) > 20 {
|
||||
result.Warnings = append(result.Warnings, "many tags (>20)")
|
||||
}
|
||||
|
||||
// Validate tag lengths
|
||||
for i, tag := range n.Tags {
|
||||
if len(tag) > 50 {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("tag %d exceeds 50 characters", i))
|
||||
}
|
||||
if strings.ContainsAny(tag, ",;|/\\") {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("tag %d contains special characters", i))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for PII in text fields
|
||||
fields := map[string]string{
|
||||
"hypothesis": n.Hypothesis,
|
||||
"context": n.Context,
|
||||
"intent": n.Intent,
|
||||
}
|
||||
|
||||
for fieldName, text := range fields {
|
||||
if findings := privacy.DetectPII(text); len(findings) > 0 {
|
||||
result.PIIFindings = append(result.PIIFindings, findings...)
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential PII detected in %s field", fieldName))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateOutcome validates an Outcome struct.
|
||||
func ValidateOutcome(o *Outcome) OutcomeValidation {
|
||||
result := OutcomeValidation{
|
||||
Warnings: make([]string, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
if o == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate status
|
||||
if !isValidOutcomeStatus(o.Status) {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid status: %s (must be validated, invalidated, inconclusive, partial, or empty)", o.Status))
|
||||
}
|
||||
|
||||
// Validate summary length
|
||||
if len(o.Summary) > 1000 {
|
||||
result.Errors = append(result.Errors, "summary exceeds 1000 characters")
|
||||
} else if len(o.Summary) > 200 {
|
||||
result.Warnings = append(result.Warnings, "summary is long (>200 chars)")
|
||||
}
|
||||
|
||||
// Validate key learnings count
|
||||
if len(o.KeyLearnings) > 5 {
|
||||
result.Errors = append(result.Errors, "too many key learnings (max 5)")
|
||||
}
|
||||
|
||||
// Validate key learning lengths
|
||||
for i, learning := range o.KeyLearnings {
|
||||
if len(learning) > 500 {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("key learning %d exceeds 500 characters", i))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate follow-up runs references
|
||||
if len(o.FollowUpRuns) > 10 {
|
||||
result.Warnings = append(result.Warnings, "many follow-up runs (>10)")
|
||||
}
|
||||
|
||||
// Check for PII in text fields
|
||||
if findings := privacy.DetectPII(o.Summary); len(findings) > 0 {
|
||||
result.Warnings = append(result.Warnings, "potential PII detected in summary")
|
||||
}
|
||||
|
||||
for i, learning := range o.KeyLearnings {
|
||||
if findings := privacy.DetectPII(learning); len(findings) > 0 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential PII detected in key learning %d", i))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
175
tests/e2e/phase2_features_test.go
Normal file
175
tests/e2e/phase2_features_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// runCLI runs the CLI with given arguments and returns output
|
||||
func runCLI(t *testing.T, cliPath string, args ...string) (string, error) {
|
||||
t.Helper()
|
||||
cmd := exec.Command(cliPath, args...)
|
||||
cmd.Dir = t.TempDir()
|
||||
output, err := cmd.CombinedOutput()
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
// contains checks if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestCompareRunsE2E tests the ml compare command end-to-end
|
||||
func TestCompareRunsE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("CompareUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "compare", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected compare --help to show usage")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompareDummyRuns", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "compare", "run_abc", "run_def", "--json")
|
||||
t.Logf("Compare output: %s", output)
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &result); err == nil {
|
||||
if _, hasA := result["run_a"]; hasA {
|
||||
t.Log("Compare returned structured response")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFindRunsE2E tests the ml find command end-to-end
|
||||
func TestFindRunsE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("FindUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "find", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected find --help to show usage")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FindByOutcome", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "find", "--outcome", "validates", "--json")
|
||||
t.Logf("Find output: %s", output)
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &result); err == nil {
|
||||
t.Log("Find returned JSON response")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestExportRunE2E tests the ml export command end-to-end
|
||||
func TestExportRunE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("ExportUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "export", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected export --help to show usage")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRequeueWithChangesE2E tests the ml requeue command with changes
|
||||
func TestRequeueWithChangesE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("RequeueUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "requeue", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected requeue --help to show usage")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RequeueWithOverrides", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "requeue", "abc123", "--lr=0.002", "--json")
|
||||
t.Logf("Requeue output: %s", output)
|
||||
})
|
||||
}
|
||||
|
||||
// TestOutcomeSetE2E tests the ml outcome set command
|
||||
func TestOutcomeSetE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("OutcomeSetUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "outcome", "set", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected outcome set --help to show usage")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDatasetVerifyE2E tests the ml dataset verify command
|
||||
func TestDatasetVerifyE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cliPath := e2eCLIPath(t)
|
||||
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
|
||||
t.Skip("CLI not built - run 'make build' first")
|
||||
}
|
||||
|
||||
t.Run("DatasetVerifyUsage", func(t *testing.T) {
|
||||
output, _ := runCLI(t, cliPath, "dataset", "verify", "--help")
|
||||
if !contains(output, "Usage") {
|
||||
t.Error("expected dataset verify --help to show usage")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DatasetVerifyTempDir", func(t *testing.T) {
|
||||
datasetDir := t.TempDir()
|
||||
for i := 0; i < 5; i++ {
|
||||
f := filepath.Join(datasetDir, "file.txt")
|
||||
os.WriteFile(f, []byte("test data"), 0644)
|
||||
}
|
||||
|
||||
output, _ := runCLI(t, cliPath, "dataset", "verify", datasetDir, "--json")
|
||||
t.Logf("Dataset verify output: %s", output)
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &result); err == nil {
|
||||
if result["ok"] == true {
|
||||
t.Log("Dataset verify returned ok")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in a new issue