feat(cli): Add metadata commands and update cancel

- note.zig: New unified metadata annotation command
  - Supports --text, --hypothesis, --outcome, --confidence, --privacy, --author
  - Stores metadata as tags in SQLite ml_tags table
- log.zig: Simplified to unified logs command (fetch/stream only)
  - Removed metric/param/tag subcommands (now in run wrapper)
  - Supports --follow for live log streaming from server
- cancel.zig: Add local process termination support
  - Sends SIGTERM first, waits 5s, then SIGKILL if needed
  - Updates run status to CANCELLED in SQLite
  - Also supports server job cancellation via WebSocket
This commit is contained in:
Jeremie Fraeys 2026-02-20 21:28:23 -05:00
parent d0c68772ea
commit f5b68cca49
No known key found for this signature in database
3 changed files with 504 additions and 83 deletions

View file

@ -1,126 +1,212 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const config = @import("../config.zig");
const db = @import("../db.zig");
const ws = @import("../net/ws/client.zig");
const crypto = @import("../utils/crypto.zig");
const logging = @import("../utils/logging.zig");
const colors = @import("../utils/colors.zig");
const auth = @import("../utils/auth.zig");
pub const CancelOptions = struct {
force: bool = false,
json: bool = false,
};
const core = @import("../core.zig");
const mode = @import("../mode.zig");
const manifest_lib = @import("../manifest.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var options = CancelOptions{};
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
colors.printError("Failed to allocate job list: {}\n", .{err});
return err;
};
defer job_names.deinit(allocator);
var flags = core.flags.CommonFlags{};
var force = false;
var targets = std.ArrayList([]const u8).init(allocator);
defer targets.deinit();
// Parse arguments for flags and job names
// Parse arguments
var i: usize = 0;
while (i < args.len) : (i += 1) {
const arg = args[i];
if (std.mem.eql(u8, arg, "--force")) {
options.force = true;
force = true;
} else if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.startsWith(u8, arg, "--help")) {
try printUsage();
return;
flags.json = true;
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
return printUsage();
} else if (std.mem.startsWith(u8, arg, "--")) {
colors.printError("Unknown option: {s}\n", .{arg});
try printUsage();
core.output.errorMsg("cancel", "Unknown option");
return error.InvalidArgs;
} else {
// This is a job name
try job_names.append(allocator, arg);
try targets.append(arg);
}
}
if (job_names.items.len == 0) {
colors.printError("No job names specified\n", .{});
try printUsage();
core.output.init(if (flags.json) .json else .text);
if (targets.items.len == 0) {
core.output.errorMsg("cancel", "No run_id specified");
return error.InvalidArgs;
}
const config = try Config.load(allocator);
const cfg = try config.Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Authenticate with server to get user context
var user_context = try auth.authenticateUser(allocator, config);
defer user_context.deinit();
// Detect mode
const mode_result = try mode.detect(allocator, cfg);
if (mode_result.warning) |w| {
std.log.warn("{s}", .{w});
}
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Connect to WebSocket and send cancel messages
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
defer client.close();
// Process each job
var success_count: usize = 0;
var failed_jobs = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
colors.printError("Failed to allocate failed jobs list: {}\n", .{err});
return err;
};
defer failed_jobs.deinit(allocator);
var failed_count: usize = 0;
for (job_names.items, 0..) |job_name, index| {
if (!options.json) {
colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name });
}
cancelSingleJob(allocator, &client, user_context, job_name, options, api_key_hash) catch |err| {
colors.printError("Failed to cancel job '{s}': {}\n", .{ job_name, err });
failed_jobs.append(allocator, job_name) catch |append_err| {
colors.printError("Failed to track failed job: {}\n", .{append_err});
for (targets.items) |target| {
if (mode.isOffline(mode_result.mode)) {
// Local mode: kill by PID
cancelLocal(allocator, target, force, flags.json) catch |err| {
if (!flags.json) {
colors.printError("Failed to cancel '{s}': {}\n", .{ target, err });
}
failed_count += 1;
continue;
};
continue;
};
} else {
// Online mode: cancel on server
cancelServer(allocator, target, force, flags.json, cfg) catch |err| {
if (!flags.json) {
colors.printError("Failed to cancel '{s}': {}\n", .{ target, err });
}
failed_count += 1;
continue;
};
}
success_count += 1;
}
// Show summary
if (!options.json) {
colors.printInfo("\nCancel Summary:\n", .{});
colors.printSuccess("Successfully canceled {d} job(s)\n", .{success_count});
if (failed_jobs.items.len > 0) {
colors.printError("Failed to cancel {d} job(s):\n", .{failed_jobs.items.len});
for (failed_jobs.items) |failed_job| {
colors.printError(" - {s}\n", .{failed_job});
}
if (flags.json) {
std.debug.print("{{\"success\":true,\"canceled\":{d},\"failed\":{d}}}\n", .{ success_count, failed_count });
} else {
colors.printSuccess("Canceled {d} run(s)\n", .{success_count});
if (failed_count > 0) {
colors.printError("Failed to cancel {d} run(s)\n", .{failed_count});
}
}
}
fn cancelSingleJob(allocator: std.mem.Allocator, client: *ws.Client, user_context: auth.UserContext, job_name: []const u8, options: CancelOptions, api_key_hash: []const u8) !void {
/// Cancel local run by PID
fn cancelLocal(allocator: std.mem.Allocator, run_id: []const u8, force: bool, json: bool) !void {
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Get DB path
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Look up PID
const sql = "SELECT pid FROM ml_runs WHERE run_id = ? AND status = 'RUNNING';";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id);
const has_row = try db.DB.step(stmt);
if (!has_row) {
return error.RunNotFoundOrNotRunning;
}
const pid = db.DB.columnInt64(stmt, 0);
if (pid == 0) {
return error.NoPIDAvailable;
}
// Send SIGTERM first
std.posix.kill(@intCast(pid), std.posix.SIG.TERM) catch |err| {
if (err == error.ProcessNotFound) {
// Process already gone
} else {
return err;
}
};
if (!force) {
// Wait 5 seconds for graceful termination
std.time.sleep(5 * std.time.ns_per_s);
}
// Check if still running, send SIGKILL if needed
if (force or isProcessRunning(@intCast(pid))) {
std.posix.kill(@intCast(pid), std.posix.SIG.KILL) catch |err| {
if (err != error.ProcessNotFound) {
return err;
}
};
}
// Update run status
const update_sql = "UPDATE ml_runs SET status = 'CANCELLED', pid = NULL WHERE run_id = ?;";
const update_stmt = try database.prepare(update_sql);
defer db.DB.finalize(update_stmt);
try db.DB.bindText(update_stmt, 1, run_id);
_ = try db.DB.step(update_stmt);
// Update manifest
const artifact_path = try std.fs.path.join(allocator, &[_][]const u8{
cfg.artifact_path,
if (cfg.experiment) |exp| exp.name else "default",
run_id,
"run_manifest.json",
});
defer allocator.free(artifact_path);
manifest_lib.updateManifestStatus(artifact_path, "CANCELLED", null, allocator) catch {};
// Checkpoint
database.checkpointOnExit();
if (!json) {
colors.printSuccess("✓ Canceled run {s}\n", .{run_id[0..8]});
}
}
/// Check if process is still running
fn isProcessRunning(pid: i32) bool {
const result = std.posix.kill(pid, 0);
return result == error.PermissionDenied or result == {};
}
/// Cancel server job
fn cancelServer(allocator: std.mem.Allocator, job_name: []const u8, force: bool, json: bool, cfg: config.Config) !void {
_ = force;
_ = json;
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendCancelJob(job_name, api_key_hash);
// Receive structured response with user context
try client.receiveAndHandleCancelResponse(allocator, user_context, job_name, options);
// Wait for acknowledgment
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response (simplified)
if (std.mem.indexOf(u8, message, "error") != null) {
return error.ServerCancelFailed;
}
}
fn printUsage() !void {
colors.printInfo("Usage: ml cancel [options] <job-name> [<job-name> ...]\n", .{});
colors.printInfo("\nOptions:\n", .{});
colors.printInfo(" --force Force cancel even if job is running\n", .{});
colors.printInfo("Usage: ml cancel [options] <run-id> [<run-id> ...]\n", .{});
colors.printInfo("\nCancel a local run (kill process) or server job.\n\n", .{});
colors.printInfo("Options:\n", .{});
colors.printInfo(" --force Force cancel (SIGKILL immediately)\n", .{});
colors.printInfo(" --json Output structured JSON\n", .{});
colors.printInfo(" --help Show this help message\n", .{});
colors.printInfo(" --help, -h Show this help message\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml cancel job1 # Cancel single job\n", .{});
colors.printInfo(" ml cancel job1 job2 job3 # Cancel multiple jobs\n", .{});
colors.printInfo(" ml cancel --force job1 # Force cancel running job\n", .{});
colors.printInfo(" ml cancel --json job1 # Cancel job with JSON output\n", .{});
colors.printInfo(" ml cancel --force --json job1 job2 # Force cancel with JSON output\n", .{});
colors.printInfo(" ml cancel abc123 # Cancel local run by run_id\n", .{});
colors.printInfo(" ml cancel --force abc123 # Force cancel\n", .{});
}

192
cli/src/commands/log.zig Normal file
View file

@ -0,0 +1,192 @@
const std = @import("std");
const config = @import("../config.zig");
const core = @import("../core.zig");
const colors = @import("../utils/colors.zig");
const manifest_lib = @import("../manifest.zig");
const mode = @import("../mode.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const crypto = @import("../utils/crypto.zig");
/// Logs command - fetch or stream run logs
/// Usage:
/// ml logs <run_id> # Fetch logs from local file or server
/// ml logs <run_id> --follow # Stream logs from server
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var command_args = try core.flags.parseCommon(allocator, args, &flags);
defer command_args.deinit(allocator);
core.output.init(if (flags.json) .json else .text);
if (flags.help) {
return printUsage();
}
if (command_args.items.len < 1) {
std.log.err("Usage: ml logs <run_id> [--follow]", .{});
return error.MissingArgument;
}
const target = command_args.items[0];
const follow = core.flags.parseBoolFlag(command_args.items, "follow");
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Detect mode
const mode_result = try mode.detect(allocator, cfg);
if (mode_result.warning) |w| {
std.log.warn("{s}", .{w});
}
if (mode.isOffline(mode_result.mode)) {
// Local mode: read from output.log file
return try fetchLocalLogs(allocator, target, &cfg, flags.json);
} else {
// Online mode: fetch or stream from server
if (follow) {
return try streamServerLogs(allocator, target, cfg);
} else {
return try fetchServerLogs(allocator, target, cfg);
}
}
}
fn fetchLocalLogs(allocator: std.mem.Allocator, target: []const u8, cfg: *const config.Config, json: bool) !void {
// Resolve manifest path
const manifest_path = manifest_lib.resolveManifestPath(target, cfg.artifact_path, allocator) catch |err| {
if (err == error.ManifestNotFound) {
std.log.err("Run not found: {s}", .{target});
return error.RunNotFound;
}
return err;
};
defer allocator.free(manifest_path);
// Read manifest to get artifact path
const manifest = try manifest_lib.readManifest(manifest_path, allocator);
defer manifest.deinit(allocator);
// Build output.log path
const output_path = try std.fs.path.join(allocator, &[_][]const u8{
manifest.artifact_path,
"output.log",
});
defer allocator.free(output_path);
// Read output.log
const content = std.fs.cwd().readFileAlloc(allocator, output_path, 10 * 1024 * 1024) catch |err| {
if (err == error.FileNotFound) {
std.log.err("No logs found for run: {s}", .{target});
return error.LogsNotFound;
}
return err;
};
defer allocator.free(content);
if (json) {
// Escape content for JSON
var escaped = std.ArrayList(u8).init(allocator);
defer escaped.deinit();
const writer = escaped.writer(allocator);
for (content) |c| {
switch (c) {
'\\' => try writer.writeAll("\\\\"),
'"' => try writer.writeAll("\\\""),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c >= 0x20 and c < 0x7f) {
try writer.writeByte(c);
} else {
try writer.print("\\u{x:0>4}", .{c});
}
},
}
}
std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"logs\":\"{s}\"}}\n", .{
manifest.run_id,
escaped.items,
});
} else {
std.debug.print("{s}\n", .{content});
}
}
fn fetchServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void {
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendFetchLogs(target, api_key_hash);
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
std.debug.print("{s}\n", .{message});
}
fn streamServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void {
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
colors.printInfo("Streaming logs for: {s}\n", .{target});
try client.sendStreamLogs(target, api_key_hash);
// Stream loop
while (true) {
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
std.debug.print("{s}\n", .{message});
continue;
};
defer packet.deinit(allocator);
switch (packet.packet_type) {
.data => {
if (packet.data_payload) |payload| {
std.debug.print("{s}\n", .{payload});
}
},
.error_packet => {
const err_msg = packet.error_message orelse "Stream error";
colors.printError("Error: {s}\n", .{err_msg});
return error.ServerError;
},
else => {},
}
}
}
fn printUsage() !void {
std.debug.print("Usage: ml logs <run_id> [options]\n\n", .{});
std.debug.print("Fetch or stream run logs.\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --follow, -f Stream logs from server (online mode)\n", .{});
std.debug.print(" --help, -h Show this help message\n", .{});
std.debug.print(" --json Output structured JSON\n\n", .{});
std.debug.print("Examples:\n", .{});
std.debug.print(" ml logs abc123 # Fetch logs (local or server)\n", .{});
std.debug.print(" ml logs abc123 --follow # Stream logs from server\n", .{});
}

