From 382c67edfc7c85ce823efc84d64f686d6dde1221 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Sat, 21 Feb 2026 17:59:14 -0500 Subject: [PATCH] fix(cli): WebSocket protocol and sync command fixes - Add sendSyncRun method for run synchronization - Add sendRerunRequest method for queue rerun - Add sync_run (0x26) and rerun_request (0x27) opcodes - Fix protocol import path to relative path - Fix db.Stmt type alias usage in sync.zig --- cli/src/commands/sync.zig | 39 +++++++++++++------------- cli/src/net/ws/client.zig | 58 ++++++++++++++++++++++++++++++++++++++- cli/src/net/ws/opcode.zig | 8 ++++++ 3 files changed, 85 insertions(+), 20 deletions(-) diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig index 6ef9737..79119af 100644 --- a/cli/src/commands/sync.zig +++ b/cli/src/commands/sync.zig @@ -42,10 +42,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var database = try db.DB.init(allocator, db_path); defer database.close(); - var runs_to_sync = std.ArrayList(RunInfo).init(allocator); + var runs_to_sync: std.ArrayList(RunInfo) = .empty; defer { for (runs_to_sync.items) |*r| r.deinit(allocator); - runs_to_sync.deinit(); + runs_to_sync.deinit(allocator); } if (specific_run_id) |run_id| { @@ -54,7 +54,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { defer db.DB.finalize(stmt); try db.DB.bindText(stmt, 1, run_id); if (try db.DB.step(stmt)) { - try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); + try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator)); } else { colors.printWarning("Run {s} already synced or not found\n", .{run_id}); return; @@ -64,7 +64,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { const stmt = try database.prepare(sql); defer db.DB.finalize(stmt); while (try db.DB.step(stmt)) { - try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); + try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator)); } } @@ -114,14 +114,15 @@ const RunInfo = struct { start_time: []const u8, end_time: ?[]const u8, - fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo { + fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo { + const s = stmt.?; return RunInfo{ - .run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)), - .experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)), - .name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)), - .status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)), - .start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)), - .end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null, + .run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)), + .experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)), + .name = try allocator.dupe(u8, db.DB.columnText(s, 2)), + .status = try allocator.dupe(u8, db.DB.columnText(s, 3)), + .start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)), + .end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null, }; } @@ -143,10 +144,10 @@ fn syncRun( api_key_hash: []const u8, ) !void { // Get metrics for this run - var metrics = std.ArrayList(Metric).init(allocator); + var metrics: std.ArrayList(Metric) = .empty; defer { for (metrics.items) |*m| m.deinit(allocator); - metrics.deinit(); + metrics.deinit(allocator); } const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;"; @@ -155,7 +156,7 @@ fn syncRun( try db.DB.bindText(metrics_stmt, 1, run_info.run_id); while (try db.DB.step(metrics_stmt)) { - try metrics.append(Metric{ + try metrics.append(allocator, Metric{ .key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)), .value = db.DB.columnDouble(metrics_stmt, 1), .step = db.DB.columnInt64(metrics_stmt, 2), @@ -163,10 +164,10 @@ fn syncRun( } // Get params for this run - var params = std.ArrayList(Param).init(allocator); + var params: std.ArrayList(Param) = .empty; defer { for (params.items) |*p| p.deinit(allocator); - params.deinit(); + params.deinit(allocator); } const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;"; @@ -175,15 +176,15 @@ fn syncRun( try db.DB.bindText(params_stmt, 1, run_info.run_id); while (try db.DB.step(params_stmt)) { - try params.append(Param{ + try params.append(allocator, Param{ .key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)), .value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)), }); } // Build sync JSON - var sync_json = std.ArrayList(u8).init(allocator); - defer sync_json.deinit(); + var sync_json: std.ArrayList(u8) = .empty; + defer sync_json.deinit(allocator); const writer = sync_json.writer(allocator); try writer.writeAll("{"); diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 56a1807..37a198e 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -2,7 +2,7 @@ const std = @import("std"); const crypto = @import("crypto"); const io = @import("io"); const log = @import("log"); -const protocol = @import("protocol"); +const protocol = @import("../protocol.zig"); const resolve = @import("resolve.zig"); const handshake = @import("handshake.zig"); const frame = @import("frame.zig"); @@ -924,6 +924,62 @@ pub const Client = struct { try frame.sendWebSocketFrame(stream, buffer); } + pub fn sendSyncRun(self: *Client, sync_json: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (sync_json.len > 0xFFFF) return error.PayloadTooLarge; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [json_len: u16] [json: var] + const total_len = 1 + 16 + 2 + sync_json.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.sync_run); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(sync_json.len), .big); + offset += 2; + + if (sync_json.len > 0) { + @memcpy(buffer[offset .. offset + sync_json.len], sync_json); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendRerunRequest(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (run_id.len > 255) return error.PayloadTooLarge; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var] + const total_len = 1 + 16 + 1 + run_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.rerun_request); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(run_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + run_id.len], run_id); + + try frame.sendWebSocketFrame(stream, buffer); + } + pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { const stream = try self.getStream(); try validateApiKeyHash(api_key_hash); diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig index 06c917b..328ac4a 100644 --- a/cli/src/net/ws/opcode.zig +++ b/cli/src/net/ws/opcode.zig @@ -38,6 +38,12 @@ pub const Opcode = enum(u8) { dataset_info = 0x08, dataset_search = 0x09, + // Sync opcode + sync_run = 0x26, + + // Rerun opcode + rerun_request = 0x27, + // Structured response opcodes response_success = 0x10, response_error = 0x11, @@ -83,6 +89,8 @@ pub const dataset_list = Opcode.dataset_list; pub const dataset_register = Opcode.dataset_register; pub const dataset_info = Opcode.dataset_info; pub const dataset_search = Opcode.dataset_search; +pub const sync_run = Opcode.sync_run; +pub const rerun_request = Opcode.rerun_request; pub const response_success = Opcode.response_success; pub const response_error = Opcode.response_error; pub const response_progress = Opcode.response_progress;