fix(cli): WebSocket protocol and sync command fixes

- Add sendSyncRun method for run synchronization
- Add sendRerunRequest method for queue rerun
- Add sync_run (0x26) and rerun_request (0x27) opcodes
- Fix protocol import path to relative path
- Fix db.Stmt type alias usage in sync.zig
This commit is contained in:
Jeremie Fraeys 2026-02-21 17:59:14 -05:00
parent ccd1dd7a4d
commit 382c67edfc
No known key found for this signature in database
3 changed files with 85 additions and 20 deletions

View file

@ -42,10 +42,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var database = try db.DB.init(allocator, db_path); var database = try db.DB.init(allocator, db_path);
defer database.close(); defer database.close();
var runs_to_sync = std.ArrayList(RunInfo).init(allocator); var runs_to_sync: std.ArrayList(RunInfo) = .empty;
defer { defer {
for (runs_to_sync.items) |*r| r.deinit(allocator); for (runs_to_sync.items) |*r| r.deinit(allocator);
runs_to_sync.deinit(); runs_to_sync.deinit(allocator);
} }
if (specific_run_id) |run_id| { if (specific_run_id) |run_id| {
@ -54,7 +54,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
defer db.DB.finalize(stmt); defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id); try db.DB.bindText(stmt, 1, run_id);
if (try db.DB.step(stmt)) { if (try db.DB.step(stmt)) {
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
} else { } else {
colors.printWarning("Run {s} already synced or not found\n", .{run_id}); colors.printWarning("Run {s} already synced or not found\n", .{run_id});
return; return;
@ -64,7 +64,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
const stmt = try database.prepare(sql); const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt); defer db.DB.finalize(stmt);
while (try db.DB.step(stmt)) { while (try db.DB.step(stmt)) {
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator)); try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
} }
} }
@ -114,14 +114,15 @@ const RunInfo = struct {
start_time: []const u8, start_time: []const u8,
end_time: ?[]const u8, end_time: ?[]const u8,
fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo { fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo {
const s = stmt.?;
return RunInfo{ return RunInfo{
.run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)), .run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)),
.experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)), .experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)),
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)), .name = try allocator.dupe(u8, db.DB.columnText(s, 2)),
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)), .status = try allocator.dupe(u8, db.DB.columnText(s, 3)),
.start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)), .start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)),
.end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null, .end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null,
}; };
} }
@ -143,10 +144,10 @@ fn syncRun(
api_key_hash: []const u8, api_key_hash: []const u8,
) !void { ) !void {
// Get metrics for this run // Get metrics for this run
var metrics = std.ArrayList(Metric).init(allocator); var metrics: std.ArrayList(Metric) = .empty;
defer { defer {
for (metrics.items) |*m| m.deinit(allocator); for (metrics.items) |*m| m.deinit(allocator);
metrics.deinit(); metrics.deinit(allocator);
} }
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;"; const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
@ -155,7 +156,7 @@ fn syncRun(
try db.DB.bindText(metrics_stmt, 1, run_info.run_id); try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
while (try db.DB.step(metrics_stmt)) { while (try db.DB.step(metrics_stmt)) {
try metrics.append(Metric{ try metrics.append(allocator, Metric{
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)), .key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
.value = db.DB.columnDouble(metrics_stmt, 1), .value = db.DB.columnDouble(metrics_stmt, 1),
.step = db.DB.columnInt64(metrics_stmt, 2), .step = db.DB.columnInt64(metrics_stmt, 2),
@ -163,10 +164,10 @@ fn syncRun(
} }
// Get params for this run // Get params for this run
var params = std.ArrayList(Param).init(allocator); var params: std.ArrayList(Param) = .empty;
defer { defer {
for (params.items) |*p| p.deinit(allocator); for (params.items) |*p| p.deinit(allocator);
params.deinit(); params.deinit(allocator);
} }
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;"; const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
@ -175,15 +176,15 @@ fn syncRun(
try db.DB.bindText(params_stmt, 1, run_info.run_id); try db.DB.bindText(params_stmt, 1, run_info.run_id);
while (try db.DB.step(params_stmt)) { while (try db.DB.step(params_stmt)) {
try params.append(Param{ try params.append(allocator, Param{
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)), .key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)), .value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
}); });
} }
// Build sync JSON // Build sync JSON
var sync_json = std.ArrayList(u8).init(allocator); var sync_json: std.ArrayList(u8) = .empty;
defer sync_json.deinit(); defer sync_json.deinit(allocator);
const writer = sync_json.writer(allocator); const writer = sync_json.writer(allocator);
try writer.writeAll("{"); try writer.writeAll("{");

View file

@ -2,7 +2,7 @@ const std = @import("std");
const crypto = @import("crypto"); const crypto = @import("crypto");
const io = @import("io"); const io = @import("io");
const log = @import("log"); const log = @import("log");
const protocol = @import("protocol"); const protocol = @import("../protocol.zig");
const resolve = @import("resolve.zig"); const resolve = @import("resolve.zig");
const handshake = @import("handshake.zig"); const handshake = @import("handshake.zig");
const frame = @import("frame.zig"); const frame = @import("frame.zig");
@ -924,6 +924,62 @@ pub const Client = struct {
try frame.sendWebSocketFrame(stream, buffer); try frame.sendWebSocketFrame(stream, buffer);
} }
pub fn sendSyncRun(self: *Client, sync_json: []const u8, api_key_hash: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
if (sync_json.len > 0xFFFF) return error.PayloadTooLarge;
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [json_len: u16] [json: var]
const total_len = 1 + 16 + 2 + sync_json.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.sync_run);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(sync_json.len), .big);
offset += 2;
if (sync_json.len > 0) {
@memcpy(buffer[offset .. offset + sync_json.len], sync_json);
}
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendRerunRequest(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
if (run_id.len > 255) return error.PayloadTooLarge;
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var]
const total_len = 1 + 16 + 1 + run_id.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.rerun_request);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(run_id.len);
offset += 1;
@memcpy(buffer[offset .. offset + run_id.len], run_id);
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
const stream = try self.getStream(); const stream = try self.getStream();
try validateApiKeyHash(api_key_hash); try validateApiKeyHash(api_key_hash);

View file

@ -38,6 +38,12 @@ pub const Opcode = enum(u8) {
dataset_info = 0x08, dataset_info = 0x08,
dataset_search = 0x09, dataset_search = 0x09,
// Sync opcode
sync_run = 0x26,
// Rerun opcode
rerun_request = 0x27,
// Structured response opcodes // Structured response opcodes
response_success = 0x10, response_success = 0x10,
response_error = 0x11, response_error = 0x11,
@ -83,6 +89,8 @@ pub const dataset_list = Opcode.dataset_list;
pub const dataset_register = Opcode.dataset_register; pub const dataset_register = Opcode.dataset_register;
pub const dataset_info = Opcode.dataset_info; pub const dataset_info = Opcode.dataset_info;
pub const dataset_search = Opcode.dataset_search; pub const dataset_search = Opcode.dataset_search;
pub const sync_run = Opcode.sync_run;
pub const rerun_request = Opcode.rerun_request;
pub const response_success = Opcode.response_success; pub const response_success = Opcode.response_success;
pub const response_error = Opcode.response_error; pub const response_error = Opcode.response_error;
pub const response_progress = Opcode.response_progress; pub const response_progress = Opcode.response_progress;