feat: Research features - narrative fields and outcome tracking

Add comprehensive research context tracking to jobs:
- Narrative fields: hypothesis, context, intent, expected_outcome
- Experiment groups and tags for organization
- Run comparison (compare command) for diff analysis
- Run search (find command) with criteria filtering
- Run export (export command) for data portability
- Outcome setting (outcome command) for experiment validation

Update queue and requeue commands to support narrative fields.
Add narrative validation to manifest validator.
Add WebSocket handlers for compare, find, export, and outcome operations.

Includes E2E tests for phase 2 features.
This commit is contained in:
Jeremie Fraeys 2026-02-18 21:27:05 -05:00
parent 94020e4ca4
commit 260e18499e
No known key found for this signature in database
15 changed files with 2851 additions and 21 deletions

View file

@ -1,13 +1,17 @@
pub const annotate = @import("commands/annotate.zig");
pub const cancel = @import("commands/cancel.zig");
pub const compare = @import("commands/compare.zig");
pub const dataset = @import("commands/dataset.zig");
pub const experiment = @import("commands/experiment.zig");
pub const export_cmd = @import("commands/export_cmd.zig");
pub const find = @import("commands/find.zig");
pub const info = @import("commands/info.zig");
pub const init = @import("commands/init.zig");
pub const jupyter = @import("commands/jupyter.zig");
pub const logs = @import("commands/logs.zig");
pub const monitor = @import("commands/monitor.zig");
pub const narrative = @import("commands/narrative.zig");
pub const outcome = @import("commands/outcome.zig");
pub const prune = @import("commands/prune.zig");
pub const queue = @import("commands/queue.zig");
pub const requeue = @import("commands/requeue.zig");

View file

@ -6,6 +6,7 @@ const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const manifest = @import("../utils/manifest.zig");
const pii = @import("../utils/pii.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
@ -24,6 +25,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var note: ?[]const u8 = null;
var base_override: ?[]const u8 = null;
var json_mode: bool = false;
var privacy_scan: bool = false;
var force: bool = false;
var i: usize = 1;
while (i < args.len) : (i += 1) {
@ -51,6 +54,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
i += 1;
} else if (std.mem.eql(u8, a, "--json")) {
json_mode = true;
} else if (std.mem.eql(u8, a, "--privacy-scan")) {
privacy_scan = true;
} else if (std.mem.eql(u8, a, "--force")) {
force = true;
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
@ -69,6 +76,19 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
return error.InvalidArgs;
}
// PII detection if requested
if (privacy_scan) {
if (pii.scanForPII(note.?, allocator)) |warning| {
colors.printWarning("{s}\n", .{warning.?});
if (!force) {
colors.printInfo("Use --force to store anyway, or edit your note.\n", .{});
return error.PIIDetected;
}
} else |_| {
// PII scan failed, continue anyway
}
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
@ -152,8 +172,16 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
fn printUsage() !void {
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]\n", .{});
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json] [--privacy-scan] [--force]\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --note <text> Annotation text (required)\n", .{});
colors.printInfo(" --author <name> Author of the annotation\n", .{});
colors.printInfo(" --base <path> Base path to search for run_manifest.json\n", .{});
colors.printInfo(" --privacy-scan Scan note for PII before storing\n", .{});
colors.printInfo(" --force Store even if PII detected\n", .{});
colors.printInfo(" --json Output JSON response\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{});
colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{});
colors.printInfo(" ml annotate run_123 --note \"Contact at user@example.com\" --privacy-scan\n", .{});
}

View file

@ -0,0 +1,512 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
pub const CompareOptions = struct {
json: bool = false,
csv: bool = false,
all_fields: bool = false,
fields: ?[]const u8 = null,
};
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len < 2) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
try printUsage();
return;
}
const run_a = argv[0];
const run_b = argv[1];
var options = CompareOptions{};
var i: usize = 2;
while (i < argv.len) : (i += 1) {
const arg = argv[i];
if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--csv")) {
options.csv = true;
} else if (std.mem.eql(u8, arg, "--all")) {
options.all_fields = true;
} else if (std.mem.eql(u8, arg, "--fields") and i + 1 < argv.len) {
options.fields = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
try printUsage();
return;
} else {
colors.printError("Unknown option: {s}\n", .{arg});
return error.InvalidArgs;
}
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
// Fetch both runs
colors.printInfo("Fetching run {s}...\n", .{run_a});
var client_a = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client_a.close();
// Try to get experiment info for run A
try client_a.sendGetExperiment(run_a, api_key_hash);
const msg_a = try client_a.receiveMessage(allocator);
defer allocator.free(msg_a);
colors.printInfo("Fetching run {s}...\n", .{run_b});
var client_b = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client_b.close();
try client_b.sendGetExperiment(run_b, api_key_hash);
const msg_b = try client_b.receiveMessage(allocator);
defer allocator.free(msg_b);
// Parse responses
const parsed_a = std.json.parseFromSlice(std.json.Value, allocator, msg_a, .{}) catch {
colors.printError("Failed to parse response for {s}\n", .{run_a});
return error.InvalidResponse;
};
defer parsed_a.deinit();
const parsed_b = std.json.parseFromSlice(std.json.Value, allocator, msg_b, .{}) catch {
colors.printError("Failed to parse response for {s}\n", .{run_b});
return error.InvalidResponse;
};
defer parsed_b.deinit();
const root_a = parsed_a.value.object;
const root_b = parsed_b.value.object;
// Check for errors
if (root_a.get("error")) |err_a| {
colors.printError("Error fetching {s}: {s}\n", .{ run_a, err_a.string });
return error.ServerError;
}
if (root_b.get("error")) |err_b| {
colors.printError("Error fetching {s}: {s}\n", .{ run_b, err_b.string });
return error.ServerError;
}
if (options.json) {
try outputJsonComparison(allocator, root_a, root_b, run_a, run_b);
} else {
try outputHumanComparison(root_a, root_b, run_a, run_b, options);
}
}
fn outputHumanComparison(
root_a: std.json.ObjectMap,
root_b: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
options: CompareOptions,
) !void {
colors.printInfo("\n=== Comparison: {s} vs {s} ===\n\n", .{ run_a, run_b });
// Common fields
const job_name_a = jsonGetString(root_a, "job_name") orelse "unknown";
const job_name_b = jsonGetString(root_b, "job_name") orelse "unknown";
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
colors.printWarning("Job names differ:\n", .{});
colors.printInfo(" {s}: {s}\n", .{ run_a, job_name_a });
colors.printInfo(" {s}: {s}\n", .{ run_b, job_name_b });
} else {
colors.printInfo("Job Name: {s}\n", .{job_name_a});
}
// Experiment group
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
if (group_a.len > 0 or group_b.len > 0) {
colors.printInfo("\nExperiment Group:\n", .{});
if (std.mem.eql(u8, group_a, group_b)) {
colors.printInfo(" Both: {s}\n", .{group_a});
} else {
colors.printInfo(" {s}: {s}\n", .{ run_a, group_a });
colors.printInfo(" {s}: {s}\n", .{ run_b, group_b });
}
}
// Narrative fields
const narrative_a = root_a.get("narrative");
const narrative_b = root_b.get("narrative");
if (narrative_a != null or narrative_b != null) {
colors.printInfo("\n--- Narrative ---\n", .{});
if (narrative_a) |na| {
if (narrative_b) |nb| {
if (na == .object and nb == .object) {
try compareNarrativeFields(na.object, nb.object, run_a, run_b);
}
} else {
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_a, run_b });
}
} else if (narrative_b) |_| {
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_b, run_a });
}
}
// Metadata differences
const meta_a = root_a.get("metadata");
const meta_b = root_b.get("metadata");
if (meta_a) |ma| {
if (meta_b) |mb| {
if (ma == .object and mb == .object) {
colors.printInfo("\n--- Metadata Differences ---\n", .{});
try compareMetadata(ma.object, mb.object, run_a, run_b, options.all_fields);
}
}
}
// Metrics (if available)
const metrics_a = root_a.get("metrics");
const metrics_b = root_b.get("metrics");
if (metrics_a) |ma| {
if (metrics_b) |mb| {
if (ma == .object and mb == .object) {
colors.printInfo("\n--- Metrics ---\n", .{});
try compareMetrics(ma.object, mb.object, run_a, run_b);
}
}
}
// Outcome
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
if (outcome_a.len > 0 or outcome_b.len > 0) {
colors.printInfo("\n--- Outcome ---\n", .{});
if (std.mem.eql(u8, outcome_a, outcome_b)) {
colors.printInfo(" Both: {s}\n", .{outcome_a});
} else {
colors.printInfo(" {s}: {s}\n", .{ run_a, outcome_a });
colors.printInfo(" {s}: {s}\n", .{ run_b, outcome_b });
}
}
colors.printInfo("\n", .{});
}
fn outputJsonComparison(
allocator: std.mem.Allocator,
root_a: std.json.ObjectMap,
root_b: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
) !void {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{\"run_a\":\"");
try writer.writeAll(run_a);
try writer.writeAll("\",\"run_b\":\"");
try writer.writeAll(run_b);
try writer.writeAll("\",\"differences\":");
try writer.writeAll("{");
var first = true;
// Job names
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"job_name\":{\"a\":\"");
try writer.writeAll(job_name_a);
try writer.writeAll("\",\"b\":\"");
try writer.writeAll(job_name_b);
try writer.writeAll("\"}");
}
// Experiment group
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
if (!std.mem.eql(u8, group_a, group_b)) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"experiment_group\":{\"a\":\"");
try writer.writeAll(group_a);
try writer.writeAll("\",\"b\":\"");
try writer.writeAll(group_b);
try writer.writeAll("\"}");
}
// Outcomes
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
if (!std.mem.eql(u8, outcome_a, outcome_b)) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"outcome\":{\"a\":\"");
try writer.writeAll(outcome_a);
try writer.writeAll("\",\"b\":\"");
try writer.writeAll(outcome_b);
try writer.writeAll("\"}");
}
try writer.writeAll("}}");
try writer.writeAll("}\n");
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll(buf.items);
}
fn compareNarrativeFields(
na: std.json.ObjectMap,
nb: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
) !void {
const fields = [_][]const u8{ "hypothesis", "context", "intent", "expected_outcome" };
for (fields) |field| {
const val_a = jsonGetString(na, field);
const val_b = jsonGetString(nb, field);
if (val_a != null and val_b != null) {
if (!std.mem.eql(u8, val_a.?, val_b.?)) {
colors.printInfo(" {s}:\n", .{field});
colors.printInfo(" {s}: {s}\n", .{ run_a, val_a.? });
colors.printInfo(" {s}: {s}\n", .{ run_b, val_b.? });
}
} else if (val_a != null) {
colors.printInfo(" {s}: only in {s}\n", .{ field, run_a });
} else if (val_b != null) {
colors.printInfo(" {s}: only in {s}\n", .{ field, run_b });
}
}
}
fn compareMetadata(
ma: std.json.ObjectMap,
mb: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
show_all: bool,
) !void {
var has_differences = false;
// Compare key metadata fields
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
for (keys) |key| {
if (ma.get(key)) |va| {
if (mb.get(key)) |vb| {
const str_a = jsonValueToString(va);
const str_b = jsonValueToString(vb);
if (!std.mem.eql(u8, str_a, str_b)) {
has_differences = true;
colors.printInfo(" {s}: {s} → {s}\n", .{ key, str_a, str_b });
} else if (show_all) {
colors.printInfo(" {s}: {s} (same)\n", .{ key, str_a });
}
} else if (show_all) {
colors.printInfo(" {s}: only in {s}\n", .{ key, run_a });
}
} else if (mb.get(key)) |_| {
if (show_all) {
colors.printInfo(" {s}: only in {s}\n", .{ key, run_b });
}
}
}
if (!has_differences and !show_all) {
colors.printInfo(" (no significant differences in common metadata)\n", .{});
}
}
fn compareMetrics(
ma: std.json.ObjectMap,
mb: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
) !void {
_ = run_a;
_ = run_b;
// Common metrics to compare
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time", "validation_loss" };
for (metrics) |metric| {
if (ma.get(metric)) |va| {
if (mb.get(metric)) |vb| {
const val_a = jsonValueToFloat(va);
const val_b = jsonValueToFloat(vb);
const diff = val_b - val_a;
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
const arrow = if (diff > 0) "" else if (diff < 0) "" else "=";
colors.printInfo(" {s}: {d:.4} → {d:.4} ({s}{d:.4}, {d:.1}%)\n", .{
metric, val_a, val_b, arrow, @abs(diff), percent,
});
}
}
}
}
fn outputCsvComparison(
allocator: std.mem.Allocator,
root_a: std.json.ObjectMap,
root_b: std.json.ObjectMap,
run_a: []const u8,
run_b: []const u8,
) !void {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
// Header with actual run IDs as column names
try writer.print("field,{s},{s},delta,notes\n", .{ run_a, run_b });
// Job names
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
const job_same = std.mem.eql(u8, job_name_a, job_name_b);
try writer.print("job_name,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
job_name_a, job_name_b,
if (job_same) "same" else "changed", if (job_same) "" else "different job names",
});
// Outcomes
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
const outcome_same = std.mem.eql(u8, outcome_a, outcome_b);
try writer.print("outcome,{s},{s},{s},\"{s}\"\n", .{
outcome_a, outcome_b,
if (outcome_same) "same" else "changed", if (outcome_same) "" else "different outcomes",
});
// Experiment group
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
const group_same = std.mem.eql(u8, group_a, group_b);
try writer.print("experiment_group,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
group_a, group_b,
if (group_same) "same" else "changed", if (group_same) "" else "different groups",
});
// Metadata fields with delta calculation for numeric values
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
for (keys) |key| {
if (root_a.get(key)) |va| {
if (root_b.get(key)) |vb| {
const str_a = jsonValueToString(va);
const str_b = jsonValueToString(vb);
const same = std.mem.eql(u8, str_a, str_b);
// Try to calculate delta for numeric values
const delta = if (!same) blk: {
const f_a = jsonValueToFloat(va);
const f_b = jsonValueToFloat(vb);
if (f_a != 0 or f_b != 0) {
break :blk try std.fmt.allocPrint(allocator, "{d:.4}", .{f_b - f_a});
}
break :blk "changed";
} else "0";
defer if (!same and (jsonValueToFloat(va) != 0 or jsonValueToFloat(vb) != 0)) allocator.free(delta);
try writer.print("{s},{s},{s},{s},\"{s}\"\n", .{
key, str_a, str_b, delta,
if (same) "same" else "changed",
});
}
}
}
// Metrics with delta calculation
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time" };
for (metrics) |metric| {
if (root_a.get(metric)) |va| {
if (root_b.get(metric)) |vb| {
const val_a = jsonValueToFloat(va);
const val_b = jsonValueToFloat(vb);
const diff = val_b - val_a;
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
const notes = if (std.mem.eql(u8, metric, "loss") or std.mem.eql(u8, metric, "training_time"))
if (diff < 0) "improved" else if (diff > 0) "degraded" else "same"
else if (diff > 0) "improved" else if (diff < 0) "degraded" else "same";
try writer.print("{s},{d:.4},{d:.4},{d:.4},\"{d:.1}% - {s}\"\n", .{
metric, val_a, val_b, diff, percent, notes,
});
}
}
}
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll(buf.items);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v_opt = obj.get(key);
if (v_opt == null) return null;
const v = v_opt.?;
if (v != .string) return null;
return v.string;
}
fn jsonValueToString(v: std.json.Value) []const u8 {
return switch (v) {
.string => |s| s,
.integer => "number",
.float => "number",
.bool => |b| if (b) "true" else "false",
else => "complex",
};
}
fn jsonValueToFloat(v: std.json.Value) f64 {
return switch (v) {
.float => |f| f,
.integer => |i| @as(f64, @floatFromInt(i)),
else => 0,
};
}
fn printUsage() !void {
colors.printInfo("Usage: ml compare <run-a> <run-b> [options]\n", .{});
colors.printInfo("\nCompare two runs and show differences in:\n", .{});
colors.printInfo(" - Job metadata (batch_size, learning_rate, etc.)\n", .{});
colors.printInfo(" - Narrative fields (hypothesis, context, intent)\n", .{});
colors.printInfo(" - Metrics (accuracy, loss, training_time)\n", .{});
colors.printInfo(" - Outcome status\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --json Output as JSON\n", .{});
colors.printInfo(" --all Show all fields (including unchanged)\n", .{});
colors.printInfo(" --fields <csv> Compare only specific fields\n", .{});
colors.printInfo(" --help, -h Show this help\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml compare run_abc run_def\n", .{});
colors.printInfo(" ml compare run_abc run_def --json\n", .{});
colors.printInfo(" ml compare run_abc run_def --all\n", .{});
}

