From cd2908181ce5b6e852bf349440b46a4c75422fb2 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Wed, 18 Feb 2026 13:56:30 -0500 Subject: [PATCH] refactor(cli): reduce code duplication in WebSocket client Add MessageBuilder struct and validation helpers to reduce duplication: - MessageBuilder: Helper for constructing binary WebSocket messages - writeOpcode, writeBytes, writeU8/16/32/64 - writeStringU8, writeStringU16 for length-prefixed strings - Validation helpers: validateApiKeyHash, validateCommitId, validateJobName - getStream(): Extracted common stream check pattern Refactored 4 representative send methods to use new helpers: - sendValidateRequestCommit, sendListJupyterPackages - sendCancelJob, sendStatusRequest Consolidated disconnect/close: close() now calls disconnect() Updated response handlers to use packet.deinit(): - receiveAndHandleResponse - receiveAndHandleDatasetResponse Reduces ~100 lines of boilerplate duplication. Build verified: zig build --release=fast --- cli/src/net/ws/client.zig | 219 +++++++++++++++++++++----------------- 1 file changed, 124 insertions(+), 95 deletions(-) diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 4269eaf..38acac4 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -12,6 +12,74 @@ const response_handlers = @import("response_handlers.zig"); const opcode = @import("opcode.zig"); const utils = @import("utils.zig"); +/// Helper for building WebSocket binary messages +const MessageBuilder = struct { + buffer: []u8, + offset: usize, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator, total_len: usize) !MessageBuilder { + const buffer = try allocator.alloc(u8, total_len); + return MessageBuilder{ + .buffer = buffer, + .offset = 0, + .allocator = allocator, + }; + } + + pub fn deinit(self: *MessageBuilder) void { + self.allocator.free(self.buffer); + } + + pub fn writeOpcode(self: *MessageBuilder, op: opcode.Opcode) void { + self.buffer[self.offset] = @intFromEnum(op); + self.offset += 1; + } + + pub fn writeBytes(self: *MessageBuilder, data: []const u8) void { + @memcpy(self.buffer[self.offset .. self.offset + data.len], data); + self.offset += data.len; + } + + pub fn writeU8(self: *MessageBuilder, value: u8) void { + self.buffer[self.offset] = value; + self.offset += 1; + } + + pub fn writeU16(self: *MessageBuilder, value: u16) void { + std.mem.writeInt(u16, self.buffer[self.offset .. self.offset + 2][0..2], value, .big); + self.offset += 2; + } + + pub fn writeU32(self: *MessageBuilder, value: u32) void { + std.mem.writeInt(u32, self.buffer[self.offset .. self.offset + 4][0..4], value, .big); + self.offset += 4; + } + + pub fn writeU64(self: *MessageBuilder, value: u64) void { + std.mem.writeInt(u64, self.buffer[self.offset .. self.offset + 8][0..8], value, .big); + self.offset += 8; + } + + pub fn writeStringU8(self: *MessageBuilder, str: []const u8) void { + self.writeU8(@intCast(str.len)); + if (str.len > 0) { + self.writeBytes(str); + } + } + + pub fn writeStringU16(self: *MessageBuilder, str: []const u8) void { + self.writeU16(@intCast(str.len)); + if (str.len > 0) { + self.writeBytes(str); + } + } + + pub fn send(self: *MessageBuilder, stream: std.net.Stream) !void { + try frame.sendWebSocketFrame(stream, self.buffer); + } +}; + /// WebSocket client for binary protocol communication pub const Client = struct { allocator: std.mem.Allocator, @@ -103,7 +171,7 @@ pub const Client = struct { return last_error; } - /// Disconnect from WebSocket server + /// Disconnect from WebSocket server (closes stream only) pub fn disconnect(self: *Client) void { if (self.stream) |stream| { stream.close(); @@ -111,62 +179,59 @@ pub const Client = struct { } } + /// Fully close client - disconnects stream and frees host memory pub fn close(self: *Client) void { - if (self.stream) |stream| { - stream.close(); - self.stream = null; - } + self.disconnect(); if (self.host.len > 0) { self.allocator.free(self.host); } } - pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; + // Validation helpers + fn validateApiKeyHash(api_key_hash: []const u8) error{InvalidApiKeyHash}!void { if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + } + + fn validateCommitId(commit_id: []const u8) error{InvalidCommitId}!void { if (commit_id.len != 20) return error.InvalidCommitId; + } - const total_len = 1 + 16 + 1 + 1 + 20; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); + fn validateJobName(job_name: []const u8) error{JobNameTooLong}!void { + if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong; + } - var offset: usize = 0; - buffer[offset] = @intFromEnum(opcode.Opcode.validate_request); - offset += 1; - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - buffer[offset] = @intFromEnum(opcode.ValidateTargetType.commit_id); - offset += 1; - buffer[offset] = 20; - offset += 1; - @memcpy(buffer[offset .. offset + 20], commit_id); - try frame.sendWebSocketFrame(stream, buffer); + fn getStream(self: *Client) error{NotConnected}!std.net.Stream { + return self.stream orelse error.NotConnected; + } + + pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { + const stream = try self.getStream(); + try validateApiKeyHash(api_key_hash); + try validateCommitId(commit_id); + + var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + 1 + 20); + defer builder.deinit(); + + builder.writeOpcode(opcode.Opcode.validate_request); + builder.writeBytes(api_key_hash); + builder.writeU8(@intFromEnum(opcode.ValidateTargetType.commit_id)); + builder.writeU8(20); + builder.writeBytes(commit_id); + try builder.send(stream); } pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + const stream = try self.getStream(); + try validateApiKeyHash(api_key_hash); if (name.len > 255) return error.NameTooLong; - // Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var] - const total_len = 1 + 16 + 1 + name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); + var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + name.len); + defer builder.deinit(); - var offset: usize = 0; - buffer[offset] = @intFromEnum(opcode.list_jupyter_packages); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(name.len); - offset += 1; - - @memcpy(buffer[offset .. offset + name.len], name); - - try frame.sendWebSocketFrame(stream, buffer); + builder.writeOpcode(opcode.list_jupyter_packages); + builder.writeBytes(api_key_hash); + builder.writeStringU8(name); + try builder.send(stream); } pub fn sendSetRunNarrative( @@ -749,30 +814,17 @@ pub const Client = struct { } pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; + const stream = try self.getStream(); + try validateApiKeyHash(api_key_hash); + try validateJobName(job_name); - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (job_name.len > 255) return error.JobNameTooLong; + var builder = try MessageBuilder.init(self.allocator, 1 + 16 + 1 + job_name.len); + defer builder.deinit(); - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] - const total_len = 1 + 16 + 1 + job_name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(opcode.cancel_job); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset..], job_name); - - try frame.sendWebSocketFrame(stream, buffer); + builder.writeOpcode(opcode.cancel_job); + builder.writeBytes(api_key_hash); + builder.writeStringU8(job_name); + try builder.send(stream); } pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void { @@ -806,20 +858,15 @@ pub const Client = struct { } pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; + const stream = try self.getStream(); + try validateApiKeyHash(api_key_hash); - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + var builder = try MessageBuilder.init(self.allocator, 1 + 16); + defer builder.deinit(); - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] - const total_len = 1 + 16; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - buffer[0] = @intFromEnum(opcode.status_request); - @memcpy(buffer[1..17], api_key_hash); - - try frame.sendWebSocketFrame(stream, buffer); + builder.writeOpcode(opcode.status_request); + builder.writeBytes(api_key_hash); + try builder.send(stream); } pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 { @@ -838,16 +885,7 @@ pub const Client = struct { std.debug.print("Server response: {s}\n", .{message}); return; }; - defer { - if (packet.success_message) |msg| allocator.free(msg); - if (packet.error_message) |msg| allocator.free(msg); - if (packet.error_details) |details| allocator.free(details); - if (packet.data_type) |dtype| allocator.free(dtype); - if (packet.data_payload) |payload| allocator.free(payload); - if (packet.progress_message) |pmsg| allocator.free(pmsg); - if (packet.status_data) |sdata| allocator.free(sdata); - if (packet.log_message) |lmsg| allocator.free(lmsg); - } + defer packet.deinit(allocator); try response_handlers.handleResponsePacket(self, packet, operation); } @@ -1315,16 +1353,7 @@ pub const Client = struct { // Fallback: treat as plain response. return allocator.dupe(u8, message); }; - defer { - if (packet.success_message) |msg| allocator.free(msg); - if (packet.error_message) |msg| allocator.free(msg); - if (packet.error_details) |details| allocator.free(details); - if (packet.data_type) |dtype| allocator.free(dtype); - if (packet.data_payload) |payload| allocator.free(payload); - if (packet.progress_message) |pmsg| allocator.free(pmsg); - if (packet.status_data) |sdata| allocator.free(sdata); - if (packet.log_message) |lmsg| allocator.free(lmsg); - } + defer packet.deinit(allocator); switch (packet.packet_type) { .data => {