const std = @import("std"); const colors = @import("../utils/colors.zig"); const config = @import("../config.zig"); const db = @import("../db.zig"); const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const mode = @import("../mode.zig"); const core = @import("../core.zig"); const manifest_lib = @import("../manifest.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var flags = core.flags.CommonFlags{}; var specific_run_id: ?[]const u8 = null; for (args) |arg| { if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { return printUsage(); } else if (std.mem.eql(u8, arg, "--json")) { flags.json = true; } else if (!std.mem.startsWith(u8, arg, "--")) { specific_run_id = arg; } } core.output.init(if (flags.json) .json else .text); const cfg = try config.Config.load(allocator); defer { var mut_cfg = cfg; mut_cfg.deinit(allocator); } const mode_result = try mode.detect(allocator, cfg); if (mode.isOffline(mode_result.mode)) { colors.printError("ml sync requires server connection\n", .{}); return error.RequiresServer; } const db_path = try cfg.getDBPath(allocator); defer allocator.free(db_path); var database = try db.DB.init(allocator, db_path); defer database.close(); var runs_to_sync: std.ArrayList(RunInfo) = .empty; defer { for (runs_to_sync.items) |*r| r.deinit(allocator); runs_to_sync.deinit(allocator); } if (specific_run_id) |run_id| { const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;"; const stmt = try database.prepare(sql); defer db.DB.finalize(stmt); try db.DB.bindText(stmt, 1, run_id); if (try db.DB.step(stmt)) { try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator)); } else { colors.printWarning("Run {s} already synced or not found\n", .{run_id}); return; } } else { const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;"; const stmt = try database.prepare(sql); defer db.DB.finalize(stmt); while (try db.DB.step(stmt)) { try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator)); } } if (runs_to_sync.items.len == 0) { if (!flags.json) colors.printSuccess("All runs already synced!\n", .{}); return; } const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); defer allocator.free(api_key_hash); const ws_url = try cfg.getWebSocketUrl(allocator); defer allocator.free(ws_url); var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); defer client.close(); var success_count: usize = 0; for (runs_to_sync.items) |run_info| { if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]}); syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| { if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err }); continue; }; const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;"; const update_stmt = try database.prepare(update_sql); defer db.DB.finalize(update_stmt); try db.DB.bindText(update_stmt, 1, run_info.run_id); _ = try db.DB.step(update_stmt); success_count += 1; } database.checkpointOnExit(); if (flags.json) { std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len }); } else { colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len }); } } const RunInfo = struct { run_id: []const u8, experiment_id: []const u8, name: []const u8, status: []const u8, start_time: []const u8, end_time: ?[]const u8, fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo { const s = stmt.?; return RunInfo{ .run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)), .experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)), .name = try allocator.dupe(u8, db.DB.columnText(s, 2)), .status = try allocator.dupe(u8, db.DB.columnText(s, 3)), .start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)), .end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null, }; } fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void { allocator.free(self.run_id); allocator.free(self.experiment_id); allocator.free(self.name); allocator.free(self.status); allocator.free(self.start_time); if (self.end_time) |et| allocator.free(et); } }; fn syncRun( allocator: std.mem.Allocator, database: *db.DB, client: *ws.Client, run_info: RunInfo, api_key_hash: []const u8, ) !void { // Get metrics for this run var metrics: std.ArrayList(Metric) = .empty; defer { for (metrics.items) |*m| m.deinit(allocator); metrics.deinit(allocator); } const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;"; const metrics_stmt = try database.prepare(metrics_sql); defer db.DB.finalize(metrics_stmt); try db.DB.bindText(metrics_stmt, 1, run_info.run_id); while (try db.DB.step(metrics_stmt)) { try metrics.append(allocator, Metric{ .key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)), .value = db.DB.columnDouble(metrics_stmt, 1), .step = db.DB.columnInt64(metrics_stmt, 2), }); } // Get params for this run var params: std.ArrayList(Param) = .empty; defer { for (params.items) |*p| p.deinit(allocator); params.deinit(allocator); } const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;"; const params_stmt = try database.prepare(params_sql); defer db.DB.finalize(params_stmt); try db.DB.bindText(params_stmt, 1, run_info.run_id); while (try db.DB.step(params_stmt)) { try params.append(allocator, Param{ .key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)), .value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)), }); } // Build sync JSON var sync_json: std.ArrayList(u8) = .empty; defer sync_json.deinit(allocator); const writer = sync_json.writer(allocator); try writer.writeAll("{"); try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id}); try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id}); try writer.print("\"name\":\"{s}\",", .{run_info.name}); try writer.print("\"status\":\"{s}\",", .{run_info.status}); try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time}); if (run_info.end_time) |et| { try writer.print("\"end_time\":\"{s}\",", .{et}); } else { try writer.writeAll("\"end_time\":null,"); } // Add metrics try writer.writeAll("\"metrics\":["); for (metrics.items, 0..) |m, i| { if (i > 0) try writer.writeAll(","); try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step }); } try writer.writeAll("],"); // Add params try writer.writeAll("\"params\":["); for (params.items, 0..) |p, i| { if (i > 0) try writer.writeAll(","); try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value }); } try writer.writeAll("]}"); // Send sync_run message try client.sendSyncRun(sync_json.items, api_key_hash); // Wait for sync_ack const response = try client.receiveMessage(allocator); defer allocator.free(response); if (std.mem.indexOf(u8, response, "sync_ack") == null) { return error.SyncRejected; } } const Metric = struct { key: []const u8, value: f64, step: i64, fn deinit(self: *Metric, allocator: std.mem.Allocator) void { allocator.free(self.key); } }; const Param = struct { key: []const u8, value: []const u8, fn deinit(self: *Param, allocator: std.mem.Allocator) void { allocator.free(self.key); allocator.free(self.value); } }; fn printUsage() void { std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{}); std.debug.print("Push local experiment runs to the server.\n\n", .{}); std.debug.print("Options:\n", .{}); std.debug.print(" --json Output structured JSON\n", .{}); std.debug.print(" --help, -h Show this help message\n\n", .{}); std.debug.print("Examples:\n", .{}); std.debug.print(" ml sync # Sync all unsynced runs\n", .{}); std.debug.print(" ml sync abc123 # Sync specific run\n", .{}); } /// Find the git root directory by walking up from the given path fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 { var buf: [std.fs.max_path_bytes]u8 = undefined; const path = try std.fs.realpath(start_path, &buf); var current = path; while (true) { // Check if .git exists in current directory const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" }); defer allocator.free(git_path); if (std.fs.accessAbsolute(git_path, .{})) { // Found .git directory return try allocator.dupe(u8, current); } else |_| { // .git not found here, try parent const parent = std.fs.path.dirname(current); if (parent == null or std.mem.eql(u8, parent.?, current)) { // Reached root without finding .git return null; } current = parent.?; } } }