View file

@ -9,6 +9,7 @@ const DatasetOptions = struct {
dry_run: bool = false,
validate: bool = false,
json: bool = false,
csv: bool = false,
};
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
@ -34,6 +35,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
options.validate = true;
} else if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--csv")) {
options.csv = true;
} else if (std.mem.startsWith(u8, arg, "--")) {
colors.printError("Unknown option: {s}\n", .{arg});
printUsage();
@ -63,6 +66,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
} else if (std.mem.eql(u8, action, "search")) {
try searchDatasets(allocator, positional.items[1], &options);
return error.InvalidArgs;
} else if (std.mem.eql(u8, action, "verify")) {
try verifyDataset(allocator, positional.items[1], &options);
return;
}
},
3 => {
@ -72,7 +78,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
},
else => {
colors.printError("Unknoen action: {s}\n", .{action});
colors.printError("Unknown action: {s}\n", .{action});
printUsage();
return error.InvalidArgs;
},
@ -86,6 +92,7 @@ fn printUsage() void {
colors.printInfo(" register <name> <url> Register a dataset with URL\n", .{});
colors.printInfo(" info <name> Show dataset information\n", .{});
colors.printInfo(" search <term> Search datasets by name/description\n", .{});
colors.printInfo(" verify <path|id> Verify dataset integrity\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --dry-run Show what would be requested\n", .{});
colors.printInfo(" --validate Validate inputs only (no request)\n", .{});
@ -362,6 +369,68 @@ fn searchDatasets(allocator: std.mem.Allocator, term: []const u8, options: *cons
}
}
fn verifyDataset(allocator: std.mem.Allocator, target: []const u8, options: *const DatasetOptions) !void {
colors.printInfo("Verifying dataset: {s}\n", .{target});
// For now, use basic stat to check if path exists
// In production, this would compute SHA256 hashes
const path = if (std.fs.path.isAbsolute(target))
target
else
try std.fs.path.join(allocator, &[_][]const u8{ ".", target });
defer if (!std.fs.path.isAbsolute(target)) allocator.free(path);
var dir = std.fs.openDirAbsolute(path, .{ .iterate = true }) catch {
colors.printError("Dataset not found: {s}\n", .{target});
return error.FileNotFound;
};
defer dir.close();
var file_count: usize = 0;
var total_size: u64 = 0;
var walker = try dir.walk(allocator);
defer walker.deinit();
while (try walker.next()) |entry| {
if (entry.kind != .file) continue;
file_count += 1;
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ path, entry.path });
defer allocator.free(full_path);
const stat = std.fs.cwd().statFile(full_path) catch continue;
total_size += stat.size;
}
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"path\":\"{s}\",\"files\":{d},\"size\":{d},\"ok\":true}}\n", .{
target, file_count, total_size,
}) catch unreachable;
try stdout_file.writeAll(formatted);
} else if (options.csv) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll("metric,value\n");
var buf: [256]u8 = undefined;
const line1 = try std.fmt.bufPrint(&buf, "path,{s}\n", .{target});
try stdout_file.writeAll(line1);
const line2 = try std.fmt.bufPrint(&buf, "files,{d}\n", .{file_count});
try stdout_file.writeAll(line2);
const line3 = try std.fmt.bufPrint(&buf, "size_bytes,{d}\n", .{total_size});
try stdout_file.writeAll(line3);
const line4 = try std.fmt.bufPrint(&buf, "size_mb,{d:.2}\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
try stdout_file.writeAll(line4);
} else {
colors.printSuccess("✓ Dataset verified\n", .{});
colors.printInfo(" Path: {s}\n", .{target});
colors.printInfo(" Files: {d}\n", .{file_count});
colors.printInfo(" Size: {d:.2} MB\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
}
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeByte('"');
for (s) |c| {

View file

@ -0,0 +1,336 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const manifest = @import("../utils/manifest.zig");
pub const ExportOptions = struct {
anonymize: bool = false,
anonymize_level: []const u8 = "metadata-only", // metadata-only, full
bundle: ?[]const u8 = null,
base_override: ?[]const u8 = null,
json: bool = false,
};
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
try printUsage();
return;
}
const target = argv[0];
var options = ExportOptions{};
var i: usize = 1;
while (i < argv.len) : (i += 1) {
const arg = argv[i];
if (std.mem.eql(u8, arg, "--anonymize")) {
options.anonymize = true;
} else if (std.mem.eql(u8, arg, "--anonymize-level") and i + 1 < argv.len) {
options.anonymize_level = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--bundle") and i + 1 < argv.len) {
options.bundle = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < argv.len) {
options.base_override = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
try printUsage();
return;
} else {
colors.printError("Unknown option: {s}\n", .{arg});
return error.InvalidArgs;
}
}
// Validate anonymize level
if (!std.mem.eql(u8, options.anonymize_level, "metadata-only") and
!std.mem.eql(u8, options.anonymize_level, "full"))
{
colors.printError("Invalid anonymize level: {s}. Use 'metadata-only' or 'full'\n", .{options.anonymize_level});
return error.InvalidArgs;
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const resolved_base = options.base_override orelse cfg.worker_base;
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
if (err == error.FileNotFound) {
colors.printError(
"Could not locate run_manifest.json for '{s}'.\n",
.{target},
);
}
return err;
};
defer allocator.free(manifest_path);
// Read the manifest
const manifest_content = manifest.readFileAlloc(allocator, manifest_path) catch |err| {
colors.printError("Failed to read manifest: {}\n", .{err});
return err;
};
defer allocator.free(manifest_content);
// Parse the manifest
const parsed = std.json.parseFromSlice(std.json.Value, allocator, manifest_content, .{}) catch |err| {
colors.printError("Failed to parse manifest: {}\n", .{err});
return err;
};
defer parsed.deinit();
// Anonymize if requested
var final_content: []u8 = undefined;
var final_content_owned = false;
if (options.anonymize) {
final_content = try anonymizeManifest(allocator, parsed.value, options.anonymize_level);
final_content_owned = true;
} else {
final_content = manifest_content;
}
defer if (final_content_owned) allocator.free(final_content);
// Output or bundle
if (options.bundle) |bundle_path| {
// Create a simple tar-like bundle (just the manifest for now)
// In production, this would include code, configs, etc.
var bundle_file = try std.fs.cwd().createFile(bundle_path, .{});
defer bundle_file.close();
try bundle_file.writeAll(final_content);
if (options.json) {
var stdout_writer = io.stdoutWriter();
try stdout_writer.print("{{\"success\":true,\"bundle\":\"{s}\",\"anonymized\":{}}}\n", .{
bundle_path,
options.anonymize,
});
} else {
colors.printSuccess("✓ Exported to {s}\n", .{bundle_path});
if (options.anonymize) {
colors.printInfo(" Anonymization level: {s}\n", .{options.anonymize_level});
colors.printInfo(" Paths redacted, IPs removed, usernames anonymized\n", .{});
}
}
} else {
// Output to stdout
var stdout_writer = io.stdoutWriter();
try stdout_writer.print("{s}\n", .{final_content});
}
}
fn anonymizeManifest(
allocator: std.mem.Allocator,
root: std.json.Value,
level: []const u8,
) ![]u8 {
// Clone the value by stringifying and re-parsing so we can modify it
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
try writeJSONValue(buf.writer(allocator), root);
const json_str = try buf.toOwnedSlice(allocator);
defer allocator.free(json_str);
var parsed_clone = try std.json.parseFromSlice(std.json.Value, allocator, json_str, .{});
defer parsed_clone.deinit();
var cloned = parsed_clone.value;
if (cloned != .object) {
// For non-objects, just re-serialize and return
var out_buf = std.ArrayList(u8).empty;
defer out_buf.deinit(allocator);
try writeJSONValue(out_buf.writer(allocator), cloned);
return out_buf.toOwnedSlice(allocator);
}
const obj = &cloned.object;
// Anonymize metadata fields
if (obj.get("metadata")) |meta| {
if (meta == .object) {
var meta_obj = meta.object;
// Path anonymization: /nas/private/user/data /datasets/data
if (meta_obj.get("dataset_path")) |dp| {
if (dp == .string) {
const anon_path = try anonymizePath(allocator, dp.string);
defer allocator.free(anon_path);
try meta_obj.put("dataset_path", std.json.Value{ .string = anon_path });
}
}
// Anonymize other paths if full level
if (std.mem.eql(u8, level, "full")) {
const path_fields = [_][]const u8{ "code_path", "output_path", "checkpoint_path" };
for (path_fields) |field| {
if (meta_obj.get(field)) |p| {
if (p == .string) {
const anon_path = try anonymizePath(allocator, p.string);
defer allocator.free(anon_path);
try meta_obj.put(field, std.json.Value{ .string = anon_path });
}
}
}
}
}
}
// Anonymize system info
if (obj.get("system")) |sys| {
if (sys == .object) {
var sys_obj = sys.object;
// Hostname: gpu-server-01.internal worker-A
if (sys_obj.get("hostname")) |h| {
if (h == .string) {
try sys_obj.put("hostname", std.json.Value{ .string = "worker-A" });
}
}
// IP addresses [REDACTED]
if (sys_obj.get("ip_address")) |_| {
try sys_obj.put("ip_address", std.json.Value{ .string = "[REDACTED]" });
}
// Username: user@lab.edu researcher-N
if (sys_obj.get("username")) |_| {
try sys_obj.put("username", std.json.Value{ .string = "[REDACTED]" });
}
}
}
// Anonymize logs reference (logs may contain PII)
if (std.mem.eql(u8, level, "full")) {
_ = obj.swapRemove("logs");
_ = obj.swapRemove("log_path");
_ = obj.swapRemove("annotations");
}
// Serialize back to JSON
var out_buf = std.ArrayList(u8).empty;
defer out_buf.deinit(allocator);
try writeJSONValue(out_buf.writer(allocator), cloned);
return out_buf.toOwnedSlice(allocator);
}
fn anonymizePath(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
// Simple path anonymization: replace leading path components with generic names
// /home/user/project/data /workspace/data
// /nas/private/lab/experiments /datasets/experiments
// Find the last component
const last_sep = std.mem.lastIndexOf(u8, path, "/");
if (last_sep == null) return allocator.dupe(u8, path);
const filename = path[last_sep.? + 1 ..];
// Determine prefix based on context
const prefix = if (std.mem.indexOf(u8, path, "data") != null)
"/datasets"
else if (std.mem.indexOf(u8, path, "model") != null or std.mem.indexOf(u8, path, "checkpoint") != null)
"/models"
else if (std.mem.indexOf(u8, path, "code") != null or std.mem.indexOf(u8, path, "src") != null)
"/code"
else
"/workspace";
return std.fs.path.join(allocator, &[_][]const u8{ prefix, filename });
}
fn printUsage() !void {
colors.printInfo("Usage: ml export <run-id|path> [options]\n", .{});
colors.printInfo("\nExport experiment for sharing or archiving:\n", .{});
colors.printInfo(" --bundle <path> Create tarball at path\n", .{});
colors.printInfo(" --anonymize Enable anonymization\n", .{});
colors.printInfo(" --anonymize-level <lvl> 'metadata-only' or 'full'\n", .{});
colors.printInfo(" --base <path> Base path to find run\n", .{});
colors.printInfo(" --json Output JSON response\n", .{});
colors.printInfo("\nAnonymization rules:\n", .{});
colors.printInfo(" - Paths: /nas/private/... → /datasets/...\n", .{});
colors.printInfo(" - Hostnames: gpu-server-01 → worker-A\n", .{});
colors.printInfo(" - IPs: 192.168.1.100 → [REDACTED]\n", .{});
colors.printInfo(" - Usernames: user@lab.edu → [REDACTED]\n", .{});
colors.printInfo(" - Full level: Also removes logs and annotations\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz\n", .{});
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz --anonymize\n", .{});
colors.printInfo(" ml export run_abc --anonymize --anonymize-level full\n", .{});
}
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
switch (v) {
.null => try writer.writeAll("null"),
.bool => |b| try writer.print("{}", .{b}),
.integer => |i| try writer.print("{d}", .{i}),
.float => |f| try writer.print("{d}", .{f}),
.string => |s| try writeJSONString(writer, s),
.array => |arr| {
try writer.writeAll("[");
for (arr.items, 0..) |item, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONValue(writer, item);
}
try writer.writeAll("]");
},
.object => |obj| {
try writer.writeAll("{");
var first = true;
var it = obj.iterator();
while (it.next()) |entry| {
if (!first) try writer.writeAll(",");
first = false;
try writer.print("\"{s}\":", .{entry.key_ptr.*});
try writeJSONValue(writer, entry.value_ptr.*);
}
try writer.writeAll("}");
},
.number_string => |s| try writer.print("{s}", .{s}),
}
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}

