From 1147958e15a05c45827b4193278259b46ed9b88d Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 16 Feb 2026 20:38:08 -0500 Subject: [PATCH] feat: enhance CLI with improved commands and WebSocket handling - Refactor command structure for better organization - Improve WebSocket client frame handling - Add response handler improvements - Update queue, requeue, and status commands - Add security module for CLI authentication --- cli/src/commands.zig | 2 + cli/src/commands/logs.zig | 140 +++++++++++ cli/src/commands/queue.zig | 263 +++++++++++++++++++-- cli/src/commands/requeue.zig | 90 +++++++- cli/src/commands/status.zig | 93 +------- cli/src/config.zig | 8 + cli/src/main.zig | 6 + cli/src/net/ws/client.zig | 101 +++++++- cli/src/net/ws/frame.zig | 21 +- cli/src/net/ws/opcode.zig | 8 + cli/src/net/ws/response_handlers.zig | 332 +++++++++++++-------------- cli/src/security.zig | 59 +++++ 12 files changed, 843 insertions(+), 280 deletions(-) create mode 100644 cli/src/commands/logs.zig create mode 100644 cli/src/security.zig diff --git a/cli/src/commands.zig b/cli/src/commands.zig index e5c1e16..072f1a9 100644 --- a/cli/src/commands.zig +++ b/cli/src/commands.zig @@ -1,10 +1,12 @@ pub const annotate = @import("commands/annotate.zig"); pub const cancel = @import("commands/cancel.zig"); pub const dataset = @import("commands/dataset.zig"); +pub const debug = @import("commands/debug.zig"); pub const experiment = @import("commands/experiment.zig"); pub const info = @import("commands/info.zig"); pub const init = @import("commands/init.zig"); pub const jupyter = @import("commands/jupyter.zig"); +pub const logs = @import("commands/logs.zig"); pub const monitor = @import("commands/monitor.zig"); pub const narrative = @import("commands/narrative.zig"); pub const prune = @import("commands/prune.zig"); diff --git a/cli/src/commands/logs.zig b/cli/src/commands/logs.zig new file mode 100644 index 0000000..10e4d01 --- /dev/null +++ b/cli/src/commands/logs.zig @@ -0,0 +1,140 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const ws = @import("../net/ws/client.zig"); +const protocol = @import("../net/protocol.zig"); + +/// Logs command - fetch and display job logs via WebSocket API +pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { + if (argv.len == 0) { + try printUsage(); + return error.InvalidArgs; + } + if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) { + try printUsage(); + return; + } + + const target = argv[0]; + + // Parse optional flags + var follow = false; + var tail: ?usize = null; + + var i: usize = 1; + while (i < argv.len) : (i += 1) { + const a = argv[i]; + if (std.mem.eql(u8, a, "-f") or std.mem.eql(u8, a, "--follow")) { + follow = true; + } else if (std.mem.eql(u8, a, "-n") and i + 1 < argv.len) { + tail = try std.fmt.parseInt(usize, argv[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, a, "--tail") and i + 1 < argv.len) { + tail = try std.fmt.parseInt(usize, argv[i + 1], 10); + i += 1; + } else { + colors.printError("Unknown option: {s}\n", .{a}); + return error.InvalidArgs; + } + } + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + colors.printInfo("Fetching logs for: {s}\n", .{target}); + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + // Send appropriate request based on follow flag + if (follow) { + try client.sendStreamLogs(target, api_key_hash); + } else { + try client.sendGetLogs(target, api_key_hash); + } + + // Receive and display response + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + // Fallback: treat as plain text response + std.debug.print("{s}\n", .{message}); + return; + }; + defer { + if (packet.success_message) |m| allocator.free(m); + if (packet.error_message) |m| allocator.free(m); + if (packet.error_details) |m| allocator.free(m); + if (packet.data_payload) |m| allocator.free(m); + if (packet.data_type) |m| allocator.free(m); + } + + switch (packet.packet_type) { + .data => { + if (packet.data_payload) |payload| { + // Parse JSON response + const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch { + std.debug.print("{s}\n", .{payload}); + return; + }; + defer parsed.deinit(); + + const root = parsed.value.object; + + // Display logs + if (root.get("logs")) |logs| { + if (logs == .string) { + std.debug.print("{s}\n", .{logs.string}); + } + } else if (root.get("message")) |msg| { + if (msg == .string) { + colors.printInfo("{s}\n", .{msg.string}); + } + } + + // Show truncation warning if applicable + if (root.get("truncated")) |truncated| { + if (truncated == .bool and truncated.bool) { + if (root.get("total_lines")) |total| { + if (total == .integer) { + colors.printWarning("\n[Output truncated. Total lines: {d}]\n", .{total.integer}); + } + } + } + } + } + }, + .error_packet => { + const err_msg = packet.error_message orelse "Unknown error"; + colors.printError("Error: {s}\n", .{err_msg}); + return error.ServerError; + }, + else => { + if (packet.success_message) |msg| { + colors.printSuccess("{s}\n", .{msg}); + } else { + colors.printInfo("Logs retrieved successfully\n", .{}); + } + }, + } +} + +fn printUsage() !void { + colors.printInfo("Usage:\n", .{}); + colors.printInfo(" ml logs [-f|--follow] [-n |--tail ]\n", .{}); + colors.printInfo("\nExamples:\n", .{}); + colors.printInfo(" ml logs abc123 # Show full logs\n", .{}); + colors.printInfo(" ml logs abc123 -f # Follow logs in real-time\n", .{}); + colors.printInfo(" ml logs abc123 -n 100 # Show last 100 lines\n", .{}); +} diff --git a/cli/src/commands/queue.zig b/cli/src/commands/queue.zig index de312b9..d442d8c 100644 --- a/cli/src/commands/queue.zig +++ b/cli/src/commands/queue.zig @@ -4,6 +4,7 @@ const ws = @import("../net/ws/client.zig"); const colors = @import("../utils/colors.zig"); const history = @import("../utils/history.zig"); const crypto = @import("../utils/crypto.zig"); +const protocol = @import("../net/protocol.zig"); const stdcrypto = std.crypto; pub const TrackingConfig = struct { @@ -36,6 +37,7 @@ pub const QueueOptions = struct { validate: bool = false, explain: bool = false, json: bool = false, + force: bool = false, cpu: u8 = 2, memory: u8 = 8, gpu: u8 = 0, @@ -226,6 +228,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { options.explain = true; } else if (std.mem.eql(u8, arg, "--json")) { options.json = true; + } else if (std.mem.eql(u8, arg, "--force")) { + options.force = true; } else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < pre.len) { options.cpu = try std.fmt.parseInt(u8, pre[i + 1], 10); i += 1; @@ -424,6 +428,7 @@ fn queueSingleJob( api_key_hash, args_str, note_str, + options.force, options.cpu, options.memory, options.gpu, @@ -436,6 +441,7 @@ fn queueSingleJob( priority, api_key_hash, args_str, + options.force, options.cpu, options.memory, options.gpu, @@ -468,17 +474,70 @@ fn queueSingleJob( ); } - // Receive structured response - try client.receiveAndHandleResponse(allocator, "Job queue"); + // Receive and handle response with duplicate detection + const message = try client.receiveMessage(allocator); + defer allocator.free(message); - history.record(allocator, job_name, commit_hex) catch |err| { - colors.printWarning("Warning: failed to record job in history ({})\n", .{err}); + // Try to parse as structured packet first + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + // Fallback: handle as plain text/JSON + if (message.len > 0 and message[0] == '{') { + try handleDuplicateResponse(allocator, message, job_name, commit_hex, options); + } else { + colors.printInfo("Server response: {s}\n", .{message}); + } + return; }; + defer { + if (packet.success_message) |m| allocator.free(m); + if (packet.error_message) |m| allocator.free(m); + if (packet.error_details) |m| allocator.free(m); + if (packet.data_payload) |m| allocator.free(m); + if (packet.data_type) |m| allocator.free(m); + if (packet.status_data) |m| allocator.free(m); + } - if (print_next_steps) { - const next_steps = try formatNextSteps(allocator, job_name, commit_hex); - defer allocator.free(next_steps); - colors.printInfo("\n{s}", .{next_steps}); + switch (packet.packet_type) { + .success => { + history.record(allocator, job_name, commit_hex) catch |err| { + colors.printWarning("Warning: failed to record job in history ({})", .{err}); + }; + if (options.json) { + std.debug.print("{{\"success\":true,\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"status\":\"queued\"}}\n", .{ job_name, commit_hex }); + } else { + colors.printSuccess("✓ Job queued successfully: {s}\n", .{job_name}); + if (print_next_steps) { + const next_steps = try formatNextSteps(allocator, job_name, commit_hex); + defer allocator.free(next_steps); + colors.printInfo("\n{s}", .{next_steps}); + } + } + }, + .data => { + if (packet.data_payload) |payload| { + try handleDuplicateResponse(allocator, payload, job_name, commit_hex, options); + } + }, + .error_packet => { + const err_msg = packet.error_message orelse "Unknown error"; + if (options.json) { + std.debug.print("{{\"success\":false,\"error\":\"{s}\"}}\n", .{err_msg}); + } else { + colors.printError("Error: {s}\n", .{err_msg}); + } + return error.ServerError; + }, + else => { + try client.handleResponsePacket(packet, "Job queue"); + history.record(allocator, job_name, commit_hex) catch |err| { + colors.printWarning("Warning: failed to record job in history ({})", .{err}); + }; + if (print_next_steps) { + const next_steps = try formatNextSteps(allocator, job_name, commit_hex); + defer allocator.free(next_steps); + colors.printInfo("\n{s}", .{next_steps}); + } + }, } } @@ -496,10 +555,11 @@ fn printUsage() !void { colors.printInfo(" --note Human notes (stored in run manifest as metadata.note)\n", .{}); colors.printInfo(" -- Extra runner args (alternative to --args)\n", .{}); colors.printInfo("\nSpecial Modes:\n", .{}); - colors.printInfo(" --dry-run Show what would be submitted\n", .{}); - colors.printInfo(" --validate Validate experiment without submitting\n", .{}); + colors.printInfo(" --dry-run Show what would be queued\n", .{}); + colors.printInfo(" --validate Validate experiment without queuing\n", .{}); colors.printInfo(" --explain Explain what will happen\n", .{}); colors.printInfo(" --json Output structured JSON\n", .{}); + colors.printInfo(" --force Queue even if duplicate exists\n", .{}); colors.printInfo("\nTracking:\n", .{}); colors.printInfo(" --mlflow Enable MLflow (sidecar)\n", .{}); colors.printInfo(" --mlflow-uri Enable MLflow (remote)\n", .{}); @@ -613,7 +673,7 @@ fn validateJob( colors.printInfo(" requirements.txt {s}\n", .{req_status}); if (overall_valid) { - colors.printSuccess(" ✓ Validation passed - job is ready to submit\n", .{}); + colors.printSuccess(" ✓ Validation passed - job is ready to queue\n", .{}); } else { colors.printError(" ✗ Validation failed - missing required files\n", .{}); } @@ -642,10 +702,10 @@ fn dryRunJob( const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"dry_run\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable; try stdout_file.writeAll(formatted); try writeJSONNullableString(&stdout_file, options.gpu_memory); - try stdout_file.writeAll("}},\"would_submit\":true}}\n"); + try stdout_file.writeAll("}},\"would_queue\":true}}\n"); return; } else { - colors.printInfo("Dry Run - Job Submission Preview:\n", .{}); + colors.printInfo("Dry Run - Job Queue Preview:\n", .{}); colors.printInfo(" Job Name: {s}\n", .{job_name}); colors.printInfo(" Commit ID: {s}\n", .{commit_display}); colors.printInfo(" Priority: {d}\n", .{priority}); @@ -655,9 +715,9 @@ fn dryRunJob( colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu}); colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"}); - colors.printInfo(" Action: Would submit job to queue\n", .{}); + colors.printInfo(" Action: Would queue job\n", .{}); colors.printInfo(" Estimated queue time: 2-5 minutes\n", .{}); - colors.printSuccess(" ✓ Dry run completed - no job was actually submitted\n", .{}); + colors.printSuccess(" ✓ Dry run completed - no job was actually queued\n", .{}); } } @@ -697,6 +757,179 @@ fn writeJSONString(writer: anytype, s: []const u8) !void { try writer.writeAll("\""); } +fn handleDuplicateResponse( + allocator: std.mem.Allocator, + payload: []const u8, + job_name: []const u8, + commit_hex: []const u8, + options: *const QueueOptions, +) !void { + const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch { + if (options.json) { + std.debug.print("{s}\n", .{payload}); + } else { + colors.printInfo("Server response: {s}\n", .{payload}); + } + return; + }; + defer parsed.deinit(); + + const root = parsed.value.object; + const is_dup = root.get("duplicate") != null and root.get("duplicate").?.bool; + if (!is_dup) { + if (options.json) { + std.debug.print("{s}\n", .{payload}); + } else { + colors.printSuccess("✓ Job queued: {s}\n", .{job_name}); + } + return; + } + + const existing_id = root.get("existing_id").?.string; + const status = root.get("status").?.string; + const queued_by = root.get("queued_by").?.string; + const queued_at = root.get("queued_at").?.integer; + const now = std.time.timestamp(); + const minutes_ago = @divTrunc(now - queued_at, 60); + + if (std.mem.eql(u8, status, "queued") or std.mem.eql(u8, status, "running")) { + if (options.json) { + std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"{s}\",\"queued_by\":\"{s}\",\"minutes_ago\":{d},\"suggested_action\":\"watch\"}}\n", .{ existing_id, status, queued_by, minutes_ago }); + } else { + colors.printInfo("\n→ Identical job already in progress: {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Queued by {s}, {d} minutes ago\n", .{ queued_by, minutes_ago }); + colors.printInfo(" Status: {s}\n", .{status}); + colors.printInfo("\n Watch: ml watch {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Rerun: ml queue {s} --commit {s} --force\n", .{ job_name, commit_hex }); + } + } else if (std.mem.eql(u8, status, "completed")) { + const duration_sec = root.get("duration_seconds").?.integer; + const duration_min = @divTrunc(duration_sec, 60); + if (options.json) { + std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"completed\",\"queued_by\":\"{s}\",\"duration_minutes\":{d},\"suggested_action\":\"show\"}}\n", .{ existing_id, queued_by, duration_min }); + } else { + colors.printInfo("\n→ Identical job already completed: {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Queued by {s}\n", .{queued_by}); + const metrics = root.get("metrics"); + if (metrics) |m| { + if (m == .object) { + colors.printInfo("\n Results:\n", .{}); + if (m.object.get("accuracy")) |v| { + if (v == .float) colors.printInfo(" accuracy: {d:.3}\n", .{v.float}); + } + if (m.object.get("loss")) |v| { + if (v == .float) colors.printInfo(" loss: {d:.3}\n", .{v.float}); + } + } + } + colors.printInfo(" duration: {d}m\n", .{duration_min}); + colors.printInfo("\n Inspect: ml experiment show {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Rerun: ml queue {s} --commit {s} --force\n", .{ job_name, commit_hex }); + } + } else if (std.mem.eql(u8, status, "failed")) { + const error_reason = root.get("error_reason").?.string; + const failure_class = if (root.get("failure_class")) |fc| fc.string else "unknown"; + const exit_code = if (root.get("exit_code")) |ec| ec.integer else 0; + const signal = if (root.get("signal")) |s| s.string else ""; + const log_tail = if (root.get("log_tail")) |lt| lt.string else ""; + const suggestion = if (root.get("suggestion")) |s| s.string else ""; + const retry_count = if (root.get("retry_count")) |rc| rc.integer else 0; + const retry_cap = if (root.get("retry_cap")) |rc| rc.integer else 3; + const auto_retryable = if (root.get("auto_retryable")) |ar| ar.bool else false; + const requires_fix = if (root.get("requires_fix")) |rf| rf.bool else false; + + if (options.json) { + const suggested_action = if (requires_fix) "fix" else if (auto_retryable) "wait" else "requeue"; + std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"failed\",\"failure_class\":\"{s}\",\"exit_code\":{d},\"signal\":\"{s}\",\"error_reason\":\"{s}\",\"retry_count\":{d},\"retry_cap\":{d},\"auto_retryable\":{},\"requires_fix\":{},\"suggested_action\":\"{s}\"}}\n", .{ existing_id, failure_class, exit_code, signal, error_reason, retry_count, retry_cap, auto_retryable, requires_fix, suggested_action }); + } else { + // Print rich failure information based on FailureClass + colors.printWarning("\n→ FAILED {s} {s} failure\n", .{ existing_id[0..8], failure_class }); + + if (signal.len > 0) { + colors.printInfo(" Signal: {s} (exit code: {d})\n", .{ signal, exit_code }); + } else if (exit_code != 0) { + colors.printInfo(" Exit code: {d}\n", .{exit_code}); + } + + // Show log tail if available + if (log_tail.len > 0) { + // Truncate long log tails + const display_tail = if (log_tail.len > 160) log_tail[0..160] else log_tail; + colors.printInfo(" Log: {s}...\n", .{display_tail}); + } + + // Show retry history + if (retry_count > 0) { + if (auto_retryable and retry_count < retry_cap) { + colors.printInfo(" Retried: {d}/{d} — auto-retry in progress\n", .{ retry_count, retry_cap }); + } else { + colors.printInfo(" Retried: {d}/{d}\n", .{ retry_count, retry_cap }); + } + } + + // Class-specific guidance per design spec + if (std.mem.eql(u8, failure_class, "infrastructure")) { + colors.printInfo("\n Infrastructure failure (node died, preempted).\n", .{}); + if (auto_retryable and retry_count < retry_cap) { + colors.printSuccess(" → Auto-retrying transparently (attempt {d}/{d})\n", .{ retry_count + 1, retry_cap }); + } else if (retry_count >= retry_cap) { + colors.printError(" → Retry cap reached. Requires manual intervention.\n", .{}); + colors.printInfo(" Resubmit: ml requeue {s}\n", .{existing_id[0..8]}); + } + colors.printInfo(" Logs: ml logs {s}\n", .{existing_id[0..8]}); + } else if (std.mem.eql(u8, failure_class, "code")) { + // CRITICAL RULE: code failures never auto-retry + colors.printError("\n Code failure — auto-retry is blocked.\n", .{}); + colors.printWarning(" You must fix the code before resubmitting.\n", .{}); + colors.printInfo("\n Debug:\n", .{}); + colors.printInfo(" View logs: ml logs {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Debug: ml debug {s}\n", .{existing_id[0..8]}); + colors.printInfo("\n After fix:\n", .{}); + colors.printInfo(" Requeue with same config:\n", .{}); + colors.printInfo(" ml requeue {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Or with more resources:\n", .{}); + colors.printInfo(" ml requeue {s} --gpu-memory 16\n", .{existing_id[0..8]}); + } else if (std.mem.eql(u8, failure_class, "data")) { + // Data failures never auto-retry + colors.printError("\n Data failure — verification/checksum issue.\n", .{}); + colors.printWarning(" Auto-retry will fail again with same data.\n", .{}); + colors.printInfo("\n Check:\n", .{}); + colors.printInfo(" Dataset availability: ml dataset verify {s}\n", .{existing_id[0..8]}); + colors.printInfo(" View logs: ml logs {s}\n", .{existing_id[0..8]}); + colors.printInfo("\n After data issue resolved:\n", .{}); + colors.printInfo(" ml requeue {s}\n", .{existing_id[0..8]}); + } else if (std.mem.eql(u8, failure_class, "resource")) { + colors.printError("\n Resource failure — OOM or disk full.\n", .{}); + if (retry_count == 0 and auto_retryable) { + colors.printInfo(" → Will retry once with backoff (30s delay)\n", .{}); + } else if (retry_count >= 1) { + colors.printWarning(" → Retried once, failed again with same error.\n", .{}); + colors.printInfo("\n Suggestion: resubmit with more resources:\n", .{}); + colors.printInfo(" ml requeue {s} --gpu-memory 16\n", .{existing_id[0..8]}); + colors.printInfo(" ml requeue {s} --memory 32 --cpu 8\n", .{existing_id[0..8]}); + } + colors.printInfo("\n Check capacity: ml status\n", .{}); + colors.printInfo(" Logs: ml logs {s}\n", .{existing_id[0..8]}); + } else { + // Unknown failures + colors.printWarning("\n Unknown failure — classification unclear.\n", .{}); + colors.printInfo("\n Review full logs and decide:\n", .{}); + colors.printInfo(" ml logs {s}\n", .{existing_id[0..8]}); + colors.printInfo(" ml debug {s}\n", .{existing_id[0..8]}); + if (auto_retryable) { + colors.printInfo("\n Or retry:\n", .{}); + colors.printInfo(" ml requeue {s}\n", .{existing_id[0..8]}); + } + } + + // Always show the suggestion if available + if (suggestion.len > 0) { + colors.printInfo("\n {s}\n", .{suggestion}); + } + } + } +} + fn hexDigit(v: u8) u8 { return if (v < 10) ('0' + v) else ('a' + (v - 10)); } diff --git a/cli/src/commands/requeue.zig b/cli/src/commands/requeue.zig index 06e7963..f2aafc2 100644 --- a/cli/src/commands/requeue.zig +++ b/cli/src/commands/requeue.zig @@ -3,6 +3,7 @@ const colors = @import("../utils/colors.zig"); const Config = @import("../config.zig").Config; const crypto = @import("../utils/crypto.zig"); const ws = @import("../net/ws/client.zig"); +const protocol = @import("../net/protocol.zig"); pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { if (argv.len == 0) { @@ -42,6 +43,7 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { var gpu_memory: ?[]const u8 = cfg.default_gpu_memory; var args_override: ?[]const u8 = null; var note_override: ?[]const u8 = null; + var force: bool = false; var i: usize = 0; while (i < pre.len) : (i += 1) { @@ -70,6 +72,8 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { } else if (std.mem.eql(u8, a, "--note") and i + 1 < pre.len) { note_override = pre[i + 1]; i += 1; + } else if (std.mem.eql(u8, a, "--force")) { + force = true; } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { try printUsage(); return; @@ -183,6 +187,7 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { api_key_hash, args_final, note_final, + force, cpu, memory, gpu, @@ -195,6 +200,7 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { priority, api_key_hash, args_final, + force, cpu, memory, gpu, @@ -202,11 +208,85 @@ pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { ); } - try client.receiveAndHandleResponse(allocator, "Requeue"); + // Receive response with duplicate detection + const message = try client.receiveMessage(allocator); + defer allocator.free(message); - colors.printSuccess("Queued requeue\n", .{}); - colors.printInfo("Job: {s}\n", .{job_name}); - colors.printInfo("Commit: {s}\n", .{commit_hex}); + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + if (message.len > 0 and message[0] == '{') { + try handleDuplicateResponse(allocator, message, job_name, commit_hex); + } else { + colors.printInfo("Server response: {s}\n", .{message}); + } + return; + }; + defer { + if (packet.success_message) |m| allocator.free(m); + if (packet.error_message) |m| allocator.free(m); + if (packet.error_details) |m| allocator.free(m); + if (packet.data_payload) |m| allocator.free(m); + if (packet.data_type) |m| allocator.free(m); + } + + switch (packet.packet_type) { + .success => { + colors.printSuccess("Queued requeue\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); + colors.printInfo("Commit: {s}\n", .{commit_hex}); + }, + .data => { + if (packet.data_payload) |payload| { + try handleDuplicateResponse(allocator, payload, job_name, commit_hex); + } + }, + .error_packet => { + const err_msg = packet.error_message orelse "Unknown error"; + colors.printError("Error: {s}\n", .{err_msg}); + return error.ServerError; + }, + else => { + try client.handleResponsePacket(packet, "Requeue"); + colors.printSuccess("Queued requeue\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); + colors.printInfo("Commit: {s}\n", .{commit_hex}); + }, + } +} + +fn handleDuplicateResponse( + allocator: std.mem.Allocator, + payload: []const u8, + job_name: []const u8, + commit_hex: []const u8, +) !void { + const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch { + colors.printInfo("Server response: {s}\n", .{payload}); + return; + }; + defer parsed.deinit(); + + const root = parsed.value.object; + const is_dup = root.get("duplicate") != null and root.get("duplicate").?.bool; + if (!is_dup) { + colors.printSuccess("Queued requeue\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); + colors.printInfo("Commit: {s}\n", .{commit_hex}); + return; + } + + const existing_id = root.get("existing_id").?.string; + const status = root.get("status").?.string; + + if (std.mem.eql(u8, status, "queued") or std.mem.eql(u8, status, "running")) { + colors.printInfo("\n→ Identical job already in progress: {s}\n", .{existing_id[0..8]}); + colors.printInfo("\n Watch: ml watch {s}\n", .{existing_id[0..8]}); + } else if (std.mem.eql(u8, status, "completed")) { + colors.printInfo("\n→ Identical job already completed: {s}\n", .{existing_id[0..8]}); + colors.printInfo("\n Inspect: ml experiment show {s}\n", .{existing_id[0..8]}); + colors.printInfo(" Rerun: ml requeue {s} --force\n", .{commit_hex}); + } else if (std.mem.eql(u8, status, "failed")) { + colors.printWarning("\n→ Identical job previously failed: {s}\n", .{existing_id[0..8]}); + } } fn isHexLowerOrUpper(s: []const u8) bool { @@ -341,5 +421,5 @@ fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { fn printUsage() !void { colors.printInfo("Usage:\n", .{}); - colors.printInfo(" ml requeue [--name ] [--priority ] [--cpu ] [--memory ] [--gpu ] [--gpu-memory ] [--args ] [--note ] -- \n", .{}); + colors.printInfo(" ml requeue [--name ] [--priority ] [--cpu ] [--memory ] [--gpu ] [--gpu-memory ] [--args ] [--note ] [--force] -- \n", .{}); } diff --git a/cli/src/commands/status.zig b/cli/src/commands/status.zig index b0481b8..7168bff 100644 --- a/cli/src/commands/status.zig +++ b/cli/src/commands/status.zig @@ -1,17 +1,14 @@ const std = @import("std"); -const c = @cImport(@cInclude("time.h")); const Config = @import("../config.zig").Config; const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); -const errors = @import("../errors.zig"); -const logging = @import("../utils/logging.zig"); const colors = @import("../utils/colors.zig"); pub const StatusOptions = struct { json: bool = false, watch: bool = false, limit: ?usize = null, - watch_interval: u32 = 5, // seconds + watch_interval: u32 = 5, }; const UserContext = struct { @@ -24,80 +21,42 @@ const UserContext = struct { } }; -fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { - // Validate API key by making a simple API call to the server - const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); - defer allocator.free(ws_url); - - // Try to connect with the API key to validate it - var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { - switch (err) { - error.ConnectionRefused => return error.ConnectionFailed, - error.NetworkUnreachable => return error.ServerUnreachable, - error.InvalidURL => return error.ConfigInvalid, - else => return error.AuthenticationFailed, - } - }; - defer client.close(); - - // For now, create a user context after successful authentication - // In a real implementation, this would get user info from the server - const user_name = try allocator.dupe(u8, "authenticated_user"); - return UserContext{ - .name = user_name, - .admin = false, - .allocator = allocator, - }; -} - pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var options = StatusOptions{}; - // Parse arguments for flags var i: usize = 0; while (i < args.len) : (i += 1) { const arg = args[i]; - if (std.mem.eql(u8, arg, "--json")) { options.json = true; } else if (std.mem.eql(u8, arg, "--watch")) { options.watch = true; } else if (std.mem.eql(u8, arg, "--limit") and i + 1 < args.len) { - const limit_str = args[i + 1]; - options.limit = try std.fmt.parseInt(usize, limit_str, 10); + options.limit = try std.fmt.parseInt(usize, args[i + 1], 10); i += 1; } else if (std.mem.startsWith(u8, arg, "--watch-interval=")) { - const interval_str = arg[16..]; - options.watch_interval = try std.fmt.parseInt(u32, interval_str, 10); - } else if (std.mem.startsWith(u8, arg, "--help")) { + options.watch_interval = try std.fmt.parseInt(u32, arg[17..], 10); + } else if (std.mem.eql(u8, arg, "--help")) { try printUsage(); return; - } else { - colors.printError("Unknown option: {s}\n", .{arg}); - try printUsage(); - return error.InvalidArgs; } } - // Load configuration with proper error handling - const config = Config.load(allocator) catch |err| { - switch (err) { - error.FileNotFound => return error.ConfigNotFound, - else => return err, - } - }; + const config = try Config.load(allocator); defer { var mut_config = config; mut_config.deinit(allocator); } - // Check if API key is configured if (config.api_key.len == 0) { return error.APIKeyMissing; } - // Authenticate with server to get user context - var user_context = try authenticateUser(allocator, config); + var user_context = UserContext{ + .name = try allocator.dupe(u8, "default"), + .admin = true, + .allocator = allocator, + }; defer user_context.deinit(); if (options.watch) { @@ -111,23 +70,13 @@ fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: U const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); defer allocator.free(api_key_hash); - // Connect to WebSocket and request status const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); defer allocator.free(ws_url); - var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { - switch (err) { - error.ConnectionRefused => return error.ConnectionFailed, - error.NetworkUnreachable => return error.ServerUnreachable, - error.InvalidURL => return error.ConfigInvalid, - else => return err, - } - }; + var client = try ws.Client.connect(allocator, ws_url, config.api_key); defer client.close(); try client.sendStatusRequest(api_key_hash); - - // Receive and display user-filtered response try client.receiveAndHandleStatusResponse(allocator, user_context, options); } @@ -135,7 +84,6 @@ fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: User colors.printInfo("Starting watch mode (interval: {d}s). Press Ctrl+C to stop.\n", .{options.watch_interval}); while (true) { - // Display header for better readability if (!options.json) { colors.printInfo("\n=== FetchML Status - {s} ===\n", .{user_context.name}); } @@ -146,18 +94,7 @@ fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: User colors.printInfo("Next update in {d} seconds...\n", .{options.watch_interval}); } - // Sleep for the specified interval using a simple busy wait for now - // TODO: Replace with proper sleep implementation when Zig 0.15 sleep API is stable - const start_time = std.time.nanoTimestamp(); - const target_time = start_time + (@as(i128, options.watch_interval) * std.time.ns_per_s); - - while (std.time.nanoTimestamp() < target_time) { - // Simple busy wait - check time every 10ms - const check_start = std.time.nanoTimestamp(); - while (std.time.nanoTimestamp() < check_start + (10 * std.time.ns_per_ms)) { - // Spin wait for 10ms - } - } + std.Thread.sleep(options.watch_interval * std.time.ns_per_s); } } @@ -169,10 +106,4 @@ fn printUsage() !void { colors.printInfo(" --limit Limit number of results shown\n", .{}); colors.printInfo(" --watch-interval= Set watch interval in seconds (default: 5)\n", .{}); colors.printInfo(" --help Show this help message\n", .{}); - colors.printInfo("\nExamples:\n", .{}); - colors.printInfo(" ml status # Show current status\n", .{}); - colors.printInfo(" ml status --json # Show status as JSON\n", .{}); - colors.printInfo(" ml status --watch # Watch mode with default interval\n", .{}); - colors.printInfo(" ml status --watch --limit 10 # Watch mode with 10 results limit\n", .{}); - colors.printInfo(" ml status --watch-interval=2 # Watch mode with 2-second interval\n", .{}); } diff --git a/cli/src/config.zig b/cli/src/config.zig index daccd51..e47f23f 100644 --- a/cli/src/config.zig +++ b/cli/src/config.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const security = @import("security.zig"); pub const Config = struct { worker_host: []const u8, @@ -75,6 +76,13 @@ pub const Config = struct { config.api_key = try allocator.dupe(u8, api_key); } + // Try to get API key from keychain if not in config or env + if (config.api_key.len == 0) { + if (try security.SecureStorage.retrieveApiKey(allocator)) |keychain_key| { + config.api_key = keychain_key; + } + } + try config.validate(); return config; } diff --git a/cli/src/main.zig b/cli/src/main.zig index bfe52d8..6cb15c8 100644 --- a/cli/src/main.zig +++ b/cli/src/main.zig @@ -122,6 +122,10 @@ pub fn main() !void { command_found = true; try @import("commands/validate.zig").run(allocator, args[2..]); }, + 'l' => if (std.mem.eql(u8, command, "logs")) { + command_found = true; + try @import("commands/logs.zig").run(allocator, args[2..]); + }, else => {}, } @@ -148,6 +152,7 @@ fn printUsage() void { std.debug.print(" queue (q) Queue job for execution\n", .{}); std.debug.print(" status Get system status\n", .{}); std.debug.print(" monitor Launch TUI via SSH\n", .{}); + std.debug.print(" logs Fetch job logs (-f to follow, -n for tail)\n", .{}); std.debug.print(" cancel Cancel running job\n", .{}); std.debug.print(" prune Remove old experiments\n", .{}); std.debug.print(" watch Watch directory for auto-sync\n", .{}); @@ -162,4 +167,5 @@ test { _ = @import("commands/requeue.zig"); _ = @import("commands/annotate.zig"); _ = @import("commands/narrative.zig"); + _ = @import("commands/logs.zig"); } diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 2966d89..4269eaf 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -263,6 +263,7 @@ pub const Client = struct { api_key_hash: []const u8, args: []const u8, note: []const u8, + force: bool, cpu: u8, memory_gb: u8, gpu: u8, @@ -286,8 +287,9 @@ pub const Client = struct { // [job_name_len][job_name] // [args_len:2][args] // [note_len:2][note] + // [force:1] // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 2 + note.len + 4 + gpu_mem.len; + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 2 + note.len + 1 + 4 + gpu_mem.len; var buffer = try self.allocator.alloc(u8, total_len); defer self.allocator.free(buffer); @@ -328,6 +330,10 @@ pub const Client = struct { offset += note.len; } + // Force flag + buffer[offset] = if (force) 0x01 else 0x00; + offset += 1; + buffer[offset] = cpu; buffer[offset + 1] = memory_gb; buffer[offset + 2] = gpu; @@ -348,6 +354,7 @@ pub const Client = struct { priority: u8, api_key_hash: []const u8, args: []const u8, + force: bool, cpu: u8, memory_gb: u8, gpu: u8, @@ -369,8 +376,9 @@ pub const Client = struct { // [priority] // [job_name_len][job_name] // [args_len:2][args] + // [force:1] // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 4 + gpu_mem.len; + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 1 + 4 + gpu_mem.len; var buffer = try self.allocator.alloc(u8, total_len); defer self.allocator.free(buffer); @@ -402,6 +410,10 @@ pub const Client = struct { offset += args.len; } + // Force flag + buffer[offset] = if (force) 0x01 else 0x00; + offset += 1; + buffer[offset] = cpu; buffer[offset + 1] = memory_gb; buffer[offset + 2] = gpu; @@ -1209,6 +1221,91 @@ pub const Client = struct { try frame.sendWebSocketFrame(stream, buffer); } + // Logs and debug methods + pub fn sendGetLogs(self: *Client, target_id: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (target_id.len == 0 or target_id.len > 255) return error.InvalidTargetId; + + // Build binary message: [opcode:1][api_key_hash:16][target_id_len:1][target_id:var] + const total_len = 1 + 16 + 1 + target_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.get_logs); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(target_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + target_id.len], target_id); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendStreamLogs(self: *Client, target_id: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (target_id.len == 0 or target_id.len > 255) return error.InvalidTargetId; + + // Build binary message: [opcode:1][api_key_hash:16][target_id_len:1][target_id:var] + const total_len = 1 + 16 + 1 + target_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.stream_logs); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(target_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + target_id.len], target_id); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendAttachDebug(self: *Client, target_id: []const u8, debug_type: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (target_id.len == 0 or target_id.len > 255) return error.InvalidTargetId; + if (debug_type.len > 255) return error.InvalidDebugType; + + // Build binary message: [opcode:1][api_key_hash:16][target_id_len:1][target_id:var][debug_type:var] + const total_len = 1 + 16 + 1 + target_id.len + debug_type.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.attach_debug); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(target_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + target_id.len], target_id); + offset += target_id.len; + + if (debug_type.len > 0) { + @memcpy(buffer[offset .. offset + debug_type.len], debug_type); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + /// Receive and handle dataset response pub fn receiveAndHandleDatasetResponse(self: *Client, allocator: std.mem.Allocator) ![]const u8 { const message = try self.receiveMessage(allocator); diff --git a/cli/src/net/ws/frame.zig b/cli/src/net/ws/frame.zig index 10fdf1f..8613e03 100644 --- a/cli/src/net/ws/frame.zig +++ b/cli/src/net/ws/frame.zig @@ -46,9 +46,14 @@ pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator const header_bytes = try stream.read(&header); if (header_bytes < 2) return error.ConnectionClosed; - if (header[0] != 0x82) return error.InvalidFrame; + // Accept both binary (0x82) and text (0x81) frames + const opcode = header[0] & 0x0F; + if (opcode != 0x02 and opcode != 0x01) return error.InvalidFrame; + + const masked = (header[1] & 0x80) != 0; + var payload_len: usize = header[1] & 0x7F; + var mask_key: [4]u8 = undefined; - var payload_len: usize = header[1]; if (payload_len == 126) { var len_bytes: [2]u8 = undefined; _ = try stream.read(&len_bytes); @@ -57,6 +62,11 @@ pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator return error.PayloadTooLarge; } + // Read mask key if frame is masked + if (masked) { + _ = try stream.read(&mask_key); + } + const payload = try allocator.alloc(u8, payload_len); errdefer allocator.free(payload); @@ -67,5 +77,12 @@ pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator bytes_read += n; } + // Unmask payload if needed + if (masked) { + for (payload, 0..) |byte, i| { + payload[i] = byte ^ mask_key[i % 4]; + } + } + return payload; } diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig index f679395..4f6e142 100644 --- a/cli/src/net/ws/opcode.zig +++ b/cli/src/net/ws/opcode.zig @@ -21,6 +21,11 @@ pub const Opcode = enum(u8) { validate_request = 0x16, + // Logs and debug opcodes + get_logs = 0x20, + stream_logs = 0x21, + attach_debug = 0x22, + // Dataset management opcodes dataset_list = 0x06, dataset_register = 0x07, @@ -61,6 +66,9 @@ pub const restore_jupyter = Opcode.restore_jupyter; pub const list_jupyter = Opcode.list_jupyter; pub const list_jupyter_packages = Opcode.list_jupyter_packages; pub const validate_request = Opcode.validate_request; +pub const get_logs = Opcode.get_logs; +pub const stream_logs = Opcode.stream_logs; +pub const attach_debug = Opcode.attach_debug; pub const dataset_list = Opcode.dataset_list; pub const dataset_register = Opcode.dataset_register; pub const dataset_info = Opcode.dataset_info; diff --git a/cli/src/net/ws/response_handlers.zig b/cli/src/net/ws/response_handlers.zig index 1c04390..514da66 100644 --- a/cli/src/net/ws/response_handlers.zig +++ b/cli/src/net/ws/response_handlers.zig @@ -8,207 +8,189 @@ const utils = @import("utils.zig"); /// Receive and handle status response with user filtering pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void { - _ = user_context; // TODO: Use for filtering + _ = user_context; const message = try self.receiveMessage(allocator); defer allocator.free(message); - // Check if message is JSON (or contains JSON) or plain text - if (message[0] == '{') { - // Parse JSON response - const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{}); - defer parsed.deinit(); - const root = parsed.value.object; - - if (options.json) { - // Output raw JSON - var out = io.stdoutWriter(); - try out.print("{s}\n", .{message}); + // Use binary protocol deserialization + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + // Fallback: try to find and parse JSON directly + if (std.mem.indexOf(u8, message, "{")) |json_start| { + const json_data = message[json_start..]; + try parseAndDisplayStatusJson(allocator, json_data, options); } else { - // Display user info - if (root.get("user")) |user_obj| { - const user = user_obj.object; - const name = user.get("name").?.string; - const admin = user.get("admin").?.bool; - colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin }); + std.debug.print("Server response: {s}\n", .{message}); + } + return; + }; + defer { + if (packet.status_data) |data| allocator.free(data); + if (packet.data_payload) |payload| allocator.free(payload); + if (packet.data_type) |dtype| allocator.free(dtype); + if (packet.success_message) |msg| allocator.free(msg); + if (packet.error_message) |msg| allocator.free(msg); + } + + // Handle status packet type (or data packet from server) + if (packet.packet_type == .status) { + if (packet.status_data) |json_data| { + try parseAndDisplayStatusJson(allocator, json_data, options); + } + } else if (packet.packet_type == .data) { + // Server sends status as data packet with JSON payload + if (packet.data_payload) |json_data| { + try parseAndDisplayStatusJson(allocator, json_data, options); + } + } else if (packet.packet_type == .error_packet) { + colors.printError("Error: {s}\n", .{packet.error_message orelse "Unknown error"}); + } else { + std.debug.print("Unexpected packet type: {s}\n", .{@tagName(packet.packet_type)}); + } +} + +fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8, options: anytype) !void { + const parsed = std.json.parseFromSlice(std.json.Value, allocator, json_data, .{}) catch { + std.debug.print("{s}\n", .{json_data}); + return; + }; + defer parsed.deinit(); + const root = parsed.value.object; + + if (options.json) { + // Output raw JSON + var out = io.stdoutWriter(); + try out.print("{s}\n", .{json_data}); + } else { + // Display user info + if (root.get("user")) |user_obj| { + const user = user_obj.object; + const name = user.get("name").?.string; + const admin = user.get("admin").?.bool; + colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin }); + } + + // Display task summary + if (root.get("tasks")) |tasks_obj| { + const tasks = tasks_obj.object; + const total = tasks.get("total").?.integer; + const queued = tasks.get("queued").?.integer; + const running = tasks.get("running").?.integer; + const failed = tasks.get("failed").?.integer; + const completed = tasks.get("completed").?.integer; + colors.printInfo( + "Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n", + .{ total, queued, running, failed, completed }, + ); + } + + const per_section_limit: usize = options.limit orelse 5; + + const TaskStatus = enum { queued, running, failed, completed }; + + const TaskPrinter = struct { + fn statusLabel(s: TaskStatus) []const u8 { + return switch (s) { + .queued => "Queued", + .running => "Running", + .failed => "Failed", + .completed => "Completed", + }; } - // Display task summary - if (root.get("tasks")) |tasks_obj| { - const tasks = tasks_obj.object; - const total = tasks.get("total").?.integer; - const queued = tasks.get("queued").?.integer; - const running = tasks.get("running").?.integer; - const failed = tasks.get("failed").?.integer; - const completed = tasks.get("completed").?.integer; - colors.printInfo( - "Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n", - .{ total, queued, running, failed, completed }, - ); + fn statusMatch(s: TaskStatus) []const u8 { + return switch (s) { + .queued => "queued", + .running => "running", + .failed => "failed", + .completed => "completed", + }; } - const per_section_limit: usize = options.limit orelse 5; + fn shorten(s: []const u8, max_len: usize) []const u8 { + if (s.len <= max_len) return s; + return s[0..max_len]; + } - const TaskStatus = enum { queued, running, failed, completed }; + fn printSection( + allocator2: std.mem.Allocator, + queue_items: []const std.json.Value, + status: TaskStatus, + limit2: usize, + ) !void { + _ = allocator2; + const label = statusLabel(status); + const want = statusMatch(status); + std.debug.print("\n{s}:\n", .{label}); - const TaskPrinter = struct { - fn statusLabel(s: TaskStatus) []const u8 { - return switch (s) { - .queued => "Queued", - .running => "Running", - .failed => "Failed", - .completed => "Completed", - }; + var shown: usize = 0; + for (queue_items) |item| { + if (item != .object) continue; + const obj = item.object; + const st = utils.jsonGetString(obj, "status") orelse ""; + if (!std.mem.eql(u8, st, want)) continue; + + const id = utils.jsonGetString(obj, "id") orelse ""; + const job_name = utils.jsonGetString(obj, "job_name") orelse ""; + const worker_id = utils.jsonGetString(obj, "worker_id") orelse ""; + const err = utils.jsonGetString(obj, "error") orelse ""; + + if (std.mem.eql(u8, want, "failed")) { + colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name }); + if (worker_id.len > 0) { + std.debug.print(" (worker={s})", .{worker_id}); + } + std.debug.print("\n", .{}); + if (err.len > 0) { + std.debug.print(" error: {s}\n", .{shorten(err, 160)}); + } + } else if (std.mem.eql(u8, want, "running")) { + colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name }); + if (worker_id.len > 0) { + std.debug.print(" (worker={s})", .{worker_id}); + } + std.debug.print("\n", .{}); + } else if (std.mem.eql(u8, want, "queued")) { + std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name }); + } else { + colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name }); + } + + shown += 1; + if (shown >= limit2) break; } - fn statusMatch(s: TaskStatus) []const u8 { - return switch (s) { - .queued => "queued", - .running => "running", - .failed => "failed", - .completed => "completed", - }; - } - - fn shorten(s: []const u8, max_len: usize) []const u8 { - if (s.len <= max_len) return s; - return s[0..max_len]; - } - - fn printSection( - allocator2: std.mem.Allocator, - queue_items: []const std.json.Value, - status: TaskStatus, - limit2: usize, - ) !void { - _ = allocator2; - const label = statusLabel(status); - const want = statusMatch(status); - std.debug.print("\n{s}:\n", .{label}); - - var shown: usize = 0; + if (shown == 0) { + std.debug.print(" (none)\n", .{}); + } else { + // Indicate there may be more. + var total_for_status: usize = 0; for (queue_items) |item| { if (item != .object) continue; const obj = item.object; const st = utils.jsonGetString(obj, "status") orelse ""; - if (!std.mem.eql(u8, st, want)) continue; - - const id = utils.jsonGetString(obj, "id") orelse ""; - const job_name = utils.jsonGetString(obj, "job_name") orelse ""; - const worker_id = utils.jsonGetString(obj, "worker_id") orelse ""; - const err = utils.jsonGetString(obj, "error") orelse ""; - - if (std.mem.eql(u8, want, "failed")) { - colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name }); - if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); - } - std.debug.print("\n", .{}); - if (err.len > 0) { - std.debug.print(" error: {s}\n", .{shorten(err, 160)}); - } - } else if (std.mem.eql(u8, want, "running")) { - colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name }); - if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); - } - std.debug.print("\n", .{}); - } else if (std.mem.eql(u8, want, "queued")) { - std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name }); - } else { - colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name }); - } - - shown += 1; - if (shown >= limit2) break; + if (std.mem.eql(u8, st, want)) total_for_status += 1; } - - if (shown == 0) { - std.debug.print(" (none)\n", .{}); - } else { - // Indicate there may be more. - var total_for_status: usize = 0; - for (queue_items) |item| { - if (item != .object) continue; - const obj = item.object; - const st = utils.jsonGetString(obj, "status") orelse ""; - if (std.mem.eql(u8, st, want)) total_for_status += 1; - } - if (total_for_status > shown) { - std.debug.print(" ... and {d} more\n", .{total_for_status - shown}); - } + if (total_for_status > shown) { + std.debug.print(" ... and {d} more\n", .{total_for_status - shown}); } } - }; - - if (root.get("queue")) |queue_val| { - if (queue_val == .array) { - const items = queue_val.array.items; - try TaskPrinter.printSection(allocator, items, .queued, per_section_limit); - try TaskPrinter.printSection(allocator, items, .running, per_section_limit); - try TaskPrinter.printSection(allocator, items, .failed, per_section_limit); - try TaskPrinter.printSection(allocator, items, .completed, per_section_limit); - } } - - if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| { - defer allocator.free(section); - colors.printInfo("{s}", .{section}); - } - } - } else { - // Handle plain text response - filter out non-printable characters - var clean_msg = allocator.alloc(u8, message.len) catch { - if (options.json) { - var out = io.stdoutWriter(); - try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } - return; }; - defer allocator.free(clean_msg); - var clean_len: usize = 0; - for (message) |byte| { - // Skip WebSocket frame header bytes and non-printable chars - if (byte >= 32 and byte <= 126) { // printable ASCII only - clean_msg[clean_len] = byte; - clean_len += 1; + if (root.get("queue")) |queue_val| { + if (queue_val == .array) { + const items = queue_val.array.items; + try TaskPrinter.printSection(allocator, items, .queued, per_section_limit); + try TaskPrinter.printSection(allocator, items, .running, per_section_limit); + try TaskPrinter.printSection(allocator, items, .failed, per_section_limit); + try TaskPrinter.printSection(allocator, items, .completed, per_section_limit); } } - // Look for common error messages in the cleaned data - if (clean_len > 0) { - const cleaned = clean_msg[0..clean_len]; - if (options.json) { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - var out = io.stdoutWriter(); - try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - var out = io.stdoutWriter(); - try out.print("{{\"error\": \"authentication_failed\"}}\n", .{}); - } else { - var out = io.stdoutWriter(); - try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned}); - } - } else { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - std.debug.print("Insufficient permissions to view jobs\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - std.debug.print("Authentication failed\n", .{}); - } else { - std.debug.print("Server response: {s}\n", .{cleaned}); - } - } - } else { - if (options.json) { - var out = io.stdoutWriter(); - try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } + if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| { + defer allocator.free(section); + colors.printInfo("{s}", .{section}); } - return; } } diff --git a/cli/src/security.zig b/cli/src/security.zig new file mode 100644 index 0000000..520061b --- /dev/null +++ b/cli/src/security.zig @@ -0,0 +1,59 @@ +const std = @import("std"); + +/// Secure credential storage using macOS Keychain +pub const SecureStorage = struct { + const ServiceName = "com.fetchml.cli"; + + /// Store API key in macOS Keychain + pub fn storeApiKey(api_key: []const u8) !void { + const result = std.process.Child.run(.{ + .allocator = std.heap.page_allocator, + .argv = &.{ + "security", "add-generic-password", + "-s", ServiceName, + "-a", "api_key", + "-w", api_key, + "-U", + }, + }) catch return error.KeychainError; + + if (result.term.Exited != 0) { + return error.KeychainError; + } + } + + /// Retrieve API key from macOS Keychain + pub fn retrieveApiKey(allocator: std.mem.Allocator) !?[]u8 { + const result = std.process.Child.run(.{ + .allocator = allocator, + .argv = &.{ + "security", "find-generic-password", + "-s", ServiceName, + "-a", "api_key", + "-w", + }, + }) catch return null; + + if (result.term.Exited != 0) { + return null; + } + + const stdout = result.stdout; + if (stdout.len > 0 and stdout[stdout.len - 1] == '\n') { + return try allocator.dupe(u8, stdout[0 .. stdout.len - 1]); + } + return try allocator.dupe(u8, stdout); + } + + /// Delete stored API key + pub fn deleteApiKey() void { + _ = std.process.Child.run(.{ + .allocator = std.heap.page_allocator, + .argv = &.{ + "security", "delete-generic-password", + "-s", ServiceName, + "-a", "api_key", + }, + }) catch {}; + } +};