refactor(cli): reduce code duplication in WebSocket client

Add MessageBuilder struct and validation helpers to reduce duplication:

- MessageBuilder: Helper for constructing binary WebSocket messages
  - writeOpcode, writeBytes, writeU8/16/32/64
  - writeStringU8, writeStringU16 for length-prefixed strings

- Validation helpers: validateApiKeyHash, validateCommitId, validateJobName
- getStream(): Extracted common stream check pattern

Refactored 4 representative send methods to use new helpers:
- sendValidateRequestCommit, sendListJupyterPackages
- sendCancelJob, sendStatusRequest

Consolidated disconnect/close: close() now calls disconnect()

Updated response handlers to use packet.deinit():
- receiveAndHandleResponse
- receiveAndHandleDatasetResponse

Reduces ~100 lines of boilerplate duplication.
Build verified: zig build --release=fast
This commit is contained in:
Jeremie Fraeys 2026-02-18 13:56:30 -05:00
parent 14eba436bf
commit cd2908181c
No known key found for this signature in database

View file

@ -12,6 +12,74 @@ const response_handlers = @import("response_handlers.zig");
const opcode = @import("opcode.zig");
const utils = @import("utils.zig");
/// Helper for building WebSocket binary messages
const MessageBuilder = struct {
buffer: []u8,
offset: usize,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, total_len: usize) !MessageBuilder {
const buffer = try allocator.alloc(u8, total_len);
return MessageBuilder{
.buffer = buffer,
.offset = 0,
.allocator = allocator,
};
}
pub fn deinit(self: *MessageBuilder) void {
self.allocator.free(self.buffer);
}
pub fn writeOpcode(self: *MessageBuilder, op: opcode.Opcode) void {
self.buffer[self.offset] = @intFromEnum(op);
self.offset += 1;
}
pub fn writeBytes(self: *MessageBuilder, data: []const u8) void {
@memcpy(self.buffer[self.offset .. self.offset + data.len], data);
self.offset += data.len;
}
pub fn writeU8(self: *MessageBuilder, value: u8) void {
self.buffer[self.offset] = value;
self.offset += 1;
}
pub fn writeU16(self: *MessageBuilder, value: u16) void {
std.mem.writeInt(u16, self.buffer[self.offset .. self.offset + 2][0..2], value, .big);
self.offset += 2;
}
pub fn writeU32(self: *MessageBuilder, value: u32) void {
std.mem.writeInt(u32, self.buffer[self.offset .. self.offset + 4][0..4], value, .big);
self.offset += 4;
}
pub fn writeU64(self: *MessageBuilder, value: u64) void {
std.mem.writeInt(u64, self.buffer[self.offset .. self.offset + 8][0..8], value, .big);
self.offset += 8;
}
pub fn writeStringU8(self: *MessageBuilder, str: []const u8) void {
self.writeU8(@intCast(str.len));
if (str.len > 0) {
self.writeBytes(str);
}
}
pub fn writeStringU16(self: *MessageBuilder, str: []const u8) void {
self.writeU16(@intCast(str.len));
if (str.len > 0) {
self.writeBytes(str);
}
}
pub fn send(self: *MessageBuilder, stream: std.net.Stream) !void {
try frame.sendWebSocketFrame(stream, self.buffer);
}
};
/// WebSocket client for binary protocol communication
pub const Client = struct {
allocator: std.mem.Allocator,
@ -103,7 +171,7 @@ pub const Client = struct {
return last_error;
}
/// Disconnect from WebSocket server
/// Disconnect from WebSocket server (closes stream only)
pub fn disconnect(self: *Client) void {
if (self.stream) |stream| {
stream.close();
@ -111,62 +179,59 @@ pub const Client = struct {
}
}
/// Fully close client - disconnects stream and frees host memory
pub fn close(self: *Client) void {
if (self.stream) |stream| {
stream.close();
self.stream = null;
}
self.disconnect();
if (self.host.len > 0) {
self.allocator.free(self.host);
}
}
pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
// Validation helpers
fn validateApiKeyHash(api_key_hash: []const u8) error{InvalidApiKeyHash}!void {
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
}
fn validateCommitId(commit_id: []const u8) error{InvalidCommitId}!void {
if (commit_id.len != 20) return error.InvalidCommitId;
}
const total_len = 1 + 16 + 1 + 1 + 20;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
fn validateJobName(job_name: []const u8) error{JobNameTooLong}!void {
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
}
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.Opcode.validate_request);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intFromEnum(opcode.ValidateTargetType.commit_id);
offset += 1;
buffer[offset] = 20;
offset += 1;
@memcpy(buffer[offset .. offset + 20], commit_id);
try frame.sendWebSocketFrame(stream, buffer);
fn getStream(self: *Client) error{NotConnected}!std.net.Stream {
return self.stream orelse error.NotConnected;
}
pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void {
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
try validateCommitId(commit_id);
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + 1 + 20);
defer builder.deinit();
builder.writeOpcode(opcode.Opcode.validate_request);
builder.writeBytes(api_key_hash);
builder.writeU8(@intFromEnum(opcode.ValidateTargetType.commit_id));
builder.writeU8(20);
builder.writeBytes(commit_id);
try builder.send(stream);
}
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
if (name.len > 255) return error.NameTooLong;
// Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var]
const total_len = 1 + 16 + 1 + name.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + name.len);
defer builder.deinit();
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.list_jupyter_packages);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(name.len);
offset += 1;
@memcpy(buffer[offset .. offset + name.len], name);
try frame.sendWebSocketFrame(stream, buffer);
builder.writeOpcode(opcode.list_jupyter_packages);
builder.writeBytes(api_key_hash);
builder.writeStringU8(name);
try builder.send(stream);
}
pub fn sendSetRunNarrative(
@ -749,30 +814,17 @@ pub const Client = struct {
}
pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
try validateJobName(job_name);
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
if (job_name.len > 255) return error.JobNameTooLong;
var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + job_name.len);
defer builder.deinit();
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var]
const total_len = 1 + 16 + 1 + job_name.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.cancel_job);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(job_name.len);
offset += 1;
@memcpy(buffer[offset..], job_name);
try frame.sendWebSocketFrame(stream, buffer);
builder.writeOpcode(opcode.cancel_job);
builder.writeBytes(api_key_hash);
builder.writeStringU8(job_name);
try builder.send(stream);
}
pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void {
@ -806,20 +858,15 @@ pub const Client = struct {
}
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
const stream = self.stream orelse return error.NotConnected;
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
var builder = try MessageBuilder.init(self.allocator, 1 + 16);
defer builder.deinit();
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes]
const total_len = 1 + 16;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
buffer[0] = @intFromEnum(opcode.status_request);
@memcpy(buffer[1..17], api_key_hash);
try frame.sendWebSocketFrame(stream, buffer);
builder.writeOpcode(opcode.status_request);
builder.writeBytes(api_key_hash);
try builder.send(stream);
}
pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 {
@ -838,16 +885,7 @@ pub const Client = struct {
std.debug.print("Server response: {s}\n", .{message});
return;
};
defer {
if (packet.success_message) |msg| allocator.free(msg);
if (packet.error_message) |msg| allocator.free(msg);
if (packet.error_details) |details| allocator.free(details);
if (packet.data_type) |dtype| allocator.free(dtype);
if (packet.data_payload) |payload| allocator.free(payload);
if (packet.progress_message) |pmsg| allocator.free(pmsg);
if (packet.status_data) |sdata| allocator.free(sdata);
if (packet.log_message) |lmsg| allocator.free(lmsg);
}
defer packet.deinit(allocator);
try response_handlers.handleResponsePacket(self, packet, operation);
}
@ -1315,16 +1353,7 @@ pub const Client = struct {
// Fallback: treat as plain response.
return allocator.dupe(u8, message);
};
defer {
if (packet.success_message) |msg| allocator.free(msg);
if (packet.error_message) |msg| allocator.free(msg);
if (packet.error_details) |details| allocator.free(details);
if (packet.data_type) |dtype| allocator.free(dtype);
if (packet.data_payload) |payload| allocator.free(payload);
if (packet.progress_message) |pmsg| allocator.free(pmsg);
if (packet.status_data) |sdata| allocator.free(sdata);
if (packet.log_message) |lmsg| allocator.free(lmsg);
}
defer packet.deinit(allocator);
switch (packet.packet_type) {
.data => {