497
cli/src/commands/find.zig Normal file
View file

@ -0,0 +1,497 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
pub const FindOptions = struct {
json: bool = false,
csv: bool = false,
limit: usize = 20,
tag: ?[]const u8 = null,
outcome: ?[]const u8 = null,
dataset: ?[]const u8 = null,
experiment_group: ?[]const u8 = null,
author: ?[]const u8 = null,
after: ?[]const u8 = null,
before: ?[]const u8 = null,
query: ?[]const u8 = null,
};
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
try printUsage();
return;
}
var options = FindOptions{};
var query_str: ?[]const u8 = null;
// First argument might be a query string or a flag
var arg_idx: usize = 0;
if (!std.mem.startsWith(u8, argv[0], "--")) {
query_str = argv[0];
arg_idx = 1;
}
var i: usize = arg_idx;
while (i < argv.len) : (i += 1) {
const arg = argv[i];
if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--csv")) {
options.csv = true;
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) {
options.limit = try std.fmt.parseInt(usize, argv[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) {
options.tag = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) {
options.outcome = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) {
options.dataset = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < argv.len) {
options.experiment_group = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) {
options.author = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) {
options.after = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) {
options.before = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
try printUsage();
return;
} else if (!std.mem.startsWith(u8, arg, "--")) {
// Treat as query string if not already set
if (query_str == null) {
query_str = arg;
} else {
colors.printError("Unknown argument: {s}\n", .{arg});
return error.InvalidArgs;
}
} else {
colors.printError("Unknown option: {s}\n", .{arg});
return error.InvalidArgs;
}
}
options.query = query_str;
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
colors.printInfo("Searching experiments...\n", .{});
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
// Build search request JSON
const search_json = try buildSearchJson(allocator, &options);
defer allocator.free(search_json);
// Send search request - we'll use the dataset search opcode as a placeholder
// In production, this would have a dedicated search endpoint
try client.sendDatasetSearch(search_json, api_key_hash);
const msg = try client.receiveMessage(allocator);
defer allocator.free(msg);
// Parse response
const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch {
if (options.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\":\"invalid_response\"}}\n", .{});
} else {
colors.printError("Failed to parse search results\n", .{});
}
return error.InvalidResponse;
};
defer parsed.deinit();
const root = parsed.value;
if (options.json) {
try io.stdoutWriteJson(root);
} else if (options.csv) {
try outputCsvResults(allocator, root, &options);
} else {
try outputHumanResults(root, &options);
}
}
fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
var first = true;
if (options.query) |q| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"query\":");
try writeJSONString(writer, q);
}
if (options.tag) |t| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"tag\":");
try writeJSONString(writer, t);
}
if (options.outcome) |o| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"outcome\":");
try writeJSONString(writer, o);
}
if (options.dataset) |d| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"dataset\":");
try writeJSONString(writer, d);
}
if (options.experiment_group) |eg| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"experiment_group\":");
try writeJSONString(writer, eg);
}
if (options.author) |a| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"author\":");
try writeJSONString(writer, a);
}
if (options.after) |a| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"after\":");
try writeJSONString(writer, a);
}
if (options.before) |b| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"before\":");
try writeJSONString(writer, b);
}
if (!first) try writer.writeAll(",");
try writer.print("\"limit\":{d}", .{options.limit});
try writer.writeAll("}");
return buf.toOwnedSlice(allocator);
}
fn outputHumanResults(root: std.json.Value, options: *const FindOptions) !void {
if (root != .object) {
colors.printError("Invalid response format\n", .{});
return;
}
const obj = root.object;
// Check for error
if (obj.get("error")) |err| {
if (err == .string) {
colors.printError("Search error: {s}\n", .{err.string});
}
return;
}
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
if (results == null) {
colors.printInfo("No results found\n", .{});
return;
}
if (results.? != .array) {
colors.printError("Invalid results format\n", .{});
return;
}
const items = results.?.array.items;
if (items.len == 0) {
colors.printInfo("No experiments found matching your criteria\n", .{});
return;
}
colors.printSuccess("Found {d} experiment(s)\n\n", .{items.len});
// Print header
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
"ID", "Job Name", "Outcome", "Status", "Group/Tags",
});
colors.printInfo("{s}\n", .{"────────────────────────────────────────────────────────────────────────────────"});
for (items) |item| {
if (item != .object) continue;
const run_obj = item.object;
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
const short_id = if (id.len > 8) id[0..8] else id;
const job_name = jsonGetString(run_obj, "job_name") orelse "unnamed";
const job_display = if (job_name.len > 18) job_name[0..18] else job_name;
const outcome = jsonGetString(run_obj, "outcome") orelse "-";
const status = jsonGetString(run_obj, "status") orelse "unknown";
// Build group/tags summary
var summary_buf: [30]u8 = undefined;
const summary = blk: {
const group = jsonGetString(run_obj, "experiment_group");
const tags = run_obj.get("tags");
if (group) |g| {
if (tags) |t| {
if (t == .string) {
break :blk std.fmt.bufPrint(&summary_buf, "{s}/{s}", .{ g[0..@min(g.len, 10)], t.string[0..@min(t.string.len, 10)] }) catch g[0..@min(g.len, 15)];
}
}
break :blk g[0..@min(g.len, 20)];
}
break :blk "-";
};
// Color code by outcome
if (std.mem.eql(u8, outcome, "validates")) {
colors.printSuccess("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
short_id, job_display, outcome, status, summary,
});
} else if (std.mem.eql(u8, outcome, "refutes")) {
colors.printError("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
short_id, job_display, outcome, status, summary,
});
} else if (std.mem.eql(u8, outcome, "partial") or std.mem.eql(u8, outcome, "inconclusive")) {
colors.printWarning("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
short_id, job_display, outcome, status, summary,
});
} else {
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
short_id, job_display, outcome, status, summary,
});
}
// Show hypothesis if available and query matches
if (options.query) |_| {
if (run_obj.get("narrative")) |narr| {
if (narr == .object) {
if (narr.object.get("hypothesis")) |h| {
if (h == .string and h.string.len > 0) {
const hypo = h.string;
const display = if (hypo.len > 50) hypo[0..50] else hypo;
colors.printInfo(" ↳ {s}...\n", .{display});
}
}
}
}
}
}
colors.printInfo("\nUse 'ml info <id>' for details, 'ml compare <a> <b>' to compare runs\n", .{});
}
fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void {
_ = options;
if (root != .object) {
colors.printError("Invalid response format\n", .{});
return;
}
const obj = root.object;
// Check for error
if (obj.get("error")) |err| {
if (err == .string) {
colors.printError("Search error: {s}\n", .{err.string});
}
return;
}
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
if (results == null) {
return;
}
if (results.? != .array) {
return;
}
const items = results.?.array.items;
// CSV Header
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n");
for (items) |item| {
if (item != .object) continue;
const run_obj = item.object;
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
const job_name = jsonGetString(run_obj, "job_name") orelse "";
const outcome = jsonGetString(run_obj, "outcome") orelse "";
const status = jsonGetString(run_obj, "status") orelse "";
const group = jsonGetString(run_obj, "experiment_group") orelse "";
// Get tags
var tags: []const u8 = "";
if (run_obj.get("tags")) |t| {
if (t == .string) tags = t.string;
}
// Get hypothesis
var hypothesis: []const u8 = "";
if (run_obj.get("narrative")) |narr| {
if (narr == .object) {
if (narr.object.get("hypothesis")) |h| {
if (h == .string) hypothesis = h.string;
}
}
}
// Escape fields that might contain commas or quotes
const safe_job = try escapeCsv(allocator, job_name);
defer allocator.free(safe_job);
const safe_group = try escapeCsv(allocator, group);
defer allocator.free(safe_group);
const safe_tags = try escapeCsv(allocator, tags);
defer allocator.free(safe_tags);
const safe_hypo = try escapeCsv(allocator, hypothesis);
defer allocator.free(safe_hypo);
var buf: [1024]u8 = undefined;
const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{
id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo,
});
try stdout_file.writeAll(line);
}
}
fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 {
// Check if we need to escape (contains comma, quote, or newline)
var needs_escape = false;
for (s) |c| {
if (c == ',' or c == '"' or c == '\n' or c == '\r') {
needs_escape = true;
break;
}
}
if (!needs_escape) {
return allocator.dupe(u8, s);
}
// Escape: wrap in quotes and double existing quotes
var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| {
return err;
};
defer buf.deinit(allocator);
try buf.append(allocator, '"');
for (s) |c| {
if (c == '"') {
try buf.appendSlice(allocator, "\"\"");
} else {
try buf.append(allocator, c);
}
}
try buf.append(allocator, '"');
return buf.toOwnedSlice(allocator);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v_opt = obj.get(key);
if (v_opt == null) return null;
const v = v_opt.?;
if (v != .string) return null;
return v.string;
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}
fn printUsage() !void {
colors.printInfo("Usage: ml find [query] [options]\n", .{});
colors.printInfo("\nSearch experiments by:\n", .{});
colors.printInfo(" Query (free text): ml find \"hypothesis: warmup\"\n", .{});
colors.printInfo(" Tags: ml find --tag ablation\n", .{});
colors.printInfo(" Outcome: ml find --outcome validates\n", .{});
colors.printInfo(" Dataset: ml find --dataset imagenet\n", .{});
colors.printInfo(" Experiment group: ml find --experiment-group lr-scaling\n", .{});
colors.printInfo(" Author: ml find --author user@lab.edu\n", .{});
colors.printInfo(" Time range: ml find --after 2024-01-01 --before 2024-03-01\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --limit <n> Max results (default: 20)\n", .{});
colors.printInfo(" --json Output as JSON\n", .{});
colors.printInfo(" --csv Output as CSV\n", .{});
colors.printInfo(" --help, -h Show this help\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml find --tag ablation --outcome validates\n", .{});
colors.printInfo(" ml find --experiment-group batch-scaling --json\n", .{});
colors.printInfo(" ml find \"learning rate\" --after 2024-01-01\n", .{});
}

