refactor(cli): reduce code duplication in WebSocket client
Add MessageBuilder struct and validation helpers to reduce duplication: - MessageBuilder: Helper for constructing binary WebSocket messages - writeOpcode, writeBytes, writeU8/16/32/64 - writeStringU8, writeStringU16 for length-prefixed strings - Validation helpers: validateApiKeyHash, validateCommitId, validateJobName - getStream(): Extracted common stream check pattern Refactored 4 representative send methods to use new helpers: - sendValidateRequestCommit, sendListJupyterPackages - sendCancelJob, sendStatusRequest Consolidated disconnect/close: close() now calls disconnect() Updated response handlers to use packet.deinit(): - receiveAndHandleResponse - receiveAndHandleDatasetResponse Reduces ~100 lines of boilerplate duplication. Build verified: zig build --release=fast
This commit is contained in:
parent
14eba436bf
commit
cd2908181c
1 changed files with 124 additions and 95 deletions
|
|
@ -12,6 +12,74 @@ const response_handlers = @import("response_handlers.zig");
|
|||
const opcode = @import("opcode.zig");
|
||||
const utils = @import("utils.zig");
|
||||
|
||||
/// Helper for building WebSocket binary messages
|
||||
const MessageBuilder = struct {
|
||||
buffer: []u8,
|
||||
offset: usize,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, total_len: usize) !MessageBuilder {
|
||||
const buffer = try allocator.alloc(u8, total_len);
|
||||
return MessageBuilder{
|
||||
.buffer = buffer,
|
||||
.offset = 0,
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *MessageBuilder) void {
|
||||
self.allocator.free(self.buffer);
|
||||
}
|
||||
|
||||
pub fn writeOpcode(self: *MessageBuilder, op: opcode.Opcode) void {
|
||||
self.buffer[self.offset] = @intFromEnum(op);
|
||||
self.offset += 1;
|
||||
}
|
||||
|
||||
pub fn writeBytes(self: *MessageBuilder, data: []const u8) void {
|
||||
@memcpy(self.buffer[self.offset .. self.offset + data.len], data);
|
||||
self.offset += data.len;
|
||||
}
|
||||
|
||||
pub fn writeU8(self: *MessageBuilder, value: u8) void {
|
||||
self.buffer[self.offset] = value;
|
||||
self.offset += 1;
|
||||
}
|
||||
|
||||
pub fn writeU16(self: *MessageBuilder, value: u16) void {
|
||||
std.mem.writeInt(u16, self.buffer[self.offset .. self.offset + 2][0..2], value, .big);
|
||||
self.offset += 2;
|
||||
}
|
||||
|
||||
pub fn writeU32(self: *MessageBuilder, value: u32) void {
|
||||
std.mem.writeInt(u32, self.buffer[self.offset .. self.offset + 4][0..4], value, .big);
|
||||
self.offset += 4;
|
||||
}
|
||||
|
||||
pub fn writeU64(self: *MessageBuilder, value: u64) void {
|
||||
std.mem.writeInt(u64, self.buffer[self.offset .. self.offset + 8][0..8], value, .big);
|
||||
self.offset += 8;
|
||||
}
|
||||
|
||||
pub fn writeStringU8(self: *MessageBuilder, str: []const u8) void {
|
||||
self.writeU8(@intCast(str.len));
|
||||
if (str.len > 0) {
|
||||
self.writeBytes(str);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeStringU16(self: *MessageBuilder, str: []const u8) void {
|
||||
self.writeU16(@intCast(str.len));
|
||||
if (str.len > 0) {
|
||||
self.writeBytes(str);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(self: *MessageBuilder, stream: std.net.Stream) !void {
|
||||
try frame.sendWebSocketFrame(stream, self.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
/// WebSocket client for binary protocol communication
|
||||
pub const Client = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
|
|
@ -103,7 +171,7 @@ pub const Client = struct {
|
|||
return last_error;
|
||||
}
|
||||
|
||||
/// Disconnect from WebSocket server
|
||||
/// Disconnect from WebSocket server (closes stream only)
|
||||
pub fn disconnect(self: *Client) void {
|
||||
if (self.stream) |stream| {
|
||||
stream.close();
|
||||
|
|
@ -111,62 +179,59 @@ pub const Client = struct {
|
|||
}
|
||||
}
|
||||
|
||||
/// Fully close client - disconnects stream and frees host memory
|
||||
pub fn close(self: *Client) void {
|
||||
if (self.stream) |stream| {
|
||||
stream.close();
|
||||
self.stream = null;
|
||||
}
|
||||
self.disconnect();
|
||||
if (self.host.len > 0) {
|
||||
self.allocator.free(self.host);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
// Validation helpers
|
||||
fn validateApiKeyHash(api_key_hash: []const u8) error{InvalidApiKeyHash}!void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
}
|
||||
|
||||
fn validateCommitId(commit_id: []const u8) error{InvalidCommitId}!void {
|
||||
if (commit_id.len != 20) return error.InvalidCommitId;
|
||||
}
|
||||
|
||||
const total_len = 1 + 16 + 1 + 1 + 20;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
fn validateJobName(job_name: []const u8) error{JobNameTooLong}!void {
|
||||
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
|
||||
}
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.Opcode.validate_request);
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
buffer[offset] = @intFromEnum(opcode.ValidateTargetType.commit_id);
|
||||
offset += 1;
|
||||
buffer[offset] = 20;
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + 20], commit_id);
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
fn getStream(self: *Client) error{NotConnected}!std.net.Stream {
|
||||
return self.stream orelse error.NotConnected;
|
||||
}
|
||||
|
||||
pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateCommitId(commit_id);
|
||||
|
||||
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + 1 + 20);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.Opcode.validate_request);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeU8(@intFromEnum(opcode.ValidateTargetType.commit_id));
|
||||
builder.writeU8(20);
|
||||
builder.writeBytes(commit_id);
|
||||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
if (name.len > 255) return error.NameTooLong;
|
||||
|
||||
// Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var]
|
||||
const total_len = 1 + 16 + 1 + name.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + name.len);
|
||||
defer builder.deinit();
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.list_jupyter_packages);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + name.len], name);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
builder.writeOpcode(opcode.list_jupyter_packages);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeStringU8(name);
|
||||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn sendSetRunNarrative(
|
||||
|
|
@ -749,30 +814,17 @@ pub const Client = struct {
|
|||
}
|
||||
|
||||
pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateJobName(job_name);
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (job_name.len > 255) return error.JobNameTooLong;
|
||||
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + job_name.len);
|
||||
defer builder.deinit();
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var]
|
||||
const total_len = 1 + 16 + 1 + job_name.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.cancel_job);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(job_name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset..], job_name);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
builder.writeOpcode(opcode.cancel_job);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeStringU8(job_name);
|
||||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void {
|
||||
|
|
@ -806,20 +858,15 @@ pub const Client = struct {
|
|||
}
|
||||
|
||||
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
var builder = try MessageBuilder.init(self.allocator, 1 + 16);
|
||||
defer builder.deinit();
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes]
|
||||
const total_len = 1 + 16;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
buffer[0] = @intFromEnum(opcode.status_request);
|
||||
@memcpy(buffer[1..17], api_key_hash);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
builder.writeOpcode(opcode.status_request);
|
||||
builder.writeBytes(api_key_hash);
|
||||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 {
|
||||
|
|
@ -838,16 +885,7 @@ pub const Client = struct {
|
|||
std.debug.print("Server response: {s}\n", .{message});
|
||||
return;
|
||||
};
|
||||
defer {
|
||||
if (packet.success_message) |msg| allocator.free(msg);
|
||||
if (packet.error_message) |msg| allocator.free(msg);
|
||||
if (packet.error_details) |details| allocator.free(details);
|
||||
if (packet.data_type) |dtype| allocator.free(dtype);
|
||||
if (packet.data_payload) |payload| allocator.free(payload);
|
||||
if (packet.progress_message) |pmsg| allocator.free(pmsg);
|
||||
if (packet.status_data) |sdata| allocator.free(sdata);
|
||||
if (packet.log_message) |lmsg| allocator.free(lmsg);
|
||||
}
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
try response_handlers.handleResponsePacket(self, packet, operation);
|
||||
}
|
||||
|
|
@ -1315,16 +1353,7 @@ pub const Client = struct {
|
|||
// Fallback: treat as plain response.
|
||||
return allocator.dupe(u8, message);
|
||||
};
|
||||
defer {
|
||||
if (packet.success_message) |msg| allocator.free(msg);
|
||||
if (packet.error_message) |msg| allocator.free(msg);
|
||||
if (packet.error_details) |details| allocator.free(details);
|
||||
if (packet.data_type) |dtype| allocator.free(dtype);
|
||||
if (packet.data_payload) |payload| allocator.free(payload);
|
||||
if (packet.progress_message) |pmsg| allocator.free(pmsg);
|
||||
if (packet.status_data) |sdata| allocator.free(sdata);
|
||||
if (packet.log_message) |lmsg| allocator.free(lmsg);
|
||||
}
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.data => {
|
||||
|
|
|
|||
Loading…
Reference in a new issue