- Add sendSyncRun method for run synchronization - Add sendRerunRequest method for queue rerun - Add sync_run (0x26) and rerun_request (0x27) opcodes - Fix protocol import path to relative path - Fix db.Stmt type alias usage in sync.zig
285 lines
9.9 KiB
Zig
285 lines
9.9 KiB
Zig
const std = @import("std");
|
|
const colors = @import("../utils/colors.zig");
|
|
const config = @import("../config.zig");
|
|
const db = @import("../db.zig");
|
|
const ws = @import("../net/ws/client.zig");
|
|
const crypto = @import("../utils/crypto.zig");
|
|
const mode = @import("../mode.zig");
|
|
const core = @import("../core.zig");
|
|
const manifest_lib = @import("../manifest.zig");
|
|
|
|
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|
var flags = core.flags.CommonFlags{};
|
|
var specific_run_id: ?[]const u8 = null;
|
|
|
|
for (args) |arg| {
|
|
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
|
return printUsage();
|
|
} else if (std.mem.eql(u8, arg, "--json")) {
|
|
flags.json = true;
|
|
} else if (!std.mem.startsWith(u8, arg, "--")) {
|
|
specific_run_id = arg;
|
|
}
|
|
}
|
|
|
|
core.output.init(if (flags.json) .json else .text);
|
|
|
|
const cfg = try config.Config.load(allocator);
|
|
defer {
|
|
var mut_cfg = cfg;
|
|
mut_cfg.deinit(allocator);
|
|
}
|
|
|
|
const mode_result = try mode.detect(allocator, cfg);
|
|
if (mode.isOffline(mode_result.mode)) {
|
|
colors.printError("ml sync requires server connection\n", .{});
|
|
return error.RequiresServer;
|
|
}
|
|
|
|
const db_path = try cfg.getDBPath(allocator);
|
|
defer allocator.free(db_path);
|
|
|
|
var database = try db.DB.init(allocator, db_path);
|
|
defer database.close();
|
|
|
|
var runs_to_sync: std.ArrayList(RunInfo) = .empty;
|
|
defer {
|
|
for (runs_to_sync.items) |*r| r.deinit(allocator);
|
|
runs_to_sync.deinit(allocator);
|
|
}
|
|
|
|
if (specific_run_id) |run_id| {
|
|
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;";
|
|
const stmt = try database.prepare(sql);
|
|
defer db.DB.finalize(stmt);
|
|
try db.DB.bindText(stmt, 1, run_id);
|
|
if (try db.DB.step(stmt)) {
|
|
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
|
} else {
|
|
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
|
|
return;
|
|
}
|
|
} else {
|
|
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;";
|
|
const stmt = try database.prepare(sql);
|
|
defer db.DB.finalize(stmt);
|
|
while (try db.DB.step(stmt)) {
|
|
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
|
}
|
|
}
|
|
|
|
if (runs_to_sync.items.len == 0) {
|
|
if (!flags.json) colors.printSuccess("All runs already synced!\n", .{});
|
|
return;
|
|
}
|
|
|
|
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
|
defer allocator.free(api_key_hash);
|
|
|
|
const ws_url = try cfg.getWebSocketUrl(allocator);
|
|
defer allocator.free(ws_url);
|
|
|
|
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
|
defer client.close();
|
|
|
|
var success_count: usize = 0;
|
|
for (runs_to_sync.items) |run_info| {
|
|
if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]});
|
|
syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| {
|
|
if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err });
|
|
continue;
|
|
};
|
|
const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;";
|
|
const update_stmt = try database.prepare(update_sql);
|
|
defer db.DB.finalize(update_stmt);
|
|
try db.DB.bindText(update_stmt, 1, run_info.run_id);
|
|
_ = try db.DB.step(update_stmt);
|
|
success_count += 1;
|
|
}
|
|
|
|
database.checkpointOnExit();
|
|
|
|
if (flags.json) {
|
|
std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len });
|
|
} else {
|
|
colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len });
|
|
}
|
|
}
|
|
|
|
const RunInfo = struct {
|
|
run_id: []const u8,
|
|
experiment_id: []const u8,
|
|
name: []const u8,
|
|
status: []const u8,
|
|
start_time: []const u8,
|
|
end_time: ?[]const u8,
|
|
|
|
fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo {
|
|
const s = stmt.?;
|
|
return RunInfo{
|
|
.run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)),
|
|
.experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)),
|
|
.name = try allocator.dupe(u8, db.DB.columnText(s, 2)),
|
|
.status = try allocator.dupe(u8, db.DB.columnText(s, 3)),
|
|
.start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)),
|
|
.end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null,
|
|
};
|
|
}
|
|
|
|
fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void {
|
|
allocator.free(self.run_id);
|
|
allocator.free(self.experiment_id);
|
|
allocator.free(self.name);
|
|
allocator.free(self.status);
|
|
allocator.free(self.start_time);
|
|
if (self.end_time) |et| allocator.free(et);
|
|
}
|
|
};
|
|
|
|
fn syncRun(
|
|
allocator: std.mem.Allocator,
|
|
database: *db.DB,
|
|
client: *ws.Client,
|
|
run_info: RunInfo,
|
|
api_key_hash: []const u8,
|
|
) !void {
|
|
// Get metrics for this run
|
|
var metrics: std.ArrayList(Metric) = .empty;
|
|
defer {
|
|
for (metrics.items) |*m| m.deinit(allocator);
|
|
metrics.deinit(allocator);
|
|
}
|
|
|
|
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
|
|
const metrics_stmt = try database.prepare(metrics_sql);
|
|
defer db.DB.finalize(metrics_stmt);
|
|
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
|
|
|
|
while (try db.DB.step(metrics_stmt)) {
|
|
try metrics.append(allocator, Metric{
|
|
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
|
|
.value = db.DB.columnDouble(metrics_stmt, 1),
|
|
.step = db.DB.columnInt64(metrics_stmt, 2),
|
|
});
|
|
}
|
|
|
|
// Get params for this run
|
|
var params: std.ArrayList(Param) = .empty;
|
|
defer {
|
|
for (params.items) |*p| p.deinit(allocator);
|
|
params.deinit(allocator);
|
|
}
|
|
|
|
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
|
|
const params_stmt = try database.prepare(params_sql);
|
|
defer db.DB.finalize(params_stmt);
|
|
try db.DB.bindText(params_stmt, 1, run_info.run_id);
|
|
|
|
while (try db.DB.step(params_stmt)) {
|
|
try params.append(allocator, Param{
|
|
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
|
|
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
|
|
});
|
|
}
|
|
|
|
// Build sync JSON
|
|
var sync_json: std.ArrayList(u8) = .empty;
|
|
defer sync_json.deinit(allocator);
|
|
const writer = sync_json.writer(allocator);
|
|
|
|
try writer.writeAll("{");
|
|
try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id});
|
|
try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id});
|
|
try writer.print("\"name\":\"{s}\",", .{run_info.name});
|
|
try writer.print("\"status\":\"{s}\",", .{run_info.status});
|
|
try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time});
|
|
if (run_info.end_time) |et| {
|
|
try writer.print("\"end_time\":\"{s}\",", .{et});
|
|
} else {
|
|
try writer.writeAll("\"end_time\":null,");
|
|
}
|
|
|
|
// Add metrics
|
|
try writer.writeAll("\"metrics\":[");
|
|
for (metrics.items, 0..) |m, i| {
|
|
if (i > 0) try writer.writeAll(",");
|
|
try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step });
|
|
}
|
|
try writer.writeAll("],");
|
|
|
|
// Add params
|
|
try writer.writeAll("\"params\":[");
|
|
for (params.items, 0..) |p, i| {
|
|
if (i > 0) try writer.writeAll(",");
|
|
try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value });
|
|
}
|
|
try writer.writeAll("]}");
|
|
|
|
// Send sync_run message
|
|
try client.sendSyncRun(sync_json.items, api_key_hash);
|
|
|
|
// Wait for sync_ack
|
|
const response = try client.receiveMessage(allocator);
|
|
defer allocator.free(response);
|
|
|
|
if (std.mem.indexOf(u8, response, "sync_ack") == null) {
|
|
return error.SyncRejected;
|
|
}
|
|
}
|
|
|
|
const Metric = struct {
|
|
key: []const u8,
|
|
value: f64,
|
|
step: i64,
|
|
|
|
fn deinit(self: *Metric, allocator: std.mem.Allocator) void {
|
|
allocator.free(self.key);
|
|
}
|
|
};
|
|
|
|
const Param = struct {
|
|
key: []const u8,
|
|
value: []const u8,
|
|
|
|
fn deinit(self: *Param, allocator: std.mem.Allocator) void {
|
|
allocator.free(self.key);
|
|
allocator.free(self.value);
|
|
}
|
|
};
|
|
|
|
fn printUsage() void {
|
|
std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{});
|
|
std.debug.print("Push local experiment runs to the server.\n\n", .{});
|
|
std.debug.print("Options:\n", .{});
|
|
std.debug.print(" --json Output structured JSON\n", .{});
|
|
std.debug.print(" --help, -h Show this help message\n\n", .{});
|
|
std.debug.print("Examples:\n", .{});
|
|
std.debug.print(" ml sync # Sync all unsynced runs\n", .{});
|
|
std.debug.print(" ml sync abc123 # Sync specific run\n", .{});
|
|
}
|
|
|
|
/// Find the git root directory by walking up from the given path
|
|
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
|
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const path = try std.fs.realpath(start_path, &buf);
|
|
|
|
var current = path;
|
|
while (true) {
|
|
// Check if .git exists in current directory
|
|
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
|
|
defer allocator.free(git_path);
|
|
|
|
if (std.fs.accessAbsolute(git_path, .{})) {
|
|
// Found .git directory
|
|
return try allocator.dupe(u8, current);
|
|
} else |_| {
|
|
// .git not found here, try parent
|
|
const parent = std.fs.path.dirname(current);
|
|
if (parent == null or std.mem.eql(u8, parent.?, current)) {
|
|
// Reached root without finding .git
|
|
return null;
|
|
}
|
|
current = parent.?;
|
|
}
|
|
}
|
|
}
|