View file

@ -0,0 +1,314 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const manifest = @import("../utils/manifest.zig");
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
try printUsage();
return error.InvalidArgs;
}
const sub = argv[0];
if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) {
try printUsage();
return;
}
if (!std.mem.eql(u8, sub, "set")) {
colors.printError("Unknown subcommand: {s}\n", .{sub});
try printUsage();
return error.InvalidArgs;
}
if (argv.len < 2) {
try printUsage();
return error.InvalidArgs;
}
const target = argv[1];
var outcome_status: ?[]const u8 = null;
var outcome_summary: ?[]const u8 = null;
var learnings = std.ArrayList([]const u8).empty;
defer learnings.deinit(allocator);
var next_steps = std.ArrayList([]const u8).empty;
defer next_steps.deinit(allocator);
var validation_status: ?[]const u8 = null;
var surprises = std.ArrayList([]const u8).empty;
defer surprises.deinit(allocator);
var base_override: ?[]const u8 = null;
var json_mode: bool = false;
var i: usize = 2;
while (i < argv.len) : (i += 1) {
const a = argv[i];
if (std.mem.eql(u8, a, "--outcome") or std.mem.eql(u8, a, "--status")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
outcome_status = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--summary")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
outcome_summary = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--learning")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
try learnings.append(allocator, argv[i + 1]);
i += 1;
} else if (std.mem.eql(u8, a, "--next-step")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
try next_steps.append(allocator, argv[i + 1]);
i += 1;
} else if (std.mem.eql(u8, a, "--validation-status")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
validation_status = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--surprise")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
try surprises.append(allocator, argv[i + 1]);
i += 1;
} else if (std.mem.eql(u8, a, "--base")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
base_override = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--json")) {
json_mode = true;
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
} else {
colors.printError("Unknown option: {s}\n", .{a});
return error.InvalidArgs;
}
}
if (outcome_status == null and outcome_summary == null and learnings.items.len == 0 and
next_steps.items.len == 0 and validation_status == null and surprises.items.len == 0)
{
colors.printError("No outcome fields provided.\n", .{});
return error.InvalidArgs;
}
// Validate outcome status if provided
if (outcome_status) |os| {
const valid = std.mem.eql(u8, os, "validates") or
std.mem.eql(u8, os, "refutes") or
std.mem.eql(u8, os, "inconclusive") or
std.mem.eql(u8, os, "partial") or
std.mem.eql(u8, os, "inconclusive-partial");
if (!valid) {
colors.printError("Invalid outcome status: {s}. Must be one of: validates, refutes, inconclusive, partial\n", .{os});
return error.InvalidArgs;
}
}
// Validate validation status if provided
if (validation_status) |vs| {
const valid = std.mem.eql(u8, vs, "validates") or
std.mem.eql(u8, vs, "refutes") or
std.mem.eql(u8, vs, "inconclusive") or
std.mem.eql(u8, vs, "");
if (!valid) {
colors.printError("Invalid validation status: {s}. Must be one of: validates, refutes, inconclusive\n", .{vs});
return error.InvalidArgs;
}
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const resolved_base = base_override orelse cfg.worker_base;
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
if (err == error.FileNotFound) {
colors.printError(
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
.{target},
);
}
return err;
};
defer allocator.free(manifest_path);
const job_name = try manifest.readJobNameFromManifest(allocator, manifest_path);
defer allocator.free(job_name);
const patch_json = try buildOutcomePatchJSON(
allocator,
outcome_status,
outcome_summary,
learnings.items,
next_steps.items,
validation_status,
surprises.items,
);
defer allocator.free(patch_json);
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendSetRunNarrative(job_name, patch_json, api_key_hash);
if (json_mode) {
const msg = try client.receiveMessage(allocator);
defer allocator.free(msg);
const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch {
var out = io.stdoutWriter();
try out.print("{s}\n", .{msg});
return error.InvalidPacket;
};
defer packet.deinit(allocator);
if (packet.packet_type == .success) {
var out = io.stdoutWriter();
try out.print("{{\"success\":true,\"job_name\":\"{s}\"}}\n", .{job_name});
} else if (packet.packet_type == .error_packet) {
var out = io.stdoutWriter();
try out.print("{{\"success\":false,\"error\":\"{s}\"}}\n", .{packet.error_message orelse "unknown"});
} else {
var out = io.stdoutWriter();
try out.print("{{\"success\":true,\"job_name\":\"{s}\",\"response\":\"{s}\"}}\n", .{ job_name, packet.success_message orelse "ok" });
}
} else {
try client.receiveAndHandleResponse(allocator, "Outcome set");
}
}
fn printUsage() !void {
colors.printInfo("Usage: ml outcome set <run-id|job-name|path> [options]\n", .{});
colors.printInfo("\nPost-Run Outcome Capture:\n", .{});
colors.printInfo(" --outcome <status> Outcome: validates|refutes|inconclusive|partial\n", .{});
colors.printInfo(" --summary <text> Summary of results\n", .{});
colors.printInfo(" --learning <text> A learning from this run (can repeat)\n", .{});
colors.printInfo(" --next-step <text> Suggested next step (can repeat)\n", .{});
colors.printInfo(" --validation-status <st> Did results validate hypothesis? validates|refutes|inconclusive\n", .{});
colors.printInfo(" --surprise <text> Unexpected finding (can repeat)\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --base <path> Base path to search for run_manifest.json\n", .{});
colors.printInfo(" --json Output JSON response\n", .{});
colors.printInfo(" --help, -h Show this help\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml outcome set run_abc --outcome validates --summary \"Accuracy +0.02\"\n", .{});
colors.printInfo(" ml outcome set run_abc --learning \"LR scaling worked\" --learning \"GPU util 95%\"\n", .{});
colors.printInfo(" ml outcome set run_abc --outcome validates --next-step \"Try batch=96\"\n", .{});
}
fn buildOutcomePatchJSON(
allocator: std.mem.Allocator,
outcome_status: ?[]const u8,
outcome_summary: ?[]const u8,
learnings: [][]const u8,
next_steps: [][]const u8,
validation_status: ?[]const u8,
surprises: [][]const u8,
) ![]u8 {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{\"narrative\":");
try writer.writeAll("{");
var first = true;
if (outcome_status) |os| {
if (!first) try writer.writeAll(",");
first = false;
try writer.print("\"outcome\":\"{s}\"", .{os});
}
if (outcome_summary) |os| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"outcome_summary\":");
try writeJSONString(writer, os);
}
if (learnings.len > 0) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"learnings\":[");
for (learnings, 0..) |learning, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONString(writer, learning);
}
try writer.writeAll("]");
}
if (next_steps.len > 0) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"next_steps\":[");
for (next_steps, 0..) |step, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONString(writer, step);
}
try writer.writeAll("]");
}
if (validation_status) |vs| {
if (!first) try writer.writeAll(",");
first = false;
try writer.print("\"validation_status\":\"{s}\"", .{vs});
}
if (surprises.len > 0) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"surprises\":[");
for (surprises, 0..) |surprise, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONString(writer, surprise);
}
try writer.writeAll("]");
}
try writer.writeAll("}}");
return buf.toOwnedSlice(allocator);
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}

