feat(cli): enhance Zig CLI with new commands and improved networking
- Add new commands: annotate, narrative, requeue - Refactor WebSocket client into modular components (net/ws/) - Add rsync embedded binary support - Improve error handling and response packet processing - Update build.zig and completions
This commit is contained in:
parent
df5d872021
commit
8e3fa94322
48 changed files with 4182 additions and 1958 deletions
27
cli/Makefile
27
cli/Makefile
|
|
@ -4,21 +4,36 @@ ZIG ?= zig
|
|||
BUILD_DIR ?= zig-out/bin
|
||||
BINARY := $(BUILD_DIR)/ml
|
||||
|
||||
.PHONY: all prod dev install clean help
|
||||
.PHONY: all prod dev test build-rsync install clean help
|
||||
|
||||
RSYNC_VERSION ?= 3.3.0
|
||||
RSYNC_SRC_BASE ?= https://download.samba.org/pub/rsync/src
|
||||
RSYNC_TARBALL ?= rsync-$(RSYNC_VERSION).tar.gz
|
||||
RSYNC_TARBALL_SHA256 ?=
|
||||
|
||||
all: $(BINARY)
|
||||
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
|
||||
$(BINARY): src/main.zig | $(BUILD_DIR)
|
||||
$(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BINARY) src/main.zig
|
||||
$(BINARY): | $(BUILD_DIR)
|
||||
$(ZIG) build --release=small
|
||||
|
||||
prod: src/main.zig | $(BUILD_DIR)
|
||||
$(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BUILD_DIR)/ml src/main.zig
|
||||
$(ZIG) build --release=small
|
||||
|
||||
dev: src/main.zig | $(BUILD_DIR)
|
||||
$(ZIG) build-exe -OReleaseFast -femit-bin=$(BUILD_DIR)/ml src/main.zig
|
||||
$(ZIG) build --release=fast
|
||||
|
||||
test:
|
||||
$(ZIG) build test
|
||||
|
||||
build-rsync:
|
||||
@RSYNC_VERSION="$(RSYNC_VERSION)" \
|
||||
RSYNC_SRC_BASE="$(RSYNC_SRC_BASE)" \
|
||||
RSYNC_TARBALL="$(RSYNC_TARBALL)" \
|
||||
RSYNC_TARBALL_SHA256="$(RSYNC_TARBALL_SHA256)" \
|
||||
bash "$(CURDIR)/scripts/build_rsync.sh"
|
||||
|
||||
install: $(BINARY)
|
||||
install -d $(DESTDIR)/usr/local/bin
|
||||
|
|
@ -32,5 +47,7 @@ help:
|
|||
@echo " all - build release-small binary (default)"
|
||||
@echo " prod - build production binary with ReleaseSmall"
|
||||
@echo " dev - build development binary with ReleaseFast"
|
||||
@echo " test - run Zig unit tests"
|
||||
@echo " build-rsync - build pinned rsync from official source into src/assets (RSYNC_VERSION=... override)"
|
||||
@echo " install - copy binary into /usr/local/bin"
|
||||
@echo " clean - remove build artifacts"
|
||||
|
|
@ -19,10 +19,12 @@ zig build
|
|||
|
||||
- `ml init` - Setup configuration
|
||||
- `ml sync <path>` - Sync project to server
|
||||
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N]` - Queue one or more jobs
|
||||
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N] [--note <text>]` - Queue one or more jobs
|
||||
- `ml status` - Check system/queue status for your API key
|
||||
- `ml validate <commit_id> [--json] [--task <task_id>]` - Validate provenance + integrity for a commit or task (includes `run_manifest.json` consistency checks when validating by task)
|
||||
- `ml info <path|id> [--json] [--base <path>]` - Show run info from `run_manifest.json` (by path or by scanning `finished/failed/running/pending`)
|
||||
- `ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]` - Append a human annotation to `run_manifest.json`
|
||||
- `ml narrative set <path|run_id|task_id> [--hypothesis <text>] [--context <text>] [--intent <text>] [--expected-outcome <text>] [--parent-run <id>] [--experiment-group <text>] [--tags <csv>] [--base <path>] [--json]` - Patch the `narrative` field in `run_manifest.json`
|
||||
- `ml monitor` - Launch monitoring interface (TUI)
|
||||
- `ml cancel <job>` - Cancel a running/queued job you own
|
||||
- `ml prune --keep N` - Keep N recent experiments
|
||||
|
|
@ -31,6 +33,7 @@ zig build
|
|||
|
||||
Notes:
|
||||
|
||||
- `--json` mode is designed to be pipe-friendly: machine-readable JSON is emitted to stdout, while user-facing messages/errors go to stderr.
|
||||
- When running `ml validate --task <task_id>`, the server will try to locate the job's `run_manifest.json` under the configured base path (pending/running/finished/failed) and cross-check key fields (task id, commit id, deps, snapshot).
|
||||
- For tasks in `running`, `completed`, or `failed` state, a missing `run_manifest.json` is treated as a validation failure. For `queued` tasks, it is treated as a warning (the job may not have started yet).
|
||||
|
||||
|
|
@ -43,6 +46,9 @@ Notes:
|
|||
Queues a job named `my-job`. If `--commit` is omitted, the CLI generates a random commit ID
|
||||
and records `(job_name, commit_id)` in `~/.ml/history.log` so you don't have to remember hashes.
|
||||
|
||||
- `ml queue my-job --note "baseline run; lr=1e-3"`
|
||||
Adds a human-readable note to the run; it will be persisted into the run's `run_manifest.json` (under `metadata.note`).
|
||||
|
||||
- `ml experiment list`
|
||||
Shows recent experiments from history with alias (job name) and commit ID.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,63 @@ pub fn build(b: *std.Build) void {
|
|||
// Standard target options
|
||||
const target = b.standardTargetOptions(.{});
|
||||
|
||||
const test_filter = b.option([]const u8, "test-filter", "Filter unit tests by name");
|
||||
_ = test_filter;
|
||||
|
||||
// Optimized release mode for size
|
||||
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseSmall });
|
||||
|
||||
const options = b.addOptions();
|
||||
|
||||
const arch = target.result.cpu.arch;
|
||||
const os_tag = target.result.os.tag;
|
||||
|
||||
const arch_str: []const u8 = switch (arch) {
|
||||
.x86_64 => "x86_64",
|
||||
.aarch64 => "arm64",
|
||||
else => "unknown",
|
||||
};
|
||||
const os_str: []const u8 = switch (os_tag) {
|
||||
.linux => "linux",
|
||||
.macos => "darwin",
|
||||
.windows => "windows",
|
||||
else => "unknown",
|
||||
};
|
||||
|
||||
const candidate_specific = b.fmt("src/assets/rsync_release_{s}_{s}.bin", .{ os_str, arch_str });
|
||||
const candidate_default = "src/assets/rsync_release.bin";
|
||||
|
||||
var selected_candidate: []const u8 = "";
|
||||
var has_rsync_release = false;
|
||||
|
||||
// Prefer a platform-specific asset if available.
|
||||
if (std.fs.cwd().openFile(candidate_specific, .{}) catch null) |f| {
|
||||
f.close();
|
||||
selected_candidate = candidate_specific;
|
||||
has_rsync_release = true;
|
||||
} else if (std.fs.cwd().openFile(candidate_default, .{}) catch null) |f| {
|
||||
f.close();
|
||||
selected_candidate = candidate_default;
|
||||
has_rsync_release = true;
|
||||
}
|
||||
|
||||
if ((optimize == .ReleaseSmall or optimize == .ReleaseFast) and !has_rsync_release) {
|
||||
std.debug.panic(
|
||||
"Release build requires an embedded rsync binary asset. Provide one of: '{s}' or '{s}'",
|
||||
.{ candidate_specific, candidate_default },
|
||||
);
|
||||
}
|
||||
|
||||
// rsync_embedded_binary.zig calls @embedFile() from cli/src/utils, so the embed path
|
||||
// must be relative to that directory.
|
||||
const selected_embed_path = if (has_rsync_release)
|
||||
b.fmt("../assets/{s}", .{std.fs.path.basename(selected_candidate)})
|
||||
else
|
||||
"";
|
||||
|
||||
options.addOption(bool, "has_rsync_release", has_rsync_release);
|
||||
options.addOption([]const u8, "rsync_release_path", selected_embed_path);
|
||||
|
||||
// CLI executable
|
||||
const exe = b.addExecutable(.{
|
||||
.name = "ml",
|
||||
|
|
@ -18,6 +72,10 @@ pub fn build(b: *std.Build) void {
|
|||
}),
|
||||
});
|
||||
|
||||
exe.root_module.strip = true;
|
||||
|
||||
exe.root_module.addOptions("build_options", options);
|
||||
|
||||
// Install the executable to zig-out/bin
|
||||
b.installArtifact(exe);
|
||||
|
||||
|
|
@ -44,6 +102,7 @@ pub fn build(b: *std.Build) void {
|
|||
.optimize = .Debug,
|
||||
}),
|
||||
});
|
||||
main_tests.root_module.addOptions("build_options", options);
|
||||
const run_main_tests = b.addRunArtifact(main_tests);
|
||||
test_step.dependOn(&run_main_tests.step);
|
||||
|
||||
|
|
@ -75,6 +134,8 @@ pub fn build(b: *std.Build) void {
|
|||
.optimize = .Debug,
|
||||
});
|
||||
|
||||
test_module.addOptions("build_options", options);
|
||||
|
||||
// Make src module available to tests as "src"
|
||||
test_module.addImport("src", src_module);
|
||||
|
||||
|
|
|
|||
83
cli/scripts/build_rsync.sh
Normal file
83
cli/scripts/build_rsync.sh
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
RSYNC_VERSION="${RSYNC_VERSION:-3.3.0}"
|
||||
RSYNC_SRC_BASE="${RSYNC_SRC_BASE:-https://download.samba.org/pub/rsync/src}"
|
||||
RSYNC_TARBALL="${RSYNC_TARBALL:-rsync-${RSYNC_VERSION}.tar.gz}"
|
||||
RSYNC_TARBALL_SHA256="${RSYNC_TARBALL_SHA256:-}"
|
||||
|
||||
os="$(uname -s | tr '[:upper:]' '[:lower:]')"
|
||||
arch="$(uname -m)"
|
||||
if [[ "${arch}" == "aarch64" || "${arch}" == "arm64" ]]; then arch="arm64"; fi
|
||||
if [[ "${arch}" == "x86_64" ]]; then arch="x86_64"; fi
|
||||
|
||||
if [[ "${os}" != "linux" ]]; then
|
||||
echo "build-rsync: supported on linux only (for reproducible official builds). Use system rsync on ${os} or build on a native runner." >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
out="${repo_root}/src/assets/rsync_release_linux_${arch}.bin"
|
||||
|
||||
tmp="$(mktemp -d)"
|
||||
cleanup() { rm -rf "${tmp}"; }
|
||||
trap cleanup EXIT
|
||||
|
||||
url="${RSYNC_SRC_BASE}/${RSYNC_TARBALL}"
|
||||
sig_url_asc="${url}.asc"
|
||||
sig_url_sig="${url}.sig"
|
||||
|
||||
echo "fetching ${url}"
|
||||
curl -fsSL "${url}" -o "${tmp}/rsync.tar.gz"
|
||||
|
||||
verified=0
|
||||
if command -v gpg >/dev/null 2>&1; then
|
||||
sig_file=""
|
||||
sig_url=""
|
||||
if curl -fsSL "${sig_url_asc}" -o "${tmp}/rsync.tar.gz.asc"; then
|
||||
sig_file="${tmp}/rsync.tar.gz.asc"
|
||||
sig_url="${sig_url_asc}"
|
||||
elif curl -fsSL "${sig_url_sig}" -o "${tmp}/rsync.tar.gz.sig"; then
|
||||
sig_file="${tmp}/rsync.tar.gz.sig"
|
||||
sig_url="${sig_url_sig}"
|
||||
fi
|
||||
|
||||
if [[ -n "${sig_file}" ]]; then
|
||||
echo "verifying signature ${sig_url}"
|
||||
if gpg --batch --verify "${sig_file}" "${tmp}/rsync.tar.gz"; then
|
||||
verified=1
|
||||
else
|
||||
echo "build-rsync: gpg signature check failed (often because the public key is not in your keyring)." >&2
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "${verified}" -ne 1 ]]; then
|
||||
if [[ -n "${RSYNC_TARBALL_SHA256}" ]]; then
|
||||
echo "verifying sha256 for ${url}"
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | sha256sum -c -
|
||||
elif command -v shasum >/dev/null 2>&1; then
|
||||
echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | shasum -a 256 -c -
|
||||
else
|
||||
echo "build-rsync: need sha256sum or shasum for checksum verification" >&2
|
||||
exit 2
|
||||
fi
|
||||
else
|
||||
echo "build-rsync: could not verify ${url} (no usable gpg signature, and RSYNC_TARBALL_SHA256 is empty)." >&2
|
||||
echo "Set RSYNC_TARBALL_SHA256=<expected sha256> or install gpg with a trusted key for the rsync signing identity." >&2
|
||||
exit 2
|
||||
fi
|
||||
fi
|
||||
|
||||
tar -C "${tmp}" -xzf "${tmp}/rsync.tar.gz"
|
||||
extract_dir="$(tar -tzf "${tmp}/rsync.tar.gz" | head -n 1 | cut -d/ -f1)"
|
||||
cd "${tmp}/${extract_dir}"
|
||||
|
||||
CC=musl-gcc CFLAGS="-O2" LDFLAGS="-static" ./configure --disable-xxhash --disable-zstd --disable-lz4
|
||||
make -j"$(getconf _NPROCESSORS_ONLN 2>/dev/null || echo 2)"
|
||||
|
||||
mkdir -p "$(dirname "${out}")"
|
||||
cp rsync "${out}"
|
||||
chmod +x "${out}"
|
||||
echo "built ${out}"
|
||||
|
|
@ -15,7 +15,7 @@ _ml_completions()
|
|||
global_opts="--help --verbose --quiet --monitor"
|
||||
|
||||
# Top-level subcommands
|
||||
cmds="init sync queue status monitor cancel prune watch dataset experiment"
|
||||
cmds="init sync queue requeue status monitor cancel prune watch dataset experiment"
|
||||
|
||||
# If completing the subcommand itself
|
||||
if [[ ${COMP_CWORD} -eq 1 ]]; then
|
||||
|
|
@ -41,8 +41,51 @@ _ml_completions()
|
|||
COMPREPLY=( $(compgen -d -- "${cur}") )
|
||||
;;
|
||||
queue)
|
||||
# Suggest common job names (static for now)
|
||||
COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") )
|
||||
queue_opts="--commit --priority --cpu --memory --gpu --gpu-memory --snapshot-id --snapshot-sha256 --args -- ${global_opts}"
|
||||
case "${prev}" in
|
||||
--priority)
|
||||
COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") )
|
||||
;;
|
||||
--cpu|--memory|--gpu)
|
||||
COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") )
|
||||
;;
|
||||
--gpu-memory)
|
||||
COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") )
|
||||
;;
|
||||
--commit|--snapshot-id|--snapshot-sha256|--args)
|
||||
# Free-form; no special completion
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${queue_opts}" -- "${cur}") )
|
||||
else
|
||||
# Suggest common job names (static for now)
|
||||
COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
requeue)
|
||||
requeue_opts="--name --priority --cpu --memory --gpu --gpu-memory --args -- ${global_opts}"
|
||||
case "${prev}" in
|
||||
--priority)
|
||||
COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") )
|
||||
;;
|
||||
--cpu|--memory|--gpu)
|
||||
COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") )
|
||||
;;
|
||||
--gpu-memory)
|
||||
COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") )
|
||||
;;
|
||||
--name|--args)
|
||||
# Free-form; no special completion
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${requeue_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
status)
|
||||
COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") )
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ _ml() {
|
|||
'init:Setup configuration interactively'
|
||||
'sync:Sync project to server'
|
||||
'queue:Queue job for execution'
|
||||
'requeue:Re-submit a previous run/commit'
|
||||
'status:Get system status'
|
||||
'monitor:Launch TUI via SSH'
|
||||
'cancel:Cancel running job'
|
||||
|
|
@ -53,7 +54,33 @@ _ml() {
|
|||
'--verbose[Enable verbose output]' \
|
||||
'--quiet[Suppress non-error output]' \
|
||||
'--monitor[Monitor progress]' \
|
||||
'1:job name:'
|
||||
'--commit[Commit id (40-hex) or unique prefix (>=7)]:commit id:' \
|
||||
'--priority[Priority (0-255)]:priority:' \
|
||||
'--cpu[CPU cores]:cpu:' \
|
||||
'--memory[Memory (GB)]:memory:' \
|
||||
'--gpu[GPU count]:gpu:' \
|
||||
'--gpu-memory[GPU memory]:gpu memory:' \
|
||||
'--snapshot-id[Snapshot id]:snapshot id:' \
|
||||
'--snapshot-sha256[Snapshot sha256]:snapshot sha256:' \
|
||||
'--args[Runner args string]:args:' \
|
||||
'1:job name:' \
|
||||
'*:args separator:(--)'
|
||||
;;
|
||||
requeue)
|
||||
_arguments -C \
|
||||
'--help[Show requeue help]' \
|
||||
'--verbose[Enable verbose output]' \
|
||||
'--quiet[Suppress non-error output]' \
|
||||
'--monitor[Monitor progress]' \
|
||||
'--name[Override job name]:job name:' \
|
||||
'--priority[Priority (0-255)]:priority:' \
|
||||
'--cpu[CPU cores]:cpu:' \
|
||||
'--memory[Memory (GB)]:memory:' \
|
||||
'--gpu[GPU count]:gpu:' \
|
||||
'--gpu-memory[GPU memory]:gpu memory:' \
|
||||
'--args[Runner args string]:args:' \
|
||||
'1:commit_id|run_id|task_id|path:' \
|
||||
'*:args separator:(--)'
|
||||
;;
|
||||
status)
|
||||
_arguments -C \
|
||||
|
|
|
|||
|
|
@ -3,4 +3,5 @@ pub const commands = @import("src/commands.zig");
|
|||
pub const net = @import("src/net.zig");
|
||||
pub const utils = @import("src/utils.zig");
|
||||
pub const config = @import("src/config.zig");
|
||||
pub const Config = @import("src/config.zig").Config;
|
||||
pub const errors = @import("src/errors.zig");
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
This directory contains rsync binaries for the ML CLI:
|
||||
|
||||
- `rsync_placeholder.bin` - Wrapper script for dev builds (calls system rsync)
|
||||
- `rsync_release.bin` - Full static rsync binary for release builds (not in repo)
|
||||
- `rsync_release_<os>_<arch>.bin` - Static rsync binary for release builds (not in repo)
|
||||
|
||||
## Build Modes
|
||||
|
||||
|
|
@ -16,44 +16,35 @@ This directory contains rsync binaries for the ML CLI:
|
|||
- Requires rsync installed on the system
|
||||
|
||||
### Release Builds (ReleaseSmall, ReleaseFast)
|
||||
- Uses `rsync_release.bin` (300-500KB static binary)
|
||||
- Uses `rsync_release_<os>_<arch>.bin` (static binary)
|
||||
- Fully self-contained, no dependencies
|
||||
- Results in ~450-650KB CLI binary
|
||||
- Works on any system without rsync installed
|
||||
|
||||
## Preparing Release Binaries
|
||||
|
||||
### Option 1: Download Pre-built Static Rsync
|
||||
### Option 1: Build from Official Rsync Source (recommended)
|
||||
|
||||
For macOS ARM64:
|
||||
On Linux:
|
||||
```bash
|
||||
cd cli/src/assets
|
||||
curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-macos-arm64 -o rsync_release.bin
|
||||
chmod +x rsync_release.bin
|
||||
cd cli
|
||||
make build-rsync RSYNC_VERSION=3.3.0
|
||||
```
|
||||
|
||||
For Linux x86_64:
|
||||
```bash
|
||||
cd cli/src/assets
|
||||
curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-linux-x86_64 -o rsync_release.bin
|
||||
chmod +x rsync_release.bin
|
||||
```
|
||||
|
||||
### Option 2: Build Static Rsync Yourself
|
||||
### Option 2: Build Rsync Yourself
|
||||
|
||||
```bash
|
||||
# Clone rsync
|
||||
git clone https://github.com/WayneD/rsync.git
|
||||
cd rsync
|
||||
# Download official source
|
||||
curl -fsSL https://download.samba.org/pub/rsync/src/rsync-3.3.0.tar.gz -o rsync.tar.gz
|
||||
tar -xzf rsync.tar.gz
|
||||
cd rsync-3.3.0
|
||||
|
||||
# Configure for static build
|
||||
./configure CFLAGS="-static" LDFLAGS="-static" --disable-xxhash --disable-zstd
|
||||
|
||||
# Build
|
||||
# Configure & build
|
||||
./configure --disable-xxhash --disable-zstd --disable-lz4
|
||||
make
|
||||
|
||||
# Copy to assets
|
||||
cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin
|
||||
# Copy to assets (example)
|
||||
cp rsync ../fetch_ml/cli/src/assets/rsync_release_linux_x86_64.bin
|
||||
```
|
||||
|
||||
### Option 3: Use System Rsync (Temporary)
|
||||
|
|
@ -61,21 +52,22 @@ cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin
|
|||
For testing release builds without a static binary:
|
||||
```bash
|
||||
cd cli/src/assets
|
||||
cp rsync_placeholder.bin rsync_release.bin
|
||||
cp rsync_placeholder.bin rsync_release_linux_x86_64.bin
|
||||
```
|
||||
|
||||
This will still use the wrapper, but allows builds to complete.
|
||||
|
||||
## Verification
|
||||
|
||||
After placing rsync_release.bin:
|
||||
After placing the appropriate `rsync_release_<os>_<arch>.bin`:
|
||||
|
||||
```bash
|
||||
# Verify it's executable
|
||||
file cli/src/assets/rsync_release.bin
|
||||
|
||||
# Test it
|
||||
./cli/src/assets/rsync_release.bin --version
|
||||
# Verify it's executable (example)
|
||||
file cli/src/assets/rsync_release_linux_x86_64.bin
|
||||
|
||||
# Test it (example)
|
||||
./cli/src/assets/rsync_release_linux_x86_64.bin --version
|
||||
|
||||
# Build release
|
||||
cd cli
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
// Commands module - exports all command modules
|
||||
pub const queue = @import("commands/queue.zig");
|
||||
pub const sync = @import("commands/sync.zig");
|
||||
pub const status = @import("commands/status.zig");
|
||||
pub const dataset = @import("commands/dataset.zig");
|
||||
pub const jupyter = @import("commands/jupyter.zig");
|
||||
pub const init = @import("commands/init.zig");
|
||||
pub const info = @import("commands/info.zig");
|
||||
pub const monitor = @import("commands/monitor.zig");
|
||||
pub const annotate = @import("commands/annotate.zig");
|
||||
pub const cancel = @import("commands/cancel.zig");
|
||||
pub const prune = @import("commands/prune.zig");
|
||||
pub const watch = @import("commands/watch.zig");
|
||||
pub const dataset = @import("commands/dataset.zig");
|
||||
pub const experiment = @import("commands/experiment.zig");
|
||||
pub const info = @import("commands/info.zig");
|
||||
pub const init = @import("commands/init.zig");
|
||||
pub const jupyter = @import("commands/jupyter.zig");
|
||||
pub const monitor = @import("commands/monitor.zig");
|
||||
pub const narrative = @import("commands/narrative.zig");
|
||||
pub const prune = @import("commands/prune.zig");
|
||||
pub const queue = @import("commands/queue.zig");
|
||||
pub const requeue = @import("commands/requeue.zig");
|
||||
pub const status = @import("commands/status.zig");
|
||||
pub const sync = @import("commands/sync.zig");
|
||||
pub const validate = @import("commands/validate.zig");
|
||||
pub const watch = @import("commands/watch.zig");
|
||||
|
|
|
|||
282
cli/src/commands/annotate.zig
Normal file
282
cli/src/commands/annotate.zig
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const target = args[0];
|
||||
|
||||
var author: []const u8 = "";
|
||||
var note: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const a = args[i];
|
||||
if (std.mem.eql(u8, a, "--author")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --author\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
author = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--note")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --note\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
note = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--base")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --base\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
base_override = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else if (std.mem.startsWith(u8, a, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
} else {
|
||||
colors.printError("Unexpected argument: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
if (note == null or std.mem.trim(u8, note.?, " \t\r\n").len == 0) {
|
||||
colors.printError("--note is required\n", .{});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
|
||||
const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const job_name = try readJobNameFromManifest(allocator, manifest_path);
|
||||
defer allocator.free(job_name);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendAnnotateRun(job_name, author, note.?, api_key_hash);
|
||||
|
||||
if (json_mode) {
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer {
|
||||
if (packet.success_message) |m| allocator.free(m);
|
||||
if (packet.error_message) |m| allocator.free(m);
|
||||
if (packet.error_details) |m| allocator.free(m);
|
||||
}
|
||||
|
||||
const Result = struct {
|
||||
ok: bool,
|
||||
job_name: []const u8,
|
||||
message: []const u8,
|
||||
error_code: ?u8 = null,
|
||||
error_message: ?[]const u8 = null,
|
||||
details: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
var out = io.stdoutWriter();
|
||||
if (packet.packet_type == .error_packet) {
|
||||
const res = Result{
|
||||
.ok = false,
|
||||
.job_name = job_name,
|
||||
.message = "",
|
||||
.error_code = @intFromEnum(packet.error_code.?),
|
||||
.error_message = packet.error_message orelse "",
|
||||
.details = packet.error_details orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return error.CommandFailed;
|
||||
}
|
||||
|
||||
const res = Result{
|
||||
.ok = true,
|
||||
.job_name = job_name,
|
||||
.message = packet.success_message orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return;
|
||||
}
|
||||
|
||||
try client.receiveAndHandleResponse(allocator, "Annotate");
|
||||
|
||||
colors.printSuccess("Annotation added\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
}
|
||||
|
||||
fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 {
|
||||
const data = try readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) return error.InvalidManifest;
|
||||
const root = parsed.value.object;
|
||||
|
||||
const job_name = jsonGetString(root, "job_name") orelse "";
|
||||
if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) {
|
||||
return error.InvalidManifest;
|
||||
}
|
||||
return allocator.dupe(u8, job_name);
|
||||
}
|
||||
|
||||
fn resolveManifestPathWithBase(
|
||||
allocator: std.mem.Allocator,
|
||||
input: []const u8,
|
||||
base: []const u8,
|
||||
) ![]u8 {
|
||||
var cwd = std.fs.cwd();
|
||||
|
||||
if (std.fs.path.isAbsolute(input)) {
|
||||
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
|
||||
var mutable_dir = dir;
|
||||
defer mutable_dir.close();
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
|
||||
var mutable_file = file;
|
||||
defer mutable_file.close();
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
return resolveManifestPathById(allocator, input, base);
|
||||
}
|
||||
|
||||
const stat = cwd.statFile(input) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
return resolveManifestPathById(allocator, input, base);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
|
||||
if (stat.kind == .directory) {
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
|
||||
if (std.mem.trim(u8, id, " \t\r\n").len == 0) {
|
||||
return error.FileNotFound;
|
||||
}
|
||||
if (base_path.len == 0) {
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
|
||||
for (roots) |root| {
|
||||
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
|
||||
defer allocator.free(root_path);
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(root_path))
|
||||
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
|
||||
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
|
||||
defer allocator.free(run_dir);
|
||||
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const file = if (std.fs.path.isAbsolute(manifest_path))
|
||||
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
|
||||
defer file.close();
|
||||
|
||||
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
|
||||
defer parsed.deinit();
|
||||
if (parsed.value != .object) continue;
|
||||
|
||||
const obj = parsed.value.object;
|
||||
const run_id = jsonGetString(obj, "run_id") orelse "";
|
||||
const task_id = jsonGetString(obj, "task_id") orelse "";
|
||||
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
|
||||
return allocator.dupe(u8, manifest_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
|
||||
var file = if (std.fs.path.isAbsolute(path))
|
||||
try std.fs.openFileAbsolute(path, .{})
|
||||
else
|
||||
try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
|
||||
return file.readToEndAlloc(allocator, 1024 * 1024);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v = obj.get(key) orelse return null;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{});
|
||||
colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{});
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
|
|
@ -18,7 +18,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
|
||||
var options = DatasetOptions{};
|
||||
|
||||
// Parse global flags: --dry-run, --validate, --json
|
||||
var positional = std.ArrayList([]const u8).initCapacity(allocator, args.len) catch |err| {
|
||||
return err;
|
||||
|
|
@ -44,36 +43,39 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
if (positional.items.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
const action = positional.items[0];
|
||||
|
||||
if (std.mem.eql(u8, action, "list")) {
|
||||
try listDatasets(allocator, &options);
|
||||
} else if (std.mem.eql(u8, action, "register")) {
|
||||
if (positional.items.len < 3) {
|
||||
colors.printError("Usage: ml dataset register <name> <url>\n", .{});
|
||||
switch (positional.items.len) {
|
||||
0 => {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
try registerDataset(allocator, positional.items[1], positional.items[2], &options);
|
||||
} else if (std.mem.eql(u8, action, "info")) {
|
||||
if (positional.items.len < 2) {
|
||||
colors.printError("Usage: ml dataset info <name>\n", .{});
|
||||
},
|
||||
1 => {
|
||||
if (std.mem.eql(u8, action, "list")) {
|
||||
try listDatasets(allocator, &options);
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
},
|
||||
2 => {
|
||||
if (std.mem.eql(u8, action, "info")) {
|
||||
try showDatasetInfo(allocator, positional.items[1], &options);
|
||||
return;
|
||||
} else if (std.mem.eql(u8, action, "search")) {
|
||||
try searchDatasets(allocator, positional.items[1], &options);
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
},
|
||||
3 => {
|
||||
if (std.mem.eql(u8, action, "register")) {
|
||||
try registerDataset(allocator, positional.items[1], positional.items[2], &options);
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
},
|
||||
else => {
|
||||
colors.printError("Unknoen action: {s}\n", .{action});
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
try showDatasetInfo(allocator, positional.items[1], &options);
|
||||
} else if (std.mem.eql(u8, action, "search")) {
|
||||
if (positional.items.len < 2) {
|
||||
colors.printError("Usage: ml dataset search <term>\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
try searchDatasets(allocator, positional.items[1], &options);
|
||||
} else {
|
||||
colors.printError("Unknown action: {s}\n", .{action});
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const config = @import("../config.zig");
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const history = @import("../utils/history.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const io = @import("../utils/io.zig");
|
||||
|
||||
pub const Options = struct {
|
||||
json: bool = false,
|
||||
|
|
@ -61,7 +62,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
defer allocator.free(data);
|
||||
|
||||
if (opts.json) {
|
||||
std.debug.print("{s}\n", .{data});
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{data});
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
|
|
@ -87,7 +87,9 @@ fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void
|
|||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to restore workspace: {s}\n", .{error_msg});
|
||||
if (packet.error_message) |msg| {
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
|
|
@ -138,8 +140,6 @@ pub fn isValidTopLevelAction(action: []const u8) bool {
|
|||
std.mem.eql(u8, action, "list") or
|
||||
std.mem.eql(u8, action, "remove") or
|
||||
std.mem.eql(u8, action, "restore") or
|
||||
std.mem.eql(u8, action, "workspace") or
|
||||
std.mem.eql(u8, action, "experiment") or
|
||||
std.mem.eql(u8, action, "package");
|
||||
}
|
||||
|
||||
|
|
@ -149,18 +149,18 @@ pub fn defaultWorkspacePath(allocator: std.mem.Allocator, name: []const u8) ![]u
|
|||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len < 1) {
|
||||
printUsage();
|
||||
printUsagePackage();
|
||||
return;
|
||||
}
|
||||
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
printUsagePackage();
|
||||
return;
|
||||
}
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
colors.printError("jupyter does not support --json\n", .{});
|
||||
printUsage();
|
||||
printUsagePackage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
|
@ -181,10 +181,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
try removeJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "restore")) {
|
||||
try restoreJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "workspace")) {
|
||||
try workspaceCommands(args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "experiment")) {
|
||||
try experimentCommands(args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "package")) {
|
||||
try packageCommands(args[1..]);
|
||||
} else {
|
||||
|
|
@ -192,12 +188,11 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
colors.printError("Usage: ml jupyter <action> [options]\n", .{});
|
||||
colors.printInfo("\nActions:\n", .{});
|
||||
colors.printInfo(" create|start|stop|status|list|remove|restore\n", .{});
|
||||
colors.printInfo(" workspace|experiment|package\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
fn printUsagePackage() void {
|
||||
colors.printError("Usage: ml jupyter package <action> [options]\n", .{});
|
||||
colors.printInfo("Actions:\n", .{});
|
||||
colors.printInfo(" list\n", .{});
|
||||
colors.printInfo("Options:\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
||||
|
|
@ -340,7 +335,9 @@ fn startJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to start service: {s}\n", .{error_msg});
|
||||
if (packet.error_message) |msg| {
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
|
|
@ -416,7 +413,9 @@ fn stopJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to stop service: {s}\n", .{error_msg});
|
||||
if (packet.error_message) |msg| {
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
|
|
@ -533,7 +532,9 @@ fn removeJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to remove service: {s}\n", .{error_msg});
|
||||
if (packet.error_message) |msg| {
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
|
|
@ -662,7 +663,9 @@ fn listServices(allocator: std.mem.Allocator) !void {
|
|||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to list services: {s}\n", .{error_msg});
|
||||
if (packet.error_message) |msg| {
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
|
|
@ -776,86 +779,133 @@ fn experimentCommands(args: []const []const u8) !void {
|
|||
|
||||
fn packageCommands(args: []const []const u8) !void {
|
||||
if (args.len < 1) {
|
||||
colors.printError("Usage: ml jupyter package <install|list|pending|approve|reject>\n", .{});
|
||||
colors.printError("Usage: ml jupyter package <list>\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const subcommand = args[0];
|
||||
|
||||
if (std.mem.eql(u8, subcommand, "install")) {
|
||||
if (std.mem.eql(u8, subcommand, "list")) {
|
||||
if (args.len < 2) {
|
||||
colors.printError("Usage: ml jupyter package install --package <name> [--channel <channel>] [--version <version>]\n", .{});
|
||||
colors.printError("Usage: ml jupyter package list <service-name>\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse package name from args
|
||||
var package_name: []const u8 = "";
|
||||
var channel: []const u8 = "conda-forge";
|
||||
var version: []const u8 = "latest";
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) {
|
||||
if (std.mem.eql(u8, args[i], "--package") and i + 1 < args.len) {
|
||||
package_name = args[i + 1];
|
||||
i += 2;
|
||||
} else if (std.mem.eql(u8, args[i], "--channel") and i + 1 < args.len) {
|
||||
channel = args[i + 1];
|
||||
i += 2;
|
||||
} else if (std.mem.eql(u8, args[i], "--version") and i + 1 < args.len) {
|
||||
version = args[i + 1];
|
||||
i += 2;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
var service_name: []const u8 = "";
|
||||
if (std.mem.eql(u8, args[1], "--name") and args.len >= 3) {
|
||||
service_name = args[2];
|
||||
} else {
|
||||
service_name = args[1];
|
||||
}
|
||||
|
||||
if (package_name.len == 0) {
|
||||
colors.printError("Package name is required\n", .{});
|
||||
if (service_name.len == 0) {
|
||||
colors.printError("Service name is required\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
// Security validations
|
||||
if (!validatePackageName(package_name)) {
|
||||
colors.printError("Invalid package name: {s}. Only alphanumeric characters, underscores, hyphens, and dots are allowed.\n", .{package_name});
|
||||
return;
|
||||
const allocator = std.heap.page_allocator;
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
if (isPackageBlocked(package_name)) {
|
||||
colors.printError("Package '{s}' is blocked by security policy for security reasons.\n", .{package_name});
|
||||
colors.printInfo("Blocked packages typically include network libraries that could be used for unauthorized data access.\n", .{});
|
||||
const protocol_str = if (config.worker_port == 443) "wss" else "ws";
|
||||
const url = try std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{
|
||||
protocol_str,
|
||||
config.worker_host,
|
||||
config.worker_port,
|
||||
});
|
||||
defer allocator.free(url);
|
||||
|
||||
var client = ws.Client.connect(allocator, url, config.api_key) catch |err| {
|
||||
colors.printError("Failed to connect to server: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
defer client.close();
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
client.sendListJupyterPackages(service_name, api_key_hash) catch |err| {
|
||||
colors.printError("Failed to send list packages command: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
|
||||
const response = client.receiveMessage(allocator) catch |err| {
|
||||
colors.printError("Failed to receive response: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
defer allocator.free(response);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(response, allocator) catch |err| {
|
||||
colors.printError("Failed to parse response: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
defer {
|
||||
if (packet.data_type) |dtype| allocator.free(dtype);
|
||||
if (packet.data_payload) |payload| allocator.free(payload);
|
||||
if (packet.error_message) |msg| allocator.free(msg);
|
||||
if (packet.error_details) |details| allocator.free(details);
|
||||
}
|
||||
|
||||
if (!validateChannel(channel)) {
|
||||
colors.printError("Channel '{s}' is not trusted. Allowed channels: conda-forge, defaults, pytorch, nvidia\n", .{channel});
|
||||
return;
|
||||
}
|
||||
switch (packet.packet_type) {
|
||||
.data => {
|
||||
colors.printInfo("Installed packages for {s}:\n", .{service_name});
|
||||
if (packet.data_payload) |payload| {
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
return;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
colors.printInfo("Requesting package installation...\n", .{});
|
||||
colors.printInfo("Package: {s}\n", .{package_name});
|
||||
colors.printInfo("Version: {s}\n", .{version});
|
||||
colors.printInfo("Channel: {s}\n", .{channel});
|
||||
colors.printInfo("Security: Package validated against security policies\n", .{});
|
||||
colors.printSuccess("Package request created successfully!\n", .{});
|
||||
colors.printInfo("Note: Package requires approval from administrator before installation.\n", .{});
|
||||
} else if (std.mem.eql(u8, subcommand, "list")) {
|
||||
colors.printInfo("Installed packages in workspace: ./workspace\n", .{});
|
||||
colors.printInfo("Package Name Version Channel Installed By\n", .{});
|
||||
colors.printInfo("------------ ------- ------- ------------\n", .{});
|
||||
colors.printInfo("numpy 1.21.0 conda-forge user1\n", .{});
|
||||
colors.printInfo("pandas 1.3.0 conda-forge user1\n", .{});
|
||||
} else if (std.mem.eql(u8, subcommand, "pending")) {
|
||||
colors.printInfo("Pending package requests for workspace: ./workspace\n", .{});
|
||||
colors.printInfo("Package Name Version Channel Requested By Time\n", .{});
|
||||
colors.printInfo("------------ ------- ------- ------------ ----\n", .{});
|
||||
colors.printInfo("torch 1.9.0 pytorch user3 2023-12-06 10:30\n", .{});
|
||||
} else if (std.mem.eql(u8, subcommand, "approve")) {
|
||||
colors.printInfo("Approving package request: torch\n", .{});
|
||||
colors.printSuccess("Package request approved!\n", .{});
|
||||
} else if (std.mem.eql(u8, subcommand, "reject")) {
|
||||
colors.printInfo("Rejecting package request: suspicious-package\n", .{});
|
||||
colors.printInfo("Reason: Security policy violation\n", .{});
|
||||
colors.printSuccess("Package request rejected!\n", .{});
|
||||
if (parsed.value != .array) {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
return;
|
||||
}
|
||||
|
||||
const pkgs = parsed.value.array;
|
||||
if (pkgs.items.len == 0) {
|
||||
colors.printInfo("No packages found.\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
colors.printInfo("NAME VERSION SOURCE\n", .{});
|
||||
colors.printInfo("---- ------- ------\n", .{});
|
||||
|
||||
for (pkgs.items) |item| {
|
||||
if (item != .object) continue;
|
||||
const obj = item.object;
|
||||
|
||||
var name: []const u8 = "";
|
||||
if (obj.get("name")) |v| {
|
||||
if (v == .string) name = v.string;
|
||||
}
|
||||
var version: []const u8 = "";
|
||||
if (obj.get("version")) |v| {
|
||||
if (v == .string) version = v.string;
|
||||
}
|
||||
var source: []const u8 = "";
|
||||
if (obj.get("source")) |v| {
|
||||
if (v == .string) source = v.string;
|
||||
}
|
||||
|
||||
std.debug.print("{s: <30} {s: <22} {s}\n", .{ name, version, source });
|
||||
}
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to list packages: {s}\n", .{error_msg});
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
else => {
|
||||
colors.printError("Unexpected response type\n", .{});
|
||||
},
|
||||
}
|
||||
} else {
|
||||
colors.printError("Invalid package command: {s}\n", .{subcommand});
|
||||
}
|
||||
|
|
|
|||
370
cli/src/commands/narrative.zig
Normal file
370
cli/src/commands/narrative.zig
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const sub = argv[0];
|
||||
if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!std.mem.eql(u8, sub, "set")) {
|
||||
colors.printError("Unknown subcommand: {s}\n", .{sub});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (argv.len < 2) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const target = argv[1];
|
||||
|
||||
var hypothesis: ?[]const u8 = null;
|
||||
var context: ?[]const u8 = null;
|
||||
var intent: ?[]const u8 = null;
|
||||
var expected_outcome: ?[]const u8 = null;
|
||||
var parent_run: ?[]const u8 = null;
|
||||
var experiment_group: ?[]const u8 = null;
|
||||
var tags_csv: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
|
||||
var i: usize = 2;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const a = argv[i];
|
||||
if (std.mem.eql(u8, a, "--hypothesis")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
hypothesis = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--context")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
context = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--intent")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
intent = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--expected-outcome")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
expected_outcome = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--parent-run")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
parent_run = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--experiment-group")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
experiment_group = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--tags")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
tags_csv = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--base")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
base_override = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
if (hypothesis == null and context == null and intent == null and expected_outcome == null and parent_run == null and experiment_group == null and tags_csv == null) {
|
||||
colors.printError("No narrative fields provided.\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const job_name = try readJobNameFromManifest(allocator, manifest_path);
|
||||
defer allocator.free(job_name);
|
||||
|
||||
const patch_json = try buildPatchJSON(
|
||||
allocator,
|
||||
hypothesis,
|
||||
context,
|
||||
intent,
|
||||
expected_outcome,
|
||||
parent_run,
|
||||
experiment_group,
|
||||
tags_csv,
|
||||
);
|
||||
defer allocator.free(patch_json);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendSetRunNarrative(job_name, patch_json, api_key_hash);
|
||||
|
||||
if (json_mode) {
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer {
|
||||
if (packet.success_message) |m| allocator.free(m);
|
||||
if (packet.error_message) |m| allocator.free(m);
|
||||
if (packet.error_details) |m| allocator.free(m);
|
||||
}
|
||||
|
||||
const Result = struct {
|
||||
ok: bool,
|
||||
job_name: []const u8,
|
||||
message: []const u8,
|
||||
error_code: ?u8 = null,
|
||||
error_message: ?[]const u8 = null,
|
||||
details: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
var out = io.stdoutWriter();
|
||||
if (packet.packet_type == .error_packet) {
|
||||
const res = Result{
|
||||
.ok = false,
|
||||
.job_name = job_name,
|
||||
.message = "",
|
||||
.error_code = @intFromEnum(packet.error_code.?),
|
||||
.error_message = packet.error_message orelse "",
|
||||
.details = packet.error_details orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return error.CommandFailed;
|
||||
}
|
||||
|
||||
const res = Result{
|
||||
.ok = true,
|
||||
.job_name = job_name,
|
||||
.message = packet.success_message orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return;
|
||||
}
|
||||
|
||||
try client.receiveAndHandleResponse(allocator, "Narrative");
|
||||
|
||||
colors.printSuccess("Narrative updated\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
}
|
||||
|
||||
fn buildPatchJSON(
|
||||
allocator: std.mem.Allocator,
|
||||
hypothesis: ?[]const u8,
|
||||
context: ?[]const u8,
|
||||
intent: ?[]const u8,
|
||||
expected_outcome: ?[]const u8,
|
||||
parent_run: ?[]const u8,
|
||||
experiment_group: ?[]const u8,
|
||||
tags_csv: ?[]const u8,
|
||||
) ![]u8 {
|
||||
var out = std.ArrayList(u8).initCapacity(allocator, 256) catch return error.OutOfMemory;
|
||||
defer out.deinit(allocator);
|
||||
|
||||
var tags_list = std.ArrayList([]const u8).initCapacity(allocator, 8) catch return error.OutOfMemory;
|
||||
defer tags_list.deinit(allocator);
|
||||
|
||||
if (tags_csv) |csv| {
|
||||
var it = std.mem.splitScalar(u8, csv, ',');
|
||||
while (it.next()) |part| {
|
||||
const trimmed = std.mem.trim(u8, part, " \t\r\n");
|
||||
if (trimmed.len == 0) continue;
|
||||
try tags_list.append(allocator, trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
const Patch = struct {
|
||||
hypothesis: ?[]const u8 = null,
|
||||
context: ?[]const u8 = null,
|
||||
intent: ?[]const u8 = null,
|
||||
expected_outcome: ?[]const u8 = null,
|
||||
parent_run: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
tags: ?[]const []const u8 = null,
|
||||
};
|
||||
|
||||
const patch = Patch{
|
||||
.hypothesis = hypothesis,
|
||||
.context = context,
|
||||
.intent = intent,
|
||||
.expected_outcome = expected_outcome,
|
||||
.parent_run = parent_run,
|
||||
.experiment_group = experiment_group,
|
||||
.tags = if (tags_list.items.len > 0) tags_list.items else null,
|
||||
};
|
||||
|
||||
const writer = out.writer(allocator);
|
||||
try writer.print("{f}", .{std.json.fmt(patch, .{})});
|
||||
return out.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 {
|
||||
const data = try readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) return error.InvalidManifest;
|
||||
const root = parsed.value.object;
|
||||
|
||||
const job_name = jsonGetString(root, "job_name") orelse "";
|
||||
if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) {
|
||||
return error.InvalidManifest;
|
||||
}
|
||||
return allocator.dupe(u8, job_name);
|
||||
}
|
||||
|
||||
fn resolveManifestPathWithBase(allocator: std.mem.Allocator, input: []const u8, base: []const u8) ![]u8 {
|
||||
var cwd = std.fs.cwd();
|
||||
|
||||
if (std.fs.path.isAbsolute(input)) {
|
||||
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
|
||||
var mutable_dir = dir;
|
||||
defer mutable_dir.close();
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
|
||||
var mutable_file = file;
|
||||
defer mutable_file.close();
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
return resolveManifestPathById(allocator, input, base);
|
||||
}
|
||||
|
||||
const stat = cwd.statFile(input) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
return resolveManifestPathById(allocator, input, base);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
|
||||
if (stat.kind == .directory) {
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
|
||||
if (std.mem.trim(u8, id, " \t\r\n").len == 0) {
|
||||
return error.FileNotFound;
|
||||
}
|
||||
if (base_path.len == 0) {
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
|
||||
for (roots) |root| {
|
||||
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
|
||||
defer allocator.free(root_path);
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(root_path))
|
||||
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
|
||||
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
|
||||
defer allocator.free(run_dir);
|
||||
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const file = if (std.fs.path.isAbsolute(manifest_path))
|
||||
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
|
||||
defer file.close();
|
||||
|
||||
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
|
||||
defer parsed.deinit();
|
||||
if (parsed.value != .object) continue;
|
||||
|
||||
const obj = parsed.value.object;
|
||||
const run_id = jsonGetString(obj, "run_id") orelse "";
|
||||
const task_id = jsonGetString(obj, "task_id") orelse "";
|
||||
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
|
||||
return allocator.dupe(u8, manifest_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
|
||||
var file = if (std.fs.path.isAbsolute(path))
|
||||
try std.fs.openFileAbsolute(path, .{})
|
||||
else
|
||||
try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
|
||||
return file.readToEndAlloc(allocator, 1024 * 1024);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v = obj.get(key) orelse return null;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml narrative set <path|run_id|task_id> [fields]\n", .{});
|
||||
colors.printInfo("\nFields:\n", .{});
|
||||
colors.printInfo(" --hypothesis \"...\"\n", .{});
|
||||
colors.printInfo(" --context \"...\"\n", .{});
|
||||
colors.printInfo(" --intent \"...\"\n", .{});
|
||||
colors.printInfo(" --expected-outcome \"...\"\n", .{});
|
||||
colors.printInfo(" --parent-run <id>\n", .{});
|
||||
colors.printInfo(" --experiment-group <name>\n", .{});
|
||||
colors.printInfo(" --tags a,b,c\n", .{});
|
||||
colors.printInfo(" --base <path>\n", .{});
|
||||
colors.printInfo(" --json\n", .{});
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const history = @import("../utils/history.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
|
|
@ -42,6 +42,43 @@ pub const QueueOptions = struct {
|
|||
gpu_memory: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
|
||||
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
|
||||
for (input) |c| {
|
||||
if (!std.ascii.isHex(c)) return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (input.len == 40) {
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(base_path))
|
||||
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
|
||||
else
|
||||
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
var found: ?[]u8 = null;
|
||||
errdefer if (found) |s| allocator.free(s);
|
||||
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
const name = entry.name;
|
||||
if (name.len != 40) continue;
|
||||
if (!std.mem.startsWith(u8, name, input)) continue;
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isHex(c)) break;
|
||||
} else {
|
||||
if (found != null) return error.InvalidArgs;
|
||||
found = try allocator.dupe(u8, name);
|
||||
}
|
||||
}
|
||||
|
||||
if (found) |s| return s;
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
try printUsage();
|
||||
|
|
@ -64,6 +101,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var priority: u8 = 5;
|
||||
var snapshot_id: ?[]const u8 = null;
|
||||
var snapshot_sha256: ?[]const u8 = null;
|
||||
var args_override: ?[]const u8 = null;
|
||||
var note_override: ?[]const u8 = null;
|
||||
|
||||
// Load configuration to get defaults
|
||||
const config = try Config.load(allocator);
|
||||
|
|
@ -88,10 +127,33 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var tracking = TrackingConfig{};
|
||||
var has_tracking = false;
|
||||
|
||||
// Support passing runner args after "--".
|
||||
var sep_index: ?usize = null;
|
||||
for (args, 0..) |a, idx| {
|
||||
if (std.mem.eql(u8, a, "--")) {
|
||||
sep_index = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const pre = args[0..(sep_index orelse args.len)];
|
||||
const post = if (sep_index) |si| args[(si + 1)..] else args[0..0];
|
||||
|
||||
var args_joined: []const u8 = "";
|
||||
if (post.len > 0) {
|
||||
var buf: std.ArrayList(u8) = .{};
|
||||
defer buf.deinit(allocator);
|
||||
for (post, 0..) |a, j| {
|
||||
if (j > 0) try buf.append(allocator, ' ');
|
||||
try buf.appendSlice(allocator, a);
|
||||
}
|
||||
args_joined = try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
defer if (post.len > 0) allocator.free(args_joined);
|
||||
|
||||
// Parse arguments - separate job names from flags
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
while (i < pre.len) : (i += 1) {
|
||||
const arg = pre[i];
|
||||
|
||||
if (std.mem.startsWith(u8, arg, "--") or std.mem.eql(u8, arg, "-h")) {
|
||||
// Parse flags
|
||||
|
|
@ -99,15 +161,21 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
try printUsage();
|
||||
return;
|
||||
}
|
||||
if (std.mem.eql(u8, arg, "--commit") and i + 1 < args.len) {
|
||||
if (std.mem.eql(u8, arg, "--commit") and i + 1 < pre.len) {
|
||||
if (commit_id_override != null) {
|
||||
allocator.free(commit_id_override.?);
|
||||
}
|
||||
const commit_hex = args[i + 1];
|
||||
if (commit_hex.len != 40) {
|
||||
colors.printError("Invalid commit id: expected 40-char hex string\n", .{});
|
||||
const commit_in = pre[i + 1];
|
||||
const commit_hex = resolveCommitHexOrPrefix(allocator, config.worker_base, commit_in) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError("No commit matches prefix: {s}\n", .{commit_in});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
colors.printError("Invalid commit id\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
};
|
||||
defer allocator.free(commit_hex);
|
||||
|
||||
const commit_bytes = crypto.decodeHex(allocator, commit_hex) catch {
|
||||
colors.printError("Invalid commit id: must be hex\n", .{});
|
||||
return error.InvalidArgs;
|
||||
|
|
@ -119,35 +187,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
commit_id_override = commit_bytes;
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
} else if (std.mem.eql(u8, arg, "--priority") and i + 1 < pre.len) {
|
||||
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--mlflow")) {
|
||||
tracking.mlflow = TrackingConfig.MLflowConfig{};
|
||||
has_tracking = true;
|
||||
} else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < args.len) {
|
||||
} else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < pre.len) {
|
||||
tracking.mlflow = TrackingConfig.MLflowConfig{
|
||||
.mode = "remote",
|
||||
.tracking_uri = args[i + 1],
|
||||
.tracking_uri = pre[i + 1],
|
||||
};
|
||||
has_tracking = true;
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--tensorboard")) {
|
||||
tracking.tensorboard = TrackingConfig.TensorBoardConfig{};
|
||||
has_tracking = true;
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < args.len) {
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < pre.len) {
|
||||
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
|
||||
tracking.wandb.?.api_key = args[i + 1];
|
||||
tracking.wandb.?.api_key = pre[i + 1];
|
||||
has_tracking = true;
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < args.len) {
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < pre.len) {
|
||||
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
|
||||
tracking.wandb.?.project = args[i + 1];
|
||||
tracking.wandb.?.project = pre[i + 1];
|
||||
has_tracking = true;
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < args.len) {
|
||||
} else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < pre.len) {
|
||||
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
|
||||
tracking.wandb.?.entity = args[i + 1];
|
||||
tracking.wandb.?.entity = pre[i + 1];
|
||||
has_tracking = true;
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--dry-run")) {
|
||||
|
|
@ -158,23 +226,29 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
options.explain = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < args.len) {
|
||||
options.cpu = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
} else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < pre.len) {
|
||||
options.cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--memory") and i + 1 < args.len) {
|
||||
options.memory = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
} else if (std.mem.eql(u8, arg, "--memory") and i + 1 < pre.len) {
|
||||
options.memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < args.len) {
|
||||
options.gpu = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
} else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < pre.len) {
|
||||
options.gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < args.len) {
|
||||
options.gpu_memory = args[i + 1];
|
||||
} else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < pre.len) {
|
||||
options.gpu_memory = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < args.len) {
|
||||
snapshot_id = args[i + 1];
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < pre.len) {
|
||||
snapshot_id = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < args.len) {
|
||||
snapshot_sha256 = args[i + 1];
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < pre.len) {
|
||||
snapshot_sha256 = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--args") and i + 1 < pre.len) {
|
||||
args_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
|
||||
note_override = pre[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -227,10 +301,25 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
|
||||
defer if (commit_id_override) |cid| allocator.free(cid);
|
||||
|
||||
const args_str: []const u8 = if (args_override) |a| a else args_joined;
|
||||
const note_str: []const u8 = if (note_override) |n| n else "";
|
||||
|
||||
for (job_names.items, 0..) |job_name, index| {
|
||||
colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name });
|
||||
|
||||
queueSingleJob(allocator, job_name, commit_id_override, priority, tracking_json, &options, snapshot_id, snapshot_sha256, print_next_steps) catch |err| {
|
||||
queueSingleJob(
|
||||
allocator,
|
||||
job_name,
|
||||
commit_id_override,
|
||||
priority,
|
||||
tracking_json,
|
||||
&options,
|
||||
snapshot_id,
|
||||
snapshot_sha256,
|
||||
args_str,
|
||||
note_str,
|
||||
print_next_steps,
|
||||
) catch |err| {
|
||||
colors.printError("Failed to queue job '{s}': {}\n", .{ job_name, err });
|
||||
failed_jobs.append(allocator, job_name) catch |append_err| {
|
||||
colors.printError("Failed to track failed job: {}\n", .{append_err});
|
||||
|
|
@ -274,6 +363,8 @@ fn queueSingleJob(
|
|||
options: *const QueueOptions,
|
||||
snapshot_id: ?[]const u8,
|
||||
snapshot_sha256: ?[]const u8,
|
||||
args_str: []const u8,
|
||||
note_str: []const u8,
|
||||
print_next_steps: bool,
|
||||
) !void {
|
||||
const commit_id = blk: {
|
||||
|
|
@ -324,6 +415,33 @@ fn queueSingleJob(
|
|||
options.gpu,
|
||||
options.gpu_memory,
|
||||
);
|
||||
} else if (note_str.len > 0 or args_str.len > 0) {
|
||||
if (note_str.len > 0) {
|
||||
try client.sendQueueJobWithArgsNoteAndResources(
|
||||
job_name,
|
||||
commit_id,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_str,
|
||||
note_str,
|
||||
options.cpu,
|
||||
options.memory,
|
||||
options.gpu,
|
||||
options.gpu_memory,
|
||||
);
|
||||
} else {
|
||||
try client.sendQueueJobWithArgsAndResources(
|
||||
job_name,
|
||||
commit_id,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_str,
|
||||
options.cpu,
|
||||
options.memory,
|
||||
options.gpu,
|
||||
options.gpu_memory,
|
||||
);
|
||||
}
|
||||
} else if (snapshot_id) |sid| {
|
||||
try client.sendQueueJobWithSnapshotAndResources(
|
||||
job_name,
|
||||
|
|
@ -374,6 +492,9 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --memory <gb> Memory in GB (default: 8)\n", .{});
|
||||
colors.printInfo(" --gpu <count> GPU count (default: 0)\n", .{});
|
||||
colors.printInfo(" --gpu-memory <gb> GPU memory budget (default: auto)\n", .{});
|
||||
colors.printInfo(" --args <string> Extra runner args (sent to worker as task.Args)\n", .{});
|
||||
colors.printInfo(" --note <string> Human notes (stored in run manifest as metadata.note)\n", .{});
|
||||
colors.printInfo(" -- <args...> Extra runner args (alternative to --args)\n", .{});
|
||||
colors.printInfo("\nSpecial Modes:\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be submitted\n", .{});
|
||||
colors.printInfo(" --validate Validate experiment without submitting\n", .{});
|
||||
|
|
|
|||
345
cli/src/commands/requeue.zig
Normal file
345
cli/src/commands/requeue.zig
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const target = argv[0];
|
||||
|
||||
// Split args at "--".
|
||||
var sep_index: ?usize = null;
|
||||
for (argv, 0..) |a, i| {
|
||||
if (std.mem.eql(u8, a, "--")) {
|
||||
sep_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const pre = argv[1..(sep_index orelse argv.len)];
|
||||
const post = if (sep_index) |i| argv[(i + 1)..] else argv[0..0];
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Defaults
|
||||
var job_name_override: ?[]const u8 = null;
|
||||
var priority: u8 = cfg.default_priority;
|
||||
var cpu: u8 = cfg.default_cpu;
|
||||
var memory: u8 = cfg.default_memory;
|
||||
var gpu: u8 = cfg.default_gpu;
|
||||
var gpu_memory: ?[]const u8 = cfg.default_gpu_memory;
|
||||
var args_override: ?[]const u8 = null;
|
||||
var note_override: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < pre.len) : (i += 1) {
|
||||
const a = pre[i];
|
||||
if (std.mem.eql(u8, a, "--name") and i + 1 < pre.len) {
|
||||
job_name_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--priority") and i + 1 < pre.len) {
|
||||
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--cpu") and i + 1 < pre.len) {
|
||||
cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--memory") and i + 1 < pre.len) {
|
||||
memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--gpu") and i + 1 < pre.len) {
|
||||
gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--gpu-memory") and i + 1 < pre.len) {
|
||||
gpu_memory = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--args") and i + 1 < pre.len) {
|
||||
args_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--note") and i + 1 < pre.len) {
|
||||
note_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
var args_joined: []const u8 = "";
|
||||
if (post.len > 0) {
|
||||
var buf: std.ArrayList(u8) = .{};
|
||||
defer buf.deinit(allocator);
|
||||
for (post, 0..) |a, idx| {
|
||||
if (idx > 0) try buf.append(allocator, ' ');
|
||||
try buf.appendSlice(allocator, a);
|
||||
}
|
||||
args_joined = try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
defer if (post.len > 0) allocator.free(args_joined);
|
||||
|
||||
const args_final: []const u8 = if (args_override) |a| a else args_joined;
|
||||
const note_final: []const u8 = if (note_override) |n| n else "";
|
||||
|
||||
// Target can be:
|
||||
// - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base
|
||||
// - run_id/task_id/path (resolved to run_manifest.json to read commit_id)
|
||||
var commit_hex: []const u8 = "";
|
||||
var commit_hex_owned: ?[]u8 = null;
|
||||
defer if (commit_hex_owned) |s| allocator.free(s);
|
||||
|
||||
var commit_bytes: []u8 = &[_]u8{};
|
||||
var commit_bytes_allocated = false;
|
||||
defer if (commit_bytes_allocated) allocator.free(commit_bytes);
|
||||
|
||||
if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) {
|
||||
if (target.len == 40) {
|
||||
commit_hex = target;
|
||||
} else {
|
||||
commit_hex_owned = try resolveCommitPrefix(allocator, cfg.worker_base, target);
|
||||
commit_hex = commit_hex_owned.?;
|
||||
}
|
||||
|
||||
const decoded = crypto.decodeHex(allocator, commit_hex) catch {
|
||||
commit_hex = "";
|
||||
commit_hex_owned = null;
|
||||
return error.InvalidCommitId;
|
||||
};
|
||||
if (decoded.len != 20) {
|
||||
allocator.free(decoded);
|
||||
commit_hex = "";
|
||||
commit_hex_owned = null;
|
||||
} else {
|
||||
commit_bytes = decoded;
|
||||
commit_bytes_allocated = true;
|
||||
}
|
||||
}
|
||||
|
||||
var job_name = blk: {
|
||||
if (job_name_override) |n| break :blk n;
|
||||
break :blk "requeue";
|
||||
};
|
||||
|
||||
if (commit_hex.len == 0) {
|
||||
const manifest_path = try resolveManifestPath(allocator, target, cfg.worker_base);
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const data = try readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) return error.InvalidManifest;
|
||||
const root = parsed.value.object;
|
||||
|
||||
commit_hex = jsonGetString(root, "commit_id") orelse "";
|
||||
if (commit_hex.len != 40) {
|
||||
colors.printError("run manifest missing commit_id\n", .{});
|
||||
return error.InvalidManifest;
|
||||
}
|
||||
|
||||
if (job_name_override == null) {
|
||||
const j = jsonGetString(root, "job_name") orelse "";
|
||||
if (j.len > 0) job_name = j;
|
||||
}
|
||||
|
||||
const b = try crypto.decodeHex(allocator, commit_hex);
|
||||
if (b.len != 20) {
|
||||
allocator.free(b);
|
||||
return error.InvalidCommitId;
|
||||
}
|
||||
commit_bytes = b;
|
||||
commit_bytes_allocated = true;
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
if (note_final.len > 0) {
|
||||
try client.sendQueueJobWithArgsNoteAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_final,
|
||||
note_final,
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
gpu_memory,
|
||||
);
|
||||
} else {
|
||||
try client.sendQueueJobWithArgsAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_final,
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
gpu_memory,
|
||||
);
|
||||
}
|
||||
|
||||
try client.receiveAndHandleResponse(allocator, "Requeue");
|
||||
|
||||
colors.printSuccess("Queued requeue\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
colors.printInfo("Commit: {s}\n", .{commit_hex});
|
||||
}
|
||||
|
||||
fn isHexLowerOrUpper(s: []const u8) bool {
|
||||
for (s) |c| {
|
||||
if (!std.ascii.isHex(c)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
fn resolveCommitPrefix(allocator: std.mem.Allocator, base_path: []const u8, prefix: []const u8) ![]u8 {
|
||||
var dir = if (std.fs.path.isAbsolute(base_path))
|
||||
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
|
||||
else
|
||||
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
var found: ?[]u8 = null;
|
||||
errdefer if (found) |s| allocator.free(s);
|
||||
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
const name = entry.name;
|
||||
if (name.len != 40) continue;
|
||||
if (!std.mem.startsWith(u8, name, prefix)) continue;
|
||||
if (!isHexLowerOrUpper(name)) continue;
|
||||
|
||||
if (found != null) {
|
||||
colors.printError("Ambiguous commit prefix: {s}\n", .{prefix});
|
||||
return error.InvalidCommitId;
|
||||
}
|
||||
found = try allocator.dupe(u8, name);
|
||||
}
|
||||
|
||||
if (found) |s| return s;
|
||||
colors.printError("No commit matches prefix: {s}\n", .{prefix});
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
fn resolveManifestPath(allocator: std.mem.Allocator, input: []const u8, base_path: []const u8) ![]u8 {
|
||||
var cwd = std.fs.cwd();
|
||||
|
||||
if (std.fs.path.isAbsolute(input)) {
|
||||
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
|
||||
var mutable_dir = dir;
|
||||
defer mutable_dir.close();
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
|
||||
var mutable_file = file;
|
||||
defer mutable_file.close();
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
return resolveManifestPathById(allocator, input, base_path);
|
||||
}
|
||||
|
||||
const stat = cwd.statFile(input) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
return resolveManifestPathById(allocator, input, base_path);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
|
||||
if (stat.kind == .directory) {
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
}
|
||||
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
|
||||
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
|
||||
for (roots) |root| {
|
||||
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
|
||||
defer allocator.free(root_path);
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(root_path))
|
||||
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
|
||||
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
|
||||
defer allocator.free(run_dir);
|
||||
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const file = if (std.fs.path.isAbsolute(manifest_path))
|
||||
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
|
||||
else
|
||||
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
|
||||
defer file.close();
|
||||
|
||||
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
|
||||
defer parsed.deinit();
|
||||
if (parsed.value != .object) continue;
|
||||
|
||||
const obj = parsed.value.object;
|
||||
const run_id = jsonGetString(obj, "run_id") orelse "";
|
||||
const task_id = jsonGetString(obj, "task_id") orelse "";
|
||||
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
|
||||
return allocator.dupe(u8, manifest_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return error.FileNotFound;
|
||||
}
|
||||
|
||||
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
|
||||
var file = if (std.fs.path.isAbsolute(path))
|
||||
try std.fs.openFileAbsolute(path, .{})
|
||||
else
|
||||
try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
|
||||
return try file.readToEndAlloc(allocator, 1024 * 1024);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v = obj.get(key) orelse return null;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage:\n", .{});
|
||||
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [--name <job>] [--priority <n>] [--cpu <n>] [--memory <gb>] [--gpu <n>] [--gpu-memory <gb>] [--args <string>] [--note <string>] -- <args...>\n", .{});
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
const std = @import("std");
|
||||
const c = @cImport(@cInclude("time.h"));
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const errors = @import("../errors.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ const Config = @import("../config.zig").Config;
|
|||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig"); // Use embedded rsync
|
||||
const storage = @import("../utils/storage.zig");
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
const Config = @import("../config.zig").Config;
|
||||
const ws = @import("../net/ws.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
|
||||
pub const Options = struct {
|
||||
json: bool = false,
|
||||
|
|
@ -78,7 +79,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
defer allocator.free(msg);
|
||||
|
||||
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
|
||||
std.debug.print("{s}\n", .{msg});
|
||||
if (opts.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
} else {
|
||||
std.debug.print("{s}\n", .{msg});
|
||||
}
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer {
|
||||
|
|
@ -101,7 +107,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
|
||||
const payload = packet.data_payload.?;
|
||||
if (opts.json) {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{payload});
|
||||
} else {
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{});
|
||||
defer parsed.deinit();
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync.zig");
|
||||
const ws = @import("../net/ws.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("utils/colors.zig");
|
||||
const ws = @import("net/ws.zig");
|
||||
const crypto = @import("utils/crypto.zig");
|
||||
|
||||
/// User-friendly error types for CLI
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const Command = enum {
|
|||
jupyter,
|
||||
init,
|
||||
sync,
|
||||
requeue,
|
||||
queue,
|
||||
status,
|
||||
monitor,
|
||||
|
|
@ -78,6 +79,14 @@ pub fn main() !void {
|
|||
command_found = true;
|
||||
try @import("commands/info.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'a' => if (std.mem.eql(u8, command, "annotate")) {
|
||||
command_found = true;
|
||||
try @import("commands/annotate.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'n' => if (std.mem.eql(u8, command, "narrative")) {
|
||||
command_found = true;
|
||||
try @import("commands/narrative.zig").run(allocator, args[2..]);
|
||||
},
|
||||
's' => if (std.mem.eql(u8, command, "sync")) {
|
||||
command_found = true;
|
||||
if (args.len < 3) {
|
||||
|
|
@ -89,6 +98,10 @@ pub fn main() !void {
|
|||
command_found = true;
|
||||
try @import("commands/status.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'r' => if (std.mem.eql(u8, command, "requeue")) {
|
||||
command_found = true;
|
||||
try @import("commands/requeue.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'q' => if (std.mem.eql(u8, command, "queue")) {
|
||||
command_found = true;
|
||||
try @import("commands/queue.zig").run(allocator, args[2..]);
|
||||
|
|
@ -127,8 +140,11 @@ fn printUsage() void {
|
|||
std.debug.print("Commands:\n", .{});
|
||||
std.debug.print(" jupyter Jupyter workspace management\n", .{});
|
||||
std.debug.print(" init Setup configuration interactively\n", .{});
|
||||
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
|
||||
std.debug.print(" narrative set <path|id> Set run narrative fields (hypothesis/context/...)\n", .{});
|
||||
std.debug.print(" info <path|id> Show run info from run_manifest.json (optionally --base <path>)\n", .{});
|
||||
std.debug.print(" sync <path> Sync project to server\n", .{});
|
||||
std.debug.print(" requeue <id> Re-submit from run_id/task_id/path (supports -- <args>)\n", .{});
|
||||
std.debug.print(" queue (q) <job> Queue job for execution\n", .{});
|
||||
std.debug.print(" status Get system status\n", .{});
|
||||
std.debug.print(" monitor Launch TUI via SSH\n", .{});
|
||||
|
|
@ -143,4 +159,7 @@ fn printUsage() void {
|
|||
|
||||
test {
|
||||
_ = @import("commands/info.zig");
|
||||
_ = @import("commands/requeue.zig");
|
||||
_ = @import("commands/annotate.zig");
|
||||
_ = @import("commands/narrative.zig");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
// Network module - exports all network modules
|
||||
pub const protocol = @import("net/protocol.zig");
|
||||
pub const ws = @import("net/ws.zig");
|
||||
pub const ws_client = @import("net/ws/client.zig");
|
||||
pub const ws_opcodes = @import("net/ws/opcodes.zig");
|
||||
|
|
|
|||
1716
cli/src/net/ws.zig
1716
cli/src/net/ws.zig
File diff suppressed because it is too large
Load diff
1256
cli/src/net/ws/client.zig
Normal file
1256
cli/src/net/ws/client.zig
Normal file
File diff suppressed because it is too large
Load diff
8
cli/src/net/ws/deps.zig
Normal file
8
cli/src/net/ws/deps.zig
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
pub const std = @import("std");
|
||||
|
||||
pub const colors = @import("../../utils/colors.zig");
|
||||
pub const crypto = @import("../../utils/crypto.zig");
|
||||
pub const io = @import("../../utils/io.zig");
|
||||
pub const log = @import("../../utils/logging.zig");
|
||||
|
||||
pub const protocol = @import("../protocol.zig");
|
||||
71
cli/src/net/ws/frame.zig
Normal file
71
cli/src/net/ws/frame.zig
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub fn sendWebSocketFrame(stream: std.net.Stream, payload: []const u8) !void {
|
||||
var frame: [14]u8 = undefined;
|
||||
var frame_len: usize = 2;
|
||||
|
||||
// FIN=1, opcode=0x2 (binary), MASK=1
|
||||
frame[0] = 0x82 | 0x80;
|
||||
|
||||
// Payload length
|
||||
if (payload.len < 126) {
|
||||
frame[1] = @as(u8, @intCast(payload.len)) | 0x80;
|
||||
} else if (payload.len < 65536) {
|
||||
frame[1] = 126 | 0x80;
|
||||
frame[2] = @intCast(payload.len >> 8);
|
||||
frame[3] = @intCast(payload.len & 0xFF);
|
||||
frame_len = 4;
|
||||
} else {
|
||||
return error.PayloadTooLarge;
|
||||
}
|
||||
|
||||
// Generate random mask (4 bytes)
|
||||
var mask: [4]u8 = undefined;
|
||||
var i: usize = 0;
|
||||
while (i < 4) : (i += 1) {
|
||||
mask[i] = @as(u8, @intCast(@mod(std.time.timestamp(), 256)));
|
||||
}
|
||||
|
||||
@memcpy(frame[frame_len .. frame_len + 4], &mask);
|
||||
frame_len += 4;
|
||||
|
||||
_ = try stream.write(frame[0..frame_len]);
|
||||
|
||||
var masked_payload = try std.heap.page_allocator.alloc(u8, payload.len);
|
||||
defer std.heap.page_allocator.free(masked_payload);
|
||||
|
||||
for (payload, 0..) |byte, j| {
|
||||
masked_payload[j] = byte ^ mask[j % 4];
|
||||
}
|
||||
|
||||
_ = try stream.write(masked_payload);
|
||||
}
|
||||
|
||||
pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator) ![]u8 {
|
||||
var header: [2]u8 = undefined;
|
||||
const header_bytes = try stream.read(&header);
|
||||
if (header_bytes < 2) return error.ConnectionClosed;
|
||||
|
||||
if (header[0] != 0x82) return error.InvalidFrame;
|
||||
|
||||
var payload_len: usize = header[1];
|
||||
if (payload_len == 126) {
|
||||
var len_bytes: [2]u8 = undefined;
|
||||
_ = try stream.read(&len_bytes);
|
||||
payload_len = (@as(usize, len_bytes[0]) << 8) | len_bytes[1];
|
||||
} else if (payload_len == 127) {
|
||||
return error.PayloadTooLarge;
|
||||
}
|
||||
|
||||
const payload = try allocator.alloc(u8, payload_len);
|
||||
errdefer allocator.free(payload);
|
||||
|
||||
var bytes_read: usize = 0;
|
||||
while (bytes_read < payload_len) {
|
||||
const n = try stream.read(payload[bytes_read..]);
|
||||
if (n == 0) return error.ConnectionClosed;
|
||||
bytes_read += n;
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
114
cli/src/net/ws/handshake.zig
Normal file
114
cli/src/net/ws/handshake.zig
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
const std = @import("std");
|
||||
|
||||
fn generateWebSocketKey(allocator: std.mem.Allocator) ![]u8 {
|
||||
var random_bytes: [16]u8 = undefined;
|
||||
std.crypto.random.bytes(&random_bytes);
|
||||
|
||||
const base64 = std.base64.standard.Encoder;
|
||||
const result = try allocator.alloc(u8, base64.calcSize(random_bytes.len));
|
||||
_ = base64.encode(result, &random_bytes);
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn handshake(
|
||||
allocator: std.mem.Allocator,
|
||||
stream: std.net.Stream,
|
||||
host: []const u8,
|
||||
url: []const u8,
|
||||
api_key: []const u8,
|
||||
) !void {
|
||||
const key = try generateWebSocketKey(allocator);
|
||||
defer allocator.free(key);
|
||||
|
||||
const request = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {s}\r\n" ++
|
||||
"Upgrade: websocket\r\n" ++
|
||||
"Connection: Upgrade\r\n" ++
|
||||
"Sec-WebSocket-Key: {s}\r\n" ++
|
||||
"Sec-WebSocket-Version: 13\r\n" ++
|
||||
"X-API-Key: {s}\r\n" ++
|
||||
"\r\n",
|
||||
.{ url, host, key, api_key },
|
||||
);
|
||||
defer allocator.free(request);
|
||||
|
||||
_ = try stream.write(request);
|
||||
|
||||
var response_buf: [4096]u8 = undefined;
|
||||
var bytes_read: usize = 0;
|
||||
var header_complete = false;
|
||||
|
||||
while (!header_complete and bytes_read < response_buf.len - 1) {
|
||||
const chunk_bytes = try stream.read(response_buf[bytes_read..]);
|
||||
if (chunk_bytes == 0) break;
|
||||
bytes_read += chunk_bytes;
|
||||
|
||||
if (std.mem.indexOf(u8, response_buf[0..bytes_read], "\r\n\r\n") != null) {
|
||||
header_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = response_buf[0..bytes_read];
|
||||
|
||||
if (std.mem.indexOf(u8, response, "101 Switching Protocols") == null) {
|
||||
if (std.mem.indexOf(u8, response, "404 Not Found") != null) {
|
||||
std.debug.print("\n❌ WebSocket Connection Failed\n", .{});
|
||||
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
|
||||
std.debug.print("The WebSocket endpoint '/ws' was not found on the server.\n\n", .{});
|
||||
std.debug.print("This usually means:\n", .{});
|
||||
std.debug.print(" • API server is not running\n", .{});
|
||||
std.debug.print(" • Incorrect server address in config\n", .{});
|
||||
std.debug.print(" • Different service running on that port\n\n", .{});
|
||||
std.debug.print("To diagnose:\n", .{});
|
||||
std.debug.print(" • Verify server address: Check ~/.ml/config.toml\n", .{});
|
||||
std.debug.print(" • Test connectivity: curl http://<server>:<port>/health\n", .{});
|
||||
std.debug.print(" • Contact your server administrator if the issue persists\n\n", .{});
|
||||
return error.EndpointNotFound;
|
||||
} else if (std.mem.indexOf(u8, response, "401 Unauthorized") != null) {
|
||||
std.debug.print("\n❌ Authentication Failed\n", .{});
|
||||
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
|
||||
std.debug.print("Invalid or missing API key.\n\n", .{});
|
||||
std.debug.print("To fix:\n", .{});
|
||||
std.debug.print(" • Verify API key in ~/.ml/config.toml matches server configuration\n", .{});
|
||||
std.debug.print(" • Request a new API key from your administrator if needed\n\n", .{});
|
||||
return error.AuthenticationFailed;
|
||||
} else if (std.mem.indexOf(u8, response, "403 Forbidden") != null) {
|
||||
std.debug.print("\n❌ Access Denied\n", .{});
|
||||
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
|
||||
std.debug.print("Your API key doesn't have permission for this operation.\n\n", .{});
|
||||
std.debug.print("To fix:\n", .{});
|
||||
std.debug.print(" • Contact your administrator to grant necessary permissions\n\n", .{});
|
||||
return error.PermissionDenied;
|
||||
} else if (std.mem.indexOf(u8, response, "503 Service Unavailable") != null) {
|
||||
std.debug.print("\n❌ Server Unavailable\n", .{});
|
||||
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
|
||||
std.debug.print("The server is temporarily unavailable.\n\n", .{});
|
||||
std.debug.print("This could be due to:\n", .{});
|
||||
std.debug.print(" • Server maintenance\n", .{});
|
||||
std.debug.print(" • High load\n", .{});
|
||||
std.debug.print(" • Server restart\n\n", .{});
|
||||
std.debug.print("To resolve:\n", .{});
|
||||
std.debug.print(" • Wait a moment and try again\n", .{});
|
||||
std.debug.print(" • Contact administrator if the issue persists\n\n", .{});
|
||||
return error.ServerUnavailable;
|
||||
} else {
|
||||
std.debug.print("\n❌ WebSocket Handshake Failed\n", .{});
|
||||
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
|
||||
std.debug.print("Expected HTTP 101 Switching Protocols, but received:\n", .{});
|
||||
|
||||
const newline_pos = std.mem.indexOf(u8, response, "\r\n") orelse response.len;
|
||||
const status_line = response[0..newline_pos];
|
||||
std.debug.print(" {s}\n\n", .{status_line});
|
||||
|
||||
std.debug.print("To diagnose:\n", .{});
|
||||
std.debug.print(" • Verify server address in ~/.ml/config.toml\n", .{});
|
||||
std.debug.print(" • Check network connectivity to the server\n", .{});
|
||||
std.debug.print(" • Contact your administrator for assistance\n\n", .{});
|
||||
return error.HandshakeFailed;
|
||||
}
|
||||
}
|
||||
|
||||
std.posix.nanosleep(0, 10 * std.time.ns_per_ms);
|
||||
}
|
||||
73
cli/src/net/ws/opcode.zig
Normal file
73
cli/src/net/ws/opcode.zig
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
pub const Opcode = enum(u8) {
|
||||
queue_job = 0x01,
|
||||
queue_job_with_tracking = 0x0C,
|
||||
queue_job_with_snapshot = 0x17,
|
||||
queue_job_with_args = 0x1A,
|
||||
queue_job_with_note = 0x1B,
|
||||
annotate_run = 0x1C,
|
||||
set_run_narrative = 0x1D,
|
||||
status_request = 0x02,
|
||||
cancel_job = 0x03,
|
||||
prune = 0x04,
|
||||
crash_report = 0x05,
|
||||
log_metric = 0x0A,
|
||||
get_experiment = 0x0B,
|
||||
start_jupyter = 0x0D,
|
||||
stop_jupyter = 0x0E,
|
||||
remove_jupyter = 0x18,
|
||||
restore_jupyter = 0x19,
|
||||
list_jupyter = 0x0F,
|
||||
list_jupyter_packages = 0x1E,
|
||||
|
||||
validate_request = 0x16,
|
||||
|
||||
// Dataset management opcodes
|
||||
dataset_list = 0x06,
|
||||
dataset_register = 0x07,
|
||||
dataset_info = 0x08,
|
||||
dataset_search = 0x09,
|
||||
|
||||
// Structured response opcodes
|
||||
response_success = 0x10,
|
||||
response_error = 0x11,
|
||||
response_progress = 0x12,
|
||||
response_status = 0x13,
|
||||
response_data = 0x14,
|
||||
response_log = 0x15,
|
||||
};
|
||||
|
||||
pub const ValidateTargetType = enum(u8) {
|
||||
commit_id = 0,
|
||||
task_id = 1,
|
||||
};
|
||||
|
||||
pub const queue_job = Opcode.queue_job;
|
||||
pub const queue_job_with_tracking = Opcode.queue_job_with_tracking;
|
||||
pub const queue_job_with_snapshot = Opcode.queue_job_with_snapshot;
|
||||
pub const queue_job_with_args = Opcode.queue_job_with_args;
|
||||
pub const queue_job_with_note = Opcode.queue_job_with_note;
|
||||
pub const annotate_run = Opcode.annotate_run;
|
||||
pub const set_run_narrative = Opcode.set_run_narrative;
|
||||
pub const status_request = Opcode.status_request;
|
||||
pub const cancel_job = Opcode.cancel_job;
|
||||
pub const prune = Opcode.prune;
|
||||
pub const crash_report = Opcode.crash_report;
|
||||
pub const log_metric = Opcode.log_metric;
|
||||
pub const get_experiment = Opcode.get_experiment;
|
||||
pub const start_jupyter = Opcode.start_jupyter;
|
||||
pub const stop_jupyter = Opcode.stop_jupyter;
|
||||
pub const remove_jupyter = Opcode.remove_jupyter;
|
||||
pub const restore_jupyter = Opcode.restore_jupyter;
|
||||
pub const list_jupyter = Opcode.list_jupyter;
|
||||
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
|
||||
pub const validate_request = Opcode.validate_request;
|
||||
pub const dataset_list = Opcode.dataset_list;
|
||||
pub const dataset_register = Opcode.dataset_register;
|
||||
pub const dataset_info = Opcode.dataset_info;
|
||||
pub const dataset_search = Opcode.dataset_search;
|
||||
pub const response_success = Opcode.response_success;
|
||||
pub const response_error = Opcode.response_error;
|
||||
pub const response_progress = Opcode.response_progress;
|
||||
pub const response_status = Opcode.response_status;
|
||||
pub const response_data = Opcode.response_data;
|
||||
pub const response_log = Opcode.response_log;
|
||||
2
cli/src/net/ws/opcodes.zig
Normal file
2
cli/src/net/ws/opcodes.zig
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub const Opcode = @import("opcode.zig").Opcode;
|
||||
pub const ValidateTargetType = @import("opcode.zig").ValidateTargetType;
|
||||
21
cli/src/net/ws/resolve.zig
Normal file
21
cli/src/net/ws/resolve.zig
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub fn resolveHostAddress(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address {
|
||||
return std.net.Address.parseIp(host, port) catch |err| switch (err) {
|
||||
error.InvalidIPAddressFormat => resolveHostname(allocator, host, port),
|
||||
else => return err,
|
||||
};
|
||||
}
|
||||
|
||||
fn resolveHostname(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address {
|
||||
var address_list = try std.net.getAddressList(allocator, host, port);
|
||||
defer address_list.deinit();
|
||||
|
||||
if (address_list.addrs.len == 0) return error.HostResolutionFailed;
|
||||
|
||||
return address_list.addrs[0];
|
||||
}
|
||||
|
||||
test "resolve hostnames for WebSocket connections" {
|
||||
_ = try resolveHostAddress(std.testing.allocator, "localhost", 9100);
|
||||
}
|
||||
74
cli/src/net/ws/response.zig
Normal file
74
cli/src/net/ws/response.zig
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
const std = @import("std");
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) {
|
||||
return null;
|
||||
}
|
||||
const v = v_opt.?;
|
||||
if (v != .string) {
|
||||
return null;
|
||||
}
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) {
|
||||
return null;
|
||||
}
|
||||
const v = v_opt.?;
|
||||
if (v != .integer) {
|
||||
return null;
|
||||
}
|
||||
return v.integer;
|
||||
}
|
||||
|
||||
pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 {
|
||||
const prewarm_val_opt = root.get("prewarm");
|
||||
if (prewarm_val_opt == null) {
|
||||
return null;
|
||||
}
|
||||
const prewarm_val = prewarm_val_opt.?;
|
||||
if (prewarm_val != .array) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const items = prewarm_val.array.items;
|
||||
if (items.len == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
var out = std.ArrayList(u8){};
|
||||
errdefer out.deinit(allocator);
|
||||
|
||||
const writer = out.writer(allocator);
|
||||
try writer.writeAll("Prewarm:\n");
|
||||
|
||||
for (items) |item| {
|
||||
if (item != .object) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const obj = item.object;
|
||||
|
||||
const worker_id = jsonGetString(obj, "worker_id") orelse "";
|
||||
const task_id = jsonGetString(obj, "task_id") orelse "";
|
||||
const phase = jsonGetString(obj, "phase") orelse "";
|
||||
const started_at = jsonGetString(obj, "started_at") orelse "";
|
||||
const dataset_count = jsonGetInt(obj, "dataset_count") orelse 0;
|
||||
const snapshot_id = jsonGetString(obj, "snapshot_id") orelse "";
|
||||
const env_image = jsonGetString(obj, "env_image") orelse "";
|
||||
const env_hit = jsonGetInt(obj, "env_hit") orelse 0;
|
||||
const env_miss = jsonGetInt(obj, "env_miss") orelse 0;
|
||||
const env_built = jsonGetInt(obj, "env_built") orelse 0;
|
||||
|
||||
try writer.print(
|
||||
" worker={s} task={s} phase={s} datasets={d} snapshot={s} env={s} env_hit={d} env_miss={d} env_built={d} started={s}\n",
|
||||
.{ worker_id, task_id, phase, dataset_count, snapshot_id, env_image, env_hit, env_miss, env_built, started_at },
|
||||
);
|
||||
}
|
||||
|
||||
const owned = try out.toOwnedSlice(allocator);
|
||||
return owned;
|
||||
}
|
||||
449
cli/src/net/ws/response_handlers.zig
Normal file
449
cli/src/net/ws/response_handlers.zig
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
const deps = @import("deps.zig");
|
||||
const std = deps.std;
|
||||
const io = deps.io;
|
||||
const protocol = deps.protocol;
|
||||
const colors = deps.colors;
|
||||
const Client = @import("client.zig").Client;
|
||||
const utils = @import("utils.zig");
|
||||
|
||||
/// Receive and handle status response with user filtering
|
||||
pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void {
|
||||
_ = user_context; // TODO: Use for filtering
|
||||
const message = try self.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Check if message is JSON (or contains JSON) or plain text
|
||||
if (message[0] == '{') {
|
||||
// Parse JSON response
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{});
|
||||
defer parsed.deinit();
|
||||
const root = parsed.value.object;
|
||||
|
||||
if (options.json) {
|
||||
// Output raw JSON
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{message});
|
||||
} else {
|
||||
// Display user info
|
||||
if (root.get("user")) |user_obj| {
|
||||
const user = user_obj.object;
|
||||
const name = user.get("name").?.string;
|
||||
const admin = user.get("admin").?.bool;
|
||||
colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin });
|
||||
}
|
||||
|
||||
// Display task summary
|
||||
if (root.get("tasks")) |tasks_obj| {
|
||||
const tasks = tasks_obj.object;
|
||||
const total = tasks.get("total").?.integer;
|
||||
const queued = tasks.get("queued").?.integer;
|
||||
const running = tasks.get("running").?.integer;
|
||||
const failed = tasks.get("failed").?.integer;
|
||||
const completed = tasks.get("completed").?.integer;
|
||||
colors.printInfo(
|
||||
"Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n",
|
||||
.{ total, queued, running, failed, completed },
|
||||
);
|
||||
}
|
||||
|
||||
const per_section_limit: usize = options.limit orelse 5;
|
||||
|
||||
const TaskStatus = enum { queued, running, failed, completed };
|
||||
|
||||
const TaskPrinter = struct {
|
||||
fn statusLabel(s: TaskStatus) []const u8 {
|
||||
return switch (s) {
|
||||
.queued => "Queued",
|
||||
.running => "Running",
|
||||
.failed => "Failed",
|
||||
.completed => "Completed",
|
||||
};
|
||||
}
|
||||
|
||||
fn statusMatch(s: TaskStatus) []const u8 {
|
||||
return switch (s) {
|
||||
.queued => "queued",
|
||||
.running => "running",
|
||||
.failed => "failed",
|
||||
.completed => "completed",
|
||||
};
|
||||
}
|
||||
|
||||
fn shorten(s: []const u8, max_len: usize) []const u8 {
|
||||
if (s.len <= max_len) return s;
|
||||
return s[0..max_len];
|
||||
}
|
||||
|
||||
fn printSection(
|
||||
allocator2: std.mem.Allocator,
|
||||
queue_items: []const std.json.Value,
|
||||
status: TaskStatus,
|
||||
limit2: usize,
|
||||
) !void {
|
||||
_ = allocator2;
|
||||
const label = statusLabel(status);
|
||||
const want = statusMatch(status);
|
||||
std.debug.print("\n{s}:\n", .{label});
|
||||
|
||||
var shown: usize = 0;
|
||||
for (queue_items) |item| {
|
||||
if (item != .object) continue;
|
||||
const obj = item.object;
|
||||
const st = utils.jsonGetString(obj, "status") orelse "";
|
||||
if (!std.mem.eql(u8, st, want)) continue;
|
||||
|
||||
const id = utils.jsonGetString(obj, "id") orelse "";
|
||||
const job_name = utils.jsonGetString(obj, "job_name") orelse "";
|
||||
const worker_id = utils.jsonGetString(obj, "worker_id") orelse "";
|
||||
const err = utils.jsonGetString(obj, "error") orelse "";
|
||||
|
||||
if (std.mem.eql(u8, want, "failed")) {
|
||||
colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
if (err.len > 0) {
|
||||
std.debug.print(" error: {s}\n", .{shorten(err, 160)});
|
||||
}
|
||||
} else if (std.mem.eql(u8, want, "running")) {
|
||||
colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
} else if (std.mem.eql(u8, want, "queued")) {
|
||||
std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
} else {
|
||||
colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
}
|
||||
|
||||
shown += 1;
|
||||
if (shown >= limit2) break;
|
||||
}
|
||||
|
||||
if (shown == 0) {
|
||||
std.debug.print(" (none)\n", .{});
|
||||
} else {
|
||||
// Indicate there may be more.
|
||||
var total_for_status: usize = 0;
|
||||
for (queue_items) |item| {
|
||||
if (item != .object) continue;
|
||||
const obj = item.object;
|
||||
const st = utils.jsonGetString(obj, "status") orelse "";
|
||||
if (std.mem.eql(u8, st, want)) total_for_status += 1;
|
||||
}
|
||||
if (total_for_status > shown) {
|
||||
std.debug.print(" ... and {d} more\n", .{total_for_status - shown});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (root.get("queue")) |queue_val| {
|
||||
if (queue_val == .array) {
|
||||
const items = queue_val.array.items;
|
||||
try TaskPrinter.printSection(allocator, items, .queued, per_section_limit);
|
||||
try TaskPrinter.printSection(allocator, items, .running, per_section_limit);
|
||||
try TaskPrinter.printSection(allocator, items, .failed, per_section_limit);
|
||||
try TaskPrinter.printSection(allocator, items, .completed, per_section_limit);
|
||||
}
|
||||
}
|
||||
|
||||
if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| {
|
||||
defer allocator.free(section);
|
||||
colors.printInfo("{s}", .{section});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle plain text response - filter out non-printable characters
|
||||
var clean_msg = allocator.alloc(u8, message.len) catch {
|
||||
if (options.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
|
||||
} else {
|
||||
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
|
||||
}
|
||||
return;
|
||||
};
|
||||
defer allocator.free(clean_msg);
|
||||
|
||||
var clean_len: usize = 0;
|
||||
for (message) |byte| {
|
||||
// Skip WebSocket frame header bytes and non-printable chars
|
||||
if (byte >= 32 and byte <= 126) { // printable ASCII only
|
||||
clean_msg[clean_len] = byte;
|
||||
clean_len += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Look for common error messages in the cleaned data
|
||||
if (clean_len > 0) {
|
||||
const cleaned = clean_msg[0..clean_len];
|
||||
if (options.json) {
|
||||
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{});
|
||||
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"authentication_failed\"}}\n", .{});
|
||||
} else {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned});
|
||||
}
|
||||
} else {
|
||||
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
|
||||
std.debug.print("Insufficient permissions to view jobs\n", .{});
|
||||
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
|
||||
std.debug.print("Authentication failed\n", .{});
|
||||
} else {
|
||||
std.debug.print("Server response: {s}\n", .{cleaned});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (options.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
|
||||
} else {
|
||||
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive and handle cancel response with user permissions
|
||||
pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void {
|
||||
const message = try self.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Check if message is JSON or plain text
|
||||
if (message[0] == '{') {
|
||||
// Parse JSON response
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{});
|
||||
defer parsed.deinit();
|
||||
const root = parsed.value.object;
|
||||
|
||||
if (options.json) {
|
||||
// Output raw JSON
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{message});
|
||||
} else {
|
||||
// Display user-friendly output
|
||||
if (root.get("success")) |success_val| {
|
||||
if (success_val.bool) {
|
||||
colors.printSuccess("Job '{s}' canceled successfully\n", .{job_name});
|
||||
} else {
|
||||
colors.printError("Failed to cancel job '{s}'\n", .{job_name});
|
||||
if (root.get("error")) |error_val| {
|
||||
colors.printError("Error: {s}\n", .{error_val.string});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name });
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle plain text response - filter out non-printable characters
|
||||
var clean_msg = allocator.alloc(u8, message.len) catch {
|
||||
if (options.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
|
||||
} else {
|
||||
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
|
||||
}
|
||||
return;
|
||||
};
|
||||
defer allocator.free(clean_msg);
|
||||
|
||||
var clean_len: usize = 0;
|
||||
for (message) |byte| {
|
||||
// Skip WebSocket frame header bytes and non-printable chars
|
||||
if (byte >= 32 and byte <= 126) { // printable ASCII only
|
||||
clean_msg[clean_len] = byte;
|
||||
clean_len += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (clean_len > 0) {
|
||||
const cleaned = clean_msg[0..clean_len];
|
||||
if (options.json) {
|
||||
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{});
|
||||
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"authentication_failed\"}}\n", .{});
|
||||
} else {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned});
|
||||
}
|
||||
} else {
|
||||
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
|
||||
std.debug.print("Insufficient permissions to cancel job\n", .{});
|
||||
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
|
||||
std.debug.print("Authentication failed\n", .{});
|
||||
} else {
|
||||
colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name });
|
||||
colors.printInfo("Response: {s}\n", .{cleaned});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (options.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
|
||||
} else {
|
||||
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle response packet with appropriate display
|
||||
pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void {
|
||||
switch (packet.packet_type) {
|
||||
.success => {
|
||||
if (packet.success_message) |msg| {
|
||||
if (msg.len > 0) {
|
||||
std.debug.print("✓ {s}: {s}\n", .{ operation, msg });
|
||||
} else {
|
||||
std.debug.print("✓ {s} completed successfully\n", .{operation});
|
||||
}
|
||||
} else {
|
||||
std.debug.print("✓ {s} completed successfully\n", .{operation});
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
std.debug.print("✗ {s} failed: {s}\n", .{ operation, error_msg });
|
||||
|
||||
if (packet.error_message) |msg| {
|
||||
if (msg.len > 0) {
|
||||
std.debug.print("Details: {s}\n", .{msg});
|
||||
}
|
||||
}
|
||||
|
||||
if (packet.error_details) |details| {
|
||||
if (details.len > 0) {
|
||||
std.debug.print("Additional info: {s}\n", .{details});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to appropriate CLI error
|
||||
return convertServerError(self, packet.error_code.?);
|
||||
},
|
||||
.progress => {
|
||||
if (packet.progress_type) |ptype| {
|
||||
switch (ptype) {
|
||||
.percentage => {
|
||||
const percentage = packet.progress_value.?;
|
||||
if (packet.progress_total) |total| {
|
||||
std.debug.print("Progress: {d}/{d} ({d:.1}%)\n", .{ percentage, total, @as(f32, @floatFromInt(percentage)) * 100.0 / @as(f32, @floatFromInt(total)) });
|
||||
} else {
|
||||
std.debug.print("Progress: {d}%\n", .{percentage});
|
||||
}
|
||||
},
|
||||
.stage => {
|
||||
if (packet.progress_message) |msg| {
|
||||
std.debug.print("Stage: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
.message => {
|
||||
if (packet.progress_message) |msg| {
|
||||
std.debug.print("Info: {s}\n", .{msg});
|
||||
}
|
||||
},
|
||||
.bytes_transferred => {
|
||||
const bytes = packet.progress_value.?;
|
||||
if (packet.progress_total) |total| {
|
||||
const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0;
|
||||
const total_mb = @as(f64, @floatFromInt(total)) / 1024.0 / 1024.0;
|
||||
std.debug.print("Transferred: {d:.2} MB / {d:.2} MB\n", .{ transferred_mb, total_mb });
|
||||
} else {
|
||||
const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0;
|
||||
std.debug.print("Transferred: {d:.2} MB\n", .{transferred_mb});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
.status => {
|
||||
if (packet.status_data) |data| {
|
||||
std.debug.print("Status: {s}\n", .{data});
|
||||
}
|
||||
},
|
||||
.data => {
|
||||
if (packet.data_type) |dtype| {
|
||||
std.debug.print("Data [{s}]: ", .{dtype});
|
||||
if (packet.data_payload) |payload| {
|
||||
// Try to display as string if it looks like text
|
||||
const is_text = for (payload) |byte| {
|
||||
if (byte < 32 and byte != '\n' and byte != '\r' and byte != '\t') break false;
|
||||
} else true;
|
||||
|
||||
if (is_text) {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
} else {
|
||||
std.debug.print("{d} bytes\n", .{payload.len});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
.log => {
|
||||
if (packet.log_level) |level| {
|
||||
const level_name = protocol.ResponsePacket.getLogLevelName(level);
|
||||
if (packet.log_message) |msg| {
|
||||
std.debug.print("[{s}] {s}\n", .{ level_name, msg });
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert server error code to CLI error
|
||||
fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror {
|
||||
_ = self;
|
||||
return switch (server_error) {
|
||||
.authentication_failed => error.AuthenticationFailed,
|
||||
.permission_denied => error.PermissionDenied,
|
||||
.resource_not_found => error.JobNotFound,
|
||||
.resource_already_exists => error.ResourceExists,
|
||||
.timeout => error.RequestTimeout,
|
||||
.server_overloaded, .service_unavailable => error.ServerUnreachable,
|
||||
.invalid_request => error.InvalidArguments,
|
||||
.job_not_found => error.JobNotFound,
|
||||
.job_already_running => error.JobAlreadyRunning,
|
||||
.job_failed_to_start, .job_execution_failed => error.CommandFailed,
|
||||
.job_cancelled => error.JobCancelled,
|
||||
else => error.ServerError,
|
||||
};
|
||||
}
|
||||
|
||||
/// Clean up packet allocated memory
|
||||
pub fn cleanupPacket(self: *Client, packet: protocol.ResponsePacket) void {
|
||||
if (packet.success_message) |msg| {
|
||||
self.allocator.free(msg);
|
||||
}
|
||||
if (packet.error_message) |msg| {
|
||||
self.allocator.free(msg);
|
||||
}
|
||||
if (packet.error_details) |details| {
|
||||
self.allocator.free(details);
|
||||
}
|
||||
if (packet.progress_message) |msg| {
|
||||
self.allocator.free(msg);
|
||||
}
|
||||
if (packet.status_data) |data| {
|
||||
self.allocator.free(data);
|
||||
}
|
||||
if (packet.data_type) |dtype| {
|
||||
self.allocator.free(dtype);
|
||||
}
|
||||
if (packet.data_payload) |payload| {
|
||||
self.allocator.free(payload);
|
||||
}
|
||||
if (packet.log_message) |msg| {
|
||||
self.allocator.free(msg);
|
||||
}
|
||||
}
|
||||
25
cli/src/net/ws/utils.zig
Normal file
25
cli/src/net/ws/utils.zig
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) {
|
||||
return null;
|
||||
}
|
||||
const v = v_opt.?;
|
||||
if (v != .string) {
|
||||
return null;
|
||||
}
|
||||
return v.string;
|
||||
}
|
||||
|
||||
pub fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) {
|
||||
return null;
|
||||
}
|
||||
const v = v_opt.?;
|
||||
if (v != .integer) {
|
||||
return null;
|
||||
}
|
||||
return v.integer;
|
||||
}
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
// Utils module - exports all utility modules
|
||||
// Utilities module - exports all utility modules
|
||||
pub const colors = @import("utils/colors.zig");
|
||||
pub const crypto = @import("utils/crypto.zig");
|
||||
pub const history = @import("utils/history.zig");
|
||||
pub const io = @import("utils/io.zig");
|
||||
pub const logging = @import("utils/logging.zig");
|
||||
pub const rsync = @import("utils/rsync.zig");
|
||||
pub const rsync_embedded = @import("utils/rsync_embedded.zig");
|
||||
pub const rsync_embedded_binary = @import("utils/rsync_embedded_binary.zig");
|
||||
pub const storage = @import("utils/storage.zig");
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ fn hexNibble(c: u8) ?u8 {
|
|||
pub fn decodeHex(allocator: std.mem.Allocator, hex: []const u8) ![]u8 {
|
||||
if ((hex.len % 2) != 0) return error.InvalidHex;
|
||||
const out = try allocator.alloc(u8, hex.len / 2);
|
||||
errdefer allocator.free(out);
|
||||
var i: usize = 0;
|
||||
while (i < out.len) : (i += 1) {
|
||||
const hi = hexNibble(hex[i * 2]) orelse return error.InvalidHex;
|
||||
|
|
|
|||
59
cli/src/utils/io.zig
Normal file
59
cli/src/utils/io.zig
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
const std = @import("std");
|
||||
|
||||
fn writeAllFd(fd: std.posix.fd_t, data: []const u8) std.Io.Writer.Error!void {
|
||||
var off: usize = 0;
|
||||
while (off < data.len) {
|
||||
const n = std.posix.write(fd, data[off..]) catch return error.WriteFailed;
|
||||
if (n == 0) return error.WriteFailed;
|
||||
off += n;
|
||||
}
|
||||
}
|
||||
|
||||
fn drainStdout(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
|
||||
_ = w;
|
||||
if (data.len == 0) return 0;
|
||||
var written: usize = 0;
|
||||
for (data) |chunk| {
|
||||
try writeAllFd(std.posix.STDOUT_FILENO, chunk);
|
||||
written += chunk.len;
|
||||
}
|
||||
if (splat > 0) {
|
||||
const last = data[data.len - 1];
|
||||
var i: usize = 0;
|
||||
while (i < splat) : (i += 1) {
|
||||
try writeAllFd(std.posix.STDOUT_FILENO, last);
|
||||
written += last.len;
|
||||
}
|
||||
}
|
||||
return written;
|
||||
}
|
||||
|
||||
fn drainStderr(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
|
||||
_ = w;
|
||||
if (data.len == 0) return 0;
|
||||
var written: usize = 0;
|
||||
for (data) |chunk| {
|
||||
try writeAllFd(std.posix.STDERR_FILENO, chunk);
|
||||
written += chunk.len;
|
||||
}
|
||||
if (splat > 0) {
|
||||
const last = data[data.len - 1];
|
||||
var i: usize = 0;
|
||||
while (i < splat) : (i += 1) {
|
||||
try writeAllFd(std.posix.STDERR_FILENO, last);
|
||||
written += last.len;
|
||||
}
|
||||
}
|
||||
return written;
|
||||
}
|
||||
|
||||
const stdout_vtable = std.Io.Writer.VTable{ .drain = drainStdout };
|
||||
const stderr_vtable = std.Io.Writer.VTable{ .drain = drainStderr };
|
||||
|
||||
pub fn stdoutWriter() std.Io.Writer {
|
||||
return .{ .vtable = &stdout_vtable, .buffer = &[_]u8{}, .end = 0 };
|
||||
}
|
||||
|
||||
pub fn stderrWriter() std.Io.Writer {
|
||||
return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 };
|
||||
}
|
||||
|
|
@ -6,24 +6,127 @@ pub const EmbeddedRsync = struct {
|
|||
|
||||
const Self = @This();
|
||||
|
||||
/// Extract embedded rsync binary to temporary location
|
||||
pub fn extractRsyncBinary(self: Self) ![]const u8 {
|
||||
const rsync_path = "/tmp/ml_rsync";
|
||||
fn encodeHexLower(out: []u8, bytes: []const u8) void {
|
||||
const hex = "0123456789abcdef";
|
||||
var i: usize = 0;
|
||||
while (i < bytes.len) : (i += 1) {
|
||||
const b = bytes[i];
|
||||
out[i * 2] = hex[(b >> 4) & 0x0f];
|
||||
out[i * 2 + 1] = hex[b & 0x0f];
|
||||
}
|
||||
}
|
||||
|
||||
// Check if rsync binary already exists
|
||||
if (std.fs.cwd().openFile(rsync_path, .{})) |file| {
|
||||
file.close();
|
||||
// Check if it's executable
|
||||
const stat = try std.fs.cwd().statFile(rsync_path);
|
||||
if (stat.mode & 0o111 != 0) {
|
||||
return try self.allocator.dupe(u8, rsync_path);
|
||||
}
|
||||
} else |_| {
|
||||
// Need to extract the binary
|
||||
try self.extractAndSetExecutable(rsync_path);
|
||||
fn cacheBaseDir(self: Self) ![]u8 {
|
||||
if (std.posix.getenv("XDG_CACHE_HOME")) |z| {
|
||||
return self.allocator.dupe(u8, std.mem.sliceTo(z, 0));
|
||||
}
|
||||
|
||||
return try self.allocator.dupe(u8, rsync_path);
|
||||
const home = if (std.posix.getenv("HOME")) |z| std.mem.sliceTo(z, 0) else return error.MissingHome;
|
||||
const target = @import("builtin").target;
|
||||
if (target.os.tag == .macos) {
|
||||
return std.fmt.allocPrint(self.allocator, "{s}/Library/Caches", .{home});
|
||||
}
|
||||
return std.fmt.allocPrint(self.allocator, "{s}/.cache", .{home});
|
||||
}
|
||||
|
||||
fn resolveRsyncPathOverride(self: Self) !?[]u8 {
|
||||
const z = std.posix.getenv("ML_RSYNC_PATH") orelse return null;
|
||||
const p = std.mem.sliceTo(z, 0);
|
||||
if (p.len == 0) return null;
|
||||
|
||||
// Basic validation: must exist and be executable.
|
||||
const file = std.fs.openFileAbsolute(p, .{}) catch return error.InvalidRsyncPath;
|
||||
defer file.close();
|
||||
const stat = try file.stat();
|
||||
if (stat.mode & 0o111 == 0) return error.InvalidRsyncPath;
|
||||
|
||||
const dup = try self.allocator.dupe(u8, p);
|
||||
return @as(?[]u8, dup);
|
||||
}
|
||||
|
||||
fn resolveSystemRsyncPath(self: Self) !?[]u8 {
|
||||
const path_z = std.posix.getenv("PATH") orelse return null;
|
||||
const path = std.mem.sliceTo(path_z, 0);
|
||||
|
||||
var it = std.mem.splitScalar(u8, path, ':');
|
||||
while (it.next()) |dir| {
|
||||
if (dir.len == 0) continue;
|
||||
const candidate = try std.fmt.allocPrint(self.allocator, "{s}/{s}", .{ dir, "rsync" });
|
||||
errdefer self.allocator.free(candidate);
|
||||
|
||||
const file = std.fs.openFileAbsolute(candidate, .{}) catch {
|
||||
self.allocator.free(candidate);
|
||||
continue;
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const stat = file.stat() catch {
|
||||
self.allocator.free(candidate);
|
||||
continue;
|
||||
};
|
||||
|
||||
if (stat.mode & 0o111 == 0) {
|
||||
self.allocator.free(candidate);
|
||||
continue;
|
||||
}
|
||||
|
||||
return @as(?[]u8, candidate);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Extract embedded rsync binary to temporary location
|
||||
pub fn extractRsyncBinary(self: Self) ![]const u8 {
|
||||
if (try self.resolveRsyncPathOverride()) |p| {
|
||||
return p;
|
||||
}
|
||||
|
||||
if (try self.resolveSystemRsyncPath()) |p| {
|
||||
return p;
|
||||
}
|
||||
|
||||
const embedded_binary = @import("rsync_embedded_binary.zig");
|
||||
const digest = embedded_binary.rsyncBinarySha256();
|
||||
|
||||
var digest_hex: [64]u8 = undefined;
|
||||
encodeHexLower(digest_hex[0..], digest[0..]);
|
||||
|
||||
const cache_base = try self.cacheBaseDir();
|
||||
defer self.allocator.free(cache_base);
|
||||
|
||||
const cache_dir = try std.fmt.allocPrint(self.allocator, "{s}/fetchml/rsync", .{cache_base});
|
||||
defer self.allocator.free(cache_dir);
|
||||
|
||||
std.fs.cwd().makePath(cache_dir) catch |err| switch (err) {
|
||||
error.PathAlreadyExists => {},
|
||||
else => return err,
|
||||
};
|
||||
|
||||
const rsync_path = try std.fmt.allocPrint(self.allocator, "{s}/ml_rsync_{s}", .{ cache_dir, digest_hex[0..16] });
|
||||
errdefer self.allocator.free(rsync_path);
|
||||
|
||||
// If file exists and matches the embedded digest, reuse it.
|
||||
if (std.fs.openFileAbsolute(rsync_path, .{})) |file| {
|
||||
defer file.close();
|
||||
|
||||
const stat = try file.stat();
|
||||
if (stat.mode & 0o111 != 0) {
|
||||
const data = try file.readToEndAlloc(self.allocator, 1024 * 1024 * 4);
|
||||
defer self.allocator.free(data);
|
||||
|
||||
var on_disk: [32]u8 = undefined;
|
||||
std.crypto.hash.sha2.Sha256.hash(data, &on_disk, .{});
|
||||
if (std.mem.eql(u8, on_disk[0..], digest[0..])) {
|
||||
return rsync_path;
|
||||
}
|
||||
}
|
||||
} else |_| {
|
||||
// Not present; will extract.
|
||||
}
|
||||
|
||||
try self.extractAndSetExecutable(rsync_path);
|
||||
return rsync_path;
|
||||
}
|
||||
|
||||
/// Extract rsync binary from embedded data and set executable permissions
|
||||
|
|
@ -39,14 +142,15 @@ pub const EmbeddedRsync = struct {
|
|||
defer self.allocator.free(debug_msg);
|
||||
std.log.debug("{s}", .{debug_msg});
|
||||
|
||||
// Write embedded binary to file system
|
||||
try std.fs.cwd().writeFile(.{ .sub_path = path, .data = binary_data });
|
||||
// Write embedded binary to file system (absolute path)
|
||||
var file = try std.fs.createFileAbsolute(path, .{ .truncate = true, .read = false, .mode = 0o755 });
|
||||
defer file.close();
|
||||
try file.writeAll(binary_data);
|
||||
|
||||
// Set executable permissions using OS API
|
||||
// Ensure executable permissions
|
||||
const mode = 0o755; // rwxr-xr-x
|
||||
std.posix.fchmodat(std.fs.cwd().fd, path, mode, 0) catch |err| {
|
||||
std.posix.fchmod(file.handle, mode) catch |err| {
|
||||
std.log.warn("Failed to set executable permissions on {s}: {}", .{ path, err });
|
||||
// Continue anyway - the script might still work
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +183,13 @@ pub const EmbeddedRsync = struct {
|
|||
child.stdout_behavior = .Inherit;
|
||||
child.stderr_behavior = .Inherit;
|
||||
|
||||
const term = try child.spawnAndWait();
|
||||
const term = child.spawnAndWait() catch |err| {
|
||||
std.log.err(
|
||||
"Failed to execute rsync at '{s}': {}. If your environment blocks executing extracted binaries (e.g. noexec), set ML_RSYNC_PATH to a system rsync.",
|
||||
.{ rsync_path, err },
|
||||
);
|
||||
return err;
|
||||
};
|
||||
|
||||
switch (term) {
|
||||
.Exited => |code| {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,119 @@
|
|||
const std = @import("std");
|
||||
const build_options = @import("build_options");
|
||||
|
||||
fn isScript(data: []const u8) bool {
|
||||
return data.len >= 2 and data[0] == '#' and data[1] == '!';
|
||||
}
|
||||
|
||||
fn readU32BE(data: []const u8, offset: usize) ?u32 {
|
||||
if (data.len < offset + 4) return null;
|
||||
return (@as(u32, data[offset]) << 24) |
|
||||
(@as(u32, data[offset + 1]) << 16) |
|
||||
(@as(u32, data[offset + 2]) << 8) |
|
||||
(@as(u32, data[offset + 3]));
|
||||
}
|
||||
|
||||
fn readU32LE(data: []const u8, offset: usize) ?u32 {
|
||||
if (data.len < offset + 4) return null;
|
||||
return (@as(u32, data[offset])) |
|
||||
(@as(u32, data[offset + 1]) << 8) |
|
||||
(@as(u32, data[offset + 2]) << 16) |
|
||||
(@as(u32, data[offset + 3]) << 24);
|
||||
}
|
||||
|
||||
fn readU16LE(data: []const u8, offset: usize) ?u16 {
|
||||
if (data.len < offset + 2) return null;
|
||||
return (@as(u16, data[offset])) | (@as(u16, data[offset + 1]) << 8);
|
||||
}
|
||||
|
||||
fn readU16BE(data: []const u8, offset: usize) ?u16 {
|
||||
if (data.len < offset + 2) return null;
|
||||
return (@as(u16, data[offset]) << 8) | (@as(u16, data[offset + 1]));
|
||||
}
|
||||
|
||||
fn machCpuTypeForArch(arch: std.Target.Cpu.Arch) ?u32 {
|
||||
return switch (arch) {
|
||||
.x86_64 => 0x01000007,
|
||||
.aarch64 => 0x0100000c,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
|
||||
fn elfMachineForArch(arch: std.Target.Cpu.Arch) ?u16 {
|
||||
return switch (arch) {
|
||||
.x86_64 => 62,
|
||||
.aarch64 => 183,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
|
||||
fn machOHasCpuType(data: []const u8, want_cputype: u32) bool {
|
||||
// Mach-O header: u32 magic; i32 cputype at offset 4.
|
||||
// Accept both endiannesses (cputype will appear swapped if magic swapped).
|
||||
const magic_le = readU32LE(data, 0) orelse return false;
|
||||
const magic_be = readU32BE(data, 0) orelse return false;
|
||||
|
||||
const MH_MAGIC_64: u32 = 0xfeedfacf;
|
||||
const MH_CIGAM_64: u32 = 0xcffaedfe;
|
||||
|
||||
if (magic_le == MH_MAGIC_64) {
|
||||
const cputype = readU32LE(data, 4) orelse return false;
|
||||
return cputype == want_cputype;
|
||||
}
|
||||
if (magic_be == MH_MAGIC_64 or magic_le == MH_CIGAM_64) {
|
||||
const cputype = readU32BE(data, 4) orelse return false;
|
||||
return cputype == want_cputype;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fn fatMachOHasCpuType(data: []const u8, want_cputype: u32) bool {
|
||||
// fat_header: magic (0xcafebabe big-endian), nfat_arch u32
|
||||
const FAT_MAGIC: u32 = 0xcafebabe;
|
||||
const magic = readU32BE(data, 0) orelse return false;
|
||||
if (magic != FAT_MAGIC) return false;
|
||||
|
||||
const nfat = readU32BE(data, 4) orelse return false;
|
||||
// fat_arch entries start at offset 8, 20 bytes each, big-endian
|
||||
var off: usize = 8;
|
||||
var i: u32 = 0;
|
||||
while (i < nfat) : (i += 1) {
|
||||
const cputype = readU32BE(data, off) orelse return false;
|
||||
if (cputype == want_cputype) return true;
|
||||
off += 20;
|
||||
if (off > data.len) return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fn elfHasMachine(data: []const u8, want_machine: u16) bool {
|
||||
// Minimal ELF check: magic + e_machine at offset 18
|
||||
if (data.len < 20) return false;
|
||||
if (!(data[0] == 0x7f and data[1] == 'E' and data[2] == 'L' and data[3] == 'F')) return false;
|
||||
const ei_data = data[5]; // 1=little, 2=big
|
||||
return switch (ei_data) {
|
||||
1 => (readU16LE(data, 18) orelse return false) == want_machine,
|
||||
2 => (readU16BE(data, 18) orelse return false) == want_machine,
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
fn isNativeForTarget(data: []const u8) bool {
|
||||
if (isScript(data)) return true;
|
||||
|
||||
const target = @import("builtin").target;
|
||||
const arch = target.cpu.arch;
|
||||
|
||||
if (machCpuTypeForArch(arch)) |want_cputype| {
|
||||
if (fatMachOHasCpuType(data, want_cputype)) return true;
|
||||
if (machOHasCpuType(data, want_cputype)) return true;
|
||||
}
|
||||
if (elfMachineForArch(arch)) |want_machine| {
|
||||
if (elfHasMachine(data, want_machine)) return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Embedded rsync binary data
|
||||
/// For dev builds: uses placeholder wrapper that calls system rsync
|
||||
|
|
@ -8,16 +123,32 @@ const std = @import("std");
|
|||
/// 1. Download or build a static rsync binary for your target platform
|
||||
/// 2. Place it at cli/src/assets/rsync_release.bin
|
||||
/// 3. Build with: zig build prod (or release/cross targets)
|
||||
const placeholder_data = @embedFile("../assets/rsync_placeholder.bin");
|
||||
|
||||
// Check if rsync_release.bin exists, otherwise fall back to placeholder
|
||||
const rsync_binary_data = if (@import("builtin").mode == .Debug or @import("builtin").mode == .ReleaseSafe)
|
||||
@embedFile("../assets/rsync_placeholder.bin")
|
||||
const release_data = if (build_options.has_rsync_release)
|
||||
@embedFile(build_options.rsync_release_path)
|
||||
else
|
||||
// For ReleaseSmall and ReleaseFast, try to use the release binary
|
||||
@embedFile("../assets/rsync_release.bin");
|
||||
placeholder_data;
|
||||
|
||||
// Prefer placeholder in Debug/ReleaseSafe. In Release*, only use rsync_release.bin if it
|
||||
// appears compatible with the target architecture; otherwise fall back to placeholder.
|
||||
const use_release = (@import("builtin").mode != .Debug and @import("builtin").mode != .ReleaseSafe) and
|
||||
build_options.has_rsync_release and
|
||||
!isScript(release_data) and
|
||||
isNativeForTarget(release_data);
|
||||
|
||||
const rsync_binary_data = if (use_release) release_data else placeholder_data;
|
||||
|
||||
pub const RSYNC_BINARY: []const u8 = rsync_binary_data;
|
||||
|
||||
pub const USING_RELEASE_BINARY: bool = use_release;
|
||||
|
||||
pub fn rsyncBinarySha256() [32]u8 {
|
||||
var digest: [32]u8 = undefined;
|
||||
std.crypto.hash.sha2.Sha256.hash(RSYNC_BINARY, &digest, .{});
|
||||
return digest;
|
||||
}
|
||||
|
||||
/// Get embedded rsync binary data
|
||||
pub fn getRsyncBinary() []const u8 {
|
||||
return RSYNC_BINARY;
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
const src = @import("src");
|
||||
const commands = src.commands;
|
||||
|
||||
test "jupyter top-level action includes create" {
|
||||
try testing.expect(src.commands.jupyter.isValidTopLevelAction("create"));
|
||||
try testing.expect(src.commands.jupyter.isValidTopLevelAction("start"));
|
||||
try testing.expect(!src.commands.jupyter.isValidTopLevelAction("bogus"));
|
||||
try testing.expect(commands.jupyter.isValidTopLevelAction("create"));
|
||||
try testing.expect(commands.jupyter.isValidTopLevelAction("start"));
|
||||
try testing.expect(!commands.jupyter.isValidTopLevelAction("bogus"));
|
||||
}
|
||||
|
||||
test "jupyter defaultWorkspacePath prefixes ./" {
|
||||
const allocator = testing.allocator;
|
||||
const p = try src.commands.jupyter.defaultWorkspacePath(allocator, "my-workspace");
|
||||
const p = try commands.jupyter.defaultWorkspacePath(allocator, "my-workspace");
|
||||
defer allocator.free(p);
|
||||
|
||||
try testing.expectEqualStrings("./my-workspace", p);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ test "CLI basic functionality" {
|
|||
// Test that CLI module can be imported
|
||||
const allocator = testing.allocator;
|
||||
_ = allocator;
|
||||
_ = src;
|
||||
|
||||
// Test basic string operations used in CLI
|
||||
const test_str = "ml sync";
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
const src = @import("src");
|
||||
|
||||
test "queue command argument parsing" {
|
||||
// Test various queue command argument combinations
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ const std = @import("std");
|
|||
const testing = std.testing;
|
||||
|
||||
const src = @import("src");
|
||||
|
||||
const protocol = src.net.protocol;
|
||||
|
||||
fn roundTrip(allocator: std.mem.Allocator, packet: protocol.ResponsePacket) !protocol.ResponsePacket {
|
||||
|
|
|
|||
|
|
@ -1,30 +1,121 @@
|
|||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
|
||||
// Simple mock rsync for testing
|
||||
const MockRsyncEmbedded = struct {
|
||||
const EmbeddedRsync = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
const c = @cImport({
|
||||
@cInclude("stdlib.h");
|
||||
});
|
||||
|
||||
fn extractRsyncBinary(self: EmbeddedRsync) ![]const u8 {
|
||||
// Simple mock - return a dummy path
|
||||
return try std.fmt.allocPrint(self.allocator, "/tmp/mock_rsync", .{});
|
||||
}
|
||||
};
|
||||
};
|
||||
const src = @import("src");
|
||||
const utils = src.utils;
|
||||
|
||||
const rsync_embedded = MockRsyncEmbedded;
|
||||
fn isScript(data: []const u8) bool {
|
||||
return data.len >= 2 and data[0] == '#' and data[1] == '!';
|
||||
}
|
||||
|
||||
test "embedded rsync binary creation" {
|
||||
const allocator = testing.allocator;
|
||||
|
||||
var embedded_rsync = rsync_embedded.EmbeddedRsync{ .allocator = allocator };
|
||||
// Ensure the override doesn't influence this test.
|
||||
_ = c.unsetenv("ML_RSYNC_PATH");
|
||||
|
||||
// Force embedded fallback by clearing PATH for this process so system rsync isn't found.
|
||||
const old_path_z = std.posix.getenv("PATH");
|
||||
if (c.setenv("PATH", "", 1) != 0) return error.Unexpected;
|
||||
defer {
|
||||
if (old_path_z) |z| {
|
||||
_ = c.setenv("PATH", z, 1);
|
||||
} else {
|
||||
_ = c.unsetenv("PATH");
|
||||
}
|
||||
}
|
||||
|
||||
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
|
||||
|
||||
// Test binary extraction
|
||||
const rsync_path = try embedded_rsync.extractRsyncBinary();
|
||||
defer allocator.free(rsync_path);
|
||||
|
||||
// Verify the path was created
|
||||
try testing.expect(rsync_path.len > 0);
|
||||
try testing.expect(std.mem.startsWith(u8, rsync_path, "/tmp/"));
|
||||
|
||||
// File exists and is executable
|
||||
const file = try std.fs.openFileAbsolute(rsync_path, .{});
|
||||
defer file.close();
|
||||
const stat = try file.stat();
|
||||
try testing.expect(stat.mode & 0o111 != 0);
|
||||
|
||||
// Contents match embedded payload
|
||||
const data = try file.readToEndAlloc(allocator, 1024 * 1024 * 4);
|
||||
defer allocator.free(data);
|
||||
|
||||
var digest: [32]u8 = undefined;
|
||||
std.crypto.hash.sha2.Sha256.hash(data, &digest, .{});
|
||||
const embedded_digest = utils.rsync_embedded_binary.rsyncBinarySha256();
|
||||
try testing.expect(std.mem.eql(u8, &digest, &embedded_digest));
|
||||
|
||||
// Sanity: if we're not using the release binary, it should be the placeholder script.
|
||||
if (!utils.rsync_embedded_binary.USING_RELEASE_BINARY) {
|
||||
try testing.expect(isScript(data));
|
||||
}
|
||||
}
|
||||
|
||||
test "embedded rsync honors ML_RSYNC_PATH override" {
|
||||
const allocator = testing.allocator;
|
||||
|
||||
// Use a known executable. We don't execute it; we only verify the override is returned.
|
||||
const override_path = "/bin/sh";
|
||||
try testing.expect(std.fs.openFileAbsolute(override_path, .{}) != error.FileNotFound);
|
||||
|
||||
if (c.setenv("ML_RSYNC_PATH", override_path, 1) != 0) return error.Unexpected;
|
||||
defer _ = c.unsetenv("ML_RSYNC_PATH");
|
||||
|
||||
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
|
||||
const rsync_path = try embedded_rsync.extractRsyncBinary();
|
||||
defer allocator.free(rsync_path);
|
||||
|
||||
try testing.expectEqualStrings(override_path, rsync_path);
|
||||
}
|
||||
|
||||
test "embedded rsync prefers system rsync when available" {
|
||||
const allocator = testing.allocator;
|
||||
|
||||
_ = c.unsetenv("ML_RSYNC_PATH");
|
||||
|
||||
// Find rsync on PATH (simple scan). If not present, skip.
|
||||
const path_z = std.posix.getenv("PATH") orelse return;
|
||||
const path = std.mem.sliceTo(path_z, 0);
|
||||
|
||||
var sys_rsync: ?[]u8 = null;
|
||||
defer if (sys_rsync) |p| allocator.free(p);
|
||||
|
||||
var it = std.mem.splitScalar(u8, path, ':');
|
||||
while (it.next()) |dir| {
|
||||
if (dir.len == 0) continue;
|
||||
const candidate = std.fmt.allocPrint(allocator, "{s}/{s}", .{ dir, "rsync" }) catch return;
|
||||
errdefer allocator.free(candidate);
|
||||
|
||||
const file = std.fs.openFileAbsolute(candidate, .{}) catch {
|
||||
allocator.free(candidate);
|
||||
continue;
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const st = file.stat() catch {
|
||||
allocator.free(candidate);
|
||||
continue;
|
||||
};
|
||||
if (st.mode & 0o111 == 0) {
|
||||
allocator.free(candidate);
|
||||
continue;
|
||||
}
|
||||
|
||||
sys_rsync = candidate;
|
||||
break;
|
||||
}
|
||||
|
||||
if (sys_rsync == null) return;
|
||||
|
||||
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
|
||||
const rsync_path = try embedded_rsync.extractRsyncBinary();
|
||||
defer allocator.free(rsync_path);
|
||||
|
||||
try testing.expectEqualStrings(sys_rsync.?, rsync_path);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const testing = std.testing;
|
|||
|
||||
const src = @import("src");
|
||||
|
||||
const ws = src.net.ws;
|
||||
const ws = src.net.ws_client;
|
||||
|
||||
test "status prewarm formatting - single entry" {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
|
|
|
|||
Loading…
Reference in a new issue