fetch_ml/cli/src/commands/experiment.zig
Jeremie Fraeys 68062831b0
refactor(cli): remove redundant doc comments from command files
Removed duplicate help text from doc comments:
- log.zig: Removed usage examples (in printUsage)
- annotate.zig: Removed usage examples (in printUsage)
- experiment.zig: Removed usage examples (in printUsage)

Rationale: printUsage() already contains detailed help text.
Doc comments should not duplicate this information.

All tests pass.
2026-03-05 11:06:28 -05:00

362 lines
13 KiB
Zig

const std = @import("std");
const config = @import("../config.zig");
const db = @import("../db.zig");
const core = @import("../core.zig");
const mode = @import("../mode.zig");
const uuid = @import("../utils/uuid.zig");
const crypto = @import("../utils/crypto.zig");
const ws = @import("../net/ws/client.zig");
const ExperimentInfo = struct {
id: []const u8,
name: []const u8,
description: []const u8,
created_at: []const u8,
status: []const u8,
synced: bool,
fn deinit(self: *ExperimentInfo, allocator: std.mem.Allocator) void {
allocator.free(self.id);
allocator.free(self.name);
allocator.free(self.description);
allocator.free(self.created_at);
allocator.free(self.status);
}
};
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var command_args = try core.flags.parseCommon(allocator, args, &flags);
defer command_args.deinit(allocator);
core.output.setMode(if (flags.json) .json else .text);
if (flags.help or command_args.items.len == 0) {
return printUsage();
}
const subcommand = command_args.items[0];
const sub_args = if (command_args.items.len > 1) command_args.items[1..] else &[_][]const u8{};
if (std.mem.eql(u8, subcommand, "create")) {
return try createExperiment(allocator, sub_args, flags.json);
} else if (std.mem.eql(u8, subcommand, "list")) {
return try listExperiments(allocator, sub_args, flags.json);
} else if (std.mem.eql(u8, subcommand, "show")) {
return try showExperiment(allocator, sub_args, flags.json);
} else {
core.output.err("Unknown subcommand");
return printUsage();
}
}
fn createExperiment(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
var name: ?[]const u8 = null;
var description: ?[]const u8 = null;
var i: usize = 0;
while (i < args.len) : (i += 1) {
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
name = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, args[i], "--description") and i + 1 < args.len) {
description = args[i + 1];
i += 1;
}
}
if (name == null) {
core.output.err("--name is required");
return error.MissingArgument;
}
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Check mode
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
// Local mode: create in SQLite
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Insert experiment with synced=0 (not synced to server yet)
const sql = "INSERT INTO ml_experiments (experiment_id, name, description, status, synced) VALUES (?, ?, ?, 'active', 0);";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
const exp_id = try generateExperimentID(allocator);
defer allocator.free(exp_id);
try db.DB.bindText(stmt, 1, exp_id);
try db.DB.bindText(stmt, 2, name.?);
try db.DB.bindText(stmt, 3, description orelse "");
_ = try db.DB.step(stmt);
// Update config with new experiment
var mut_cfg = cfg;
if (mut_cfg.experiment == null) {
mut_cfg.experiment = config.ExperimentConfig{
.name = "",
.entrypoint = "",
};
}
mut_cfg.experiment.?.name = try allocator.dupe(u8, name.?);
try mut_cfg.save(allocator);
database.checkpointOnExit();
if (json) {
std.debug.print("{{\"success\":true,\"experiment_id\":\"{s}\",\"name\":\"{s}\"}}\n", .{ exp_id, name.? });
} else {
std.debug.print("Created experiment: {s} ({s})\n", .{ name.?, exp_id[0..8] });
}
} else {
// Server mode: send to server via WebSocket
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendCreateExperiment(api_key_hash, name.?, description orelse "");
// Receive response
const response = try client.receiveMessage(allocator);
defer allocator.free(response);
// Parse response (expecting JSON with experiment_id)
if (std.mem.indexOf(u8, response, "experiment_id") != null) {
// Also update local config
var mut_cfg = cfg;
if (mut_cfg.experiment == null) {
mut_cfg.experiment = config.ExperimentConfig{
.name = "",
.entrypoint = "",
};
}
mut_cfg.experiment.?.name = try allocator.dupe(u8, name.?);
try mut_cfg.save(allocator);
if (json) {
std.debug.print("{{\"success\":true,\"name\":\"{s}\",\"source\":\"server\"}}\n", .{name.?});
} else {
std.debug.print("Created experiment on server: {s}\n", .{name.?});
}
} else {
std.debug.print("Failed to create experiment on server: {s}\n", .{response});
return error.ServerError;
}
}
}
fn listExperiments(allocator: std.mem.Allocator, _: []const []const u8, json: bool) !void {
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
// Local mode: list from SQLite
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
const sql = "SELECT experiment_id, name, description, created_at, status, synced FROM ml_experiments ORDER BY created_at DESC;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
var experiments = try std.ArrayList(ExperimentInfo).initCapacity(allocator, 16);
defer {
for (experiments.items) |*e| e.deinit(allocator);
experiments.deinit(allocator);
}
while (try db.DB.step(stmt)) {
try experiments.append(allocator, ExperimentInfo{
.id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)),
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 1)),
.description = try allocator.dupe(u8, db.DB.columnText(stmt, 2)),
.created_at = try allocator.dupe(u8, db.DB.columnText(stmt, 3)),
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 4)),
.synced = db.DB.columnInt64(stmt, 5) != 0,
});
}
if (json) {
std.debug.print("[", .{});
for (experiments.items, 0..) |e, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("{{\"id\":\"{s}\",\"name\":\"{s}\",\"status\":\"{s}\",\"description\":\"{s}\",\"synced\":{s}}}", .{ e.id, e.name, e.status, e.description, if (e.synced) "true" else "false" });
}
std.debug.print("]\n", .{});
} else {
if (experiments.items.len == 0) {
std.debug.print("No experiments found.\n", .{});
} else {
for (experiments.items) |e| {
const sync_indicator = if (e.synced) "S" else "U";
std.debug.print("{s}\t{s}\t{s}\t{s}\n", .{ sync_indicator, e.id[0..8], e.name, e.status });
}
}
}
} else {
// Server mode: query server via WebSocket
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendListExperiments(api_key_hash);
// Receive response
const response = try client.receiveMessage(allocator);
defer allocator.free(response);
// For now, just display raw response
if (json) {
std.debug.print("{s}\n", .{response});
} else {
std.debug.print("{s}\n", .{response});
}
}
}
fn showExperiment(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
if (args.len == 0) {
core.output.err("experiment_id required");
return error.MissingArgument;
}
const exp_id = args[0];
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
// Local mode: show from SQLite
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
// Get experiment details
const exp_sql = "SELECT experiment_id, name, description, created_at, status, synced FROM ml_experiments WHERE experiment_id = ?;";
const exp_stmt = try database.prepare(exp_sql);
defer db.DB.finalize(exp_stmt);
try db.DB.bindText(exp_stmt, 1, exp_id);
if (!try db.DB.step(exp_stmt)) {
core.output.err("Experiment not found");
return error.NotFound;
}
const name = db.DB.columnText(exp_stmt, 1);
const description = db.DB.columnText(exp_stmt, 2);
const created_at = db.DB.columnText(exp_stmt, 3);
const status = db.DB.columnText(exp_stmt, 4);
const synced = db.DB.columnInt64(exp_stmt, 5) != 0;
// Get run count and last run date
const runs_sql =
"SELECT COUNT(*), MAX(start_time) FROM ml_runs WHERE experiment_id = ?;";
const runs_stmt = try database.prepare(runs_sql);
defer db.DB.finalize(runs_stmt);
try db.DB.bindText(runs_stmt, 1, exp_id);
var run_count: i64 = 0;
var last_run: ?[]const u8 = null;
if (try db.DB.step(runs_stmt)) {
run_count = db.DB.columnInt64(runs_stmt, 0);
if (db.DB.columnText(runs_stmt, 1).len > 0) {
last_run = try allocator.dupe(u8, db.DB.columnText(runs_stmt, 1));
}
}
defer if (last_run) |lr| allocator.free(lr);
if (json) {
std.debug.print("{{\"experiment_id\":\"{s}\",\"name\":\"{s}\",\"description\":\"{s}\",\"status\":\"{s}\",\"created_at\":\"{s}\",\"synced\":{s},\"run_count\":{d},\"last_run\":\"{s}\"}}\n", .{
exp_id, name, description, status, created_at,
if (synced) "true" else "false", run_count, last_run orelse "null",
});
} else {
std.debug.print("{s}\t{s}\t{s}\n", .{ name, exp_id, status });
if (description.len > 0) {
std.debug.print("Description\t{s}\n", .{description});
}
std.debug.print("Created\t{s}\n", .{created_at});
std.debug.print("Synced\t{s}\n", .{if (synced) "yes" else "no"});
std.debug.print("Runs\t{d}\n", .{run_count});
if (last_run) |lr| {
std.debug.print("LastRun\t{s}\n", .{lr});
}
}
} else {
// Server mode: query server via WebSocket
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendGetExperimentByID(api_key_hash, exp_id);
// Receive response
const response = try client.receiveMessage(allocator);
defer allocator.free(response);
if (json) {
std.debug.print("{s}\n", .{response});
} else {
std.debug.print("{s}\n", .{response});
}
}
}
fn generateExperimentID(allocator: std.mem.Allocator) ![]const u8 {
return try uuid.generateV4(allocator);
}
fn printUsage() !void {
std.debug.print("Usage: ml experiment <subcommand> [options]\n\n", .{});
std.debug.print("Subcommands:\n", .{});
std.debug.print("\tcreate --name <name> [--description <desc>]\tCreate new experiment\n", .{});
std.debug.print("\tlist\t\t\t\t\t\tList experiments\n", .{});
std.debug.print("\tshow <experiment_id>\t\t\t\tShow experiment details\n", .{});
std.debug.print("\nOptions:\n", .{});
std.debug.print("\t--name <string>\t\tExperiment name (required for create)\n", .{});
std.debug.print("\t--description <string>\tExperiment description\n", .{});
std.debug.print("\t--help, -h\t\tShow this help\n", .{});
std.debug.print("\t--json\t\t\tOutput structured JSON\n\n", .{});
std.debug.print("Examples:\n", .{});
std.debug.print("\tml experiment create --name \"baseline-cnn\"\n", .{});
std.debug.print("\tml experiment list\n", .{});
}