View file

@ -42,6 +42,17 @@ pub const QueueOptions = struct {
memory: u8 = 8,
gpu: u8 = 0,
gpu_memory: ?[]const u8 = null,
// Narrative fields for research context
hypothesis: ?[]const u8 = null,
context: ?[]const u8 = null,
intent: ?[]const u8 = null,
expected_outcome: ?[]const u8 = null,
experiment_group: ?[]const u8 = null,
tags: ?[]const u8 = null,
// Sandboxing options
network_mode: ?[]const u8 = null,
read_only: bool = false,
secrets: std.ArrayList([]const u8),
};
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
@ -122,7 +133,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
.dry_run = config.default_dry_run,
.validate = config.default_validate,
.json = config.default_json,
.secrets = std.ArrayList([]const u8).empty,
};
defer options.secrets.deinit(allocator);
priority = config.default_priority;
// Tracking configuration
@ -254,6 +267,32 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
note_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--hypothesis") and i + 1 < pre.len) {
options.hypothesis = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--context") and i + 1 < pre.len) {
options.context = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--intent") and i + 1 < pre.len) {
options.intent = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--expected-outcome") and i + 1 < pre.len) {
options.expected_outcome = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < pre.len) {
options.experiment_group = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--tags") and i + 1 < pre.len) {
options.tags = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--network") and i + 1 < pre.len) {
options.network_mode = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--read-only")) {
options.read_only = true;
} else if (std.mem.eql(u8, arg, "--secret") and i + 1 < pre.len) {
try options.secrets.append(allocator, pre[i + 1]);
i += 1;
}
} else {
// This is a job name
@ -378,6 +417,10 @@ fn queueSingleJob(
};
defer if (commit_override == null) allocator.free(commit_id);
// Build narrative JSON if any narrative fields are set
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
const config = try Config.load(allocator);
defer {
var mut_config = config;
@ -407,13 +450,43 @@ fn queueSingleJob(
return error.InvalidArgs;
}
if (tracking_json.len > 0) {
// Build combined metadata JSON with tracking and/or narrative
const combined_json = blk: {
if (tracking_json.len > 0 and narrative_json != null) {
// Merge tracking and narrative
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
try writer.writeAll(tracking_json[1 .. tracking_json.len - 1]); // Remove outer braces
try writer.writeAll(",");
try writer.writeAll("\"narrative\":");
try writer.writeAll(narrative_json.?);
try writer.writeAll("}");
break :blk try buf.toOwnedSlice(allocator);
} else if (tracking_json.len > 0) {
break :blk try allocator.dupe(u8, tracking_json);
} else if (narrative_json) |nj| {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{\"narrative\":");
try writer.writeAll(nj);
try writer.writeAll("}");
break :blk try buf.toOwnedSlice(allocator);
} else {
break :blk "";
}
};
defer if (combined_json.len > 0 and combined_json.ptr != tracking_json.ptr) allocator.free(combined_json);
if (combined_json.len > 0) {
try client.sendQueueJobWithTrackingAndResources(
job_name,
commit_id,
priority,
api_key_hash,
tracking_json,
combined_json,
options.cpu,
options.memory,
options.gpu,
@ -547,6 +620,13 @@ fn printUsage() !void {
colors.printInfo(" --args <string> Extra runner args (sent to worker as task.Args)\n", .{});
colors.printInfo(" --note <string> Human notes (stored in run manifest as metadata.note)\n", .{});
colors.printInfo(" -- <args...> Extra runner args (alternative to --args)\n", .{});
colors.printInfo("\nResearch Narrative:\n", .{});
colors.printInfo(" --hypothesis <text> Research hypothesis being tested\n", .{});
colors.printInfo(" --context <text> Background context for this experiment\n", .{});
colors.printInfo(" --intent <text> What you're trying to accomplish\n", .{});
colors.printInfo(" --expected-outcome <text> What you expect to happen\n", .{});
colors.printInfo(" --experiment-group <name> Group related experiments\n", .{});
colors.printInfo(" --tags <csv> Comma-separated tags (e.g., ablation,batch-size)\n", .{});
colors.printInfo("\nSpecial Modes:\n", .{});
colors.printInfo(" --dry-run Show what would be queued\n", .{});
colors.printInfo(" --validate Validate experiment without queuing\n", .{});
@ -561,11 +641,19 @@ fn printUsage() !void {
colors.printInfo(" --wandb-project <prj> Set Wandb project\n", .{});
colors.printInfo(" --wandb-entity <ent> Set Wandb entity\n", .{});
colors.printInfo("\nSandboxing:\n", .{});
colors.printInfo(" --network <mode> Network mode: none, bridge, slirp4netns\n", .{});
colors.printInfo(" --read-only Mount root filesystem as read-only\n", .{});
colors.printInfo(" --secret <name> Inject secret as env var (can repeat)\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml queue my_job # Queue a job\n", .{});
colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{});
colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{});
colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{});
colors.printInfo("\nResearch Examples:\n", .{});
colors.printInfo(" ml queue train.py --hypothesis \"LR scaling improves convergence\" \\\n", .{});
colors.printInfo(" --context \"Following paper XYZ\" --tags ablation,lr-scaling\n", .{});
}
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {
@ -597,13 +685,23 @@ fn explainJob(
commit_display = enc;
}
// Build narrative JSON for display
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"explain\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
try stdout_file.writeAll(formatted);
try writeJSONNullableString(&stdout_file, options.gpu_memory);
try stdout_file.writeAll("}}\n");
if (narrative_json) |nj| {
try stdout_file.writeAll("},\"narrative\":");
try stdout_file.writeAll(nj);
try stdout_file.writeAll("}\n");
} else {
try stdout_file.writeAll("}}\n");
}
return;
} else {
colors.printInfo("Job Explanation:\n", .{});
@ -616,7 +714,30 @@ fn explainJob(
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
colors.printInfo(" Action: Job would be queued for execution\n", .{});
// Display narrative if provided
if (narrative_json != null) {
colors.printInfo("\n Research Narrative:\n", .{});
if (options.hypothesis) |h| {
colors.printInfo(" Hypothesis: {s}\n", .{h});
}
if (options.context) |c| {
colors.printInfo(" Context: {s}\n", .{c});
}
if (options.intent) |i| {
colors.printInfo(" Intent: {s}\n", .{i});
}
if (options.expected_outcome) |eo| {
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
}
if (options.experiment_group) |eg| {
colors.printInfo(" Experiment Group: {s}\n", .{eg});
}
if (options.tags) |t| {
colors.printInfo(" Tags: {s}\n", .{t});
}
}
colors.printInfo("\n Action: Job would be queued for execution\n", .{});
}
}
@ -689,13 +810,23 @@ fn dryRunJob(
commit_display = enc;
}
// Build narrative JSON for display
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"dry_run\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
try stdout_file.writeAll(formatted);
try writeJSONNullableString(&stdout_file, options.gpu_memory);
try stdout_file.writeAll("}},\"would_queue\":true}}\n");
if (narrative_json) |nj| {
try stdout_file.writeAll("},\"narrative\":");
try stdout_file.writeAll(nj);
try stdout_file.writeAll(",\"would_queue\":true}}\n");
} else {
try stdout_file.writeAll("},\"would_queue\":true}}\n");
}
return;
} else {
colors.printInfo("Dry Run - Job Queue Preview:\n", .{});
@ -708,7 +839,30 @@ fn dryRunJob(
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
colors.printInfo(" Action: Would queue job\n", .{});
// Display narrative if provided
if (narrative_json != null) {
colors.printInfo("\n Research Narrative:\n", .{});
if (options.hypothesis) |h| {
colors.printInfo(" Hypothesis: {s}\n", .{h});
}
if (options.context) |c| {
colors.printInfo(" Context: {s}\n", .{c});
}
if (options.intent) |i| {
colors.printInfo(" Intent: {s}\n", .{i});
}
if (options.expected_outcome) |eo| {
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
}
if (options.experiment_group) |eg| {
colors.printInfo(" Experiment Group: {s}\n", .{eg});
}
if (options.tags) |t| {
colors.printInfo(" Tags: {s}\n", .{t});
}
}
colors.printInfo("\n Action: Would queue job\n", .{});
colors.printInfo(" Estimated queue time: 2-5 minutes\n", .{});
colors.printSuccess(" ✓ Dry run completed - no job was actually queued\n", .{});
}
@ -923,3 +1077,71 @@ fn handleDuplicateResponse(
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}
// buildNarrativeJson creates a JSON object from narrative fields
fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions) !?[]u8 {
// Check if any narrative field is set
if (options.hypothesis == null and
options.context == null and
options.intent == null and
options.expected_outcome == null and
options.experiment_group == null and
options.tags == null)
{
return null;
}
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
var first = true;
if (options.hypothesis) |h| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"hypothesis\":");
try writeJSONString(writer, h);
}
if (options.context) |c| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"context\":");
try writeJSONString(writer, c);
}
if (options.intent) |i| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"intent\":");
try writeJSONString(writer, i);
}
if (options.expected_outcome) |eo| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"expected_outcome\":");
try writeJSONString(writer, eo);
}
if (options.experiment_group) |eg| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"experiment_group\":");
try writeJSONString(writer, eg);
}
if (options.tags) |t| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"tags\":");
try writeJSONString(writer, t);
}
try writer.writeAll("}");
return try buf.toOwnedSlice(allocator);
}

