feat(cli): unified commands and local mode support

- Update experiment.zig with unified commands (local + server modes)
- Add init.zig for local project initialization
- Update sync.zig for project synchronization
- Update main.zig to route new local mode commands (experiment, run, log)
- Support automatic mode detection from config (sqlite:// vs wss://)
This commit is contained in:
Jeremie Fraeys 2026-02-20 15:51:04 -05:00
parent 2c596038b5
commit 7ce0fd251e
No known key found for this signature in database
4 changed files with 368 additions and 33 deletions

View file

@ -6,6 +6,7 @@ const history = @import("../utils/history.zig");
const colors = @import("../utils/colors.zig");
const cancel_cmd = @import("cancel.zig");
const crypto = @import("../utils/crypto.zig");
const db = @import("../db.zig");
fn jsonError(command: []const u8, message: []const u8) void {
std.debug.print(
@ -55,12 +56,34 @@ pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (std.mem.eql(u8, command, "init")) {
try executeInit(allocator, command_args.items[1..], &options);
} else if (std.mem.eql(u8, command, "create")) {
try executeCreate(allocator, command_args.items[1..], &options);
} else if (std.mem.eql(u8, command, "log")) {
try executeLog(allocator, command_args.items[1..], &options);
// Route to local or server mode based on config
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
if (cfg.isLocalMode()) {
try executeLogLocal(allocator, command_args.items[1..], &options);
} else {
try executeLog(allocator, command_args.items[1..], &options);
}
} else if (std.mem.eql(u8, command, "show")) {
try executeShow(allocator, command_args.items[1..], &options);
} else if (std.mem.eql(u8, command, "list")) {
try executeList(allocator, &options);
// Route to local or server mode based on config
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
if (cfg.isLocalMode()) {
try executeListLocal(allocator, &options);
} else {
try executeList(allocator, &options);
}
} else if (std.mem.eql(u8, command, "delete")) {
if (command_args.items.len < 2) {
if (options.json) {
@ -142,17 +165,232 @@ fn printUsage() !void {
colors.printInfo(" --json Output structured JSON\n", .{});
colors.printInfo(" --help, -h Show this help message\n", .{});
colors.printInfo("\nCommands:\n", .{});
colors.printInfo(" init Initialize a new experiment\n", .{});
colors.printInfo(" log Log a metric for an experiment\n", .{});
colors.printInfo(" init Initialize a new experiment (server mode)\n", .{});
colors.printInfo(" create --name <name> Create experiment in local mode\n", .{});
colors.printInfo(" log Log a metric (auto-detects mode)\n", .{});
colors.printInfo(" show <commit_id> Show experiment details\n", .{});
colors.printInfo(" list List recent experiments\n", .{});
colors.printInfo(" list List experiments (auto-detects mode)\n", .{});
colors.printInfo(" delete <alias|commit> Cancel/delete an experiment\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml experiment init --name \"my-experiment\" --description \"Test experiment\"\n", .{});
colors.printInfo(" ml experiment create --name \"my-experiment\"\n", .{});
colors.printInfo(" ml experiment show abc123 --json\n", .{});
colors.printInfo(" ml experiment list --json\n", .{});
}
// Local mode implementations
fn executeCreate(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
var name: ?[]const u8 = null;
var artifact_path: ?[]const u8 = null;
var i: usize = 0;
while (i < args.len) : (i += 1) {
const arg = args[i];
if (std.mem.eql(u8, arg, "--name")) {
if (i + 1 < args.len) {
name = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--artifact-path")) {
if (i + 1 < args.len) {
artifact_path = args[i + 1];
i += 1;
}
}
}
if (name == null) {
if (options.json) {
jsonError("experiment.create", "--name is required");
} else {
colors.printError("Error: --name is required\n", .{});
}
return error.MissingArgument;
}
// Load config
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
if (!cfg.isLocalMode()) {
if (options.json) {
jsonError("experiment.create", "create only works in local mode (sqlite://)");
} else {
colors.printError("Error: experiment create only works in local mode (sqlite://)\n", .{});
}
return error.NotLocalMode;
}
// Get DB path
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
// Initialize DB
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Generate experiment ID
const exp_id = try db.generateUUID(allocator);
defer allocator.free(exp_id);
// Insert experiment
const sql = "INSERT INTO ml_experiments (experiment_id, name, artifact_path) VALUES (?, ?, ?);";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, exp_id);
try db.DB.bindText(stmt, 2, name.?);
try db.DB.bindText(stmt, 3, artifact_path orelse "");
_ = try db.DB.step(stmt);
database.checkpointOnExit();
if (options.json) {
std.debug.print("{{\"success\":true,\"command\":\"experiment.create\",\"data\":{{\"experiment_id\":\"{s}\",\"name\":\"{s}\"}}}}\n", .{ exp_id, name.? });
} else {
colors.printSuccess("✓ Created experiment: {s}\n", .{name.?});
colors.printInfo(" experiment_id: {s}\n", .{exp_id});
}
}
fn executeLogLocal(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
var run_id: ?[]const u8 = null;
var name: ?[]const u8 = null;
var value: ?f64 = null;
var step: i64 = 0;
var i: usize = 0;
while (i < args.len) : (i += 1) {
const arg = args[i];
if (std.mem.eql(u8, arg, "--run")) {
if (i + 1 < args.len) {
run_id = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--name")) {
if (i + 1 < args.len) {
name = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--value")) {
if (i + 1 < args.len) {
value = std.fmt.parseFloat(f64, args[i + 1]) catch null;
i += 1;
}
} else if (std.mem.eql(u8, arg, "--step")) {
if (i + 1 < args.len) {
step = std.fmt.parseInt(i64, args[i + 1], 10) catch 0;
i += 1;
}
}
}
if (run_id == null or name == null or value == null) {
if (options.json) {
jsonError("experiment.log", "Usage: ml experiment log --run <id> --name <name> --value <value> [--step <step>]");
} else {
colors.printError("Usage: ml experiment log --run <id> --name <name> --value <value> [--step <step>]\n", .{});
}
return;
}
// Load config
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Get DB path
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
// Initialize DB
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Insert metric
const sql = "INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?);";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id.?);
try db.DB.bindText(stmt, 2, name.?);
try db.DB.bindDouble(stmt, 3, value.?);
try db.DB.bindInt64(stmt, 4, step);
_ = try db.DB.step(stmt);
if (options.json) {
std.debug.print("{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"run_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}}}}}}\n", .{ run_id.?, name.?, value.?, step });
} else {
colors.printSuccess("✓ Logged metric: {s} = {d:.4} (step {d})\n", .{ name.?, value.?, step });
}
}
fn executeListLocal(allocator: std.mem.Allocator, options: *const ExperimentOptions) !void {
// Load config
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Get DB path
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
// Initialize DB
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Query experiments
const sql = "SELECT experiment_id, name, lifecycle, created_at FROM ml_experiments ORDER BY created_at DESC;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
var experiments = std.ArrayList(struct { id: []const u8, name: []const u8, lifecycle: []const u8, created: []const u8 }).init(allocator);
defer {
for (experiments.items) |exp| {
allocator.free(exp.id);
allocator.free(exp.name);
allocator.free(exp.lifecycle);
allocator.free(exp.created);
}
experiments.deinit();
}
while (try db.DB.step(stmt)) {
const id = try allocator.dupe(u8, db.DB.columnText(stmt, 0));
const name = try allocator.dupe(u8, db.DB.columnText(stmt, 1));
const lifecycle = try allocator.dupe(u8, db.DB.columnText(stmt, 2));
const created = try allocator.dupe(u8, db.DB.columnText(stmt, 3));
try experiments.append(.{ .id = id, .name = name, .lifecycle = lifecycle, .created = created });
}
if (options.json) {
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[", .{});
for (experiments.items, 0..) |exp, idx| {
if (idx > 0) std.debug.print(",", .{});
std.debug.print("{{\"experiment_id\":\"{s}\",\"name\":\"{s}\",\"lifecycle\":\"{s}\",\"created_at\":\"{s}\"}}", .{ exp.id, exp.name, exp.lifecycle, exp.created });
}
std.debug.print("],\"total\":{d}}}}}\n", .{experiments.items.len});
} else {
if (experiments.items.len == 0) {
colors.printWarning("No experiments found. Create one with: ml experiment create --name <name>\n", .{});
} else {
colors.printInfo("\nExperiments:\n", .{});
colors.printInfo("{s:-<60}\n", .{""});
for (experiments.items) |exp| {
std.debug.print("{s} | {s} | {s} | {s}\n", .{ exp.id, exp.name, exp.lifecycle, exp.created });
}
std.debug.print("\nTotal: {d} experiments\n", .{experiments.items.len});
}
}
}
fn executeLog(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
var commit_id: ?[]const u8 = null;
var name: ?[]const u8 = null;

View file

@ -1,23 +1,115 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const db = @import("../db.zig");
pub fn run(_: std.mem.Allocator, args: []const []const u8) !void {
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len > 0 and (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h"))) {
printUsage();
return;
}
std.debug.print("ML Experiment Manager - Configuration Setup\n\n", .{});
std.debug.print("Please create ~/.ml/config.toml with the following format:\n\n", .{});
std.debug.print("worker_host = \"worker.local\"\n", .{});
std.debug.print("worker_user = \"mluser\"\n", .{});
std.debug.print("worker_base = \"/data/ml-experiments\"\n", .{});
std.debug.print("worker_port = 22\n", .{});
std.debug.print("api_key = \"your-api-key\"\n", .{});
std.debug.print("\n[OK] Configuration template shown above\n", .{});
// Parse optional CLI flags
var cli_tracking_uri: ?[]const u8 = null;
var cli_artifact_path: ?[]const u8 = null;
var cli_sync_uri: ?[]const u8 = null;
var i: usize = 0;
while (i < args.len) : (i += 1) {
if (std.mem.eql(u8, args[i], "--tracking-uri") and i + 1 < args.len) {
cli_tracking_uri = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, args[i], "--artifact-path") and i + 1 < args.len) {
cli_artifact_path = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, args[i], "--sync-uri") and i + 1 < args.len) {
cli_sync_uri = args[i + 1];
i += 1;
}
}
// Load config with CLI overrides
var cfg = try Config.loadWithOverrides(allocator, cli_tracking_uri, cli_artifact_path, cli_sync_uri);
defer cfg.deinit(allocator);
// Print resolved config
std.debug.print("Resolved config:\n", .{});
std.debug.print(" tracking_uri = {s}", .{cfg.tracking_uri});
// Indicate if using default
if (cli_tracking_uri == null and std.mem.eql(u8, cfg.tracking_uri, "sqlite://./fetch_ml.db")) {
std.debug.print(" (default)\n", .{});
} else {
std.debug.print("\n", .{});
}
std.debug.print(" artifact_path = {s}", .{cfg.artifact_path});
if (cli_artifact_path == null and std.mem.eql(u8, cfg.artifact_path, "./experiments/")) {
std.debug.print(" (default)\n", .{});
} else {
std.debug.print("\n", .{});
}
if (cfg.sync_uri.len > 0) {
std.debug.print(" sync_uri = {s}\n", .{cfg.sync_uri});
} else {
std.debug.print(" sync_uri = (not set)\n", .{});
}
std.debug.print("\n", .{});
// Only initialize SQLite DB in local mode
if (!cfg.isLocalMode()) {
std.debug.print("Runner mode detected (wss://). No local database needed.\n", .{});
std.debug.print("Server: {s}:{d}\n", .{ cfg.worker_host, cfg.worker_port });
return;
}
// Get DB path from tracking URI
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
// Check if DB already exists
const db_exists = blk: {
std.fs.accessAbsolute(db_path, .{}) catch |err| {
if (err == error.FileNotFound) break :blk false;
};
break :blk true;
};
if (db_exists) {
std.debug.print("✓ Database already exists: {s}\n", .{db_path});
} else {
// Create parent directories if needed
if (std.fs.path.dirname(db_path)) |dir| {
std.fs.makeDirAbsolute(dir) catch |err| {
if (err != error.PathAlreadyExists) {
std.log.err("Failed to create directory {s}: {}", .{ dir, err });
return error.MkdirFailed;
}
};
}
// Initialize database (creates schema)
var database = try db.DB.init(allocator, db_path);
defer database.close();
defer database.checkpointOnExit();
std.debug.print("✓ Created database: {s}\n", .{db_path});
}
// Verify schema by connecting
var database = try db.DB.init(allocator, db_path);
defer database.close();
std.debug.print("✓ Schema applied (WAL mode enabled)\n", .{});
std.debug.print("✓ Ready for experiment tracking\n", .{});
}
fn printUsage() void {
std.debug.print("Usage: ml init\n\n", .{});
std.debug.print("Shows a template for ~/.ml/config.toml\n", .{});
std.debug.print("Usage: ml init [OPTIONS]\n\n", .{});
std.debug.print("Initialize local experiment tracking database\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --tracking-uri URI SQLite database path (e.g., sqlite://./fetch_ml.db)\n", .{});
std.debug.print(" --artifact-path PATH Artifacts directory (default: ./experiments/)\n", .{});
std.debug.print(" --sync-uri URI Server to sync with (e.g., wss://ml.company.com/ws)\n", .{});
std.debug.print(" -h, --help Show this help\n", .{});
}

View file

@ -7,6 +7,7 @@ const ws = @import("../net/ws/client.zig");
const logging = @import("../utils/logging.zig");
const json = @import("../utils/json.zig");
const native_hash = @import("../utils/native_hash.zig");
const ProgressBar = @import("../ui/progress.zig").ProgressBar;
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
@ -174,21 +175,17 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm
var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3);
defer client.disconnect();
// Send progress monitoring request (this would be a new opcode on the server side)
// For now, we'll just listen for any progress messages
// Initialize progress bar (will be updated as we receive progress messages)
var progress = ProgressBar.init(100, "Syncing files");
var timeout_counter: u32 = 0;
const max_timeout = 30; // 30 seconds timeout
var spinner_index: usize = 0;
const spinner_chars = [_]u8{ '|', '/', '-', '\\' };
while (timeout_counter < max_timeout) {
const message = client.receiveMessage(allocator) catch |err| {
switch (err) {
error.ConnectionClosed, error.ConnectionTimedOut => {
timeout_counter += 1;
spinner_index = (spinner_index + 1) % 4;
logging.progress("Waiting for progress {c} (attempt {d}/{d})\n", .{ spinner_chars[spinner_index], timeout_counter, max_timeout });
std.Thread.sleep(1 * std.time.ns_per_s);
continue;
},
@ -207,19 +204,21 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm
if (parsed.value == .object) {
const root = parsed.value.object;
const status = json.getString(root, "status") orelse "unknown";
const progress = json.getInt(root, "progress") orelse 0;
const total = json.getInt(root, "total") orelse 0;
const current = json.getInt(root, "progress") orelse 0;
const total = json.getInt(root, "total") orelse 100;
if (std.mem.eql(u8, status, "complete")) {
logging.success("Sync complete!\n", .{});
progress.finish();
colors.printSuccess("Sync complete!\n", .{});
break;
} else if (std.mem.eql(u8, status, "error")) {
const error_msg = json.getString(root, "error") orelse "Unknown error";
logging.err("Sync failed: {s}\n", .{error_msg});
colors.printError("Sync failed: {s}\n", .{error_msg});
return error.SyncFailed;
} else {
const pct = if (total > 0) @divTrunc(progress * 100, total) else 0;
logging.progress("Sync: {s} ({d}/{d} files, {d}%)\n", .{ status, progress, total, pct });
// Update progress bar
progress.total = @intCast(total);
progress.update(@intCast(current));
}
} else {
logging.success("Sync progress: {s}\n", .{message});

View file

@ -35,7 +35,7 @@ pub fn main() !void {
try @import("commands/jupyter.zig").run(allocator, args[2..]);
},
'i' => if (std.mem.eql(u8, command, "init")) {
colors.printInfo("Setup configuration interactively\n", .{});
try @import("commands/init.zig").run(allocator, args[2..]);
} else if (std.mem.eql(u8, command, "info")) {
try @import("commands/info.zig").run(allocator, args[2..]);
} else handleUnknownCommand(command),
@ -61,7 +61,9 @@ pub fn main() !void {
} else handleUnknownCommand(command),
'r' => if (std.mem.eql(u8, command, "requeue")) {
try @import("commands/requeue.zig").run(allocator, args[2..]);
},
} else if (std.mem.eql(u8, command, "run")) {
try @import("commands/run.zig").execute(allocator, args[2..]);
} else handleUnknownCommand(command),
'q' => if (std.mem.eql(u8, command, "queue")) {
try @import("commands/queue.zig").run(allocator, args[2..]);
},
@ -86,7 +88,7 @@ pub fn main() !void {
},
'l' => if (std.mem.eql(u8, command, "logs")) {
try @import("commands/logs.zig").run(allocator, args[2..]);
},
} else handleUnknownCommand(command),
else => {
colors.printError("Unknown command: {s}\n", .{args[1]});
printUsage();
@ -101,7 +103,11 @@ fn printUsage() void {
std.debug.print("Usage: ml <command> [options]\n\n", .{});
std.debug.print("Commands:\n", .{});
std.debug.print(" jupyter Jupyter workspace management\n", .{});
std.debug.print(" init Setup configuration interactively\n", .{});
std.debug.print(" init Initialize local experiment tracking database\n", .{});
std.debug.print(" experiment Manage experiments (auto-detects local/server mode)\n", .{});
std.debug.print(" - create, list, log, show, delete\n", .{});
std.debug.print(" run Manage runs (auto-detects local/server mode)\n", .{});
std.debug.print(" - start, finish, fail, list\n", .{});
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
std.debug.print(" compare <a> <b> Compare two runs (show differences)\n", .{});
std.debug.print(" export <id> Export experiment bundle (--anonymize for safe sharing)\n", .{});