feat(cli): enhance Zig CLI with new commands and improved networking

- Add new commands: annotate, narrative, requeue
- Refactor WebSocket client into modular components (net/ws/)
- Add rsync embedded binary support
- Improve error handling and response packet processing
- Update build.zig and completions
This commit is contained in:
Jeremie Fraeys 2026-02-12 12:05:10 -05:00
parent df5d872021
commit 8e3fa94322
No known key found for this signature in database
48 changed files with 4182 additions and 1958 deletions

View file

@ -4,21 +4,36 @@ ZIG ?= zig
BUILD_DIR ?= zig-out/bin
BINARY := $(BUILD_DIR)/ml
.PHONY: all prod dev install clean help
.PHONY: all prod dev test build-rsync install clean help
RSYNC_VERSION ?= 3.3.0
RSYNC_SRC_BASE ?= https://download.samba.org/pub/rsync/src
RSYNC_TARBALL ?= rsync-$(RSYNC_VERSION).tar.gz
RSYNC_TARBALL_SHA256 ?=
all: $(BINARY)
$(BUILD_DIR):
mkdir -p $(BUILD_DIR)
$(BINARY): src/main.zig | $(BUILD_DIR)
$(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BINARY) src/main.zig
$(BINARY): | $(BUILD_DIR)
$(ZIG) build --release=small
prod: src/main.zig | $(BUILD_DIR)
$(ZIG) build-exe -OReleaseSmall -fstrip -femit-bin=$(BUILD_DIR)/ml src/main.zig
$(ZIG) build --release=small
dev: src/main.zig | $(BUILD_DIR)
$(ZIG) build-exe -OReleaseFast -femit-bin=$(BUILD_DIR)/ml src/main.zig
$(ZIG) build --release=fast
test:
$(ZIG) build test
build-rsync:
@RSYNC_VERSION="$(RSYNC_VERSION)" \
RSYNC_SRC_BASE="$(RSYNC_SRC_BASE)" \
RSYNC_TARBALL="$(RSYNC_TARBALL)" \
RSYNC_TARBALL_SHA256="$(RSYNC_TARBALL_SHA256)" \
bash "$(CURDIR)/scripts/build_rsync.sh"
install: $(BINARY)
install -d $(DESTDIR)/usr/local/bin
@ -32,5 +47,7 @@ help:
@echo " all - build release-small binary (default)"
@echo " prod - build production binary with ReleaseSmall"
@echo " dev - build development binary with ReleaseFast"
@echo " test - run Zig unit tests"
@echo " build-rsync - build pinned rsync from official source into src/assets (RSYNC_VERSION=... override)"
@echo " install - copy binary into /usr/local/bin"
@echo " clean - remove build artifacts"

View file

@ -19,10 +19,12 @@ zig build
- `ml init` - Setup configuration
- `ml sync <path>` - Sync project to server
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N]` - Queue one or more jobs
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N] [--note <text>]` - Queue one or more jobs
- `ml status` - Check system/queue status for your API key
- `ml validate <commit_id> [--json] [--task <task_id>]` - Validate provenance + integrity for a commit or task (includes `run_manifest.json` consistency checks when validating by task)
- `ml info <path|id> [--json] [--base <path>]` - Show run info from `run_manifest.json` (by path or by scanning `finished/failed/running/pending`)
- `ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]` - Append a human annotation to `run_manifest.json`
- `ml narrative set <path|run_id|task_id> [--hypothesis <text>] [--context <text>] [--intent <text>] [--expected-outcome <text>] [--parent-run <id>] [--experiment-group <text>] [--tags <csv>] [--base <path>] [--json]` - Patch the `narrative` field in `run_manifest.json`
- `ml monitor` - Launch monitoring interface (TUI)
- `ml cancel <job>` - Cancel a running/queued job you own
- `ml prune --keep N` - Keep N recent experiments
@ -31,6 +33,7 @@ zig build
Notes:
- `--json` mode is designed to be pipe-friendly: machine-readable JSON is emitted to stdout, while user-facing messages/errors go to stderr.
- When running `ml validate --task <task_id>`, the server will try to locate the job's `run_manifest.json` under the configured base path (pending/running/finished/failed) and cross-check key fields (task id, commit id, deps, snapshot).
- For tasks in `running`, `completed`, or `failed` state, a missing `run_manifest.json` is treated as a validation failure. For `queued` tasks, it is treated as a warning (the job may not have started yet).
@ -43,6 +46,9 @@ Notes:
Queues a job named `my-job`. If `--commit` is omitted, the CLI generates a random commit ID
and records `(job_name, commit_id)` in `~/.ml/history.log` so you don't have to remember hashes.
- `ml queue my-job --note "baseline run; lr=1e-3"`
Adds a human-readable note to the run; it will be persisted into the run's `run_manifest.json` (under `metadata.note`).
- `ml experiment list`
Shows recent experiments from history with alias (job name) and commit ID.

View file

@ -5,9 +5,63 @@ pub fn build(b: *std.Build) void {
// Standard target options
const target = b.standardTargetOptions(.{});
const test_filter = b.option([]const u8, "test-filter", "Filter unit tests by name");
_ = test_filter;
// Optimized release mode for size
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseSmall });
const options = b.addOptions();
const arch = target.result.cpu.arch;
const os_tag = target.result.os.tag;
const arch_str: []const u8 = switch (arch) {
.x86_64 => "x86_64",
.aarch64 => "arm64",
else => "unknown",
};
const os_str: []const u8 = switch (os_tag) {
.linux => "linux",
.macos => "darwin",
.windows => "windows",
else => "unknown",
};
const candidate_specific = b.fmt("src/assets/rsync_release_{s}_{s}.bin", .{ os_str, arch_str });
const candidate_default = "src/assets/rsync_release.bin";
var selected_candidate: []const u8 = "";
var has_rsync_release = false;
// Prefer a platform-specific asset if available.
if (std.fs.cwd().openFile(candidate_specific, .{}) catch null) |f| {
f.close();
selected_candidate = candidate_specific;
has_rsync_release = true;
} else if (std.fs.cwd().openFile(candidate_default, .{}) catch null) |f| {
f.close();
selected_candidate = candidate_default;
has_rsync_release = true;
}
if ((optimize == .ReleaseSmall or optimize == .ReleaseFast) and !has_rsync_release) {
std.debug.panic(
"Release build requires an embedded rsync binary asset. Provide one of: '{s}' or '{s}'",
.{ candidate_specific, candidate_default },
);
}
// rsync_embedded_binary.zig calls @embedFile() from cli/src/utils, so the embed path
// must be relative to that directory.
const selected_embed_path = if (has_rsync_release)
b.fmt("../assets/{s}", .{std.fs.path.basename(selected_candidate)})
else
"";
options.addOption(bool, "has_rsync_release", has_rsync_release);
options.addOption([]const u8, "rsync_release_path", selected_embed_path);
// CLI executable
const exe = b.addExecutable(.{
.name = "ml",
@ -18,6 +72,10 @@ pub fn build(b: *std.Build) void {
}),
});
exe.root_module.strip = true;
exe.root_module.addOptions("build_options", options);
// Install the executable to zig-out/bin
b.installArtifact(exe);
@ -44,6 +102,7 @@ pub fn build(b: *std.Build) void {
.optimize = .Debug,
}),
});
main_tests.root_module.addOptions("build_options", options);
const run_main_tests = b.addRunArtifact(main_tests);
test_step.dependOn(&run_main_tests.step);
@ -75,6 +134,8 @@ pub fn build(b: *std.Build) void {
.optimize = .Debug,
});
test_module.addOptions("build_options", options);
// Make src module available to tests as "src"
test_module.addImport("src", src_module);

View file

@ -0,0 +1,83 @@
#!/usr/bin/env bash
set -euo pipefail
RSYNC_VERSION="${RSYNC_VERSION:-3.3.0}"
RSYNC_SRC_BASE="${RSYNC_SRC_BASE:-https://download.samba.org/pub/rsync/src}"
RSYNC_TARBALL="${RSYNC_TARBALL:-rsync-${RSYNC_VERSION}.tar.gz}"
RSYNC_TARBALL_SHA256="${RSYNC_TARBALL_SHA256:-}"
os="$(uname -s | tr '[:upper:]' '[:lower:]')"
arch="$(uname -m)"
if [[ "${arch}" == "aarch64" || "${arch}" == "arm64" ]]; then arch="arm64"; fi
if [[ "${arch}" == "x86_64" ]]; then arch="x86_64"; fi
if [[ "${os}" != "linux" ]]; then
echo "build-rsync: supported on linux only (for reproducible official builds). Use system rsync on ${os} or build on a native runner." >&2
exit 2
fi
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
out="${repo_root}/src/assets/rsync_release_linux_${arch}.bin"
tmp="$(mktemp -d)"
cleanup() { rm -rf "${tmp}"; }
trap cleanup EXIT
url="${RSYNC_SRC_BASE}/${RSYNC_TARBALL}"
sig_url_asc="${url}.asc"
sig_url_sig="${url}.sig"
echo "fetching ${url}"
curl -fsSL "${url}" -o "${tmp}/rsync.tar.gz"
verified=0
if command -v gpg >/dev/null 2>&1; then
sig_file=""
sig_url=""
if curl -fsSL "${sig_url_asc}" -o "${tmp}/rsync.tar.gz.asc"; then
sig_file="${tmp}/rsync.tar.gz.asc"
sig_url="${sig_url_asc}"
elif curl -fsSL "${sig_url_sig}" -o "${tmp}/rsync.tar.gz.sig"; then
sig_file="${tmp}/rsync.tar.gz.sig"
sig_url="${sig_url_sig}"
fi
if [[ -n "${sig_file}" ]]; then
echo "verifying signature ${sig_url}"
if gpg --batch --verify "${sig_file}" "${tmp}/rsync.tar.gz"; then
verified=1
else
echo "build-rsync: gpg signature check failed (often because the public key is not in your keyring)." >&2
fi
fi
fi
if [[ "${verified}" -ne 1 ]]; then
if [[ -n "${RSYNC_TARBALL_SHA256}" ]]; then
echo "verifying sha256 for ${url}"
if command -v sha256sum >/dev/null 2>&1; then
echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | sha256sum -c -
elif command -v shasum >/dev/null 2>&1; then
echo "${RSYNC_TARBALL_SHA256} ${tmp}/rsync.tar.gz" | shasum -a 256 -c -
else
echo "build-rsync: need sha256sum or shasum for checksum verification" >&2
exit 2
fi
else
echo "build-rsync: could not verify ${url} (no usable gpg signature, and RSYNC_TARBALL_SHA256 is empty)." >&2
echo "Set RSYNC_TARBALL_SHA256=<expected sha256> or install gpg with a trusted key for the rsync signing identity." >&2
exit 2
fi
fi
tar -C "${tmp}" -xzf "${tmp}/rsync.tar.gz"
extract_dir="$(tar -tzf "${tmp}/rsync.tar.gz" | head -n 1 | cut -d/ -f1)"
cd "${tmp}/${extract_dir}"
CC=musl-gcc CFLAGS="-O2" LDFLAGS="-static" ./configure --disable-xxhash --disable-zstd --disable-lz4
make -j"$(getconf _NPROCESSORS_ONLN 2>/dev/null || echo 2)"
mkdir -p "$(dirname "${out}")"
cp rsync "${out}"
chmod +x "${out}"
echo "built ${out}"

View file

@ -15,7 +15,7 @@ _ml_completions()
global_opts="--help --verbose --quiet --monitor"
# Top-level subcommands
cmds="init sync queue status monitor cancel prune watch dataset experiment"
cmds="init sync queue requeue status monitor cancel prune watch dataset experiment"
# If completing the subcommand itself
if [[ ${COMP_CWORD} -eq 1 ]]; then
@ -41,8 +41,51 @@ _ml_completions()
COMPREPLY=( $(compgen -d -- "${cur}") )
;;
queue)
# Suggest common job names (static for now)
COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") )
queue_opts="--commit --priority --cpu --memory --gpu --gpu-memory --snapshot-id --snapshot-sha256 --args -- ${global_opts}"
case "${prev}" in
--priority)
COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") )
;;
--cpu|--memory|--gpu)
COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") )
;;
--gpu-memory)
COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") )
;;
--commit|--snapshot-id|--snapshot-sha256|--args)
# Free-form; no special completion
;;
*)
if [[ "${cur}" == --* ]]; then
COMPREPLY=( $(compgen -W "${queue_opts}" -- "${cur}") )
else
# Suggest common job names (static for now)
COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") )
fi
;;
esac
;;
requeue)
requeue_opts="--name --priority --cpu --memory --gpu --gpu-memory --args -- ${global_opts}"
case "${prev}" in
--priority)
COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") )
;;
--cpu|--memory|--gpu)
COMPREPLY=( $(compgen -W "0 1 2 4 8 16 32" -- "${cur}") )
;;
--gpu-memory)
COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") )
;;
--name|--args)
# Free-form; no special completion
;;
*)
if [[ "${cur}" == --* ]]; then
COMPREPLY=( $(compgen -W "${requeue_opts}" -- "${cur}") )
fi
;;
esac
;;
status)
COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") )

View file

@ -10,6 +10,7 @@ _ml() {
'init:Setup configuration interactively'
'sync:Sync project to server'
'queue:Queue job for execution'
'requeue:Re-submit a previous run/commit'
'status:Get system status'
'monitor:Launch TUI via SSH'
'cancel:Cancel running job'
@ -53,7 +54,33 @@ _ml() {
'--verbose[Enable verbose output]' \
'--quiet[Suppress non-error output]' \
'--monitor[Monitor progress]' \
'1:job name:'
'--commit[Commit id (40-hex) or unique prefix (>=7)]:commit id:' \
'--priority[Priority (0-255)]:priority:' \
'--cpu[CPU cores]:cpu:' \
'--memory[Memory (GB)]:memory:' \
'--gpu[GPU count]:gpu:' \
'--gpu-memory[GPU memory]:gpu memory:' \
'--snapshot-id[Snapshot id]:snapshot id:' \
'--snapshot-sha256[Snapshot sha256]:snapshot sha256:' \
'--args[Runner args string]:args:' \
'1:job name:' \
'*:args separator:(--)'
;;
requeue)
_arguments -C \
'--help[Show requeue help]' \
'--verbose[Enable verbose output]' \
'--quiet[Suppress non-error output]' \
'--monitor[Monitor progress]' \
'--name[Override job name]:job name:' \
'--priority[Priority (0-255)]:priority:' \
'--cpu[CPU cores]:cpu:' \
'--memory[Memory (GB)]:memory:' \
'--gpu[GPU count]:gpu:' \
'--gpu-memory[GPU memory]:gpu memory:' \
'--args[Runner args string]:args:' \
'1:commit_id|run_id|task_id|path:' \
'*:args separator:(--)'
;;
status)
_arguments -C \