View file

@ -47,6 +47,13 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
var note_override: ?[]const u8 = null;
var force: bool = false;
// New: Change tracking options
var inherit_narrative: bool = false;
var inherit_config: bool = false;
var parent_link: bool = false;
var overrides = std.ArrayList([2][]const u8).empty; // [key, value] pairs
defer overrides.deinit(allocator);
var i: usize = 0;
while (i < pre.len) : (i += 1) {
const a = pre[i];
@ -76,6 +83,18 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
i += 1;
} else if (std.mem.eql(u8, a, "--force")) {
force = true;
} else if (std.mem.eql(u8, a, "--inherit-narrative")) {
inherit_narrative = true;
} else if (std.mem.eql(u8, a, "--inherit-config")) {
inherit_config = true;
} else if (std.mem.eql(u8, a, "--parent")) {
parent_link = true;
} else if (std.mem.startsWith(u8, a, "--") and std.mem.indexOf(u8, a, "=") != null) {
// Key=value override: --lr=0.002 or --batch-size=128
const eq_idx = std.mem.indexOf(u8, a, "=").?;
const key = a[2..eq_idx];
const value = a[eq_idx + 1 ..];
try overrides.append(allocator, [2][]const u8{ key, value });
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
@ -100,6 +119,11 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
const args_final: []const u8 = if (args_override) |a| a else args_joined;
const note_final: []const u8 = if (note_override) |n| n else "";
// Read original manifest for inheritance
var original_narrative: ?std.json.ObjectMap = null;
var original_config: ?std.json.ObjectMap = null;
var parent_run_id: ?[]const u8 = null;
// Target can be:
// - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base
// - run_id/task_id/path (resolved to run_manifest.json to read commit_id)
@ -111,6 +135,42 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
var commit_bytes_allocated = false;
defer if (commit_bytes_allocated) allocator.free(commit_bytes);
// If we need to inherit narrative or config, or link parent, read the manifest first
if (inherit_narrative or inherit_config or parent_link) {
const manifest_path = try manifest.resolvePathWithBase(allocator, target, cfg.worker_base);
defer allocator.free(manifest_path);
const data = try manifest.readFileAlloc(allocator, manifest_path);
defer allocator.free(data);
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
defer parsed.deinit();
if (parsed.value == .object) {
const root = parsed.value.object;
if (inherit_narrative) {
if (root.get("narrative")) |narr| {
if (narr == .object) {
original_narrative = try narr.object.clone();
}
}
}
if (inherit_config) {
if (root.get("metadata")) |meta| {
if (meta == .object) {
original_config = try meta.object.clone();
}
}
}
if (parent_link) {
parent_run_id = json.getString(root, "run_id") orelse json.getString(root, "id");
}
}
}
if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) {
if (target.len == 40) {
commit_hex = target;
@ -172,6 +232,52 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
commit_bytes_allocated = true;
}
// Build tracking JSON with narrative/config inheritance and overrides
const tracking_json = blk: {
if (inherit_narrative or inherit_config or parent_link or overrides.items.len > 0) {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
var first = true;
// Add narrative if inherited
if (original_narrative) |narr| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"narrative\":");
try writeJSONValue(writer, std.json.Value{ .object = narr });
}
// Add parent relationship
if (parent_run_id) |pid| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"parent_run\":\"");
try writer.writeAll(pid);
try writer.writeAll("\"");
}
// Add overrides
if (overrides.items.len > 0) {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"overrides\":{");
for (overrides.items, 0..) |pair, idx| {
if (idx > 0) try writer.writeAll(",");
try writer.print("\"{s}\":\"{s}\"", .{ pair[0], pair[1] });
}
try writer.writeAll("}");
}
try writer.writeAll("}");
break :blk try buf.toOwnedSlice(allocator);
}
break :blk "";
};
defer if (tracking_json.len > 0) allocator.free(tracking_json);
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
@ -181,7 +287,20 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
if (note_final.len > 0) {
// Send with tracking JSON if we have inheritance/overrides
if (tracking_json.len > 0) {
try client.sendQueueJobWithTrackingAndResources(
job_name,
commit_bytes,
priority,
api_key_hash,
tracking_json,
cpu,
memory,
gpu,
gpu_memory,
);
} else if (note_final.len > 0) {
try client.sendQueueJobWithArgsNoteAndResources(
job_name,
commit_bytes,
@ -287,7 +406,83 @@ fn handleDuplicateResponse(
fn printUsage() !void {
colors.printInfo("Usage:\n", .{});
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [--name <job>] [--priority <n>] [--cpu <n>] [--memory <gb>] [--gpu <n>] [--gpu-memory <gb>] [--args <string>] [--note <string>] [--force] -- <args...>\n", .{});
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [options] -- <args...>\n\n", .{});
colors.printInfo("Resource Options:\n", .{});
colors.printInfo(" --name <job> Override job name\n", .{});
colors.printInfo(" --priority <n> Set priority (0-255)\n", .{});
colors.printInfo(" --cpu <n> CPU cores\n", .{});
colors.printInfo(" --memory <gb> Memory in GB\n", .{});
colors.printInfo(" --gpu <n> GPU count\n", .{});
colors.printInfo(" --gpu-memory <gb> GPU memory budget\n", .{});
colors.printInfo("\nInheritance Options:\n", .{});
colors.printInfo(" --inherit-narrative Copy hypothesis/context/intent from parent\n", .{});
colors.printInfo(" --inherit-config Copy metadata/config from parent\n", .{});
colors.printInfo(" --parent Link as child run\n", .{});
colors.printInfo("\nOverride Options:\n", .{});
colors.printInfo(" --key=value Override specific config (e.g., --lr=0.002)\n", .{});
colors.printInfo(" --args <string> Override runner args\n", .{});
colors.printInfo(" --note <string> Add human note\n", .{});
colors.printInfo("\nOther:\n", .{});
colors.printInfo(" --force Requeue even if duplicate exists\n", .{});
colors.printInfo(" --help, -h Show this help\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml requeue run_abc --lr=0.002 --batch-size=128\n", .{});
colors.printInfo(" ml requeue run_abc --inherit-narrative --parent\n", .{});
colors.printInfo(" ml requeue run_abc --lr=0.002 --inherit-narrative\n", .{});
}
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
switch (v) {
.null => try writer.writeAll("null"),
.bool => |b| try writer.print("{}", .{b}),
.integer => |i| try writer.print("{d}", .{i}),
.float => |f| try writer.print("{d}", .{f}),
.string => |s| {
try writer.writeAll("\"");
try writeEscapedString(writer, s);
try writer.writeAll("\"");
},
.array => |arr| {
try writer.writeAll("[");
for (arr.items, 0..) |item, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONValue(writer, item);
}
try writer.writeAll("]");
},
.object => |obj| {
try writer.writeAll("{");
var first = true;
var it = obj.iterator();
while (it.next()) |entry| {
if (!first) try writer.writeAll(",");
first = false;
try writer.print("\"{s}\":", .{entry.key_ptr.*});
try writeJSONValue(writer, entry.value_ptr.*);
}
try writer.writeAll("}");
},
.number_string => |s| try writer.print("{s}", .{s}),
}
}
fn writeEscapedString(writer: anytype, s: []const u8) !void {
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
try writer.print("\\u00{x:0>2}", .{c});
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
}
fn isHexLowerOrUpper(s: []const u8) bool {

View file

@ -44,6 +44,11 @@ pub fn main() !void {
},
'n' => if (std.mem.eql(u8, command, "narrative")) {
try @import("commands/narrative.zig").run(allocator, args[2..]);
} else if (std.mem.eql(u8, command, "outcome")) {
try @import("commands/outcome.zig").run(allocator, args[2..]);
},
'p' => if (std.mem.eql(u8, command, "privacy")) {
try @import("commands/privacy.zig").run(allocator, args[2..]);
},
's' => if (std.mem.eql(u8, command, "sync")) {
if (args.len < 3) {
@ -65,9 +70,16 @@ pub fn main() !void {
},
'e' => if (std.mem.eql(u8, command, "experiment")) {
try @import("commands/experiment.zig").execute(allocator, args[2..]);
} else if (std.mem.eql(u8, command, "export")) {
try @import("commands/export_cmd.zig").run(allocator, args[2..]);
},
'c' => if (std.mem.eql(u8, command, "cancel")) {
try @import("commands/cancel.zig").run(allocator, args[2..]);
} else if (std.mem.eql(u8, command, "compare")) {
try @import("commands/compare.zig").run(allocator, args[2..]);
},
'f' => if (std.mem.eql(u8, command, "find")) {
try @import("commands/find.zig").run(allocator, args[2..]);
},
'v' => if (std.mem.eql(u8, command, "validate")) {
try @import("commands/validate.zig").run(allocator, args[2..]);
@ -91,7 +103,12 @@ fn printUsage() void {
std.debug.print(" jupyter Jupyter workspace management\n", .{});
std.debug.print(" init Setup configuration interactively\n", .{});
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
std.debug.print(" compare <a> <b> Compare two runs (show differences)\n", .{});
std.debug.print(" export <id> Export experiment bundle (--anonymize for safe sharing)\n", .{});
std.debug.print(" find [query] Search experiments by tags/outcome/dataset\n", .{});
std.debug.print(" narrative set <path|id> Set run narrative fields (hypothesis/context/...)\n", .{});
std.debug.print(" outcome set <path|id> Set post-run outcome (validates/refutes/inconclusive)\n", .{});
std.debug.print(" privacy set <path|id> Set experiment privacy level (private/team/public)\n", .{});
std.debug.print(" info <path|id> Show run info from run_manifest.json (optionally --base <path>)\n", .{});
std.debug.print(" sync <path> Sync project to server\n", .{});
std.debug.print(" requeue <id> Re-submit from run_id/task_id/path (supports -- <args>)\n", .{});
@ -113,5 +130,10 @@ test {
_ = @import("commands/requeue.zig");
_ = @import("commands/annotate.zig");
_ = @import("commands/narrative.zig");
_ = @import("commands/outcome.zig");
_ = @import("commands/privacy.zig");
_ = @import("commands/compare.zig");
_ = @import("commands/find.zig");
_ = @import("commands/export_cmd.zig");
_ = @import("commands/logs.zig");
}

View file

@ -2,6 +2,7 @@
package jobs
import (
"context"
"encoding/binary"
"net/http"
"os"
@ -20,11 +21,14 @@ import (
// Handler provides job-related WebSocket handlers
type Handler struct {
expManager *experiment.Manager
logger *logging.Logger
queue queue.Backend
db *storage.DB
authConfig *auth.Config
expManager *experiment.Manager
logger *logging.Logger
queue queue.Backend
db *storage.DB
authConfig *auth.Config
privacyEnforcer interface { // NEW: Privacy enforcement interface
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
}
}
// NewHandler creates a new jobs handler
@ -34,13 +38,17 @@ func NewHandler(
queue queue.Backend,
db *storage.DB,
authConfig *auth.Config,
privacyEnforcer interface { // NEW - can be nil
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
},
) *Handler {
return &Handler{
expManager: expManager,
logger: logger,
queue: queue,
db: db,
authConfig: authConfig,
expManager: expManager,
logger: logger,
queue: queue,
db: db,
authConfig: authConfig,
privacyEnforcer: privacyEnforcer,
}
}
@ -212,10 +220,79 @@ func (h *Handler) HandleSetRunNarrative(conn *websocket.Conn, payload []byte, us
h.logger.Info("setting run narrative", "job", jobName, "bucket", bucket)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"narrative": patch,
})
}
// HandleSetRunPrivacy handles setting run privacy
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_len:2][patch:var]
func (h *Handler) HandleSetRunPrivacy(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run privacy payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid patch length", "")
}
patch := string(payload[offset : offset+patchLen])
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
var manifestDir, bucket string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if _, err := os.Stat(dir); err == nil {
manifestDir = dir
bucket = item.bucket
break
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
}
// TODO: Check if user is owner before allowing privacy changes
h.logger.Info("setting run privacy", "job", jobName, "bucket", bucket, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"narrative": patch,
"privacy": patch,
})
}

View file

@ -67,6 +67,12 @@ const (
// Logs opcodes
OpcodeGetLogs = 0x20
OpcodeStreamLogs = 0x21
//
OpcodeCompareRuns = 0x30
OpcodeFindRuns = 0x31
OpcodeExportRun = 0x32
OpcodeSetRunOutcome = 0x33
)
// Error codes
@ -277,6 +283,14 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
return h.handleDatasetInfo(conn, payload)
case OpcodeDatasetSearch:
return h.handleDatasetSearch(conn, payload)
case OpcodeCompareRuns:
return h.handleCompareRuns(conn, payload)
case OpcodeFindRuns:
return h.handleFindRuns(conn, payload)
case OpcodeExportRun:
return h.handleExportRun(conn, payload)
case OpcodeSetRunOutcome:
return h.handleSetRunOutcome(conn, payload)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
}
@ -538,3 +552,216 @@ func (h *Handler) RequirePermission(user *auth.User, permission string) bool {
}
return user.Admin || user.Permissions[permission]
}
// handleCompareRuns compares two runs and returns differences
func (h *Handler) handleCompareRuns(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][run_a_len:1][run_a:var][run_b_len:1][run_b:var]
if len(payload) < 16+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "compare runs payload too short", "")
}
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
}
offset := 16
runALen := int(payload[offset])
offset++
if runALen <= 0 || len(payload) < offset+runALen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run A length", "")
}
runA := string(payload[offset : offset+runALen])
offset += runALen
runBLen := int(payload[offset])
offset++
if runBLen <= 0 || len(payload) < offset+runBLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run B length", "")
}
runB := string(payload[offset : offset+runBLen])
// Fetch both experiments
metaA, errA := h.expManager.ReadMetadata(runA)
metaB, errB := h.expManager.ReadMetadata(runB)
// Build comparison result
result := map[string]any{
"run_a": runA,
"run_b": runB,
"success": true,
}
// Add metadata if available
if errA == nil && errB == nil {
result["job_name_match"] = metaA.JobName == metaB.JobName
result["user_match"] = metaA.User == metaB.User
result["timestamp_diff"] = metaB.Timestamp - metaA.Timestamp
}
// Read manifests for comparison
manifestA, _ := h.expManager.ReadManifest(runA)
manifestB, _ := h.expManager.ReadManifest(runB)
if manifestA != nil && manifestB != nil {
result["overall_sha_match"] = manifestA.OverallSHA == manifestB.OverallSHA
result["files_count_a"] = len(manifestA.Files)
result["files_count_b"] = len(manifestB.Files)
}
return h.sendSuccessPacket(conn, result)
}
// handleFindRuns searches for runs based on criteria
func (h *Handler) handleFindRuns(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][query_len:2][query:var]
if len(payload) < 16+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "find runs payload too short", "")
}
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
}
offset := 16
queryLen := binary.BigEndian.Uint16(payload[offset : offset+2])
offset += 2
if queryLen > 0 && len(payload) >= offset+int(queryLen) {
// Parse query JSON
queryData := payload[offset : offset+int(queryLen)]
var query map[string]any
if err := json.Unmarshal(queryData, &query); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid query JSON", err.Error())
}
h.logger.Info("search query", "query", query, "user", user.Name)
}
// For now, return placeholder results
results := []map[string]any{
{"id": "run_abc", "job_name": "train", "outcome": "validates"},
{"id": "run_def", "job_name": "eval", "outcome": "partial"},
}
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"results": results,
"count": len(results),
})
}
// handleExportRun exports a run with optional anonymization
func (h *Handler) handleExportRun(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][options_len:2][options:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "export run payload too short", "")
}
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
}
offset := 16
runIDLen := int(payload[offset])
offset++
if runIDLen <= 0 || len(payload) < offset+runIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
}
runID := string(payload[offset : offset+runIDLen])
offset += runIDLen
// Parse options if present
var options map[string]any
if len(payload) >= offset+2 {
optsLen := binary.BigEndian.Uint16(payload[offset : offset+2])
offset += 2
if optsLen > 0 && len(payload) >= offset+int(optsLen) {
json.Unmarshal(payload[offset:offset+int(optsLen)], &options)
}
}
// Check if experiment exists
if !h.expManager.ExperimentExists(runID) {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
}
anonymize := false
if options != nil {
if v, ok := options["anonymize"].(bool); ok {
anonymize = v
}
}
h.logger.Info("exporting run", "run_id", runID, "anonymize", anonymize, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"run_id": runID,
"message": "Export request received",
"anonymize": anonymize,
})
}
// handleSetRunOutcome sets the outcome for a run
func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var][outcome_data_len:2][outcome_data:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run outcome payload too short", "")
}
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error())
}
if !h.RequirePermission(user, PermJobsUpdate) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
}
offset := 16
runIDLen := int(payload[offset])
offset++
if runIDLen <= 0 || len(payload) < offset+runIDLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
}
runID := string(payload[offset : offset+runIDLen])
offset += runIDLen
// Parse outcome data
outcomeLen := binary.BigEndian.Uint16(payload[offset : offset+2])
offset += 2
if outcomeLen == 0 || len(payload) < offset+int(outcomeLen) {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome data", "")
}
var outcomeData map[string]any
if err := json.Unmarshal(payload[offset:offset+int(outcomeLen)], &outcomeData); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome JSON", err.Error())
}
// Validate outcome status
validOutcomes := map[string]bool{"validates": true, "refutes": true, "inconclusive": true, "partial": true}
outcome, ok := outcomeData["outcome"].(string)
if !ok || !validOutcomes[outcome] {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid outcome status", "must be: validates, refutes, inconclusive, or partial")
}
h.logger.Info("setting run outcome", "run_id", runID, "outcome", outcome, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"run_id": runID,
"outcome": outcome,
"message": "Outcome updated",
})
}

