fetch_ml/cli/src/commands/sync.zig
Jeremie Fraeys 382c67edfc
fix(cli): WebSocket protocol and sync command fixes
- Add sendSyncRun method for run synchronization
- Add sendRerunRequest method for queue rerun
- Add sync_run (0x26) and rerun_request (0x27) opcodes
- Fix protocol import path to relative path
- Fix db.Stmt type alias usage in sync.zig
2026-02-21 17:59:14 -05:00

285 lines
9.9 KiB
Zig

const std = @import("std");
const colors = @import("../utils/colors.zig");
const config = @import("../config.zig");
const db = @import("../db.zig");
const ws = @import("../net/ws/client.zig");
const crypto = @import("../utils/crypto.zig");
const mode = @import("../mode.zig");
const core = @import("../core.zig");
const manifest_lib = @import("../manifest.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var specific_run_id: ?[]const u8 = null;
for (args) |arg| {
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
return printUsage();
} else if (std.mem.eql(u8, arg, "--json")) {
flags.json = true;
} else if (!std.mem.startsWith(u8, arg, "--")) {
specific_run_id = arg;
}
}
core.output.init(if (flags.json) .json else .text);
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
colors.printError("ml sync requires server connection\n", .{});
return error.RequiresServer;
}
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
var runs_to_sync: std.ArrayList(RunInfo) = .empty;
defer {
for (runs_to_sync.items) |*r| r.deinit(allocator);
runs_to_sync.deinit(allocator);
}
if (specific_run_id) |run_id| {
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id);
if (try db.DB.step(stmt)) {
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
} else {
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
return;
}
} else {
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
while (try db.DB.step(stmt)) {
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
}
}
if (runs_to_sync.items.len == 0) {
if (!flags.json) colors.printSuccess("All runs already synced!\n", .{});
return;
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
var success_count: usize = 0;
for (runs_to_sync.items) |run_info| {
if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]});
syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| {
if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err });
continue;
};
const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;";
const update_stmt = try database.prepare(update_sql);
defer db.DB.finalize(update_stmt);
try db.DB.bindText(update_stmt, 1, run_info.run_id);
_ = try db.DB.step(update_stmt);
success_count += 1;
}
database.checkpointOnExit();
if (flags.json) {
std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len });
} else {
colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len });
}
}
const RunInfo = struct {
run_id: []const u8,
experiment_id: []const u8,
name: []const u8,
status: []const u8,
start_time: []const u8,
end_time: ?[]const u8,
fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo {
const s = stmt.?;
return RunInfo{
.run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)),
.experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)),
.name = try allocator.dupe(u8, db.DB.columnText(s, 2)),
.status = try allocator.dupe(u8, db.DB.columnText(s, 3)),
.start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)),
.end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null,
};
}
fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void {
allocator.free(self.run_id);
allocator.free(self.experiment_id);
allocator.free(self.name);
allocator.free(self.status);
allocator.free(self.start_time);
if (self.end_time) |et| allocator.free(et);
}
};
fn syncRun(
allocator: std.mem.Allocator,
database: *db.DB,
client: *ws.Client,
run_info: RunInfo,
api_key_hash: []const u8,
) !void {
// Get metrics for this run
var metrics: std.ArrayList(Metric) = .empty;
defer {
for (metrics.items) |*m| m.deinit(allocator);
metrics.deinit(allocator);
}
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
const metrics_stmt = try database.prepare(metrics_sql);
defer db.DB.finalize(metrics_stmt);
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
while (try db.DB.step(metrics_stmt)) {
try metrics.append(allocator, Metric{
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
.value = db.DB.columnDouble(metrics_stmt, 1),
.step = db.DB.columnInt64(metrics_stmt, 2),
});
}
// Get params for this run
var params: std.ArrayList(Param) = .empty;
defer {
for (params.items) |*p| p.deinit(allocator);
params.deinit(allocator);
}
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
const params_stmt = try database.prepare(params_sql);
defer db.DB.finalize(params_stmt);
try db.DB.bindText(params_stmt, 1, run_info.run_id);
while (try db.DB.step(params_stmt)) {
try params.append(allocator, Param{
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
});
}
// Build sync JSON
var sync_json: std.ArrayList(u8) = .empty;
defer sync_json.deinit(allocator);
const writer = sync_json.writer(allocator);
try writer.writeAll("{");
try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id});
try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id});
try writer.print("\"name\":\"{s}\",", .{run_info.name});
try writer.print("\"status\":\"{s}\",", .{run_info.status});
try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time});
if (run_info.end_time) |et| {
try writer.print("\"end_time\":\"{s}\",", .{et});
} else {
try writer.writeAll("\"end_time\":null,");
}
// Add metrics
try writer.writeAll("\"metrics\":[");
for (metrics.items, 0..) |m, i| {
if (i > 0) try writer.writeAll(",");
try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step });
}
try writer.writeAll("],");
// Add params
try writer.writeAll("\"params\":[");
for (params.items, 0..) |p, i| {
if (i > 0) try writer.writeAll(",");
try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value });
}
try writer.writeAll("]}");
// Send sync_run message
try client.sendSyncRun(sync_json.items, api_key_hash);
// Wait for sync_ack
const response = try client.receiveMessage(allocator);
defer allocator.free(response);
if (std.mem.indexOf(u8, response, "sync_ack") == null) {
return error.SyncRejected;
}
}
const Metric = struct {
key: []const u8,
value: f64,
step: i64,
fn deinit(self: *Metric, allocator: std.mem.Allocator) void {
allocator.free(self.key);
}
};
const Param = struct {
key: []const u8,
value: []const u8,
fn deinit(self: *Param, allocator: std.mem.Allocator) void {
allocator.free(self.key);
allocator.free(self.value);
}
};
fn printUsage() void {
std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{});
std.debug.print("Push local experiment runs to the server.\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --json Output structured JSON\n", .{});
std.debug.print(" --help, -h Show this help message\n\n", .{});
std.debug.print("Examples:\n", .{});
std.debug.print(" ml sync # Sync all unsynced runs\n", .{});
std.debug.print(" ml sync abc123 # Sync specific run\n", .{});
}
/// Find the git root directory by walking up from the given path
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
var buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try std.fs.realpath(start_path, &buf);
var current = path;
while (true) {
// Check if .git exists in current directory
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
defer allocator.free(git_path);
if (std.fs.accessAbsolute(git_path, .{})) {
// Found .git directory
return try allocator.dupe(u8, current);
} else |_| {
// .git not found here, try parent
const parent = std.fs.path.dirname(current);
if (parent == null or std.mem.eql(u8, parent.?, current)) {
// Reached root without finding .git
return null;
}
current = parent.?;
}
}
}