diff --git a/cli/Makefile b/cli/Makefile index 297273e..29daf3f 100644 --- a/cli/Makefile +++ b/cli/Makefile @@ -4,21 +4,36 @@ ZIG ?= zig BUILD_DIR ?= zig-out/bin BINARY := $(BUILD_DIR)/ml -.PHONY: all prod dev install clean help +.PHONY: all prod dev test build-rsync install clean help + +RSYNC_VERSION ?= 3.3.0 +RSYNC_SRC_BASE ?= https://download.samba.org/pub/rsync/src +RSYNC_TARBALL ?= rsync-$(RSYNC_VERSION).tar.gz +RSYNC_TARBALL_SHA256 ?= all: $(BINARY) $(BUILD_DIR): mkdir -p $(BUILD_DIR) -$(BINARY): src/main.zig | $(BUILD_DIR) - $(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BINARY) src/main.zig +$(BINARY): | $(BUILD_DIR) + $(ZIG) build --release=small prod: src/main.zig | $(BUILD_DIR) - $(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BUILD_DIR)/ml src/main.zig + $(ZIG) build --release=small dev: src/main.zig | $(BUILD_DIR) - $(ZIG) build-exe -OReleaseFast -femit-bin=$(BUILD_DIR)/ml src/main.zig + $(ZIG) build --release=fast + +test: + $(ZIG) build test + +build-rsync: + @RSYNC_VERSION="$(RSYNC_VERSION)" \ + RSYNC_SRC_BASE="$(RSYNC_SRC_BASE)" \ + RSYNC_TARBALL="$(RSYNC_TARBALL)" \ + RSYNC_TARBALL_SHA256="$(RSYNC_TARBALL_SHA256)" \ + bash "$(CURDIR)/scripts/build_rsync.sh" install: $(BINARY) install -d $(DESTDIR)/usr/local/bin @@ -32,5 +47,7 @@ help: @echo " all - build release-small binary (default)" @echo " prod - build production binary with ReleaseSmall" @echo " dev - build development binary with ReleaseFast" + @echo " test - run Zig unit tests" + @echo " build-rsync - build pinned rsync from official source into src/assets (RSYNC_VERSION=... override)" @echo " install - copy binary into /usr/local/bin" @echo " clean - remove build artifacts" \ No newline at end of file diff --git a/cli/README.md b/cli/README.md index 5eac0dc..7cd2ef5 100644 --- a/cli/README.md +++ b/cli/README.md @@ -19,10 +19,12 @@ zig build - `ml init` - Setup configuration - `ml sync ` - Sync project to server -- `ml queue [job2 ...] [--commit ] [--priority N]` - Queue one or more jobs +- `ml queue [job2 ...] [--commit ] [--priority N] [--note ]` - Queue one or more jobs - `ml status` - Check system/queue status for your API key - `ml validate [--json] [--task ]` - Validate provenance + integrity for a commit or task (includes `run_manifest.json` consistency checks when validating by task) - `ml info [--json] [--base ]` - Show run info from `run_manifest.json` (by path or by scanning `finished/failed/running/pending`) +- `ml annotate --note [--author ] [--base ] [--json]` - Append a human annotation to `run_manifest.json` +- `ml narrative set [--hypothesis ] [--context ] [--intent ] [--expected-outcome ] [--parent-run ] [--experiment-group ] [--tags ] [--base ] [--json]` - Patch the `narrative` field in `run_manifest.json` - `ml monitor` - Launch monitoring interface (TUI) - `ml cancel ` - Cancel a running/queued job you own - `ml prune --keep N` - Keep N recent experiments @@ -31,6 +33,7 @@ zig build Notes: +- `--json` mode is designed to be pipe-friendly: machine-readable JSON is emitted to stdout, while user-facing messages/errors go to stderr. - When running `ml validate --task `, the server will try to locate the job's `run_manifest.json` under the configured base path (pending/running/finished/failed) and cross-check key fields (task id, commit id, deps, snapshot). - For tasks in `running`, `completed`, or `failed` state, a missing `run_manifest.json` is treated as a validation failure. For `queued` tasks, it is treated as a warning (the job may not have started yet). @@ -43,6 +46,9 @@ Notes: Queues a job named `my-job`. If `--commit` is omitted, the CLI generates a random commit ID and records `(job_name, commit_id)` in `~/.ml/history.log` so you don't have to remember hashes. +- `ml queue my-job --note "baseline run; lr=1e-3"` + Adds a human-readable note to the run; it will be persisted into the run's `run_manifest.json` (under `metadata.note`). + - `ml experiment list` Shows recent experiments from history with alias (job name) and commit ID. diff --git a/cli/build.zig b/cli/build.zig index 833b299..201c78c 100644 --- a/cli/build.zig +++ b/cli/build.zig @@ -5,9 +5,63 @@ pub fn build(b: *std.Build) void { // Standard target options const target = b.standardTargetOptions(.{}); + const test_filter = b.option([]const u8, "test-filter", "Filter unit tests by name"); + _ = test_filter; + // Optimized release mode for size const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseSmall }); + const options = b.addOptions(); + + const arch = target.result.cpu.arch; + const os_tag = target.result.os.tag; + + const arch_str: []const u8 = switch (arch) { + .x86_64 => "x86_64", + .aarch64 => "arm64", + else => "unknown", + }; + const os_str: []const u8 = switch (os_tag) { + .linux => "linux", + .macos => "darwin", + .windows => "windows", + else => "unknown", + }; + + const candidate_specific = b.fmt("src/assets/rsync_release_{s}_{s}.bin", .{ os_str, arch_str }); + const candidate_default = "src/assets/rsync_release.bin"; + + var selected_candidate: []const u8 = ""; + var has_rsync_release = false; + + // Prefer a platform-specific asset if available. + if (std.fs.cwd().openFile(candidate_specific, .{}) catch null) |f| { + f.close(); + selected_candidate = candidate_specific; + has_rsync_release = true; + } else if (std.fs.cwd().openFile(candidate_default, .{}) catch null) |f| { + f.close(); + selected_candidate = candidate_default; + has_rsync_release = true; + } + + if ((optimize == .ReleaseSmall or optimize == .ReleaseFast) and !has_rsync_release) { + std.debug.panic( + "Release build requires an embedded rsync binary asset. Provide one of: '{s}' or '{s}'", + .{ candidate_specific, candidate_default }, + ); + } + + // rsync_embedded_binary.zig calls @embedFile() from cli/src/utils, so the embed path + // must be relative to that directory. + const selected_embed_path = if (has_rsync_release) + b.fmt("../assets/{s}", .{std.fs.path.basename(selected_candidate)}) + else + ""; + + options.addOption(bool, "has_rsync_release", has_rsync_release); + options.addOption([]const u8, "rsync_release_path", selected_embed_path); + // CLI executable const exe = b.addExecutable(.{ .name = "ml", @@ -18,6 +72,10 @@ pub fn build(b: *std.Build) void { }), }); + exe.root_module.strip = true; + + exe.root_module.addOptions("build_options", options); + // Install the executable to zig-out/bin b.installArtifact(exe); @@ -44,6 +102,7 @@ pub fn build(b: *std.Build) void { .optimize = .Debug, }), }); + main_tests.root_module.addOptions("build_options", options); const run_main_tests = b.addRunArtifact(main_tests); test_step.dependOn(&run_main_tests.step); @@ -75,6 +134,8 @@ pub fn build(b: *std.Build) void { .optimize = .Debug, }); + test_module.addOptions("build_options", options); + // Make src module available to tests as "src" test_module.addImport("src", src_module); diff --git a/cli/scripts/build_rsync.sh b/cli/scripts/build_rsync.sh new file mode 100644 index 0000000..e755a14 --- /dev/null +++ b/cli/scripts/build_rsync.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +set -euo pipefail + +RSYNC_VERSION="${RSYNC_VERSION:-3.3.0}" +RSYNC_SRC_BASE="${RSYNC_SRC_BASE:-https://download.samba.org/pub/rsync/src}" +RSYNC_TARBALL="${RSYNC_TARBALL:-rsync-${RSYNC_VERSION}.tar.gz}" +RSYNC_TARBALL_SHA256="${RSYNC_TARBALL_SHA256:-}" + +os="$(uname -s | tr '[:upper:]' '[:lower:]')" +arch="$(uname -m)" +if [[ "${arch}" == "aarch64" || "${arch}" == "arm64" ]]; then arch="arm64"; fi +if [[ "${arch}" == "x86_64" ]]; then arch="x86_64"; fi + +if [[ "${os}" != "linux" ]]; then + echo "build-rsync: supported on linux only (for reproducible official builds). Use system rsync on ${os} or build on a native runner." >&2 + exit 2 +fi + +repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +out="${repo_root}/src/assets/rsync_release_linux_${arch}.bin" + +tmp="$(mktemp -d)" +cleanup() { rm -rf "${tmp}"; } +trap cleanup EXIT + +url="${RSYNC_SRC_BASE}/${RSYNC_TARBALL}" +sig_url_asc="${url}.asc" +sig_url_sig="${url}.sig" + +echo "fetching ${url}" +curl -fsSL "${url}" -o "${tmp}/rsync.tar.gz" + +verified=0 +if command -v gpg >/dev/null 2>&1; then + sig_file="" + sig_url="" + if curl -fsSL "${sig_url_asc}" -o "${tmp}/rsync.tar.gz.asc"; then + sig_file="${tmp}/rsync.tar.gz.asc" + sig_url="${sig_url_asc}" + elif curl -fsSL "${sig_url_sig}" -o "${tmp}/rsync.tar.gz.sig"; then + sig_file="${tmp}/rsync.tar.gz.sig" + sig_url="${sig_url_sig}" + fi + + if [[ -n "${sig_file}" ]]; then + echo "verifying signature ${sig_url}" + if gpg --batch --verify "${sig_file}" "${tmp}/rsync.tar.gz"; then + verified=1 + else + echo "build-rsync: gpg signature check failed (often because the public key is not in your keyring)." >&2 + fi + fi +fi + +if [[ "${verified}" -ne 1 ]]; then + if [[ -n "${RSYNC_TARBALL_SHA256}" ]]; then + echo "verifying sha256 for ${url}" + if command -v sha256sum >/dev/null 2>&1; then + echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | sha256sum -c - + elif command -v shasum >/dev/null 2>&1; then + echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | shasum -a 256 -c - + else + echo "build-rsync: need sha256sum or shasum for checksum verification" >&2 + exit 2 + fi + else + echo "build-rsync: could not verify ${url} (no usable gpg signature, and RSYNC_TARBALL_SHA256 is empty)." >&2 + echo "Set RSYNC_TARBALL_SHA256= or install gpg with a trusted key for the rsync signing identity." >&2 + exit 2 + fi +fi + +tar -C "${tmp}" -xzf "${tmp}/rsync.tar.gz" +extract_dir="$(tar -tzf "${tmp}/rsync.tar.gz" | head -n 1 | cut -d/ -f1)" +cd "${tmp}/${extract_dir}" + +CC=musl-gcc CFLAGS="-O2" LDFLAGS="-static" ./configure --disable-xxhash --disable-zstd --disable-lz4 +make -j"$(getconf _NPROCESSORS_ONLN 2>/dev/null || echo 2)" + +mkdir -p "$(dirname "${out}")" +cp rsync "${out}" +chmod +x "${out}" +echo "built ${out}" diff --git a/cli/scripts/ml_completion.bash b/cli/scripts/ml_completion.bash index 6963ea2..9e7fa16 100644 --- a/cli/scripts/ml_completion.bash +++ b/cli/scripts/ml_completion.bash @@ -15,7 +15,7 @@ _ml_completions() global_opts="--help --verbose --quiet --monitor" # Top-level subcommands - cmds="init sync queue status monitor cancel prune watch dataset experiment" + cmds="init sync queue requeue status monitor cancel prune watch dataset experiment" # If completing the subcommand itself if [[ ${COMP_CWORD} -eq 1 ]]; then @@ -41,8 +41,51 @@ _ml_completions() COMPREPLY=( $(compgen -d -- "${cur}") ) ;; queue) - # Suggest common job names (static for now) - COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") ) + queue_opts="--commit --priority --cpu --memory --gpu --gpu-memory --snapshot-id --snapshot-sha256 --args -- ${global_opts}" + case "${prev}" in + --priority) + COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") ) + ;; + --cpu|--memory|--gpu) + COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") ) + ;; + --gpu-memory) + COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") ) + ;; + --commit|--snapshot-id|--snapshot-sha256|--args) + # Free-form; no special completion + ;; + *) + if [[ "${cur}" == --* ]]; then + COMPREPLY=( $(compgen -W "${queue_opts}" -- "${cur}") ) + else + # Suggest common job names (static for now) + COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") ) + fi + ;; + esac + ;; + requeue) + requeue_opts="--name --priority --cpu --memory --gpu --gpu-memory --args -- ${global_opts}" + case "${prev}" in + --priority) + COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") ) + ;; + --cpu|--memory|--gpu) + COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") ) + ;; + --gpu-memory) + COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") ) + ;; + --name|--args) + # Free-form; no special completion + ;; + *) + if [[ "${cur}" == --* ]]; then + COMPREPLY=( $(compgen -W "${requeue_opts}" -- "${cur}") ) + fi + ;; + esac ;; status) COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") ) diff --git a/cli/scripts/ml_completion.zsh b/cli/scripts/ml_completion.zsh index 6c95d6f..0f0058e 100644 --- a/cli/scripts/ml_completion.zsh +++ b/cli/scripts/ml_completion.zsh @@ -10,6 +10,7 @@ _ml() { 'init:Setup configuration interactively' 'sync:Sync project to server' 'queue:Queue job for execution' + 'requeue:Re-submit a previous run/commit' 'status:Get system status' 'monitor:Launch TUI via SSH' 'cancel:Cancel running job' @@ -53,7 +54,33 @@ _ml() { '--verbose[Enable verbose output]' \ '--quiet[Suppress non-error output]' \ '--monitor[Monitor progress]' \ - '1:job name:' + '--commit[Commit id (40-hex) or unique prefix (>=7)]:commit id:' \ + '--priority[Priority (0-255)]:priority:' \ + '--cpu[CPU cores]:cpu:' \ + '--memory[Memory (GB)]:memory:' \ + '--gpu[GPU count]:gpu:' \ + '--gpu-memory[GPU memory]:gpu memory:' \ + '--snapshot-id[Snapshot id]:snapshot id:' \ + '--snapshot-sha256[Snapshot sha256]:snapshot sha256:' \ + '--args[Runner args string]:args:' \ + '1:job name:' \ + '*:args separator:(--)' + ;; + requeue) + _arguments -C \ + '--help[Show requeue help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '--name[Override job name]:job name:' \ + '--priority[Priority (0-255)]:priority:' \ + '--cpu[CPU cores]:cpu:' \ + '--memory[Memory (GB)]:memory:' \ + '--gpu[GPU count]:gpu:' \ + '--gpu-memory[GPU memory]:gpu memory:' \ + '--args[Runner args string]:args:' \ + '1:commit_id|run_id|task_id|path:' \ + '*:args separator:(--)' ;; status) _arguments -C \ diff --git a/cli/src.zig b/cli/src.zig index 3148a36..023d7c2 100644 --- a/cli/src.zig +++ b/cli/src.zig @@ -3,4 +3,5 @@ pub const commands = @import("src/commands.zig"); pub const net = @import("src/net.zig"); pub const utils = @import("src/utils.zig"); pub const config = @import("src/config.zig"); +pub const Config = @import("src/config.zig").Config; pub const errors = @import("src/errors.zig"); diff --git a/cli/src/assets/README.md b/cli/src/assets/README.md index fcffe05..432ce6a 100644 --- a/cli/src/assets/README.md +++ b/cli/src/assets/README.md @@ -5,7 +5,7 @@ This directory contains rsync binaries for the ML CLI: - `rsync_placeholder.bin` - Wrapper script for dev builds (calls system rsync) -- `rsync_release.bin` - Full static rsync binary for release builds (not in repo) +- `rsync_release__.bin` - Static rsync binary for release builds (not in repo) ## Build Modes @@ -16,44 +16,35 @@ This directory contains rsync binaries for the ML CLI: - Requires rsync installed on the system ### Release Builds (ReleaseSmall, ReleaseFast) -- Uses `rsync_release.bin` (300-500KB static binary) +- Uses `rsync_release__.bin` (static binary) - Fully self-contained, no dependencies - Results in ~450-650KB CLI binary - Works on any system without rsync installed ## Preparing Release Binaries -### Option 1: Download Pre-built Static Rsync +### Option 1: Build from Official Rsync Source (recommended) -For macOS ARM64: +On Linux: ```bash -cd cli/src/assets -curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-macos-arm64 -o rsync_release.bin -chmod +x rsync_release.bin +cd cli +make build-rsync RSYNC_VERSION=3.3.0 ``` -For Linux x86_64: -```bash -cd cli/src/assets -curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-linux-x86_64 -o rsync_release.bin -chmod +x rsync_release.bin -``` - -### Option 2: Build Static Rsync Yourself +### Option 2: Build Rsync Yourself ```bash -# Clone rsync -git clone https://github.com/WayneD/rsync.git -cd rsync +# Download official source +curl -fsSL https://download.samba.org/pub/rsync/src/rsync-3.3.0.tar.gz -o rsync.tar.gz +tar -xzf rsync.tar.gz +cd rsync-3.3.0 -# Configure for static build -./configure CFLAGS="-static" LDFLAGS="-static" --disable-xxhash --disable-zstd - -# Build +# Configure & build +./configure --disable-xxhash --disable-zstd --disable-lz4 make -# Copy to assets -cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin +# Copy to assets (example) +cp rsync ../fetch_ml/cli/src/assets/rsync_release_linux_x86_64.bin ``` ### Option 3: Use System Rsync (Temporary) @@ -61,21 +52,22 @@ cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin For testing release builds without a static binary: ```bash cd cli/src/assets -cp rsync_placeholder.bin rsync_release.bin +cp rsync_placeholder.bin rsync_release_linux_x86_64.bin ``` This will still use the wrapper, but allows builds to complete. ## Verification -After placing rsync_release.bin: +After placing the appropriate `rsync_release__.bin`: ```bash -# Verify it's executable -file cli/src/assets/rsync_release.bin -# Test it -./cli/src/assets/rsync_release.bin --version +# Verify it's executable (example) +file cli/src/assets/rsync_release_linux_x86_64.bin + +# Test it (example) +./cli/src/assets/rsync_release_linux_x86_64.bin --version # Build release cd cli diff --git a/cli/src/commands.zig b/cli/src/commands.zig index e543c34..e5c1e16 100644 --- a/cli/src/commands.zig +++ b/cli/src/commands.zig @@ -1,14 +1,16 @@ -// Commands module - exports all command modules -pub const queue = @import("commands/queue.zig"); -pub const sync = @import("commands/sync.zig"); -pub const status = @import("commands/status.zig"); -pub const dataset = @import("commands/dataset.zig"); -pub const jupyter = @import("commands/jupyter.zig"); -pub const init = @import("commands/init.zig"); -pub const info = @import("commands/info.zig"); -pub const monitor = @import("commands/monitor.zig"); +pub const annotate = @import("commands/annotate.zig"); pub const cancel = @import("commands/cancel.zig"); -pub const prune = @import("commands/prune.zig"); -pub const watch = @import("commands/watch.zig"); +pub const dataset = @import("commands/dataset.zig"); pub const experiment = @import("commands/experiment.zig"); +pub const info = @import("commands/info.zig"); +pub const init = @import("commands/init.zig"); +pub const jupyter = @import("commands/jupyter.zig"); +pub const monitor = @import("commands/monitor.zig"); +pub const narrative = @import("commands/narrative.zig"); +pub const prune = @import("commands/prune.zig"); +pub const queue = @import("commands/queue.zig"); +pub const requeue = @import("commands/requeue.zig"); +pub const status = @import("commands/status.zig"); +pub const sync = @import("commands/sync.zig"); pub const validate = @import("commands/validate.zig"); +pub const watch = @import("commands/watch.zig"); diff --git a/cli/src/commands/annotate.zig b/cli/src/commands/annotate.zig new file mode 100644 index 0000000..67158de --- /dev/null +++ b/cli/src/commands/annotate.zig @@ -0,0 +1,282 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const io = @import("../utils/io.zig"); +const ws = @import("../net/ws/client.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + try printUsage(); + return error.InvalidArgs; + } + + if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) { + try printUsage(); + return; + } + + const target = args[0]; + + var author: []const u8 = ""; + var note: ?[]const u8 = null; + var base_override: ?[]const u8 = null; + var json_mode: bool = false; + + var i: usize = 1; + while (i < args.len) : (i += 1) { + const a = args[i]; + if (std.mem.eql(u8, a, "--author")) { + if (i + 1 >= args.len) { + colors.printError("Missing value for --author\n", .{}); + return error.InvalidArgs; + } + author = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--note")) { + if (i + 1 >= args.len) { + colors.printError("Missing value for --note\n", .{}); + return error.InvalidArgs; + } + note = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--base")) { + if (i + 1 >= args.len) { + colors.printError("Missing value for --base\n", .{}); + return error.InvalidArgs; + } + base_override = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--json")) { + json_mode = true; + } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { + try printUsage(); + return; + } else if (std.mem.startsWith(u8, a, "--")) { + colors.printError("Unknown option: {s}\n", .{a}); + return error.InvalidArgs; + } else { + colors.printError("Unexpected argument: {s}\n", .{a}); + return error.InvalidArgs; + } + } + + if (note == null or std.mem.trim(u8, note.?, " \t\r\n").len == 0) { + colors.printError("--note is required\n", .{}); + try printUsage(); + return error.InvalidArgs; + } + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const resolved_base = base_override orelse cfg.worker_base; + + const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| { + if (err == error.FileNotFound) { + colors.printError( + "Could not locate run_manifest.json for '{s}'. Provide a path, or use --base to scan finished/failed/running/pending.\n", + .{target}, + ); + } + return err; + }; + defer allocator.free(manifest_path); + + const job_name = try readJobNameFromManifest(allocator, manifest_path); + defer allocator.free(job_name); + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + try client.sendAnnotateRun(job_name, author, note.?, api_key_hash); + + if (json_mode) { + const msg = try client.receiveMessage(allocator); + defer allocator.free(msg); + + const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch { + var out = io.stdoutWriter(); + try out.print("{s}\n", .{msg}); + return error.InvalidPacket; + }; + defer { + if (packet.success_message) |m| allocator.free(m); + if (packet.error_message) |m| allocator.free(m); + if (packet.error_details) |m| allocator.free(m); + } + + const Result = struct { + ok: bool, + job_name: []const u8, + message: []const u8, + error_code: ?u8 = null, + error_message: ?[]const u8 = null, + details: ?[]const u8 = null, + }; + + var out = io.stdoutWriter(); + if (packet.packet_type == .error_packet) { + const res = Result{ + .ok = false, + .job_name = job_name, + .message = "", + .error_code = @intFromEnum(packet.error_code.?), + .error_message = packet.error_message orelse "", + .details = packet.error_details orelse "", + }; + try out.print("{f}\n", .{std.json.fmt(res, .{})}); + return error.CommandFailed; + } + + const res = Result{ + .ok = true, + .job_name = job_name, + .message = packet.success_message orelse "", + }; + try out.print("{f}\n", .{std.json.fmt(res, .{})}); + return; + } + + try client.receiveAndHandleResponse(allocator, "Annotate"); + + colors.printSuccess("Annotation added\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); +} + +fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 { + const data = try readFileAlloc(allocator, manifest_path); + defer allocator.free(data); + + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{}); + defer parsed.deinit(); + + if (parsed.value != .object) return error.InvalidManifest; + const root = parsed.value.object; + + const job_name = jsonGetString(root, "job_name") orelse ""; + if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) { + return error.InvalidManifest; + } + return allocator.dupe(u8, job_name); +} + +fn resolveManifestPathWithBase( + allocator: std.mem.Allocator, + input: []const u8, + base: []const u8, +) ![]u8 { + var cwd = std.fs.cwd(); + + if (std.fs.path.isAbsolute(input)) { + if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| { + var mutable_dir = dir; + defer mutable_dir.close(); + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + if (std.fs.openFileAbsolute(input, .{}) catch null) |file| { + var mutable_file = file; + defer mutable_file.close(); + return allocator.dupe(u8, input); + } + return resolveManifestPathById(allocator, input, base); + } + + const stat = cwd.statFile(input) catch |err| { + if (err == error.FileNotFound) { + return resolveManifestPathById(allocator, input, base); + } + return err; + }; + + if (stat.kind == .directory) { + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + + return allocator.dupe(u8, input); +} + +fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 { + if (std.mem.trim(u8, id, " \t\r\n").len == 0) { + return error.FileNotFound; + } + if (base_path.len == 0) { + return error.FileNotFound; + } + + const roots = [_][]const u8{ "finished", "failed", "running", "pending" }; + for (roots) |root| { + const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root }); + defer allocator.free(root_path); + + var dir = if (std.fs.path.isAbsolute(root_path)) + (std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue) + else + (std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue); + defer dir.close(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + + const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name }); + defer allocator.free(run_dir); + const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" }); + defer allocator.free(manifest_path); + + const file = if (std.fs.path.isAbsolute(manifest_path)) + (std.fs.openFileAbsolute(manifest_path, .{}) catch continue) + else + (std.fs.cwd().openFile(manifest_path, .{}) catch continue); + defer file.close(); + + const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue; + defer allocator.free(data); + + const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue; + defer parsed.deinit(); + if (parsed.value != .object) continue; + + const obj = parsed.value.object; + const run_id = jsonGetString(obj, "run_id") orelse ""; + const task_id = jsonGetString(obj, "task_id") orelse ""; + if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) { + return allocator.dupe(u8, manifest_path); + } + } + } + + return error.FileNotFound; +} + +fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 { + var file = if (std.fs.path.isAbsolute(path)) + try std.fs.openFileAbsolute(path, .{}) + else + try std.fs.cwd().openFile(path, .{}); + defer file.close(); + + return file.readToEndAlloc(allocator, 1024 * 1024); +} + +fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { + const v = obj.get(key) orelse return null; + if (v != .string) return null; + return v.string; +} + +fn printUsage() !void { + colors.printInfo("Usage: ml annotate --note [--author ] [--base ] [--json]\n", .{}); + colors.printInfo("\nExamples:\n", .{}); + colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{}); + colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{}); +} diff --git a/cli/src/commands/cancel.zig b/cli/src/commands/cancel.zig index addc955..90ceebc 100644 --- a/cli/src/commands/cancel.zig +++ b/cli/src/commands/cancel.zig @@ -1,6 +1,6 @@ const std = @import("std"); const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const logging = @import("../utils/logging.zig"); const colors = @import("../utils/colors.zig"); diff --git a/cli/src/commands/dataset.zig b/cli/src/commands/dataset.zig index 2834009..73ca487 100644 --- a/cli/src/commands/dataset.zig +++ b/cli/src/commands/dataset.zig @@ -1,6 +1,6 @@ const std = @import("std"); const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const colors = @import("../utils/colors.zig"); const logging = @import("../utils/logging.zig"); const crypto = @import("../utils/crypto.zig"); @@ -18,7 +18,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } var options = DatasetOptions{}; - // Parse global flags: --dry-run, --validate, --json var positional = std.ArrayList([]const u8).initCapacity(allocator, args.len) catch |err| { return err; @@ -44,36 +43,39 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } } - if (positional.items.len == 0) { - printUsage(); - return error.InvalidArgs; - } const action = positional.items[0]; - if (std.mem.eql(u8, action, "list")) { - try listDatasets(allocator, &options); - } else if (std.mem.eql(u8, action, "register")) { - if (positional.items.len < 3) { - colors.printError("Usage: ml dataset register \n", .{}); + switch (positional.items.len) { + 0 => { + printUsage(); return error.InvalidArgs; - } - try registerDataset(allocator, positional.items[1], positional.items[2], &options); - } else if (std.mem.eql(u8, action, "info")) { - if (positional.items.len < 2) { - colors.printError("Usage: ml dataset info \n", .{}); + }, + 1 => { + if (std.mem.eql(u8, action, "list")) { + try listDatasets(allocator, &options); + return error.InvalidArgs; + } + }, + 2 => { + if (std.mem.eql(u8, action, "info")) { + try showDatasetInfo(allocator, positional.items[1], &options); + return; + } else if (std.mem.eql(u8, action, "search")) { + try searchDatasets(allocator, positional.items[1], &options); + return error.InvalidArgs; + } + }, + 3 => { + if (std.mem.eql(u8, action, "register")) { + try registerDataset(allocator, positional.items[1], positional.items[2], &options); + return error.InvalidArgs; + } + }, + else => { + colors.printError("Unknoen action: {s}\n", .{action}); + printUsage(); return error.InvalidArgs; - } - try showDatasetInfo(allocator, positional.items[1], &options); - } else if (std.mem.eql(u8, action, "search")) { - if (positional.items.len < 2) { - colors.printError("Usage: ml dataset search \n", .{}); - return error.InvalidArgs; - } - try searchDatasets(allocator, positional.items[1], &options); - } else { - colors.printError("Unknown action: {s}\n", .{action}); - printUsage(); - return error.InvalidArgs; + }, } } diff --git a/cli/src/commands/experiment.zig b/cli/src/commands/experiment.zig index 2ca45d6..b0c275f 100644 --- a/cli/src/commands/experiment.zig +++ b/cli/src/commands/experiment.zig @@ -1,6 +1,6 @@ const std = @import("std"); const config = @import("../config.zig"); -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const protocol = @import("../net/protocol.zig"); const history = @import("../utils/history.zig"); const colors = @import("../utils/colors.zig"); diff --git a/cli/src/commands/info.zig b/cli/src/commands/info.zig index 03fba4e..126743b 100644 --- a/cli/src/commands/info.zig +++ b/cli/src/commands/info.zig @@ -1,6 +1,7 @@ const std = @import("std"); const colors = @import("../utils/colors.zig"); const Config = @import("../config.zig").Config; +const io = @import("../utils/io.zig"); pub const Options = struct { json: bool = false, @@ -61,7 +62,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { defer allocator.free(data); if (opts.json) { - std.debug.print("{s}\n", .{data}); + var out = io.stdoutWriter(); + try out.print("{s}\n", .{data}); return; } diff --git a/cli/src/commands/jupyter.zig b/cli/src/commands/jupyter.zig index 32bda53..c4509d6 100644 --- a/cli/src/commands/jupyter.zig +++ b/cli/src/commands/jupyter.zig @@ -1,6 +1,6 @@ const std = @import("std"); const colors = @import("../utils/colors.zig"); -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const protocol = @import("../net/protocol.zig"); const crypto = @import("../utils/crypto.zig"); const Config = @import("../config.zig").Config; @@ -87,7 +87,9 @@ fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void .error_packet => { const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); colors.printError("Failed to restore workspace: {s}\n", .{error_msg}); - if (packet.error_message) |msg| { + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { colors.printError("Details: {s}\n", .{msg}); } }, @@ -138,8 +140,6 @@ pub fn isValidTopLevelAction(action: []const u8) bool { std.mem.eql(u8, action, "list") or std.mem.eql(u8, action, "remove") or std.mem.eql(u8, action, "restore") or - std.mem.eql(u8, action, "workspace") or - std.mem.eql(u8, action, "experiment") or std.mem.eql(u8, action, "package"); } @@ -149,18 +149,18 @@ pub fn defaultWorkspacePath(allocator: std.mem.Allocator, name: []const u8) ![]u pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len < 1) { - printUsage(); + printUsagePackage(); return; } for (args) |arg| { if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { - printUsage(); + printUsagePackage(); return; } if (std.mem.eql(u8, arg, "--json")) { colors.printError("jupyter does not support --json\n", .{}); - printUsage(); + printUsagePackage(); return error.InvalidArgs; } } @@ -181,10 +181,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { try removeJupyter(allocator, args[1..]); } else if (std.mem.eql(u8, action, "restore")) { try restoreJupyter(allocator, args[1..]); - } else if (std.mem.eql(u8, action, "workspace")) { - try workspaceCommands(args[1..]); - } else if (std.mem.eql(u8, action, "experiment")) { - try experimentCommands(args[1..]); } else if (std.mem.eql(u8, action, "package")) { try packageCommands(args[1..]); } else { @@ -192,12 +188,11 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } } -fn printUsage() void { - colors.printError("Usage: ml jupyter [options]\n", .{}); - colors.printInfo("\nActions:\n", .{}); - colors.printInfo(" create|start|stop|status|list|remove|restore\n", .{}); - colors.printInfo(" workspace|experiment|package\n", .{}); - colors.printInfo("\nOptions:\n", .{}); +fn printUsagePackage() void { + colors.printError("Usage: ml jupyter package [options]\n", .{}); + colors.printInfo("Actions:\n", .{}); + colors.printInfo(" list\n", .{}); + colors.printInfo("Options:\n", .{}); colors.printInfo(" --help, -h Show this help message\n", .{}); } @@ -340,7 +335,9 @@ fn startJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void { .error_packet => { const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); colors.printError("Failed to start service: {s}\n", .{error_msg}); - if (packet.error_message) |msg| { + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { colors.printError("Details: {s}\n", .{msg}); } }, @@ -416,7 +413,9 @@ fn stopJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void { .error_packet => { const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); colors.printError("Failed to stop service: {s}\n", .{error_msg}); - if (packet.error_message) |msg| { + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { colors.printError("Details: {s}\n", .{msg}); } }, @@ -533,7 +532,9 @@ fn removeJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void { .error_packet => { const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); colors.printError("Failed to remove service: {s}\n", .{error_msg}); - if (packet.error_message) |msg| { + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { colors.printError("Details: {s}\n", .{msg}); } }, @@ -662,7 +663,9 @@ fn listServices(allocator: std.mem.Allocator) !void { .error_packet => { const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); colors.printError("Failed to list services: {s}\n", .{error_msg}); - if (packet.error_message) |msg| { + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { colors.printError("Details: {s}\n", .{msg}); } }, @@ -776,86 +779,133 @@ fn experimentCommands(args: []const []const u8) !void { fn packageCommands(args: []const []const u8) !void { if (args.len < 1) { - colors.printError("Usage: ml jupyter package \n", .{}); + colors.printError("Usage: ml jupyter package \n", .{}); return; } const subcommand = args[0]; - if (std.mem.eql(u8, subcommand, "install")) { + if (std.mem.eql(u8, subcommand, "list")) { if (args.len < 2) { - colors.printError("Usage: ml jupyter package install --package [--channel ] [--version ]\n", .{}); + colors.printError("Usage: ml jupyter package list \n", .{}); return; } - // Parse package name from args - var package_name: []const u8 = ""; - var channel: []const u8 = "conda-forge"; - var version: []const u8 = "latest"; - - var i: usize = 0; - while (i < args.len) { - if (std.mem.eql(u8, args[i], "--package") and i + 1 < args.len) { - package_name = args[i + 1]; - i += 2; - } else if (std.mem.eql(u8, args[i], "--channel") and i + 1 < args.len) { - channel = args[i + 1]; - i += 2; - } else if (std.mem.eql(u8, args[i], "--version") and i + 1 < args.len) { - version = args[i + 1]; - i += 2; - } else { - i += 1; - } + var service_name: []const u8 = ""; + if (std.mem.eql(u8, args[1], "--name") and args.len >= 3) { + service_name = args[2]; + } else { + service_name = args[1]; } - - if (package_name.len == 0) { - colors.printError("Package name is required\n", .{}); + if (service_name.len == 0) { + colors.printError("Service name is required\n", .{}); return; } - // Security validations - if (!validatePackageName(package_name)) { - colors.printError("Invalid package name: {s}. Only alphanumeric characters, underscores, hyphens, and dots are allowed.\n", .{package_name}); - return; + const allocator = std.heap.page_allocator; + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); } - if (isPackageBlocked(package_name)) { - colors.printError("Package '{s}' is blocked by security policy for security reasons.\n", .{package_name}); - colors.printInfo("Blocked packages typically include network libraries that could be used for unauthorized data access.\n", .{}); + const protocol_str = if (config.worker_port == 443) "wss" else "ws"; + const url = try std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{ + protocol_str, + config.worker_host, + config.worker_port, + }); + defer allocator.free(url); + + var client = ws.Client.connect(allocator, url, config.api_key) catch |err| { + colors.printError("Failed to connect to server: {}\n", .{err}); return; + }; + defer client.close(); + + const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); + defer allocator.free(api_key_hash); + + client.sendListJupyterPackages(service_name, api_key_hash) catch |err| { + colors.printError("Failed to send list packages command: {}\n", .{err}); + return; + }; + + const response = client.receiveMessage(allocator) catch |err| { + colors.printError("Failed to receive response: {}\n", .{err}); + return; + }; + defer allocator.free(response); + + const packet = protocol.ResponsePacket.deserialize(response, allocator) catch |err| { + colors.printError("Failed to parse response: {}\n", .{err}); + return; + }; + defer { + if (packet.data_type) |dtype| allocator.free(dtype); + if (packet.data_payload) |payload| allocator.free(payload); + if (packet.error_message) |msg| allocator.free(msg); + if (packet.error_details) |details| allocator.free(details); } - if (!validateChannel(channel)) { - colors.printError("Channel '{s}' is not trusted. Allowed channels: conda-forge, defaults, pytorch, nvidia\n", .{channel}); - return; - } + switch (packet.packet_type) { + .data => { + colors.printInfo("Installed packages for {s}:\n", .{service_name}); + if (packet.data_payload) |payload| { + const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch { + std.debug.print("{s}\n", .{payload}); + return; + }; + defer parsed.deinit(); - colors.printInfo("Requesting package installation...\n", .{}); - colors.printInfo("Package: {s}\n", .{package_name}); - colors.printInfo("Version: {s}\n", .{version}); - colors.printInfo("Channel: {s}\n", .{channel}); - colors.printInfo("Security: Package validated against security policies\n", .{}); - colors.printSuccess("Package request created successfully!\n", .{}); - colors.printInfo("Note: Package requires approval from administrator before installation.\n", .{}); - } else if (std.mem.eql(u8, subcommand, "list")) { - colors.printInfo("Installed packages in workspace: ./workspace\n", .{}); - colors.printInfo("Package Name Version Channel Installed By\n", .{}); - colors.printInfo("------------ ------- ------- ------------\n", .{}); - colors.printInfo("numpy 1.21.0 conda-forge user1\n", .{}); - colors.printInfo("pandas 1.3.0 conda-forge user1\n", .{}); - } else if (std.mem.eql(u8, subcommand, "pending")) { - colors.printInfo("Pending package requests for workspace: ./workspace\n", .{}); - colors.printInfo("Package Name Version Channel Requested By Time\n", .{}); - colors.printInfo("------------ ------- ------- ------------ ----\n", .{}); - colors.printInfo("torch 1.9.0 pytorch user3 2023-12-06 10:30\n", .{}); - } else if (std.mem.eql(u8, subcommand, "approve")) { - colors.printInfo("Approving package request: torch\n", .{}); - colors.printSuccess("Package request approved!\n", .{}); - } else if (std.mem.eql(u8, subcommand, "reject")) { - colors.printInfo("Rejecting package request: suspicious-package\n", .{}); - colors.printInfo("Reason: Security policy violation\n", .{}); - colors.printSuccess("Package request rejected!\n", .{}); + if (parsed.value != .array) { + std.debug.print("{s}\n", .{payload}); + return; + } + + const pkgs = parsed.value.array; + if (pkgs.items.len == 0) { + colors.printInfo("No packages found.\n", .{}); + return; + } + + colors.printInfo("NAME VERSION SOURCE\n", .{}); + colors.printInfo("---- ------- ------\n", .{}); + + for (pkgs.items) |item| { + if (item != .object) continue; + const obj = item.object; + + var name: []const u8 = ""; + if (obj.get("name")) |v| { + if (v == .string) name = v.string; + } + var version: []const u8 = ""; + if (obj.get("version")) |v| { + if (v == .string) version = v.string; + } + var source: []const u8 = ""; + if (obj.get("source")) |v| { + if (v == .string) source = v.string; + } + + std.debug.print("{s: <30} {s: <22} {s}\n", .{ name, version, source }); + } + } + }, + .error_packet => { + const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); + colors.printError("Failed to list packages: {s}\n", .{error_msg}); + if (packet.error_details) |details| { + colors.printError("Details: {s}\n", .{details}); + } else if (packet.error_message) |msg| { + colors.printError("Details: {s}\n", .{msg}); + } + }, + else => { + colors.printError("Unexpected response type\n", .{}); + }, + } } else { colors.printError("Invalid package command: {s}\n", .{subcommand}); } diff --git a/cli/src/commands/narrative.zig b/cli/src/commands/narrative.zig new file mode 100644 index 0000000..e062a27 --- /dev/null +++ b/cli/src/commands/narrative.zig @@ -0,0 +1,370 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const io = @import("../utils/io.zig"); +const ws = @import("../net/ws/client.zig"); + +pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { + if (argv.len == 0) { + try printUsage(); + return error.InvalidArgs; + } + + const sub = argv[0]; + if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) { + try printUsage(); + return; + } + + if (!std.mem.eql(u8, sub, "set")) { + colors.printError("Unknown subcommand: {s}\n", .{sub}); + try printUsage(); + return error.InvalidArgs; + } + + if (argv.len < 2) { + try printUsage(); + return error.InvalidArgs; + } + + const target = argv[1]; + + var hypothesis: ?[]const u8 = null; + var context: ?[]const u8 = null; + var intent: ?[]const u8 = null; + var expected_outcome: ?[]const u8 = null; + var parent_run: ?[]const u8 = null; + var experiment_group: ?[]const u8 = null; + var tags_csv: ?[]const u8 = null; + var base_override: ?[]const u8 = null; + var json_mode: bool = false; + + var i: usize = 2; + while (i < argv.len) : (i += 1) { + const a = argv[i]; + if (std.mem.eql(u8, a, "--hypothesis")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + hypothesis = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--context")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + context = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--intent")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + intent = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--expected-outcome")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + expected_outcome = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--parent-run")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + parent_run = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--experiment-group")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + experiment_group = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--tags")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + tags_csv = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--base")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + base_override = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--json")) { + json_mode = true; + } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { + try printUsage(); + return; + } else { + colors.printError("Unknown option: {s}\n", .{a}); + return error.InvalidArgs; + } + } + + if (hypothesis == null and context == null and intent == null and expected_outcome == null and parent_run == null and experiment_group == null and tags_csv == null) { + colors.printError("No narrative fields provided.\n", .{}); + return error.InvalidArgs; + } + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const resolved_base = base_override orelse cfg.worker_base; + const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| { + if (err == error.FileNotFound) { + colors.printError( + "Could not locate run_manifest.json for '{s}'. Provide a path, or use --base to scan finished/failed/running/pending.\n", + .{target}, + ); + } + return err; + }; + defer allocator.free(manifest_path); + + const job_name = try readJobNameFromManifest(allocator, manifest_path); + defer allocator.free(job_name); + + const patch_json = try buildPatchJSON( + allocator, + hypothesis, + context, + intent, + expected_outcome, + parent_run, + experiment_group, + tags_csv, + ); + defer allocator.free(patch_json); + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + try client.sendSetRunNarrative(job_name, patch_json, api_key_hash); + + if (json_mode) { + const msg = try client.receiveMessage(allocator); + defer allocator.free(msg); + + const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch { + var out = io.stdoutWriter(); + try out.print("{s}\n", .{msg}); + return error.InvalidPacket; + }; + defer { + if (packet.success_message) |m| allocator.free(m); + if (packet.error_message) |m| allocator.free(m); + if (packet.error_details) |m| allocator.free(m); + } + + const Result = struct { + ok: bool, + job_name: []const u8, + message: []const u8, + error_code: ?u8 = null, + error_message: ?[]const u8 = null, + details: ?[]const u8 = null, + }; + + var out = io.stdoutWriter(); + if (packet.packet_type == .error_packet) { + const res = Result{ + .ok = false, + .job_name = job_name, + .message = "", + .error_code = @intFromEnum(packet.error_code.?), + .error_message = packet.error_message orelse "", + .details = packet.error_details orelse "", + }; + try out.print("{f}\n", .{std.json.fmt(res, .{})}); + return error.CommandFailed; + } + + const res = Result{ + .ok = true, + .job_name = job_name, + .message = packet.success_message orelse "", + }; + try out.print("{f}\n", .{std.json.fmt(res, .{})}); + return; + } + + try client.receiveAndHandleResponse(allocator, "Narrative"); + + colors.printSuccess("Narrative updated\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); +} + +fn buildPatchJSON( + allocator: std.mem.Allocator, + hypothesis: ?[]const u8, + context: ?[]const u8, + intent: ?[]const u8, + expected_outcome: ?[]const u8, + parent_run: ?[]const u8, + experiment_group: ?[]const u8, + tags_csv: ?[]const u8, +) ![]u8 { + var out = std.ArrayList(u8).initCapacity(allocator, 256) catch return error.OutOfMemory; + defer out.deinit(allocator); + + var tags_list = std.ArrayList([]const u8).initCapacity(allocator, 8) catch return error.OutOfMemory; + defer tags_list.deinit(allocator); + + if (tags_csv) |csv| { + var it = std.mem.splitScalar(u8, csv, ','); + while (it.next()) |part| { + const trimmed = std.mem.trim(u8, part, " \t\r\n"); + if (trimmed.len == 0) continue; + try tags_list.append(allocator, trimmed); + } + } + + const Patch = struct { + hypothesis: ?[]const u8 = null, + context: ?[]const u8 = null, + intent: ?[]const u8 = null, + expected_outcome: ?[]const u8 = null, + parent_run: ?[]const u8 = null, + experiment_group: ?[]const u8 = null, + tags: ?[]const []const u8 = null, + }; + + const patch = Patch{ + .hypothesis = hypothesis, + .context = context, + .intent = intent, + .expected_outcome = expected_outcome, + .parent_run = parent_run, + .experiment_group = experiment_group, + .tags = if (tags_list.items.len > 0) tags_list.items else null, + }; + + const writer = out.writer(allocator); + try writer.print("{f}", .{std.json.fmt(patch, .{})}); + return out.toOwnedSlice(allocator); +} + +fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 { + const data = try readFileAlloc(allocator, manifest_path); + defer allocator.free(data); + + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{}); + defer parsed.deinit(); + + if (parsed.value != .object) return error.InvalidManifest; + const root = parsed.value.object; + + const job_name = jsonGetString(root, "job_name") orelse ""; + if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) { + return error.InvalidManifest; + } + return allocator.dupe(u8, job_name); +} + +fn resolveManifestPathWithBase(allocator: std.mem.Allocator, input: []const u8, base: []const u8) ![]u8 { + var cwd = std.fs.cwd(); + + if (std.fs.path.isAbsolute(input)) { + if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| { + var mutable_dir = dir; + defer mutable_dir.close(); + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + if (std.fs.openFileAbsolute(input, .{}) catch null) |file| { + var mutable_file = file; + defer mutable_file.close(); + return allocator.dupe(u8, input); + } + return resolveManifestPathById(allocator, input, base); + } + + const stat = cwd.statFile(input) catch |err| { + if (err == error.FileNotFound) { + return resolveManifestPathById(allocator, input, base); + } + return err; + }; + + if (stat.kind == .directory) { + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + + return allocator.dupe(u8, input); +} + +fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 { + if (std.mem.trim(u8, id, " \t\r\n").len == 0) { + return error.FileNotFound; + } + if (base_path.len == 0) { + return error.FileNotFound; + } + + const roots = [_][]const u8{ "finished", "failed", "running", "pending" }; + for (roots) |root| { + const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root }); + defer allocator.free(root_path); + + var dir = if (std.fs.path.isAbsolute(root_path)) + (std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue) + else + (std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue); + defer dir.close(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + + const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name }); + defer allocator.free(run_dir); + const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" }); + defer allocator.free(manifest_path); + + const file = if (std.fs.path.isAbsolute(manifest_path)) + (std.fs.openFileAbsolute(manifest_path, .{}) catch continue) + else + (std.fs.cwd().openFile(manifest_path, .{}) catch continue); + defer file.close(); + + const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue; + defer allocator.free(data); + + const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue; + defer parsed.deinit(); + if (parsed.value != .object) continue; + + const obj = parsed.value.object; + const run_id = jsonGetString(obj, "run_id") orelse ""; + const task_id = jsonGetString(obj, "task_id") orelse ""; + if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) { + return allocator.dupe(u8, manifest_path); + } + } + } + + return error.FileNotFound; +} + +fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 { + var file = if (std.fs.path.isAbsolute(path)) + try std.fs.openFileAbsolute(path, .{}) + else + try std.fs.cwd().openFile(path, .{}); + defer file.close(); + + return file.readToEndAlloc(allocator, 1024 * 1024); +} + +fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { + const v = obj.get(key) orelse return null; + if (v != .string) return null; + return v.string; +} + +fn printUsage() !void { + colors.printInfo("Usage: ml narrative set [fields]\n", .{}); + colors.printInfo("\nFields:\n", .{}); + colors.printInfo(" --hypothesis \"...\"\n", .{}); + colors.printInfo(" --context \"...\"\n", .{}); + colors.printInfo(" --intent \"...\"\n", .{}); + colors.printInfo(" --expected-outcome \"...\"\n", .{}); + colors.printInfo(" --parent-run \n", .{}); + colors.printInfo(" --experiment-group \n", .{}); + colors.printInfo(" --tags a,b,c\n", .{}); + colors.printInfo(" --base \n", .{}); + colors.printInfo(" --json\n", .{}); +} diff --git a/cli/src/commands/prune.zig b/cli/src/commands/prune.zig index 3ec4f5f..273120f 100644 --- a/cli/src/commands/prune.zig +++ b/cli/src/commands/prune.zig @@ -1,6 +1,6 @@ const std = @import("std"); const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const logging = @import("../utils/logging.zig"); diff --git a/cli/src/commands/queue.zig b/cli/src/commands/queue.zig index a688332..de312b9 100644 --- a/cli/src/commands/queue.zig +++ b/cli/src/commands/queue.zig @@ -1,6 +1,6 @@ const std = @import("std"); const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const colors = @import("../utils/colors.zig"); const history = @import("../utils/history.zig"); const crypto = @import("../utils/crypto.zig"); @@ -42,6 +42,43 @@ pub const QueueOptions = struct { gpu_memory: ?[]const u8 = null, }; +fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 { + if (input.len < 7 or input.len > 40) return error.InvalidArgs; + for (input) |c| { + if (!std.ascii.isHex(c)) return error.InvalidArgs; + } + + if (input.len == 40) { + return allocator.dupe(u8, input); + } + + var dir = if (std.fs.path.isAbsolute(base_path)) + try std.fs.openDirAbsolute(base_path, .{ .iterate = true }) + else + try std.fs.cwd().openDir(base_path, .{ .iterate = true }); + defer dir.close(); + + var it = dir.iterate(); + var found: ?[]u8 = null; + errdefer if (found) |s| allocator.free(s); + + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + const name = entry.name; + if (name.len != 40) continue; + if (!std.mem.startsWith(u8, name, input)) continue; + for (name) |c| { + if (!std.ascii.isHex(c)) break; + } else { + if (found != null) return error.InvalidArgs; + found = try allocator.dupe(u8, name); + } + } + + if (found) |s| return s; + return error.FileNotFound; +} + pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len == 0) { try printUsage(); @@ -64,6 +101,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var priority: u8 = 5; var snapshot_id: ?[]const u8 = null; var snapshot_sha256: ?[]const u8 = null; + var args_override: ?[]const u8 = null; + var note_override: ?[]const u8 = null; // Load configuration to get defaults const config = try Config.load(allocator); @@ -88,10 +127,33 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var tracking = TrackingConfig{}; var has_tracking = false; + // Support passing runner args after "--". + var sep_index: ?usize = null; + for (args, 0..) |a, idx| { + if (std.mem.eql(u8, a, "--")) { + sep_index = idx; + break; + } + } + const pre = args[0..(sep_index orelse args.len)]; + const post = if (sep_index) |si| args[(si + 1)..] else args[0..0]; + + var args_joined: []const u8 = ""; + if (post.len > 0) { + var buf: std.ArrayList(u8) = .{}; + defer buf.deinit(allocator); + for (post, 0..) |a, j| { + if (j > 0) try buf.append(allocator, ' '); + try buf.appendSlice(allocator, a); + } + args_joined = try buf.toOwnedSlice(allocator); + } + defer if (post.len > 0) allocator.free(args_joined); + // Parse arguments - separate job names from flags var i: usize = 0; - while (i < args.len) : (i += 1) { - const arg = args[i]; + while (i < pre.len) : (i += 1) { + const arg = pre[i]; if (std.mem.startsWith(u8, arg, "--") or std.mem.eql(u8, arg, "-h")) { // Parse flags @@ -99,15 +161,21 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { try printUsage(); return; } - if (std.mem.eql(u8, arg, "--commit") and i + 1 < args.len) { + if (std.mem.eql(u8, arg, "--commit") and i + 1 < pre.len) { if (commit_id_override != null) { allocator.free(commit_id_override.?); } - const commit_hex = args[i + 1]; - if (commit_hex.len != 40) { - colors.printError("Invalid commit id: expected 40-char hex string\n", .{}); + const commit_in = pre[i + 1]; + const commit_hex = resolveCommitHexOrPrefix(allocator, config.worker_base, commit_in) catch |err| { + if (err == error.FileNotFound) { + colors.printError("No commit matches prefix: {s}\n", .{commit_in}); + return error.InvalidArgs; + } + colors.printError("Invalid commit id\n", .{}); return error.InvalidArgs; - } + }; + defer allocator.free(commit_hex); + const commit_bytes = crypto.decodeHex(allocator, commit_hex) catch { colors.printError("Invalid commit id: must be hex\n", .{}); return error.InvalidArgs; @@ -119,35 +187,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } commit_id_override = commit_bytes; i += 1; - } else if (std.mem.eql(u8, arg, "--priority") and i + 1 < args.len) { - priority = try std.fmt.parseInt(u8, args[i + 1], 10); + } else if (std.mem.eql(u8, arg, "--priority") and i + 1 < pre.len) { + priority = try std.fmt.parseInt(u8, pre[i + 1], 10); i += 1; } else if (std.mem.eql(u8, arg, "--mlflow")) { tracking.mlflow = TrackingConfig.MLflowConfig{}; has_tracking = true; - } else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < args.len) { + } else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < pre.len) { tracking.mlflow = TrackingConfig.MLflowConfig{ .mode = "remote", - .tracking_uri = args[i + 1], + .tracking_uri = pre[i + 1], }; has_tracking = true; i += 1; } else if (std.mem.eql(u8, arg, "--tensorboard")) { tracking.tensorboard = TrackingConfig.TensorBoardConfig{}; has_tracking = true; - } else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < args.len) { + } else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < pre.len) { if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{}; - tracking.wandb.?.api_key = args[i + 1]; + tracking.wandb.?.api_key = pre[i + 1]; has_tracking = true; i += 1; - } else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < args.len) { + } else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < pre.len) { if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{}; - tracking.wandb.?.project = args[i + 1]; + tracking.wandb.?.project = pre[i + 1]; has_tracking = true; i += 1; - } else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < args.len) { + } else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < pre.len) { if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{}; - tracking.wandb.?.entity = args[i + 1]; + tracking.wandb.?.entity = pre[i + 1]; has_tracking = true; i += 1; } else if (std.mem.eql(u8, arg, "--dry-run")) { @@ -158,23 +226,29 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { options.explain = true; } else if (std.mem.eql(u8, arg, "--json")) { options.json = true; - } else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < args.len) { - options.cpu = try std.fmt.parseInt(u8, args[i + 1], 10); + } else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < pre.len) { + options.cpu = try std.fmt.parseInt(u8, pre[i + 1], 10); i += 1; - } else if (std.mem.eql(u8, arg, "--memory") and i + 1 < args.len) { - options.memory = try std.fmt.parseInt(u8, args[i + 1], 10); + } else if (std.mem.eql(u8, arg, "--memory") and i + 1 < pre.len) { + options.memory = try std.fmt.parseInt(u8, pre[i + 1], 10); i += 1; - } else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < args.len) { - options.gpu = try std.fmt.parseInt(u8, args[i + 1], 10); + } else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < pre.len) { + options.gpu = try std.fmt.parseInt(u8, pre[i + 1], 10); i += 1; - } else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < args.len) { - options.gpu_memory = args[i + 1]; + } else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < pre.len) { + options.gpu_memory = pre[i + 1]; i += 1; - } else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < args.len) { - snapshot_id = args[i + 1]; + } else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < pre.len) { + snapshot_id = pre[i + 1]; i += 1; - } else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < args.len) { - snapshot_sha256 = args[i + 1]; + } else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < pre.len) { + snapshot_sha256 = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, arg, "--args") and i + 1 < pre.len) { + args_override = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) { + note_override = pre[i + 1]; i += 1; } } else { @@ -227,10 +301,25 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { defer if (commit_id_override) |cid| allocator.free(cid); + const args_str: []const u8 = if (args_override) |a| a else args_joined; + const note_str: []const u8 = if (note_override) |n| n else ""; + for (job_names.items, 0..) |job_name, index| { colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name }); - queueSingleJob(allocator, job_name, commit_id_override, priority, tracking_json, &options, snapshot_id, snapshot_sha256, print_next_steps) catch |err| { + queueSingleJob( + allocator, + job_name, + commit_id_override, + priority, + tracking_json, + &options, + snapshot_id, + snapshot_sha256, + args_str, + note_str, + print_next_steps, + ) catch |err| { colors.printError("Failed to queue job '{s}': {}\n", .{ job_name, err }); failed_jobs.append(allocator, job_name) catch |append_err| { colors.printError("Failed to track failed job: {}\n", .{append_err}); @@ -274,6 +363,8 @@ fn queueSingleJob( options: *const QueueOptions, snapshot_id: ?[]const u8, snapshot_sha256: ?[]const u8, + args_str: []const u8, + note_str: []const u8, print_next_steps: bool, ) !void { const commit_id = blk: { @@ -324,6 +415,33 @@ fn queueSingleJob( options.gpu, options.gpu_memory, ); + } else if (note_str.len > 0 or args_str.len > 0) { + if (note_str.len > 0) { + try client.sendQueueJobWithArgsNoteAndResources( + job_name, + commit_id, + priority, + api_key_hash, + args_str, + note_str, + options.cpu, + options.memory, + options.gpu, + options.gpu_memory, + ); + } else { + try client.sendQueueJobWithArgsAndResources( + job_name, + commit_id, + priority, + api_key_hash, + args_str, + options.cpu, + options.memory, + options.gpu, + options.gpu_memory, + ); + } } else if (snapshot_id) |sid| { try client.sendQueueJobWithSnapshotAndResources( job_name, @@ -374,6 +492,9 @@ fn printUsage() !void { colors.printInfo(" --memory Memory in GB (default: 8)\n", .{}); colors.printInfo(" --gpu GPU count (default: 0)\n", .{}); colors.printInfo(" --gpu-memory GPU memory budget (default: auto)\n", .{}); + colors.printInfo(" --args Extra runner args (sent to worker as task.Args)\n", .{}); + colors.printInfo(" --note Human notes (stored in run manifest as metadata.note)\n", .{}); + colors.printInfo(" -- Extra runner args (alternative to --args)\n", .{}); colors.printInfo("\nSpecial Modes:\n", .{}); colors.printInfo(" --dry-run Show what would be submitted\n", .{}); colors.printInfo(" --validate Validate experiment without submitting\n", .{}); diff --git a/cli/src/commands/requeue.zig b/cli/src/commands/requeue.zig new file mode 100644 index 0000000..06e7963 --- /dev/null +++ b/cli/src/commands/requeue.zig @@ -0,0 +1,345 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const ws = @import("../net/ws/client.zig"); + +pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { + if (argv.len == 0) { + try printUsage(); + return error.InvalidArgs; + } + if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) { + try printUsage(); + return; + } + + const target = argv[0]; + + // Split args at "--". + var sep_index: ?usize = null; + for (argv, 0..) |a, i| { + if (std.mem.eql(u8, a, "--")) { + sep_index = i; + break; + } + } + const pre = argv[1..(sep_index orelse argv.len)]; + const post = if (sep_index) |i| argv[(i + 1)..] else argv[0..0]; + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + // Defaults + var job_name_override: ?[]const u8 = null; + var priority: u8 = cfg.default_priority; + var cpu: u8 = cfg.default_cpu; + var memory: u8 = cfg.default_memory; + var gpu: u8 = cfg.default_gpu; + var gpu_memory: ?[]const u8 = cfg.default_gpu_memory; + var args_override: ?[]const u8 = null; + var note_override: ?[]const u8 = null; + + var i: usize = 0; + while (i < pre.len) : (i += 1) { + const a = pre[i]; + if (std.mem.eql(u8, a, "--name") and i + 1 < pre.len) { + job_name_override = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--priority") and i + 1 < pre.len) { + priority = try std.fmt.parseInt(u8, pre[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, a, "--cpu") and i + 1 < pre.len) { + cpu = try std.fmt.parseInt(u8, pre[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, a, "--memory") and i + 1 < pre.len) { + memory = try std.fmt.parseInt(u8, pre[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, a, "--gpu") and i + 1 < pre.len) { + gpu = try std.fmt.parseInt(u8, pre[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, a, "--gpu-memory") and i + 1 < pre.len) { + gpu_memory = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--args") and i + 1 < pre.len) { + args_override = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--note") and i + 1 < pre.len) { + note_override = pre[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { + try printUsage(); + return; + } else { + colors.printError("Unknown option: {s}\n", .{a}); + return error.InvalidArgs; + } + } + + var args_joined: []const u8 = ""; + if (post.len > 0) { + var buf: std.ArrayList(u8) = .{}; + defer buf.deinit(allocator); + for (post, 0..) |a, idx| { + if (idx > 0) try buf.append(allocator, ' '); + try buf.appendSlice(allocator, a); + } + args_joined = try buf.toOwnedSlice(allocator); + } + defer if (post.len > 0) allocator.free(args_joined); + + const args_final: []const u8 = if (args_override) |a| a else args_joined; + const note_final: []const u8 = if (note_override) |n| n else ""; + + // Target can be: + // - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base + // - run_id/task_id/path (resolved to run_manifest.json to read commit_id) + var commit_hex: []const u8 = ""; + var commit_hex_owned: ?[]u8 = null; + defer if (commit_hex_owned) |s| allocator.free(s); + + var commit_bytes: []u8 = &[_]u8{}; + var commit_bytes_allocated = false; + defer if (commit_bytes_allocated) allocator.free(commit_bytes); + + if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) { + if (target.len == 40) { + commit_hex = target; + } else { + commit_hex_owned = try resolveCommitPrefix(allocator, cfg.worker_base, target); + commit_hex = commit_hex_owned.?; + } + + const decoded = crypto.decodeHex(allocator, commit_hex) catch { + commit_hex = ""; + commit_hex_owned = null; + return error.InvalidCommitId; + }; + if (decoded.len != 20) { + allocator.free(decoded); + commit_hex = ""; + commit_hex_owned = null; + } else { + commit_bytes = decoded; + commit_bytes_allocated = true; + } + } + + var job_name = blk: { + if (job_name_override) |n| break :blk n; + break :blk "requeue"; + }; + + if (commit_hex.len == 0) { + const manifest_path = try resolveManifestPath(allocator, target, cfg.worker_base); + defer allocator.free(manifest_path); + + const data = try readFileAlloc(allocator, manifest_path); + defer allocator.free(data); + + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{}); + defer parsed.deinit(); + + if (parsed.value != .object) return error.InvalidManifest; + const root = parsed.value.object; + + commit_hex = jsonGetString(root, "commit_id") orelse ""; + if (commit_hex.len != 40) { + colors.printError("run manifest missing commit_id\n", .{}); + return error.InvalidManifest; + } + + if (job_name_override == null) { + const j = jsonGetString(root, "job_name") orelse ""; + if (j.len > 0) job_name = j; + } + + const b = try crypto.decodeHex(allocator, commit_hex); + if (b.len != 20) { + allocator.free(b); + return error.InvalidCommitId; + } + commit_bytes = b; + commit_bytes_allocated = true; + } + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + if (note_final.len > 0) { + try client.sendQueueJobWithArgsNoteAndResources( + job_name, + commit_bytes, + priority, + api_key_hash, + args_final, + note_final, + cpu, + memory, + gpu, + gpu_memory, + ); + } else { + try client.sendQueueJobWithArgsAndResources( + job_name, + commit_bytes, + priority, + api_key_hash, + args_final, + cpu, + memory, + gpu, + gpu_memory, + ); + } + + try client.receiveAndHandleResponse(allocator, "Requeue"); + + colors.printSuccess("Queued requeue\n", .{}); + colors.printInfo("Job: {s}\n", .{job_name}); + colors.printInfo("Commit: {s}\n", .{commit_hex}); +} + +fn isHexLowerOrUpper(s: []const u8) bool { + for (s) |c| { + if (!std.ascii.isHex(c)) return false; + } + return true; +} + +fn resolveCommitPrefix(allocator: std.mem.Allocator, base_path: []const u8, prefix: []const u8) ![]u8 { + var dir = if (std.fs.path.isAbsolute(base_path)) + try std.fs.openDirAbsolute(base_path, .{ .iterate = true }) + else + try std.fs.cwd().openDir(base_path, .{ .iterate = true }); + defer dir.close(); + + var it = dir.iterate(); + var found: ?[]u8 = null; + errdefer if (found) |s| allocator.free(s); + + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + const name = entry.name; + if (name.len != 40) continue; + if (!std.mem.startsWith(u8, name, prefix)) continue; + if (!isHexLowerOrUpper(name)) continue; + + if (found != null) { + colors.printError("Ambiguous commit prefix: {s}\n", .{prefix}); + return error.InvalidCommitId; + } + found = try allocator.dupe(u8, name); + } + + if (found) |s| return s; + colors.printError("No commit matches prefix: {s}\n", .{prefix}); + return error.FileNotFound; +} + +fn resolveManifestPath(allocator: std.mem.Allocator, input: []const u8, base_path: []const u8) ![]u8 { + var cwd = std.fs.cwd(); + + if (std.fs.path.isAbsolute(input)) { + if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| { + var mutable_dir = dir; + defer mutable_dir.close(); + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + if (std.fs.openFileAbsolute(input, .{}) catch null) |file| { + var mutable_file = file; + defer mutable_file.close(); + return allocator.dupe(u8, input); + } + return resolveManifestPathById(allocator, input, base_path); + } + + const stat = cwd.statFile(input) catch |err| { + if (err == error.FileNotFound) { + return resolveManifestPathById(allocator, input, base_path); + } + return err; + }; + + if (stat.kind == .directory) { + return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" }); + } + + return allocator.dupe(u8, input); +} + +fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 { + const roots = [_][]const u8{ "finished", "failed", "running", "pending" }; + for (roots) |root| { + const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root }); + defer allocator.free(root_path); + + var dir = if (std.fs.path.isAbsolute(root_path)) + (std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue) + else + (std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue); + defer dir.close(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + if (entry.kind != .directory) continue; + + const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name }); + defer allocator.free(run_dir); + const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" }); + defer allocator.free(manifest_path); + + const file = if (std.fs.path.isAbsolute(manifest_path)) + (std.fs.openFileAbsolute(manifest_path, .{}) catch continue) + else + (std.fs.cwd().openFile(manifest_path, .{}) catch continue); + defer file.close(); + + const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue; + defer allocator.free(data); + + const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue; + defer parsed.deinit(); + if (parsed.value != .object) continue; + + const obj = parsed.value.object; + const run_id = jsonGetString(obj, "run_id") orelse ""; + const task_id = jsonGetString(obj, "task_id") orelse ""; + if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) { + return allocator.dupe(u8, manifest_path); + } + } + } + + return error.FileNotFound; +} + +fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 { + var file = if (std.fs.path.isAbsolute(path)) + try std.fs.openFileAbsolute(path, .{}) + else + try std.fs.cwd().openFile(path, .{}); + defer file.close(); + + return try file.readToEndAlloc(allocator, 1024 * 1024); +} + +fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { + const v = obj.get(key) orelse return null; + if (v != .string) return null; + return v.string; +} + +fn printUsage() !void { + colors.printInfo("Usage:\n", .{}); + colors.printInfo(" ml requeue [--name ] [--priority ] [--cpu ] [--memory ] [--gpu ] [--gpu-memory ] [--args ] [--note ] -- \n", .{}); +} diff --git a/cli/src/commands/status.zig b/cli/src/commands/status.zig index 4e93a60..b0481b8 100644 --- a/cli/src/commands/status.zig +++ b/cli/src/commands/status.zig @@ -1,7 +1,7 @@ const std = @import("std"); const c = @cImport(@cInclude("time.h")); const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const errors = @import("../errors.zig"); const logging = @import("../utils/logging.zig"); diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig index 719d47c..cd8cfdd 100644 --- a/cli/src/commands/sync.zig +++ b/cli/src/commands/sync.zig @@ -4,7 +4,7 @@ const Config = @import("../config.zig").Config; const crypto = @import("../utils/crypto.zig"); const rsync = @import("../utils/rsync_embedded.zig"); // Use embedded rsync const storage = @import("../utils/storage.zig"); -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const logging = @import("../utils/logging.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { diff --git a/cli/src/commands/validate.zig b/cli/src/commands/validate.zig index 7d9a734..4d1d5b6 100644 --- a/cli/src/commands/validate.zig +++ b/cli/src/commands/validate.zig @@ -1,9 +1,10 @@ const std = @import("std"); const testing = std.testing; const Config = @import("../config.zig").Config; -const ws = @import("../net/ws.zig"); +const ws = @import("../net/ws/client.zig"); const colors = @import("../utils/colors.zig"); const crypto = @import("../utils/crypto.zig"); +const io = @import("../utils/io.zig"); pub const Options = struct { json: bool = false, @@ -78,7 +79,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { defer allocator.free(msg); const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch { - std.debug.print("{s}\n", .{msg}); + if (opts.json) { + var out = io.stdoutWriter(); + try out.print("{s}\n", .{msg}); + } else { + std.debug.print("{s}\n", .{msg}); + } return error.InvalidPacket; }; defer { @@ -101,7 +107,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { const payload = packet.data_payload.?; if (opts.json) { - std.debug.print("{s}\n", .{payload}); + var out = io.stdoutWriter(); + try out.print("{s}\n", .{payload}); } else { const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{}); defer parsed.deinit(); diff --git a/cli/src/commands/watch.zig b/cli/src/commands/watch.zig index 46a3c61..422edb7 100644 --- a/cli/src/commands/watch.zig +++ b/cli/src/commands/watch.zig @@ -1,8 +1,8 @@ const std = @import("std"); const Config = @import("../config.zig").Config; const crypto = @import("../utils/crypto.zig"); -const rsync = @import("../utils/rsync.zig"); -const ws = @import("../net/ws.zig"); +const rsync = @import("../utils/rsync_embedded.zig"); +const ws = @import("../net/ws/client.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len == 0) { diff --git a/cli/src/errors.zig b/cli/src/errors.zig index 6db939f..948d82a 100644 --- a/cli/src/errors.zig +++ b/cli/src/errors.zig @@ -1,6 +1,5 @@ const std = @import("std"); const colors = @import("utils/colors.zig"); -const ws = @import("net/ws.zig"); const crypto = @import("utils/crypto.zig"); /// User-friendly error types for CLI diff --git a/cli/src/main.zig b/cli/src/main.zig index 7bee768..bfe52d8 100644 --- a/cli/src/main.zig +++ b/cli/src/main.zig @@ -6,6 +6,7 @@ const Command = enum { jupyter, init, sync, + requeue, queue, status, monitor, @@ -78,6 +79,14 @@ pub fn main() !void { command_found = true; try @import("commands/info.zig").run(allocator, args[2..]); }, + 'a' => if (std.mem.eql(u8, command, "annotate")) { + command_found = true; + try @import("commands/annotate.zig").run(allocator, args[2..]); + }, + 'n' => if (std.mem.eql(u8, command, "narrative")) { + command_found = true; + try @import("commands/narrative.zig").run(allocator, args[2..]); + }, 's' => if (std.mem.eql(u8, command, "sync")) { command_found = true; if (args.len < 3) { @@ -89,6 +98,10 @@ pub fn main() !void { command_found = true; try @import("commands/status.zig").run(allocator, args[2..]); }, + 'r' => if (std.mem.eql(u8, command, "requeue")) { + command_found = true; + try @import("commands/requeue.zig").run(allocator, args[2..]); + }, 'q' => if (std.mem.eql(u8, command, "queue")) { command_found = true; try @import("commands/queue.zig").run(allocator, args[2..]); @@ -127,8 +140,11 @@ fn printUsage() void { std.debug.print("Commands:\n", .{}); std.debug.print(" jupyter Jupyter workspace management\n", .{}); std.debug.print(" init Setup configuration interactively\n", .{}); + std.debug.print(" annotate Add an annotation to run_manifest.json (--note \"...\")\n", .{}); + std.debug.print(" narrative set Set run narrative fields (hypothesis/context/...)\n", .{}); std.debug.print(" info Show run info from run_manifest.json (optionally --base )\n", .{}); std.debug.print(" sync Sync project to server\n", .{}); + std.debug.print(" requeue Re-submit from run_id/task_id/path (supports -- )\n", .{}); std.debug.print(" queue (q) Queue job for execution\n", .{}); std.debug.print(" status Get system status\n", .{}); std.debug.print(" monitor Launch TUI via SSH\n", .{}); @@ -143,4 +159,7 @@ fn printUsage() void { test { _ = @import("commands/info.zig"); + _ = @import("commands/requeue.zig"); + _ = @import("commands/annotate.zig"); + _ = @import("commands/narrative.zig"); } diff --git a/cli/src/net.zig b/cli/src/net.zig index 2a90e8b..37f6531 100644 --- a/cli/src/net.zig +++ b/cli/src/net.zig @@ -1,3 +1,5 @@ // Network module - exports all network modules pub const protocol = @import("net/protocol.zig"); pub const ws = @import("net/ws.zig"); +pub const ws_client = @import("net/ws/client.zig"); +pub const ws_opcodes = @import("net/ws/opcodes.zig"); diff --git a/cli/src/net/ws.zig b/cli/src/net/ws.zig index 5098d9b..c7a31f7 100644 --- a/cli/src/net/ws.zig +++ b/cli/src/net/ws.zig @@ -1,1705 +1,11 @@ -const std = @import("std"); -const crypto = @import("../utils/crypto.zig"); -const protocol = @import("protocol.zig"); -const log = @import("../utils/logging.zig"); - -/// Binary WebSocket protocol opcodes -pub const Opcode = enum(u8) { - queue_job = 0x01, - queue_job_with_tracking = 0x0C, - queue_job_with_snapshot = 0x17, - status_request = 0x02, - cancel_job = 0x03, - prune = 0x04, - crash_report = 0x05, - log_metric = 0x0A, - get_experiment = 0x0B, - start_jupyter = 0x0D, - stop_jupyter = 0x0E, - remove_jupyter = 0x18, - restore_jupyter = 0x19, - list_jupyter = 0x0F, - - validate_request = 0x16, - - // Dataset management opcodes - dataset_list = 0x06, - dataset_register = 0x07, - dataset_info = 0x08, - dataset_search = 0x09, - - // Structured response opcodes - response_success = 0x10, - response_error = 0x11, - response_progress = 0x12, - response_status = 0x13, - response_data = 0x14, - response_log = 0x15, -}; - -pub const ValidateTargetType = enum(u8) { - commit_id = 0, - task_id = 1, -}; - -/// WebSocket client for binary protocol communication -pub const Client = struct { - allocator: std.mem.Allocator, - stream: ?std.net.Stream, - host: []const u8, - port: u16, - is_tls: bool = false, - - pub fn connect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !Client { - // Detect TLS - const is_tls = std.mem.startsWith(u8, url, "wss://"); - - // Parse URL (simplified - assumes ws://host:port/path or wss://host:port/path) - const host_start = std.mem.indexOf(u8, url, "//") orelse return error.InvalidURL; - const host_port_start = host_start + 2; - const path_start = std.mem.indexOfPos(u8, url, host_port_start, "/") orelse url.len; - const colon_pos = std.mem.indexOfPos(u8, url, host_port_start, ":"); - - const host_end = blk: { - if (colon_pos) |pos| { - if (pos < path_start) break :blk pos; - } - break :blk path_start; - }; - const host = url[host_port_start..host_end]; - - var port: u16 = if (is_tls) 9101 else 9100; // default ports - if (colon_pos) |pos| { - if (pos < path_start) { - const port_start = pos + 1; - const port_end = std.mem.indexOfPos(u8, url, port_start, "/") orelse url.len; - const port_str = url[port_start..port_end]; - port = try std.fmt.parseInt(u16, port_str, 10); - } - } - - // Connect to server - const address = try resolveHostAddress(allocator, host, port); - const stream = try std.net.tcpConnectToAddress(address); - - // For TLS, we'd need to wrap the stream with TLS - // For now, we'll just support ws:// and document wss:// requires additional setup - if (is_tls) { - // TODO(context): Implement native wss:// support by introducing a transport abstraction - // (raw TCP vs TLS client stream), performing TLS handshake + certificate verification, and updating - // handshake/frame read+write helpers to operate on the chosen transport. - std.log.warn("TLS (wss://) support requires additional TLS library integration", .{}); - return error.TLSNotSupported; - } - // Perform WebSocket handshake - try handshake(allocator, stream, host, url, api_key); - - return Client{ - .allocator = allocator, - .stream = stream, - .host = try allocator.dupe(u8, host), - .port = port, - .is_tls = is_tls, - }; - } - - /// Connect to WebSocket server with retry logic - pub fn connectWithRetry(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8, max_retries: u32) !Client { - var retry_count: u32 = 0; - var last_error: anyerror = error.ConnectionFailed; - - while (retry_count < max_retries) { - const client = connect(allocator, url, api_key) catch |err| { - last_error = err; - retry_count += 1; - - if (retry_count < max_retries) { - const delay_ms = @min(1000 * retry_count, 5000); // Exponential backoff, max 5s - log.warn("Connection failed (attempt {d}/{d}), retrying in {d}s...\n", .{ retry_count, max_retries, delay_ms / 1000 }); - std.Thread.sleep(@as(u64, delay_ms) * std.time.ns_per_ms); - } - continue; - }; - - if (retry_count > 0) { - log.success("Connected successfully after {d} attempts\n", .{retry_count + 1}); - } - return client; - } - - return last_error; - } - - /// Disconnect from WebSocket server - pub fn disconnect(self: *Client) void { - if (self.stream) |stream| { - stream.close(); - self.stream = null; - } - } - - fn handshake(allocator: std.mem.Allocator, stream: std.net.Stream, host: []const u8, url: []const u8, api_key: []const u8) !void { - const key = try generateWebSocketKey(allocator); - defer allocator.free(key); - - // API key is already hashed in config, send as-is - const request = try std.fmt.allocPrint(allocator, "GET {s} HTTP/1.1\r\n" ++ - "Host: {s}\r\n" ++ - "Upgrade: websocket\r\n" ++ - "Connection: Upgrade\r\n" ++ - "Sec-WebSocket-Key: {s}\r\n" ++ - "Sec-WebSocket-Version: 13\r\n" ++ - "X-API-Key: {s}\r\n" ++ - "\r\n", .{ url, host, key, api_key }); - defer allocator.free(request); - - _ = try stream.write(request); - - // Read response until complete - var response_buf: [4096]u8 = undefined; - var bytes_read: usize = 0; - var header_complete = false; - - while (!header_complete and bytes_read < response_buf.len - 1) { - const chunk_bytes = try stream.read(response_buf[bytes_read..]); - if (chunk_bytes == 0) break; - bytes_read += chunk_bytes; - - // Check if we have complete HTTP headers (\r\n\r\n) - if (std.mem.indexOf(u8, response_buf[0..bytes_read], "\r\n\r\n") != null) { - header_complete = true; - } - } - - const response = response_buf[0..bytes_read]; - - // Check for successful handshake - if (std.mem.indexOf(u8, response, "101 Switching Protocols") == null) { - // Parse HTTP status code for better error messages - if (std.mem.indexOf(u8, response, "404 Not Found") != null) { - std.debug.print("\n❌ WebSocket Connection Failed\n", .{}); - std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); - std.debug.print("The WebSocket endpoint '/ws' was not found on the server.\n\n", .{}); - std.debug.print("This usually means:\n", .{}); - std.debug.print(" • API server is not running\n", .{}); - std.debug.print(" • Incorrect server address in config\n", .{}); - std.debug.print(" • Different service running on that port\n\n", .{}); - std.debug.print("To diagnose:\n", .{}); - std.debug.print(" • Verify server address: Check ~/.ml/config.toml\n", .{}); - std.debug.print(" • Test connectivity: curl http://:/health\n", .{}); - std.debug.print(" • Contact your server administrator if the issue persists\n\n", .{}); - return error.EndpointNotFound; - } else if (std.mem.indexOf(u8, response, "401 Unauthorized") != null) { - std.debug.print("\n❌ Authentication Failed\n", .{}); - std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); - std.debug.print("Invalid or missing API key.\n\n", .{}); - std.debug.print("To fix:\n", .{}); - std.debug.print(" • Verify API key in ~/.ml/config.toml matches server configuration\n", .{}); - std.debug.print(" • Request a new API key from your administrator if needed\n\n", .{}); - return error.AuthenticationFailed; - } else if (std.mem.indexOf(u8, response, "403 Forbidden") != null) { - std.debug.print("\n❌ Access Denied\n", .{}); - std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); - std.debug.print("Your API key doesn't have permission for this operation.\n\n", .{}); - std.debug.print("To fix:\n", .{}); - std.debug.print(" • Contact your administrator to grant necessary permissions\n\n", .{}); - return error.PermissionDenied; - } else if (std.mem.indexOf(u8, response, "503 Service Unavailable") != null) { - std.debug.print("\n❌ Server Unavailable\n", .{}); - std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); - std.debug.print("The server is temporarily unavailable.\n\n", .{}); - std.debug.print("This could be due to:\n", .{}); - std.debug.print(" • Server maintenance\n", .{}); - std.debug.print(" • High load\n", .{}); - std.debug.print(" • Server restart\n\n", .{}); - std.debug.print("To resolve:\n", .{}); - std.debug.print(" • Wait a moment and try again\n", .{}); - std.debug.print(" • Contact administrator if the issue persists\n\n", .{}); - return error.ServerUnavailable; - } else { - // Generic handshake failure - show response for debugging - std.debug.print("\n❌ WebSocket Handshake Failed\n", .{}); - std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); - std.debug.print("Expected HTTP 101 Switching Protocols, but received:\n", .{}); - - // Show first line of response (status line) - const newline_pos = std.mem.indexOf(u8, response, "\r\n") orelse response.len; - const status_line = response[0..newline_pos]; - std.debug.print(" {s}\n\n", .{status_line}); - - std.debug.print("To diagnose:\n", .{}); - std.debug.print(" • Verify server address in ~/.ml/config.toml\n", .{}); - std.debug.print(" • Check network connectivity to the server\n", .{}); - std.debug.print(" • Contact your administrator for assistance\n\n", .{}); - return error.HandshakeFailed; - } - } - - // Add small delay to ensure server is ready for WebSocket frames - std.posix.nanosleep(0, 10 * std.time.ns_per_ms); - } - - fn generateWebSocketKey(allocator: std.mem.Allocator) ![]u8 { - var random_bytes: [16]u8 = undefined; - std.crypto.random.bytes(&random_bytes); - - const base64 = std.base64.standard.Encoder; - const result = try allocator.alloc(u8, base64.calcSize(random_bytes.len)); - _ = base64.encode(result, &random_bytes); - return result; - } - - pub fn close(self: *Client) void { - if (self.stream) |stream| { - stream.close(); - self.stream = null; - } - if (self.host.len > 0) { - self.allocator.free(self.host); - } - } - - pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - - const total_len = 1 + 16 + 1 + 1 + 20; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.validate_request); - offset += 1; - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - buffer[offset] = @intFromEnum(ValidateTargetType.commit_id); - offset += 1; - buffer[offset] = 20; - offset += 1; - @memcpy(buffer[offset .. offset + 20], commit_id); - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendQueueJobWithSnapshotAndResources( - self: *Client, - job_name: []const u8, - commit_id: []const u8, - priority: u8, - api_key_hash: []const u8, - snapshot_id: []const u8, - snapshot_sha256: []const u8, - cpu: u8, - memory_gb: u8, - gpu: u8, - gpu_memory: ?[]const u8, - ) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (job_name.len > 255) return error.JobNameTooLong; - if (snapshot_id.len == 0 or snapshot_id.len > 255) return error.PayloadTooLarge; - if (snapshot_sha256.len == 0 or snapshot_sha256.len > 255) return error.PayloadTooLarge; - - const gpu_mem = gpu_memory orelse ""; - if (gpu_mem.len > 255) return error.PayloadTooLarge; - - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 1 + snapshot_id.len + 1 + snapshot_sha256.len + 4 + gpu_mem.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.queue_job_with_snapshot); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - buffer[offset] = priority; - offset += 1; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset .. offset + job_name.len], job_name); - offset += job_name.len; - - buffer[offset] = @intCast(snapshot_id.len); - offset += 1; - - @memcpy(buffer[offset .. offset + snapshot_id.len], snapshot_id); - offset += snapshot_id.len; - - buffer[offset] = @intCast(snapshot_sha256.len); - offset += 1; - - @memcpy(buffer[offset .. offset + snapshot_sha256.len], snapshot_sha256); - offset += snapshot_sha256.len; - - buffer[offset] = cpu; - buffer[offset + 1] = memory_gb; - buffer[offset + 2] = gpu; - buffer[offset + 3] = @intCast(gpu_mem.len); - offset += 4; - - if (gpu_mem.len > 0) { - @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); - } - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendValidateRequestTask(self: *Client, api_key_hash: []const u8, task_id: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (task_id.len == 0 or task_id.len > 255) return error.PayloadTooLarge; - - const total_len = 1 + 16 + 1 + 1 + task_id.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.validate_request); - offset += 1; - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - buffer[offset] = @intFromEnum(ValidateTargetType.task_id); - offset += 1; - buffer[offset] = @intCast(task_id.len); - offset += 1; - @memcpy(buffer[offset .. offset + task_id.len], task_id); - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendQueueJob(self: *Client, job_name: []const u8, commit_id: []const u8, priority: u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - // Validate input lengths - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (job_name.len > 255) return error.JobNameTooLong; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] [priority: u8] [job_name_len: u8] [job_name: var] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.queue_job); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - buffer[offset] = priority; - offset += 1; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset..], job_name); - - // Send as WebSocket binary frame - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendQueueJobWithResources( - self: *Client, - job_name: []const u8, - commit_id: []const u8, - priority: u8, - api_key_hash: []const u8, - cpu: u8, - memory_gb: u8, - gpu: u8, - gpu_memory: ?[]const u8, - ) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (job_name.len > 255) return error.JobNameTooLong; - - const gpu_mem = gpu_memory orelse ""; - if (gpu_mem.len > 255) return error.PayloadTooLarge; - - // Tail encoding: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 4 + gpu_mem.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.queue_job); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - buffer[offset] = priority; - offset += 1; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset .. offset + job_name.len], job_name); - offset += job_name.len; - - buffer[offset] = cpu; - buffer[offset + 1] = memory_gb; - buffer[offset + 2] = gpu; - buffer[offset + 3] = @intCast(gpu_mem.len); - offset += 4; - - if (gpu_mem.len > 0) { - @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); - } - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendQueueJobWithTracking( - self: *Client, - job_name: []const u8, - commit_id: []const u8, - priority: u8, - api_key_hash: []const u8, - tracking_json: []const u8, - ) !void { - const stream = self.stream orelse return error.NotConnected; - - // Validate input lengths - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (job_name.len > 255) return error.JobNameTooLong; - if (tracking_json.len > 0xFFFF) return error.PayloadTooLarge; - - // Build binary message: - // [opcode: u8] - // [api_key_hash: 16] - // [commit_id: 20] - // [priority: u8] - // [job_name_len: u8] - // [job_name: var] - // [tracking_json_len: u16] - // [tracking_json: var] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + tracking_json.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.queue_job_with_tracking); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - buffer[offset] = priority; - offset += 1; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset .. offset + job_name.len], job_name); - offset += job_name.len; - - // tracking_json length (big-endian) - buffer[offset] = @intCast((tracking_json.len >> 8) & 0xFF); - buffer[offset + 1] = @intCast(tracking_json.len & 0xFF); - offset += 2; - - if (tracking_json.len > 0) { - @memcpy(buffer[offset .. offset + tracking_json.len], tracking_json); - } - - // Single WebSocket frame for throughput - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendQueueJobWithTrackingAndResources( - self: *Client, - job_name: []const u8, - commit_id: []const u8, - priority: u8, - api_key_hash: []const u8, - tracking_json: []const u8, - cpu: u8, - memory_gb: u8, - gpu: u8, - gpu_memory: ?[]const u8, - ) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (job_name.len > 255) return error.JobNameTooLong; - if (tracking_json.len > 0xFFFF) return error.PayloadTooLarge; - - const gpu_mem = gpu_memory orelse ""; - if (gpu_mem.len > 255) return error.PayloadTooLarge; - - // [opcode] - // [api_key_hash] - // [commit_id] - // [priority] - // [job_name_len][job_name] - // [tracking_json_len:2][tracking_json] - // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] - const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + tracking_json.len + 4 + gpu_mem.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.queue_job_with_tracking); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - buffer[offset] = priority; - offset += 1; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - @memcpy(buffer[offset .. offset + job_name.len], job_name); - offset += job_name.len; - - buffer[offset] = @intCast((tracking_json.len >> 8) & 0xFF); - buffer[offset + 1] = @intCast(tracking_json.len & 0xFF); - offset += 2; - - if (tracking_json.len > 0) { - @memcpy(buffer[offset .. offset + tracking_json.len], tracking_json); - offset += tracking_json.len; - } - - buffer[offset] = cpu; - buffer[offset + 1] = memory_gb; - buffer[offset + 2] = gpu; - buffer[offset + 3] = @intCast(gpu_mem.len); - offset += 4; - - if (gpu_mem.len > 0) { - @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); - } - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (job_name.len > 255) return error.JobNameTooLong; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] - const total_len = 1 + 16 + 1 + job_name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.cancel_job); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(job_name.len); - offset += 1; - - @memcpy(buffer[offset..], job_name); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [prune_type: u8] [value: u4] - const total_len = 1 + 16 + 1 + 4; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.prune); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = prune_type; - offset += 1; - - // Store value in big-endian format - buffer[offset] = @intCast((value >> 24) & 0xFF); - buffer[offset + 1] = @intCast((value >> 16) & 0xFF); - buffer[offset + 2] = @intCast((value >> 8) & 0xFF); - buffer[offset + 3] = @intCast(value & 0xFF); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] - const total_len = 1 + 16; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - buffer[0] = @intFromEnum(Opcode.status_request); - @memcpy(buffer[1..17], api_key_hash); - - try sendWebSocketFrame(stream, buffer); - } - - fn sendWebSocketFrame(stream: std.net.Stream, payload: []const u8) !void { - var frame: [14]u8 = undefined; // Extra space for mask - var frame_len: usize = 2; - - // FIN=1, opcode=0x2 (binary), MASK=1 - frame[0] = 0x82 | 0x80; - - // Payload length - if (payload.len < 126) { - frame[1] = @as(u8, @intCast(payload.len)) | 0x80; // Set MASK bit - } else if (payload.len < 65536) { - frame[1] = 126 | 0x80; // Set MASK bit - frame[2] = @intCast(payload.len >> 8); - frame[3] = @intCast(payload.len & 0xFF); - frame_len = 4; - } else { - return error.PayloadTooLarge; - } - - // Generate random mask (4 bytes) - var mask: [4]u8 = undefined; - var i: usize = 0; - while (i < 4) : (i += 1) { - mask[i] = @as(u8, @intCast(@mod(std.time.timestamp(), 256))); - } - - // Copy mask to frame - @memcpy(frame[frame_len .. frame_len + 4], &mask); - frame_len += 4; - - // Send frame header - _ = try stream.write(frame[0..frame_len]); - - // Send payload with masking - var masked_payload = try std.heap.page_allocator.alloc(u8, payload.len); - defer std.heap.page_allocator.free(masked_payload); - - for (payload, 0..) |byte, j| { - masked_payload[j] = byte ^ mask[j % 4]; - } - - _ = try stream.write(masked_payload); - } - - pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 { - const stream = self.stream orelse return error.NotConnected; - - // Read frame header - var header: [2]u8 = undefined; - const header_bytes = try stream.read(&header); - if (header_bytes < 2) return error.ConnectionClosed; - - // Check for binary frame and FIN bit - if (header[0] != 0x82) return error.InvalidFrame; - - // Get payload length - var payload_len: usize = header[1]; - if (payload_len == 126) { - var len_bytes: [2]u8 = undefined; - _ = try stream.read(&len_bytes); - payload_len = (@as(usize, len_bytes[0]) << 8) | len_bytes[1]; - } else if (payload_len == 127) { - return error.PayloadTooLarge; - } - - // Read payload - const payload = try allocator.alloc(u8, payload_len); - errdefer allocator.free(payload); - - var bytes_read: usize = 0; - while (bytes_read < payload_len) { - const n = try stream.read(payload[bytes_read..]); - if (n == 0) return error.ConnectionClosed; - bytes_read += n; - } - - return payload; - } - - /// Receive and handle response with automatic display - pub fn receiveAndHandleResponse(self: *Client, allocator: std.mem.Allocator, operation: []const u8) !void { - const message = try self.receiveMessage(allocator); - defer allocator.free(message); - - const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { - // Fallback: treat as plain response. - std.debug.print("Server response: {s}\n", .{message}); - return; - }; - defer { - if (packet.success_message) |msg| allocator.free(msg); - if (packet.error_message) |msg| allocator.free(msg); - if (packet.error_details) |details| allocator.free(details); - if (packet.data_type) |dtype| allocator.free(dtype); - if (packet.data_payload) |payload| allocator.free(payload); - if (packet.progress_message) |pmsg| allocator.free(pmsg); - if (packet.status_data) |sdata| allocator.free(sdata); - if (packet.log_message) |lmsg| allocator.free(lmsg); - } - - try self.handleResponsePacket(packet, operation); - } - - fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { - const v_opt = obj.get(key); - if (v_opt == null) { - return null; - } - const v = v_opt.?; - if (v != .string) { - return null; - } - return v.string; - } - - fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 { - const v_opt = obj.get(key); - if (v_opt == null) { - return null; - } - const v = v_opt.?; - if (v != .integer) { - return null; - } - return v.integer; - } - - pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 { - const prewarm_val_opt = root.get("prewarm"); - if (prewarm_val_opt == null) { - return null; - } - const prewarm_val = prewarm_val_opt.?; - if (prewarm_val != .array) { - return null; - } - - const items = prewarm_val.array.items; - if (items.len == 0) { - return null; - } - - var out = std.ArrayList(u8){}; - errdefer out.deinit(allocator); - - const writer = out.writer(allocator); - try writer.writeAll("Prewarm:\n"); - - for (items) |item| { - if (item != .object) { - continue; - } - - const obj = item.object; - - const worker_id = jsonGetString(obj, "worker_id") orelse ""; - const task_id = jsonGetString(obj, "task_id") orelse ""; - const phase = jsonGetString(obj, "phase") orelse ""; - const started_at = jsonGetString(obj, "started_at") orelse ""; - const dataset_count = jsonGetInt(obj, "dataset_count") orelse 0; - const snapshot_id = jsonGetString(obj, "snapshot_id") orelse ""; - const env_image = jsonGetString(obj, "env_image") orelse ""; - const env_hit = jsonGetInt(obj, "env_hit") orelse 0; - const env_miss = jsonGetInt(obj, "env_miss") orelse 0; - const env_built = jsonGetInt(obj, "env_built") orelse 0; - - try writer.print( - " worker={s} task={s} phase={s} datasets={d} snapshot={s} env={s} env_hit={d} env_miss={d} env_built={d} started={s}\n", - .{ worker_id, task_id, phase, dataset_count, snapshot_id, env_image, env_hit, env_miss, env_built, started_at }, - ); - } - - const owned = try out.toOwnedSlice(allocator); - return owned; - } - - /// Receive and handle status response with user filtering - pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void { - _ = user_context; // TODO: Use for filtering - const message = try self.receiveMessage(allocator); - defer allocator.free(message); - - const json_start_opt = std.mem.indexOfScalar(u8, message, '{'); - - // Check if message is JSON (or contains JSON) or plain text - if (json_start_opt != null) { - const json_slice = message[json_start_opt.?..]; - // Parse JSON response - const parsed = try std.json.parseFromSlice(std.json.Value, allocator, json_slice, .{}); - defer parsed.deinit(); - const root = parsed.value.object; - - // Apply limit if specified - if (options.limit) |limit| { - // For now, just note the limit - actual implementation would truncate results - const colors = @import("../utils/colors.zig"); - colors.printInfo("Showing {d} results (limited)\n", .{limit}); - } - - if (options.json) { - // Output raw JSON - std.debug.print("{s}\n", .{json_slice}); - } else { - // Display user info - if (root.get("user")) |user_obj| { - const user = user_obj.object; - const name = user.get("name").?.string; - const admin = user.get("admin").?.bool; - const colors = @import("../utils/colors.zig"); - colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin }); - } - - // Display task summary - if (root.get("tasks")) |tasks_obj| { - const tasks = tasks_obj.object; - const total = tasks.get("total").?.integer; - const queued = tasks.get("queued").?.integer; - const running = tasks.get("running").?.integer; - const failed = tasks.get("failed").?.integer; - const completed = tasks.get("completed").?.integer; - const colors = @import("../utils/colors.zig"); - colors.printInfo( - "Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n", - .{ total, queued, running, failed, completed }, - ); - } - - const per_section_limit: usize = options.limit orelse 5; - - const TaskStatus = enum { queued, running, failed, completed }; - - const TaskPrinter = struct { - fn statusLabel(s: TaskStatus) []const u8 { - return switch (s) { - .queued => "Queued", - .running => "Running", - .failed => "Failed", - .completed => "Completed", - }; - } - - fn statusMatch(s: TaskStatus) []const u8 { - return switch (s) { - .queued => "queued", - .running => "running", - .failed => "failed", - .completed => "completed", - }; - } - - fn shorten(s: []const u8, max_len: usize) []const u8 { - if (s.len <= max_len) return s; - return s[0..max_len]; - } - - fn printSection( - allocator2: std.mem.Allocator, - queue_items: []const std.json.Value, - status: TaskStatus, - limit2: usize, - ) !void { - _ = allocator2; - const colors = @import("../utils/colors.zig"); - const label = statusLabel(status); - const want = statusMatch(status); - std.debug.print("\n{s}:\n", .{label}); - - var shown: usize = 0; - for (queue_items) |item| { - if (item != .object) continue; - const obj = item.object; - const st = jsonGetString(obj, "status") orelse ""; - if (!std.mem.eql(u8, st, want)) continue; - - const id = jsonGetString(obj, "id") orelse ""; - const job_name = jsonGetString(obj, "job_name") orelse ""; - const worker_id = jsonGetString(obj, "worker_id") orelse ""; - const err = jsonGetString(obj, "error") orelse ""; - - if (std.mem.eql(u8, want, "failed")) { - colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name }); - if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); - } - std.debug.print("\n", .{}); - if (err.len > 0) { - std.debug.print(" error: {s}\n", .{shorten(err, 160)}); - } - } else if (std.mem.eql(u8, want, "running")) { - colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name }); - if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); - } - std.debug.print("\n", .{}); - } else if (std.mem.eql(u8, want, "queued")) { - std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name }); - } else { - colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name }); - } - - shown += 1; - if (shown >= limit2) break; - } - - if (shown == 0) { - std.debug.print(" (none)\n", .{}); - } else { - // Indicate there may be more. - var total_for_status: usize = 0; - for (queue_items) |item| { - if (item != .object) continue; - const obj = item.object; - const st = jsonGetString(obj, "status") orelse ""; - if (std.mem.eql(u8, st, want)) total_for_status += 1; - } - if (total_for_status > shown) { - std.debug.print(" ... and {d} more\n", .{total_for_status - shown}); - } - } - } - }; - - if (root.get("queue")) |queue_val| { - if (queue_val == .array) { - const items = queue_val.array.items; - try TaskPrinter.printSection(allocator, items, .queued, per_section_limit); - try TaskPrinter.printSection(allocator, items, .running, per_section_limit); - try TaskPrinter.printSection(allocator, items, .failed, per_section_limit); - try TaskPrinter.printSection(allocator, items, .completed, per_section_limit); - } - } - - if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| { - defer allocator.free(section); - const colors = @import("../utils/colors.zig"); - colors.printInfo("{s}", .{section}); - } - } - } else { - // Handle plain text response - filter out non-printable characters - var clean_msg = allocator.alloc(u8, message.len) catch { - if (options.json) { - std.debug.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } - return; - }; - defer allocator.free(clean_msg); - - var clean_len: usize = 0; - for (message) |byte| { - // Skip WebSocket frame header bytes and non-printable chars - if (byte >= 32 and byte <= 126) { // printable ASCII only - clean_msg[clean_len] = byte; - clean_len += 1; - } - } - - // Look for common error messages in the cleaned data - if (clean_len > 0) { - const cleaned = clean_msg[0..clean_len]; - if (options.json) { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - std.debug.print("{{\"error\": \"insufficient_permissions\"}}\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - std.debug.print("{{\"error\": \"authentication_failed\"}}\n", .{}); - } else { - std.debug.print("{{\"response\": \"{s}\"}}\n", .{cleaned}); - } - } else { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - std.debug.print("Insufficient permissions to view jobs\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - std.debug.print("Authentication failed\n", .{}); - } else { - std.debug.print("Server response: {s}\n", .{cleaned}); - } - } - } else { - if (options.json) { - std.debug.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } - } - return; - } - } - - /// Receive and handle cancel response with user permissions - pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void { - const message = try self.receiveMessage(allocator); - defer allocator.free(message); - - // Check if message is JSON or plain text - if (message[0] == '{') { - // Parse JSON response - const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{}); - defer parsed.deinit(); - const root = parsed.value.object; - - if (options.json) { - // Output raw JSON - std.debug.print("{s}\n", .{message}); - } else { - // Display user-friendly output - if (root.get("success")) |success_val| { - if (success_val.bool) { - const colors = @import("../utils/colors.zig"); - colors.printSuccess("Job '{s}' canceled successfully\n", .{job_name}); - } else { - const colors = @import("../utils/colors.zig"); - colors.printError("Failed to cancel job '{s}'\n", .{job_name}); - if (root.get("error")) |error_val| { - colors.printError("Error: {s}\n", .{error_val.string}); - } - } - } else { - const colors = @import("../utils/colors.zig"); - colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name }); - } - } - } else { - // Handle plain text response - filter out non-printable characters - var clean_msg = allocator.alloc(u8, message.len) catch { - if (options.json) { - std.debug.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } - return; - }; - defer allocator.free(clean_msg); - - var clean_len: usize = 0; - for (message) |byte| { - // Skip WebSocket frame header bytes and non-printable chars - if (byte >= 32 and byte <= 126) { // printable ASCII only - clean_msg[clean_len] = byte; - clean_len += 1; - } - } - - // Look for common error messages in the cleaned data - if (clean_len > 0) { - const cleaned = clean_msg[0..clean_len]; - if (options.json) { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - std.debug.print("{{\"error\": \"insufficient_permissions\"}}\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - std.debug.print("{{\"error\": \"authentication_failed\"}}\n", .{}); - } else { - std.debug.print("{{\"response\": \"{s}\"}}\n", .{cleaned}); - } - } else { - if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { - std.debug.print("Insufficient permissions to cancel job\n", .{}); - } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { - std.debug.print("Authentication failed\n", .{}); - } else { - const colors = @import("../utils/colors.zig"); - colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name }); - colors.printInfo("Response: {s}\n", .{cleaned}); - } - } - } else { - if (options.json) { - std.debug.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); - } else { - std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); - } - } - return; - } - } - - /// Handle response packet with appropriate display - pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void { - switch (packet.packet_type) { - .success => { - if (packet.success_message) |msg| { - if (msg.len > 0) { - std.debug.print("✓ {s}: {s}\n", .{ operation, msg }); - } else { - std.debug.print("✓ {s} completed successfully\n", .{operation}); - } - } else { - std.debug.print("✓ {s} completed successfully\n", .{operation}); - } - }, - .error_packet => { - const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); - std.debug.print("✗ {s} failed: {s}\n", .{ operation, error_msg }); - - if (packet.error_message) |msg| { - if (msg.len > 0) { - std.debug.print("Details: {s}\n", .{msg}); - } - } - - if (packet.error_details) |details| { - if (details.len > 0) { - std.debug.print("Additional info: {s}\n", .{details}); - } - } - - // Convert to appropriate CLI error - return self.convertServerError(packet.error_code.?); - }, - .progress => { - if (packet.progress_type) |ptype| { - switch (ptype) { - .percentage => { - const percentage = packet.progress_value.?; - if (packet.progress_total) |total| { - std.debug.print("Progress: {d}/{d} ({d:.1}%)\n", .{ percentage, total, @as(f32, @floatFromInt(percentage)) * 100.0 / @as(f32, @floatFromInt(total)) }); - } else { - std.debug.print("Progress: {d}%\n", .{percentage}); - } - }, - .stage => { - if (packet.progress_message) |msg| { - std.debug.print("Stage: {s}\n", .{msg}); - } - }, - .message => { - if (packet.progress_message) |msg| { - std.debug.print("Info: {s}\n", .{msg}); - } - }, - .bytes_transferred => { - const bytes = packet.progress_value.?; - if (packet.progress_total) |total| { - const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; - const total_mb = @as(f64, @floatFromInt(total)) / 1024.0 / 1024.0; - std.debug.print("Transferred: {d:.2} MB / {d:.2} MB\n", .{ transferred_mb, total_mb }); - } else { - const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; - std.debug.print("Transferred: {d:.2} MB\n", .{transferred_mb}); - } - }, - } - } - }, - .status => { - if (packet.status_data) |data| { - std.debug.print("Status: {s}\n", .{data}); - } - }, - .data => { - if (packet.data_type) |dtype| { - std.debug.print("Data [{s}]: ", .{dtype}); - if (packet.data_payload) |payload| { - // Try to display as string if it looks like text - const is_text = for (payload) |byte| { - if (byte < 32 and byte != '\n' and byte != '\r' and byte != '\t') break false; - } else true; - - if (is_text) { - std.debug.print("{s}\n", .{payload}); - } else { - std.debug.print("{d} bytes\n", .{payload.len}); - } - } - } - }, - .log => { - if (packet.log_level) |level| { - const level_name = protocol.ResponsePacket.getLogLevelName(level); - if (packet.log_message) |msg| { - std.debug.print("[{s}] {s}\n", .{ level_name, msg }); - } - } - }, - } - } - - /// Convert server error code to CLI error - fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror { - _ = self; // Client instance not needed for error conversion - return switch (server_error) { - .authentication_failed => error.AuthenticationFailed, - .permission_denied => error.PermissionDenied, - .resource_not_found => error.JobNotFound, - .resource_already_exists => error.ResourceExists, - .timeout => error.RequestTimeout, - .server_overloaded, .service_unavailable => error.ServerUnreachable, - .invalid_request => error.InvalidArguments, - .job_not_found => error.JobNotFound, - .job_already_running => error.JobAlreadyRunning, - .job_failed_to_start, .job_execution_failed => error.CommandFailed, - .job_cancelled => error.JobCancelled, - else => error.ServerError, - }; - } - - /// Clean up packet allocated memory - pub fn cleanupPacket(self: *Client, packet: protocol.ResponsePacket) void { - if (packet.success_message) |msg| { - self.allocator.free(msg); - } - if (packet.error_message) |msg| { - self.allocator.free(msg); - } - if (packet.error_details) |details| { - self.allocator.free(details); - } - if (packet.progress_message) |msg| { - self.allocator.free(msg); - } - if (packet.status_data) |data| { - self.allocator.free(data); - } - if (packet.data_type) |dtype| { - self.allocator.free(dtype); - } - if (packet.data_payload) |payload| { - self.allocator.free(payload); - } - if (packet.log_message) |msg| { - self.allocator.free(msg); - } - } - pub fn sendCrashReport(self: *Client, api_key_hash: []const u8, error_type: []const u8, error_message: []const u8, command: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: [opcode:1][api_key_hash:16][error_type_len:2][error_type][error_message_len:2][error_message][command_len:2][command] - const total_len = 1 + 16 + 2 + error_type.len + 2 + error_message.len + 2 + command.len; - const message = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(message); - - var offset: usize = 0; - - // Opcode - message[offset] = @intFromEnum(Opcode.crash_report); - offset += 1; - - // API key hash - @memcpy(message[offset .. offset + 16], api_key_hash); - offset += 16; - - // Error type length and data - std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_type.len), .big); - offset += 2; - @memcpy(message[offset .. offset + error_type.len], error_type); - offset += error_type.len; - - // Error message length and data - std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_message.len), .big); - offset += 2; - @memcpy(message[offset .. offset + error_message.len], error_message); - offset += error_message.len; - - // Command length and data - std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(command.len), .big); - offset += 2; - @memcpy(message[offset .. offset + command.len], command); - - // Send WebSocket frame - try sendWebSocketFrame(stream, message); - } - - // Dataset management methods - pub fn sendDatasetList(self: *Client, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: [opcode: u8] [api_key_hash: 16 bytes] - const total_len = 1 + 16; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - buffer[0] = @intFromEnum(Opcode.dataset_list); - @memcpy(buffer[1..17], api_key_hash); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendDatasetRegister(self: *Client, name: []const u8, url: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (name.len > 255) return error.NameTooLong; - if (url.len > 1023) return error.URLTooLong; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [name_len: u8] [name: var] [url_len: u16] [url: var] - const total_len = 1 + 16 + 1 + name.len + 2 + url.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.dataset_register); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(name.len); - offset += 1; - - @memcpy(buffer[offset .. offset + name.len], name); - offset += name.len; - - std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(url.len), .big); - offset += 2; - - @memcpy(buffer[offset .. offset + url.len], url); - - try sendWebSocketFrame(stream, buffer); - } - - // Jupyter management methods - pub fn sendStartJupyter(self: *Client, name: []const u8, workspace: []const u8, password: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (name.len > 255) return error.NameTooLong; - if (workspace.len > 65535) return error.WorkspacePathTooLong; - if (password.len > 255) return error.PasswordTooLong; - - // Build binary message: - // [opcode:1][api_key_hash:16][name_len:1][name:var][workspace_len:2][workspace:var][password_len:1][password:var] - const total_len = 1 + 16 + 1 + name.len + 2 + workspace.len + 1 + password.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.start_jupyter); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(name.len); - offset += 1; - @memcpy(buffer[offset .. offset + name.len], name); - offset += name.len; - - std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(workspace.len), .big); - offset += 2; - @memcpy(buffer[offset .. offset + workspace.len], workspace); - offset += workspace.len; - - buffer[offset] = @intCast(password.len); - offset += 1; - @memcpy(buffer[offset .. offset + password.len], password); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendStopJupyter(self: *Client, service_id: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (service_id.len > 255) return error.InvalidServiceId; - - // Build binary message: [opcode:1][api_key_hash:16][service_id_len:1][service_id:var] - const total_len = 1 + 16 + 1 + service_id.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.stop_jupyter); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(service_id.len); - offset += 1; - @memcpy(buffer[offset .. offset + service_id.len], service_id); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendRemoveJupyter(self: *Client, service_id: []const u8, api_key_hash: []const u8, purge: bool) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (service_id.len > 255) return error.InvalidServiceId; - - // Build binary message: [opcode:1][api_key_hash:16][service_id_len:1][service_id:var][purge:1] - const total_len = 1 + 16 + 1 + service_id.len + 1; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.remove_jupyter); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(service_id.len); - offset += 1; - @memcpy(buffer[offset .. offset + service_id.len], service_id); - offset += service_id.len; - - buffer[offset] = if (purge) 0x01 else 0x00; - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendRestoreJupyter(self: *Client, name: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (name.len > 255) return error.NameTooLong; - - // Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var] - const total_len = 1 + 16 + 1 + name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.restore_jupyter); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(name.len); - offset += 1; - @memcpy(buffer[offset .. offset + name.len], name); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendListJupyter(self: *Client, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: [opcode:1][api_key_hash:16] - const total_len = 1 + 16; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - buffer[0] = @intFromEnum(Opcode.list_jupyter); - @memcpy(buffer[1..17], api_key_hash); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendDatasetInfo(self: *Client, name: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (name.len > 255) return error.NameTooLong; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [name_len: u8] [name: var] - const total_len = 1 + 16 + 1 + name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.dataset_info); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(name.len); - offset += 1; - - @memcpy(buffer[offset..], name); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendDatasetSearch(self: *Client, term: []const u8, api_key_hash: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - - // Build binary message: [opcode: u8] [api_key_hash: 16 bytes] [term_len: u8] [term: var] - const total_len = 1 + 16 + 1 + term.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.dataset_search); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - buffer[offset] = @intCast(term.len); - offset += 1; - - @memcpy(buffer[offset..], term); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendLogMetric(self: *Client, api_key_hash: []const u8, commit_id: []const u8, name: []const u8, value: f64, step: u32) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - if (name.len > 255) return error.NameTooLong; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] [step: u32] [value: f64] [name_len: u8] [name: var] - const total_len = 1 + 16 + 20 + 4 + 8 + 1 + name.len; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.log_metric); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - std.mem.writeInt(u32, buffer[offset .. offset + 4][0..4], step, .big); - offset += 4; - - std.mem.writeInt(u64, buffer[offset .. offset + 8][0..8], @as(u64, @bitCast(value)), .big); - offset += 8; - - buffer[offset] = @intCast(name.len); - offset += 1; - - @memcpy(buffer[offset..], name); - - try sendWebSocketFrame(stream, buffer); - } - - pub fn sendGetExperiment(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { - const stream = self.stream orelse return error.NotConnected; - - if (api_key_hash.len != 16) return error.InvalidApiKeyHash; - if (commit_id.len != 20) return error.InvalidCommitId; - - // Build binary message: - // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] - const total_len = 1 + 16 + 20; - var buffer = try self.allocator.alloc(u8, total_len); - defer self.allocator.free(buffer); - - var offset: usize = 0; - buffer[offset] = @intFromEnum(Opcode.get_experiment); - offset += 1; - - @memcpy(buffer[offset .. offset + 16], api_key_hash); - offset += 16; - - @memcpy(buffer[offset .. offset + 20], commit_id); - offset += 20; - - try sendWebSocketFrame(stream, buffer); - } - - /// Receive and handle dataset response - pub fn receiveAndHandleDatasetResponse(self: *Client, allocator: std.mem.Allocator) ![]const u8 { - const message = try self.receiveMessage(allocator); - defer allocator.free(message); - - const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { - // Fallback: treat as plain response. - return allocator.dupe(u8, message); - }; - defer { - if (packet.success_message) |msg| allocator.free(msg); - if (packet.error_message) |msg| allocator.free(msg); - if (packet.error_details) |details| allocator.free(details); - if (packet.data_type) |dtype| allocator.free(dtype); - if (packet.data_payload) |payload| allocator.free(payload); - if (packet.progress_message) |pmsg| allocator.free(pmsg); - if (packet.status_data) |sdata| allocator.free(sdata); - if (packet.log_message) |lmsg| allocator.free(lmsg); - } - - switch (packet.packet_type) { - .data => { - if (packet.data_payload) |payload| { - return allocator.dupe(u8, payload); - } - return allocator.dupe(u8, ""); - }, - .success => { - if (packet.success_message) |msg| { - return allocator.dupe(u8, msg); - } - return allocator.dupe(u8, ""); - }, - .error_packet => { - // Print details and raise appropriate CLI error. - _ = self.handleResponsePacket(packet, "Dataset") catch {}; - return self.convertServerError(packet.error_code.?); - }, - else => { - return allocator.dupe(u8, ""); - }, - } - } -}; - -fn resolveHostAddress(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { - return std.net.Address.parseIp(host, port) catch |err| switch (err) { - error.InvalidIPAddressFormat => resolveHostname(allocator, host, port), - else => return err, - }; -} - -fn resolveHostname(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { - var address_list = try std.net.getAddressList(allocator, host, port); - defer address_list.deinit(); - - if (address_list.addrs.len == 0) return error.HostResolutionFailed; - - return address_list.addrs[0]; -} - -test "resolve hostnames for WebSocket connections" { - _ = try resolveHostAddress(std.testing.allocator, "localhost", 9100); -} +// WebSocket module - exports websocket components +pub const client = @import("ws/client.zig"); +pub const frame = @import("ws/frame.zig"); +pub const handshake = @import("ws/handshake.zig"); +pub const resolve = @import("ws/resolve.zig"); +pub const response = @import("ws/response.zig"); +pub const response_handlers = @import("ws/response_handlers.zig"); +pub const opcode = @import("ws/opcode.zig"); +pub const utils = @import("ws/utils.zig"); + +pub const Client = client.Client; diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig new file mode 100644 index 0000000..2966d89 --- /dev/null +++ b/cli/src/net/ws/client.zig @@ -0,0 +1,1256 @@ +const deps = @import("deps.zig"); +const std = deps.std; +const crypto = deps.crypto; +const io = deps.io; +const log = deps.log; +const protocol = deps.protocol; +const resolve = @import("resolve.zig"); +const handshake = @import("handshake.zig"); +const frame = @import("frame.zig"); +const response = @import("response.zig"); +const response_handlers = @import("response_handlers.zig"); +const opcode = @import("opcode.zig"); +const utils = @import("utils.zig"); + +/// WebSocket client for binary protocol communication +pub const Client = struct { + allocator: std.mem.Allocator, + stream: ?std.net.Stream, + host: []const u8, + port: u16, + is_tls: bool = false, + + pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 { + return response.formatPrewarmFromStatusRoot(allocator, root); + } + + pub fn connect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !Client { + // Detect TLS + const is_tls = std.mem.startsWith(u8, url, "wss://"); + + // Parse URL (simplified - assumes ws://host:port/path or wss://host:port/path) + const host_start = std.mem.indexOf(u8, url, "//") orelse return error.InvalidURL; + const host_port_start = host_start + 2; + const path_start = std.mem.indexOfPos(u8, url, host_port_start, "/") orelse url.len; + const colon_pos = std.mem.indexOfPos(u8, url, host_port_start, ":"); + + const host_end = blk: { + if (colon_pos) |pos| { + if (pos < path_start) break :blk pos; + } + break :blk path_start; + }; + const host = url[host_port_start..host_end]; + + var port: u16 = if (is_tls) 9101 else 9100; // default ports + if (colon_pos) |pos| { + if (pos < path_start) { + const port_start = pos + 1; + const port_end = std.mem.indexOfPos(u8, url, port_start, "/") orelse url.len; + const port_str = url[port_start..port_end]; + port = try std.fmt.parseInt(u16, port_str, 10); + } + } + + // Connect to server + const stream = try std.net.tcpConnectToAddress(try resolve.resolveHostAddress(allocator, host, port)); + + // For TLS, we'd need to wrap the stream with TLS + // For now, we'll just support ws:// and document wss:// requires additional setup + if (is_tls) { + // TODO(context): Implement native wss:// support by introducing a transport abstraction + // (raw TCP vs TLS client stream), performing TLS handshake + certificate verification, and updating + // handshake/frame read+write helpers to operate on the chosen transport. + std.log.warn("TLS (wss://) support requires additional TLS library integration", .{}); + return error.TLSNotSupported; + } + // Perform WebSocket handshake + try handshake.handshake(allocator, stream, host, url, api_key); + + return Client{ + .allocator = allocator, + .stream = stream, + .host = try allocator.dupe(u8, host), + .port = port, + .is_tls = is_tls, + }; + } + + /// Connect to WebSocket server with retry logic + pub fn connectWithRetry(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8, max_retries: u32) !Client { + var retry_count: u32 = 0; + var last_error: anyerror = error.ConnectionFailed; + + while (retry_count < max_retries) { + const client = connect(allocator, url, api_key) catch |err| { + last_error = err; + retry_count += 1; + + if (retry_count < max_retries) { + const delay_ms = @min(1000 * retry_count, 5000); // Exponential backoff, max 5s + log.warn("Connection failed (attempt {d}/{d}), retrying in {d}s...\n", .{ retry_count, max_retries, delay_ms / 1000 }); + std.Thread.sleep(@as(u64, delay_ms) * std.time.ns_per_ms); + } + continue; + }; + + if (retry_count > 0) { + log.success("Connected successfully after {d} attempts\n", .{retry_count + 1}); + } + return client; + } + + return last_error; + } + + /// Disconnect from WebSocket server + pub fn disconnect(self: *Client) void { + if (self.stream) |stream| { + stream.close(); + self.stream = null; + } + } + + pub fn close(self: *Client) void { + if (self.stream) |stream| { + stream.close(); + self.stream = null; + } + if (self.host.len > 0) { + self.allocator.free(self.host); + } + } + + pub fn sendValidateRequestCommit(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + + const total_len = 1 + 16 + 1 + 1 + 20; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.Opcode.validate_request); + offset += 1; + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + buffer[offset] = @intFromEnum(opcode.ValidateTargetType.commit_id); + offset += 1; + buffer[offset] = 20; + offset += 1; + @memcpy(buffer[offset .. offset + 20], commit_id); + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var] + const total_len = 1 + 16 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.list_jupyter_packages); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + name.len], name); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendSetRunNarrative( + self: *Client, + job_name: []const u8, + patch_json: []const u8, + api_key_hash: []const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong; + if (patch_json.len == 0 or patch_json.len > 0xFFFF) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash:16] + // [job_name_len:1][job_name] + // [patch_len:2][patch_json] + const total_len = 1 + 16 + 1 + job_name.len + 2 + patch_json.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.set_run_narrative); + offset += 1; + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @as(u8, @intCast(job_name.len)); + offset += 1; + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(patch_json.len)), .big); + offset += 2; + @memcpy(buffer[offset .. offset + patch_json.len], patch_json); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendAnnotateRun( + self: *Client, + job_name: []const u8, + author: []const u8, + note: []const u8, + api_key_hash: []const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong; + if (author.len > 255) return error.PayloadTooLarge; + if (note.len == 0 or note.len > 0xFFFF) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash:16] + // [job_name_len:1][job_name] + // [author_len:1][author] + // [note_len:2][note] + const total_len = 1 + 16 + 1 + job_name.len + 1 + author.len + 2 + note.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.annotate_run); + offset += 1; + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @as(u8, @intCast(job_name.len)); + offset += 1; + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = @as(u8, @intCast(author.len)); + offset += 1; + if (author.len > 0) { + @memcpy(buffer[offset .. offset + author.len], author); + } + offset += author.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(note.len)), .big); + offset += 2; + @memcpy(buffer[offset .. offset + note.len], note); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithArgsNoteAndResources( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + args: []const u8, + note: []const u8, + cpu: u8, + memory_gb: u8, + gpu: u8, + gpu_memory: ?[]const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + if (args.len > 0xFFFF) return error.PayloadTooLarge; + if (note.len > 0xFFFF) return error.PayloadTooLarge; + + const gpu_mem = gpu_memory orelse ""; + if (gpu_mem.len > 255) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash] + // [commit_id] + // [priority] + // [job_name_len][job_name] + // [args_len:2][args] + // [note_len:2][note] + // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 2 + note.len + 4 + gpu_mem.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job_with_note); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = @intCast((args.len >> 8) & 0xFF); + buffer[offset + 1] = @intCast(args.len & 0xFF); + offset += 2; + + if (args.len > 0) { + @memcpy(buffer[offset .. offset + args.len], args); + offset += args.len; + } + + buffer[offset] = @intCast((note.len >> 8) & 0xFF); + buffer[offset + 1] = @intCast(note.len & 0xFF); + offset += 2; + + if (note.len > 0) { + @memcpy(buffer[offset .. offset + note.len], note); + offset += note.len; + } + + buffer[offset] = cpu; + buffer[offset + 1] = memory_gb; + buffer[offset + 2] = gpu; + buffer[offset + 3] = @intCast(gpu_mem.len); + offset += 4; + + if (gpu_mem.len > 0) { + @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithArgsAndResources( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + args: []const u8, + cpu: u8, + memory_gb: u8, + gpu: u8, + gpu_memory: ?[]const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + if (args.len > 0xFFFF) return error.PayloadTooLarge; + + const gpu_mem = gpu_memory orelse ""; + if (gpu_mem.len > 255) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash] + // [commit_id] + // [priority] + // [job_name_len][job_name] + // [args_len:2][args] + // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + args.len + 4 + gpu_mem.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job_with_args); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = @intCast((args.len >> 8) & 0xFF); + buffer[offset + 1] = @intCast(args.len & 0xFF); + offset += 2; + + if (args.len > 0) { + @memcpy(buffer[offset .. offset + args.len], args); + offset += args.len; + } + + buffer[offset] = cpu; + buffer[offset + 1] = memory_gb; + buffer[offset + 2] = gpu; + buffer[offset + 3] = @intCast(gpu_mem.len); + offset += 4; + + if (gpu_mem.len > 0) { + @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithSnapshotAndResources( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + snapshot_id: []const u8, + snapshot_sha256: []const u8, + cpu: u8, + memory_gb: u8, + gpu: u8, + gpu_memory: ?[]const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + if (snapshot_id.len == 0 or snapshot_id.len > 255) return error.PayloadTooLarge; + if (snapshot_sha256.len == 0 or snapshot_sha256.len > 255) return error.PayloadTooLarge; + + const gpu_mem = gpu_memory orelse ""; + if (gpu_mem.len > 255) return error.PayloadTooLarge; + + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 1 + snapshot_id.len + 1 + snapshot_sha256.len + 4 + gpu_mem.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job_with_snapshot); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = @intCast(snapshot_id.len); + offset += 1; + + @memcpy(buffer[offset .. offset + snapshot_id.len], snapshot_id); + offset += snapshot_id.len; + + buffer[offset] = @intCast(snapshot_sha256.len); + offset += 1; + + @memcpy(buffer[offset .. offset + snapshot_sha256.len], snapshot_sha256); + offset += snapshot_sha256.len; + + buffer[offset] = cpu; + buffer[offset + 1] = memory_gb; + buffer[offset + 2] = gpu; + buffer[offset + 3] = @intCast(gpu_mem.len); + offset += 4; + + if (gpu_mem.len > 0) { + @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendValidateRequestTask(self: *Client, api_key_hash: []const u8, task_id: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (task_id.len == 0 or task_id.len > 255) return error.PayloadTooLarge; + + const total_len = 1 + 16 + 1 + 1 + task_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.validate_request); + offset += 1; + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + buffer[offset] = @intFromEnum(opcode.ValidateTargetType.task_id); + offset += 1; + buffer[offset] = @intCast(task_id.len); + offset += 1; + @memcpy(buffer[offset .. offset + task_id.len], task_id); + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJob(self: *Client, job_name: []const u8, commit_id: []const u8, priority: u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + // Validate input lengths + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] [priority: u8] [job_name_len: u8] [job_name: var] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset..], job_name); + + // Send as WebSocket binary frame + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithResources( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + cpu: u8, + memory_gb: u8, + gpu: u8, + gpu_memory: ?[]const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + + const gpu_mem = gpu_memory orelse ""; + if (gpu_mem.len > 255) return error.PayloadTooLarge; + + // Tail encoding: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 4 + gpu_mem.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = cpu; + buffer[offset + 1] = memory_gb; + buffer[offset + 2] = gpu; + buffer[offset + 3] = @intCast(gpu_mem.len); + offset += 4; + + if (gpu_mem.len > 0) { + @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithTracking( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + tracking_json: []const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + // Validate input lengths + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + if (tracking_json.len > 0xFFFF) return error.PayloadTooLarge; + + // Build binary message: + // [opcode: u8] + // [api_key_hash: 16] + // [commit_id: 20] + // [priority: u8] + // [job_name_len: u8] + // [job_name: var] + // [tracking_json_len: u16] + // [tracking_json: var] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + tracking_json.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job_with_tracking); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + // tracking_json length (big-endian) + buffer[offset] = @intCast((tracking_json.len >> 8) & 0xFF); + buffer[offset + 1] = @intCast(tracking_json.len & 0xFF); + offset += 2; + + if (tracking_json.len > 0) { + @memcpy(buffer[offset .. offset + tracking_json.len], tracking_json); + } + + // Single WebSocket frame for throughput + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendQueueJobWithTrackingAndResources( + self: *Client, + job_name: []const u8, + commit_id: []const u8, + priority: u8, + api_key_hash: []const u8, + tracking_json: []const u8, + cpu: u8, + memory_gb: u8, + gpu: u8, + gpu_memory: ?[]const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + if (tracking_json.len > 0xFFFF) return error.PayloadTooLarge; + + const gpu_mem = gpu_memory orelse ""; + if (gpu_mem.len > 255) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash] + // [commit_id] + // [priority] + // [job_name_len][job_name] + // [tracking_json_len:2][tracking_json] + // [cpu][memory_gb][gpu][gpu_mem_len][gpu_mem] + const total_len = 1 + 16 + 20 + 1 + 1 + job_name.len + 2 + tracking_json.len + 4 + gpu_mem.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.queue_job_with_tracking); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + buffer[offset] = @intCast((tracking_json.len >> 8) & 0xFF); + buffer[offset + 1] = @intCast(tracking_json.len & 0xFF); + offset += 2; + + if (tracking_json.len > 0) { + @memcpy(buffer[offset .. offset + tracking_json.len], tracking_json); + offset += tracking_json.len; + } + + buffer[offset] = cpu; + buffer[offset + 1] = memory_gb; + buffer[offset + 2] = gpu; + buffer[offset + 3] = @intCast(gpu_mem.len); + offset += 4; + + if (gpu_mem.len > 0) { + @memcpy(buffer[offset .. offset + gpu_mem.len], gpu_mem); + } + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (job_name.len > 255) return error.JobNameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] + const total_len = 1 + 16 + 1 + job_name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.cancel_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset..], job_name); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [prune_type: u8] [value: u4] + const total_len = 1 + 16 + 1 + 4; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.prune); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = prune_type; + offset += 1; + + // Store value in big-endian format + buffer[offset] = @intCast((value >> 24) & 0xFF); + buffer[offset + 1] = @intCast((value >> 16) & 0xFF); + buffer[offset + 2] = @intCast((value >> 8) & 0xFF); + buffer[offset + 3] = @intCast(value & 0xFF); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] + const total_len = 1 + 16; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + buffer[0] = @intFromEnum(opcode.status_request); + @memcpy(buffer[1..17], api_key_hash); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 { + const stream = self.stream orelse return error.NotConnected; + + return frame.receiveBinaryMessage(stream, allocator); + } + + /// Receive and handle response with automatic display + pub fn receiveAndHandleResponse(self: *Client, allocator: std.mem.Allocator, operation: []const u8) !void { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + // Fallback: treat as plain response. + std.debug.print("Server response: {s}\n", .{message}); + return; + }; + defer { + if (packet.success_message) |msg| allocator.free(msg); + if (packet.error_message) |msg| allocator.free(msg); + if (packet.error_details) |details| allocator.free(details); + if (packet.data_type) |dtype| allocator.free(dtype); + if (packet.data_payload) |payload| allocator.free(payload); + if (packet.progress_message) |pmsg| allocator.free(pmsg); + if (packet.status_data) |sdata| allocator.free(sdata); + if (packet.log_message) |lmsg| allocator.free(lmsg); + } + + try response_handlers.handleResponsePacket(self, packet, operation); + } + + pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void { + return response_handlers.receiveAndHandleStatusResponse(self, allocator, user_context, options); + } + + pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void { + return response_handlers.receiveAndHandleCancelResponse(self, allocator, user_context, job_name, options); + } + + pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void { + return response_handlers.handleResponsePacket(self, packet, operation); + } + + fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror { + _ = self; + return switch (server_error) { + .authentication_failed => error.AuthenticationFailed, + .permission_denied => error.PermissionDenied, + .resource_not_found => error.JobNotFound, + .resource_already_exists => error.ResourceExists, + .timeout => error.RequestTimeout, + .server_overloaded, .service_unavailable => error.ServerUnreachable, + .invalid_request => error.InvalidArguments, + .job_not_found => error.JobNotFound, + .job_already_running => error.JobAlreadyRunning, + .job_failed_to_start, .job_execution_failed => error.CommandFailed, + .job_cancelled => error.JobCancelled, + else => error.ServerError, + }; + } + + pub fn sendCrashReport(self: *Client, api_key_hash: []const u8, error_type: []const u8, error_message: []const u8, command: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: [opcode:1][api_key_hash:16][error_type_len:2][error_type][error_message_len:2][error_message][command_len:2][command] + const total_len = 1 + 16 + 2 + error_type.len + 2 + error_message.len + 2 + command.len; + const message = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(message); + + var offset: usize = 0; + + // opcode + message[offset] = @intFromEnum(opcode.crash_report); + offset += 1; + + // API key hash + @memcpy(message[offset .. offset + 16], api_key_hash); + offset += 16; + + // Error type length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_type.len), .big); + offset += 2; + @memcpy(message[offset .. offset + error_type.len], error_type); + offset += error_type.len; + + // Error message length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_message.len), .big); + offset += 2; + @memcpy(message[offset .. offset + error_message.len], error_message); + offset += error_message.len; + + // Command length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(command.len), .big); + offset += 2; + @memcpy(message[offset .. offset + command.len], command); + + // Send WebSocket frame + try frame.sendWebSocketFrame(stream, message); + } + + // Dataset management methods + pub fn sendDatasetList(self: *Client, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: [opcode: u8] [api_key_hash: 16 bytes] + const total_len = 1 + 16; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + buffer[0] = @intFromEnum(opcode.dataset_list); + @memcpy(buffer[1..17], api_key_hash); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetRegister(self: *Client, name: []const u8, url: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + if (url.len > 1023) return error.URLTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [name_len: u8] [name: var] [url_len: u16] [url: var] + const total_len = 1 + 16 + 1 + name.len + 2 + url.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.dataset_register); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + name.len], name); + offset += name.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(url.len), .big); + offset += 2; + + @memcpy(buffer[offset .. offset + url.len], url); + + try frame.sendWebSocketFrame(stream, buffer); + } + + // Jupyter management methods + pub fn sendStartJupyter(self: *Client, name: []const u8, workspace: []const u8, password: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + if (workspace.len > 65535) return error.WorkspacePathTooLong; + if (password.len > 255) return error.PasswordTooLong; + + // Build binary message: + // [opcode:1][api_key_hash:16][name_len:1][name:var][workspace_len:2][workspace:var][password_len:1][password:var] + const total_len = 1 + 16 + 1 + name.len + 2 + workspace.len + 1 + password.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.start_jupyter); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(name.len); + offset += 1; + @memcpy(buffer[offset .. offset + name.len], name); + offset += name.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(workspace.len), .big); + offset += 2; + @memcpy(buffer[offset .. offset + workspace.len], workspace); + offset += workspace.len; + + buffer[offset] = @intCast(password.len); + offset += 1; + @memcpy(buffer[offset .. offset + password.len], password); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendStopJupyter(self: *Client, service_id: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (service_id.len > 255) return error.InvalidServiceId; + + // Build binary message: [opcode:1][api_key_hash:16][service_id_len:1][service_id:var] + const total_len = 1 + 16 + 1 + service_id.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.stop_jupyter); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(service_id.len); + offset += 1; + @memcpy(buffer[offset .. offset + service_id.len], service_id); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendRemoveJupyter(self: *Client, service_id: []const u8, api_key_hash: []const u8, purge: bool) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (service_id.len > 255) return error.InvalidServiceId; + + // Build binary message: [opcode:1][api_key_hash:16][service_id_len:1][service_id:var][purge:1] + const total_len = 1 + 16 + 1 + service_id.len + 1; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.remove_jupyter); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(service_id.len); + offset += 1; + @memcpy(buffer[offset .. offset + service_id.len], service_id); + offset += service_id.len; + + buffer[offset] = if (purge) 0x01 else 0x00; + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendRestoreJupyter(self: *Client, name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: [opcode:1][api_key_hash:16][name_len:1][name:var] + const total_len = 1 + 16 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.restore_jupyter); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(name.len); + offset += 1; + @memcpy(buffer[offset .. offset + name.len], name); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendListJupyter(self: *Client, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: [opcode:1][api_key_hash:16] + const total_len = 1 + 16; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + buffer[0] = @intFromEnum(opcode.list_jupyter); + @memcpy(buffer[1..17], api_key_hash); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetInfo(self: *Client, name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [name_len: u8] [name: var] + const total_len = 1 + 16 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.dataset_info); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset..], name); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetSearch(self: *Client, term: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + + // Build binary message: [opcode: u8] [api_key_hash: 16 bytes] [term_len: u8] [term: var] + const total_len = 1 + 16 + 1 + term.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.dataset_search); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(term.len); + offset += 1; + + @memcpy(buffer[offset..], term); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendLogMetric(self: *Client, api_key_hash: []const u8, commit_id: []const u8, name: []const u8, value: f64, step: u32) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] [step: u32] [value: f64] [name_len: u8] [name: var] + const total_len = 1 + 16 + 20 + 4 + 8 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.log_metric); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + std.mem.writeInt(u32, buffer[offset .. offset + 4][0..4], step, .big); + offset += 4; + + std.mem.writeInt(u64, buffer[offset .. offset + 8][0..8], @as(u64, @bitCast(value)), .big); + offset += 8; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset..], name); + + try frame.sendWebSocketFrame(stream, buffer); + } + + pub fn sendGetExperiment(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (commit_id.len != 20) return error.InvalidCommitId; + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [commit_id: 20 bytes] + const total_len = 1 + 16 + 20; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.get_experiment); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + @memcpy(buffer[offset .. offset + 20], commit_id); + offset += 20; + + try frame.sendWebSocketFrame(stream, buffer); + } + + /// Receive and handle dataset response + pub fn receiveAndHandleDatasetResponse(self: *Client, allocator: std.mem.Allocator) ![]const u8 { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + const packet = protocol.ResponsePacket.deserialize(message, allocator) catch { + // Fallback: treat as plain response. + return allocator.dupe(u8, message); + }; + defer { + if (packet.success_message) |msg| allocator.free(msg); + if (packet.error_message) |msg| allocator.free(msg); + if (packet.error_details) |details| allocator.free(details); + if (packet.data_type) |dtype| allocator.free(dtype); + if (packet.data_payload) |payload| allocator.free(payload); + if (packet.progress_message) |pmsg| allocator.free(pmsg); + if (packet.status_data) |sdata| allocator.free(sdata); + if (packet.log_message) |lmsg| allocator.free(lmsg); + } + + switch (packet.packet_type) { + .data => { + if (packet.data_payload) |payload| { + return allocator.dupe(u8, payload); + } + return allocator.dupe(u8, ""); + }, + .success => { + if (packet.success_message) |msg| { + return allocator.dupe(u8, msg); + } + return allocator.dupe(u8, ""); + }, + .error_packet => { + // Print details and raise appropriate CLI error. + _ = response_handlers.handleResponsePacket(self, packet, "Dataset") catch {}; + return self.convertServerError(packet.error_code.?); + }, + else => { + // Unexpected packet type. + return error.UnexpectedResponse; + }, + } + } +}; diff --git a/cli/src/net/ws/deps.zig b/cli/src/net/ws/deps.zig new file mode 100644 index 0000000..3c78f77 --- /dev/null +++ b/cli/src/net/ws/deps.zig @@ -0,0 +1,8 @@ +pub const std = @import("std"); + +pub const colors = @import("../../utils/colors.zig"); +pub const crypto = @import("../../utils/crypto.zig"); +pub const io = @import("../../utils/io.zig"); +pub const log = @import("../../utils/logging.zig"); + +pub const protocol = @import("../protocol.zig"); diff --git a/cli/src/net/ws/frame.zig b/cli/src/net/ws/frame.zig new file mode 100644 index 0000000..10fdf1f --- /dev/null +++ b/cli/src/net/ws/frame.zig @@ -0,0 +1,71 @@ +const std = @import("std"); + +pub fn sendWebSocketFrame(stream: std.net.Stream, payload: []const u8) !void { + var frame: [14]u8 = undefined; + var frame_len: usize = 2; + + // FIN=1, opcode=0x2 (binary), MASK=1 + frame[0] = 0x82 | 0x80; + + // Payload length + if (payload.len < 126) { + frame[1] = @as(u8, @intCast(payload.len)) | 0x80; + } else if (payload.len < 65536) { + frame[1] = 126 | 0x80; + frame[2] = @intCast(payload.len >> 8); + frame[3] = @intCast(payload.len & 0xFF); + frame_len = 4; + } else { + return error.PayloadTooLarge; + } + + // Generate random mask (4 bytes) + var mask: [4]u8 = undefined; + var i: usize = 0; + while (i < 4) : (i += 1) { + mask[i] = @as(u8, @intCast(@mod(std.time.timestamp(), 256))); + } + + @memcpy(frame[frame_len .. frame_len + 4], &mask); + frame_len += 4; + + _ = try stream.write(frame[0..frame_len]); + + var masked_payload = try std.heap.page_allocator.alloc(u8, payload.len); + defer std.heap.page_allocator.free(masked_payload); + + for (payload, 0..) |byte, j| { + masked_payload[j] = byte ^ mask[j % 4]; + } + + _ = try stream.write(masked_payload); +} + +pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator) ![]u8 { + var header: [2]u8 = undefined; + const header_bytes = try stream.read(&header); + if (header_bytes < 2) return error.ConnectionClosed; + + if (header[0] != 0x82) return error.InvalidFrame; + + var payload_len: usize = header[1]; + if (payload_len == 126) { + var len_bytes: [2]u8 = undefined; + _ = try stream.read(&len_bytes); + payload_len = (@as(usize, len_bytes[0]) << 8) | len_bytes[1]; + } else if (payload_len == 127) { + return error.PayloadTooLarge; + } + + const payload = try allocator.alloc(u8, payload_len); + errdefer allocator.free(payload); + + var bytes_read: usize = 0; + while (bytes_read < payload_len) { + const n = try stream.read(payload[bytes_read..]); + if (n == 0) return error.ConnectionClosed; + bytes_read += n; + } + + return payload; +} diff --git a/cli/src/net/ws/handshake.zig b/cli/src/net/ws/handshake.zig new file mode 100644 index 0000000..48a7248 --- /dev/null +++ b/cli/src/net/ws/handshake.zig @@ -0,0 +1,114 @@ +const std = @import("std"); + +fn generateWebSocketKey(allocator: std.mem.Allocator) ![]u8 { + var random_bytes: [16]u8 = undefined; + std.crypto.random.bytes(&random_bytes); + + const base64 = std.base64.standard.Encoder; + const result = try allocator.alloc(u8, base64.calcSize(random_bytes.len)); + _ = base64.encode(result, &random_bytes); + return result; +} + +pub fn handshake( + allocator: std.mem.Allocator, + stream: std.net.Stream, + host: []const u8, + url: []const u8, + api_key: []const u8, +) !void { + const key = try generateWebSocketKey(allocator); + defer allocator.free(key); + + const request = try std.fmt.allocPrint( + allocator, + "GET {s} HTTP/1.1\r\n" ++ + "Host: {s}\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: {s}\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "X-API-Key: {s}\r\n" ++ + "\r\n", + .{ url, host, key, api_key }, + ); + defer allocator.free(request); + + _ = try stream.write(request); + + var response_buf: [4096]u8 = undefined; + var bytes_read: usize = 0; + var header_complete = false; + + while (!header_complete and bytes_read < response_buf.len - 1) { + const chunk_bytes = try stream.read(response_buf[bytes_read..]); + if (chunk_bytes == 0) break; + bytes_read += chunk_bytes; + + if (std.mem.indexOf(u8, response_buf[0..bytes_read], "\r\n\r\n") != null) { + header_complete = true; + } + } + + const response = response_buf[0..bytes_read]; + + if (std.mem.indexOf(u8, response, "101 Switching Protocols") == null) { + if (std.mem.indexOf(u8, response, "404 Not Found") != null) { + std.debug.print("\n❌ WebSocket Connection Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("The WebSocket endpoint '/ws' was not found on the server.\n\n", .{}); + std.debug.print("This usually means:\n", .{}); + std.debug.print(" • API server is not running\n", .{}); + std.debug.print(" • Incorrect server address in config\n", .{}); + std.debug.print(" • Different service running on that port\n\n", .{}); + std.debug.print("To diagnose:\n", .{}); + std.debug.print(" • Verify server address: Check ~/.ml/config.toml\n", .{}); + std.debug.print(" • Test connectivity: curl http://:/health\n", .{}); + std.debug.print(" • Contact your server administrator if the issue persists\n\n", .{}); + return error.EndpointNotFound; + } else if (std.mem.indexOf(u8, response, "401 Unauthorized") != null) { + std.debug.print("\n❌ Authentication Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Invalid or missing API key.\n\n", .{}); + std.debug.print("To fix:\n", .{}); + std.debug.print(" • Verify API key in ~/.ml/config.toml matches server configuration\n", .{}); + std.debug.print(" • Request a new API key from your administrator if needed\n\n", .{}); + return error.AuthenticationFailed; + } else if (std.mem.indexOf(u8, response, "403 Forbidden") != null) { + std.debug.print("\n❌ Access Denied\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Your API key doesn't have permission for this operation.\n\n", .{}); + std.debug.print("To fix:\n", .{}); + std.debug.print(" • Contact your administrator to grant necessary permissions\n\n", .{}); + return error.PermissionDenied; + } else if (std.mem.indexOf(u8, response, "503 Service Unavailable") != null) { + std.debug.print("\n❌ Server Unavailable\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("The server is temporarily unavailable.\n\n", .{}); + std.debug.print("This could be due to:\n", .{}); + std.debug.print(" • Server maintenance\n", .{}); + std.debug.print(" • High load\n", .{}); + std.debug.print(" • Server restart\n\n", .{}); + std.debug.print("To resolve:\n", .{}); + std.debug.print(" • Wait a moment and try again\n", .{}); + std.debug.print(" • Contact administrator if the issue persists\n\n", .{}); + return error.ServerUnavailable; + } else { + std.debug.print("\n❌ WebSocket Handshake Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Expected HTTP 101 Switching Protocols, but received:\n", .{}); + + const newline_pos = std.mem.indexOf(u8, response, "\r\n") orelse response.len; + const status_line = response[0..newline_pos]; + std.debug.print(" {s}\n\n", .{status_line}); + + std.debug.print("To diagnose:\n", .{}); + std.debug.print(" • Verify server address in ~/.ml/config.toml\n", .{}); + std.debug.print(" • Check network connectivity to the server\n", .{}); + std.debug.print(" • Contact your administrator for assistance\n\n", .{}); + return error.HandshakeFailed; + } + } + + std.posix.nanosleep(0, 10 * std.time.ns_per_ms); +} diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig new file mode 100644 index 0000000..f679395 --- /dev/null +++ b/cli/src/net/ws/opcode.zig @@ -0,0 +1,73 @@ +pub const Opcode = enum(u8) { + queue_job = 0x01, + queue_job_with_tracking = 0x0C, + queue_job_with_snapshot = 0x17, + queue_job_with_args = 0x1A, + queue_job_with_note = 0x1B, + annotate_run = 0x1C, + set_run_narrative = 0x1D, + status_request = 0x02, + cancel_job = 0x03, + prune = 0x04, + crash_report = 0x05, + log_metric = 0x0A, + get_experiment = 0x0B, + start_jupyter = 0x0D, + stop_jupyter = 0x0E, + remove_jupyter = 0x18, + restore_jupyter = 0x19, + list_jupyter = 0x0F, + list_jupyter_packages = 0x1E, + + validate_request = 0x16, + + // Dataset management opcodes + dataset_list = 0x06, + dataset_register = 0x07, + dataset_info = 0x08, + dataset_search = 0x09, + + // Structured response opcodes + response_success = 0x10, + response_error = 0x11, + response_progress = 0x12, + response_status = 0x13, + response_data = 0x14, + response_log = 0x15, +}; + +pub const ValidateTargetType = enum(u8) { + commit_id = 0, + task_id = 1, +}; + +pub const queue_job = Opcode.queue_job; +pub const queue_job_with_tracking = Opcode.queue_job_with_tracking; +pub const queue_job_with_snapshot = Opcode.queue_job_with_snapshot; +pub const queue_job_with_args = Opcode.queue_job_with_args; +pub const queue_job_with_note = Opcode.queue_job_with_note; +pub const annotate_run = Opcode.annotate_run; +pub const set_run_narrative = Opcode.set_run_narrative; +pub const status_request = Opcode.status_request; +pub const cancel_job = Opcode.cancel_job; +pub const prune = Opcode.prune; +pub const crash_report = Opcode.crash_report; +pub const log_metric = Opcode.log_metric; +pub const get_experiment = Opcode.get_experiment; +pub const start_jupyter = Opcode.start_jupyter; +pub const stop_jupyter = Opcode.stop_jupyter; +pub const remove_jupyter = Opcode.remove_jupyter; +pub const restore_jupyter = Opcode.restore_jupyter; +pub const list_jupyter = Opcode.list_jupyter; +pub const list_jupyter_packages = Opcode.list_jupyter_packages; +pub const validate_request = Opcode.validate_request; +pub const dataset_list = Opcode.dataset_list; +pub const dataset_register = Opcode.dataset_register; +pub const dataset_info = Opcode.dataset_info; +pub const dataset_search = Opcode.dataset_search; +pub const response_success = Opcode.response_success; +pub const response_error = Opcode.response_error; +pub const response_progress = Opcode.response_progress; +pub const response_status = Opcode.response_status; +pub const response_data = Opcode.response_data; +pub const response_log = Opcode.response_log; diff --git a/cli/src/net/ws/opcodes.zig b/cli/src/net/ws/opcodes.zig new file mode 100644 index 0000000..6d42056 --- /dev/null +++ b/cli/src/net/ws/opcodes.zig @@ -0,0 +1,2 @@ +pub const Opcode = @import("opcode.zig").Opcode; +pub const ValidateTargetType = @import("opcode.zig").ValidateTargetType; diff --git a/cli/src/net/ws/resolve.zig b/cli/src/net/ws/resolve.zig new file mode 100644 index 0000000..1a71763 --- /dev/null +++ b/cli/src/net/ws/resolve.zig @@ -0,0 +1,21 @@ +const std = @import("std"); + +pub fn resolveHostAddress(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { + return std.net.Address.parseIp(host, port) catch |err| switch (err) { + error.InvalidIPAddressFormat => resolveHostname(allocator, host, port), + else => return err, + }; +} + +fn resolveHostname(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { + var address_list = try std.net.getAddressList(allocator, host, port); + defer address_list.deinit(); + + if (address_list.addrs.len == 0) return error.HostResolutionFailed; + + return address_list.addrs[0]; +} + +test "resolve hostnames for WebSocket connections" { + _ = try resolveHostAddress(std.testing.allocator, "localhost", 9100); +} diff --git a/cli/src/net/ws/response.zig b/cli/src/net/ws/response.zig new file mode 100644 index 0000000..b073d0c --- /dev/null +++ b/cli/src/net/ws/response.zig @@ -0,0 +1,74 @@ +const std = @import("std"); + +fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { + const v_opt = obj.get(key); + if (v_opt == null) { + return null; + } + const v = v_opt.?; + if (v != .string) { + return null; + } + return v.string; +} + +fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 { + const v_opt = obj.get(key); + if (v_opt == null) { + return null; + } + const v = v_opt.?; + if (v != .integer) { + return null; + } + return v.integer; +} + +pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 { + const prewarm_val_opt = root.get("prewarm"); + if (prewarm_val_opt == null) { + return null; + } + const prewarm_val = prewarm_val_opt.?; + if (prewarm_val != .array) { + return null; + } + + const items = prewarm_val.array.items; + if (items.len == 0) { + return null; + } + + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + const writer = out.writer(allocator); + try writer.writeAll("Prewarm:\n"); + + for (items) |item| { + if (item != .object) { + continue; + } + + const obj = item.object; + + const worker_id = jsonGetString(obj, "worker_id") orelse ""; + const task_id = jsonGetString(obj, "task_id") orelse ""; + const phase = jsonGetString(obj, "phase") orelse ""; + const started_at = jsonGetString(obj, "started_at") orelse ""; + const dataset_count = jsonGetInt(obj, "dataset_count") orelse 0; + const snapshot_id = jsonGetString(obj, "snapshot_id") orelse ""; + const env_image = jsonGetString(obj, "env_image") orelse ""; + const env_hit = jsonGetInt(obj, "env_hit") orelse 0; + const env_miss = jsonGetInt(obj, "env_miss") orelse 0; + const env_built = jsonGetInt(obj, "env_built") orelse 0; + + try writer.print( + " worker={s} task={s} phase={s} datasets={d} snapshot={s} env={s} env_hit={d} env_miss={d} env_built={d} started={s}\n", + .{ worker_id, task_id, phase, dataset_count, snapshot_id, env_image, env_hit, env_miss, env_built, started_at }, + ); + } + + const owned = try out.toOwnedSlice(allocator); + return owned; +} diff --git a/cli/src/net/ws/response_handlers.zig b/cli/src/net/ws/response_handlers.zig new file mode 100644 index 0000000..1c04390 --- /dev/null +++ b/cli/src/net/ws/response_handlers.zig @@ -0,0 +1,449 @@ +const deps = @import("deps.zig"); +const std = deps.std; +const io = deps.io; +const protocol = deps.protocol; +const colors = deps.colors; +const Client = @import("client.zig").Client; +const utils = @import("utils.zig"); + +/// Receive and handle status response with user filtering +pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void { + _ = user_context; // TODO: Use for filtering + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // Check if message is JSON (or contains JSON) or plain text + if (message[0] == '{') { + // Parse JSON response + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{}); + defer parsed.deinit(); + const root = parsed.value.object; + + if (options.json) { + // Output raw JSON + var out = io.stdoutWriter(); + try out.print("{s}\n", .{message}); + } else { + // Display user info + if (root.get("user")) |user_obj| { + const user = user_obj.object; + const name = user.get("name").?.string; + const admin = user.get("admin").?.bool; + colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin }); + } + + // Display task summary + if (root.get("tasks")) |tasks_obj| { + const tasks = tasks_obj.object; + const total = tasks.get("total").?.integer; + const queued = tasks.get("queued").?.integer; + const running = tasks.get("running").?.integer; + const failed = tasks.get("failed").?.integer; + const completed = tasks.get("completed").?.integer; + colors.printInfo( + "Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n", + .{ total, queued, running, failed, completed }, + ); + } + + const per_section_limit: usize = options.limit orelse 5; + + const TaskStatus = enum { queued, running, failed, completed }; + + const TaskPrinter = struct { + fn statusLabel(s: TaskStatus) []const u8 { + return switch (s) { + .queued => "Queued", + .running => "Running", + .failed => "Failed", + .completed => "Completed", + }; + } + + fn statusMatch(s: TaskStatus) []const u8 { + return switch (s) { + .queued => "queued", + .running => "running", + .failed => "failed", + .completed => "completed", + }; + } + + fn shorten(s: []const u8, max_len: usize) []const u8 { + if (s.len <= max_len) return s; + return s[0..max_len]; + } + + fn printSection( + allocator2: std.mem.Allocator, + queue_items: []const std.json.Value, + status: TaskStatus, + limit2: usize, + ) !void { + _ = allocator2; + const label = statusLabel(status); + const want = statusMatch(status); + std.debug.print("\n{s}:\n", .{label}); + + var shown: usize = 0; + for (queue_items) |item| { + if (item != .object) continue; + const obj = item.object; + const st = utils.jsonGetString(obj, "status") orelse ""; + if (!std.mem.eql(u8, st, want)) continue; + + const id = utils.jsonGetString(obj, "id") orelse ""; + const job_name = utils.jsonGetString(obj, "job_name") orelse ""; + const worker_id = utils.jsonGetString(obj, "worker_id") orelse ""; + const err = utils.jsonGetString(obj, "error") orelse ""; + + if (std.mem.eql(u8, want, "failed")) { + colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name }); + if (worker_id.len > 0) { + std.debug.print(" (worker={s})", .{worker_id}); + } + std.debug.print("\n", .{}); + if (err.len > 0) { + std.debug.print(" error: {s}\n", .{shorten(err, 160)}); + } + } else if (std.mem.eql(u8, want, "running")) { + colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name }); + if (worker_id.len > 0) { + std.debug.print(" (worker={s})", .{worker_id}); + } + std.debug.print("\n", .{}); + } else if (std.mem.eql(u8, want, "queued")) { + std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name }); + } else { + colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name }); + } + + shown += 1; + if (shown >= limit2) break; + } + + if (shown == 0) { + std.debug.print(" (none)\n", .{}); + } else { + // Indicate there may be more. + var total_for_status: usize = 0; + for (queue_items) |item| { + if (item != .object) continue; + const obj = item.object; + const st = utils.jsonGetString(obj, "status") orelse ""; + if (std.mem.eql(u8, st, want)) total_for_status += 1; + } + if (total_for_status > shown) { + std.debug.print(" ... and {d} more\n", .{total_for_status - shown}); + } + } + } + }; + + if (root.get("queue")) |queue_val| { + if (queue_val == .array) { + const items = queue_val.array.items; + try TaskPrinter.printSection(allocator, items, .queued, per_section_limit); + try TaskPrinter.printSection(allocator, items, .running, per_section_limit); + try TaskPrinter.printSection(allocator, items, .failed, per_section_limit); + try TaskPrinter.printSection(allocator, items, .completed, per_section_limit); + } + } + + if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| { + defer allocator.free(section); + colors.printInfo("{s}", .{section}); + } + } + } else { + // Handle plain text response - filter out non-printable characters + var clean_msg = allocator.alloc(u8, message.len) catch { + if (options.json) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); + } else { + std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); + } + return; + }; + defer allocator.free(clean_msg); + + var clean_len: usize = 0; + for (message) |byte| { + // Skip WebSocket frame header bytes and non-printable chars + if (byte >= 32 and byte <= 126) { // printable ASCII only + clean_msg[clean_len] = byte; + clean_len += 1; + } + } + + // Look for common error messages in the cleaned data + if (clean_len > 0) { + const cleaned = clean_msg[0..clean_len]; + if (options.json) { + if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{}); + } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"authentication_failed\"}}\n", .{}); + } else { + var out = io.stdoutWriter(); + try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned}); + } + } else { + if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { + std.debug.print("Insufficient permissions to view jobs\n", .{}); + } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { + std.debug.print("Authentication failed\n", .{}); + } else { + std.debug.print("Server response: {s}\n", .{cleaned}); + } + } + } else { + if (options.json) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); + } else { + std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); + } + } + return; + } +} + +/// Receive and handle cancel response with user permissions +pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // Check if message is JSON or plain text + if (message[0] == '{') { + // Parse JSON response + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{}); + defer parsed.deinit(); + const root = parsed.value.object; + + if (options.json) { + // Output raw JSON + var out = io.stdoutWriter(); + try out.print("{s}\n", .{message}); + } else { + // Display user-friendly output + if (root.get("success")) |success_val| { + if (success_val.bool) { + colors.printSuccess("Job '{s}' canceled successfully\n", .{job_name}); + } else { + colors.printError("Failed to cancel job '{s}'\n", .{job_name}); + if (root.get("error")) |error_val| { + colors.printError("Error: {s}\n", .{error_val.string}); + } + } + } else { + colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name }); + } + } + } else { + // Handle plain text response - filter out non-printable characters + var clean_msg = allocator.alloc(u8, message.len) catch { + if (options.json) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); + } else { + std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); + } + return; + }; + defer allocator.free(clean_msg); + + var clean_len: usize = 0; + for (message) |byte| { + // Skip WebSocket frame header bytes and non-printable chars + if (byte >= 32 and byte <= 126) { // printable ASCII only + clean_msg[clean_len] = byte; + clean_len += 1; + } + } + + if (clean_len > 0) { + const cleaned = clean_msg[0..clean_len]; + if (options.json) { + if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{}); + } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"authentication_failed\"}}\n", .{}); + } else { + var out = io.stdoutWriter(); + try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned}); + } + } else { + if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) { + std.debug.print("Insufficient permissions to cancel job\n", .{}); + } else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) { + std.debug.print("Authentication failed\n", .{}); + } else { + colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name }); + colors.printInfo("Response: {s}\n", .{cleaned}); + } + } + } else { + if (options.json) { + var out = io.stdoutWriter(); + try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len}); + } else { + std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len}); + } + } + return; + } +} + +/// Handle response packet with appropriate display +pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void { + switch (packet.packet_type) { + .success => { + if (packet.success_message) |msg| { + if (msg.len > 0) { + std.debug.print("✓ {s}: {s}\n", .{ operation, msg }); + } else { + std.debug.print("✓ {s} completed successfully\n", .{operation}); + } + } else { + std.debug.print("✓ {s} completed successfully\n", .{operation}); + } + }, + .error_packet => { + const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); + std.debug.print("✗ {s} failed: {s}\n", .{ operation, error_msg }); + + if (packet.error_message) |msg| { + if (msg.len > 0) { + std.debug.print("Details: {s}\n", .{msg}); + } + } + + if (packet.error_details) |details| { + if (details.len > 0) { + std.debug.print("Additional info: {s}\n", .{details}); + } + } + + // Convert to appropriate CLI error + return convertServerError(self, packet.error_code.?); + }, + .progress => { + if (packet.progress_type) |ptype| { + switch (ptype) { + .percentage => { + const percentage = packet.progress_value.?; + if (packet.progress_total) |total| { + std.debug.print("Progress: {d}/{d} ({d:.1}%)\n", .{ percentage, total, @as(f32, @floatFromInt(percentage)) * 100.0 / @as(f32, @floatFromInt(total)) }); + } else { + std.debug.print("Progress: {d}%\n", .{percentage}); + } + }, + .stage => { + if (packet.progress_message) |msg| { + std.debug.print("Stage: {s}\n", .{msg}); + } + }, + .message => { + if (packet.progress_message) |msg| { + std.debug.print("Info: {s}\n", .{msg}); + } + }, + .bytes_transferred => { + const bytes = packet.progress_value.?; + if (packet.progress_total) |total| { + const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; + const total_mb = @as(f64, @floatFromInt(total)) / 1024.0 / 1024.0; + std.debug.print("Transferred: {d:.2} MB / {d:.2} MB\n", .{ transferred_mb, total_mb }); + } else { + const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; + std.debug.print("Transferred: {d:.2} MB\n", .{transferred_mb}); + } + }, + } + } + }, + .status => { + if (packet.status_data) |data| { + std.debug.print("Status: {s}\n", .{data}); + } + }, + .data => { + if (packet.data_type) |dtype| { + std.debug.print("Data [{s}]: ", .{dtype}); + if (packet.data_payload) |payload| { + // Try to display as string if it looks like text + const is_text = for (payload) |byte| { + if (byte < 32 and byte != '\n' and byte != '\r' and byte != '\t') break false; + } else true; + + if (is_text) { + std.debug.print("{s}\n", .{payload}); + } else { + std.debug.print("{d} bytes\n", .{payload.len}); + } + } + } + }, + .log => { + if (packet.log_level) |level| { + const level_name = protocol.ResponsePacket.getLogLevelName(level); + if (packet.log_message) |msg| { + std.debug.print("[{s}] {s}\n", .{ level_name, msg }); + } + } + }, + } +} + +/// Convert server error code to CLI error +fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror { + _ = self; + return switch (server_error) { + .authentication_failed => error.AuthenticationFailed, + .permission_denied => error.PermissionDenied, + .resource_not_found => error.JobNotFound, + .resource_already_exists => error.ResourceExists, + .timeout => error.RequestTimeout, + .server_overloaded, .service_unavailable => error.ServerUnreachable, + .invalid_request => error.InvalidArguments, + .job_not_found => error.JobNotFound, + .job_already_running => error.JobAlreadyRunning, + .job_failed_to_start, .job_execution_failed => error.CommandFailed, + .job_cancelled => error.JobCancelled, + else => error.ServerError, + }; +} + +/// Clean up packet allocated memory +pub fn cleanupPacket(self: *Client, packet: protocol.ResponsePacket) void { + if (packet.success_message) |msg| { + self.allocator.free(msg); + } + if (packet.error_message) |msg| { + self.allocator.free(msg); + } + if (packet.error_details) |details| { + self.allocator.free(details); + } + if (packet.progress_message) |msg| { + self.allocator.free(msg); + } + if (packet.status_data) |data| { + self.allocator.free(data); + } + if (packet.data_type) |dtype| { + self.allocator.free(dtype); + } + if (packet.data_payload) |payload| { + self.allocator.free(payload); + } + if (packet.log_message) |msg| { + self.allocator.free(msg); + } +} diff --git a/cli/src/net/ws/utils.zig b/cli/src/net/ws/utils.zig new file mode 100644 index 0000000..63d061b --- /dev/null +++ b/cli/src/net/ws/utils.zig @@ -0,0 +1,25 @@ +const std = @import("std"); + +pub fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 { + const v_opt = obj.get(key); + if (v_opt == null) { + return null; + } + const v = v_opt.?; + if (v != .string) { + return null; + } + return v.string; +} + +pub fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 { + const v_opt = obj.get(key); + if (v_opt == null) { + return null; + } + const v = v_opt.?; + if (v != .integer) { + return null; + } + return v.integer; +} diff --git a/cli/src/utils.zig b/cli/src/utils.zig index 17f2833..25b4e70 100644 --- a/cli/src/utils.zig +++ b/cli/src/utils.zig @@ -1,8 +1,10 @@ -// Utils module - exports all utility modules +// Utilities module - exports all utility modules pub const colors = @import("utils/colors.zig"); pub const crypto = @import("utils/crypto.zig"); pub const history = @import("utils/history.zig"); +pub const io = @import("utils/io.zig"); pub const logging = @import("utils/logging.zig"); pub const rsync = @import("utils/rsync.zig"); pub const rsync_embedded = @import("utils/rsync_embedded.zig"); +pub const rsync_embedded_binary = @import("utils/rsync_embedded_binary.zig"); pub const storage = @import("utils/storage.zig"); diff --git a/cli/src/utils/crypto.zig b/cli/src/utils/crypto.zig index ef0d378..7b5b959 100644 --- a/cli/src/utils/crypto.zig +++ b/cli/src/utils/crypto.zig @@ -18,6 +18,7 @@ fn hexNibble(c: u8) ?u8 { pub fn decodeHex(allocator: std.mem.Allocator, hex: []const u8) ![]u8 { if ((hex.len % 2) != 0) return error.InvalidHex; const out = try allocator.alloc(u8, hex.len / 2); + errdefer allocator.free(out); var i: usize = 0; while (i < out.len) : (i += 1) { const hi = hexNibble(hex[i * 2]) orelse return error.InvalidHex; diff --git a/cli/src/utils/io.zig b/cli/src/utils/io.zig new file mode 100644 index 0000000..b79d99e --- /dev/null +++ b/cli/src/utils/io.zig @@ -0,0 +1,59 @@ +const std = @import("std"); + +fn writeAllFd(fd: std.posix.fd_t, data: []const u8) std.Io.Writer.Error!void { + var off: usize = 0; + while (off < data.len) { + const n = std.posix.write(fd, data[off..]) catch return error.WriteFailed; + if (n == 0) return error.WriteFailed; + off += n; + } +} + +fn drainStdout(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + _ = w; + if (data.len == 0) return 0; + var written: usize = 0; + for (data) |chunk| { + try writeAllFd(std.posix.STDOUT_FILENO, chunk); + written += chunk.len; + } + if (splat > 0) { + const last = data[data.len - 1]; + var i: usize = 0; + while (i < splat) : (i += 1) { + try writeAllFd(std.posix.STDOUT_FILENO, last); + written += last.len; + } + } + return written; +} + +fn drainStderr(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + _ = w; + if (data.len == 0) return 0; + var written: usize = 0; + for (data) |chunk| { + try writeAllFd(std.posix.STDERR_FILENO, chunk); + written += chunk.len; + } + if (splat > 0) { + const last = data[data.len - 1]; + var i: usize = 0; + while (i < splat) : (i += 1) { + try writeAllFd(std.posix.STDERR_FILENO, last); + written += last.len; + } + } + return written; +} + +const stdout_vtable = std.Io.Writer.VTable{ .drain = drainStdout }; +const stderr_vtable = std.Io.Writer.VTable{ .drain = drainStderr }; + +pub fn stdoutWriter() std.Io.Writer { + return .{ .vtable = &stdout_vtable, .buffer = &[_]u8{}, .end = 0 }; +} + +pub fn stderrWriter() std.Io.Writer { + return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 }; +} diff --git a/cli/src/utils/rsync_embedded.zig b/cli/src/utils/rsync_embedded.zig index 126cc9f..56d7bab 100644 --- a/cli/src/utils/rsync_embedded.zig +++ b/cli/src/utils/rsync_embedded.zig @@ -6,24 +6,127 @@ pub const EmbeddedRsync = struct { const Self = @This(); - /// Extract embedded rsync binary to temporary location - pub fn extractRsyncBinary(self: Self) ![]const u8 { - const rsync_path = "/tmp/ml_rsync"; + fn encodeHexLower(out: []u8, bytes: []const u8) void { + const hex = "0123456789abcdef"; + var i: usize = 0; + while (i < bytes.len) : (i += 1) { + const b = bytes[i]; + out[i * 2] = hex[(b >> 4) & 0x0f]; + out[i * 2 + 1] = hex[b & 0x0f]; + } + } - // Check if rsync binary already exists - if (std.fs.cwd().openFile(rsync_path, .{})) |file| { - file.close(); - // Check if it's executable - const stat = try std.fs.cwd().statFile(rsync_path); - if (stat.mode & 0o111 != 0) { - return try self.allocator.dupe(u8, rsync_path); - } - } else |_| { - // Need to extract the binary - try self.extractAndSetExecutable(rsync_path); + fn cacheBaseDir(self: Self) ![]u8 { + if (std.posix.getenv("XDG_CACHE_HOME")) |z| { + return self.allocator.dupe(u8, std.mem.sliceTo(z, 0)); } - return try self.allocator.dupe(u8, rsync_path); + const home = if (std.posix.getenv("HOME")) |z| std.mem.sliceTo(z, 0) else return error.MissingHome; + const target = @import("builtin").target; + if (target.os.tag == .macos) { + return std.fmt.allocPrint(self.allocator, "{s}/Library/Caches", .{home}); + } + return std.fmt.allocPrint(self.allocator, "{s}/.cache", .{home}); + } + + fn resolveRsyncPathOverride(self: Self) !?[]u8 { + const z = std.posix.getenv("ML_RSYNC_PATH") orelse return null; + const p = std.mem.sliceTo(z, 0); + if (p.len == 0) return null; + + // Basic validation: must exist and be executable. + const file = std.fs.openFileAbsolute(p, .{}) catch return error.InvalidRsyncPath; + defer file.close(); + const stat = try file.stat(); + if (stat.mode & 0o111 == 0) return error.InvalidRsyncPath; + + const dup = try self.allocator.dupe(u8, p); + return @as(?[]u8, dup); + } + + fn resolveSystemRsyncPath(self: Self) !?[]u8 { + const path_z = std.posix.getenv("PATH") orelse return null; + const path = std.mem.sliceTo(path_z, 0); + + var it = std.mem.splitScalar(u8, path, ':'); + while (it.next()) |dir| { + if (dir.len == 0) continue; + const candidate = try std.fmt.allocPrint(self.allocator, "{s}/{s}", .{ dir, "rsync" }); + errdefer self.allocator.free(candidate); + + const file = std.fs.openFileAbsolute(candidate, .{}) catch { + self.allocator.free(candidate); + continue; + }; + defer file.close(); + + const stat = file.stat() catch { + self.allocator.free(candidate); + continue; + }; + + if (stat.mode & 0o111 == 0) { + self.allocator.free(candidate); + continue; + } + + return @as(?[]u8, candidate); + } + + return null; + } + + /// Extract embedded rsync binary to temporary location + pub fn extractRsyncBinary(self: Self) ![]const u8 { + if (try self.resolveRsyncPathOverride()) |p| { + return p; + } + + if (try self.resolveSystemRsyncPath()) |p| { + return p; + } + + const embedded_binary = @import("rsync_embedded_binary.zig"); + const digest = embedded_binary.rsyncBinarySha256(); + + var digest_hex: [64]u8 = undefined; + encodeHexLower(digest_hex[0..], digest[0..]); + + const cache_base = try self.cacheBaseDir(); + defer self.allocator.free(cache_base); + + const cache_dir = try std.fmt.allocPrint(self.allocator, "{s}/fetchml/rsync", .{cache_base}); + defer self.allocator.free(cache_dir); + + std.fs.cwd().makePath(cache_dir) catch |err| switch (err) { + error.PathAlreadyExists => {}, + else => return err, + }; + + const rsync_path = try std.fmt.allocPrint(self.allocator, "{s}/ml_rsync_{s}", .{ cache_dir, digest_hex[0..16] }); + errdefer self.allocator.free(rsync_path); + + // If file exists and matches the embedded digest, reuse it. + if (std.fs.openFileAbsolute(rsync_path, .{})) |file| { + defer file.close(); + + const stat = try file.stat(); + if (stat.mode & 0o111 != 0) { + const data = try file.readToEndAlloc(self.allocator, 1024 * 1024 * 4); + defer self.allocator.free(data); + + var on_disk: [32]u8 = undefined; + std.crypto.hash.sha2.Sha256.hash(data, &on_disk, .{}); + if (std.mem.eql(u8, on_disk[0..], digest[0..])) { + return rsync_path; + } + } + } else |_| { + // Not present; will extract. + } + + try self.extractAndSetExecutable(rsync_path); + return rsync_path; } /// Extract rsync binary from embedded data and set executable permissions @@ -39,14 +142,15 @@ pub const EmbeddedRsync = struct { defer self.allocator.free(debug_msg); std.log.debug("{s}", .{debug_msg}); - // Write embedded binary to file system - try std.fs.cwd().writeFile(.{ .sub_path = path, .data = binary_data }); + // Write embedded binary to file system (absolute path) + var file = try std.fs.createFileAbsolute(path, .{ .truncate = true, .read = false, .mode = 0o755 }); + defer file.close(); + try file.writeAll(binary_data); - // Set executable permissions using OS API + // Ensure executable permissions const mode = 0o755; // rwxr-xr-x - std.posix.fchmodat(std.fs.cwd().fd, path, mode, 0) catch |err| { + std.posix.fchmod(file.handle, mode) catch |err| { std.log.warn("Failed to set executable permissions on {s}: {}", .{ path, err }); - // Continue anyway - the script might still work }; } @@ -79,7 +183,13 @@ pub const EmbeddedRsync = struct { child.stdout_behavior = .Inherit; child.stderr_behavior = .Inherit; - const term = try child.spawnAndWait(); + const term = child.spawnAndWait() catch |err| { + std.log.err( + "Failed to execute rsync at '{s}': {}. If your environment blocks executing extracted binaries (e.g. noexec), set ML_RSYNC_PATH to a system rsync.", + .{ rsync_path, err }, + ); + return err; + }; switch (term) { .Exited => |code| { diff --git a/cli/src/utils/rsync_embedded_binary.zig b/cli/src/utils/rsync_embedded_binary.zig index 3c467c1..6c7ef51 100644 --- a/cli/src/utils/rsync_embedded_binary.zig +++ b/cli/src/utils/rsync_embedded_binary.zig @@ -1,4 +1,119 @@ const std = @import("std"); +const build_options = @import("build_options"); + +fn isScript(data: []const u8) bool { + return data.len >= 2 and data[0] == '#' and data[1] == '!'; +} + +fn readU32BE(data: []const u8, offset: usize) ?u32 { + if (data.len < offset + 4) return null; + return (@as(u32, data[offset]) << 24) | + (@as(u32, data[offset + 1]) << 16) | + (@as(u32, data[offset + 2]) << 8) | + (@as(u32, data[offset + 3])); +} + +fn readU32LE(data: []const u8, offset: usize) ?u32 { + if (data.len < offset + 4) return null; + return (@as(u32, data[offset])) | + (@as(u32, data[offset + 1]) << 8) | + (@as(u32, data[offset + 2]) << 16) | + (@as(u32, data[offset + 3]) << 24); +} + +fn readU16LE(data: []const u8, offset: usize) ?u16 { + if (data.len < offset + 2) return null; + return (@as(u16, data[offset])) | (@as(u16, data[offset + 1]) << 8); +} + +fn readU16BE(data: []const u8, offset: usize) ?u16 { + if (data.len < offset + 2) return null; + return (@as(u16, data[offset]) << 8) | (@as(u16, data[offset + 1])); +} + +fn machCpuTypeForArch(arch: std.Target.Cpu.Arch) ?u32 { + return switch (arch) { + .x86_64 => 0x01000007, + .aarch64 => 0x0100000c, + else => null, + }; +} + +fn elfMachineForArch(arch: std.Target.Cpu.Arch) ?u16 { + return switch (arch) { + .x86_64 => 62, + .aarch64 => 183, + else => null, + }; +} + +fn machOHasCpuType(data: []const u8, want_cputype: u32) bool { + // Mach-O header: u32 magic; i32 cputype at offset 4. + // Accept both endiannesses (cputype will appear swapped if magic swapped). + const magic_le = readU32LE(data, 0) orelse return false; + const magic_be = readU32BE(data, 0) orelse return false; + + const MH_MAGIC_64: u32 = 0xfeedfacf; + const MH_CIGAM_64: u32 = 0xcffaedfe; + + if (magic_le == MH_MAGIC_64) { + const cputype = readU32LE(data, 4) orelse return false; + return cputype == want_cputype; + } + if (magic_be == MH_MAGIC_64 or magic_le == MH_CIGAM_64) { + const cputype = readU32BE(data, 4) orelse return false; + return cputype == want_cputype; + } + return false; +} + +fn fatMachOHasCpuType(data: []const u8, want_cputype: u32) bool { + // fat_header: magic (0xcafebabe big-endian), nfat_arch u32 + const FAT_MAGIC: u32 = 0xcafebabe; + const magic = readU32BE(data, 0) orelse return false; + if (magic != FAT_MAGIC) return false; + + const nfat = readU32BE(data, 4) orelse return false; + // fat_arch entries start at offset 8, 20 bytes each, big-endian + var off: usize = 8; + var i: u32 = 0; + while (i < nfat) : (i += 1) { + const cputype = readU32BE(data, off) orelse return false; + if (cputype == want_cputype) return true; + off += 20; + if (off > data.len) return false; + } + return false; +} + +fn elfHasMachine(data: []const u8, want_machine: u16) bool { + // Minimal ELF check: magic + e_machine at offset 18 + if (data.len < 20) return false; + if (!(data[0] == 0x7f and data[1] == 'E' and data[2] == 'L' and data[3] == 'F')) return false; + const ei_data = data[5]; // 1=little, 2=big + return switch (ei_data) { + 1 => (readU16LE(data, 18) orelse return false) == want_machine, + 2 => (readU16BE(data, 18) orelse return false) == want_machine, + else => false, + }; +} + +fn isNativeForTarget(data: []const u8) bool { + if (isScript(data)) return true; + + const target = @import("builtin").target; + const arch = target.cpu.arch; + + if (machCpuTypeForArch(arch)) |want_cputype| { + if (fatMachOHasCpuType(data, want_cputype)) return true; + if (machOHasCpuType(data, want_cputype)) return true; + } + if (elfMachineForArch(arch)) |want_machine| { + if (elfHasMachine(data, want_machine)) return true; + } + + return false; +} /// Embedded rsync binary data /// For dev builds: uses placeholder wrapper that calls system rsync @@ -8,16 +123,32 @@ const std = @import("std"); /// 1. Download or build a static rsync binary for your target platform /// 2. Place it at cli/src/assets/rsync_release.bin /// 3. Build with: zig build prod (or release/cross targets) +const placeholder_data = @embedFile("../assets/rsync_placeholder.bin"); -// Check if rsync_release.bin exists, otherwise fall back to placeholder -const rsync_binary_data = if (@import("builtin").mode == .Debug or @import("builtin").mode == .ReleaseSafe) - @embedFile("../assets/rsync_placeholder.bin") +const release_data = if (build_options.has_rsync_release) + @embedFile(build_options.rsync_release_path) else - // For ReleaseSmall and ReleaseFast, try to use the release binary - @embedFile("../assets/rsync_release.bin"); + placeholder_data; + +// Prefer placeholder in Debug/ReleaseSafe. In Release*, only use rsync_release.bin if it +// appears compatible with the target architecture; otherwise fall back to placeholder. +const use_release = (@import("builtin").mode != .Debug and @import("builtin").mode != .ReleaseSafe) and + build_options.has_rsync_release and + !isScript(release_data) and + isNativeForTarget(release_data); + +const rsync_binary_data = if (use_release) release_data else placeholder_data; pub const RSYNC_BINARY: []const u8 = rsync_binary_data; +pub const USING_RELEASE_BINARY: bool = use_release; + +pub fn rsyncBinarySha256() [32]u8 { + var digest: [32]u8 = undefined; + std.crypto.hash.sha2.Sha256.hash(RSYNC_BINARY, &digest, .{}); + return digest; +} + /// Get embedded rsync binary data pub fn getRsyncBinary() []const u8 { return RSYNC_BINARY; diff --git a/cli/tests/jupyter_test.zig b/cli/tests/jupyter_test.zig index 2dc3b45..4cb38c1 100644 --- a/cli/tests/jupyter_test.zig +++ b/cli/tests/jupyter_test.zig @@ -1,16 +1,17 @@ const std = @import("std"); const testing = std.testing; const src = @import("src"); +const commands = src.commands; test "jupyter top-level action includes create" { - try testing.expect(src.commands.jupyter.isValidTopLevelAction("create")); - try testing.expect(src.commands.jupyter.isValidTopLevelAction("start")); - try testing.expect(!src.commands.jupyter.isValidTopLevelAction("bogus")); + try testing.expect(commands.jupyter.isValidTopLevelAction("create")); + try testing.expect(commands.jupyter.isValidTopLevelAction("start")); + try testing.expect(!commands.jupyter.isValidTopLevelAction("bogus")); } test "jupyter defaultWorkspacePath prefixes ./" { const allocator = testing.allocator; - const p = try src.commands.jupyter.defaultWorkspacePath(allocator, "my-workspace"); + const p = try commands.jupyter.defaultWorkspacePath(allocator, "my-workspace"); defer allocator.free(p); try testing.expectEqualStrings("./my-workspace", p); diff --git a/cli/tests/main_test.zig b/cli/tests/main_test.zig index 850a38f..e17756a 100644 --- a/cli/tests/main_test.zig +++ b/cli/tests/main_test.zig @@ -6,6 +6,7 @@ test "CLI basic functionality" { // Test that CLI module can be imported const allocator = testing.allocator; _ = allocator; + _ = src; // Test basic string operations used in CLI const test_str = "ml sync"; diff --git a/cli/tests/queue_test.zig b/cli/tests/queue_test.zig index 05947c6..61f13c2 100644 --- a/cli/tests/queue_test.zig +++ b/cli/tests/queue_test.zig @@ -1,6 +1,5 @@ const std = @import("std"); const testing = std.testing; -const src = @import("src"); test "queue command argument parsing" { // Test various queue command argument combinations diff --git a/cli/tests/response_packets_test.zig b/cli/tests/response_packets_test.zig index e520764..27bce28 100644 --- a/cli/tests/response_packets_test.zig +++ b/cli/tests/response_packets_test.zig @@ -2,7 +2,6 @@ const std = @import("std"); const testing = std.testing; const src = @import("src"); - const protocol = src.net.protocol; fn roundTrip(allocator: std.mem.Allocator, packet: protocol.ResponsePacket) !protocol.ResponsePacket { diff --git a/cli/tests/rsync_embedded_test.zig b/cli/tests/rsync_embedded_test.zig index 28f2d43..a0e1074 100644 --- a/cli/tests/rsync_embedded_test.zig +++ b/cli/tests/rsync_embedded_test.zig @@ -1,30 +1,121 @@ const std = @import("std"); const testing = std.testing; -// Simple mock rsync for testing -const MockRsyncEmbedded = struct { - const EmbeddedRsync = struct { - allocator: std.mem.Allocator, +const c = @cImport({ + @cInclude("stdlib.h"); +}); - fn extractRsyncBinary(self: EmbeddedRsync) ![]const u8 { - // Simple mock - return a dummy path - return try std.fmt.allocPrint(self.allocator, "/tmp/mock_rsync", .{}); - } - }; -}; +const src = @import("src"); +const utils = src.utils; -const rsync_embedded = MockRsyncEmbedded; +fn isScript(data: []const u8) bool { + return data.len >= 2 and data[0] == '#' and data[1] == '!'; +} test "embedded rsync binary creation" { const allocator = testing.allocator; - var embedded_rsync = rsync_embedded.EmbeddedRsync{ .allocator = allocator }; + // Ensure the override doesn't influence this test. + _ = c.unsetenv("ML_RSYNC_PATH"); + + // Force embedded fallback by clearing PATH for this process so system rsync isn't found. + const old_path_z = std.posix.getenv("PATH"); + if (c.setenv("PATH", "", 1) != 0) return error.Unexpected; + defer { + if (old_path_z) |z| { + _ = c.setenv("PATH", z, 1); + } else { + _ = c.unsetenv("PATH"); + } + } + + var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator }; - // Test binary extraction const rsync_path = try embedded_rsync.extractRsyncBinary(); defer allocator.free(rsync_path); - // Verify the path was created try testing.expect(rsync_path.len > 0); - try testing.expect(std.mem.startsWith(u8, rsync_path, "/tmp/")); + + // File exists and is executable + const file = try std.fs.openFileAbsolute(rsync_path, .{}); + defer file.close(); + const stat = try file.stat(); + try testing.expect(stat.mode & 0o111 != 0); + + // Contents match embedded payload + const data = try file.readToEndAlloc(allocator, 1024 * 1024 * 4); + defer allocator.free(data); + + var digest: [32]u8 = undefined; + std.crypto.hash.sha2.Sha256.hash(data, &digest, .{}); + const embedded_digest = utils.rsync_embedded_binary.rsyncBinarySha256(); + try testing.expect(std.mem.eql(u8, &digest, &embedded_digest)); + + // Sanity: if we're not using the release binary, it should be the placeholder script. + if (!utils.rsync_embedded_binary.USING_RELEASE_BINARY) { + try testing.expect(isScript(data)); + } +} + +test "embedded rsync honors ML_RSYNC_PATH override" { + const allocator = testing.allocator; + + // Use a known executable. We don't execute it; we only verify the override is returned. + const override_path = "/bin/sh"; + try testing.expect(std.fs.openFileAbsolute(override_path, .{}) != error.FileNotFound); + + if (c.setenv("ML_RSYNC_PATH", override_path, 1) != 0) return error.Unexpected; + defer _ = c.unsetenv("ML_RSYNC_PATH"); + + var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator }; + const rsync_path = try embedded_rsync.extractRsyncBinary(); + defer allocator.free(rsync_path); + + try testing.expectEqualStrings(override_path, rsync_path); +} + +test "embedded rsync prefers system rsync when available" { + const allocator = testing.allocator; + + _ = c.unsetenv("ML_RSYNC_PATH"); + + // Find rsync on PATH (simple scan). If not present, skip. + const path_z = std.posix.getenv("PATH") orelse return; + const path = std.mem.sliceTo(path_z, 0); + + var sys_rsync: ?[]u8 = null; + defer if (sys_rsync) |p| allocator.free(p); + + var it = std.mem.splitScalar(u8, path, ':'); + while (it.next()) |dir| { + if (dir.len == 0) continue; + const candidate = std.fmt.allocPrint(allocator, "{s}/{s}", .{ dir, "rsync" }) catch return; + errdefer allocator.free(candidate); + + const file = std.fs.openFileAbsolute(candidate, .{}) catch { + allocator.free(candidate); + continue; + }; + defer file.close(); + + const st = file.stat() catch { + allocator.free(candidate); + continue; + }; + if (st.mode & 0o111 == 0) { + allocator.free(candidate); + continue; + } + + sys_rsync = candidate; + break; + } + + if (sys_rsync == null) return; + + var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator }; + const rsync_path = try embedded_rsync.extractRsyncBinary(); + defer allocator.free(rsync_path); + + try testing.expectEqualStrings(sys_rsync.?, rsync_path); } diff --git a/cli/tests/status_prewarm_test.zig b/cli/tests/status_prewarm_test.zig index dda79a0..f4b4137 100644 --- a/cli/tests/status_prewarm_test.zig +++ b/cli/tests/status_prewarm_test.zig @@ -3,7 +3,7 @@ const testing = std.testing; const src = @import("src"); -const ws = src.net.ws; +const ws = src.net.ws_client; test "status prewarm formatting - single entry" { var gpa = std.heap.GeneralPurposeAllocator(.{}){};