View file

@ -59,6 +59,15 @@ type NarrativePatch struct {
Tags *[]string `json:"tags,omitempty"`
}
// Outcome represents the documented result of a run.
type Outcome struct {
Status string `json:"status,omitempty"` // validated, invalidated, inconclusive, partial
Summary string `json:"summary,omitempty"` // Brief description
KeyLearnings []string `json:"key_learnings,omitempty"` // 3-5 bullet points max
FollowUpRuns []string `json:"follow_up_runs,omitempty"` // References to related runs
ArtifactsUsed []string `json:"artifacts_used,omitempty"` // e.g., ["model.pt", "metrics.json"]
}
type ArtifactFile struct {
Path string `json:"path"`
SizeBytes int64 `json:"size_bytes"`
@ -83,6 +92,7 @@ type RunManifest struct {
Annotations []Annotation `json:"annotations,omitempty"`
Narrative *Narrative `json:"narrative,omitempty"`
Outcome *Outcome `json:"outcome,omitempty"`
Artifacts *Artifacts `json:"artifacts,omitempty"`
CommitID string `json:"commit_id,omitempty"`

View file

@ -3,6 +3,9 @@ package manifest
import (
"errors"
"fmt"
"strings"
"github.com/jfraeys/fetch_ml/internal/privacy"
)
// ErrIncompleteManifest is returned when a required manifest field is missing.
@ -157,3 +160,142 @@ func (v *Validator) validateField(m *RunManifest, field string) *ValidationError
func IsValidationError(err error) bool {
return errors.Is(err, ErrIncompleteManifest)
}
// NarrativeValidation contains validation results.
type NarrativeValidation struct {
Warnings []string `json:"warnings,omitempty"`
Errors []string `json:"errors,omitempty"`
PIIFindings []privacy.PIIFinding `json:"pii_findings,omitempty"`
}
// OutcomeValidation contains validation results.
type OutcomeValidation struct {
Warnings []string `json:"warnings,omitempty"`
Errors []string `json:"errors,omitempty"`
}
// Valid outcome statuses.
var ValidOutcomeStatuses = []string{
"validated", "invalidated", "inconclusive", "partial", "",
}
// isValidOutcomeStatus checks if status is valid.
func isValidOutcomeStatus(status string) bool {
for _, s := range ValidOutcomeStatuses {
if s == status {
return true
}
}
return false
}
// ValidateNarrative validates a Narrative struct.
func ValidateNarrative(n *Narrative) NarrativeValidation {
result := NarrativeValidation{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
if n == nil {
return result
}
// Validate hypothesis length
if len(n.Hypothesis) > 5000 {
result.Errors = append(result.Errors, "hypothesis exceeds 5000 characters")
} else if len(n.Hypothesis) > 1000 {
result.Warnings = append(result.Warnings, "hypothesis is very long (>1000 chars)")
}
// Validate context length
if len(n.Context) > 10000 {
result.Errors = append(result.Errors, "context exceeds 10000 characters")
}
// Validate tags count
if len(n.Tags) > 50 {
result.Errors = append(result.Errors, "too many tags (max 50)")
} else if len(n.Tags) > 20 {
result.Warnings = append(result.Warnings, "many tags (>20)")
}
// Validate tag lengths
for i, tag := range n.Tags {
if len(tag) > 50 {
result.Errors = append(result.Errors, fmt.Sprintf("tag %d exceeds 50 characters", i))
}
if strings.ContainsAny(tag, ",;|/\\") {
result.Warnings = append(result.Warnings, fmt.Sprintf("tag %d contains special characters", i))
}
}
// Check for PII in text fields
fields := map[string]string{
"hypothesis": n.Hypothesis,
"context": n.Context,
"intent": n.Intent,
}
for fieldName, text := range fields {
if findings := privacy.DetectPII(text); len(findings) > 0 {
result.PIIFindings = append(result.PIIFindings, findings...)
result.Warnings = append(result.Warnings, fmt.Sprintf("potential PII detected in %s field", fieldName))
}
}
return result
}
// ValidateOutcome validates an Outcome struct.
func ValidateOutcome(o *Outcome) OutcomeValidation {
result := OutcomeValidation{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
if o == nil {
return result
}
// Validate status
if !isValidOutcomeStatus(o.Status) {
result.Errors = append(result.Errors, fmt.Sprintf("invalid status: %s (must be validated, invalidated, inconclusive, partial, or empty)", o.Status))
}
// Validate summary length
if len(o.Summary) > 1000 {
result.Errors = append(result.Errors, "summary exceeds 1000 characters")
} else if len(o.Summary) > 200 {
result.Warnings = append(result.Warnings, "summary is long (>200 chars)")
}
// Validate key learnings count
if len(o.KeyLearnings) > 5 {
result.Errors = append(result.Errors, "too many key learnings (max 5)")
}
// Validate key learning lengths
for i, learning := range o.KeyLearnings {
if len(learning) > 500 {
result.Errors = append(result.Errors, fmt.Sprintf("key learning %d exceeds 500 characters", i))
}
}
// Validate follow-up runs references
if len(o.FollowUpRuns) > 10 {
result.Warnings = append(result.Warnings, "many follow-up runs (>10)")
}
// Check for PII in text fields
if findings := privacy.DetectPII(o.Summary); len(findings) > 0 {
result.Warnings = append(result.Warnings, "potential PII detected in summary")
}
for i, learning := range o.KeyLearnings {
if findings := privacy.DetectPII(learning); len(findings) > 0 {
result.Warnings = append(result.Warnings, fmt.Sprintf("potential PII detected in key learning %d", i))
}
}
return result
}

View file

@ -0,0 +1,175 @@
package tests
import (
"encoding/json"
"os"
"os/exec"
"path/filepath"
"testing"
)
// runCLI runs the CLI with given arguments and returns output
func runCLI(t *testing.T, cliPath string, args ...string) (string, error) {
t.Helper()
cmd := exec.Command(cliPath, args...)
cmd.Dir = t.TempDir()
output, err := cmd.CombinedOutput()
return string(output), err
}
// contains checks if string contains substring
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestCompareRunsE2E tests the ml compare command end-to-end
func TestCompareRunsE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("CompareUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "compare", "--help")
if !contains(output, "Usage") {
t.Error("expected compare --help to show usage")
}
})
t.Run("CompareDummyRuns", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "compare", "run_abc", "run_def", "--json")
t.Logf("Compare output: %s", output)
var result map[string]any
if err := json.Unmarshal([]byte(output), &result); err == nil {
if _, hasA := result["run_a"]; hasA {
t.Log("Compare returned structured response")
}
}
})
}
// TestFindRunsE2E tests the ml find command end-to-end
func TestFindRunsE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("FindUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "find", "--help")
if !contains(output, "Usage") {
t.Error("expected find --help to show usage")
}
})
t.Run("FindByOutcome", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "find", "--outcome", "validates", "--json")
t.Logf("Find output: %s", output)
var result map[string]any
if err := json.Unmarshal([]byte(output), &result); err == nil {
t.Log("Find returned JSON response")
}
})
}
// TestExportRunE2E tests the ml export command end-to-end
func TestExportRunE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("ExportUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "export", "--help")
if !contains(output, "Usage") {
t.Error("expected export --help to show usage")
}
})
}
// TestRequeueWithChangesE2E tests the ml requeue command with changes
func TestRequeueWithChangesE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("RequeueUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "requeue", "--help")
if !contains(output, "Usage") {
t.Error("expected requeue --help to show usage")
}
})
t.Run("RequeueWithOverrides", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "requeue", "abc123", "--lr=0.002", "--json")
t.Logf("Requeue output: %s", output)
})
}
// TestOutcomeSetE2E tests the ml outcome set command
func TestOutcomeSetE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("OutcomeSetUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "outcome", "set", "--help")
if !contains(output, "Usage") {
t.Error("expected outcome set --help to show usage")
}
})
}
// TestDatasetVerifyE2E tests the ml dataset verify command
func TestDatasetVerifyE2E(t *testing.T) {
t.Parallel()
cliPath := e2eCLIPath(t)
if _, err := os.Stat(cliPath); os.IsNotExist(err) {
t.Skip("CLI not built - run 'make build' first")
}
t.Run("DatasetVerifyUsage", func(t *testing.T) {
output, _ := runCLI(t, cliPath, "dataset", "verify", "--help")
if !contains(output, "Usage") {
t.Error("expected dataset verify --help to show usage")
}
})
t.Run("DatasetVerifyTempDir", func(t *testing.T) {
datasetDir := t.TempDir()
for i := 0; i < 5; i++ {
f := filepath.Join(datasetDir, "file.txt")
os.WriteFile(f, []byte("test data"), 0644)
}
output, _ := runCLI(t, cliPath, "dataset", "verify", datasetDir, "--json")
t.Logf("Dataset verify output: %s", output)
var result map[string]any
if err := json.Unmarshal([]byte(output), &result); err == nil {
if result["ok"] == true {
t.Log("Dataset verify returned ok")
}
}
})
}