View file

@ -3,4 +3,5 @@ pub const commands = @import("src/commands.zig");
pub const net = @import("src/net.zig");
pub const utils = @import("src/utils.zig");
pub const config = @import("src/config.zig");
pub const Config = @import("src/config.zig").Config;
pub const errors = @import("src/errors.zig");

View file

@ -5,7 +5,7 @@
This directory contains rsync binaries for the ML CLI:
- `rsync_placeholder.bin` - Wrapper script for dev builds (calls system rsync)
- `rsync_release.bin` - Full static rsync binary for release builds (not in repo)
- `rsync_release_<os>_<arch>.bin` - Static rsync binary for release builds (not in repo)
## Build Modes
@ -16,44 +16,35 @@ This directory contains rsync binaries for the ML CLI:
- Requires rsync installed on the system
### Release Builds (ReleaseSmall, ReleaseFast)
- Uses `rsync_release.bin` (300-500KB static binary)
- Uses `rsync_release_<os>_<arch>.bin` (static binary)
- Fully self-contained, no dependencies
- Results in ~450-650KB CLI binary
- Works on any system without rsync installed
## Preparing Release Binaries
### Option 1: Download Pre-built Static Rsync
### Option 1: Build from Official Rsync Source (recommended)
For macOS ARM64:
On Linux:
```bash
cd cli/src/assets
curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-macos-arm64 -o rsync_release.bin
chmod +x rsync_release.bin
cd cli
make build-rsync RSYNC_VERSION=3.3.0
```
For Linux x86_64:
```bash
cd cli/src/assets
curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-linux-x86_64 -o rsync_release.bin
chmod +x rsync_release.bin
```
### Option 2: Build Static Rsync Yourself
### Option 2: Build Rsync Yourself
```bash
# Clone rsync
git clone https://github.com/WayneD/rsync.git
cd rsync
# Download official source
curl -fsSL https://download.samba.org/pub/rsync/src/rsync-3.3.0.tar.gz -o rsync.tar.gz
tar -xzf rsync.tar.gz
cd rsync-3.3.0
# Configure for static build
./configure CFLAGS="-static" LDFLAGS="-static" --disable-xxhash --disable-zstd
# Build
# Configure & build
./configure --disable-xxhash --disable-zstd --disable-lz4
make
# Copy to assets
cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin
# Copy to assets (example)
cp rsync ../fetch_ml/cli/src/assets/rsync_release_linux_x86_64.bin
```
### Option 3: Use System Rsync (Temporary)
@ -61,21 +52,22 @@ cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin
For testing release builds without a static binary:
```bash
cd cli/src/assets
cp rsync_placeholder.bin rsync_release.bin
cp rsync_placeholder.bin rsync_release_linux_x86_64.bin
```
This will still use the wrapper, but allows builds to complete.
## Verification
After placing rsync_release.bin:
After placing the appropriate `rsync_release_<os>_<arch>.bin`:
```bash
# Verify it's executable
file cli/src/assets/rsync_release.bin
# Test it
./cli/src/assets/rsync_release.bin --version
# Verify it's executable (example)
file cli/src/assets/rsync_release_linux_x86_64.bin
# Test it (example)
./cli/src/assets/rsync_release_linux_x86_64.bin --version
# Build release
cd cli

View file

@ -1,14 +1,16 @@
// Commands module - exports all command modules
pub const queue = @import("commands/queue.zig");
pub const sync = @import("commands/sync.zig");
pub const status = @import("commands/status.zig");
pub const dataset = @import("commands/dataset.zig");
pub const jupyter = @import("commands/jupyter.zig");
pub const init = @import("commands/init.zig");
pub const info = @import("commands/info.zig");
pub const monitor = @import("commands/monitor.zig");
pub const annotate = @import("commands/annotate.zig");
pub const cancel = @import("commands/cancel.zig");
pub const prune = @import("commands/prune.zig");
pub const watch = @import("commands/watch.zig");
pub const dataset = @import("commands/dataset.zig");
pub const experiment = @import("commands/experiment.zig");
pub const info = @import("commands/info.zig");
pub const init = @import("commands/init.zig");
pub const jupyter = @import("commands/jupyter.zig");
pub const monitor = @import("commands/monitor.zig");
pub const narrative = @import("commands/narrative.zig");
pub const prune = @import("commands/prune.zig");
pub const queue = @import("commands/queue.zig");
pub const requeue = @import("commands/requeue.zig");
pub const status = @import("commands/status.zig");
pub const sync = @import("commands/sync.zig");
pub const validate = @import("commands/validate.zig");
pub const watch = @import("commands/watch.zig");

View file

@ -0,0 +1,282 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) {
try printUsage();
return;
}
const target = args[0];
var author: []const u8 = "";
var note: ?[]const u8 = null;
var base_override: ?[]const u8 = null;
var json_mode: bool = false;
var i: usize = 1;
while (i < args.len) : (i += 1) {
const a = args[i];
if (std.mem.eql(u8, a, "--author")) {
if (i + 1 >= args.len) {
colors.printError("Missing value for --author\n", .{});
return error.InvalidArgs;
}
author = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--note")) {
if (i + 1 >= args.len) {
colors.printError("Missing value for --note\n", .{});
return error.InvalidArgs;
}
note = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--base")) {
if (i + 1 >= args.len) {
colors.printError("Missing value for --base\n", .{});
return error.InvalidArgs;
}
base_override = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--json")) {
json_mode = true;
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
} else if (std.mem.startsWith(u8, a, "--")) {
colors.printError("Unknown option: {s}\n", .{a});
return error.InvalidArgs;
} else {
colors.printError("Unexpected argument: {s}\n", .{a});
return error.InvalidArgs;
}
}
if (note == null or std.mem.trim(u8, note.?, " \t\r\n").len == 0) {
colors.printError("--note is required\n", .{});
try printUsage();
return error.InvalidArgs;
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const resolved_base = base_override orelse cfg.worker_base;
const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| {
if (err == error.FileNotFound) {
colors.printError(
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
.{target},
);
}
return err;
};
defer allocator.free(manifest_path);
const job_name = try readJobNameFromManifest(allocator, manifest_path);
defer allocator.free(job_name);
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendAnnotateRun(job_name, author, note.?, api_key_hash);
if (json_mode) {
const msg = try client.receiveMessage(allocator);
defer allocator.free(msg);
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
var out = io.stdoutWriter();
try out.print("{s}\n", .{msg});
return error.InvalidPacket;
};
defer {
if (packet.success_message) |m| allocator.free(m);
if (packet.error_message) |m| allocator.free(m);
if (packet.error_details) |m| allocator.free(m);
}
const Result = struct {
ok: bool,
job_name: []const u8,
message: []const u8,
error_code: ?u8 = null,
error_message: ?[]const u8 = null,
details: ?[]const u8 = null,
};
var out = io.stdoutWriter();
if (packet.packet_type == .error_packet) {
const res = Result{
.ok = false,
.job_name = job_name,
.message = "",
.error_code = @intFromEnum(packet.error_code.?),
.error_message = packet.error_message orelse "",
.details = packet.error_details orelse "",
};
try out.print("{f}\n", .{std.json.fmt(res, .{})});
return error.CommandFailed;
}
const res = Result{
.ok = true,
.job_name = job_name,
.message = packet.success_message orelse "",
};
try out.print("{f}\n", .{std.json.fmt(res, .{})});
return;
}
try client.receiveAndHandleResponse(allocator, "Annotate");
colors.printSuccess("Annotation added\n", .{});
colors.printInfo("Job: {s}\n", .{job_name});
}
fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 {
const data = try readFileAlloc(allocator, manifest_path);
defer allocator.free(data);
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
defer parsed.deinit();
if (parsed.value != .object) return error.InvalidManifest;
const root = parsed.value.object;
const job_name = jsonGetString(root, "job_name") orelse "";
if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) {
return error.InvalidManifest;
}
return allocator.dupe(u8, job_name);
}
fn resolveManifestPathWithBase(
allocator: std.mem.Allocator,
input: []const u8,
base: []const u8,
) ![]u8 {
var cwd = std.fs.cwd();
if (std.fs.path.isAbsolute(input)) {
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
var mutable_dir = dir;
defer mutable_dir.close();
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
var mutable_file = file;
defer mutable_file.close();
return allocator.dupe(u8, input);
}
return resolveManifestPathById(allocator, input, base);
}
const stat = cwd.statFile(input) catch |err| {
if (err == error.FileNotFound) {
return resolveManifestPathById(allocator, input, base);
}
return err;
};
if (stat.kind == .directory) {
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
return allocator.dupe(u8, input);
}
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
if (std.mem.trim(u8, id, " \t\r\n").len == 0) {
return error.FileNotFound;
}
if (base_path.len == 0) {
return error.FileNotFound;
}
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
for (roots) |root| {
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
defer allocator.free(root_path);
var dir = if (std.fs.path.isAbsolute(root_path))
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
else
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
defer dir.close();
var it = dir.iterate();
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
defer allocator.free(run_dir);
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
defer allocator.free(manifest_path);
const file = if (std.fs.path.isAbsolute(manifest_path))
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
else
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
defer file.close();
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
defer allocator.free(data);
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
defer parsed.deinit();
if (parsed.value != .object) continue;
const obj = parsed.value.object;
const run_id = jsonGetString(obj, "run_id") orelse "";
const task_id = jsonGetString(obj, "task_id") orelse "";
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
return allocator.dupe(u8, manifest_path);
}
}
}
return error.FileNotFound;
}
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
var file = if (std.fs.path.isAbsolute(path))
try std.fs.openFileAbsolute(path, .{})
else
try std.fs.cwd().openFile(path, .{});
defer file.close();
return file.readToEndAlloc(allocator, 1024 * 1024);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v = obj.get(key) orelse return null;
if (v != .string) return null;
return v.string;
}
fn printUsage() !void {
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]\n", .{});
colors.printInfo("\nExamples:\n", .{});
colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{});
colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{});
}

View file

@ -1,6 +1,6 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const crypto = @import("../utils/crypto.zig");
const logging = @import("../utils/logging.zig");
const colors = @import("../utils/colors.zig");

View file

@ -1,6 +1,6 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const colors = @import("../utils/colors.zig");
const logging = @import("../utils/logging.zig");
const crypto = @import("../utils/crypto.zig");
@ -18,7 +18,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
var options = DatasetOptions{};
// Parse global flags: --dry-run, --validate, --json
var positional = std.ArrayList([]const u8).initCapacity(allocator, args.len) catch |err| {
return err;
@ -44,36 +43,39 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
}
if (positional.items.len == 0) {
printUsage();
return error.InvalidArgs;
}
const action = positional.items[0];
if (std.mem.eql(u8, action, "list")) {
try listDatasets(allocator, &options);
} else if (std.mem.eql(u8, action, "register")) {
if (positional.items.len < 3) {
colors.printError("Usage: ml dataset register <name> <url>\n", .{});
switch (positional.items.len) {
0 => {
printUsage();
return error.InvalidArgs;
}
try registerDataset(allocator, positional.items[1], positional.items[2], &options);
} else if (std.mem.eql(u8, action, "info")) {
if (positional.items.len < 2) {
colors.printError("Usage: ml dataset info <name>\n", .{});
},
1 => {
if (std.mem.eql(u8, action, "list")) {
try listDatasets(allocator, &options);
return error.InvalidArgs;
}
},
2 => {
if (std.mem.eql(u8, action, "info")) {
try showDatasetInfo(allocator, positional.items[1], &options);
return;
} else if (std.mem.eql(u8, action, "search")) {
try searchDatasets(allocator, positional.items[1], &options);
return error.InvalidArgs;
}
},
3 => {
if (std.mem.eql(u8, action, "register")) {
try registerDataset(allocator, positional.items[1], positional.items[2], &options);
return error.InvalidArgs;
}
},
else => {
colors.printError("Unknoen action: {s}\n", .{action});
printUsage();
return error.InvalidArgs;
}
try showDatasetInfo(allocator, positional.items[1], &options);
} else if (std.mem.eql(u8, action, "search")) {
if (positional.items.len < 2) {
colors.printError("Usage: ml dataset search <term>\n", .{});
return error.InvalidArgs;
}
try searchDatasets(allocator, positional.items[1], &options);
} else {
colors.printError("Unknown action: {s}\n", .{action});
printUsage();
return error.InvalidArgs;
},
}
}

View file

@ -1,6 +1,6 @@
const std = @import("std");
const config = @import("../config.zig");
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const history = @import("../utils/history.zig");
const colors = @import("../utils/colors.zig");

View file

@ -1,6 +1,7 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const io = @import("../utils/io.zig");
pub const Options = struct {
json: bool = false,
@ -61,7 +62,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
defer allocator.free(data);
if (opts.json) {
std.debug.print("{s}\n", .{data});
var out = io.stdoutWriter();
try out.print("{s}\n", .{data});
return;
}

View file

@ -1,6 +1,6 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const protocol = @import("../net/protocol.zig");
const crypto = @import("../utils/crypto.zig");
const Config = @import("../config.zig").Config;
@ -87,7 +87,9 @@ fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to restore workspace: {s}\n", .{error_msg});
if (packet.error_message) |msg| {
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
@ -138,8 +140,6 @@ pub fn isValidTopLevelAction(action: []const u8) bool {
std.mem.eql(u8, action, "list") or
std.mem.eql(u8, action, "remove") or
std.mem.eql(u8, action, "restore") or
std.mem.eql(u8, action, "workspace") or
std.mem.eql(u8, action, "experiment") or
std.mem.eql(u8, action, "package");
}
@ -149,18 +149,18 @@ pub fn defaultWorkspacePath(allocator: std.mem.Allocator, name: []const u8) ![]u
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len < 1) {
printUsage();
printUsagePackage();
return;
}
for (args) |arg| {
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
printUsage();
printUsagePackage();
return;
}
if (std.mem.eql(u8, arg, "--json")) {
colors.printError("jupyter does not support --json\n", .{});
printUsage();
printUsagePackage();
return error.InvalidArgs;
}
}
@ -181,10 +181,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
try removeJupyter(allocator, args[1..]);
} else if (std.mem.eql(u8, action, "restore")) {
try restoreJupyter(allocator, args[1..]);
} else if (std.mem.eql(u8, action, "workspace")) {
try workspaceCommands(args[1..]);
} else if (std.mem.eql(u8, action, "experiment")) {
try experimentCommands(args[1..]);
} else if (std.mem.eql(u8, action, "package")) {
try packageCommands(args[1..]);
} else {
@ -192,12 +188,11 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
}
fn printUsage() void {
colors.printError("Usage: ml jupyter <action> [options]\n", .{});
colors.printInfo("\nActions:\n", .{});
colors.printInfo(" create|start|stop|status|list|remove|restore\n", .{});
colors.printInfo(" workspace|experiment|package\n", .{});
colors.printInfo("\nOptions:\n", .{});
fn printUsagePackage() void {
colors.printError("Usage: ml jupyter package <action> [options]\n", .{});
colors.printInfo("Actions:\n", .{});
colors.printInfo(" list\n", .{});
colors.printInfo("Options:\n", .{});
colors.printInfo(" --help, -h Show this help message\n", .{});
}
@ -340,7 +335,9 @@ fn startJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to start service: {s}\n", .{error_msg});
if (packet.error_message) |msg| {
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
@ -416,7 +413,9 @@ fn stopJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to stop service: {s}\n", .{error_msg});
if (packet.error_message) |msg| {
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
@ -533,7 +532,9 @@ fn removeJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to remove service: {s}\n", .{error_msg});
if (packet.error_message) |msg| {
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
@ -662,7 +663,9 @@ fn listServices(allocator: std.mem.Allocator) !void {
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to list services: {s}\n", .{error_msg});
if (packet.error_message) |msg| {
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
@ -776,86 +779,133 @@ fn experimentCommands(args: []const []const u8) !void {
fn packageCommands(args: []const []const u8) !void {
if (args.len < 1) {
colors.printError("Usage: ml jupyter package <install|list|pending|approve|reject>\n", .{});
colors.printError("Usage: ml jupyter package <list>\n", .{});
return;
}
const subcommand = args[0];
if (std.mem.eql(u8, subcommand, "install")) {
if (std.mem.eql(u8, subcommand, "list")) {
if (args.len < 2) {
colors.printError("Usage: ml jupyter package install --package <name> [--channel <channel>] [--version <version>]\n", .{});
colors.printError("Usage: ml jupyter package list <service-name>\n", .{});
return;
}
// Parse package name from args
var package_name: []const u8 = "";
var channel: []const u8 = "conda-forge";
var version: []const u8 = "latest";
var i: usize = 0;
while (i < args.len) {
if (std.mem.eql(u8, args[i], "--package") and i + 1 < args.len) {
package_name = args[i + 1];
i += 2;
} else if (std.mem.eql(u8, args[i], "--channel") and i + 1 < args.len) {
channel = args[i + 1];
i += 2;
} else if (std.mem.eql(u8, args[i], "--version") and i + 1 < args.len) {
version = args[i + 1];
i += 2;
} else {
i += 1;
}
var service_name: []const u8 = "";
if (std.mem.eql(u8, args[1], "--name") and args.len >= 3) {
service_name = args[2];
} else {
service_name = args[1];
}
if (package_name.len == 0) {
colors.printError("Package name is required\n", .{});
if (service_name.len == 0) {
colors.printError("Service name is required\n", .{});
return;
}
// Security validations
if (!validatePackageName(package_name)) {
colors.printError("Invalid package name: {s}. Only alphanumeric characters, underscores, hyphens, and dots are allowed.\n", .{package_name});
return;
const allocator = std.heap.page_allocator;
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
if (isPackageBlocked(package_name)) {
colors.printError("Package '{s}' is blocked by security policy for security reasons.\n", .{package_name});
colors.printInfo("Blocked packages typically include network libraries that could be used for unauthorized data access.\n", .{});
const protocol_str = if (config.worker_port == 443) "wss" else "ws";
const url = try std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{
protocol_str,
config.worker_host,
config.worker_port,
});
defer allocator.free(url);
var client = ws.Client.connect(allocator, url, config.api_key) catch |err| {
colors.printError("Failed to connect to server: {}\n", .{err});
return;
};
defer client.close();
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
client.sendListJupyterPackages(service_name, api_key_hash) catch |err| {
colors.printError("Failed to send list packages command: {}\n", .{err});
return;
};
const response = client.receiveMessage(allocator) catch |err| {
colors.printError("Failed to receive response: {}\n", .{err});
return;
};
defer allocator.free(response);
const packet = protocol.ResponsePacket.deserialize(response, allocator) catch |err| {
colors.printError("Failed to parse response: {}\n", .{err});
return;
};
defer {
if (packet.data_type) |dtype| allocator.free(dtype);
if (packet.data_payload) |payload| allocator.free(payload);
if (packet.error_message) |msg| allocator.free(msg);
if (packet.error_details) |details| allocator.free(details);
}
if (!validateChannel(channel)) {
colors.printError("Channel '{s}' is not trusted. Allowed channels: conda-forge, defaults, pytorch, nvidia\n", .{channel});
return;
}
switch (packet.packet_type) {
.data => {
colors.printInfo("Installed packages for {s}:\n", .{service_name});
if (packet.data_payload) |payload| {
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch {
std.debug.print("{s}\n", .{payload});
return;
};
defer parsed.deinit();
colors.printInfo("Requesting package installation...\n", .{});
colors.printInfo("Package: {s}\n", .{package_name});
colors.printInfo("Version: {s}\n", .{version});
colors.printInfo("Channel: {s}\n", .{channel});
colors.printInfo("Security: Package validated against security policies\n", .{});
colors.printSuccess("Package request created successfully!\n", .{});
colors.printInfo("Note: Package requires approval from administrator before installation.\n", .{});
} else if (std.mem.eql(u8, subcommand, "list")) {
colors.printInfo("Installed packages in workspace: ./workspace\n", .{});
colors.printInfo("Package Name Version Channel Installed By\n", .{});
colors.printInfo("------------ ------- ------- ------------\n", .{});
colors.printInfo("numpy 1.21.0 conda-forge user1\n", .{});
colors.printInfo("pandas 1.3.0 conda-forge user1\n", .{});
} else if (std.mem.eql(u8, subcommand, "pending")) {
colors.printInfo("Pending package requests for workspace: ./workspace\n", .{});
colors.printInfo("Package Name Version Channel Requested By Time\n", .{});
colors.printInfo("------------ ------- ------- ------------ ----\n", .{});
colors.printInfo("torch 1.9.0 pytorch user3 2023-12-06 10:30\n", .{});
} else if (std.mem.eql(u8, subcommand, "approve")) {
colors.printInfo("Approving package request: torch\n", .{});
colors.printSuccess("Package request approved!\n", .{});
} else if (std.mem.eql(u8, subcommand, "reject")) {
colors.printInfo("Rejecting package request: suspicious-package\n", .{});
colors.printInfo("Reason: Security policy violation\n", .{});
colors.printSuccess("Package request rejected!\n", .{});
if (parsed.value != .array) {
std.debug.print("{s}\n", .{payload});
return;
}
const pkgs = parsed.value.array;
if (pkgs.items.len == 0) {
colors.printInfo("No packages found.\n", .{});
return;
}
colors.printInfo("NAME VERSION SOURCE\n", .{});
colors.printInfo("---- ------- ------\n", .{});
for (pkgs.items) |item| {
if (item != .object) continue;
const obj = item.object;
var name: []const u8 = "";
if (obj.get("name")) |v| {
if (v == .string) name = v.string;
}
var version: []const u8 = "";
if (obj.get("version")) |v| {
if (v == .string) version = v.string;
}
var source: []const u8 = "";
if (obj.get("source")) |v| {
if (v == .string) source = v.string;
}
std.debug.print("{s: <30} {s: <22} {s}\n", .{ name, version, source });
}
}
},
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
colors.printError("Failed to list packages: {s}\n", .{error_msg});
if (packet.error_details) |details| {
colors.printError("Details: {s}\n", .{details});
} else if (packet.error_message) |msg| {
colors.printError("Details: {s}\n", .{msg});
}
},
else => {
colors.printError("Unexpected response type\n", .{});
},
}
} else {
colors.printError("Invalid package command: {s}\n", .{subcommand});
}

View file

@ -0,0 +1,370 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
const ws = @import("../net/ws/client.zig");
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
try printUsage();
return error.InvalidArgs;
}
const sub = argv[0];
if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) {
try printUsage();
return;
}
if (!std.mem.eql(u8, sub, "set")) {
colors.printError("Unknown subcommand: {s}\n", .{sub});
try printUsage();
return error.InvalidArgs;
}
if (argv.len < 2) {
try printUsage();
return error.InvalidArgs;
}
const target = argv[1];
var hypothesis: ?[]const u8 = null;
var context: ?[]const u8 = null;
var intent: ?[]const u8 = null;
var expected_outcome: ?[]const u8 = null;
var parent_run: ?[]const u8 = null;
var experiment_group: ?[]const u8 = null;
var tags_csv: ?[]const u8 = null;
var base_override: ?[]const u8 = null;
var json_mode: bool = false;
var i: usize = 2;
while (i < argv.len) : (i += 1) {
const a = argv[i];
if (std.mem.eql(u8, a, "--hypothesis")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
hypothesis = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--context")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
context = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--intent")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
intent = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--expected-outcome")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
expected_outcome = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--parent-run")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
parent_run = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--experiment-group")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
experiment_group = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--tags")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
tags_csv = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--base")) {
if (i + 1 >= argv.len) return error.InvalidArgs;
base_override = argv[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--json")) {
json_mode = true;
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
} else {
colors.printError("Unknown option: {s}\n", .{a});
return error.InvalidArgs;
}
}
if (hypothesis == null and context == null and intent == null and expected_outcome == null and parent_run == null and experiment_group == null and tags_csv == null) {
colors.printError("No narrative fields provided.\n", .{});
return error.InvalidArgs;
}
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const resolved_base = base_override orelse cfg.worker_base;
const manifest_path = resolveManifestPathWithBase(allocator, target, resolved_base) catch |err| {
if (err == error.FileNotFound) {
colors.printError(
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
.{target},
);
}
return err;
};
defer allocator.free(manifest_path);
const job_name = try readJobNameFromManifest(allocator, manifest_path);
defer allocator.free(job_name);
const patch_json = try buildPatchJSON(
allocator,
hypothesis,
context,
intent,
expected_outcome,
parent_run,
experiment_group,
tags_csv,
);
defer allocator.free(patch_json);
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendSetRunNarrative(job_name, patch_json, api_key_hash);
if (json_mode) {
const msg = try client.receiveMessage(allocator);
defer allocator.free(msg);
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
var out = io.stdoutWriter();
try out.print("{s}\n", .{msg});
return error.InvalidPacket;
};
defer {
if (packet.success_message) |m| allocator.free(m);
if (packet.error_message) |m| allocator.free(m);
if (packet.error_details) |m| allocator.free(m);
}
const Result = struct {
ok: bool,
job_name: []const u8,
message: []const u8,
error_code: ?u8 = null,
error_message: ?[]const u8 = null,
details: ?[]const u8 = null,
};
var out = io.stdoutWriter();
if (packet.packet_type == .error_packet) {
const res = Result{
.ok = false,
.job_name = job_name,
.message = "",
.error_code = @intFromEnum(packet.error_code.?),
.error_message = packet.error_message orelse "",
.details = packet.error_details orelse "",
};
try out.print("{f}\n", .{std.json.fmt(res, .{})});
return error.CommandFailed;
}
const res = Result{
.ok = true,
.job_name = job_name,
.message = packet.success_message orelse "",
};
try out.print("{f}\n", .{std.json.fmt(res, .{})});
return;
}
try client.receiveAndHandleResponse(allocator, "Narrative");
colors.printSuccess("Narrative updated\n", .{});
colors.printInfo("Job: {s}\n", .{job_name});
}
fn buildPatchJSON(
allocator: std.mem.Allocator,
hypothesis: ?[]const u8,
context: ?[]const u8,
intent: ?[]const u8,
expected_outcome: ?[]const u8,
parent_run: ?[]const u8,
experiment_group: ?[]const u8,
tags_csv: ?[]const u8,
) ![]u8 {
var out = std.ArrayList(u8).initCapacity(allocator, 256) catch return error.OutOfMemory;
defer out.deinit(allocator);
var tags_list = std.ArrayList([]const u8).initCapacity(allocator, 8) catch return error.OutOfMemory;
defer tags_list.deinit(allocator);
if (tags_csv) |csv| {
var it = std.mem.splitScalar(u8, csv, ',');
while (it.next()) |part| {
const trimmed = std.mem.trim(u8, part, " \t\r\n");
if (trimmed.len == 0) continue;
try tags_list.append(allocator, trimmed);
}
}
const Patch = struct {
hypothesis: ?[]const u8 = null,
context: ?[]const u8 = null,
intent: ?[]const u8 = null,
expected_outcome: ?[]const u8 = null,
parent_run: ?[]const u8 = null,
experiment_group: ?[]const u8 = null,
tags: ?[]const []const u8 = null,
};
const patch = Patch{
.hypothesis = hypothesis,
.context = context,
.intent = intent,
.expected_outcome = expected_outcome,
.parent_run = parent_run,
.experiment_group = experiment_group,
.tags = if (tags_list.items.len > 0) tags_list.items else null,
};
const writer = out.writer(allocator);
try writer.print("{f}", .{std.json.fmt(patch, .{})});
return out.toOwnedSlice(allocator);
}
fn readJobNameFromManifest(allocator: std.mem.Allocator, manifest_path: []const u8) ![]u8 {
const data = try readFileAlloc(allocator, manifest_path);
defer allocator.free(data);
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
defer parsed.deinit();
if (parsed.value != .object) return error.InvalidManifest;
const root = parsed.value.object;
const job_name = jsonGetString(root, "job_name") orelse "";
if (std.mem.trim(u8, job_name, " \t\r\n").len == 0) {
return error.InvalidManifest;
}
return allocator.dupe(u8, job_name);
}
fn resolveManifestPathWithBase(allocator: std.mem.Allocator, input: []const u8, base: []const u8) ![]u8 {
var cwd = std.fs.cwd();
if (std.fs.path.isAbsolute(input)) {
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
var mutable_dir = dir;
defer mutable_dir.close();
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
var mutable_file = file;
defer mutable_file.close();
return allocator.dupe(u8, input);
}
return resolveManifestPathById(allocator, input, base);
}
const stat = cwd.statFile(input) catch |err| {
if (err == error.FileNotFound) {
return resolveManifestPathById(allocator, input, base);
}
return err;
};
if (stat.kind == .directory) {
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
return allocator.dupe(u8, input);
}
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
if (std.mem.trim(u8, id, " \t\r\n").len == 0) {
return error.FileNotFound;
}
if (base_path.len == 0) {
return error.FileNotFound;
}
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
for (roots) |root| {
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
defer allocator.free(root_path);
var dir = if (std.fs.path.isAbsolute(root_path))
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
else
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
defer dir.close();
var it = dir.iterate();
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
defer allocator.free(run_dir);
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
defer allocator.free(manifest_path);
const file = if (std.fs.path.isAbsolute(manifest_path))
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
else
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
defer file.close();
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
defer allocator.free(data);
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
defer parsed.deinit();
if (parsed.value != .object) continue;
const obj = parsed.value.object;
const run_id = jsonGetString(obj, "run_id") orelse "";
const task_id = jsonGetString(obj, "task_id") orelse "";
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
return allocator.dupe(u8, manifest_path);
}
}
}
return error.FileNotFound;
}
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
var file = if (std.fs.path.isAbsolute(path))
try std.fs.openFileAbsolute(path, .{})
else
try std.fs.cwd().openFile(path, .{});
defer file.close();
return file.readToEndAlloc(allocator, 1024 * 1024);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v = obj.get(key) orelse return null;
if (v != .string) return null;
return v.string;
}
fn printUsage() !void {
colors.printInfo("Usage: ml narrative set <path|run_id|task_id> [fields]\n", .{});
colors.printInfo("\nFields:\n", .{});
colors.printInfo(" --hypothesis \"...\"\n", .{});
colors.printInfo(" --context \"...\"\n", .{});
colors.printInfo(" --intent \"...\"\n", .{});
colors.printInfo(" --expected-outcome \"...\"\n", .{});
colors.printInfo(" --parent-run <id>\n", .{});
colors.printInfo(" --experiment-group <name>\n", .{});
colors.printInfo(" --tags a,b,c\n", .{});
colors.printInfo(" --base <path>\n", .{});
colors.printInfo(" --json\n", .{});
}

View file

@ -1,6 +1,6 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const crypto = @import("../utils/crypto.zig");
const logging = @import("../utils/logging.zig");

View file

@ -1,6 +1,6 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const colors = @import("../utils/colors.zig");
const history = @import("../utils/history.zig");
const crypto = @import("../utils/crypto.zig");
@ -42,6 +42,43 @@ pub const QueueOptions = struct {
gpu_memory: ?[]const u8 = null,
};
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
for (input) |c| {
if (!std.ascii.isHex(c)) return error.InvalidArgs;
}
if (input.len == 40) {
return allocator.dupe(u8, input);
}
var dir = if (std.fs.path.isAbsolute(base_path))
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
else
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
defer dir.close();
var it = dir.iterate();
var found: ?[]u8 = null;
errdefer if (found) |s| allocator.free(s);
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const name = entry.name;
if (name.len != 40) continue;
if (!std.mem.startsWith(u8, name, input)) continue;
for (name) |c| {
if (!std.ascii.isHex(c)) break;
} else {
if (found != null) return error.InvalidArgs;
found = try allocator.dupe(u8, name);
}
}
if (found) |s| return s;
return error.FileNotFound;
}
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
try printUsage();
@ -64,6 +101,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var priority: u8 = 5;
var snapshot_id: ?[]const u8 = null;
var snapshot_sha256: ?[]const u8 = null;
var args_override: ?[]const u8 = null;
var note_override: ?[]const u8 = null;
// Load configuration to get defaults
const config = try Config.load(allocator);
@ -88,10 +127,33 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var tracking = TrackingConfig{};
var has_tracking = false;
// Support passing runner args after "--".
var sep_index: ?usize = null;
for (args, 0..) |a, idx| {
if (std.mem.eql(u8, a, "--")) {
sep_index = idx;
break;
}
}
const pre = args[0..(sep_index orelse args.len)];
const post = if (sep_index) |si| args[(si + 1)..] else args[0..0];
var args_joined: []const u8 = "";
if (post.len > 0) {
var buf: std.ArrayList(u8) = .{};
defer buf.deinit(allocator);
for (post, 0..) |a, j| {
if (j > 0) try buf.append(allocator, ' ');
try buf.appendSlice(allocator, a);
}
args_joined = try buf.toOwnedSlice(allocator);
}
defer if (post.len > 0) allocator.free(args_joined);
// Parse arguments - separate job names from flags
var i: usize = 0;
while (i < args.len) : (i += 1) {
const arg = args[i];
while (i < pre.len) : (i += 1) {
const arg = pre[i];
if (std.mem.startsWith(u8, arg, "--") or std.mem.eql(u8, arg, "-h")) {
// Parse flags
@ -99,15 +161,21 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
try printUsage();
return;
}
if (std.mem.eql(u8, arg, "--commit") and i + 1 < args.len) {
if (std.mem.eql(u8, arg, "--commit") and i + 1 < pre.len) {
if (commit_id_override != null) {
allocator.free(commit_id_override.?);
}
const commit_hex = args[i + 1];
if (commit_hex.len != 40) {
colors.printError("Invalid commit id: expected 40-char hex string\n", .{});
const commit_in = pre[i + 1];
const commit_hex = resolveCommitHexOrPrefix(allocator, config.worker_base, commit_in) catch |err| {
if (err == error.FileNotFound) {
colors.printError("No commit matches prefix: {s}\n", .{commit_in});
return error.InvalidArgs;
}
colors.printError("Invalid commit id\n", .{});
return error.InvalidArgs;
}
};
defer allocator.free(commit_hex);
const commit_bytes = crypto.decodeHex(allocator, commit_hex) catch {
colors.printError("Invalid commit id: must be hex\n", .{});
return error.InvalidArgs;
@ -119,35 +187,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
commit_id_override = commit_bytes;
i += 1;
} else if (std.mem.eql(u8, arg, "--priority") and i + 1 < args.len) {
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
} else if (std.mem.eql(u8, arg, "--priority") and i + 1 < pre.len) {
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--mlflow")) {
tracking.mlflow = TrackingConfig.MLflowConfig{};
has_tracking = true;
} else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < args.len) {
} else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < pre.len) {
tracking.mlflow = TrackingConfig.MLflowConfig{
.mode = "remote",
.tracking_uri = args[i + 1],
.tracking_uri = pre[i + 1],
};
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--tensorboard")) {
tracking.tensorboard = TrackingConfig.TensorBoardConfig{};
has_tracking = true;
} else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < args.len) {
} else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.api_key = args[i + 1];
tracking.wandb.?.api_key = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < args.len) {
} else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.project = args[i + 1];
tracking.wandb.?.project = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < args.len) {
} else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.entity = args[i + 1];
tracking.wandb.?.entity = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--dry-run")) {
@ -158,23 +226,29 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
options.explain = true;
} else if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < args.len) {
options.cpu = try std.fmt.parseInt(u8, args[i + 1], 10);
} else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < pre.len) {
options.cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--memory") and i + 1 < args.len) {
options.memory = try std.fmt.parseInt(u8, args[i + 1], 10);
} else if (std.mem.eql(u8, arg, "--memory") and i + 1 < pre.len) {
options.memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < args.len) {
options.gpu = try std.fmt.parseInt(u8, args[i + 1], 10);
} else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < pre.len) {
options.gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < args.len) {
options.gpu_memory = args[i + 1];
} else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < pre.len) {
options.gpu_memory = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < args.len) {
snapshot_id = args[i + 1];
} else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < pre.len) {
snapshot_id = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < args.len) {
snapshot_sha256 = args[i + 1];
} else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < pre.len) {
snapshot_sha256 = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--args") and i + 1 < pre.len) {
args_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
note_override = pre[i + 1];
i += 1;
}
} else {
@ -227,10 +301,25 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
defer if (commit_id_override) |cid| allocator.free(cid);
const args_str: []const u8 = if (args_override) |a| a else args_joined;
const note_str: []const u8 = if (note_override) |n| n else "";
for (job_names.items, 0..) |job_name, index| {
colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name });
queueSingleJob(allocator, job_name, commit_id_override, priority, tracking_json, &options, snapshot_id, snapshot_sha256, print_next_steps) catch |err| {
queueSingleJob(
allocator,
job_name,
commit_id_override,
priority,
tracking_json,
&options,
snapshot_id,
snapshot_sha256,
args_str,
note_str,
print_next_steps,
) catch |err| {
colors.printError("Failed to queue job '{s}': {}\n", .{ job_name, err });
failed_jobs.append(allocator, job_name) catch |append_err| {
colors.printError("Failed to track failed job: {}\n", .{append_err});
@ -274,6 +363,8 @@ fn queueSingleJob(
options: *const QueueOptions,
snapshot_id: ?[]const u8,
snapshot_sha256: ?[]const u8,
args_str: []const u8,
note_str: []const u8,
print_next_steps: bool,
) !void {
const commit_id = blk: {
@ -324,6 +415,33 @@ fn queueSingleJob(
options.gpu,
options.gpu_memory,
);
} else if (note_str.len > 0 or args_str.len > 0) {
if (note_str.len > 0) {
try client.sendQueueJobWithArgsNoteAndResources(
job_name,
commit_id,
priority,
api_key_hash,
args_str,
note_str,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
} else {
try client.sendQueueJobWithArgsAndResources(
job_name,
commit_id,
priority,
api_key_hash,
args_str,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
}
} else if (snapshot_id) |sid| {
try client.sendQueueJobWithSnapshotAndResources(
job_name,
@ -374,6 +492,9 @@ fn printUsage() !void {
colors.printInfo(" --memory <gb> Memory in GB (default: 8)\n", .{});
colors.printInfo(" --gpu <count> GPU count (default: 0)\n", .{});
colors.printInfo(" --gpu-memory <gb> GPU memory budget (default: auto)\n", .{});
colors.printInfo(" --args <string> Extra runner args (sent to worker as task.Args)\n", .{});
colors.printInfo(" --note <string> Human notes (stored in run manifest as metadata.note)\n", .{});
colors.printInfo(" -- <args...> Extra runner args (alternative to --args)\n", .{});
colors.printInfo("\nSpecial Modes:\n", .{});
colors.printInfo(" --dry-run Show what would be submitted\n", .{});
colors.printInfo(" --validate Validate experiment without submitting\n", .{});

View file

@ -0,0 +1,345 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const ws = @import("../net/ws/client.zig");
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
if (argv.len == 0) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
try printUsage();
return;
}
const target = argv[0];
// Split args at "--".
var sep_index: ?usize = null;
for (argv, 0..) |a, i| {
if (std.mem.eql(u8, a, "--")) {
sep_index = i;
break;
}
}
const pre = argv[1..(sep_index orelse argv.len)];
const post = if (sep_index) |i| argv[(i + 1)..] else argv[0..0];
const cfg = try Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
// Defaults
var job_name_override: ?[]const u8 = null;
var priority: u8 = cfg.default_priority;
var cpu: u8 = cfg.default_cpu;
var memory: u8 = cfg.default_memory;
var gpu: u8 = cfg.default_gpu;
var gpu_memory: ?[]const u8 = cfg.default_gpu_memory;
var args_override: ?[]const u8 = null;
var note_override: ?[]const u8 = null;
var i: usize = 0;
while (i < pre.len) : (i += 1) {
const a = pre[i];
if (std.mem.eql(u8, a, "--name") and i + 1 < pre.len) {
job_name_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--priority") and i + 1 < pre.len) {
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, a, "--cpu") and i + 1 < pre.len) {
cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, a, "--memory") and i + 1 < pre.len) {
memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, a, "--gpu") and i + 1 < pre.len) {
gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, a, "--gpu-memory") and i + 1 < pre.len) {
gpu_memory = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--args") and i + 1 < pre.len) {
args_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--note") and i + 1 < pre.len) {
note_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
try printUsage();
return;
} else {
colors.printError("Unknown option: {s}\n", .{a});
return error.InvalidArgs;
}
}
var args_joined: []const u8 = "";
if (post.len > 0) {
var buf: std.ArrayList(u8) = .{};
defer buf.deinit(allocator);
for (post, 0..) |a, idx| {
if (idx > 0) try buf.append(allocator, ' ');
try buf.appendSlice(allocator, a);
}
args_joined = try buf.toOwnedSlice(allocator);
}
defer if (post.len > 0) allocator.free(args_joined);
const args_final: []const u8 = if (args_override) |a| a else args_joined;
const note_final: []const u8 = if (note_override) |n| n else "";
// Target can be:
// - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base
// - run_id/task_id/path (resolved to run_manifest.json to read commit_id)
var commit_hex: []const u8 = "";
var commit_hex_owned: ?[]u8 = null;
defer if (commit_hex_owned) |s| allocator.free(s);
var commit_bytes: []u8 = &[_]u8{};
var commit_bytes_allocated = false;
defer if (commit_bytes_allocated) allocator.free(commit_bytes);
if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) {
if (target.len == 40) {
commit_hex = target;
} else {
commit_hex_owned = try resolveCommitPrefix(allocator, cfg.worker_base, target);
commit_hex = commit_hex_owned.?;
}
const decoded = crypto.decodeHex(allocator, commit_hex) catch {
commit_hex = "";
commit_hex_owned = null;
return error.InvalidCommitId;
};
if (decoded.len != 20) {
allocator.free(decoded);
commit_hex = "";
commit_hex_owned = null;
} else {
commit_bytes = decoded;
commit_bytes_allocated = true;
}
}
var job_name = blk: {
if (job_name_override) |n| break :blk n;
break :blk "requeue";
};
if (commit_hex.len == 0) {
const manifest_path = try resolveManifestPath(allocator, target, cfg.worker_base);
defer allocator.free(manifest_path);
const data = try readFileAlloc(allocator, manifest_path);
defer allocator.free(data);
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
defer parsed.deinit();
if (parsed.value != .object) return error.InvalidManifest;
const root = parsed.value.object;
commit_hex = jsonGetString(root, "commit_id") orelse "";
if (commit_hex.len != 40) {
colors.printError("run manifest missing commit_id\n", .{});
return error.InvalidManifest;
}
if (job_name_override == null) {
const j = jsonGetString(root, "job_name") orelse "";
if (j.len > 0) job_name = j;
}
const b = try crypto.decodeHex(allocator, commit_hex);
if (b.len != 20) {
allocator.free(b);
return error.InvalidCommitId;
}
commit_bytes = b;
commit_bytes_allocated = true;
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host});
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
if (note_final.len > 0) {
try client.sendQueueJobWithArgsNoteAndResources(
job_name,
commit_bytes,
priority,
api_key_hash,
args_final,
note_final,
cpu,
memory,
gpu,
gpu_memory,
);
} else {
try client.sendQueueJobWithArgsAndResources(
job_name,
commit_bytes,
priority,
api_key_hash,
args_final,
cpu,
memory,
gpu,
gpu_memory,
);
}
try client.receiveAndHandleResponse(allocator, "Requeue");
colors.printSuccess("Queued requeue\n", .{});
colors.printInfo("Job: {s}\n", .{job_name});
colors.printInfo("Commit: {s}\n", .{commit_hex});
}
fn isHexLowerOrUpper(s: []const u8) bool {
for (s) |c| {
if (!std.ascii.isHex(c)) return false;
}
return true;
}
fn resolveCommitPrefix(allocator: std.mem.Allocator, base_path: []const u8, prefix: []const u8) ![]u8 {
var dir = if (std.fs.path.isAbsolute(base_path))
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
else
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
defer dir.close();
var it = dir.iterate();
var found: ?[]u8 = null;
errdefer if (found) |s| allocator.free(s);
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const name = entry.name;
if (name.len != 40) continue;
if (!std.mem.startsWith(u8, name, prefix)) continue;
if (!isHexLowerOrUpper(name)) continue;
if (found != null) {
colors.printError("Ambiguous commit prefix: {s}\n", .{prefix});
return error.InvalidCommitId;
}
found = try allocator.dupe(u8, name);
}
if (found) |s| return s;
colors.printError("No commit matches prefix: {s}\n", .{prefix});
return error.FileNotFound;
}
fn resolveManifestPath(allocator: std.mem.Allocator, input: []const u8, base_path: []const u8) ![]u8 {
var cwd = std.fs.cwd();
if (std.fs.path.isAbsolute(input)) {
if (std.fs.openDirAbsolute(input, .{}) catch null) |dir| {
var mutable_dir = dir;
defer mutable_dir.close();
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
if (std.fs.openFileAbsolute(input, .{}) catch null) |file| {
var mutable_file = file;
defer mutable_file.close();
return allocator.dupe(u8, input);
}
return resolveManifestPathById(allocator, input, base_path);
}
const stat = cwd.statFile(input) catch |err| {
if (err == error.FileNotFound) {
return resolveManifestPathById(allocator, input, base_path);
}
return err;
};
if (stat.kind == .directory) {
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
}
return allocator.dupe(u8, input);
}
fn resolveManifestPathById(allocator: std.mem.Allocator, id: []const u8, base_path: []const u8) ![]u8 {
const roots = [_][]const u8{ "finished", "failed", "running", "pending" };
for (roots) |root| {
const root_path = try std.fs.path.join(allocator, &[_][]const u8{ base_path, root });
defer allocator.free(root_path);
var dir = if (std.fs.path.isAbsolute(root_path))
(std.fs.openDirAbsolute(root_path, .{ .iterate = true }) catch continue)
else
(std.fs.cwd().openDir(root_path, .{ .iterate = true }) catch continue);
defer dir.close();
var it = dir.iterate();
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const run_dir = try std.fs.path.join(allocator, &[_][]const u8{ root_path, entry.name });
defer allocator.free(run_dir);
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{ run_dir, "run_manifest.json" });
defer allocator.free(manifest_path);
const file = if (std.fs.path.isAbsolute(manifest_path))
(std.fs.openFileAbsolute(manifest_path, .{}) catch continue)
else
(std.fs.cwd().openFile(manifest_path, .{}) catch continue);
defer file.close();
const data = file.readToEndAlloc(allocator, 1024 * 1024) catch continue;
defer allocator.free(data);
const parsed = std.json.parseFromSlice(std.json.Value, allocator, data, .{}) catch continue;
defer parsed.deinit();
if (parsed.value != .object) continue;
const obj = parsed.value.object;
const run_id = jsonGetString(obj, "run_id") orelse "";
const task_id = jsonGetString(obj, "task_id") orelse "";
if (std.mem.eql(u8, run_id, id) or std.mem.eql(u8, task_id, id)) {
return allocator.dupe(u8, manifest_path);
}
}
}
return error.FileNotFound;
}
fn readFileAlloc(allocator: std.mem.Allocator, path: []const u8) ![]u8 {
var file = if (std.fs.path.isAbsolute(path))
try std.fs.openFileAbsolute(path, .{})
else
try std.fs.cwd().openFile(path, .{});
defer file.close();
return try file.readToEndAlloc(allocator, 1024 * 1024);
}
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v = obj.get(key) orelse return null;
if (v != .string) return null;
return v.string;
}
fn printUsage() !void {
colors.printInfo("Usage:\n", .{});
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [--name <job>] [--priority <n>] [--cpu <n>] [--memory <gb>] [--gpu <n>] [--gpu-memory <gb>] [--args <string>] [--note <string>] -- <args...>\n", .{});
}

View file

@ -1,7 +1,7 @@
const std = @import("std");
const c = @cImport(@cInclude("time.h"));
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const crypto = @import("../utils/crypto.zig");
const errors = @import("../errors.zig");
const logging = @import("../utils/logging.zig");

View file

@ -4,7 +4,7 @@ const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const rsync = @import("../utils/rsync_embedded.zig"); // Use embedded rsync
const storage = @import("../utils/storage.zig");
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const logging = @import("../utils/logging.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {

View file

@ -1,9 +1,10 @@
const std = @import("std");
const testing = std.testing;
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws.zig");
const ws = @import("../net/ws/client.zig");
const colors = @import("../utils/colors.zig");
const crypto = @import("../utils/crypto.zig");
const io = @import("../utils/io.zig");
pub const Options = struct {
json: bool = false,
@ -78,7 +79,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
defer allocator.free(msg);
const packet = @import("../net/protocol.zig").ResponsePacket.deserialize(msg, allocator) catch {
std.debug.print("{s}\n", .{msg});
if (opts.json) {
var out = io.stdoutWriter();
try out.print("{s}\n", .{msg});
} else {
std.debug.print("{s}\n", .{msg});
}
return error.InvalidPacket;
};
defer {
@ -101,7 +107,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
const payload = packet.data_payload.?;
if (opts.json) {
std.debug.print("{s}\n", .{payload});
var out = io.stdoutWriter();
try out.print("{s}\n", .{payload});
} else {
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{});
defer parsed.deinit();

View file

@ -1,8 +1,8 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const rsync = @import("../utils/rsync.zig");
const ws = @import("../net/ws.zig");
const rsync = @import("../utils/rsync_embedded.zig");
const ws = @import("../net/ws/client.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {

View file

@ -1,6 +1,5 @@
const std = @import("std");
const colors = @import("utils/colors.zig");
const ws = @import("net/ws.zig");
const crypto = @import("utils/crypto.zig");
/// User-friendly error types for CLI

View file

@ -6,6 +6,7 @@ const Command = enum {
jupyter,
init,
sync,
requeue,
queue,
status,
monitor,
@ -78,6 +79,14 @@ pub fn main() !void {
command_found = true;
try @import("commands/info.zig").run(allocator, args[2..]);
},
'a' => if (std.mem.eql(u8, command, "annotate")) {
command_found = true;
try @import("commands/annotate.zig").run(allocator, args[2..]);
},
'n' => if (std.mem.eql(u8, command, "narrative")) {
command_found = true;
try @import("commands/narrative.zig").run(allocator, args[2..]);
},
's' => if (std.mem.eql(u8, command, "sync")) {
command_found = true;
if (args.len < 3) {
@ -89,6 +98,10 @@ pub fn main() !void {
command_found = true;
try @import("commands/status.zig").run(allocator, args[2..]);
},
'r' => if (std.mem.eql(u8, command, "requeue")) {
command_found = true;
try @import("commands/requeue.zig").run(allocator, args[2..]);
},
'q' => if (std.mem.eql(u8, command, "queue")) {
command_found = true;
try @import("commands/queue.zig").run(allocator, args[2..]);
@ -127,8 +140,11 @@ fn printUsage() void {
std.debug.print("Commands:\n", .{});
std.debug.print(" jupyter Jupyter workspace management\n", .{});
std.debug.print(" init Setup configuration interactively\n", .{});
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
std.debug.print(" narrative set <path|id> Set run narrative fields (hypothesis/context/...)\n", .{});
std.debug.print(" info <path|id> Show run info from run_manifest.json (optionally --base <path>)\n", .{});
std.debug.print(" sync <path> Sync project to server\n", .{});
std.debug.print(" requeue <id> Re-submit from run_id/task_id/path (supports -- <args>)\n", .{});
std.debug.print(" queue (q) <job> Queue job for execution\n", .{});
std.debug.print(" status Get system status\n", .{});
std.debug.print(" monitor Launch TUI via SSH\n", .{});
@ -143,4 +159,7 @@ fn printUsage() void {
test {
_ = @import("commands/info.zig");
_ = @import("commands/requeue.zig");
_ = @import("commands/annotate.zig");
_ = @import("commands/narrative.zig");
}

View file

@ -1,3 +1,5 @@
// Network module - exports all network modules
pub const protocol = @import("net/protocol.zig");
pub const ws = @import("net/ws.zig");
pub const ws_client = @import("net/ws/client.zig");
pub const ws_opcodes = @import("net/ws/opcodes.zig");

File diff suppressed because it is too large Load diff

1256
cli/src/net/ws/client.zig Normal file

File diff suppressed because it is too large Load diff

8
cli/src/net/ws/deps.zig Normal file
View file

@ -0,0 +1,8 @@
pub const std = @import("std");
pub const colors = @import("../../utils/colors.zig");
pub const crypto = @import("../../utils/crypto.zig");
pub const io = @import("../../utils/io.zig");
pub const log = @import("../../utils/logging.zig");
pub const protocol = @import("../protocol.zig");

71
cli/src/net/ws/frame.zig Normal file
View file

@ -0,0 +1,71 @@
const std = @import("std");
pub fn sendWebSocketFrame(stream: std.net.Stream, payload: []const u8) !void {
var frame: [14]u8 = undefined;
var frame_len: usize = 2;
// FIN=1, opcode=0x2 (binary), MASK=1
frame[0] = 0x82 | 0x80;
// Payload length
if (payload.len < 126) {
frame[1] = @as(u8, @intCast(payload.len)) | 0x80;
} else if (payload.len < 65536) {
frame[1] = 126 | 0x80;
frame[2] = @intCast(payload.len >> 8);
frame[3] = @intCast(payload.len & 0xFF);
frame_len = 4;
} else {
return error.PayloadTooLarge;
}
// Generate random mask (4 bytes)
var mask: [4]u8 = undefined;
var i: usize = 0;
while (i < 4) : (i += 1) {
mask[i] = @as(u8, @intCast(@mod(std.time.timestamp(), 256)));
}
@memcpy(frame[frame_len .. frame_len + 4], &mask);
frame_len += 4;
_ = try stream.write(frame[0..frame_len]);
var masked_payload = try std.heap.page_allocator.alloc(u8, payload.len);
defer std.heap.page_allocator.free(masked_payload);
for (payload, 0..) |byte, j| {
masked_payload[j] = byte ^ mask[j % 4];
}
_ = try stream.write(masked_payload);
}
pub fn receiveBinaryMessage(stream: std.net.Stream, allocator: std.mem.Allocator) ![]u8 {
var header: [2]u8 = undefined;
const header_bytes = try stream.read(&header);
if (header_bytes < 2) return error.ConnectionClosed;
if (header[0] != 0x82) return error.InvalidFrame;
var payload_len: usize = header[1];
if (payload_len == 126) {
var len_bytes: [2]u8 = undefined;
_ = try stream.read(&len_bytes);
payload_len = (@as(usize, len_bytes[0]) << 8) | len_bytes[1];
} else if (payload_len == 127) {
return error.PayloadTooLarge;
}
const payload = try allocator.alloc(u8, payload_len);
errdefer allocator.free(payload);
var bytes_read: usize = 0;
while (bytes_read < payload_len) {
const n = try stream.read(payload[bytes_read..]);
if (n == 0) return error.ConnectionClosed;
bytes_read += n;
}
return payload;
}

View file

@ -0,0 +1,114 @@
const std = @import("std");
fn generateWebSocketKey(allocator: std.mem.Allocator) ![]u8 {
var random_bytes: [16]u8 = undefined;
std.crypto.random.bytes(&random_bytes);
const base64 = std.base64.standard.Encoder;
const result = try allocator.alloc(u8, base64.calcSize(random_bytes.len));
_ = base64.encode(result, &random_bytes);
return result;
}
pub fn handshake(
allocator: std.mem.Allocator,
stream: std.net.Stream,
host: []const u8,
url: []const u8,
api_key: []const u8,
) !void {
const key = try generateWebSocketKey(allocator);
defer allocator.free(key);
const request = try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
"Host: {s}\r\n" ++
"Upgrade: websocket\r\n" ++
"Connection: Upgrade\r\n" ++
"Sec-WebSocket-Key: {s}\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"X-API-Key: {s}\r\n" ++
"\r\n",
.{ url, host, key, api_key },
);
defer allocator.free(request);
_ = try stream.write(request);
var response_buf: [4096]u8 = undefined;
var bytes_read: usize = 0;
var header_complete = false;
while (!header_complete and bytes_read < response_buf.len - 1) {
const chunk_bytes = try stream.read(response_buf[bytes_read..]);
if (chunk_bytes == 0) break;
bytes_read += chunk_bytes;
if (std.mem.indexOf(u8, response_buf[0..bytes_read], "\r\n\r\n") != null) {
header_complete = true;
}
}
const response = response_buf[0..bytes_read];
if (std.mem.indexOf(u8, response, "101 Switching Protocols") == null) {
if (std.mem.indexOf(u8, response, "404 Not Found") != null) {
std.debug.print("\n❌ WebSocket Connection Failed\n", .{});
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
std.debug.print("The WebSocket endpoint '/ws' was not found on the server.\n\n", .{});
std.debug.print("This usually means:\n", .{});
std.debug.print(" • API server is not running\n", .{});
std.debug.print(" • Incorrect server address in config\n", .{});
std.debug.print(" • Different service running on that port\n\n", .{});
std.debug.print("To diagnose:\n", .{});
std.debug.print(" • Verify server address: Check ~/.ml/config.toml\n", .{});
std.debug.print(" • Test connectivity: curl http://<server>:<port>/health\n", .{});
std.debug.print(" • Contact your server administrator if the issue persists\n\n", .{});
return error.EndpointNotFound;
} else if (std.mem.indexOf(u8, response, "401 Unauthorized") != null) {
std.debug.print("\n❌ Authentication Failed\n", .{});
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
std.debug.print("Invalid or missing API key.\n\n", .{});
std.debug.print("To fix:\n", .{});
std.debug.print(" • Verify API key in ~/.ml/config.toml matches server configuration\n", .{});
std.debug.print(" • Request a new API key from your administrator if needed\n\n", .{});
return error.AuthenticationFailed;
} else if (std.mem.indexOf(u8, response, "403 Forbidden") != null) {
std.debug.print("\n❌ Access Denied\n", .{});
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
std.debug.print("Your API key doesn't have permission for this operation.\n\n", .{});
std.debug.print("To fix:\n", .{});
std.debug.print(" • Contact your administrator to grant necessary permissions\n\n", .{});
return error.PermissionDenied;
} else if (std.mem.indexOf(u8, response, "503 Service Unavailable") != null) {
std.debug.print("\n❌ Server Unavailable\n", .{});
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
std.debug.print("The server is temporarily unavailable.\n\n", .{});
std.debug.print("This could be due to:\n", .{});
std.debug.print(" • Server maintenance\n", .{});
std.debug.print(" • High load\n", .{});
std.debug.print(" • Server restart\n\n", .{});
std.debug.print("To resolve:\n", .{});
std.debug.print(" • Wait a moment and try again\n", .{});
std.debug.print(" • Contact administrator if the issue persists\n\n", .{});
return error.ServerUnavailable;
} else {
std.debug.print("\n❌ WebSocket Handshake Failed\n", .{});
std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{});
std.debug.print("Expected HTTP 101 Switching Protocols, but received:\n", .{});
const newline_pos = std.mem.indexOf(u8, response, "\r\n") orelse response.len;
const status_line = response[0..newline_pos];
std.debug.print(" {s}\n\n", .{status_line});
std.debug.print("To diagnose:\n", .{});
std.debug.print(" • Verify server address in ~/.ml/config.toml\n", .{});
std.debug.print(" • Check network connectivity to the server\n", .{});
std.debug.print(" • Contact your administrator for assistance\n\n", .{});
return error.HandshakeFailed;
}
}
std.posix.nanosleep(0, 10 * std.time.ns_per_ms);
}

73
cli/src/net/ws/opcode.zig Normal file
View file

@ -0,0 +1,73 @@
pub const Opcode = enum(u8) {
queue_job = 0x01,
queue_job_with_tracking = 0x0C,
queue_job_with_snapshot = 0x17,
queue_job_with_args = 0x1A,
queue_job_with_note = 0x1B,
annotate_run = 0x1C,
set_run_narrative = 0x1D,
status_request = 0x02,
cancel_job = 0x03,
prune = 0x04,
crash_report = 0x05,
log_metric = 0x0A,
get_experiment = 0x0B,
start_jupyter = 0x0D,
stop_jupyter = 0x0E,
remove_jupyter = 0x18,
restore_jupyter = 0x19,
list_jupyter = 0x0F,
list_jupyter_packages = 0x1E,
validate_request = 0x16,
// Dataset management opcodes
dataset_list = 0x06,
dataset_register = 0x07,
dataset_info = 0x08,
dataset_search = 0x09,
// Structured response opcodes
response_success = 0x10,
response_error = 0x11,
response_progress = 0x12,
response_status = 0x13,
response_data = 0x14,
response_log = 0x15,
};
pub const ValidateTargetType = enum(u8) {
commit_id = 0,
task_id = 1,
};
pub const queue_job = Opcode.queue_job;
pub const queue_job_with_tracking = Opcode.queue_job_with_tracking;
pub const queue_job_with_snapshot = Opcode.queue_job_with_snapshot;
pub const queue_job_with_args = Opcode.queue_job_with_args;
pub const queue_job_with_note = Opcode.queue_job_with_note;
pub const annotate_run = Opcode.annotate_run;
pub const set_run_narrative = Opcode.set_run_narrative;
pub const status_request = Opcode.status_request;
pub const cancel_job = Opcode.cancel_job;
pub const prune = Opcode.prune;
pub const crash_report = Opcode.crash_report;
pub const log_metric = Opcode.log_metric;
pub const get_experiment = Opcode.get_experiment;
pub const start_jupyter = Opcode.start_jupyter;
pub const stop_jupyter = Opcode.stop_jupyter;
pub const remove_jupyter = Opcode.remove_jupyter;
pub const restore_jupyter = Opcode.restore_jupyter;
pub const list_jupyter = Opcode.list_jupyter;
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
pub const validate_request = Opcode.validate_request;
pub const dataset_list = Opcode.dataset_list;
pub const dataset_register = Opcode.dataset_register;
pub const dataset_info = Opcode.dataset_info;
pub const dataset_search = Opcode.dataset_search;
pub const response_success = Opcode.response_success;
pub const response_error = Opcode.response_error;
pub const response_progress = Opcode.response_progress;
pub const response_status = Opcode.response_status;
pub const response_data = Opcode.response_data;
pub const response_log = Opcode.response_log;

View file

@ -0,0 +1,2 @@
pub const Opcode = @import("opcode.zig").Opcode;
pub const ValidateTargetType = @import("opcode.zig").ValidateTargetType;

View file

@ -0,0 +1,21 @@
const std = @import("std");
pub fn resolveHostAddress(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address {
return std.net.Address.parseIp(host, port) catch |err| switch (err) {
error.InvalidIPAddressFormat => resolveHostname(allocator, host, port),
else => return err,
};
}
fn resolveHostname(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address {
var address_list = try std.net.getAddressList(allocator, host, port);
defer address_list.deinit();
if (address_list.addrs.len == 0) return error.HostResolutionFailed;
return address_list.addrs[0];
}
test "resolve hostnames for WebSocket connections" {
_ = try resolveHostAddress(std.testing.allocator, "localhost", 9100);
}

View file

@ -0,0 +1,74 @@
const std = @import("std");
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v_opt = obj.get(key);
if (v_opt == null) {
return null;
}
const v = v_opt.?;
if (v != .string) {
return null;
}
return v.string;
}
fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 {
const v_opt = obj.get(key);
if (v_opt == null) {
return null;
}
const v = v_opt.?;
if (v != .integer) {
return null;
}
return v.integer;
}
pub fn formatPrewarmFromStatusRoot(allocator: std.mem.Allocator, root: std.json.ObjectMap) !?[]u8 {
const prewarm_val_opt = root.get("prewarm");
if (prewarm_val_opt == null) {
return null;
}
const prewarm_val = prewarm_val_opt.?;
if (prewarm_val != .array) {
return null;
}
const items = prewarm_val.array.items;
if (items.len == 0) {
return null;
}
var out = std.ArrayList(u8){};
errdefer out.deinit(allocator);
const writer = out.writer(allocator);
try writer.writeAll("Prewarm:\n");
for (items) |item| {
if (item != .object) {
continue;
}
const obj = item.object;
const worker_id = jsonGetString(obj, "worker_id") orelse "";
const task_id = jsonGetString(obj, "task_id") orelse "";
const phase = jsonGetString(obj, "phase") orelse "";
const started_at = jsonGetString(obj, "started_at") orelse "";
const dataset_count = jsonGetInt(obj, "dataset_count") orelse 0;
const snapshot_id = jsonGetString(obj, "snapshot_id") orelse "";
const env_image = jsonGetString(obj, "env_image") orelse "";
const env_hit = jsonGetInt(obj, "env_hit") orelse 0;
const env_miss = jsonGetInt(obj, "env_miss") orelse 0;
const env_built = jsonGetInt(obj, "env_built") orelse 0;
try writer.print(
" worker={s} task={s} phase={s} datasets={d} snapshot={s} env={s} env_hit={d} env_miss={d} env_built={d} started={s}\n",
.{ worker_id, task_id, phase, dataset_count, snapshot_id, env_image, env_hit, env_miss, env_built, started_at },
);
}
const owned = try out.toOwnedSlice(allocator);
return owned;
}

View file

@ -0,0 +1,449 @@
const deps = @import("deps.zig");
const std = deps.std;
const io = deps.io;
const protocol = deps.protocol;
const colors = deps.colors;
const Client = @import("client.zig").Client;
const utils = @import("utils.zig");
/// Receive and handle status response with user filtering
pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, options: anytype) !void {
_ = user_context; // TODO: Use for filtering
const message = try self.receiveMessage(allocator);
defer allocator.free(message);
// Check if message is JSON (or contains JSON) or plain text
if (message[0] == '{') {
// Parse JSON response
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{});
defer parsed.deinit();
const root = parsed.value.object;
if (options.json) {
// Output raw JSON
var out = io.stdoutWriter();
try out.print("{s}\n", .{message});
} else {
// Display user info
if (root.get("user")) |user_obj| {
const user = user_obj.object;
const name = user.get("name").?.string;
const admin = user.get("admin").?.bool;
colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin });
}
// Display task summary
if (root.get("tasks")) |tasks_obj| {
const tasks = tasks_obj.object;
const total = tasks.get("total").?.integer;
const queued = tasks.get("queued").?.integer;
const running = tasks.get("running").?.integer;
const failed = tasks.get("failed").?.integer;
const completed = tasks.get("completed").?.integer;
colors.printInfo(
"Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n",
.{ total, queued, running, failed, completed },
);
}
const per_section_limit: usize = options.limit orelse 5;
const TaskStatus = enum { queued, running, failed, completed };
const TaskPrinter = struct {
fn statusLabel(s: TaskStatus) []const u8 {
return switch (s) {
.queued => "Queued",
.running => "Running",
.failed => "Failed",
.completed => "Completed",
};
}
fn statusMatch(s: TaskStatus) []const u8 {
return switch (s) {
.queued => "queued",
.running => "running",
.failed => "failed",
.completed => "completed",
};
}
fn shorten(s: []const u8, max_len: usize) []const u8 {
if (s.len <= max_len) return s;
return s[0..max_len];
}
fn printSection(
allocator2: std.mem.Allocator,
queue_items: []const std.json.Value,
status: TaskStatus,
limit2: usize,
) !void {
_ = allocator2;
const label = statusLabel(status);
const want = statusMatch(status);
std.debug.print("\n{s}:\n", .{label});
var shown: usize = 0;
for (queue_items) |item| {
if (item != .object) continue;
const obj = item.object;
const st = utils.jsonGetString(obj, "status") orelse "";
if (!std.mem.eql(u8, st, want)) continue;
const id = utils.jsonGetString(obj, "id") orelse "";
const job_name = utils.jsonGetString(obj, "job_name") orelse "";
const worker_id = utils.jsonGetString(obj, "worker_id") orelse "";
const err = utils.jsonGetString(obj, "error") orelse "";
if (std.mem.eql(u8, want, "failed")) {
colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name });
if (worker_id.len > 0) {
std.debug.print(" (worker={s})", .{worker_id});
}
std.debug.print("\n", .{});
if (err.len > 0) {
std.debug.print(" error: {s}\n", .{shorten(err, 160)});
}
} else if (std.mem.eql(u8, want, "running")) {
colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name });
if (worker_id.len > 0) {
std.debug.print(" (worker={s})", .{worker_id});
}
std.debug.print("\n", .{});
} else if (std.mem.eql(u8, want, "queued")) {
std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name });
} else {
colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name });
}
shown += 1;
if (shown >= limit2) break;
}
if (shown == 0) {
std.debug.print(" (none)\n", .{});
} else {
// Indicate there may be more.
var total_for_status: usize = 0;
for (queue_items) |item| {
if (item != .object) continue;
const obj = item.object;
const st = utils.jsonGetString(obj, "status") orelse "";
if (std.mem.eql(u8, st, want)) total_for_status += 1;
}
if (total_for_status > shown) {
std.debug.print(" ... and {d} more\n", .{total_for_status - shown});
}
}
}
};
if (root.get("queue")) |queue_val| {
if (queue_val == .array) {
const items = queue_val.array.items;
try TaskPrinter.printSection(allocator, items, .queued, per_section_limit);
try TaskPrinter.printSection(allocator, items, .running, per_section_limit);
try TaskPrinter.printSection(allocator, items, .failed, per_section_limit);
try TaskPrinter.printSection(allocator, items, .completed, per_section_limit);
}
}
if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| {
defer allocator.free(section);
colors.printInfo("{s}", .{section});
}
}
} else {
// Handle plain text response - filter out non-printable characters
var clean_msg = allocator.alloc(u8, message.len) catch {
if (options.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
} else {
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
}
return;
};
defer allocator.free(clean_msg);
var clean_len: usize = 0;
for (message) |byte| {
// Skip WebSocket frame header bytes and non-printable chars
if (byte >= 32 and byte <= 126) { // printable ASCII only
clean_msg[clean_len] = byte;
clean_len += 1;
}
}
// Look for common error messages in the cleaned data
if (clean_len > 0) {
const cleaned = clean_msg[0..clean_len];
if (options.json) {
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{});
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"authentication_failed\"}}\n", .{});
} else {
var out = io.stdoutWriter();
try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned});
}
} else {
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
std.debug.print("Insufficient permissions to view jobs\n", .{});
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
std.debug.print("Authentication failed\n", .{});
} else {
std.debug.print("Server response: {s}\n", .{cleaned});
}
}
} else {
if (options.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
} else {
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
}
}
return;
}
}
/// Receive and handle cancel response with user permissions
pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8, options: anytype) !void {
const message = try self.receiveMessage(allocator);
defer allocator.free(message);
// Check if message is JSON or plain text
if (message[0] == '{') {
// Parse JSON response
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, message, .{});
defer parsed.deinit();
const root = parsed.value.object;
if (options.json) {
// Output raw JSON
var out = io.stdoutWriter();
try out.print("{s}\n", .{message});
} else {
// Display user-friendly output
if (root.get("success")) |success_val| {
if (success_val.bool) {
colors.printSuccess("Job '{s}' canceled successfully\n", .{job_name});
} else {
colors.printError("Failed to cancel job '{s}'\n", .{job_name});
if (root.get("error")) |error_val| {
colors.printError("Error: {s}\n", .{error_val.string});
}
}
} else {
colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name });
}
}
} else {
// Handle plain text response - filter out non-printable characters
var clean_msg = allocator.alloc(u8, message.len) catch {
if (options.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
} else {
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
}
return;
};
defer allocator.free(clean_msg);
var clean_len: usize = 0;
for (message) |byte| {
// Skip WebSocket frame header bytes and non-printable chars
if (byte >= 32 and byte <= 126) { // printable ASCII only
clean_msg[clean_len] = byte;
clean_len += 1;
}
}
if (clean_len > 0) {
const cleaned = clean_msg[0..clean_len];
if (options.json) {
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"insufficient_permissions\"}}\n", .{});
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"authentication_failed\"}}\n", .{});
} else {
var out = io.stdoutWriter();
try out.print("{{\"response\": \"{s}\"}}\n", .{cleaned});
}
} else {
if (std.mem.indexOf(u8, cleaned, "Insufficient permissions") != null) {
std.debug.print("Insufficient permissions to cancel job\n", .{});
} else if (std.mem.indexOf(u8, cleaned, "Authentication failed") != null) {
std.debug.print("Authentication failed\n", .{});
} else {
colors.printInfo("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name });
colors.printInfo("Response: {s}\n", .{cleaned});
}
}
} else {
if (options.json) {
var out = io.stdoutWriter();
try out.print("{{\"error\": \"binary_data\", \"bytes\": {d}}}\n", .{message.len});
} else {
std.debug.print("Server response: [binary data - {d} bytes]\n", .{message.len});
}
}
return;
}
}
/// Handle response packet with appropriate display
pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void {
switch (packet.packet_type) {
.success => {
if (packet.success_message) |msg| {
if (msg.len > 0) {
std.debug.print("✓ {s}: {s}\n", .{ operation, msg });
} else {
std.debug.print("✓ {s} completed successfully\n", .{operation});
}
} else {
std.debug.print("✓ {s} completed successfully\n", .{operation});
}
},
.error_packet => {
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
std.debug.print("✗ {s} failed: {s}\n", .{ operation, error_msg });
if (packet.error_message) |msg| {
if (msg.len > 0) {
std.debug.print("Details: {s}\n", .{msg});
}
}
if (packet.error_details) |details| {
if (details.len > 0) {
std.debug.print("Additional info: {s}\n", .{details});
}
}
// Convert to appropriate CLI error
return convertServerError(self, packet.error_code.?);
},
.progress => {
if (packet.progress_type) |ptype| {
switch (ptype) {
.percentage => {
const percentage = packet.progress_value.?;
if (packet.progress_total) |total| {
std.debug.print("Progress: {d}/{d} ({d:.1}%)\n", .{ percentage, total, @as(f32, @floatFromInt(percentage)) * 100.0 / @as(f32, @floatFromInt(total)) });
} else {
std.debug.print("Progress: {d}%\n", .{percentage});
}
},
.stage => {
if (packet.progress_message) |msg| {
std.debug.print("Stage: {s}\n", .{msg});
}
},
.message => {
if (packet.progress_message) |msg| {
std.debug.print("Info: {s}\n", .{msg});
}
},
.bytes_transferred => {
const bytes = packet.progress_value.?;
if (packet.progress_total) |total| {
const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0;
const total_mb = @as(f64, @floatFromInt(total)) / 1024.0 / 1024.0;
std.debug.print("Transferred: {d:.2} MB / {d:.2} MB\n", .{ transferred_mb, total_mb });
} else {
const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0;
std.debug.print("Transferred: {d:.2} MB\n", .{transferred_mb});
}
},
}
}
},
.status => {
if (packet.status_data) |data| {
std.debug.print("Status: {s}\n", .{data});
}
},
.data => {
if (packet.data_type) |dtype| {
std.debug.print("Data [{s}]: ", .{dtype});
if (packet.data_payload) |payload| {
// Try to display as string if it looks like text
const is_text = for (payload) |byte| {
if (byte < 32 and byte != '\n' and byte != '\r' and byte != '\t') break false;
} else true;
if (is_text) {
std.debug.print("{s}\n", .{payload});
} else {
std.debug.print("{d} bytes\n", .{payload.len});
}
}
}
},
.log => {
if (packet.log_level) |level| {
const level_name = protocol.ResponsePacket.getLogLevelName(level);
if (packet.log_message) |msg| {
std.debug.print("[{s}] {s}\n", .{ level_name, msg });
}
}
},
}
}
/// Convert server error code to CLI error
fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror {
_ = self;
return switch (server_error) {
.authentication_failed => error.AuthenticationFailed,
.permission_denied => error.PermissionDenied,
.resource_not_found => error.JobNotFound,
.resource_already_exists => error.ResourceExists,
.timeout => error.RequestTimeout,
.server_overloaded, .service_unavailable => error.ServerUnreachable,
.invalid_request => error.InvalidArguments,
.job_not_found => error.JobNotFound,
.job_already_running => error.JobAlreadyRunning,
.job_failed_to_start, .job_execution_failed => error.CommandFailed,
.job_cancelled => error.JobCancelled,
else => error.ServerError,
};
}
/// Clean up packet allocated memory
pub fn cleanupPacket(self: *Client, packet: protocol.ResponsePacket) void {
if (packet.success_message) |msg| {
self.allocator.free(msg);
}
if (packet.error_message) |msg| {
self.allocator.free(msg);
}
if (packet.error_details) |details| {
self.allocator.free(details);
}
if (packet.progress_message) |msg| {
self.allocator.free(msg);
}
if (packet.status_data) |data| {
self.allocator.free(data);
}
if (packet.data_type) |dtype| {
self.allocator.free(dtype);
}
if (packet.data_payload) |payload| {
self.allocator.free(payload);
}
if (packet.log_message) |msg| {
self.allocator.free(msg);
}
}

25
cli/src/net/ws/utils.zig Normal file
View file

@ -0,0 +1,25 @@
const std = @import("std");
pub fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
const v_opt = obj.get(key);
if (v_opt == null) {
return null;
}
const v = v_opt.?;
if (v != .string) {
return null;
}
return v.string;
}
pub fn jsonGetInt(obj: std.json.ObjectMap, key: []const u8) ?i64 {
const v_opt = obj.get(key);
if (v_opt == null) {
return null;
}
const v = v_opt.?;
if (v != .integer) {
return null;
}
return v.integer;
}

View file

@ -1,8 +1,10 @@
// Utils module - exports all utility modules
// Utilities module - exports all utility modules
pub const colors = @import("utils/colors.zig");
pub const crypto = @import("utils/crypto.zig");
pub const history = @import("utils/history.zig");
pub const io = @import("utils/io.zig");
pub const logging = @import("utils/logging.zig");
pub const rsync = @import("utils/rsync.zig");
pub const rsync_embedded = @import("utils/rsync_embedded.zig");
pub const rsync_embedded_binary = @import("utils/rsync_embedded_binary.zig");
pub const storage = @import("utils/storage.zig");

View file

@ -18,6 +18,7 @@ fn hexNibble(c: u8) ?u8 {
pub fn decodeHex(allocator: std.mem.Allocator, hex: []const u8) ![]u8 {
if ((hex.len % 2) != 0) return error.InvalidHex;
const out = try allocator.alloc(u8, hex.len / 2);
errdefer allocator.free(out);
var i: usize = 0;
while (i < out.len) : (i += 1) {
const hi = hexNibble(hex[i * 2]) orelse return error.InvalidHex;

59
cli/src/utils/io.zig Normal file
View file

@ -0,0 +1,59 @@
const std = @import("std");
fn writeAllFd(fd: std.posix.fd_t, data: []const u8) std.Io.Writer.Error!void {
var off: usize = 0;
while (off < data.len) {
const n = std.posix.write(fd, data[off..]) catch return error.WriteFailed;
if (n == 0) return error.WriteFailed;
off += n;
}
}
fn drainStdout(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
_ = w;
if (data.len == 0) return 0;
var written: usize = 0;
for (data) |chunk| {
try writeAllFd(std.posix.STDOUT_FILENO, chunk);
written += chunk.len;
}
if (splat > 0) {
const last = data[data.len - 1];
var i: usize = 0;
while (i < splat) : (i += 1) {
try writeAllFd(std.posix.STDOUT_FILENO, last);
written += last.len;
}
}
return written;
}
fn drainStderr(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
_ = w;
if (data.len == 0) return 0;
var written: usize = 0;
for (data) |chunk| {
try writeAllFd(std.posix.STDERR_FILENO, chunk);
written += chunk.len;
}
if (splat > 0) {
const last = data[data.len - 1];
var i: usize = 0;
while (i < splat) : (i += 1) {
try writeAllFd(std.posix.STDERR_FILENO, last);
written += last.len;
}
}
return written;
}
const stdout_vtable = std.Io.Writer.VTable{ .drain = drainStdout };
const stderr_vtable = std.Io.Writer.VTable{ .drain = drainStderr };
pub fn stdoutWriter() std.Io.Writer {
return .{ .vtable = &stdout_vtable, .buffer = &[_]u8{}, .end = 0 };
}
pub fn stderrWriter() std.Io.Writer {
return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 };
}

View file

@ -6,24 +6,127 @@ pub const EmbeddedRsync = struct {
const Self = @This();
/// Extract embedded rsync binary to temporary location
pub fn extractRsyncBinary(self: Self) ![]const u8 {
const rsync_path = "/tmp/ml_rsync";
fn encodeHexLower(out: []u8, bytes: []const u8) void {
const hex = "0123456789abcdef";
var i: usize = 0;
while (i < bytes.len) : (i += 1) {
const b = bytes[i];
out[i * 2] = hex[(b >> 4) & 0x0f];
out[i * 2 + 1] = hex[b & 0x0f];
}
}
// Check if rsync binary already exists
if (std.fs.cwd().openFile(rsync_path, .{})) |file| {
file.close();
// Check if it's executable
const stat = try std.fs.cwd().statFile(rsync_path);
if (stat.mode & 0o111 != 0) {
return try self.allocator.dupe(u8, rsync_path);
}
} else |_| {
// Need to extract the binary
try self.extractAndSetExecutable(rsync_path);
fn cacheBaseDir(self: Self) ![]u8 {
if (std.posix.getenv("XDG_CACHE_HOME")) |z| {
return self.allocator.dupe(u8, std.mem.sliceTo(z, 0));
}
return try self.allocator.dupe(u8, rsync_path);
const home = if (std.posix.getenv("HOME")) |z| std.mem.sliceTo(z, 0) else return error.MissingHome;
const target = @import("builtin").target;
if (target.os.tag == .macos) {
return std.fmt.allocPrint(self.allocator, "{s}/Library/Caches", .{home});
}
return std.fmt.allocPrint(self.allocator, "{s}/.cache", .{home});
}
fn resolveRsyncPathOverride(self: Self) !?[]u8 {
const z = std.posix.getenv("ML_RSYNC_PATH") orelse return null;
const p = std.mem.sliceTo(z, 0);
if (p.len == 0) return null;
// Basic validation: must exist and be executable.
const file = std.fs.openFileAbsolute(p, .{}) catch return error.InvalidRsyncPath;
defer file.close();
const stat = try file.stat();
if (stat.mode & 0o111 == 0) return error.InvalidRsyncPath;
const dup = try self.allocator.dupe(u8, p);
return @as(?[]u8, dup);
}
fn resolveSystemRsyncPath(self: Self) !?[]u8 {
const path_z = std.posix.getenv("PATH") orelse return null;
const path = std.mem.sliceTo(path_z, 0);
var it = std.mem.splitScalar(u8, path, ':');
while (it.next()) |dir| {
if (dir.len == 0) continue;
const candidate = try std.fmt.allocPrint(self.allocator, "{s}/{s}", .{ dir, "rsync" });
errdefer self.allocator.free(candidate);
const file = std.fs.openFileAbsolute(candidate, .{}) catch {
self.allocator.free(candidate);
continue;
};
defer file.close();
const stat = file.stat() catch {
self.allocator.free(candidate);
continue;
};
if (stat.mode & 0o111 == 0) {
self.allocator.free(candidate);
continue;
}
return @as(?[]u8, candidate);
}
return null;
}
/// Extract embedded rsync binary to temporary location
pub fn extractRsyncBinary(self: Self) ![]const u8 {
if (try self.resolveRsyncPathOverride()) |p| {
return p;
}
if (try self.resolveSystemRsyncPath()) |p| {
return p;
}
const embedded_binary = @import("rsync_embedded_binary.zig");
const digest = embedded_binary.rsyncBinarySha256();
var digest_hex: [64]u8 = undefined;
encodeHexLower(digest_hex[0..], digest[0..]);
const cache_base = try self.cacheBaseDir();
defer self.allocator.free(cache_base);
const cache_dir = try std.fmt.allocPrint(self.allocator, "{s}/fetchml/rsync", .{cache_base});
defer self.allocator.free(cache_dir);
std.fs.cwd().makePath(cache_dir) catch |err| switch (err) {
error.PathAlreadyExists => {},
else => return err,
};
const rsync_path = try std.fmt.allocPrint(self.allocator, "{s}/ml_rsync_{s}", .{ cache_dir, digest_hex[0..16] });
errdefer self.allocator.free(rsync_path);
// If file exists and matches the embedded digest, reuse it.
if (std.fs.openFileAbsolute(rsync_path, .{})) |file| {
defer file.close();
const stat = try file.stat();
if (stat.mode & 0o111 != 0) {
const data = try file.readToEndAlloc(self.allocator, 1024 * 1024 * 4);
defer self.allocator.free(data);
var on_disk: [32]u8 = undefined;
std.crypto.hash.sha2.Sha256.hash(data, &on_disk, .{});
if (std.mem.eql(u8, on_disk[0..], digest[0..])) {
return rsync_path;
}
}
} else |_| {
// Not present; will extract.
}
try self.extractAndSetExecutable(rsync_path);
return rsync_path;
}
/// Extract rsync binary from embedded data and set executable permissions
@ -39,14 +142,15 @@ pub const EmbeddedRsync = struct {
defer self.allocator.free(debug_msg);
std.log.debug("{s}", .{debug_msg});
// Write embedded binary to file system
try std.fs.cwd().writeFile(.{ .sub_path = path, .data = binary_data });
// Write embedded binary to file system (absolute path)
var file = try std.fs.createFileAbsolute(path, .{ .truncate = true, .read = false, .mode = 0o755 });
defer file.close();
try file.writeAll(binary_data);
// Set executable permissions using OS API
// Ensure executable permissions
const mode = 0o755; // rwxr-xr-x
std.posix.fchmodat(std.fs.cwd().fd, path, mode, 0) catch |err| {
std.posix.fchmod(file.handle, mode) catch |err| {
std.log.warn("Failed to set executable permissions on {s}: {}", .{ path, err });
// Continue anyway - the script might still work
};
}
@ -79,7 +183,13 @@ pub const EmbeddedRsync = struct {
child.stdout_behavior = .Inherit;
child.stderr_behavior = .Inherit;
const term = try child.spawnAndWait();
const term = child.spawnAndWait() catch |err| {
std.log.err(
"Failed to execute rsync at '{s}': {}. If your environment blocks executing extracted binaries (e.g. noexec), set ML_RSYNC_PATH to a system rsync.",
.{ rsync_path, err },
);
return err;
};
switch (term) {
.Exited => |code| {

View file

@ -1,4 +1,119 @@
const std = @import("std");
const build_options = @import("build_options");
fn isScript(data: []const u8) bool {
return data.len >= 2 and data[0] == '#' and data[1] == '!';
}
fn readU32BE(data: []const u8, offset: usize) ?u32 {
if (data.len < offset + 4) return null;
return (@as(u32, data[offset]) << 24) |
(@as(u32, data[offset + 1]) << 16) |
(@as(u32, data[offset + 2]) << 8) |
(@as(u32, data[offset + 3]));
}
fn readU32LE(data: []const u8, offset: usize) ?u32 {
if (data.len < offset + 4) return null;
return (@as(u32, data[offset])) |
(@as(u32, data[offset + 1]) << 8) |
(@as(u32, data[offset + 2]) << 16) |
(@as(u32, data[offset + 3]) << 24);
}
fn readU16LE(data: []const u8, offset: usize) ?u16 {
if (data.len < offset + 2) return null;
return (@as(u16, data[offset])) | (@as(u16, data[offset + 1]) << 8);
}
fn readU16BE(data: []const u8, offset: usize) ?u16 {
if (data.len < offset + 2) return null;
return (@as(u16, data[offset]) << 8) | (@as(u16, data[offset + 1]));
}
fn machCpuTypeForArch(arch: std.Target.Cpu.Arch) ?u32 {
return switch (arch) {
.x86_64 => 0x01000007,
.aarch64 => 0x0100000c,
else => null,
};
}
fn elfMachineForArch(arch: std.Target.Cpu.Arch) ?u16 {
return switch (arch) {
.x86_64 => 62,
.aarch64 => 183,
else => null,
};
}
fn machOHasCpuType(data: []const u8, want_cputype: u32) bool {
// Mach-O header: u32 magic; i32 cputype at offset 4.
// Accept both endiannesses (cputype will appear swapped if magic swapped).
const magic_le = readU32LE(data, 0) orelse return false;
const magic_be = readU32BE(data, 0) orelse return false;
const MH_MAGIC_64: u32 = 0xfeedfacf;
const MH_CIGAM_64: u32 = 0xcffaedfe;
if (magic_le == MH_MAGIC_64) {
const cputype = readU32LE(data, 4) orelse return false;
return cputype == want_cputype;
}
if (magic_be == MH_MAGIC_64 or magic_le == MH_CIGAM_64) {
const cputype = readU32BE(data, 4) orelse return false;
return cputype == want_cputype;
}
return false;
}
fn fatMachOHasCpuType(data: []const u8, want_cputype: u32) bool {
// fat_header: magic (0xcafebabe big-endian), nfat_arch u32
const FAT_MAGIC: u32 = 0xcafebabe;
const magic = readU32BE(data, 0) orelse return false;
if (magic != FAT_MAGIC) return false;
const nfat = readU32BE(data, 4) orelse return false;
// fat_arch entries start at offset 8, 20 bytes each, big-endian
var off: usize = 8;
var i: u32 = 0;
while (i < nfat) : (i += 1) {
const cputype = readU32BE(data, off) orelse return false;
if (cputype == want_cputype) return true;
off += 20;
if (off > data.len) return false;
}
return false;
}
fn elfHasMachine(data: []const u8, want_machine: u16) bool {
// Minimal ELF check: magic + e_machine at offset 18
if (data.len < 20) return false;
if (!(data[0] == 0x7f and data[1] == 'E' and data[2] == 'L' and data[3] == 'F')) return false;
const ei_data = data[5]; // 1=little, 2=big
return switch (ei_data) {
1 => (readU16LE(data, 18) orelse return false) == want_machine,
2 => (readU16BE(data, 18) orelse return false) == want_machine,
else => false,
};
}
fn isNativeForTarget(data: []const u8) bool {
if (isScript(data)) return true;
const target = @import("builtin").target;
const arch = target.cpu.arch;
if (machCpuTypeForArch(arch)) |want_cputype| {
if (fatMachOHasCpuType(data, want_cputype)) return true;
if (machOHasCpuType(data, want_cputype)) return true;
}
if (elfMachineForArch(arch)) |want_machine| {
if (elfHasMachine(data, want_machine)) return true;
}
return false;
}
/// Embedded rsync binary data
/// For dev builds: uses placeholder wrapper that calls system rsync
@ -8,16 +123,32 @@ const std = @import("std");
/// 1. Download or build a static rsync binary for your target platform
/// 2. Place it at cli/src/assets/rsync_release.bin
/// 3. Build with: zig build prod (or release/cross targets)
const placeholder_data = @embedFile("../assets/rsync_placeholder.bin");
// Check if rsync_release.bin exists, otherwise fall back to placeholder
const rsync_binary_data = if (@import("builtin").mode == .Debug or @import("builtin").mode == .ReleaseSafe)
@embedFile("../assets/rsync_placeholder.bin")
const release_data = if (build_options.has_rsync_release)
@embedFile(build_options.rsync_release_path)
else
// For ReleaseSmall and ReleaseFast, try to use the release binary
@embedFile("../assets/rsync_release.bin");
placeholder_data;
// Prefer placeholder in Debug/ReleaseSafe. In Release*, only use rsync_release.bin if it
// appears compatible with the target architecture; otherwise fall back to placeholder.
const use_release = (@import("builtin").mode != .Debug and @import("builtin").mode != .ReleaseSafe) and
build_options.has_rsync_release and
!isScript(release_data) and
isNativeForTarget(release_data);
const rsync_binary_data = if (use_release) release_data else placeholder_data;
pub const RSYNC_BINARY: []const u8 = rsync_binary_data;
pub const USING_RELEASE_BINARY: bool = use_release;
pub fn rsyncBinarySha256() [32]u8 {
var digest: [32]u8 = undefined;
std.crypto.hash.sha2.Sha256.hash(RSYNC_BINARY, &digest, .{});
return digest;
}
/// Get embedded rsync binary data
pub fn getRsyncBinary() []const u8 {
return RSYNC_BINARY;

View file

@ -1,16 +1,17 @@
const std = @import("std");
const testing = std.testing;
const src = @import("src");
const commands = src.commands;
test "jupyter top-level action includes create" {
try testing.expect(src.commands.jupyter.isValidTopLevelAction("create"));
try testing.expect(src.commands.jupyter.isValidTopLevelAction("start"));
try testing.expect(!src.commands.jupyter.isValidTopLevelAction("bogus"));
try testing.expect(commands.jupyter.isValidTopLevelAction("create"));
try testing.expect(commands.jupyter.isValidTopLevelAction("start"));
try testing.expect(!commands.jupyter.isValidTopLevelAction("bogus"));
}
test "jupyter defaultWorkspacePath prefixes ./" {
const allocator = testing.allocator;
const p = try src.commands.jupyter.defaultWorkspacePath(allocator, "my-workspace");
const p = try commands.jupyter.defaultWorkspacePath(allocator, "my-workspace");
defer allocator.free(p);
try testing.expectEqualStrings("./my-workspace", p);

View file

@ -6,6 +6,7 @@ test "CLI basic functionality" {
// Test that CLI module can be imported
const allocator = testing.allocator;
_ = allocator;
_ = src;
// Test basic string operations used in CLI
const test_str = "ml sync";

View file

@ -1,6 +1,5 @@
const std = @import("std");
const testing = std.testing;
const src = @import("src");
test "queue command argument parsing" {
// Test various queue command argument combinations

View file

@ -2,7 +2,6 @@ const std = @import("std");
const testing = std.testing;
const src = @import("src");
const protocol = src.net.protocol;
fn roundTrip(allocator: std.mem.Allocator, packet: protocol.ResponsePacket) !protocol.ResponsePacket {

View file

@ -1,30 +1,121 @@
const std = @import("std");
const testing = std.testing;
// Simple mock rsync for testing
const MockRsyncEmbedded = struct {
const EmbeddedRsync = struct {
allocator: std.mem.Allocator,
const c = @cImport({
@cInclude("stdlib.h");
});
fn extractRsyncBinary(self: EmbeddedRsync) ![]const u8 {
// Simple mock - return a dummy path
return try std.fmt.allocPrint(self.allocator, "/tmp/mock_rsync", .{});
}
};
};
const src = @import("src");
const utils = src.utils;
const rsync_embedded = MockRsyncEmbedded;
fn isScript(data: []const u8) bool {
return data.len >= 2 and data[0] == '#' and data[1] == '!';
}
test "embedded rsync binary creation" {
const allocator = testing.allocator;
var embedded_rsync = rsync_embedded.EmbeddedRsync{ .allocator = allocator };
// Ensure the override doesn't influence this test.
_ = c.unsetenv("ML_RSYNC_PATH");
// Force embedded fallback by clearing PATH for this process so system rsync isn't found.
const old_path_z = std.posix.getenv("PATH");
if (c.setenv("PATH", "", 1) != 0) return error.Unexpected;
defer {
if (old_path_z) |z| {
_ = c.setenv("PATH", z, 1);
} else {
_ = c.unsetenv("PATH");
}
}
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
// Test binary extraction
const rsync_path = try embedded_rsync.extractRsyncBinary();
defer allocator.free(rsync_path);
// Verify the path was created
try testing.expect(rsync_path.len > 0);
try testing.expect(std.mem.startsWith(u8, rsync_path, "/tmp/"));
// File exists and is executable
const file = try std.fs.openFileAbsolute(rsync_path, .{});
defer file.close();
const stat = try file.stat();
try testing.expect(stat.mode & 0o111 != 0);
// Contents match embedded payload
const data = try file.readToEndAlloc(allocator, 1024 * 1024 * 4);
defer allocator.free(data);
var digest: [32]u8 = undefined;
std.crypto.hash.sha2.Sha256.hash(data, &digest, .{});
const embedded_digest = utils.rsync_embedded_binary.rsyncBinarySha256();
try testing.expect(std.mem.eql(u8, &digest, &embedded_digest));
// Sanity: if we're not using the release binary, it should be the placeholder script.
if (!utils.rsync_embedded_binary.USING_RELEASE_BINARY) {
try testing.expect(isScript(data));
}
}
test "embedded rsync honors ML_RSYNC_PATH override" {
const allocator = testing.allocator;
// Use a known executable. We don't execute it; we only verify the override is returned.
const override_path = "/bin/sh";
try testing.expect(std.fs.openFileAbsolute(override_path, .{}) != error.FileNotFound);
if (c.setenv("ML_RSYNC_PATH", override_path, 1) != 0) return error.Unexpected;
defer _ = c.unsetenv("ML_RSYNC_PATH");
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
const rsync_path = try embedded_rsync.extractRsyncBinary();
defer allocator.free(rsync_path);
try testing.expectEqualStrings(override_path, rsync_path);
}
test "embedded rsync prefers system rsync when available" {
const allocator = testing.allocator;
_ = c.unsetenv("ML_RSYNC_PATH");
// Find rsync on PATH (simple scan). If not present, skip.
const path_z = std.posix.getenv("PATH") orelse return;
const path = std.mem.sliceTo(path_z, 0);
var sys_rsync: ?[]u8 = null;
defer if (sys_rsync) |p| allocator.free(p);
var it = std.mem.splitScalar(u8, path, ':');
while (it.next()) |dir| {
if (dir.len == 0) continue;
const candidate = std.fmt.allocPrint(allocator, "{s}/{s}", .{ dir, "rsync" }) catch return;
errdefer allocator.free(candidate);
const file = std.fs.openFileAbsolute(candidate, .{}) catch {
allocator.free(candidate);
continue;
};
defer file.close();
const st = file.stat() catch {
allocator.free(candidate);
continue;
};
if (st.mode & 0o111 == 0) {
allocator.free(candidate);
continue;
}
sys_rsync = candidate;
break;
}
if (sys_rsync == null) return;
var embedded_rsync = utils.rsync_embedded.EmbeddedRsync{ .allocator = allocator };
const rsync_path = try embedded_rsync.extractRsyncBinary();
defer allocator.free(rsync_path);
try testing.expectEqualStrings(sys_rsync.?, rsync_path);
}

View file

@ -3,7 +3,7 @@ const testing = std.testing;
const src = @import("src");
const ws = src.net.ws;
const ws = src.net.ws_client;
test "status prewarm formatting - single entry" {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};