fix(cli): WebSocket protocol and sync command fixes
- Add sendSyncRun method for run synchronization - Add sendRerunRequest method for queue rerun - Add sync_run (0x26) and rerun_request (0x27) opcodes - Fix protocol import path to relative path - Fix db.Stmt type alias usage in sync.zig
This commit is contained in:
parent
ccd1dd7a4d
commit
382c67edfc
3 changed files with 85 additions and 20 deletions
|
|
@ -42,10 +42,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
var runs_to_sync = std.ArrayList(RunInfo).init(allocator);
|
||||
var runs_to_sync: std.ArrayList(RunInfo) = .empty;
|
||||
defer {
|
||||
for (runs_to_sync.items) |*r| r.deinit(allocator);
|
||||
runs_to_sync.deinit();
|
||||
runs_to_sync.deinit(allocator);
|
||||
}
|
||||
|
||||
if (specific_run_id) |run_id| {
|
||||
|
|
@ -54,7 +54,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
defer db.DB.finalize(stmt);
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
if (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
|
||||
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
||||
} else {
|
||||
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
|
||||
return;
|
||||
|
|
@ -64,7 +64,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
while (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
|
||||
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -114,14 +114,15 @@ const RunInfo = struct {
|
|||
start_time: []const u8,
|
||||
end_time: ?[]const u8,
|
||||
|
||||
fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo {
|
||||
fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo {
|
||||
const s = stmt.?;
|
||||
return RunInfo{
|
||||
.run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)),
|
||||
.experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)),
|
||||
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)),
|
||||
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)),
|
||||
.start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)),
|
||||
.end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null,
|
||||
.run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)),
|
||||
.experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)),
|
||||
.name = try allocator.dupe(u8, db.DB.columnText(s, 2)),
|
||||
.status = try allocator.dupe(u8, db.DB.columnText(s, 3)),
|
||||
.start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)),
|
||||
.end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -143,10 +144,10 @@ fn syncRun(
|
|||
api_key_hash: []const u8,
|
||||
) !void {
|
||||
// Get metrics for this run
|
||||
var metrics = std.ArrayList(Metric).init(allocator);
|
||||
var metrics: std.ArrayList(Metric) = .empty;
|
||||
defer {
|
||||
for (metrics.items) |*m| m.deinit(allocator);
|
||||
metrics.deinit();
|
||||
metrics.deinit(allocator);
|
||||
}
|
||||
|
||||
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
|
||||
|
|
@ -155,7 +156,7 @@ fn syncRun(
|
|||
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(metrics_stmt)) {
|
||||
try metrics.append(Metric{
|
||||
try metrics.append(allocator, Metric{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
|
||||
.value = db.DB.columnDouble(metrics_stmt, 1),
|
||||
.step = db.DB.columnInt64(metrics_stmt, 2),
|
||||
|
|
@ -163,10 +164,10 @@ fn syncRun(
|
|||
}
|
||||
|
||||
// Get params for this run
|
||||
var params = std.ArrayList(Param).init(allocator);
|
||||
var params: std.ArrayList(Param) = .empty;
|
||||
defer {
|
||||
for (params.items) |*p| p.deinit(allocator);
|
||||
params.deinit();
|
||||
params.deinit(allocator);
|
||||
}
|
||||
|
||||
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
|
||||
|
|
@ -175,15 +176,15 @@ fn syncRun(
|
|||
try db.DB.bindText(params_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(params_stmt)) {
|
||||
try params.append(Param{
|
||||
try params.append(allocator, Param{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
|
||||
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
|
||||
});
|
||||
}
|
||||
|
||||
// Build sync JSON
|
||||
var sync_json = std.ArrayList(u8).init(allocator);
|
||||
defer sync_json.deinit();
|
||||
var sync_json: std.ArrayList(u8) = .empty;
|
||||
defer sync_json.deinit(allocator);
|
||||
const writer = sync_json.writer(allocator);
|
||||
|
||||
try writer.writeAll("{");
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ const std = @import("std");
|
|||
const crypto = @import("crypto");
|
||||
const io = @import("io");
|
||||
const log = @import("log");
|
||||
const protocol = @import("protocol");
|
||||
const protocol = @import("../protocol.zig");
|
||||
const resolve = @import("resolve.zig");
|
||||
const handshake = @import("handshake.zig");
|
||||
const frame = @import("frame.zig");
|
||||
|
|
@ -924,6 +924,62 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendSyncRun(self: *Client, sync_json: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (sync_json.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [json_len: u16] [json: var]
|
||||
const total_len = 1 + 16 + 2 + sync_json.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.sync_run);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(sync_json.len), .big);
|
||||
offset += 2;
|
||||
|
||||
if (sync_json.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + sync_json.len], sync_json);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendRerunRequest(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (run_id.len > 255) return error.PayloadTooLarge;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var]
|
||||
const total_len = 1 + 16 + 1 + run_id.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.rerun_request);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(run_id.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + run_id.len], run_id);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
|
|
|||
|
|
@ -38,6 +38,12 @@ pub const Opcode = enum(u8) {
|
|||
dataset_info = 0x08,
|
||||
dataset_search = 0x09,
|
||||
|
||||
// Sync opcode
|
||||
sync_run = 0x26,
|
||||
|
||||
// Rerun opcode
|
||||
rerun_request = 0x27,
|
||||
|
||||
// Structured response opcodes
|
||||
response_success = 0x10,
|
||||
response_error = 0x11,
|
||||
|
|
@ -83,6 +89,8 @@ pub const dataset_list = Opcode.dataset_list;
|
|||
pub const dataset_register = Opcode.dataset_register;
|
||||
pub const dataset_info = Opcode.dataset_info;
|
||||
pub const dataset_search = Opcode.dataset_search;
|
||||
pub const sync_run = Opcode.sync_run;
|
||||
pub const rerun_request = Opcode.rerun_request;
|
||||
pub const response_success = Opcode.response_success;
|
||||
pub const response_error = Opcode.response_error;
|
||||
pub const response_progress = Opcode.response_progress;
|
||||
|
|
|
|||
Loading…
Reference in a new issue