From c6a224d5fcdc5bfa2b8109b24e9730210885059e Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 5 Mar 2026 12:07:00 -0500 Subject: [PATCH] feat(cli,server): unify info command with remote/local support Enhance ml info to query server when connected, falling back to local manifests when offline. Unifies behavior with other commands like run, exec, and cancel. CLI changes: - Add --local and --remote flags for explicit control - Auto-detect connection state via mode.detect() - queryRemoteRun(): Query server via WebSocket for run details - queryLocalRun(): Read local run_manifest.json - displayRunInfo(): Shared display logic for both sources - Add connection status indicators (Remote: connecting.../connected) WebSocket protocol: - Add query_run_info opcode (0x28) to cli and server - Add sendQueryRunInfo() method to ws/client.zig - Protocol: [opcode:1][api_key_hash:16][run_id_len:1][run_id:var] Server changes: - Add handleQueryRunInfo() handler to ws/handler.go - Returns run_id, job_name, user, timestamp, overall_sha, files_count - Checks PermJobsRead permission - Looks up run in experiment manager Usage: ml info abc123 # Auto: tries remote, falls back to local ml info abc123 --local # Force local manifest lookup ml info abc123 --remote # Force remote query (fails if offline) --- cli/src/commands/info.zig | 102 +++++++++++++++++++++++++++++++++++-- cli/src/net/ws/client.zig | 32 +++++++++++- cli/src/net/ws/opcode.zig | 4 ++ internal/api/ws/handler.go | 65 +++++++++++++++++++++++ 4 files changed, 197 insertions(+), 6 deletions(-) diff --git a/cli/src/commands/info.zig b/cli/src/commands/info.zig index 0b8fc33..2cd0ac0 100644 --- a/cli/src/commands/info.zig +++ b/cli/src/commands/info.zig @@ -4,16 +4,22 @@ const io = @import("../utils/io.zig"); const json = @import("../utils/json.zig"); const manifest = @import("../utils/manifest.zig"); const core = @import("../core.zig"); +const mode = @import("../mode.zig"); +const common = @import("common.zig"); pub const Options = struct { json: bool = false, base: ?[]const u8 = null, + local: bool = false, + remote: bool = false, }; pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var flags = core.flags.CommonFlags{}; var base: ?[]const u8 = null; var target_path: ?[]const u8 = null; + var force_local = false; + var force_remote = false; var i: usize = 0; while (i < args.len) : (i += 1) { @@ -23,6 +29,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } else if (std.mem.eql(u8, arg, "--base") and i + 1 < args.len) { base = args[i + 1]; i += 1; + } else if (std.mem.eql(u8, arg, "--local")) { + force_local = true; + } else if (std.mem.eql(u8, arg, "--remote")) { + force_remote = true; } else if (std.mem.startsWith(u8, arg, "--help")) { return printUsage(); } else if (std.mem.startsWith(u8, arg, "--")) { @@ -40,7 +50,73 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { return printUsage(); } - const manifest_path = manifest.resolvePathWithBase(allocator, target_path.?, base) catch |err| { + // Load config for mode detection + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Determine execution mode + const mode_result = try mode.detect(allocator, cfg); + const use_remote = if (force_local) false else if (force_remote) true else mode.isOnline(mode_result.mode); + + if (use_remote) { + // Try remote query first + queryRemoteRun(allocator, target_path.?, flags.json) catch |err| { + if (!flags.json) { + std.debug.print("Remote query failed ({}), falling back to local...\n", .{err}); + } + // Fall back to local + try queryLocalRun(allocator, target_path.?, base, flags.json); + }; + } else { + // Local-only mode + try queryLocalRun(allocator, target_path.?, base, flags.json); + } +} + +fn queryRemoteRun(allocator: std.mem.Allocator, run_id: []const u8, json_mode: bool) !void { + var ctx = try common.ConnectionContext.init(allocator); + defer ctx.deinit(); + + if (!json_mode) { + std.debug.print("Remote: connecting...\n", .{}); + } + + try ctx.connect(); + + if (!json_mode) { + std.debug.print("Remote: connected\n", .{}); + } + + try ctx.client.sendQueryRunInfo(run_id, ctx.api_key_hash); + + const response = try ctx.client.receiveMessage(allocator); + defer allocator.free(response); + + // Parse response as JSON + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, response, .{}); + defer parsed.deinit(); + + if (parsed.value != .object) { + return error.InvalidResponse; + } + + // Check for error + if (json.getString(parsed.value.object, "error")) |err_msg| { + if (!json_mode) { + std.debug.print("Error: {s}\n", .{err_msg}); + } + return error.RemoteQueryFailed; + } + + // Display the run info + try displayRunInfo(allocator, parsed.value.object, null, json_mode); +} + +fn queryLocalRun(allocator: std.mem.Allocator, target: []const u8, base: ?[]const u8, json_mode: bool) !void { + const manifest_path = manifest.resolvePathWithBase(allocator, target, base) catch |err| { if (err == error.FileNotFound) { core.output.err("Manifest not found"); } @@ -53,7 +129,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { allocator.free(data); } - if (flags.json) { + if (json_mode) { var out = io.stdoutWriter(); try out.print("{s}\n", .{data}); return; @@ -67,7 +143,16 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { return error.InvalidManifest; } - const root = parsed.value.object; + try displayRunInfo(allocator, parsed.value.object, manifest_path, false); +} + +fn displayRunInfo(allocator: std.mem.Allocator, root: std.json.ObjectMap, manifest_path: ?[]const u8, json_mode: bool) !void { + _ = allocator; + + if (json_mode) { + // Already printed in queryRemoteRun + return; + } const run_id = json.getString(root, "run_id") orelse ""; const task_id = json.getString(root, "task_id") orelse ""; @@ -95,7 +180,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { const finalize_ms = json.getInt(root, "finalize_duration_ms") orelse 0; const total_ms = json.getInt(root, "total_duration_ms") orelse 0; - std.debug.print("run_manifest\t{s}\n", .{manifest_path}); + if (manifest_path) |path| { + std.debug.print("run_manifest\t{s}\n", .{path}); + } if (job_name.len > 0) std.debug.print("job_name\t{s}\n", .{job_name}); if (run_id.len > 0) std.debug.print("run_id\t{s}\n", .{run_id}); @@ -139,7 +226,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { fn printUsage() !void { std.debug.print("Usage:\n", .{}); - std.debug.print("\tml info [--json] [--base ]\n", .{}); + std.debug.print("\tml info [--json] [--base ] [--local] [--remote]\n", .{}); + std.debug.print("\nOptions:\n", .{}); + std.debug.print("\t--json\t\tOutput machine-readable JSON\n", .{}); + std.debug.print("\t--base \tBase path for resolving run manifests\n", .{}); + std.debug.print("\t--local\t\tForce local manifest lookup\n", .{}); + std.debug.print("\t--remote\tForce remote server query (fails if offline)\n", .{}); } test "resolveManifestPath uses run_manifest.json for directories" { diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 8aba856..95200fe 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -114,6 +114,7 @@ pub const Client = struct { host: []const u8, port: u16, is_tls: bool = false, + connected: bool = false, pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 { return response.formatPrewarmFromStatusRoot(allocator, root); @@ -181,6 +182,7 @@ pub const Client = struct { .host = try allocator.dupe(u8, host), .port = port, .is_tls = is_tls, + .connected = true, }; } @@ -218,7 +220,10 @@ pub const Client = struct { /// Fully close client - disconnects transport and frees host memory pub fn close(self: *Client) void { - self.disconnect(); + if (self.connected) { + self.disconnect(); + self.connected = false; + } if (self.host.len > 0) { self.allocator.free(self.host); } @@ -997,6 +1002,31 @@ pub const Client = struct { try frame.sendWebSocketFrame(self.transport, buffer); } + pub fn sendQueryRunInfo(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void { + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (run_id.len > 255) return error.PayloadTooLarge; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var] + const total_len = 1 + 16 + 1 + run_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.query_run_info); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(run_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + run_id.len], run_id); + + try frame.sendWebSocketFrame(self.transport, buffer); + } + pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { try validateApiKeyHash(api_key_hash); diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig index 328ac4a..3800062 100644 --- a/cli/src/net/ws/opcode.zig +++ b/cli/src/net/ws/opcode.zig @@ -44,6 +44,9 @@ pub const Opcode = enum(u8) { // Rerun opcode rerun_request = 0x27, + // Run info query opcode + query_run_info = 0x28, + // Structured response opcodes response_success = 0x10, response_error = 0x11, @@ -91,6 +94,7 @@ pub const dataset_info = Opcode.dataset_info; pub const dataset_search = Opcode.dataset_search; pub const sync_run = Opcode.sync_run; pub const rerun_request = Opcode.rerun_request; +pub const query_run_info = Opcode.query_run_info; pub const response_success = Opcode.response_success; pub const response_error = Opcode.response_error; pub const response_progress = Opcode.response_progress; diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index ad56f5d..3de15da 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -71,6 +71,10 @@ const ( OpcodeGetLogs = 0x20 OpcodeStreamLogs = 0x21 + // Run query opcodes + OpcodeQueryJob = 0x23 + OpcodeQueryRunInfo = 0x28 + // OpcodeCompareRuns = 0x30 OpcodeFindRuns = 0x31 @@ -333,6 +337,8 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { return h.handleExportRun(conn, payload) case OpcodeSetRunOutcome: return h.handleSetRunOutcome(conn, payload) + case OpcodeQueryRunInfo: + return h.handleQueryRunInfo(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } @@ -843,6 +849,65 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro }) } +// handleQueryRunInfo handles run info queries from the CLI +func (h *Handler) handleQueryRunInfo(conn *websocket.Conn, payload []byte) error { + // Parse payload: [api_key_hash:16][run_id_len:1][run_id:var] + if len(payload) < 16+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "query run info payload too short", "") + } + + user, err := h.Authenticate(payload) + if err != nil { + return h.sendErrorPacket( + conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(), + ) + } + if !h.RequirePermission(user, PermJobsRead) { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") + } + + offset := 16 + runIDLen := int(payload[offset]) + offset++ + if runIDLen <= 0 || len(payload) < offset+runIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "") + } + runID := string(payload[offset : offset+runIDLen]) + + h.logger.Info("querying run info", "run_id", runID, "user", user.Name) + + // Check if experiment/run exists + if h.expManager == nil || !h.expManager.ExperimentExists(runID) { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID) + } + + // Read metadata + meta, err := h.expManager.ReadMetadata(runID) + if err != nil { + h.logger.Warn("failed to read experiment metadata", "run_id", runID, "error", err) + meta = &experiment.Metadata{CommitID: runID} + } + + // Read manifest + manifest, _ := h.expManager.ReadManifest(runID) + + // Build response + result := map[string]any{ + "run_id": runID, + "job_name": meta.JobName, + "user": meta.User, + "timestamp": meta.Timestamp, + "success": true, + } + + if manifest != nil { + result["overall_sha"] = manifest.OverallSHA + result["files_count"] = len(manifest.Files) + } + + return h.sendSuccessPacket(conn, result) +} + // BroadcastJobUpdate sends job status update to all connected TUI clients func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) { h.clientsMu.RLock()