143
cli/src/commands/note.zig Normal file
View file

@ -0,0 +1,143 @@
const std = @import("std");
const config = @import("../config.zig");
const db = @import("../db.zig");
const core = @import("../core.zig");
const colors = @import("../utils/colors.zig");
const manifest_lib = @import("../manifest.zig");
/// Note command - unified metadata annotation
/// Usage:
/// ml note <run_id> --text "Try lr=3e-4 next"
/// ml note <run_id> --hypothesis "LR scaling helps"
/// ml note <run_id> --outcome validates --confidence 0.9
/// ml note <run_id> --privacy private
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var command_args = try core.flags.parseCommon(allocator, args, &flags);
defer command_args.deinit(allocator);
core.output.init(if (flags.json) .json else .text);
if (flags.help) {
return printUsage();
}
if (command_args.items.len < 1) {
std.log.err("Usage: ml note <run_id> [options]", .{});
return error.MissingArgument;
}
const run_id = command_args.items[0];
// Parse metadata options
const text = core.flags.parseKVFlag(command_args.items, "text");
const hypothesis = core.flags.parseKVFlag(command_args.items, "hypothesis");
const outcome = core.flags.parseKVFlag(command_args.items, "outcome");
const confidence = core.flags.parseKVFlag(command_args.items, "confidence");
const privacy = core.flags.parseKVFlag(command_args.items, "privacy");
const author = core.flags.parseKVFlag(command_args.items, "author");
// Check that at least one option is provided
if (text == null and hypothesis == null and outcome == null and privacy == null) {
std.log.err("No metadata provided. Use --text, --hypothesis, --outcome, or --privacy", .{});
return error.MissingMetadata;
}
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Get DB path
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Verify run exists
const check_sql = "SELECT 1 FROM ml_runs WHERE run_id = ?;";
const check_stmt = try database.prepare(check_sql);
defer db.DB.finalize(check_stmt);
try db.DB.bindText(check_stmt, 1, run_id);
const has_row = try db.DB.step(check_stmt);
if (!has_row) {
std.log.err("Run not found: {s}", .{run_id});
return error.RunNotFound;
}
// Add text note as a tag
if (text) |t| {
try addTag(allocator, &database, run_id, "note", t, author);
}
// Add hypothesis
if (hypothesis) |h| {
try addTag(allocator, &database, run_id, "hypothesis", h, author);
}
// Add outcome
if (outcome) |o| {
try addTag(allocator, &database, run_id, "outcome", o, author);
if (confidence) |c| {
try addTag(allocator, &database, run_id, "confidence", c, author);
}
}
// Add privacy level
if (privacy) |p| {
try addTag(allocator, &database, run_id, "privacy", p, author);
}
// Checkpoint WAL
database.checkpointOnExit();
if (flags.json) {
std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"action\":\"note_added\"}}\n", .{run_id});
} else {
colors.printSuccess("✓ Added note to run {s}\n", .{run_id[0..8]});
}
}
fn addTag(
allocator: std.mem.Allocator,
database: *db.DB,
run_id: []const u8,
key: []const u8,
value: []const u8,
author: ?[]const u8,
) !void {
const full_value = if (author) |a|
try std.fmt.allocPrint(allocator, "{s} (by {s})", .{ value, a })
else
try allocator.dupe(u8, value);
defer allocator.free(full_value);
const sql = "INSERT INTO ml_tags (run_id, key, value) VALUES (?, ?, ?);";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id);
try db.DB.bindText(stmt, 2, key);
try db.DB.bindText(stmt, 3, full_value);
_ = try db.DB.step(stmt);
}
fn printUsage() !void {
std.debug.print("Usage: ml note <run_id> [options]\n\n", .{});
std.debug.print("Add metadata notes to a run.\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --text <string> Free-form annotation\n", .{});
std.debug.print(" --hypothesis <string> Research hypothesis\n", .{});
std.debug.print(" --outcome <status> Outcome: validates/refutes/inconclusive\n", .{});
std.debug.print(" --confidence <0-1> Confidence in outcome\n", .{});
std.debug.print(" --privacy <level> Privacy: private/team/public\n", .{});
std.debug.print(" --author <name> Author of the note\n", .{});
std.debug.print(" --help, -h Show this help\n", .{});
std.debug.print(" --json Output structured JSON\n\n", .{});
std.debug.print("Examples:\n", .{});
std.debug.print(" ml note abc123 --text \"Try lr=3e-4 next\"\n", .{});
std.debug.print(" ml note abc123 --hypothesis \"LR scaling helps\"\n", .{});
std.debug.print(" ml note abc123 --outcome validates --confidence 0.9\n", .{});
}