From cf8115c67019559aadb35ec36804b99974318cf6 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 5 Mar 2026 12:07:41 -0500 Subject: [PATCH] feat(cli): standardize connection handling across commands Add isConnected() method to common.ConnectionContext to check WebSocket client connection state. Migrate all server-connected commands to use the standardized ConnectionContext pattern: - jupyter/lifecycle.zig: Replace local ConnectionCtx with common.ConnectionContext - status.zig: Use ConnectionContext, remove manual connection boilerplate, add connection status indicators (connecting/connected) - cancel.zig: Use ConnectionContext for server cancel operations - dataset.zig: Use ConnectionContext for list/register/info/search operations - exec/remote.zig: Use ConnectionContext for remote job execution Benefits: - Eliminates ~160 lines of duplicated connection boilerplate - Consistent error handling and cleanup across commands - Single point of change for connection logic - Adds runtime connection state visibility to status command --- cli/src/commands/cancel.zig | 17 ++-- cli/src/commands/common.zig | 15 ++- cli/src/commands/dataset.zig | 125 +++++++++---------------- cli/src/commands/exec/remote.zig | 19 ++-- cli/src/commands/jupyter/lifecycle.zig | 45 ++------- cli/src/commands/status.zig | 32 ++++--- 6 files changed, 93 insertions(+), 160 deletions(-) diff --git a/cli/src/commands/cancel.zig b/cli/src/commands/cancel.zig index 6071d8b..ac238d2 100644 --- a/cli/src/commands/cancel.zig +++ b/cli/src/commands/cancel.zig @@ -6,6 +6,7 @@ const crypto = @import("../utils/crypto.zig"); const core = @import("../core.zig"); const mode = @import("../mode.zig"); const manifest_lib = @import("../manifest.zig"); +const common = @import("common.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var flags = core.flags.CommonFlags{}; @@ -176,20 +177,16 @@ fn isProcessRunning(pid: i32) bool { fn cancelServer(allocator: std.mem.Allocator, job_name: []const u8, force: bool, json: bool, cfg: config.Config) !void { _ = force; _ = json; + _ = cfg; - const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); - defer allocator.free(api_key_hash); + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); - const ws_url = try cfg.getWebSocketUrl(allocator); - defer allocator.free(ws_url); - - var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); - defer client.close(); - - try client.sendCancelJob(job_name, api_key_hash); + try ctx.client.sendCancelJob(job_name, ctx.api_key_hash); // Wait for acknowledgment - const message = try client.receiveMessage(allocator); + const message = try ctx.client.receiveMessage(allocator); defer allocator.free(message); // Parse response (simplified) diff --git a/cli/src/commands/common.zig b/cli/src/commands/common.zig index 76fa275..35a5fcc 100644 --- a/cli/src/commands/common.zig +++ b/cli/src/commands/common.zig @@ -7,7 +7,7 @@ const crypto = @import("../utils/crypto.zig"); pub const ConnectionContext = struct { allocator: std.mem.Allocator, config: Config, - client: *ws.Client, + client: ws.Client, api_key_hash: []const u8, ws_url: []const u8, @@ -21,21 +21,22 @@ pub const ConnectionContext = struct { const ws_url = try config.getWebSocketUrl(allocator); errdefer allocator.free(ws_url); - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - errdefer client.close(); - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); errdefer allocator.free(api_key_hash); return ConnectionContext{ .allocator = allocator, .config = config, - .client = &client, + .client = undefined, .api_key_hash = api_key_hash, .ws_url = ws_url, }; } + pub fn connect(self: *ConnectionContext) !void { + self.client = try ws.Client.connect(self.allocator, self.ws_url, self.config.api_key); + } + pub fn deinit(self: *ConnectionContext) void { self.client.close(); self.allocator.free(self.api_key_hash); @@ -43,6 +44,10 @@ pub const ConnectionContext = struct { var mut_config = self.config; mut_config.deinit(self.allocator); } + + pub fn isConnected(self: *const ConnectionContext) bool { + return self.client.connected; + } }; /// Execute operation with standard config + WebSocket setup diff --git a/cli/src/commands/dataset.zig b/cli/src/commands/dataset.zig index d6e57de..0d7bb03 100644 --- a/cli/src/commands/dataset.zig +++ b/cli/src/commands/dataset.zig @@ -3,6 +3,7 @@ const Config = @import("../config.zig").Config; const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const core = @import("../core.zig"); +const common = @import("common.zig"); const DatasetOptions = struct { dry_run: bool = false, @@ -108,16 +109,6 @@ fn printUsage() void { } fn listDatasets(allocator: std.mem.Allocator, options: *const DatasetOptions) !void { - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); - } - - // Connect to WebSocket and request dataset list - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); - if (options.validate) { if (options.json) { const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; @@ -130,12 +121,6 @@ fn listDatasets(allocator: std.mem.Allocator, options: *const DatasetOptions) !v return; } - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); - - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); - if (options.dry_run) { if (options.json) { const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; @@ -148,10 +133,14 @@ fn listDatasets(allocator: std.mem.Allocator, options: *const DatasetOptions) !v return; } - try client.sendDatasetList(api_key_hash); + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); + + try ctx.client.sendDatasetList(ctx.api_key_hash); // Receive and display dataset list - const response = try client.receiveAndHandleDatasetResponse(allocator); + const response = try ctx.client.receiveAndHandleDatasetResponse(allocator); defer allocator.free(response); if (options.json) { @@ -173,26 +162,20 @@ fn listDatasets(allocator: std.mem.Allocator, options: *const DatasetOptions) !v } fn registerDataset(allocator: std.mem.Allocator, name: []const u8, url: []const u8, options: *const DatasetOptions) !void { - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); + // Validate URL format (always check) + if (!std.mem.startsWith(u8, url, "http://") and !std.mem.startsWith(u8, url, "https://") and + !std.mem.startsWith(u8, url, "s3://") and !std.mem.startsWith(u8, url, "gs://")) + { + if (!options.validate) { + std.debug.print("Invalid URL format. Supported: http://, https://, s3://, gs://\n", .{}); + } + return error.InvalidURL; } - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); - if (options.validate) { if (name.len == 0 or name.len > 255) return error.InvalidArgs; if (url.len == 0 or url.len > 1023) return error.InvalidURL; - // Validate URL format - if (!std.mem.startsWith(u8, url, "http://") and !std.mem.startsWith(u8, url, "https://") and - !std.mem.startsWith(u8, url, "s3://") and !std.mem.startsWith(u8, url, "gs://")) - { - return error.InvalidURL; - } - if (options.json) { const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; var buffer: [4096]u8 = undefined; @@ -207,16 +190,6 @@ fn registerDataset(allocator: std.mem.Allocator, name: []const u8, url: []const return; } - // Validate URL format - if (!std.mem.startsWith(u8, url, "http://") and !std.mem.startsWith(u8, url, "https://") and - !std.mem.startsWith(u8, url, "s3://") and !std.mem.startsWith(u8, url, "gs://")) - { - std.debug.print("Invalid URL format. Supported: http://, https://, s3://, gs://\n", .{}); - return error.InvalidURL; - } - - // Connect to WebSocket and register dataset - if (options.dry_run) { if (options.json) { const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; @@ -232,16 +205,15 @@ fn registerDataset(allocator: std.mem.Allocator, name: []const u8, url: []const return; } - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); + // Connect to WebSocket and register dataset + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); - - try client.sendDatasetRegister(name, url, api_key_hash); + try ctx.client.sendDatasetRegister(name, url, ctx.api_key_hash); // Receive response - const response = try client.receiveAndHandleDatasetResponse(allocator); + const response = try ctx.client.receiveAndHandleDatasetResponse(allocator); defer allocator.free(response); if (options.json) { @@ -263,15 +235,6 @@ fn registerDataset(allocator: std.mem.Allocator, name: []const u8, url: []const } fn showDatasetInfo(allocator: std.mem.Allocator, name: []const u8, options: *const DatasetOptions) !void { - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); - } - - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); - if (options.validate) { if (name.len == 0 or name.len > 255) return error.InvalidArgs; if (options.json) { @@ -288,8 +251,6 @@ fn showDatasetInfo(allocator: std.mem.Allocator, name: []const u8, options: *con return; } - // Connect to WebSocket and get dataset info - if (options.dry_run) { if (options.json) { const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; @@ -305,16 +266,15 @@ fn showDatasetInfo(allocator: std.mem.Allocator, name: []const u8, options: *con return; } - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); + // Connect to WebSocket and get dataset info + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); - - try client.sendDatasetInfo(name, api_key_hash); + try ctx.client.sendDatasetInfo(name, ctx.api_key_hash); // Receive response - const response = try client.receiveAndHandleDatasetResponse(allocator); + const response = try ctx.client.receiveAndHandleDatasetResponse(allocator); defer allocator.free(response); if (options.json) { @@ -333,15 +293,6 @@ fn showDatasetInfo(allocator: std.mem.Allocator, name: []const u8, options: *con } fn searchDatasets(allocator: std.mem.Allocator, term: []const u8, options: *const DatasetOptions) !void { - const config = try Config.load(allocator); - defer { - var mut_config = config; - mut_config.deinit(allocator); - } - - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); - if (options.validate) { if (term.len == 0 or term.len > 255) return error.InvalidArgs; if (options.json) { @@ -358,16 +309,26 @@ fn searchDatasets(allocator: std.mem.Allocator, term: []const u8, options: *cons return; } - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); + if (options.dry_run) { + if (options.json) { + const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; + var buffer: [4096]u8 = undefined; + const formatted = std.fmt.bufPrint(&buffer, "{{\"dry_run\":true,\"action\":\"search\",\"term\":\"{s}\"}}\n", .{term}) catch unreachable; + try stdout_file.writeAll(formatted); + } else { + std.debug.print("Dry run: would search datasets for '{s}'\n", .{term}); + } + return; + } - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); - try client.sendDatasetSearch(term, api_key_hash); + try ctx.client.sendDatasetSearch(term, ctx.api_key_hash); // Receive response - const response = try client.receiveAndHandleDatasetResponse(allocator); + const response = try ctx.client.receiveAndHandleDatasetResponse(allocator); defer allocator.free(response); if (options.json) { diff --git a/cli/src/commands/exec/remote.zig b/cli/src/commands/exec/remote.zig index 0ac580f..65517c7 100644 --- a/cli/src/commands/exec/remote.zig +++ b/cli/src/commands/exec/remote.zig @@ -4,6 +4,7 @@ const ws = @import("../../net/ws/client.zig"); const crypto = @import("../../utils/crypto.zig"); const protocol = @import("../../net/protocol.zig"); const history = @import("../../utils/history.zig"); +const common = @import("../common.zig"); /// Execute job on remote server pub fn execute( @@ -14,17 +15,13 @@ pub fn execute( args_str: []const u8, cfg: config.Config, ) !void { + _ = cfg; // Use queue command logic for remote execution std.log.info("Queueing job on remote server: {s}", .{job_name}); - const ws_url = try cfg.getWebSocketUrl(allocator); - defer allocator.free(ws_url); - - var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); - defer client.close(); - - const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); - defer allocator.free(api_key_hash); + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + try ctx.connect(); // Generate commit ID var commit_bytes: [20]u8 = undefined; @@ -35,11 +32,11 @@ pub fn execute( defer if (narrative_json) |j| allocator.free(j); // Send queue request - try client.sendQueueJobWithArgsAndResources( + try ctx.client.sendQueueJobWithArgsAndResources( job_name, &commit_bytes, priority, - api_key_hash, + ctx.api_key_hash, args_str, false, // force options.cpu, @@ -49,7 +46,7 @@ pub fn execute( ); // Receive response - const message = try client.receiveMessage(allocator); + const message = try ctx.client.receiveMessage(allocator); defer allocator.free(message); const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { diff --git a/cli/src/commands/jupyter/lifecycle.zig b/cli/src/commands/jupyter/lifecycle.zig index b07c1fc..e52c1a9 100644 --- a/cli/src/commands/jupyter/lifecycle.zig +++ b/cli/src/commands/jupyter/lifecycle.zig @@ -4,47 +4,10 @@ const ws = @import("../../net/ws/client.zig"); const crypto = @import("../../utils/crypto.zig"); const protocol = @import("../../net/protocol.zig"); const validation = @import("validation.zig"); +const common = @import("../common.zig"); /// Context holding connection resources for cleanup -const ConnectionCtx = struct { - config: Config, - client: ws.Client, - api_key_hash: []const u8, - ws_url: []const u8, - allocator: std.mem.Allocator, - - fn init(allocator: std.mem.Allocator) !ConnectionCtx { - const config = try Config.load(allocator); - errdefer { - var mut = config; - mut.deinit(allocator); - } - - const ws_url = try config.getWebSocketUrl(allocator); - errdefer allocator.free(ws_url); - - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - errdefer client.close(); - - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - - return ConnectionCtx{ - .config = config, - .client = client, - .api_key_hash = api_key_hash, - .ws_url = ws_url, - .allocator = allocator, - }; - } - - fn deinit(self: *ConnectionCtx) void { - self.allocator.free(self.api_key_hash); - self.allocator.free(self.ws_url); - self.client.close(); - var mut = self.config; - mut.deinit(self.allocator); - } -}; +const ConnectionCtx = common.ConnectionContext; /// Create a new Jupyter workspace and start it pub fn createJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void { @@ -126,6 +89,7 @@ pub fn startJupyter(allocator: std.mem.Allocator, args: []const []const u8) !voi var ctx = try ConnectionCtx.init(allocator); defer ctx.deinit(); + try ctx.connect(); std.debug.print("Starting Jupyter service '{s}'...\n", .{name}); @@ -178,6 +142,7 @@ pub fn stopJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void var ctx = try ConnectionCtx.init(allocator); defer ctx.deinit(); + try ctx.connect(); std.debug.print("Stopping service {s}...\n", .{service_id}); @@ -264,6 +229,7 @@ pub fn removeJupyter(allocator: std.mem.Allocator, args: []const []const u8) !vo var ctx = try ConnectionCtx.init(allocator); defer ctx.deinit(); + try ctx.connect(); if (purge) { std.debug.print("Permanently deleting service {s}...\n", .{service_id}); @@ -318,6 +284,7 @@ pub fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8, js var ctx = try ConnectionCtx.init(allocator); defer ctx.deinit(); + try ctx.connect(); std.debug.print("Restoring workspace {s}...", .{name}); diff --git a/cli/src/commands/status.zig b/cli/src/commands/status.zig index 7baa66c..646f5d8 100644 --- a/cli/src/commands/status.zig +++ b/cli/src/commands/status.zig @@ -5,6 +5,7 @@ const crypto = @import("../utils/crypto.zig"); const io = @import("../utils/io.zig"); const auth = @import("../utils/auth.zig"); const core = @import("../core.zig"); +const common = @import("common.zig"); pub const StatusOptions = struct { json: bool = false, @@ -56,29 +57,34 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { defer user_context.deinit(); if (options.watch) { - try runWatchMode(allocator, config, user_context, options); + try runWatchMode(allocator, user_context, options); } else if (options.tui) { try runTuiMode(allocator, config, args); } else { - try runSingleStatus(allocator, config, user_context, options); + try runSingleStatus(allocator, user_context, options); } } -fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: auth.UserContext, options: StatusOptions) !void { - const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); - defer allocator.free(api_key_hash); +fn runSingleStatus(allocator: std.mem.Allocator, user_context: auth.UserContext, options: StatusOptions) !void { + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); - const ws_url = try config.getWebSocketUrl(allocator); - defer allocator.free(ws_url); + // Show connection status before attempting connection + if (!options.json) { + std.debug.print("Remote: connecting...\n", .{}); + } - var client = try ws.Client.connect(allocator, ws_url, config.api_key); - defer client.close(); + try ctx.connect(); - try client.sendStatusRequest(api_key_hash); - try client.receiveAndHandleStatusResponse(allocator, user_context, options); + if (!options.json) { + std.debug.print("Remote: connected\n", .{}); + } + + try ctx.client.sendStatusRequest(ctx.api_key_hash); + try ctx.client.receiveAndHandleStatusResponse(allocator, user_context, options); } -fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: auth.UserContext, options: StatusOptions) !void { +fn runWatchMode(allocator: std.mem.Allocator, user_context: auth.UserContext, options: StatusOptions) !void { std.debug.print("Starting watch mode (interval: {d}s). Press Ctrl+C to stop.\n", .{options.watch_interval}); while (true) { @@ -86,7 +92,7 @@ fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: auth std.debug.print("\n=== FetchML Status - {s} ===", .{user_context.name}); } - try runSingleStatus(allocator, config, user_context, options); + try runSingleStatus(allocator, user_context, options); if (!options.json) { std.debug.print("Next update in {d} seconds...\n", .{options.watch_interval});