feat(cli,server): unify info command with remote/local support
Enhance ml info to query server when connected, falling back to local manifests when offline. Unifies behavior with other commands like run, exec, and cancel. CLI changes: - Add --local and --remote flags for explicit control - Auto-detect connection state via mode.detect() - queryRemoteRun(): Query server via WebSocket for run details - queryLocalRun(): Read local run_manifest.json - displayRunInfo(): Shared display logic for both sources - Add connection status indicators (Remote: connecting.../connected) WebSocket protocol: - Add query_run_info opcode (0x28) to cli and server - Add sendQueryRunInfo() method to ws/client.zig - Protocol: [opcode:1][api_key_hash:16][run_id_len:1][run_id:var] Server changes: - Add handleQueryRunInfo() handler to ws/handler.go - Returns run_id, job_name, user, timestamp, overall_sha, files_count - Checks PermJobsRead permission - Looks up run in experiment manager Usage: ml info abc123 # Auto: tries remote, falls back to local ml info abc123 --local # Force local manifest lookup ml info abc123 --remote # Force remote query (fails if offline)
This commit is contained in:
parent
68062831b0
commit
c6a224d5fc
4 changed files with 197 additions and 6 deletions
|
|
@ -4,16 +4,22 @@ const io = @import("../utils/io.zig");
|
|||
const json = @import("../utils/json.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const core = @import("../core.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const common = @import("common.zig");
|
||||
|
||||
pub const Options = struct {
|
||||
json: bool = false,
|
||||
base: ?[]const u8 = null,
|
||||
local: bool = false,
|
||||
remote: bool = false,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var base: ?[]const u8 = null;
|
||||
var target_path: ?[]const u8 = null;
|
||||
var force_local = false;
|
||||
var force_remote = false;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
|
|
@ -23,6 +29,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < args.len) {
|
||||
base = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--local")) {
|
||||
force_local = true;
|
||||
} else if (std.mem.eql(u8, arg, "--remote")) {
|
||||
force_remote = true;
|
||||
} else if (std.mem.startsWith(u8, arg, "--help")) {
|
||||
return printUsage();
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
|
|
@ -40,7 +50,73 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return printUsage();
|
||||
}
|
||||
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target_path.?, base) catch |err| {
|
||||
// Load config for mode detection
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Determine execution mode
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
const use_remote = if (force_local) false else if (force_remote) true else mode.isOnline(mode_result.mode);
|
||||
|
||||
if (use_remote) {
|
||||
// Try remote query first
|
||||
queryRemoteRun(allocator, target_path.?, flags.json) catch |err| {
|
||||
if (!flags.json) {
|
||||
std.debug.print("Remote query failed ({}), falling back to local...\n", .{err});
|
||||
}
|
||||
// Fall back to local
|
||||
try queryLocalRun(allocator, target_path.?, base, flags.json);
|
||||
};
|
||||
} else {
|
||||
// Local-only mode
|
||||
try queryLocalRun(allocator, target_path.?, base, flags.json);
|
||||
}
|
||||
}
|
||||
|
||||
fn queryRemoteRun(allocator: std.mem.Allocator, run_id: []const u8, json_mode: bool) !void {
|
||||
var ctx = try common.ConnectionContext.init(allocator);
|
||||
defer ctx.deinit();
|
||||
|
||||
if (!json_mode) {
|
||||
std.debug.print("Remote: connecting...\n", .{});
|
||||
}
|
||||
|
||||
try ctx.connect();
|
||||
|
||||
if (!json_mode) {
|
||||
std.debug.print("Remote: connected\n", .{});
|
||||
}
|
||||
|
||||
try ctx.client.sendQueryRunInfo(run_id, ctx.api_key_hash);
|
||||
|
||||
const response = try ctx.client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
// Parse response as JSON
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, response, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) {
|
||||
return error.InvalidResponse;
|
||||
}
|
||||
|
||||
// Check for error
|
||||
if (json.getString(parsed.value.object, "error")) |err_msg| {
|
||||
if (!json_mode) {
|
||||
std.debug.print("Error: {s}\n", .{err_msg});
|
||||
}
|
||||
return error.RemoteQueryFailed;
|
||||
}
|
||||
|
||||
// Display the run info
|
||||
try displayRunInfo(allocator, parsed.value.object, null, json_mode);
|
||||
}
|
||||
|
||||
fn queryLocalRun(allocator: std.mem.Allocator, target: []const u8, base: ?[]const u8, json_mode: bool) !void {
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
core.output.err("Manifest not found");
|
||||
}
|
||||
|
|
@ -53,7 +129,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
allocator.free(data);
|
||||
}
|
||||
|
||||
if (flags.json) {
|
||||
if (json_mode) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{data});
|
||||
return;
|
||||
|
|
@ -67,7 +143,16 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return error.InvalidManifest;
|
||||
}
|
||||
|
||||
const root = parsed.value.object;
|
||||
try displayRunInfo(allocator, parsed.value.object, manifest_path, false);
|
||||
}
|
||||
|
||||
fn displayRunInfo(allocator: std.mem.Allocator, root: std.json.ObjectMap, manifest_path: ?[]const u8, json_mode: bool) !void {
|
||||
_ = allocator;
|
||||
|
||||
if (json_mode) {
|
||||
// Already printed in queryRemoteRun
|
||||
return;
|
||||
}
|
||||
|
||||
const run_id = json.getString(root, "run_id") orelse "";
|
||||
const task_id = json.getString(root, "task_id") orelse "";
|
||||
|
|
@ -95,7 +180,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
const finalize_ms = json.getInt(root, "finalize_duration_ms") orelse 0;
|
||||
const total_ms = json.getInt(root, "total_duration_ms") orelse 0;
|
||||
|
||||
std.debug.print("run_manifest\t{s}\n", .{manifest_path});
|
||||
if (manifest_path) |path| {
|
||||
std.debug.print("run_manifest\t{s}\n", .{path});
|
||||
}
|
||||
|
||||
if (job_name.len > 0) std.debug.print("job_name\t{s}\n", .{job_name});
|
||||
if (run_id.len > 0) std.debug.print("run_id\t{s}\n", .{run_id});
|
||||
|
|
@ -139,7 +226,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
|
||||
fn printUsage() !void {
|
||||
std.debug.print("Usage:\n", .{});
|
||||
std.debug.print("\tml info <run_dir_or_manifest_path_or_id> [--json] [--base <path>]\n", .{});
|
||||
std.debug.print("\tml info <run_dir_or_manifest_path_or_id> [--json] [--base <path>] [--local] [--remote]\n", .{});
|
||||
std.debug.print("\nOptions:\n", .{});
|
||||
std.debug.print("\t--json\t\tOutput machine-readable JSON\n", .{});
|
||||
std.debug.print("\t--base <path>\tBase path for resolving run manifests\n", .{});
|
||||
std.debug.print("\t--local\t\tForce local manifest lookup\n", .{});
|
||||
std.debug.print("\t--remote\tForce remote server query (fails if offline)\n", .{});
|
||||
}
|
||||
|
||||
test "resolveManifestPath uses run_manifest.json for directories" {
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ pub const Client = struct {
|
|||
host: []const u8,
|
||||
port: u16,
|
||||
is_tls: bool = false,
|
||||
connected: bool = false,
|
||||
|
||||
pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 {
|
||||
return response.formatPrewarmFromStatusRoot(allocator, root);
|
||||
|
|
@ -181,6 +182,7 @@ pub const Client = struct {
|
|||
.host = try allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
.is_tls = is_tls,
|
||||
.connected = true,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -218,7 +220,10 @@ pub const Client = struct {
|
|||
|
||||
/// Fully close client - disconnects transport and frees host memory
|
||||
pub fn close(self: *Client) void {
|
||||
self.disconnect();
|
||||
if (self.connected) {
|
||||
self.disconnect();
|
||||
self.connected = false;
|
||||
}
|
||||
if (self.host.len > 0) {
|
||||
self.allocator.free(self.host);
|
||||
}
|
||||
|
|
@ -997,6 +1002,31 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(self.transport, buffer);
|
||||
}
|
||||
|
||||
pub fn sendQueryRunInfo(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void {
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (run_id.len > 255) return error.PayloadTooLarge;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var]
|
||||
const total_len = 1 + 16 + 1 + run_id.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.query_run_info);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(run_id.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + run_id.len], run_id);
|
||||
|
||||
try frame.sendWebSocketFrame(self.transport, buffer);
|
||||
}
|
||||
|
||||
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@ pub const Opcode = enum(u8) {
|
|||
// Rerun opcode
|
||||
rerun_request = 0x27,
|
||||
|
||||
// Run info query opcode
|
||||
query_run_info = 0x28,
|
||||
|
||||
// Structured response opcodes
|
||||
response_success = 0x10,
|
||||
response_error = 0x11,
|
||||
|
|
@ -91,6 +94,7 @@ pub const dataset_info = Opcode.dataset_info;
|
|||
pub const dataset_search = Opcode.dataset_search;
|
||||
pub const sync_run = Opcode.sync_run;
|
||||
pub const rerun_request = Opcode.rerun_request;
|
||||
pub const query_run_info = Opcode.query_run_info;
|
||||
pub const response_success = Opcode.response_success;
|
||||
pub const response_error = Opcode.response_error;
|
||||
pub const response_progress = Opcode.response_progress;
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ const (
|
|||
OpcodeGetLogs = 0x20
|
||||
OpcodeStreamLogs = 0x21
|
||||
|
||||
// Run query opcodes
|
||||
OpcodeQueryJob = 0x23
|
||||
OpcodeQueryRunInfo = 0x28
|
||||
|
||||
//
|
||||
OpcodeCompareRuns = 0x30
|
||||
OpcodeFindRuns = 0x31
|
||||
|
|
@ -333,6 +337,8 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|||
return h.handleExportRun(conn, payload)
|
||||
case OpcodeSetRunOutcome:
|
||||
return h.handleSetRunOutcome(conn, payload)
|
||||
case OpcodeQueryRunInfo:
|
||||
return h.handleQueryRunInfo(conn, payload)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
||||
}
|
||||
|
|
@ -843,6 +849,65 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
|
|||
})
|
||||
}
|
||||
|
||||
// handleQueryRunInfo handles run info queries from the CLI
|
||||
func (h *Handler) handleQueryRunInfo(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][run_id_len:1][run_id:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "query run info payload too short", "")
|
||||
}
|
||||
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error(),
|
||||
)
|
||||
}
|
||||
if !h.RequirePermission(user, PermJobsRead) {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
runIDLen := int(payload[offset])
|
||||
offset++
|
||||
if runIDLen <= 0 || len(payload) < offset+runIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid run ID length", "")
|
||||
}
|
||||
runID := string(payload[offset : offset+runIDLen])
|
||||
|
||||
h.logger.Info("querying run info", "run_id", runID, "user", user.Name)
|
||||
|
||||
// Check if experiment/run exists
|
||||
if h.expManager == nil || !h.expManager.ExperimentExists(runID) {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run not found", runID)
|
||||
}
|
||||
|
||||
// Read metadata
|
||||
meta, err := h.expManager.ReadMetadata(runID)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to read experiment metadata", "run_id", runID, "error", err)
|
||||
meta = &experiment.Metadata{CommitID: runID}
|
||||
}
|
||||
|
||||
// Read manifest
|
||||
manifest, _ := h.expManager.ReadManifest(runID)
|
||||
|
||||
// Build response
|
||||
result := map[string]any{
|
||||
"run_id": runID,
|
||||
"job_name": meta.JobName,
|
||||
"user": meta.User,
|
||||
"timestamp": meta.Timestamp,
|
||||
"success": true,
|
||||
}
|
||||
|
||||
if manifest != nil {
|
||||
result["overall_sha"] = manifest.OverallSHA
|
||||
result["files_count"] = len(manifest.Files)
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, result)
|
||||
}
|
||||
|
||||
// BroadcastJobUpdate sends job status update to all connected TUI clients
|
||||
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
|
||||
h.clientsMu.RLock()
|
||||
|
|
|
|||
Loading…
Reference in a new issue