diff --git a/cli/src/commands/experiment.zig b/cli/src/commands/experiment.zig index b667ab9..09930e9 100644 --- a/cli/src/commands/experiment.zig +++ b/cli/src/commands/experiment.zig @@ -6,6 +6,7 @@ const history = @import("../utils/history.zig"); const colors = @import("../utils/colors.zig"); const cancel_cmd = @import("cancel.zig"); const crypto = @import("../utils/crypto.zig"); +const db = @import("../db.zig"); fn jsonError(command: []const u8, message: []const u8) void { std.debug.print( @@ -55,12 +56,34 @@ pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void { if (std.mem.eql(u8, command, "init")) { try executeInit(allocator, command_args.items[1..], &options); + } else if (std.mem.eql(u8, command, "create")) { + try executeCreate(allocator, command_args.items[1..], &options); } else if (std.mem.eql(u8, command, "log")) { - try executeLog(allocator, command_args.items[1..], &options); + // Route to local or server mode based on config + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + if (cfg.isLocalMode()) { + try executeLogLocal(allocator, command_args.items[1..], &options); + } else { + try executeLog(allocator, command_args.items[1..], &options); + } } else if (std.mem.eql(u8, command, "show")) { try executeShow(allocator, command_args.items[1..], &options); } else if (std.mem.eql(u8, command, "list")) { - try executeList(allocator, &options); + // Route to local or server mode based on config + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + if (cfg.isLocalMode()) { + try executeListLocal(allocator, &options); + } else { + try executeList(allocator, &options); + } } else if (std.mem.eql(u8, command, "delete")) { if (command_args.items.len < 2) { if (options.json) { @@ -142,17 +165,232 @@ fn printUsage() !void { colors.printInfo(" --json Output structured JSON\n", .{}); colors.printInfo(" --help, -h Show this help message\n", .{}); colors.printInfo("\nCommands:\n", .{}); - colors.printInfo(" init Initialize a new experiment\n", .{}); - colors.printInfo(" log Log a metric for an experiment\n", .{}); + colors.printInfo(" init Initialize a new experiment (server mode)\n", .{}); + colors.printInfo(" create --name Create experiment in local mode\n", .{}); + colors.printInfo(" log Log a metric (auto-detects mode)\n", .{}); colors.printInfo(" show Show experiment details\n", .{}); - colors.printInfo(" list List recent experiments\n", .{}); + colors.printInfo(" list List experiments (auto-detects mode)\n", .{}); colors.printInfo(" delete Cancel/delete an experiment\n", .{}); colors.printInfo("\nExamples:\n", .{}); - colors.printInfo(" ml experiment init --name \"my-experiment\" --description \"Test experiment\"\n", .{}); + colors.printInfo(" ml experiment create --name \"my-experiment\"\n", .{}); colors.printInfo(" ml experiment show abc123 --json\n", .{}); colors.printInfo(" ml experiment list --json\n", .{}); } +// Local mode implementations +fn executeCreate(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void { + var name: ?[]const u8 = null; + var artifact_path: ?[]const u8 = null; + + var i: usize = 0; + while (i < args.len) : (i += 1) { + const arg = args[i]; + if (std.mem.eql(u8, arg, "--name")) { + if (i + 1 < args.len) { + name = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--artifact-path")) { + if (i + 1 < args.len) { + artifact_path = args[i + 1]; + i += 1; + } + } + } + + if (name == null) { + if (options.json) { + jsonError("experiment.create", "--name is required"); + } else { + colors.printError("Error: --name is required\n", .{}); + } + return error.MissingArgument; + } + + // Load config + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + if (!cfg.isLocalMode()) { + if (options.json) { + jsonError("experiment.create", "create only works in local mode (sqlite://)"); + } else { + colors.printError("Error: experiment create only works in local mode (sqlite://)\n", .{}); + } + return error.NotLocalMode; + } + + // Get DB path + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + // Initialize DB + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + // Generate experiment ID + const exp_id = try db.generateUUID(allocator); + defer allocator.free(exp_id); + + // Insert experiment + const sql = "INSERT INTO ml_experiments (experiment_id, name, artifact_path) VALUES (?, ?, ?);"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + + try db.DB.bindText(stmt, 1, exp_id); + try db.DB.bindText(stmt, 2, name.?); + try db.DB.bindText(stmt, 3, artifact_path orelse ""); + + _ = try db.DB.step(stmt); + database.checkpointOnExit(); + + if (options.json) { + std.debug.print("{{\"success\":true,\"command\":\"experiment.create\",\"data\":{{\"experiment_id\":\"{s}\",\"name\":\"{s}\"}}}}\n", .{ exp_id, name.? }); + } else { + colors.printSuccess("✓ Created experiment: {s}\n", .{name.?}); + colors.printInfo(" experiment_id: {s}\n", .{exp_id}); + } +} + +fn executeLogLocal(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void { + var run_id: ?[]const u8 = null; + var name: ?[]const u8 = null; + var value: ?f64 = null; + var step: i64 = 0; + + var i: usize = 0; + while (i < args.len) : (i += 1) { + const arg = args[i]; + if (std.mem.eql(u8, arg, "--run")) { + if (i + 1 < args.len) { + run_id = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--name")) { + if (i + 1 < args.len) { + name = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--value")) { + if (i + 1 < args.len) { + value = std.fmt.parseFloat(f64, args[i + 1]) catch null; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--step")) { + if (i + 1 < args.len) { + step = std.fmt.parseInt(i64, args[i + 1], 10) catch 0; + i += 1; + } + } + } + + if (run_id == null or name == null or value == null) { + if (options.json) { + jsonError("experiment.log", "Usage: ml experiment log --run --name --value [--step ]"); + } else { + colors.printError("Usage: ml experiment log --run --name --value [--step ]\n", .{}); + } + return; + } + + // Load config + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Get DB path + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + // Initialize DB + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + // Insert metric + const sql = "INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?);"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + + try db.DB.bindText(stmt, 1, run_id.?); + try db.DB.bindText(stmt, 2, name.?); + try db.DB.bindDouble(stmt, 3, value.?); + try db.DB.bindInt64(stmt, 4, step); + + _ = try db.DB.step(stmt); + + if (options.json) { + std.debug.print("{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"run_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}}}}}}\n", .{ run_id.?, name.?, value.?, step }); + } else { + colors.printSuccess("✓ Logged metric: {s} = {d:.4} (step {d})\n", .{ name.?, value.?, step }); + } +} + +fn executeListLocal(allocator: std.mem.Allocator, options: *const ExperimentOptions) !void { + // Load config + const cfg = try config.Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Get DB path + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + // Initialize DB + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + // Query experiments + const sql = "SELECT experiment_id, name, lifecycle, created_at FROM ml_experiments ORDER BY created_at DESC;"; + const stmt = try database.prepare(sql); + defer db.DB.finalize(stmt); + + var experiments = std.ArrayList(struct { id: []const u8, name: []const u8, lifecycle: []const u8, created: []const u8 }).init(allocator); + defer { + for (experiments.items) |exp| { + allocator.free(exp.id); + allocator.free(exp.name); + allocator.free(exp.lifecycle); + allocator.free(exp.created); + } + experiments.deinit(); + } + + while (try db.DB.step(stmt)) { + const id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)); + const name = try allocator.dupe(u8, db.DB.columnText(stmt, 1)); + const lifecycle = try allocator.dupe(u8, db.DB.columnText(stmt, 2)); + const created = try allocator.dupe(u8, db.DB.columnText(stmt, 3)); + try experiments.append(.{ .id = id, .name = name, .lifecycle = lifecycle, .created = created }); + } + + if (options.json) { + std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[", .{}); + for (experiments.items, 0..) |exp, idx| { + if (idx > 0) std.debug.print(",", .{}); + std.debug.print("{{\"experiment_id\":\"{s}\",\"name\":\"{s}\",\"lifecycle\":\"{s}\",\"created_at\":\"{s}\"}}", .{ exp.id, exp.name, exp.lifecycle, exp.created }); + } + std.debug.print("],\"total\":{d}}}}}\n", .{experiments.items.len}); + } else { + if (experiments.items.len == 0) { + colors.printWarning("No experiments found. Create one with: ml experiment create --name \n", .{}); + } else { + colors.printInfo("\nExperiments:\n", .{}); + colors.printInfo("{s:-<60}\n", .{""}); + for (experiments.items) |exp| { + std.debug.print("{s} | {s} | {s} | {s}\n", .{ exp.id, exp.name, exp.lifecycle, exp.created }); + } + std.debug.print("\nTotal: {d} experiments\n", .{experiments.items.len}); + } + } +} + fn executeLog(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void { var commit_id: ?[]const u8 = null; var name: ?[]const u8 = null; diff --git a/cli/src/commands/init.zig b/cli/src/commands/init.zig index 4d6054f..f9d8877 100644 --- a/cli/src/commands/init.zig +++ b/cli/src/commands/init.zig @@ -1,23 +1,115 @@ const std = @import("std"); const Config = @import("../config.zig").Config; +const db = @import("../db.zig"); -pub fn run(_: std.mem.Allocator, args: []const []const u8) !void { +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len > 0 and (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h"))) { printUsage(); return; } - std.debug.print("ML Experiment Manager - Configuration Setup\n\n", .{}); - std.debug.print("Please create ~/.ml/config.toml with the following format:\n\n", .{}); - std.debug.print("worker_host = \"worker.local\"\n", .{}); - std.debug.print("worker_user = \"mluser\"\n", .{}); - std.debug.print("worker_base = \"/data/ml-experiments\"\n", .{}); - std.debug.print("worker_port = 22\n", .{}); - std.debug.print("api_key = \"your-api-key\"\n", .{}); - std.debug.print("\n[OK] Configuration template shown above\n", .{}); + // Parse optional CLI flags + var cli_tracking_uri: ?[]const u8 = null; + var cli_artifact_path: ?[]const u8 = null; + var cli_sync_uri: ?[]const u8 = null; + + var i: usize = 0; + while (i < args.len) : (i += 1) { + if (std.mem.eql(u8, args[i], "--tracking-uri") and i + 1 < args.len) { + cli_tracking_uri = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, args[i], "--artifact-path") and i + 1 < args.len) { + cli_artifact_path = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, args[i], "--sync-uri") and i + 1 < args.len) { + cli_sync_uri = args[i + 1]; + i += 1; + } + } + + // Load config with CLI overrides + var cfg = try Config.loadWithOverrides(allocator, cli_tracking_uri, cli_artifact_path, cli_sync_uri); + defer cfg.deinit(allocator); + + // Print resolved config + std.debug.print("Resolved config:\n", .{}); + std.debug.print(" tracking_uri = {s}", .{cfg.tracking_uri}); + + // Indicate if using default + if (cli_tracking_uri == null and std.mem.eql(u8, cfg.tracking_uri, "sqlite://./fetch_ml.db")) { + std.debug.print(" (default)\n", .{}); + } else { + std.debug.print("\n", .{}); + } + + std.debug.print(" artifact_path = {s}", .{cfg.artifact_path}); + if (cli_artifact_path == null and std.mem.eql(u8, cfg.artifact_path, "./experiments/")) { + std.debug.print(" (default)\n", .{}); + } else { + std.debug.print("\n", .{}); + } + + if (cfg.sync_uri.len > 0) { + std.debug.print(" sync_uri = {s}\n", .{cfg.sync_uri}); + } else { + std.debug.print(" sync_uri = (not set)\n", .{}); + } + std.debug.print("\n", .{}); + + // Only initialize SQLite DB in local mode + if (!cfg.isLocalMode()) { + std.debug.print("Runner mode detected (wss://). No local database needed.\n", .{}); + std.debug.print("Server: {s}:{d}\n", .{ cfg.worker_host, cfg.worker_port }); + return; + } + + // Get DB path from tracking URI + const db_path = try cfg.getDBPath(allocator); + defer allocator.free(db_path); + + // Check if DB already exists + const db_exists = blk: { + std.fs.accessAbsolute(db_path, .{}) catch |err| { + if (err == error.FileNotFound) break :blk false; + }; + break :blk true; + }; + + if (db_exists) { + std.debug.print("✓ Database already exists: {s}\n", .{db_path}); + } else { + // Create parent directories if needed + if (std.fs.path.dirname(db_path)) |dir| { + std.fs.makeDirAbsolute(dir) catch |err| { + if (err != error.PathAlreadyExists) { + std.log.err("Failed to create directory {s}: {}", .{ dir, err }); + return error.MkdirFailed; + } + }; + } + + // Initialize database (creates schema) + var database = try db.DB.init(allocator, db_path); + defer database.close(); + defer database.checkpointOnExit(); + + std.debug.print("✓ Created database: {s}\n", .{db_path}); + } + + // Verify schema by connecting + var database = try db.DB.init(allocator, db_path); + defer database.close(); + + std.debug.print("✓ Schema applied (WAL mode enabled)\n", .{}); + std.debug.print("✓ Ready for experiment tracking\n", .{}); } fn printUsage() void { - std.debug.print("Usage: ml init\n\n", .{}); - std.debug.print("Shows a template for ~/.ml/config.toml\n", .{}); + std.debug.print("Usage: ml init [OPTIONS]\n\n", .{}); + std.debug.print("Initialize local experiment tracking database\n\n", .{}); + std.debug.print("Options:\n", .{}); + std.debug.print(" --tracking-uri URI SQLite database path (e.g., sqlite://./fetch_ml.db)\n", .{}); + std.debug.print(" --artifact-path PATH Artifacts directory (default: ./experiments/)\n", .{}); + std.debug.print(" --sync-uri URI Server to sync with (e.g., wss://ml.company.com/ws)\n", .{}); + std.debug.print(" -h, --help Show this help\n", .{}); } diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig index c3fb539..0ad4d87 100644 --- a/cli/src/commands/sync.zig +++ b/cli/src/commands/sync.zig @@ -7,6 +7,7 @@ const ws = @import("../net/ws/client.zig"); const logging = @import("../utils/logging.zig"); const json = @import("../utils/json.zig"); const native_hash = @import("../utils/native_hash.zig"); +const ProgressBar = @import("../ui/progress.zig").ProgressBar; pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len == 0) { @@ -174,21 +175,17 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3); defer client.disconnect(); - // Send progress monitoring request (this would be a new opcode on the server side) - // For now, we'll just listen for any progress messages + // Initialize progress bar (will be updated as we receive progress messages) + var progress = ProgressBar.init(100, "Syncing files"); var timeout_counter: u32 = 0; const max_timeout = 30; // 30 seconds timeout - var spinner_index: usize = 0; - const spinner_chars = [_]u8{ '|', '/', '-', '\\' }; while (timeout_counter < max_timeout) { const message = client.receiveMessage(allocator) catch |err| { switch (err) { error.ConnectionClosed, error.ConnectionTimedOut => { timeout_counter += 1; - spinner_index = (spinner_index + 1) % 4; - logging.progress("Waiting for progress {c} (attempt {d}/{d})\n", .{ spinner_chars[spinner_index], timeout_counter, max_timeout }); std.Thread.sleep(1 * std.time.ns_per_s); continue; }, @@ -207,19 +204,21 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm if (parsed.value == .object) { const root = parsed.value.object; const status = json.getString(root, "status") orelse "unknown"; - const progress = json.getInt(root, "progress") orelse 0; - const total = json.getInt(root, "total") orelse 0; + const current = json.getInt(root, "progress") orelse 0; + const total = json.getInt(root, "total") orelse 100; if (std.mem.eql(u8, status, "complete")) { - logging.success("Sync complete!\n", .{}); + progress.finish(); + colors.printSuccess("Sync complete!\n", .{}); break; } else if (std.mem.eql(u8, status, "error")) { const error_msg = json.getString(root, "error") orelse "Unknown error"; - logging.err("Sync failed: {s}\n", .{error_msg}); + colors.printError("Sync failed: {s}\n", .{error_msg}); return error.SyncFailed; } else { - const pct = if (total > 0) @divTrunc(progress * 100, total) else 0; - logging.progress("Sync: {s} ({d}/{d} files, {d}%)\n", .{ status, progress, total, pct }); + // Update progress bar + progress.total = @intCast(total); + progress.update(@intCast(current)); } } else { logging.success("Sync progress: {s}\n", .{message}); diff --git a/cli/src/main.zig b/cli/src/main.zig index 981d1c1..50df8fe 100644 --- a/cli/src/main.zig +++ b/cli/src/main.zig @@ -35,7 +35,7 @@ pub fn main() !void { try @import("commands/jupyter.zig").run(allocator, args[2..]); }, 'i' => if (std.mem.eql(u8, command, "init")) { - colors.printInfo("Setup configuration interactively\n", .{}); + try @import("commands/init.zig").run(allocator, args[2..]); } else if (std.mem.eql(u8, command, "info")) { try @import("commands/info.zig").run(allocator, args[2..]); } else handleUnknownCommand(command), @@ -61,7 +61,9 @@ pub fn main() !void { } else handleUnknownCommand(command), 'r' => if (std.mem.eql(u8, command, "requeue")) { try @import("commands/requeue.zig").run(allocator, args[2..]); - }, + } else if (std.mem.eql(u8, command, "run")) { + try @import("commands/run.zig").execute(allocator, args[2..]); + } else handleUnknownCommand(command), 'q' => if (std.mem.eql(u8, command, "queue")) { try @import("commands/queue.zig").run(allocator, args[2..]); }, @@ -86,7 +88,7 @@ pub fn main() !void { }, 'l' => if (std.mem.eql(u8, command, "logs")) { try @import("commands/logs.zig").run(allocator, args[2..]); - }, + } else handleUnknownCommand(command), else => { colors.printError("Unknown command: {s}\n", .{args[1]}); printUsage(); @@ -101,7 +103,11 @@ fn printUsage() void { std.debug.print("Usage: ml [options]\n\n", .{}); std.debug.print("Commands:\n", .{}); std.debug.print(" jupyter Jupyter workspace management\n", .{}); - std.debug.print(" init Setup configuration interactively\n", .{}); + std.debug.print(" init Initialize local experiment tracking database\n", .{}); + std.debug.print(" experiment Manage experiments (auto-detects local/server mode)\n", .{}); + std.debug.print(" - create, list, log, show, delete\n", .{}); + std.debug.print(" run Manage runs (auto-detects local/server mode)\n", .{}); + std.debug.print(" - start, finish, fail, list\n", .{}); std.debug.print(" annotate Add an annotation to run_manifest.json (--note \"...\")\n", .{}); std.debug.print(" compare Compare two runs (show differences)\n", .{}); std.debug.print(" export Export experiment bundle (--anonymize for safe sharing)\n", .{});