feat(cli,server): unify info command with remote/local support

Enhance ml info to query server when connected, falling back to local
manifests when offline. Unifies behavior with other commands like run,
exec, and cancel.

CLI changes:
- Add --local and --remote flags for explicit control
- Auto-detect connection state via mode.detect()
- queryRemoteRun(): Query server via WebSocket for run details
- queryLocalRun(): Read local run_manifest.json
- displayRunInfo(): Shared display logic for both sources
- Add connection status indicators (Remote: connecting.../connected)

WebSocket protocol:
- Add query_run_info opcode (0x28) to cli and server
- Add sendQueryRunInfo() method to ws/client.zig
- Protocol: [opcode:1][api_key_hash:16][run_id_len:1][run_id:var]

Server changes:
- Add handleQueryRunInfo() handler to ws/handler.go
- Returns run_id, job_name, user, timestamp, overall_sha, files_count
- Checks PermJobsRead permission
- Looks up run in experiment manager

Usage:
  ml info abc123              # Auto: tries remote, falls back to local
  ml info abc123 --local      # Force local manifest lookup
  ml info abc123 --remote     # Force remote query (fails if offline)
This commit is contained in:
Jeremie Fraeys 2026-03-05 12:07:00 -05:00
parent 68062831b0
commit c6a224d5fc
No known key found for this signature in database
4 changed files with 197 additions and 6 deletions

View file

