refactor(cli): create modular WebSocket client structure
Break monolithic client.zig (1,558 lines) into focused modules: - connection.zig: Transport, connection logic, URL parsing, TLS setup - messaging.zig: MessageBuilder, validation, send methods - state.zig: ClientState, response handling, error conversion - mod.zig: Public exports and Client struct composition Benefits: - Each module <400 lines (maintainability target) - Clear separation of concerns - Easier to test individual components - Foundation for future client refactoring Original client.zig kept intact for backward compatibility. Build passes successfully.
This commit is contained in:
parent
fd4c342de0
commit
c17811cf2b
4 changed files with 713 additions and 0 deletions
162
cli/src/net/ws/client/connection.zig
Normal file
162
cli/src/net/ws/client/connection.zig
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
const std = @import("std");
|
||||
const tls = @import("../tls.zig");
|
||||
const resolve = @import("../resolve.zig");
|
||||
const handshake = @import("../handshake.zig");
|
||||
|
||||
/// Transport abstraction for WebSocket connections
|
||||
/// Supports both raw TCP and TLS-wrapped connections
|
||||
pub const Transport = union(enum) {
|
||||
tcp: std.net.Stream,
|
||||
tls: *tls.TlsStream,
|
||||
|
||||
pub fn read(self: Transport, buffer: []u8) !usize {
|
||||
return switch (self) {
|
||||
.tcp => |s| s.read(buffer),
|
||||
.tls => |s| s.read(buffer),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn write(self: Transport, buffer: []const u8) !void {
|
||||
return switch (self) {
|
||||
.tcp => |s| s.writeAll(buffer),
|
||||
.tls => |s| s.write(buffer),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn close(self: *Transport) void {
|
||||
switch (self.*) {
|
||||
.tcp => |*s| s.close(),
|
||||
.tls => |s| s.close(),
|
||||
}
|
||||
self.* = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
/// Connection result returned after successful connection
|
||||
pub const ConnectionResult = struct {
|
||||
transport: Transport,
|
||||
host: []const u8,
|
||||
port: u16,
|
||||
is_tls: bool,
|
||||
};
|
||||
|
||||
/// Connection configuration
|
||||
pub const ConnectionConfig = struct {
|
||||
url: []const u8,
|
||||
api_key: []const u8,
|
||||
max_retries: u32 = 3,
|
||||
};
|
||||
|
||||
/// Parse WebSocket URL and extract host, port, TLS info
|
||||
pub fn parseUrl(url: []const u8) !struct { host: []const u8, port: u16, is_tls: bool, path: []const u8 } {
|
||||
const is_tls = std.mem.startsWith(u8, url, "wss://");
|
||||
|
||||
const host_start = std.mem.indexOf(u8, url, "//") orelse return error.InvalidURL;
|
||||
const host_port_start = host_start + 2;
|
||||
const path_start = std.mem.indexOfPos(u8, url, host_port_start, "/") orelse url.len;
|
||||
const colon_pos = std.mem.indexOfPos(u8, url, host_port_start, ":");
|
||||
|
||||
const host_end = blk: {
|
||||
if (colon_pos) |pos| {
|
||||
if (pos < path_start) break :blk pos;
|
||||
}
|
||||
break :blk path_start;
|
||||
};
|
||||
const host = url[host_port_start..host_end];
|
||||
|
||||
var port: u16 = if (is_tls) 9101 else 9100;
|
||||
if (colon_pos) |pos| {
|
||||
if (pos < path_start) {
|
||||
const port_start = pos + 1;
|
||||
const port_end = std.mem.indexOfPos(u8, url, port_start, "/") orelse url.len;
|
||||
const port_str = url[port_start..port_end];
|
||||
port = try std.fmt.parseInt(u16, port_str, 10);
|
||||
}
|
||||
}
|
||||
|
||||
const path = url[path_start..];
|
||||
return .{ .host = host, .port = port, .is_tls = is_tls, .path = path };
|
||||
}
|
||||
|
||||
/// Create transport (TCP or TLS) for connection
|
||||
pub fn createTransport(allocator: std.mem.Allocator, host: []const u8, port: u16, is_tls: bool) !Transport {
|
||||
const tcp_stream = try std.net.tcpConnectToAddress(try resolve.resolveHostAddress(allocator, host, port));
|
||||
|
||||
if (is_tls) {
|
||||
const tls_stream = try allocator.create(tls.TlsStream);
|
||||
errdefer allocator.destroy(tls_stream);
|
||||
|
||||
tls_stream.* = tls.TlsStream.init(allocator, tcp_stream, host) catch |err| {
|
||||
allocator.destroy(tls_stream);
|
||||
tcp_stream.close();
|
||||
if (err == error.TlsLibraryRequired) {
|
||||
std.log.warn("TLS (wss://) support requires external TLS library integration. Falling back to ws://", .{});
|
||||
return error.TLSNotSupported;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
return Transport{ .tls = tls_stream };
|
||||
} else {
|
||||
return Transport{ .tcp = tcp_stream };
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect with retry logic
|
||||
pub fn connectWithRetry(
|
||||
allocator: std.mem.Allocator,
|
||||
url: []const u8,
|
||||
api_key: []const u8,
|
||||
max_retries: u32,
|
||||
) !ConnectionResult {
|
||||
var retry_count: u32 = 0;
|
||||
var last_error: anyerror = error.ConnectionFailed;
|
||||
|
||||
while (retry_count < max_retries) {
|
||||
const result = tryConnect(allocator, url, api_key) catch |err| {
|
||||
last_error = err;
|
||||
retry_count += 1;
|
||||
|
||||
if (retry_count < max_retries) {
|
||||
const delay_ms = @min(1000 * retry_count, 5000);
|
||||
std.log.warn("Connection failed (attempt {d}/{d}), retrying in {d}s...\n", .{ retry_count, max_retries, delay_ms / 1000 });
|
||||
std.Thread.sleep(@as(u64, delay_ms) * std.time.ns_per_ms);
|
||||
}
|
||||
continue;
|
||||
};
|
||||
|
||||
if (retry_count > 0) {
|
||||
std.log.info("Connected successfully after {d} attempts\n", .{retry_count + 1});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
return last_error;
|
||||
}
|
||||
|
||||
/// Single connection attempt
|
||||
fn tryConnect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !ConnectionResult {
|
||||
const parsed = try parseUrl(url);
|
||||
|
||||
var transport = try createTransport(allocator, parsed.host, parsed.port, parsed.is_tls);
|
||||
errdefer transport.close();
|
||||
|
||||
try handshake.handshake(allocator, transport, parsed.host, url, api_key);
|
||||
|
||||
const host_copy = try allocator.dupe(u8, parsed.host);
|
||||
errdefer allocator.free(host_copy);
|
||||
|
||||
return .{
|
||||
.transport = transport,
|
||||
.host = host_copy,
|
||||
.port = parsed.port,
|
||||
.is_tls = parsed.is_tls,
|
||||
};
|
||||
}
|
||||
|
||||
/// Close connection and cleanup
|
||||
pub fn closeConnection(transport: *Transport, host: []const u8, allocator: std.mem.Allocator) void {
|
||||
transport.close();
|
||||
if (host.len > 0) {
|
||||
allocator.free(host);
|
||||
}
|
||||
}
|
||||
306
cli/src/net/ws/client/messaging.zig
Normal file
306
cli/src/net/ws/client/messaging.zig
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
const std = @import("std");
|
||||
const frame = @import("../frame.zig");
|
||||
const opcode = @import("../opcode.zig");
|
||||
const Transport = @import("connection.zig").Transport;
|
||||
|
||||
/// Binary message builder for WebSocket frames
|
||||
pub const MessageBuilder = struct {
|
||||
buffer: []u8,
|
||||
offset: usize,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, total_len: usize) !MessageBuilder {
|
||||
const buffer = try allocator.alloc(u8, total_len);
|
||||
return MessageBuilder{
|
||||
.buffer = buffer,
|
||||
.offset = 0,
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *MessageBuilder) void {
|
||||
self.allocator.free(self.buffer);
|
||||
}
|
||||
|
||||
pub fn writeOpcode(self: *MessageBuilder, op: opcode.Opcode) void {
|
||||
self.buffer[self.offset] = @intFromEnum(op);
|
||||
self.offset += 1;
|
||||
}
|
||||
|
||||
pub fn writeBytes(self: *MessageBuilder, data: []const u8) void {
|
||||
@memcpy(self.buffer[self.offset .. self.offset + data.len], data);
|
||||
self.offset += data.len;
|
||||
}
|
||||
|
||||
pub fn writeU8(self: *MessageBuilder, value: u8) void {
|
||||
self.buffer[self.offset] = value;
|
||||
self.offset += 1;
|
||||
}
|
||||
|
||||
pub fn writeU16(self: *MessageBuilder, value: u16) void {
|
||||
std.mem.writeInt(u16, self.buffer[self.offset .. self.offset + 2][0..2], value, .big);
|
||||
self.offset += 2;
|
||||
}
|
||||
|
||||
pub fn writeU32(self: *MessageBuilder, value: u32) void {
|
||||
std.mem.writeInt(u32, self.buffer[self.offset .. self.offset + 4][0..4], value, .big);
|
||||
self.offset += 4;
|
||||
}
|
||||
|
||||
pub fn writeU64(self: *MessageBuilder, value: u64) void {
|
||||
std.mem.writeInt(u64, self.buffer[self.offset .. self.offset + 8][0..8], value, .big);
|
||||
self.offset += 8;
|
||||
}
|
||||
|
||||
pub fn writeStringU8(self: *MessageBuilder, str: []const u8) void {
|
||||
self.writeU8(@intCast(str.len));
|
||||
if (str.len > 0) {
|
||||
self.writeBytes(str);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeStringU16(self: *MessageBuilder, str: []const u8) void {
|
||||
self.writeU16(@intCast(str.len));
|
||||
if (str.len > 0) {
|
||||
self.writeBytes(str);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(self: *MessageBuilder, transport: Transport) !void {
|
||||
try frame.sendWebSocketFrame(transport, self.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
/// Validation helpers
|
||||
pub fn validateApiKeyHash(api_key_hash: []const u8) error{InvalidApiKeyHash}!void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
}
|
||||
|
||||
pub fn validateCommitId(commit_id: []const u8) error{InvalidCommitId}!void {
|
||||
if (commit_id.len != 20) return error.InvalidCommitId;
|
||||
}
|
||||
|
||||
pub fn validateJobName(job_name: []const u8) error{JobNameTooLong}!void {
|
||||
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
|
||||
}
|
||||
|
||||
/// Build and send validate request
|
||||
pub fn sendValidateRequestCommit(transport: Transport, allocator: std.mem.Allocator, api_key_hash: []const u8, commit_id: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateCommitId(commit_id);
|
||||
|
||||
var builder = try MessageBuilder.init(allocator, 1 + 16 + 1 + 1 + 20);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.Opcode.validate_request);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeU8(@intFromEnum(opcode.ValidateTargetType.commit_id));
|
||||
builder.writeU8(20);
|
||||
builder.writeBytes(commit_id);
|
||||
try builder.send(transport);
|
||||
}
|
||||
|
||||
/// Build and send queue job with resources
|
||||
pub fn sendQueueJob(
|
||||
transport: Transport,
|
||||
allocator: std.mem.Allocator,
|
||||
job_name: []const u8,
|
||||
commit_id: []const u8,
|
||||
priority: u8,
|
||||
api_key_hash: []const u8,
|
||||
cpu: u8,
|
||||
memory_gb: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
) !void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (commit_id.len != 20) return error.InvalidCommitId;
|
||||
if (job_name.len > 255) return error.JobNameTooLong;
|
||||
|
||||
const gpu_mem = gpu_memory orelse "";
|
||||
if (gpu_mem.len > 255) return error.PayloadTooLarge;
|
||||
|
||||
const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 4 + gpu_mem.len;
|
||||
var buffer = try allocator.alloc(u8, total_len);
|
||||
defer allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.queue_job);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 20], commit_id);
|
||||
offset += 20;
|
||||
|
||||
buffer[offset] = priority;
|
||||
offset += 1;
|
||||
|
||||
buffer[offset] = @intCast(job_name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
buffer[offset] = cpu;
|
||||
buffer[offset + 1] = memory_gb;
|
||||
buffer[offset + 2] = gpu;
|
||||
buffer[offset + 3] = @intCast(gpu_mem.len);
|
||||
offset += 4;
|
||||
|
||||
if (gpu_mem.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(transport, buffer);
|
||||
}
|
||||
|
||||
/// Build and send cancel job request
|
||||
pub fn sendCancelJob(transport: Transport, allocator: std.mem.Allocator, job_name: []const u8, api_key_hash: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateJobName(job_name);
|
||||
|
||||
var builder = try MessageBuilder.init(allocator, 1 + 16 + 1 + job_name.len);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.cancel_job);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeStringU8(job_name);
|
||||
try builder.send(transport);
|
||||
}
|
||||
|
||||
/// Build and send status request
|
||||
pub fn sendStatusRequest(transport: Transport, allocator: std.mem.Allocator, api_key_hash: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
||||
var builder = try MessageBuilder.init(allocator, 1 + 16);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.status_request);
|
||||
builder.writeBytes(api_key_hash);
|
||||
try builder.send(transport);
|
||||
}
|
||||
|
||||
/// Build and send sync run request
|
||||
pub fn sendSyncRun(transport: Transport, allocator: std.mem.Allocator, sync_json: []const u8, api_key_hash: []const u8) !void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (sync_json.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
const total_len = 1 + 16 + 2 + sync_json.len;
|
||||
var buffer = try allocator.alloc(u8, total_len);
|
||||
defer allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.sync_run);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(sync_json.len), .big);
|
||||
offset += 2;
|
||||
|
||||
if (sync_json.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + sync_json.len], sync_json);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(transport, buffer);
|
||||
}
|
||||
|
||||
/// Build and send queue job with args and resources
|
||||
pub fn sendQueueJobWithArgsAndResources(
|
||||
transport: Transport,
|
||||
allocator: std.mem.Allocator,
|
||||
job_name: []const u8,
|
||||
commit_id: []const u8,
|
||||
priority: u8,
|
||||
api_key_hash: []const u8,
|
||||
args: []const u8,
|
||||
force: bool,
|
||||
cpu: u8,
|
||||
memory_gb: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
) !void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (commit_id.len != 20) return error.InvalidCommitId;
|
||||
if (job_name.len > 255) return error.JobNameTooLong;
|
||||
if (args.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
const gpu_mem = gpu_memory orelse "";
|
||||
if (gpu_mem.len > 255) return error.PayloadTooLarge;
|
||||
|
||||
const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 1 + 4 + gpu_mem.len;
|
||||
var buffer = try allocator.alloc(u8, total_len);
|
||||
defer allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.queue_job_with_args);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 20], commit_id);
|
||||
offset += 20;
|
||||
|
||||
buffer[offset] = priority;
|
||||
offset += 1;
|
||||
|
||||
buffer[offset] = @intCast(job_name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
buffer[offset] = @intCast((args.len >> 8) & 0xFF);
|
||||
buffer[offset + 1] = @intCast(args.len & 0xFF);
|
||||
offset += 2;
|
||||
|
||||
if (args.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + args.len], args);
|
||||
offset += args.len;
|
||||
}
|
||||
|
||||
buffer[offset] = if (force) 0x01 else 0x00;
|
||||
offset += 1;
|
||||
|
||||
buffer[offset] = cpu;
|
||||
buffer[offset + 1] = memory_gb;
|
||||
buffer[offset + 2] = gpu;
|
||||
buffer[offset + 3] = @intCast(gpu_mem.len);
|
||||
offset += 4;
|
||||
|
||||
if (gpu_mem.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(transport, buffer);
|
||||
}
|
||||
|
||||
/// Build and send restore Jupyter request
|
||||
pub fn sendRestoreJupyter(transport: Transport, allocator: std.mem.Allocator, name: []const u8, api_key_hash: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
if (name.len > 255) return error.NameTooLong;
|
||||
|
||||
var builder = try MessageBuilder.init(allocator, 1 + 16 + 1 + name.len);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.restore_jupyter);
|
||||
builder.writeBytes(api_key_hash);
|
||||
builder.writeStringU8(name);
|
||||
try builder.send(transport);
|
||||
}
|
||||
|
||||
/// Build and send list Jupyter packages request
|
||||
pub fn sendListJupyter(transport: Transport, allocator: std.mem.Allocator, api_key_hash: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
||||
var builder = try MessageBuilder.init(allocator, 1 + 16);
|
||||
defer builder.deinit();
|
||||
|
||||
builder.writeOpcode(opcode.list_jupyter_packages);
|
||||
builder.writeBytes(api_key_hash);
|
||||
try builder.send(transport);
|
||||
}
|
||||
112
cli/src/net/ws/client/mod.zig
Normal file
112
cli/src/net/ws/client/mod.zig
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
const std = @import("std");
|
||||
|
||||
// Sub-modules
|
||||
pub const connection = @import("connection.zig");
|
||||
pub const messaging = @import("messaging.zig");
|
||||
pub const state = @import("state.zig");
|
||||
|
||||
// Re-export key types for convenience
|
||||
pub const Transport = connection.Transport;
|
||||
pub const ConnectionConfig = connection.ConnectionConfig;
|
||||
pub const ConnectionResult = connection.ConnectionResult;
|
||||
pub const MessageBuilder = messaging.MessageBuilder;
|
||||
pub const ClientState = state.ClientState;
|
||||
|
||||
/// Main WebSocket client combining all functionality
|
||||
pub const Client = struct {
|
||||
state: ClientState,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, transport: Transport, host: []const u8, port: u16, is_tls: bool) Client {
|
||||
return .{
|
||||
.state = ClientState.init(allocator, transport, host, port, is_tls),
|
||||
};
|
||||
}
|
||||
|
||||
/// Connect to server with retry
|
||||
pub fn connect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !Client {
|
||||
const result = try connection.connectWithRetry(allocator, url, api_key, 3);
|
||||
return Client.init(allocator, result.transport, result.host, result.port, result.is_tls);
|
||||
}
|
||||
|
||||
/// Connect with custom retry count
|
||||
pub fn connectWithRetry(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8, max_retries: u32) !Client {
|
||||
const result = try connection.connectWithRetry(allocator, url, api_key, max_retries);
|
||||
return Client.init(allocator, result.transport, result.host, result.port, result.is_tls);
|
||||
}
|
||||
|
||||
/// Disconnect
|
||||
pub fn disconnect(self: *Client) void {
|
||||
self.state.disconnect();
|
||||
}
|
||||
|
||||
/// Close and cleanup
|
||||
pub fn close(self: *Client) void {
|
||||
self.state.close();
|
||||
}
|
||||
|
||||
/// Get transport for messaging
|
||||
pub fn getTransport(self: *Client) Transport {
|
||||
return self.state.getTransport();
|
||||
}
|
||||
|
||||
// Delegate to state module
|
||||
pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 {
|
||||
return self.state.receiveMessage(allocator);
|
||||
}
|
||||
|
||||
pub fn receiveAndHandleResponse(self: *Client, allocator: std.mem.Allocator, operation: []const u8) !void {
|
||||
return self.state.receiveAndHandleResponse(allocator, operation);
|
||||
}
|
||||
|
||||
pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void {
|
||||
return self.state.receiveAndHandleStatusResponse(allocator, user_context, options);
|
||||
}
|
||||
|
||||
pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void {
|
||||
return self.state.receiveAndHandleCancelResponse(allocator, user_context, job_name, options);
|
||||
}
|
||||
|
||||
pub fn handleResponsePacket(self: *Client, packet: anytype, operation: []const u8) !void {
|
||||
return self.state.handleResponsePacket(packet, operation);
|
||||
}
|
||||
|
||||
pub fn receiveAndHandleDatasetResponse(self: *Client, allocator: std.mem.Allocator) ![]const u8 {
|
||||
return self.state.receiveAndHandleDatasetResponse(allocator);
|
||||
}
|
||||
|
||||
// Messaging methods - delegate to messaging module
|
||||
pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void {
|
||||
return messaging.sendValidateRequestCommit(self.getTransport(), self.state.allocator, api_key_hash, commit_id);
|
||||
}
|
||||
|
||||
pub fn sendQueueJob(self: *Client, job_name: []const u8, commit_id: []const u8, priority: u8, api_key_hash: []const u8, cpu: u8, memory_gb: u8, gpu: u8, gpu_memory: ?[]const u8) !void {
|
||||
return messaging.sendQueueJob(self.getTransport(), self.state.allocator, job_name, commit_id, priority, api_key_hash, cpu, memory_gb, gpu, gpu_memory);
|
||||
}
|
||||
|
||||
pub fn sendQueueJobWithArgsAndResources(self: *Client, job_name: []const u8, commit_id: []const u8, priority: u8, api_key_hash: []const u8, args: []const u8, force: bool, cpu: u8, memory_gb: u8, gpu: u8, gpu_memory: ?[]const u8) !void {
|
||||
return messaging.sendQueueJobWithArgsAndResources(self.getTransport(), self.state.allocator, job_name, commit_id, priority, api_key_hash, args, force, cpu, memory_gb, gpu, gpu_memory);
|
||||
}
|
||||
|
||||
pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void {
|
||||
return messaging.sendCancelJob(self.getTransport(), self.state.allocator, job_name, api_key_hash);
|
||||
}
|
||||
|
||||
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
|
||||
return messaging.sendStatusRequest(self.getTransport(), self.state.allocator, api_key_hash);
|
||||
}
|
||||
|
||||
pub fn sendSyncRun(self: *Client, sync_json: []const u8, api_key_hash: []const u8) !void {
|
||||
return messaging.sendSyncRun(self.getTransport(), self.state.allocator, sync_json, api_key_hash);
|
||||
}
|
||||
|
||||
pub fn sendRestoreJupyter(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
|
||||
return messaging.sendRestoreJupyter(self.getTransport(), self.state.allocator, name, api_key_hash);
|
||||
}
|
||||
|
||||
pub fn sendListJupyter(self: *Client, api_key_hash: []const u8) !void {
|
||||
return messaging.sendListJupyter(self.getTransport(), self.state.allocator, api_key_hash);
|
||||
}
|
||||
};
|
||||
|
||||
/// Format prewarm info from status
|
||||
pub const formatPrewarmFromStatusRoot = state.formatPrewarmFromStatusRoot;
|
||||
133
cli/src/net/ws/client/state.zig
Normal file
133
cli/src/net/ws/client/state.zig
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
const std = @import("std");
|
||||
const Transport = @import("connection.zig").Transport;
|
||||
const protocol = @import("../../protocol.zig");
|
||||
const response_handlers = @import("../response_handlers.zig");
|
||||
const frame = @import("../frame.zig");
|
||||
|
||||
/// WebSocket client state and response handling
|
||||
pub const ClientState = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
transport: Transport,
|
||||
host: []const u8,
|
||||
port: u16,
|
||||
is_tls: bool = false,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, transport: Transport, host: []const u8, port: u16, is_tls: bool) ClientState {
|
||||
return .{
|
||||
.allocator = allocator,
|
||||
.transport = transport,
|
||||
.host = host,
|
||||
.port = port,
|
||||
.is_tls = is_tls,
|
||||
};
|
||||
}
|
||||
|
||||
/// Disconnect transport only
|
||||
pub fn disconnect(self: *ClientState) void {
|
||||
self.transport.close();
|
||||
}
|
||||
|
||||
/// Fully close - disconnect and free resources
|
||||
pub fn close(self: *ClientState) void {
|
||||
self.disconnect();
|
||||
if (self.host.len > 0) {
|
||||
self.allocator.free(self.host);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get transport for I/O operations
|
||||
pub fn getTransport(self: *ClientState) Transport {
|
||||
return self.transport;
|
||||
}
|
||||
|
||||
/// Receive binary message
|
||||
pub fn receiveMessage(self: *ClientState, allocator: std.mem.Allocator) ![]u8 {
|
||||
return frame.receiveBinaryMessage(self.transport, allocator);
|
||||
}
|
||||
|
||||
/// Receive and handle response with automatic display
|
||||
pub fn receiveAndHandleResponse(self: *ClientState, allocator: std.mem.Allocator, operation: []const u8) !void {
|
||||
const message = try self.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
std.debug.print("Server response: {s}\n", .{message});
|
||||
return;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
try response_handlers.handleResponsePacket(self, packet, operation);
|
||||
}
|
||||
|
||||
/// Receive and handle status response
|
||||
pub fn receiveAndHandleStatusResponse(self: *ClientState, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void {
|
||||
return response_handlers.receiveAndHandleStatusResponse(self, allocator, user_context, options);
|
||||
}
|
||||
|
||||
/// Receive and handle cancel response
|
||||
pub fn receiveAndHandleCancelResponse(self: *ClientState, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void {
|
||||
return response_handlers.receiveAndHandleCancelResponse(self, allocator, user_context, job_name, options);
|
||||
}
|
||||
|
||||
/// Handle response packet
|
||||
pub fn handleResponsePacket(self: *ClientState, packet: protocol.ResponsePacket, operation: []const u8) !void {
|
||||
return response_handlers.handleResponsePacket(self, packet, operation);
|
||||
}
|
||||
|
||||
/// Convert server error to CLI error
|
||||
pub fn convertServerError(self: *ClientState, server_error: protocol.ErrorCode) anyerror {
|
||||
_ = self;
|
||||
return switch (server_error) {
|
||||
.authentication_failed => error.AuthenticationFailed,
|
||||
.permission_denied => error.PermissionDenied,
|
||||
.resource_not_found => error.JobNotFound,
|
||||
.resource_already_exists => error.ResourceExists,
|
||||
.timeout => error.RequestTimeout,
|
||||
.server_overloaded, .service_unavailable => error.ServerUnreachable,
|
||||
.invalid_request => error.InvalidArguments,
|
||||
.job_not_found => error.JobNotFound,
|
||||
.job_already_running => error.JobAlreadyRunning,
|
||||
.job_failed_to_start, .job_execution_failed => error.CommandFailed,
|
||||
.job_cancelled => error.JobCancelled,
|
||||
else => error.ServerError,
|
||||
};
|
||||
}
|
||||
|
||||
/// Receive and handle dataset response
|
||||
pub fn receiveAndHandleDatasetResponse(self: *ClientState, allocator: std.mem.Allocator) ![]const u8 {
|
||||
const message = try self.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
return allocator.dupe(u8, message);
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.data => {
|
||||
if (packet.data_payload) |payload| {
|
||||
return allocator.dupe(u8, payload);
|
||||
}
|
||||
return allocator.dupe(u8, "");
|
||||
},
|
||||
.success => {
|
||||
if (packet.success_message) |msg| {
|
||||
return allocator.dupe(u8, msg);
|
||||
}
|
||||
return allocator.dupe(u8, "");
|
||||
},
|
||||
.error_packet => {
|
||||
_ = response_handlers.handleResponsePacket(self, packet, "Dataset") catch {};
|
||||
return self.convertServerError(packet.error_code.?);
|
||||
},
|
||||
else => {
|
||||
return error.UnexpectedResponse;
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Format prewarm info from status root
|
||||
pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 {
|
||||
return @import("../response.zig").formatPrewarmFromStatusRoot(allocator, root);
|
||||
}
|
||||
Loading…
Reference in a new issue