@ -4,16 +4,22 @@ const io = @import("../utils/io.zig");
const json = @import("../utils/json.zig");
const manifest = @import("../utils/manifest.zig");
const core = @import("../core.zig");
const mode = @import("../mode.zig");
const common = @import("common.zig");
pub const Options = struct {
json: bool = false,
base: ?[]const u8 = null,
local: bool = false,
remote: bool = false,
};
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var base: ?[]const u8 = null;
var target_path: ?[]const u8 = null;
var force_local = false;
var force_remote = false;
var i: usize = 0;
while (i < args.len) : (i += 1) {
@ -23,6 +29,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < args.len) {
base = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--local")) {
force_local = true;
} else if (std.mem.eql(u8, arg, "--remote")) {
force_remote = true;
} else if (std.mem.startsWith(u8, arg, "--help")) {
return printUsage();
} else if (std.mem.startsWith(u8, arg, "--")) {
@ -40,7 +50,73 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
return printUsage();
}
const manifest_path = manifest.resolvePathWithBase(allocator, target_path.?, base) catch |err| {
// Load config for mode detection
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Determine execution mode
const mode_result = try mode.detect(allocator, cfg);
const use_remote = if (force_local) false else if (force_remote) true else mode.isOnline(mode_result.mode);
if (use_remote) {
// Try remote query first
queryRemoteRun(allocator, target_path.?, flags.json) catch |err| {
if (!flags.json) {
std.debug.print("Remote query failed ({}), falling back to local...\n", .{err});
}
// Fall back to local
try queryLocalRun(allocator, target_path.?, base, flags.json);
};
} else {
// Local-only mode
try queryLocalRun(allocator, target_path.?, base, flags.json);
}
}
fn queryRemoteRun(allocator: std.mem.Allocator, run_id: []const u8, json_mode: bool) !void {
var ctx = try common.ConnectionContext.init(allocator);
defer ctx.deinit();
if (!json_mode) {
std.debug.print("Remote: connecting...\n", .{});
}
try ctx.connect();
if (!json_mode) {
std.debug.print("Remote: connected\n", .{});
}
try ctx.client.sendQueryRunInfo(run_id, ctx.api_key_hash);
const response = try ctx.client.receiveMessage(allocator);
defer allocator.free(response);
// Parse response as JSON
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, response, .{});
defer parsed.deinit();
if (parsed.value != .object) {
return error.InvalidResponse;
}
// Check for error
if (json.getString(parsed.value.object, "error")) |err_msg| {
if (!json_mode) {
std.debug.print("Error: {s}\n", .{err_msg});
}
return error.RemoteQueryFailed;
}
// Display the run info
try displayRunInfo(allocator, parsed.value.object, null, json_mode);
}
fn queryLocalRun(allocator: std.mem.Allocator, target: []const u8, base: ?[]const u8, json_mode: bool) !void {
const manifest_path = manifest.resolvePathWithBase(allocator, target, base) catch |err| {
if (err == error.FileNotFound) {
core.output.err("Manifest not found");
}
@ -53,7 +129,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
allocator.free(data);
}
if (flags.json) {
if (json_mode) {
var out = io.stdoutWriter();
try out.print("{s}\n", .{data});
return;
@ -67,7 +143,16 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
return error.InvalidManifest;
}
const root = parsed.value.object;
try displayRunInfo(allocator, parsed.value.object, manifest_path, false);
}
fn displayRunInfo(allocator: std.mem.Allocator, root: std.json.ObjectMap, manifest_path: ?[]const u8, json_mode: bool) !void {
_ = allocator;
if (json_mode) {
// Already printed in queryRemoteRun
return;
}
const run_id = json.getString(root, "run_id") orelse "";
const task_id = json.getString(root, "task_id") orelse "";
@ -95,7 +180,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
const finalize_ms = json.getInt(root, "finalize_duration_ms") orelse 0;
const total_ms = json.getInt(root, "total_duration_ms") orelse 0;
std.debug.print("run_manifest\t{s}\n", .{manifest_path});
if (manifest_path) |path| {
std.debug.print("run_manifest\t{s}\n", .{path});
}
if (job_name.len > 0) std.debug.print("job_name\t{s}\n", .{job_name});
if (run_id.len > 0) std.debug.print("run_id\t{s}\n", .{run_id});
@ -139,7 +226,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
fn printUsage() !void {
std.debug.print("Usage:\n", .{});
std.debug.print("\tml info <run_dir_or_manifest_path_or_id> [--json] [--base <path>]\n", .{});
std.debug.print("\tml info <run_dir_or_manifest_path_or_id> [--json] [--base <path>] [--local] [--remote]\n", .{});
std.debug.print("\nOptions:\n", .{});
std.debug.print("\t--json\t\tOutput machine-readable JSON\n", .{});
std.debug.print("\t--base <path>\tBase path for resolving run manifests\n", .{});
std.debug.print("\t--local\t\tForce local manifest lookup\n", .{});
std.debug.print("\t--remote\tForce remote server query (fails if offline)\n", .{});
}
test "resolveManifestPath uses run_manifest.json for directories" {

View file

@ -114,6 +114,7 @@ pub const Client = struct {
host: []const u8,
port: u16,
is_tls: bool = false,
connected: bool = false,
pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 {
return response.formatPrewarmFromStatusRoot(allocator, root);
@ -181,6 +182,7 @@ pub const Client = struct {
.host = try allocator.dupe(u8, host),
.port = port,
.is_tls = is_tls,
.connected = true,
};
}
@ -218,7 +220,10 @@ pub const Client = struct {
/// Fully close client - disconnects transport and frees host memory
pub fn close(self: *Client) void {
self.disconnect();
if (self.connected) {
self.disconnect();
self.connected = false;
}
if (self.host.len > 0) {
self.allocator.free(self.host);
}
@ -997,6 +1002,31 @@ pub const Client = struct {
try frame.sendWebSocketFrame(self.transport, buffer);
}
pub fn sendQueryRunInfo(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void {
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
if (run_id.len > 255) return error.PayloadTooLarge;
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var]
const total_len = 1 + 16 + 1 + run_id.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.query_run_info);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(run_id.len);
offset += 1;
@memcpy(buffer[offset .. offset + run_id.len], run_id);
try frame.sendWebSocketFrame(self.transport, buffer);
}
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
try validateApiKeyHash(api_key_hash);

View file

@ -44,6 +44,9 @@ pub const Opcode = enum(u8) {
// Rerun opcode
rerun_request = 0x27,
// Run info query opcode
query_run_info = 0x28,
// Structured response opcodes
response_success = 0x10,
response_error = 0x11,
@ -91,6 +94,7 @@ pub const dataset_info = Opcode.dataset_info;
pub const dataset_search = Opcode.dataset_search;
pub const sync_run = Opcode.sync_run;
pub const rerun_request = Opcode.rerun_request;
pub const query_run_info = Opcode.query_run_info;
pub const response_success = Opcode.response_success;
pub const response_error = Opcode.response_error;
pub const response_progress = Opcode.response_progress;

View file

@ -71,6 +71,10 @@ const (
OpcodeGetLogs = 0x20
OpcodeStreamLogs = 0x21
// Run query opcodes
OpcodeQueryJob = 0x23
OpcodeQueryRunInfo = 0x28
//
OpcodeCompareRuns = 0x30
OpcodeFindRuns = 0x31
@ -333,6 +337,8 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
return h.handleExportRun(conn, payload)
case OpcodeSetRunOutcome:
return h.handleSetRunOutcome(conn, payload)
case OpcodeQueryRunInfo:
return h.handleQueryRunInfo(conn, payload)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
}
@ -843,6 +849,65 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
})
}
// handleQueryRunInfo handles run info queries from the CLI
func (h *Handler) handleQueryRunInfo(conn *websocket.Conn, payload []byte) error {
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "query run info payload too short", "")
}
user, err := h.Authenticate(payload)
if err != nil {
return h.sendErrorPacket(
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
)
}
if !h.RequirePermission(user, PermJobsRead) {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
}
offset := 16
runIDLen := int(payload[offset])
offset++
if runIDLen <= 0 || len(payload) < offset+runIDLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
}
runID := string(payload[offset : offset+runIDLen])
h.logger.Info("querying run info", "run_id", runID, "user", user.Name)
// Check if experiment/run exists
if h.expManager == nil || !h.expManager.ExperimentExists(runID) {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
}
// Read metadata
meta, err := h.expManager.ReadMetadata(runID)
if err != nil {
h.logger.Warn("failed to read experiment metadata", "run_id", runID, "error", err)
meta = &experiment.Metadata{CommitID: runID}
}
// Read manifest
manifest, _ := h.expManager.ReadManifest(runID)
// Build response
result := map[string]any{
"run_id": runID,
"job_name": meta.JobName,
"user": meta.User,
"timestamp": meta.Timestamp,
"success": true,
}
if manifest != nil {
result["overall_sha"] = manifest.OverallSHA
result["files_count"] = len(manifest.Files)
}
return h.sendSuccessPacket(conn, result)
}
// BroadcastJobUpdate sends job status update to all connected TUI clients
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
h.clientsMu.RLock()