diff --git a/cli/Makefile b/cli/Makefile new file mode 100644 index 0000000..822ce04 --- /dev/null +++ b/cli/Makefile @@ -0,0 +1,120 @@ +# ML Experiment Manager CLI Build System +# Fast, small, and cross-platform builds + +.PHONY: help build dev prod release cross clean install size test run + +# Default target +help: + @echo "ML Experiment Manager CLI - Build System" + @echo "" + @echo "Available targets:" + @echo " build - Build default version (debug)" + @echo " dev - Build development version (fast compile, debug info)" + @echo " prod - Build production version (small binary, stripped)" + @echo " release - Build release version (optimized for speed)" + @echo " cross - Build cross-platform binaries" + @echo " clean - Clean all build artifacts" + @echo " install - Install binary to /usr/local/bin" + @echo " size - Show binary sizes" + @echo " test - Run unit tests" + @echo " run - Build and run with arguments" + @echo "" + @echo "Examples:" + @echo " make dev" + @echo " make prod" + @echo " make cross" + @echo " make run ARGS=\"status\"" + +# Default build +build: + zig build + +# Development build - fast compilation, debug info +dev: + @echo "Building development version..." + zig build dev + @echo "Dev binary: zig-out/dev/ml-dev" + +# Production build - small and fast +prod: + @echo "Building production version (optimized for size)..." + zig build prod + @echo "Production binary: zig-out/prod/ml" + +# Release build - maximum performance +release: + @echo "Building release version (optimized for speed)..." + zig build release + @echo "Release binary: zig-out/release/ml-release" + +# Cross-platform builds +cross: + @echo "Building cross-platform binaries..." + zig build cross + @echo "Cross-platform binaries in: zig-out/cross/" + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + zig build clean + @echo "Cleaned zig-out/ and zig-cache/" + +# Install to system PATH +install: prod + @echo "Installing to /usr/local/bin..." + zig build install-system + @echo "Installed! Run 'ml' from anywhere." + +# Show binary sizes +size: + @echo "Binary sizes:" + zig build size + +# Run tests +test: + @echo "Running unit tests..." + zig build test + +# Run with arguments +run: + @if [ -z "$(ARGS)" ]; then \ + echo "Usage: make run ARGS=\" [options]\""; \ + echo "Example: make run ARGS=\"status\""; \ + exit 1; \ + fi + zig build run -- $(ARGS) + +# Development workflow +dev-test: dev test + @echo "Development build and tests completed!" + +# Production workflow +prod-test: prod test size + @echo "Production build, tests, and size check completed!" + +# Full release workflow +release-all: clean cross prod test size + @echo "Full release workflow completed!" + @echo "Binaries ready for distribution in zig-out/" + +# Quick development commands +quick-run: dev + ./zig-out/dev/ml-dev $(ARGS) + +quick-test: dev test + @echo "Quick dev and test cycle completed!" + +# Check for required tools +check-tools: + @echo "Checking required tools..." + @which zig > /dev/null || (echo "Error: zig not found. Install from https://ziglang.org/" && exit 1) + @echo "All tools found!" + +# Show build info +info: + @echo "Build information:" + @echo " Zig version: $$(zig version)" + @echo " Target: $$(zig target)" + @echo " CPU: $$(uname -m)" + @echo " OS: $$(uname -s)" + @echo " Available memory: $$(free -h 2>/dev/null || vm_stat | grep 'Pages free' || echo 'N/A')" diff --git a/cli/README.md b/cli/README.md new file mode 100644 index 0000000..686dec1 --- /dev/null +++ b/cli/README.md @@ -0,0 +1,54 @@ +# ML CLI + +Fast CLI tool for managing ML experiments. + +## Quick Start + +```bash +# 1. Build +zig build + +# 2. Setup configuration +./zig-out/bin/ml init + +# 3. Run experiment +./zig-out/bin/ml sync ./my-experiment --queue +``` + +## Commands + +- `ml init` - Setup configuration +- `ml sync ` - Sync project to server +- `ml queue ` - Queue job for execution +- `ml status` - Check system status +- `ml monitor` - Launch monitoring interface +- `ml cancel ` - Cancel running job +- `ml prune --keep N` - Keep N recent experiments +- `ml watch ` - Auto-sync directory + +## Configuration + +Create `~/.ml/config.toml`: + +```toml +worker_host = "worker.local" +worker_user = "mluser" +worker_base = "/data/ml-experiments" +worker_port = 22 +api_key = "your-api-key" +``` + +## Install + +```bash +# Install to system +make install + +# Or copy binary manually +cp zig-out/bin/ml /usr/local/bin/ +``` + +## Need Help? + +- `ml --help` - Show command help +- `ml --help` - Show command-specific help diff --git a/cli/build.zig b/cli/build.zig new file mode 100644 index 0000000..69031e1 --- /dev/null +++ b/cli/build.zig @@ -0,0 +1,175 @@ +// Targets: +// - default: ml (ReleaseSmall) +// - dev: ml-dev (Debug) +// - prod: ml (ReleaseSmall) -> zig-out/prod +// - release: ml-release (ReleaseFast) -> zig-out/release +// - cross: ml-* per platform -> zig-out/cross + +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseSmall }); + + // Common executable configuration + // Default build optimizes for small binary size (~200KB after strip) + const exe = b.addExecutable(.{ + .name = "ml", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = optimize, + }), + }); + + // Embed rsync binary if available + const embed_rsync = b.option(bool, "embed-rsync", "Embed rsync binary in CLI") orelse false; + if (embed_rsync) { + // This would embed the actual rsync binary + // For now, we'll use the wrapper approach + exe.root_module.addCMacro("EMBED_RSYNC", "1"); + } + + b.installArtifact(exe); + + // Default run command + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_cmd.addArgs(args); + } + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); + + // === Development Build === + // Fast build with debug info for development + const dev_exe = b.addExecutable(.{ + .name = "ml-dev", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = b.resolveTargetQuery(.{}), // Use host target + .optimize = .Debug, + }), + }); + + const dev_install = b.addInstallArtifact(dev_exe, .{ + .dest_dir = .{ .override = .{ .custom = "dev" } }, + }); + const dev_step = b.step("dev", "Build development version (fast compilation, debug info)"); + dev_step.dependOn(&dev_install.step); + + // === Production Build === + // Optimized for size and speed + const prod_exe = b.addExecutable(.{ + .name = "ml", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = .ReleaseSmall, // Optimize for small binary size + }), + }); + + const prod_install = b.addInstallArtifact(prod_exe, .{ + .dest_dir = .{ .override = .{ .custom = "prod" } }, + }); + const prod_step = b.step("prod", "Build production version (optimized for size and speed)"); + prod_step.dependOn(&prod_install.step); + + // === Release Build === + // Fully optimized for performance + const release_exe = b.addExecutable(.{ + .name = "ml-release", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = .ReleaseFast, // Optimize for speed + }), + }); + + const release_install = b.addInstallArtifact(release_exe, .{ + .dest_dir = .{ .override = .{ .custom = "release" } }, + }); + const release_step = b.step("release", "Build release version (optimized for performance)"); + release_step.dependOn(&release_install.step); + + // === Cross-Platform Builds === + // Build for common platforms with descriptive binary names + const cross_targets = [_]struct { + query: std.Target.Query, + name: []const u8, + }{ + .{ .query = .{ .cpu_arch = .x86_64, .os_tag = .linux }, .name = "ml-linux-x86_64" }, + .{ .query = .{ .cpu_arch = .x86_64, .os_tag = .linux, .abi = .musl }, .name = "ml-linux-musl-x86_64" }, + .{ .query = .{ .cpu_arch = .x86_64, .os_tag = .macos }, .name = "ml-macos-x86_64" }, + .{ .query = .{ .cpu_arch = .aarch64, .os_tag = .macos }, .name = "ml-macos-aarch64" }, + }; + + const cross_step = b.step("cross", "Build cross-platform binaries"); + + for (cross_targets) |ct| { + const cross_target = b.resolveTargetQuery(ct.query); + const cross_exe = b.addExecutable(.{ + .name = ct.name, + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = cross_target, + .optimize = .ReleaseSmall, + }), + }); + + const cross_install = b.addInstallArtifact(cross_exe, .{ + .dest_dir = .{ .override = .{ .custom = "cross" } }, + }); + cross_step.dependOn(&cross_install.step); + } + + // === Clean Step === + const clean_step = b.step("clean", "Clean build artifacts"); + const clean_cmd = b.addSystemCommand(&[_][]const u8{ + "rm", "-rf", + "zig-out", ".zig-cache", + }); + clean_step.dependOn(&clean_cmd.step); + + // === Install Step === + // Install binary to system PATH + const install_step = b.step("install-system", "Install binary to /usr/local/bin"); + const install_cmd = b.addSystemCommand(&[_][]const u8{ + "sudo", "cp", "zig-out/bin/ml", "/usr/local/bin/", + }); + install_step.dependOn(&install_cmd.step); + + // === Size Check === + // Show binary sizes for different builds + const size_step = b.step("size", "Show binary sizes"); + const size_cmd = b.addSystemCommand(&[_][]const u8{ + "sh", "-c", + "if [ -d zig-out/bin ]; then echo 'zig-out/bin contents:'; ls -lh zig-out/bin/; fi; " ++ + "if [ -d zig-out/prod ]; then echo '\nzig-out/prod contents:'; ls -lh zig-out/prod/; fi; " ++ + "if [ -d zig-out/release ]; then echo '\nzig-out/release contents:'; ls -lh zig-out/release/; fi; " ++ + "if [ -d zig-out/cross ]; then echo '\nzig-out/cross contents:'; ls -lh zig-out/cross/; fi; " ++ + "if [ ! -d zig-out/bin ] && [ ! -d zig-out/prod ] && [ ! -d zig-out/release ] && [ ! -d zig-out/cross ]; then " ++ + "echo 'No CLI binaries found. Run zig build (default), zig build dev, zig build prod, or zig build release first.'; fi", + }); + size_step.dependOn(&size_cmd.step); + + // Test step + const source_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = optimize, + }); + const test_module = b.createModule(.{ + .root_source_file = b.path("tests/main_test.zig"), + .target = target, + .optimize = optimize, + }); + test_module.addImport("src", source_module); + + const exe_tests = b.addTest(.{ + .root_module = test_module, + }); + const run_exe_tests = b.addRunArtifact(exe_tests); + const test_step = b.step("test", "Run unit tests"); + test_step.dependOn(&run_exe_tests.step); +} diff --git a/cli/scripts/ml_completion.bash b/cli/scripts/ml_completion.bash new file mode 100644 index 0000000..6963ea2 --- /dev/null +++ b/cli/scripts/ml_completion.bash @@ -0,0 +1,88 @@ +# Bash completion for the `ml` CLI +# Usage: +# source /path/to/ml_completion.bash +# or +# echo 'source /path/to/ml_completion.bash' >> ~/.bashrc + +_ml_completions() +{ + local cur prev cmds + COMPREPLY=() + cur="${COMP_WORDS[COMP_CWORD]}" + prev="${COMP_WORDS[COMP_CWORD-1]}" + + # Global options + global_opts="--help --verbose --quiet --monitor" + + # Top-level subcommands + cmds="init sync queue status monitor cancel prune watch dataset experiment" + + # If completing the subcommand itself + if [[ ${COMP_CWORD} -eq 1 ]]; then + COMPREPLY=( $(compgen -W "${cmds} ${global_opts}" -- "${cur}") ) + return 0 + fi + + # Handle global options anywhere + case "${prev}" in + --help|--verbose|--quiet|--monitor) + # No further completion after global flags + return 0 + ;; + esac + + case "${COMP_WORDS[1]}" in + init) + # No specific arguments for init + COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") ) + ;; + sync) + # Complete directories for sync + COMPREPLY=( $(compgen -d -- "${cur}") ) + ;; + queue) + # Suggest common job names (static for now) + COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") ) + ;; + status) + COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") ) + ;; + monitor) + COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") ) + ;; + cancel) + # Job names often free-form; no special completion + ;; + prune) + case "${prev}" in + --keep) + COMPREPLY=( $(compgen -W "1 2 3 4 5" -- "${cur}") ) + ;; + --older-than) + COMPREPLY=( $(compgen -W "1h 6h 12h 1d 3d 7d" -- "${cur}") ) + ;; + *) + COMPREPLY=( $(compgen -W "--keep --older-than ${global_opts}" -- "${cur}") ) + ;; + esac + ;; + watch) + # Complete directories for watch + COMPREPLY=( $(compgen -d -- "${cur}") ) + ;; + dataset) + COMPREPLY=( $(compgen -W "list upload download delete info search" -- "${cur}") ) + ;; + experiment) + COMPREPLY=( $(compgen -W "log show" -- "${cur}") ) + ;; + *) + # Fallback to global options + COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") ) + ;; + esac + + return 0 +} + +complete -F _ml_completions ml diff --git a/cli/scripts/ml_completion.zsh b/cli/scripts/ml_completion.zsh new file mode 100644 index 0000000..6c95d6f --- /dev/null +++ b/cli/scripts/ml_completion.zsh @@ -0,0 +1,119 @@ +# Zsh completion for the `ml` CLI +# Usage: +# source /path/to/ml_completion.zsh +# or add to your ~/.zshrc: +# source /path/to/ml_completion.zsh + +_ml() { + local -a subcommands + subcommands=( + 'init:Setup configuration interactively' + 'sync:Sync project to server' + 'queue:Queue job for execution' + 'status:Get system status' + 'monitor:Launch TUI via SSH' + 'cancel:Cancel running job' + 'prune:Prune old experiments' + 'watch:Watch directory for auto-sync' + 'dataset:Manage datasets' + 'experiment:Manage experiments' + ) + + local -a global_opts + global_opts=( + '--help:Show this help message' + '--verbose:Enable verbose output' + '--quiet:Suppress non-error output' + '--monitor:Monitor progress of long-running operations' + ) + + local curcontext="$curcontext" state line + _arguments -C \ + '1:command:->cmds' \ + '*::arg:->args' + + case $state in + cmds) + _describe -t commands 'ml commands' subcommands + return + ;; + args) + case $words[2] in + sync) + _arguments -C \ + '--help[Show sync help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '1:directory:_directories' + ;; + queue) + _arguments -C \ + '--help[Show queue help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '1:job name:' + ;; + status) + _arguments -C \ + '--help[Show status help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' + ;; + monitor) + _arguments -C \ + '--help[Show monitor help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' + ;; + cancel) + _arguments -C \ + '--help[Show cancel help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '1:job name:' + ;; + prune) + _arguments -C \ + '--help[Show prune help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '--keep[Keep N most recent experiments]:number:' \ + '--older-than[Remove experiments older than D days]:days:' + ;; + watch) + _arguments -C \ + '--help[Show watch help]' \ + '--verbose[Enable verbose output]' \ + '--quiet[Suppress non-error output]' \ + '--monitor[Monitor progress]' \ + '1:directory:_directories' + ;; + dataset) + _values 'dataset subcommand' \ + 'list[list datasets]' \ + 'upload[upload dataset]' \ + 'download[download dataset]' \ + 'delete[delete dataset]' \ + 'info[dataset info]' \ + 'search[search datasets]' + ;; + experiment) + _values 'experiment subcommand' \ + 'log[log metrics]' \ + 'show[show experiment]' + ;; + *) + _arguments -C "${global_opts[@]}" + ;; + esac + ;; + esac +} + +compdef _ml ml \ No newline at end of file diff --git a/cli/src/assets/README.md b/cli/src/assets/README.md new file mode 100644 index 0000000..fcffe05 --- /dev/null +++ b/cli/src/assets/README.md @@ -0,0 +1,93 @@ +# Rsync Binary Setup for Release Builds + +## Overview + +This directory contains rsync binaries for the ML CLI: + +- `rsync_placeholder.bin` - Wrapper script for dev builds (calls system rsync) +- `rsync_release.bin` - Full static rsync binary for release builds (not in repo) + +## Build Modes + +### Development/Debug Builds +- Uses `rsync_placeholder.bin` (98 bytes) +- Calls system rsync via wrapper script +- Results in ~152KB CLI binary +- Requires rsync installed on the system + +### Release Builds (ReleaseSmall, ReleaseFast) +- Uses `rsync_release.bin` (300-500KB static binary) +- Fully self-contained, no dependencies +- Results in ~450-650KB CLI binary +- Works on any system without rsync installed + +## Preparing Release Binaries + +### Option 1: Download Pre-built Static Rsync + +For macOS ARM64: +```bash +cd cli/src/assets +curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-macos-arm64 -o rsync_release.bin +chmod +x rsync_release.bin +``` + +For Linux x86_64: +```bash +cd cli/src/assets +curl -L https://github.com/WayneD/rsync/releases/download/v3.2.7/rsync-linux-x86_64 -o rsync_release.bin +chmod +x rsync_release.bin +``` + +### Option 2: Build Static Rsync Yourself + +```bash +# Clone rsync +git clone https://github.com/WayneD/rsync.git +cd rsync + +# Configure for static build +./configure CFLAGS="-static" LDFLAGS="-static" --disable-xxhash --disable-zstd + +# Build +make + +# Copy to assets +cp rsync ../fetch_ml/cli/src/assets/rsync_release.bin +``` + +### Option 3: Use System Rsync (Temporary) + +For testing release builds without a static binary: +```bash +cd cli/src/assets +cp rsync_placeholder.bin rsync_release.bin +``` + +This will still use the wrapper, but allows builds to complete. + +## Verification + +After placing rsync_release.bin: + +```bash +# Verify it's executable +file cli/src/assets/rsync_release.bin + +# Test it +./cli/src/assets/rsync_release.bin --version + +# Build release +cd cli +zig build prod + +# Check binary size +ls -lh zig-out/prod/ml +``` + +## Notes + +- `rsync_release.bin` is not tracked in git (add to .gitignore if needed) +- Different platforms need different static binaries +- For cross-compilation, provide platform-specific binaries +- The wrapper approach for dev builds is intentional for fast iteration diff --git a/cli/src/assets/rsync_placeholder.bin b/cli/src/assets/rsync_placeholder.bin new file mode 100755 index 0000000..1db52e2 --- /dev/null +++ b/cli/src/assets/rsync_placeholder.bin @@ -0,0 +1,15 @@ +#!/bin/bash +# Rsync wrapper for development builds +# This calls the system's rsync instead of embedding a full binary +# Keeps the dev binary small (152KB) while still functional + +# Find rsync on the system +RSYNC_PATH=$(which rsync 2>/dev/null || echo "/usr/bin/rsync") + +if [ ! -x "$RSYNC_PATH" ]; then + echo "Error: rsync not found on system. Please install rsync or use a release build with embedded rsync." >&2 + exit 127 +fi + +# Pass all arguments to system rsync +exec "$RSYNC_PATH" "$@" diff --git a/cli/src/commands/cancel.zig b/cli/src/commands/cancel.zig new file mode 100644 index 0000000..4e3cfab --- /dev/null +++ b/cli/src/commands/cancel.zig @@ -0,0 +1,77 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws.zig"); +const crypto = @import("../utils/crypto.zig"); +const logging = @import("../utils/logging.zig"); + +const UserContext = struct { + name: []const u8, + admin: bool, + allocator: std.mem.Allocator, + + pub fn deinit(self: *UserContext) void { + self.allocator.free(self.name); + } +}; + +fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { + // Validate API key by making a simple API call to the server + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + // Try to connect with the API key to validate it + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return error.AuthenticationFailed, + } + }; + defer client.close(); + + // For now, create a user context after successful authentication + // In a real implementation, this would get user info from the server + const user_name = try allocator.dupe(u8, "authenticated_user"); + return UserContext{ + .name = user_name, + .admin = false, + .allocator = allocator, + }; +} + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + std.debug.print("Usage: ml cancel \n", .{}); + return error.InvalidArgs; + } + + const job_name = args[0]; + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Authenticate with server to get user context + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Use plain password for WebSocket authentication, hash for binary protocol + const api_key_plain = config.api_key; // Plain password from config + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + // Connect to WebSocket and send cancel message + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendCancelJob(job_name, api_key_hash); + + // Receive structured response with user context + try client.receiveAndHandleCancelResponse(allocator, user_context, job_name); +} diff --git a/cli/src/commands/dataset.zig b/cli/src/commands/dataset.zig new file mode 100644 index 0000000..a547919 --- /dev/null +++ b/cli/src/commands/dataset.zig @@ -0,0 +1,240 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws.zig"); +const crypto = @import("../utils/crypto.zig"); +const colors = @import("../utils/colors.zig"); +const logging = @import("../utils/logging.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + colors.printError("Usage: ml dataset [options]\n", .{}); + colors.printInfo("Actions:\n", .{}); + colors.printInfo(" list List registered datasets\n", .{}); + colors.printInfo(" register Register a dataset with URL\n", .{}); + colors.printInfo(" info Show dataset information\n", .{}); + colors.printInfo(" search Search datasets by name/description\n", .{}); + return error.InvalidArgs; + } + + const action = args[0]; + + if (std.mem.eql(u8, action, "list")) { + try listDatasets(allocator); + } else if (std.mem.eql(u8, action, "register")) { + if (args.len < 3) { + colors.printError("Usage: ml dataset register \n", .{}); + return error.InvalidArgs; + } + try registerDataset(allocator, args[1], args[2]); + } else if (std.mem.eql(u8, action, "info")) { + if (args.len < 2) { + colors.printError("Usage: ml dataset info \n", .{}); + return error.InvalidArgs; + } + try showDatasetInfo(allocator, args[1]); + } else if (std.mem.eql(u8, action, "search")) { + if (args.len < 2) { + colors.printError("Usage: ml dataset search \n", .{}); + return error.InvalidArgs; + } + try searchDatasets(allocator, args[1]); + } else { + colors.printError("Unknown action: {s}\n", .{action}); + return error.InvalidArgs; + } +} + +fn listDatasets(allocator: std.mem.Allocator) !void { + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Authenticate with server to get user context + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Connect to WebSocket and request dataset list + const api_key_plain = config.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendDatasetList(api_key_hash); + + // Receive and display dataset list + const response = try client.receiveAndHandleDatasetResponse(allocator); + defer allocator.free(response); + + colors.printInfo("Registered Datasets:\n", .{}); + colors.printInfo("=====================\n\n", .{}); + + // Parse and display datasets (simplified for now) + if (std.mem.eql(u8, response, "[]")) { + colors.printWarning("No datasets registered.\n", .{}); + colors.printInfo("Use 'ml dataset register ' to add a dataset.\n", .{}); + } else { + colors.printSuccess("{s}\n", .{response}); + } +} + +fn registerDataset(allocator: std.mem.Allocator, name: []const u8, url: []const u8) !void { + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Validate URL format + if (!std.mem.startsWith(u8, url, "http://") and !std.mem.startsWith(u8, url, "https://") and + !std.mem.startsWith(u8, url, "s3://") and !std.mem.startsWith(u8, url, "gs://")) + { + colors.printError("Invalid URL format. Supported: http://, https://, s3://, gs://\n", .{}); + return error.InvalidURL; + } + + // Authenticate with server + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Connect to WebSocket and register dataset + const api_key_plain = config.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendDatasetRegister(name, url, api_key_hash); + + // Receive response + const response = try client.receiveAndHandleDatasetResponse(allocator); + defer allocator.free(response); + + if (std.mem.startsWith(u8, response, "ERROR")) { + colors.printError("Failed to register dataset: {s}\n", .{response}); + } else { + colors.printSuccess("Dataset '{s}' registered successfully!\n", .{name}); + colors.printInfo("URL: {s}\n", .{url}); + } +} + +fn showDatasetInfo(allocator: std.mem.Allocator, name: []const u8) !void { + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Authenticate with server + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Connect to WebSocket and get dataset info + const api_key_plain = config.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendDatasetInfo(name, api_key_hash); + + // Receive response + const response = try client.receiveAndHandleDatasetResponse(allocator); + defer allocator.free(response); + + if (std.mem.startsWith(u8, response, "ERROR") or std.mem.startsWith(u8, response, "NOT_FOUND")) { + colors.printError("Dataset '{s}' not found.\n", .{name}); + } else { + colors.printInfo("Dataset Information:\n", .{}); + colors.printInfo("===================\n", .{}); + colors.printSuccess("Name: {s}\n", .{name}); + colors.printSuccess("Details: {s}\n", .{response}); + } +} + +fn searchDatasets(allocator: std.mem.Allocator, term: []const u8) !void { + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Authenticate with server + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Connect to WebSocket and search datasets + const api_key_plain = config.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendDatasetSearch(term, api_key_hash); + + // Receive response + const response = try client.receiveAndHandleDatasetResponse(allocator); + defer allocator.free(response); + + colors.printInfo("Search Results for '{s}':\n", .{term}); + colors.printInfo("========================\n\n", .{}); + + if (std.mem.eql(u8, response, "[]")) { + colors.printWarning("No datasets found matching '{s}'.\n", .{term}); + } else { + colors.printSuccess("{s}\n", .{response}); + } +} + +// Reuse authenticateUser from other commands +fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + // Try to connect with the API key to validate it + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return error.AuthenticationFailed, + } + }; + defer client.close(); + + // For now, create a user context after successful authentication + const user_name = try allocator.dupe(u8, "authenticated_user"); + return UserContext{ + .name = user_name, + .admin = false, + .allocator = allocator, + }; +} + +const UserContext = struct { + name: []const u8, + admin: bool, + allocator: std.mem.Allocator, + + pub fn deinit(self: *UserContext) void { + self.allocator.free(self.name); + } +}; diff --git a/cli/src/commands/experiment.zig b/cli/src/commands/experiment.zig new file mode 100644 index 0000000..140b969 --- /dev/null +++ b/cli/src/commands/experiment.zig @@ -0,0 +1,192 @@ +const std = @import("std"); +const config = @import("../config.zig"); +const ws = @import("../net/ws.zig"); +const protocol = @import("../net/protocol.zig"); + +pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len < 1) { + std.debug.print("Usage: ml experiment [args]\n", .{}); + std.debug.print("Commands:\n", .{}); + std.debug.print(" log Log a metric\n", .{}); + std.debug.print(" show Show experiment details\n", .{}); + return; + } + + const command = args[0]; + + if (std.mem.eql(u8, command, "log")) { + try executeLog(allocator, args[1..]); + } else if (std.mem.eql(u8, command, "show")) { + try executeShow(allocator, args[1..]); + } else { + std.debug.print("Unknown command: {s}\n", .{command}); + } +} + +fn executeLog(allocator: std.mem.Allocator, args: []const []const u8) !void { + var commit_id: ?[]const u8 = null; + var name: ?[]const u8 = null; + var value: ?f64 = null; + var step: u32 = 0; + + var i: usize = 0; + while (i < args.len) : (i += 1) { + const arg = args[i]; + if (std.mem.eql(u8, arg, "--id")) { + if (i + 1 < args.len) { + commit_id = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--name")) { + if (i + 1 < args.len) { + name = args[i + 1]; + i += 1; + } + } else if (std.mem.eql(u8, arg, "--value")) { + if (i + 1 < args.len) { + value = try std.fmt.parseFloat(f64, args[i + 1]); + i += 1; + } + } else if (std.mem.eql(u8, arg, "--step")) { + if (i + 1 < args.len) { + step = try std.fmt.parseInt(u32, args[i + 1], 10); + i += 1; + } + } + } + + if (commit_id == null or name == null or value == null) { + std.debug.print("Usage: ml experiment log --id --name --value [--step ]\n", .{}); + return; + } + + const Config = @import("../config.zig").Config; + const crypto = @import("../utils/crypto.zig"); + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const api_key_plain = cfg.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendLogMetric(api_key_hash, commit_id.?, name.?, value.?, step); + try client.receiveAndHandleResponse(allocator, "Log metric"); +} + +fn executeShow(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len < 1) { + std.debug.print("Usage: ml experiment show \n", .{}); + return; + } + + const commit_id = args[0]; + + const Config = @import("../config.zig").Config; + const crypto = @import("../utils/crypto.zig"); + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const api_key_plain = cfg.api_key; + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{cfg.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendGetExperiment(api_key_hash, commit_id); + + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + const packet = try protocol.ResponsePacket.deserialize(message, allocator); + defer { + // Clean up allocated strings from packet + if (packet.success_message) |msg| allocator.free(msg); + if (packet.error_message) |msg| allocator.free(msg); + if (packet.error_details) |details| allocator.free(details); + if (packet.data_type) |dtype| allocator.free(dtype); + if (packet.data_payload) |payload| allocator.free(payload); + } + + // For now, let's just print the result + switch (packet.packet_type) { + .success, .data => { + if (packet.data_payload) |payload| { + // Parse JSON response + const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch |err| { + std.debug.print("Failed to parse response: {}\n", .{err}); + return; + }; + defer parsed.deinit(); + + const root = parsed.value; + if (root != .object) { + std.debug.print("Invalid response format\n", .{}); + return; + } + + const metadata = root.object.get("metadata"); + const metrics = root.object.get("metrics"); + + if (metadata != null and metadata.? == .object) { + std.debug.print("\nExperiment Details:\n", .{}); + std.debug.print("-------------------\n", .{}); + const m = metadata.?.object; + if (m.get("JobName")) |v| std.debug.print("Job Name: {s}\n", .{v.string}); + if (m.get("CommitID")) |v| std.debug.print("Commit ID: {s}\n", .{v.string}); + if (m.get("User")) |v| std.debug.print("User: {s}\n", .{v.string}); + if (m.get("Timestamp")) |v| { + const ts = v.integer; + std.debug.print("Timestamp: {d}\n", .{ts}); + } + } + + if (metrics != null and metrics.? == .array) { + std.debug.print("\nMetrics:\n", .{}); + std.debug.print("-------------------\n", .{}); + const items = metrics.?.array.items; + if (items.len == 0) { + std.debug.print("No metrics logged.\n", .{}); + } else { + for (items) |item| { + if (item == .object) { + const name = item.object.get("name").?.string; + const value = item.object.get("value").?.float; + const step = item.object.get("step").?.integer; + std.debug.print("{s}: {d:.4} (Step: {d})\n", .{ name, value, step }); + } + } + } + } + std.debug.print("\n", .{}); + } else if (packet.success_message) |msg| { + std.debug.print("{s}\n", .{msg}); + } + }, + .error_packet => { + if (packet.error_message) |msg| { + std.debug.print("Error: {s}\n", .{msg}); + } + }, + else => { + std.debug.print("Unexpected response type\n", .{}); + }, + } +} diff --git a/cli/src/commands/init.zig b/cli/src/commands/init.zig new file mode 100644 index 0000000..8354e6d --- /dev/null +++ b/cli/src/commands/init.zig @@ -0,0 +1,13 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; + +pub fn run(_: std.mem.Allocator, _: []const []const u8) !void { + std.debug.print("ML Experiment Manager - Configuration Setup\n\n", .{}); + std.debug.print("Please create ~/.ml/config.toml with the following format:\n\n", .{}); + std.debug.print("worker_host = \"worker.local\"\n", .{}); + std.debug.print("worker_user = \"mluser\"\n", .{}); + std.debug.print("worker_base = \"/data/ml-experiments\"\n", .{}); + std.debug.print("worker_port = 22\n", .{}); + std.debug.print("api_key = \"your-api-key\"\n", .{}); + std.debug.print("\n[OK] Configuration template shown above\n", .{}); +} diff --git a/cli/src/commands/monitor.zig b/cli/src/commands/monitor.zig new file mode 100644 index 0000000..83cc27a --- /dev/null +++ b/cli/src/commands/monitor.zig @@ -0,0 +1,39 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + _ = args; + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + std.debug.print("Launching TUI via SSH...\n", .{}); + + // Build SSH command + const ssh_cmd = try std.fmt.allocPrint( + allocator, + "ssh -t -p {d} {s}@{s} 'cd {s} && ./bin/tui --config configs/config-local.yaml --api-key {s}'", + .{ config.worker_port, config.worker_user, config.worker_host, config.worker_base, config.api_key }, + ); + defer allocator.free(ssh_cmd); + + // Execute SSH command + var child = std.process.Child.init(&[_][]const u8{ "sh", "-c", ssh_cmd }, allocator); + child.stdin_behavior = .Inherit; + child.stdout_behavior = .Inherit; + child.stderr_behavior = .Inherit; + + const term = try child.spawnAndWait(); + + switch (term) { + .Exited => |code| { + if (code != 0) { + std.debug.print("TUI exited with code {d}\n", .{code}); + } + }, + else => {}, + } +} diff --git a/cli/src/commands/prune.zig b/cli/src/commands/prune.zig new file mode 100644 index 0000000..4597b7b --- /dev/null +++ b/cli/src/commands/prune.zig @@ -0,0 +1,93 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws.zig"); +const crypto = @import("../utils/crypto.zig"); +const logging = @import("../utils/logging.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + var keep_count: ?u32 = null; + var older_than_days: ?u32 = null; + + // Parse flags + var i: usize = 0; + while (i < args.len) : (i += 1) { + if (std.mem.eql(u8, args[i], "--keep") and i + 1 < args.len) { + keep_count = try std.fmt.parseInt(u32, args[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, args[i], "--older-than") and i + 1 < args.len) { + older_than_days = try std.fmt.parseInt(u32, args[i + 1], 10); + i += 1; + } + } + + if (keep_count == null and older_than_days == null) { + logging.info("Usage: ml prune --keep OR --older-than \n", .{}); + return error.InvalidArgs; + } + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Add confirmation prompt + if (keep_count) |count| { + if (!logging.confirm("This will permanently delete all but the {d} most recent experiments. Continue?", .{count})) { + logging.info("Prune cancelled.\n", .{}); + return; + } + } else if (older_than_days) |days| { + if (!logging.confirm("This will permanently delete all experiments older than {d} days. Continue?", .{days})) { + logging.info("Prune cancelled.\n", .{}); + return; + } + } + + logging.info("Pruning experiments...\n", .{}); + + // Use plain password for WebSocket authentication, hash for binary protocol + const api_key_plain = config.api_key; // Plain password from config + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + // Connect to WebSocket and send prune message + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + // Determine prune type and send message + var prune_type: u8 = undefined; + var value: u32 = undefined; + + if (keep_count) |count| { + prune_type = 0; // keep N + value = count; + logging.info("Keeping {d} most recent experiments\n", .{count}); + } + if (older_than_days) |days| { + prune_type = 1; // older than days + value = days; + logging.info("Removing experiments older than {d} days\n", .{days}); + } + + try client.sendPrune(api_key_hash, prune_type, value); + + // Receive response + const response = try client.receiveMessage(allocator); + defer allocator.free(response); + + // Parse prune response (simplified - assumes success/failure byte) + if (response.len > 0) { + if (response[0] == 0x00) { + logging.success("✓ Prune operation completed successfully\n", .{}); + } else { + logging.err("✗ Prune operation failed: error code {d}\n", .{response[0]}); + return error.PruneFailed; + } + } else { + logging.success("✓ Prune request sent (no response received)\n", .{}); + } +} diff --git a/cli/src/commands/queue.zig b/cli/src/commands/queue.zig new file mode 100644 index 0000000..0e1e0d2 --- /dev/null +++ b/cli/src/commands/queue.zig @@ -0,0 +1,118 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws.zig"); +const crypto = @import("../utils/crypto.zig"); +const colors = @import("../utils/colors.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + colors.printError("Usage: ml queue [job2 job3...] [--commit ] [--priority N]\n", .{}); + return error.InvalidArgs; + } + + // Support batch operations - multiple job names + var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { + colors.printError("Failed to allocate job list: {}\n", .{err}); + return err; + }; + defer job_names.deinit(allocator); + + var commit_id: ?[]const u8 = null; + var priority: u8 = 5; + + // Parse arguments - separate job names from flags + var i: usize = 0; + while (i < args.len) : (i += 1) { + const arg = args[i]; + + if (std.mem.startsWith(u8, arg, "--")) { + // Parse flags + if (std.mem.eql(u8, arg, "--commit") and i + 1 < args.len) { + commit_id = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, arg, "--priority") and i + 1 < args.len) { + priority = try std.fmt.parseInt(u8, args[i + 1], 10); + i += 1; + } + } else { + // This is a job name + job_names.append(allocator, arg) catch |err| { + colors.printError("Failed to append job: {}\n", .{err}); + return err; + }; + } + } + + if (job_names.items.len == 0) { + colors.printError("No job names specified\n", .{}); + return error.InvalidArgs; + } + + colors.printInfo("Queueing {d} job(s)...\n", .{job_names.items.len}); + + // Process each job + var success_count: usize = 0; + var failed_jobs = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { + colors.printError("Failed to allocate failed jobs list: {}\n", .{err}); + return err; + }; + defer failed_jobs.deinit(allocator); + + for (job_names.items, 0..) |job_name, index| { + colors.printProgress("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name }); + + queueSingleJob(allocator, job_name, commit_id, priority) catch |err| { + colors.printError("Failed to queue job '{s}': {}\n", .{ job_name, err }); + failed_jobs.append(allocator, job_name) catch |append_err| { + colors.printError("Failed to track failed job: {}\n", .{append_err}); + }; + continue; + }; + + colors.printSuccess("Successfully queued job '{s}'\n", .{job_name}); + success_count += 1; + } + + // Show summary + colors.printInfo("Batch queuing complete.\n", .{}); + colors.printSuccess("Successfully queued: {d} job(s)\n", .{success_count}); + + if (failed_jobs.items.len > 0) { + colors.printError("Failed to queue: {d} job(s)\n", .{failed_jobs.items.len}); + for (failed_jobs.items) |failed_job| { + colors.printError(" - {s}\n", .{failed_job}); + } + } +} + +fn queueSingleJob(allocator: std.mem.Allocator, job_name: []const u8, commit_id: ?[]const u8, priority: u8) !void { + if (commit_id == null) { + colors.printError("Error: --commit is required\n", .{}); + return error.MissingCommit; + } + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_id.? }); + + // Use plain password for WebSocket authentication, hash for binary protocol + const api_key_plain = config.api_key; // Plain password from config + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + // Connect to WebSocket and send queue message + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_plain); + defer client.close(); + + try client.sendQueueJob(job_name, commit_id.?, priority, api_key_hash); + + // Receive structured response + try client.receiveAndHandleResponse(allocator, "Job queue"); +} diff --git a/cli/src/commands/status.zig b/cli/src/commands/status.zig new file mode 100644 index 0000000..832443e --- /dev/null +++ b/cli/src/commands/status.zig @@ -0,0 +1,95 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws.zig"); +const crypto = @import("../utils/crypto.zig"); +const errors = @import("../errors.zig"); +const logging = @import("../utils/logging.zig"); + +const UserContext = struct { + name: []const u8, + admin: bool, + allocator: std.mem.Allocator, + + pub fn deinit(self: *UserContext) void { + self.allocator.free(self.name); + } +}; + +fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { + // Validate API key by making a simple API call to the server + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + // Try to connect with the API key to validate it + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return error.AuthenticationFailed, + } + }; + defer client.close(); + + // For now, create a user context after successful authentication + // In a real implementation, this would get user info from the server + const user_name = try allocator.dupe(u8, "authenticated_user"); + return UserContext{ + .name = user_name, + .admin = false, + .allocator = allocator, + }; +} + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + _ = args; + + // Load configuration with proper error handling + const config = Config.load(allocator) catch |err| { + switch (err) { + error.FileNotFound => return error.ConfigNotFound, + else => return err, + } + }; + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Check if API key is configured + if (config.api_key.len == 0) { + return error.APIKeyMissing; + } + + // Authenticate with server to get user context + var user_context = try authenticateUser(allocator, config); + defer user_context.deinit(); + + // Use plain password for WebSocket authentication, compute hash for binary protocol + const api_key_plain = config.api_key; // Plain password from config + const api_key_hash = try crypto.hashString(allocator, api_key_plain); + defer allocator.free(api_key_hash); + + // Connect to WebSocket and request status + const ws_url = std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}) catch |err| { + return err; + }; + defer allocator.free(ws_url); + + var client = ws.Client.connect(allocator, ws_url, api_key_plain) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return err, + } + }; + defer client.close(); + + client.sendStatusRequest(api_key_hash) catch { + return error.RequestFailed; + }; + + // Receive and display user-filtered response + try client.receiveAndHandleStatusResponse(allocator, user_context); +} diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig new file mode 100644 index 0000000..0822457 --- /dev/null +++ b/cli/src/commands/sync.zig @@ -0,0 +1,160 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const rsync = @import("../utils/rsync_embedded.zig"); // Use embedded rsync +const storage = @import("../utils/storage.zig"); +const ws = @import("../net/ws.zig"); +const logging = @import("../utils/logging.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + logging.err("Usage: ml sync [--name ] [--queue] [--priority N]\n", .{}); + return error.InvalidArgs; + } + + const path = args[0]; + var job_name: ?[]const u8 = null; + var should_queue = false; + var priority: u8 = 5; + + // Parse flags + var i: usize = 1; + while (i < args.len) : (i += 1) { + if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) { + job_name = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, args[i], "--queue")) { + should_queue = true; + } else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) { + priority = try std.fmt.parseInt(u8, args[i + 1], 10); + i += 1; + } + } + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + // Calculate commit ID (SHA256 of directory tree) + const commit_id = try crypto.hashDirectory(allocator, path); + defer allocator.free(commit_id); + + // Content-addressed storage optimization + // try cas.deduplicateDirectory(path); + + // Use local file operations instead of rsync for testing + const local_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}/files/", + .{ config.worker_base, commit_id }, + ); + defer allocator.free(local_path); + + // Create directory and copy files locally + try std.fs.cwd().makePath(local_path); + + var src_dir = try std.fs.cwd().openDir(path, .{ .iterate = true }); + defer src_dir.close(); + + var dest_dir = try std.fs.cwd().openDir(local_path, .{ .iterate = true }); + defer dest_dir.close(); + + var walker = try src_dir.walk(allocator); + defer walker.deinit(); + + while (try walker.next()) |entry| { + std.debug.print("Processing entry: {s}\n", .{entry.path}); + if (entry.kind == .file) { + const rel_path = try allocator.dupe(u8, entry.path); + defer allocator.free(rel_path); + + std.debug.print("Copying file: {s}\n", .{rel_path}); + const src_file = try src_dir.openFile(rel_path, .{}); + defer src_file.close(); + + const dest_file = try dest_dir.createFile(rel_path, .{}); + defer dest_file.close(); + + const src_contents = try src_file.readToEndAlloc(allocator, 1024 * 1024); + defer allocator.free(src_contents); + + try dest_file.writeAll(src_contents); + colors.printSuccess("Successfully copied: {s}\n", .{rel_path}); + } + } + + std.debug.print("✓ Files synced successfully\n", .{}); + + // If queue flag is set, queue the job + if (should_queue) { + const queue_cmd = @import("queue.zig"); + const actual_job_name = job_name orelse commit_id[0..8]; + const queue_args = [_][]const u8{ actual_job_name, "--commit", commit_id, "--priority", try std.fmt.allocPrint(allocator, "{d}", .{priority}) }; + defer allocator.free(queue_args[queue_args.len - 1]); + try queue_cmd.run(allocator, &queue_args); + } + + // Optional: Connect to server for progress monitoring if --monitor flag is used + var monitor_progress = false; + for (args[1..]) |arg| { + if (std.mem.eql(u8, arg, "--monitor")) { + monitor_progress = true; + break; + } + } + + if (monitor_progress) { + std.debug.print("\nMonitoring sync progress...\n", .{}); + try monitorSyncProgress(allocator, &config, commit_id); + } +} + +fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, commit_id: []const u8) !void { + _ = commit_id; + // Use plain password for WebSocket authentication + const api_key_plain = config.api_key; + + // Connect to server with retry logic + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + logging.info("Connecting to server {s}...\n", .{ws_url}); + var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3); + defer client.disconnect(); + + // Send progress monitoring request (this would be a new opcode on the server side) + // For now, we'll just listen for any progress messages + + var timeout_counter: u32 = 0; + const max_timeout = 30; // 30 seconds timeout + var spinner_index: usize = 0; + const spinner_chars = [_]u8{ '|', '/', '-', '\\' }; + + while (timeout_counter < max_timeout) { + const message = client.receiveMessage(allocator) catch |err| { + switch (err) { + error.ConnectionClosed, error.ConnectionTimedOut => { + timeout_counter += 1; + spinner_index = (spinner_index + 1) % 4; + logging.progress("Waiting for progress {c} (attempt {d}/{d})\n", .{ spinner_chars[spinner_index], timeout_counter, max_timeout }); + std.Thread.sleep(1 * std.time.ns_per_s); + continue; + }, + else => return err, + } + }; + defer allocator.free(message); + + // For now, just display a simple success message + // TODO: Implement proper JSON parsing and packet handling + logging.success("Sync progress message received\n", .{}); + break; + } + + if (timeout_counter >= max_timeout) { + std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{}); + } +} diff --git a/cli/src/commands/watch.zig b/cli/src/commands/watch.zig new file mode 100644 index 0000000..b28fa20 --- /dev/null +++ b/cli/src/commands/watch.zig @@ -0,0 +1,124 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const rsync = @import("../utils/rsync.zig"); +const ws = @import("../net/ws.zig"); + +pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len == 0) { + std.debug.print("Usage: ml watch [--name ] [--priority N] [--queue]\n", .{}); + return error.InvalidArgs; + } + + const path = args[0]; + var job_name: ?[]const u8 = null; + var priority: u8 = 5; + var should_queue = false; + + // Parse flags + var i: usize = 1; + while (i < args.len) : (i += 1) { + if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) { + job_name = args[i + 1]; + i += 1; + } else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) { + priority = try std.fmt.parseInt(u8, args[i + 1], 10); + i += 1; + } else if (std.mem.eql(u8, args[i], "--queue")) { + should_queue = true; + } + } + + const config = try Config.load(allocator); + defer { + var mut_config = config; + mut_config.deinit(allocator); + } + + std.debug.print("Watching {s} for changes...\n", .{path}); + std.debug.print("Press Ctrl+C to stop\n", .{}); + + // Initial sync + var last_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config); + defer allocator.free(last_commit_id); + + // Watch for changes + var watcher = try std.fs.cwd().openDir(path, .{ .iterate = true }); + defer watcher.close(); + + var last_modified: u64 = 0; + + while (true) { + // Check for file changes + var modified = false; + var walker = try watcher.walk(allocator); + defer walker.deinit(); + + while (try walker.next()) |entry| { + if (entry.kind == .file) { + const file = try watcher.openFile(entry.path, .{}); + defer file.close(); + + const stat = try file.stat(); + if (stat.mtime > last_modified) { + last_modified = @intCast(stat.mtime); + modified = true; + } + } + } + + if (modified) { + std.debug.print("\nChanges detected, syncing...\n", .{}); + + const new_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config); + defer allocator.free(new_commit_id); + + if (!std.mem.eql(u8, last_commit_id, new_commit_id)) { + allocator.free(last_commit_id); + last_commit_id = try allocator.dupe(u8, new_commit_id); + std.debug.print("✓ Synced new version: {s}\n", .{last_commit_id[0..8]}); + } + } + + // Wait before checking again + std.Thread.sleep(2_000_000_000); // 2 seconds in nanoseconds + } +} + +fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, config: Config) ![]u8 { + // Calculate commit ID + const commit_id = try crypto.hashDirectory(allocator, path); + + // Sync files via rsync + const remote_path = try std.fmt.allocPrint( + allocator, + "{s}@{s}:{s}/{s}/files/", + .{ config.worker_user, config.worker_host, config.worker_base, commit_id }, + ); + defer allocator.free(remote_path); + + try rsync.sync(allocator, path, remote_path, config.worker_port); + + if (should_queue) { + const actual_job_name = job_name orelse commit_id[0..8]; + const api_key_hash = config.api_key; + + // Connect to WebSocket and queue job + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, api_key_hash); + defer client.close(); + + try client.sendQueueJob(actual_job_name, commit_id, priority, api_key_hash); + + const response = try client.receiveMessage(allocator); + defer allocator.free(response); + + if (response.len > 0 and response[0] == 0x00) { + std.debug.print("✓ Job queued successfully: {s}\n", .{actual_job_name}); + } + } + + return commit_id; +} diff --git a/cli/src/config.zig b/cli/src/config.zig new file mode 100644 index 0000000..b999ca7 --- /dev/null +++ b/cli/src/config.zig @@ -0,0 +1,155 @@ +const std = @import("std"); + +pub const Config = struct { + worker_host: []const u8, + worker_user: []const u8, + worker_base: []const u8, + worker_port: u16, + api_key: []const u8, + + pub fn validate(self: Config) !void { + // Validate host + if (self.worker_host.len == 0) { + return error.EmptyHost; + } + + // Validate port range + if (self.worker_port == 0 or self.worker_port > 65535) { + return error.InvalidPort; + } + + // Validate API key format (should be hex string) + if (self.api_key.len == 0) { + return error.EmptyAPIKey; + } + + // Check if API key is valid hex + for (self.api_key) |char| { + if (!((char >= '0' and char <= '9') or + (char >= 'a' and char <= 'f') or + (char >= 'A' and char <= 'F'))) + { + return error.InvalidAPIKeyFormat; + } + } + + // Validate base path + if (self.worker_base.len == 0) { + return error.EmptyBasePath; + } + } + + pub fn load(allocator: std.mem.Allocator) !Config { + const home = std.posix.getenv("HOME") orelse return error.NoHomeDir; + const config_path = try std.fmt.allocPrint(allocator, "{s}/.ml/config.toml", .{home}); + defer allocator.free(config_path); + + const file = std.fs.openFileAbsolute(config_path, .{}) catch |err| { + if (err == error.FileNotFound) { + std.debug.print("Config file not found. Run 'ml init' first.\n", .{}); + return error.ConfigNotFound; + } + return err; + }; + defer file.close(); + + // Load config with environment variable overrides + var config = try loadFromFile(allocator, file); + + // Apply environment variable overrides + if (std.posix.getenv("ML_HOST")) |host| { + config.worker_host = try allocator.dupe(u8, host); + } + if (std.posix.getenv("ML_USER")) |user| { + config.worker_user = try allocator.dupe(u8, user); + } + if (std.posix.getenv("ML_BASE")) |base| { + config.worker_base = try allocator.dupe(u8, base); + } + if (std.posix.getenv("ML_PORT")) |port_str| { + config.worker_port = try std.fmt.parseInt(u16, port_str, 10); + } + if (std.posix.getenv("ML_API_KEY")) |api_key| { + config.api_key = try allocator.dupe(u8, api_key); + } + + try config.validate(); + return config; + } + + fn loadFromFile(allocator: std.mem.Allocator, file: std.fs.File) !Config { + const content = try file.readToEndAlloc(allocator, 1024 * 1024); + defer allocator.free(content); + + // Simple TOML parser - parse key=value pairs + var config = Config{ + .worker_host = "", + .worker_user = "", + .worker_base = "", + .worker_port = 22, + .api_key = "", + }; + + var lines = std.mem.splitScalar(u8, content, '\n'); + while (lines.next()) |line| { + const trimmed = std.mem.trim(u8, line, " \t\r"); + if (trimmed.len == 0 or trimmed[0] == '#') continue; + + var parts = std.mem.splitScalar(u8, trimmed, '='); + const key = std.mem.trim(u8, parts.next() orelse continue, " \t"); + const value_raw = std.mem.trim(u8, parts.next() orelse continue, " \t"); + + // Remove quotes + const value = if (value_raw.len >= 2 and value_raw[0] == '"' and value_raw[value_raw.len - 1] == '"') + value_raw[1 .. value_raw.len - 1] + else + value_raw; + + if (std.mem.eql(u8, key, "worker_host")) { + config.worker_host = try allocator.dupe(u8, value); + } else if (std.mem.eql(u8, key, "worker_user")) { + config.worker_user = try allocator.dupe(u8, value); + } else if (std.mem.eql(u8, key, "worker_base")) { + config.worker_base = try allocator.dupe(u8, value); + } else if (std.mem.eql(u8, key, "worker_port")) { + config.worker_port = try std.fmt.parseInt(u16, value, 10); + } else if (std.mem.eql(u8, key, "api_key")) { + config.api_key = try allocator.dupe(u8, value); + } + } + + return config; + } + + pub fn save(self: Config, allocator: std.mem.Allocator) !void { + const home = std.posix.getenv("HOME") orelse return error.NoHomeDir; + + // Create .ml directory + const ml_dir = try std.fmt.allocPrint(allocator, "{s}/.ml", .{home}); + defer allocator.free(ml_dir); + + std.fs.makeDirAbsolute(ml_dir) catch |err| { + if (err != error.PathAlreadyExists) return err; + }; + + const config_path = try std.fmt.allocPrint(allocator, "{s}/config.toml", .{ml_dir}); + defer allocator.free(config_path); + + const file = try std.fs.createFileAbsolute(config_path, .{}); + defer file.close(); + + const writer = file.writer(); + try writer.print("worker_host = \"{s}\"\n", .{self.worker_host}); + try writer.print("worker_user = \"{s}\"\n", .{self.worker_user}); + try writer.print("worker_base = \"{s}\"\n", .{self.worker_base}); + try writer.print("worker_port = {d}\n", .{self.worker_port}); + try writer.print("api_key = \"{s}\"\n", .{self.api_key}); + } + + pub fn deinit(self: *Config, allocator: std.mem.Allocator) void { + allocator.free(self.worker_host); + allocator.free(self.worker_user); + allocator.free(self.worker_base); + allocator.free(self.api_key); + } +}; diff --git a/cli/src/errors.zig b/cli/src/errors.zig new file mode 100644 index 0000000..6db939f --- /dev/null +++ b/cli/src/errors.zig @@ -0,0 +1,206 @@ +const std = @import("std"); +const colors = @import("utils/colors.zig"); +const ws = @import("net/ws.zig"); +const crypto = @import("utils/crypto.zig"); + +/// User-friendly error types for CLI +pub const CLIError = error{ + // Configuration errors + ConfigNotFound, + ConfigInvalid, + APIKeyMissing, + APIKeyInvalid, + + // Network errors + ConnectionFailed, + ServerUnreachable, + AuthenticationFailed, + RequestTimeout, + + // Command errors + InvalidArguments, + MissingCommit, + CommandFailed, + JobNotFound, + PermissionDenied, + ResourceExists, + JobAlreadyRunning, + JobCancelled, + ServerError, + SyncFailed, + + // System errors + OutOfMemory, + FileSystemError, + ProcessError, +}; + +/// Error message mapping +pub const ErrorMessages = struct { + pub fn getMessage(err: anyerror) []const u8 { + return switch (err) { + // Configuration errors + error.ConfigNotFound => "Configuration file not found. Run 'ml init' to create one.", + error.ConfigInvalid => "Configuration file is invalid. Please check your settings.", + error.APIKeyMissing => "API key not configured. Set it in your config file.", + error.APIKeyInvalid => "API key is invalid. Please check your credentials.", + error.InvalidAPIKeyFormat => "API key format is invalid. Expected 64-character hash.", + + // Network errors + error.ConnectionFailed => "Failed to connect to server. Check if the server is running.", + error.ServerUnreachable => "Server is unreachable. Verify network connectivity.", + error.AuthenticationFailed => "Authentication failed. Check your API key.", + error.RequestTimeout => "Request timed out. Server may be busy.", + + // Command errors + error.InvalidArguments => "Invalid command arguments. Use --help for usage.", + error.MissingCommit => "Missing commit ID. Use --commit to specify the commit.", + error.CommandFailed => "Command failed. Check server logs for details.", + error.JobNotFound => "Job not found. Verify the job name.", + error.PermissionDenied => "Permission denied. Check your user permissions.", + error.ResourceExists => "Resource already exists.", + error.JobAlreadyRunning => "Job is already running.", + error.JobCancelled => "Job was cancelled.", + error.PruneFailed => "Prune operation failed. Check server logs for details.", + error.ServerError => "Server error occurred. Check server logs for details.", + error.SyncFailed => "Sync operation failed.", + + // System errors + error.OutOfMemory => "Out of memory. Close other applications and try again.", + error.FileSystemError => "File system error. Check disk space and permissions.", + error.ProcessError => "Process error. Try again or contact support.", + + // WebSocket specific errors + error.InvalidURL => "Invalid server URL. Check your configuration.", + error.TLSNotSupported => "TLS (HTTPS) not supported in this build.", + error.ConnectionRefused => "Connection refused. Server may not be running.", + error.NetworkUnreachable => "Network unreachable. Check your internet connection.", + error.InvalidFrame => "Invalid WebSocket frame. Protocol error.", + error.EndpointNotFound => "WebSocket endpoint not found. Server may not be running or is misconfigured.", + error.ServerUnavailable => "Server is temporarily unavailable.", + error.HandshakeFailed => "WebSocket handshake failed.", + + // Default fallback + else => "An unexpected error occurred. Please try again or contact support.", + }; + } + + /// Check if error is user-fixable + pub fn isUserFixable(err: anyerror) bool { + return switch (err) { + error.ConfigNotFound, + error.ConfigInvalid, + error.APIKeyMissing, + error.APIKeyInvalid, + error.InvalidArguments, + error.MissingCommit, + error.JobNotFound, + error.ResourceExists, + error.JobAlreadyRunning, + error.JobCancelled, + error.SyncFailed, + => true, + + error.ConnectionFailed, + error.ServerUnreachable, + error.AuthenticationFailed, + error.RequestTimeout, + error.ConnectionRefused, + error.NetworkUnreachable, + => true, + + else => false, + }; + } + + /// Get suggestion for fixing the error + pub fn getSuggestion(err: anyerror) ?[]const u8 { + return switch (err) { + error.ConfigNotFound => "Run 'ml init' to create a configuration file.", + error.APIKeyMissing => "Add your API key to the configuration file.", + error.ConnectionFailed => "Start the API server with 'api-server' or check if it's running.", + error.AuthenticationFailed => "Verify your API key in the configuration.", + error.InvalidArguments => "Use 'ml --help' for correct usage.", + error.MissingCommit => "Use --commit to specify the commit ID for your job.", + error.JobNotFound => "List available jobs with 'ml status'.", + error.ResourceExists => "Use a different name or remove the existing resource.", + error.JobAlreadyRunning => "Wait for the current job to finish or cancel it first.", + error.JobCancelled => "The job was cancelled. You can restart it if needed.", + error.SyncFailed => "Check network connectivity and server status.", + else => null, + }; + } +}; + +/// Error handler for CLI commands +pub const ErrorHandler = struct { + /// Send crash report to server for non-user-fixable errors + fn sendCrashReport() void { + return; + } + + pub fn display(self: ErrorHandler, err: anyerror, context: ?[]const u8) void { + _ = self; // Self not used in current implementation + const message = ErrorMessages.getMessage(err); + const suggestion = ErrorMessages.getSuggestion(err); + + colors.printError("Error: {s}\n", .{message}); + + if (context) |ctx| { + colors.printWarning("Context: {s}\n", .{ctx}); + } + + if (suggestion) |sug| { + colors.printInfo("Suggestion: {s}\n", .{sug}); + } + + if (ErrorMessages.isUserFixable(err)) { + colors.printInfo("This error can be fixed by updating your configuration.\n", .{}); + } + } + + pub fn handleCommandError(err: anyerror, command: []const u8) void { + _ = command; // TODO: Use command in crash report + // Send crash report for non-user-fixable errors + sendCrashReport(); + + const message = ErrorMessages.getMessage(err); + const suggestion = ErrorMessages.getSuggestion(err); + const is_fixable = ErrorMessages.isUserFixable(err); + + colors.printError("Error: {s}\n", .{message}); + + if (suggestion) |sug| { + colors.printInfo("Suggestion: {s}\n", .{sug}); + } + + if (is_fixable) { + colors.printInfo("This is a user-fixable issue.\n", .{}); + } else { + colors.printWarning("If this persists, check server logs or contact support.\n", .{}); + } + + // Exit with appropriate code + std.process.exit(if (is_fixable) 2 else 1); + } + + pub fn handleNetworkError(err: anyerror, operation: []const u8) void { + std.debug.print("Network error during {s}: {s}\n", .{ operation, ErrorMessages.getMessage(err) }); + + if (ErrorMessages.getSuggestion(err)) |sug| { + std.debug.print("Try: {s}\n", .{sug}); + } + + std.process.exit(3); + } + + pub fn handleConfigError(err: anyerror) void { + std.debug.print("Configuration error: {s}\n", .{ErrorMessages.getMessage(err)}); + + if (ErrorMessages.getSuggestion(err)) |sug| { + std.debug.print("Fix: {s}\n", .{sug}); + } + + std.process.exit(4); + } +}; diff --git a/cli/src/main.zig b/cli/src/main.zig new file mode 100644 index 0000000..b8911b2 --- /dev/null +++ b/cli/src/main.zig @@ -0,0 +1,125 @@ +const std = @import("std"); +const errors = @import("errors.zig"); +const colors = @import("utils/colors.zig"); + +// Global verbosity level +var verbose_mode: bool = false; +var quiet_mode: bool = false; + +const commands = struct { + const init = @import("commands/init.zig"); + const sync = @import("commands/sync.zig"); + const queue = @import("commands/queue.zig"); + const status = @import("commands/status.zig"); + const monitor = @import("commands/monitor.zig"); + const cancel = @import("commands/cancel.zig"); + const prune = @import("commands/prune.zig"); + const watch = @import("commands/watch.zig"); + const dataset = @import("commands/dataset.zig"); + const experiment = @import("commands/experiment.zig"); +}; + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + const allocator = gpa.allocator(); + + // Parse command line arguments + var args_iter = std.process.args(); + _ = args_iter.next(); // Skip executable name + + var command_args = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { + colors.printError("Failed to allocate command args: {}\n", .{err}); + return err; + }; + defer command_args.deinit(allocator); + + // Parse global flags first + while (args_iter.next()) |arg| { + if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { + printUsage(); + return; + } else if (std.mem.eql(u8, arg, "--verbose") or std.mem.eql(u8, arg, "-v")) { + verbose_mode = true; + quiet_mode = false; + } else if (std.mem.eql(u8, arg, "--quiet") or std.mem.eql(u8, arg, "-q")) { + quiet_mode = true; + verbose_mode = false; + } else { + command_args.append(allocator, arg) catch |err| { + colors.printError("Failed to append argument: {}\n", .{err}); + return err; + }; + } + } + + const args = command_args.items; + + if (args.len < 1) { + colors.printError("No command specified.\n", .{}); + printUsage(); + return; + } + + const command = args[0]; + + // Handle commands with proper error handling + if (std.mem.eql(u8, command, "--help") or std.mem.eql(u8, command, "help")) { + printUsage(); + return; + } + + const command_result = if (std.mem.eql(u8, command, "init")) + commands.init.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "sync")) + commands.sync.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "queue")) + commands.queue.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "status")) + commands.status.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "monitor")) + commands.monitor.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "cancel")) + commands.cancel.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "prune")) + commands.prune.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "watch")) + commands.watch.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "dataset")) + commands.dataset.run(allocator, args[1..]) + else if (std.mem.eql(u8, command, "experiment")) + commands.experiment.execute(allocator, args[1..]) + else + error.InvalidArguments; + + // Handle any errors that occur during command execution + command_result catch |err| { + errors.ErrorHandler.handleCommandError(err, command); + }; +} + +pub fn printUsage() void { + colors.printInfo("ML Experiment Manager\n\n", .{}); + std.debug.print("Usage: ml [options]\n\n", .{}); + std.debug.print("Commands:\n", .{}); + std.debug.print(" init Setup configuration interactively\n", .{}); + std.debug.print(" sync Sync project to server\n", .{}); + std.debug.print(" queue Queue job for execution\n", .{}); + std.debug.print(" status Get system status\n", .{}); + std.debug.print(" monitor Launch TUI via SSH\n", .{}); + std.debug.print(" cancel Cancel running job\n", .{}); + std.debug.print(" prune --keep N Keep N most recent experiments\n", .{}); + std.debug.print(" prune --older-than D Remove experiments older than D days\n", .{}); + std.debug.print(" watch Watch directory for auto-sync\n", .{}); + std.debug.print(" dataset Manage datasets (list, upload, download, delete)\n", .{}); + std.debug.print(" experiment Manage experiments (log, show)\n\n", .{}); + std.debug.print("Options:\n", .{}); + std.debug.print(" --help Show this help message\n", .{}); + std.debug.print(" --verbose Enable verbose output\n", .{}); + std.debug.print(" --quiet Suppress non-error output\n", .{}); + std.debug.print(" --monitor Monitor progress of long-running operations\n", .{}); +} + +test "basic test" { + try std.testing.expectEqual(@as(i32, 10), 10); +} diff --git a/cli/src/net/protocol.zig b/cli/src/net/protocol.zig new file mode 100644 index 0000000..bd75afb --- /dev/null +++ b/cli/src/net/protocol.zig @@ -0,0 +1,336 @@ +const std = @import("std"); + +/// Response packet types for structured server responses +pub const PacketType = enum(u8) { + success = 0x00, + error_packet = 0x01, + progress = 0x02, + status = 0x03, + data = 0x04, + log = 0x05, +}; + +/// Error codes for structured error responses +pub const ErrorCode = enum(u8) { + // General errors (0x00-0x0F) + unknown_error = 0x00, + invalid_request = 0x01, + authentication_failed = 0x02, + permission_denied = 0x03, + resource_not_found = 0x04, + resource_already_exists = 0x05, + + // Server errors (0x10-0x1F) + server_overloaded = 0x10, + database_error = 0x11, + network_error = 0x12, + storage_error = 0x13, + timeout = 0x14, + + // Job errors (0x20-0x2F) + job_not_found = 0x20, + job_already_running = 0x21, + job_failed_to_start = 0x22, + job_execution_failed = 0x23, + job_cancelled = 0x24, + + // System errors (0x30-0x3F) + out_of_memory = 0x30, + disk_full = 0x31, + invalid_configuration = 0x32, + service_unavailable = 0x33, +}; + +/// Progress update types +pub const ProgressType = enum(u8) { + percentage = 0x00, + stage = 0x01, + message = 0x02, + bytes_transferred = 0x03, +}; + +/// Base response packet structure +pub const ResponsePacket = struct { + packet_type: PacketType, + timestamp: u64, // Unix timestamp + + // Success packet fields + success_message: ?[]const u8 = null, + + // Error packet fields + error_code: ?ErrorCode = null, + error_message: ?[]const u8 = null, + error_details: ?[]const u8 = null, + + // Progress packet fields + progress_type: ?ProgressType = null, + progress_value: ?u32 = null, + progress_total: ?u32 = null, + progress_message: ?[]const u8 = null, + + // Status packet fields + status_data: ?[]const u8 = null, + + // Data packet fields + data_type: ?[]const u8 = null, + data_payload: ?[]const u8 = null, + + // Log packet fields + log_level: ?u8 = null, + log_message: ?[]const u8 = null, + + pub fn initSuccess(timestamp: u64, message: ?[]const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .success, + .timestamp = timestamp, + .success_message = message, + }; + } + + pub fn initError(timestamp: u64, code: ErrorCode, message: []const u8, details: ?[]const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .error_packet, + .timestamp = timestamp, + .error_code = code, + .error_message = message, + .error_details = details, + }; + } + + pub fn initProgress(timestamp: u64, ptype: ProgressType, value: u32, total: ?u32, message: ?[]const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .progress, + .timestamp = timestamp, + .progress_type = ptype, + .progress_value = value, + .progress_total = total, + .progress_message = message, + }; + } + + pub fn initStatus(timestamp: u64, data: []const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .status, + .timestamp = timestamp, + .status_data = data, + }; + } + + pub fn initData(timestamp: u64, data_type: []const u8, payload: []const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .data, + .timestamp = timestamp, + .data_type = data_type, + .data_payload = payload, + }; + } + + pub fn initLog(timestamp: u64, level: u8, message: []const u8) ResponsePacket { + return ResponsePacket{ + .packet_type = .log, + .timestamp = timestamp, + .log_level = level, + .log_message = message, + }; + } + + /// Serialize packet to binary format + pub fn serialize(self: ResponsePacket, allocator: std.mem.Allocator) ![]u8 { + var buffer = try std.ArrayList(u8).initCapacity(allocator, 256); + defer buffer.deinit(allocator); + + try buffer.append(allocator, @intFromEnum(self.packet_type)); + try buffer.appendSlice(allocator, &std.mem.toBytes(self.timestamp)); + + switch (self.packet_type) { + .success => { + if (self.success_message) |msg| { + try writeString(&buffer, allocator, msg); + } else { + try writeString(&buffer, allocator, ""); + } + }, + .error_packet => { + try buffer.append(allocator, @intFromEnum(self.error_code.?)); + try writeString(&buffer, allocator, self.error_message.?); + if (self.error_details) |details| { + try writeString(&buffer, allocator, details); + } else { + try writeString(&buffer, allocator, ""); + } + }, + .progress => { + try buffer.append(allocator, @intFromEnum(self.progress_type.?)); + try buffer.appendSlice(allocator, &std.mem.toBytes(self.progress_value.?)); + if (self.progress_total) |total| { + try buffer.appendSlice(allocator, &std.mem.toBytes(total)); + } else { + try buffer.appendSlice(allocator, &[4]u8{ 0, 0, 0, 0 }); // 0 indicates no total + } + if (self.progress_message) |msg| { + try writeString(&buffer, allocator, msg); + } else { + try writeString(&buffer, allocator, ""); + } + }, + .status => { + try writeString(&buffer, allocator, self.status_data.?); + }, + .data => { + try writeString(&buffer, allocator, self.data_type.?); + try writeBytes(&buffer, allocator, self.data_payload.?); + }, + .log => { + try buffer.append(allocator, self.log_level.?); + try writeString(&buffer, allocator, self.log_message.?); + }, + } + + return buffer.toOwnedSlice(allocator); + } + + /// Deserialize packet from binary format + pub fn deserialize(data: []const u8, allocator: std.mem.Allocator) !ResponsePacket { + if (data.len < 9) return error.InvalidPacket; // packet_type + timestamp + + var offset: usize = 0; + const packet_type = @as(PacketType, @enumFromInt(data[offset])); + offset += 1; + + const timestamp = std.mem.readInt(u64, data[offset .. offset + 8][0..8], .big); + offset += 8; + + switch (packet_type) { + .success => { + const message = try readString(data, &offset, allocator); + return ResponsePacket.initSuccess(timestamp, message); + }, + .error_packet => { + if (offset >= data.len) return error.InvalidPacket; + const error_code = @as(ErrorCode, @enumFromInt(data[offset])); + offset += 1; + + const error_message = try readString(data, &offset, allocator); + const error_details = try readString(data, &offset, allocator); + + return ResponsePacket.initError(timestamp, error_code, error_message, if (error_details.len > 0) error_details else null); + }, + .progress => { + if (offset + 1 + 4 + 4 > data.len) return error.InvalidPacket; + const progress_type = @as(ProgressType, @enumFromInt(data[offset])); + offset += 1; + + const progress_value = std.mem.readInt(u32, data[offset .. offset + 4][0..4], .big); + offset += 4; + + const progress_total = std.mem.readInt(u32, data[offset .. offset + 4][0..4], .big); + offset += 4; + + const progress_message = try readString(data, &offset, allocator); + + return ResponsePacket.initProgress(timestamp, progress_type, progress_value, if (progress_total > 0) progress_total else null, if (progress_message.len > 0) progress_message else null); + }, + .status => { + const status_data = try readString(data, &offset, allocator); + return ResponsePacket.initStatus(timestamp, status_data); + }, + .data => { + const data_type = try readString(data, &offset, allocator); + const data_payload = try readBytes(data, &offset, allocator); + return ResponsePacket.initData(timestamp, data_type, data_payload); + }, + .log => { + if (offset >= data.len) return error.InvalidPacket; + const log_level = data[offset]; + offset += 1; + + const log_message = try readString(data, &offset, allocator); + return ResponsePacket.initLog(timestamp, log_level, log_message); + }, + } + } + + /// Get human-readable error message for error code + pub fn getErrorMessage(code: ErrorCode) []const u8 { + return switch (code) { + .unknown_error => "Unknown error occurred", + .invalid_request => "Invalid request format", + .authentication_failed => "Authentication failed", + .permission_denied => "Permission denied", + .resource_not_found => "Resource not found", + .resource_already_exists => "Resource already exists", + + .server_overloaded => "Server is overloaded", + .database_error => "Database error occurred", + .network_error => "Network error occurred", + .storage_error => "Storage error occurred", + .timeout => "Operation timed out", + + .job_not_found => "Job not found", + .job_already_running => "Job is already running", + .job_failed_to_start => "Job failed to start", + .job_execution_failed => "Job execution failed", + .job_cancelled => "Job was cancelled", + + .out_of_memory => "Server out of memory", + .disk_full => "Server disk full", + .invalid_configuration => "Invalid server configuration", + .service_unavailable => "Service temporarily unavailable", + }; + } + + /// Get log level name + pub fn getLogLevelName(level: u8) []const u8 { + return switch (level) { + 0 => "DEBUG", + 1 => "INFO", + 2 => "WARN", + 3 => "ERROR", + else => "UNKNOWN", + }; + } +}; + +/// Helper function to write string with length prefix +fn writeString(buffer: *std.ArrayList(u8), allocator: std.mem.Allocator, str: []const u8) !void { + try buffer.appendSlice(allocator, &std.mem.toBytes(@as(u16, @intCast(str.len)))); + try buffer.appendSlice(allocator, str); +} + +/// Helper function to write bytes with length prefix +fn writeBytes(buffer: *std.ArrayList(u8), allocator: std.mem.Allocator, bytes: []const u8) !void { + try buffer.appendSlice(allocator, &std.mem.toBytes(@as(u32, @intCast(bytes.len)))); + try buffer.appendSlice(allocator, bytes); +} + +/// Helper function to read string with length prefix +fn readString(data: []const u8, offset: *usize, allocator: std.mem.Allocator) ![]const u8 { + if (offset.* + 2 > data.len) return error.InvalidPacket; + + const len = std.mem.readInt(u16, data[offset.* .. offset.* + 2][0..2], .big); + offset.* += 2; + + if (offset.* + len > data.len) return error.InvalidPacket; + + const str = try allocator.alloc(u8, len); + @memcpy(str, data[offset.* .. offset.* + len]); + offset.* += len; + + return str; +} + +/// Helper function to read bytes with length prefix +fn readBytes(data: []const u8, offset: *usize, allocator: std.mem.Allocator) ![]const u8 { + if (offset.* + 4 > data.len) return error.InvalidPacket; + + const len = std.mem.readInt(u32, data[offset.* .. offset.* + 4][0..4], .big); + offset.* += 4; + + if (offset.* + len > data.len) return error.InvalidPacket; + + const bytes = try allocator.alloc(u8, len); + @memcpy(bytes, data[offset.* .. offset.* + len]); + offset.* += len; + + return bytes; +} diff --git a/cli/src/net/ws.zig b/cli/src/net/ws.zig new file mode 100644 index 0000000..ba03503 --- /dev/null +++ b/cli/src/net/ws.zig @@ -0,0 +1,835 @@ +const std = @import("std"); +const crypto = @import("../utils/crypto.zig"); +const protocol = @import("protocol.zig"); +const log = @import("../utils/logging.zig"); + +/// Binary WebSocket protocol opcodes +pub const Opcode = enum(u8) { + queue_job = 0x01, + status_request = 0x02, + cancel_job = 0x03, + prune = 0x04, + crash_report = 0x05, + log_metric = 0x0A, + get_experiment = 0x0B, + + // Dataset management opcodes + dataset_list = 0x06, + dataset_register = 0x07, + dataset_info = 0x08, + dataset_search = 0x09, + + // Structured response opcodes + response_success = 0x10, + response_error = 0x11, + response_progress = 0x12, + response_status = 0x13, + response_data = 0x14, + response_log = 0x15, +}; + +/// WebSocket client for binary protocol communication +pub const Client = struct { + allocator: std.mem.Allocator, + stream: ?std.net.Stream, + host: []const u8, + port: u16, + is_tls: bool = false, + + pub fn connect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !Client { + // Detect TLS + const is_tls = std.mem.startsWith(u8, url, "wss://"); + + // Parse URL (simplified - assumes ws://host:port/path or wss://host:port/path) + const host_start = std.mem.indexOf(u8, url, "//") orelse return error.InvalidURL; + const host_port_start = host_start + 2; + const path_start = std.mem.indexOfPos(u8, url, host_port_start, "/") orelse url.len; + const colon_pos = std.mem.indexOfPos(u8, url, host_port_start, ":"); + + const host_end = blk: { + if (colon_pos) |pos| { + if (pos < path_start) break :blk pos; + } + break :blk path_start; + }; + const host = url[host_port_start..host_end]; + + var port: u16 = if (is_tls) 9101 else 9100; // default ports + if (colon_pos) |pos| { + if (pos < path_start) { + const port_start = pos + 1; + const port_end = std.mem.indexOfPos(u8, url, port_start, "/") orelse url.len; + port = try std.fmt.parseInt(u16, url[port_start..port_end], 10); + } + } + + // Connect to server + const address = try resolveHostAddress(allocator, host, port); + const stream = try std.net.tcpConnectToAddress(address); + + // For TLS, we'd need to wrap the stream with TLS + // For now, we'll just support ws:// and document wss:// requires additional setup + if (is_tls) { + std.log.warn("TLS (wss://) support requires additional TLS library integration", .{}); + return error.TLSNotSupported; + } + + // Perform WebSocket handshake + try handshake(allocator, stream, host, url, api_key); + + return Client{ + .allocator = allocator, + .stream = stream, + .host = try allocator.dupe(u8, host), + .port = port, + .is_tls = is_tls, + }; + } + + /// Connect to WebSocket server with retry logic + pub fn connectWithRetry(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8, max_retries: u32) !Client { + var retry_count: u32 = 0; + var last_error: anyerror = error.ConnectionFailed; + + while (retry_count < max_retries) { + const client = connect(allocator, url, api_key) catch |err| { + last_error = err; + retry_count += 1; + + if (retry_count < max_retries) { + const delay_ms = @min(1000 * retry_count, 5000); // Exponential backoff, max 5s + log.warn("Connection failed (attempt {d}/{d}), retrying in {d}s...\n", .{ retry_count, max_retries, delay_ms / 1000 }); + std.Thread.sleep(@as(u64, delay_ms) * std.time.ns_per_ms); + } + continue; + }; + + if (retry_count > 0) { + log.success("Connected successfully after {d} attempts\n", .{retry_count + 1}); + } + return client; + } + + return last_error; + } + + /// Disconnect from WebSocket server + pub fn disconnect(self: *Client) void { + if (self.stream) |stream| { + stream.close(); + self.stream = null; + } + } + + fn handshake(allocator: std.mem.Allocator, stream: std.net.Stream, host: []const u8, url: []const u8, api_key: []const u8) !void { + const key = try generateWebSocketKey(allocator); + defer allocator.free(key); + + // Send handshake request with API key authentication + const request = try std.fmt.allocPrint(allocator, "GET {s} HTTP/1.1\r\n" ++ + "Host: {s}\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: {s}\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "X-API-Key: {s}\r\n" ++ + "\r\n", .{ url, host, key, api_key }); + defer allocator.free(request); + + _ = try stream.write(request); + + // Read response + var response_buf: [1024]u8 = undefined; + const bytes_read = try stream.read(&response_buf); + const response = response_buf[0..bytes_read]; + + // Check for successful handshake + if (std.mem.indexOf(u8, response, "101 Switching Protocols") == null) { + // Parse HTTP status code for better error messages + if (std.mem.indexOf(u8, response, "404 Not Found") != null) { + std.debug.print("\n❌ WebSocket Connection Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("The WebSocket endpoint '/ws' was not found on the server.\n\n", .{}); + std.debug.print("This usually means:\n", .{}); + std.debug.print(" • API server is not running\n", .{}); + std.debug.print(" • Incorrect server address in config\n", .{}); + std.debug.print(" • Different service running on that port\n\n", .{}); + std.debug.print("To diagnose:\n", .{}); + std.debug.print(" • Verify server address: Check ~/.ml/config.toml\n", .{}); + std.debug.print(" • Test connectivity: curl http://:/health\n", .{}); + std.debug.print(" • Contact your server administrator if the issue persists\n\n", .{}); + return error.EndpointNotFound; + } else if (std.mem.indexOf(u8, response, "401 Unauthorized") != null) { + std.debug.print("\n❌ Authentication Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Invalid or missing API key.\n\n", .{}); + std.debug.print("To fix:\n", .{}); + std.debug.print(" • Verify API key in ~/.ml/config.toml matches server configuration\n", .{}); + std.debug.print(" • Request a new API key from your administrator if needed\n\n", .{}); + return error.AuthenticationFailed; + } else if (std.mem.indexOf(u8, response, "403 Forbidden") != null) { + std.debug.print("\n❌ Access Denied\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Your API key doesn't have permission for this operation.\n\n", .{}); + std.debug.print("To fix:\n", .{}); + std.debug.print(" • Contact your administrator to grant necessary permissions\n\n", .{}); + return error.PermissionDenied; + } else if (std.mem.indexOf(u8, response, "503 Service Unavailable") != null) { + std.debug.print("\n❌ Server Unavailable\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("The server is temporarily unavailable.\n\n", .{}); + std.debug.print("This could be due to:\n", .{}); + std.debug.print(" • Server maintenance\n", .{}); + std.debug.print(" • High load\n", .{}); + std.debug.print(" • Server restart\n\n", .{}); + std.debug.print("To resolve:\n", .{}); + std.debug.print(" • Wait a moment and try again\n", .{}); + std.debug.print(" • Contact administrator if the issue persists\n\n", .{}); + return error.ServerUnavailable; + } else { + // Generic handshake failure - show response for debugging + std.debug.print("\n❌ WebSocket Handshake Failed\n", .{}); + std.debug.print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n", .{}); + std.debug.print("Expected HTTP 101 Switching Protocols, but received:\n", .{}); + + // Show first line of response (status line) + const newline_pos = std.mem.indexOf(u8, response, "\r\n") orelse response.len; + const status_line = response[0..newline_pos]; + std.debug.print(" {s}\n\n", .{status_line}); + + std.debug.print("To diagnose:\n", .{}); + std.debug.print(" • Verify server address in ~/.ml/config.toml\n", .{}); + std.debug.print(" • Check network connectivity to the server\n", .{}); + std.debug.print(" • Contact your administrator for assistance\n\n", .{}); + return error.HandshakeFailed; + } + } + } + + fn generateWebSocketKey(allocator: std.mem.Allocator) ![]u8 { + var random_bytes: [16]u8 = undefined; + std.crypto.random.bytes(&random_bytes); + + const base64 = std.base64.standard.Encoder; + const result = try allocator.alloc(u8, base64.calcSize(random_bytes.len)); + _ = base64.encode(result, &random_bytes); + return result; + } + + pub fn close(self: *Client) void { + if (self.stream) |stream| { + stream.close(); + self.stream = null; + } + if (self.host.len > 0) { + self.allocator.free(self.host); + } + } + + pub fn sendQueueJob(self: *Client, job_name: []const u8, commit_id: []const u8, priority: u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + // Validate input lengths + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (commit_id.len != 64) return error.InvalidCommitId; + if (job_name.len > 255) return error.JobNameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [commit_id: 64 bytes] [priority: u8] [job_name_len: u8] [job_name: var] + const total_len = 1 + 64 + 64 + 1 + 1 + job_name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.queue_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + @memcpy(buffer[offset .. offset + 64], commit_id); + offset += 64; + + buffer[offset] = priority; + offset += 1; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset..], job_name); + + // Send as WebSocket binary frame + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendCancelJob(self: *Client, job_name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (job_name.len > 255) return error.JobNameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [job_name_len: u8] [job_name: var] + const total_len = 1 + 64 + 1 + job_name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.cancel_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset..], job_name); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendPrune(self: *Client, api_key_hash: []const u8, prune_type: u8, value: u32) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [prune_type: u8] [value: u4] + const total_len = 1 + 64 + 1 + 4; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.prune); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + buffer[offset] = prune_type; + offset += 1; + + // Store value in big-endian format + buffer[offset] = @intCast((value >> 24) & 0xFF); + buffer[offset + 1] = @intCast((value >> 16) & 0xFF); + buffer[offset + 2] = @intCast((value >> 8) & 0xFF); + buffer[offset + 3] = @intCast(value & 0xFF); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] + const total_len = 1 + 64; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + buffer[0] = @intFromEnum(Opcode.status_request); + @memcpy(buffer[1..65], api_key_hash); + + try sendWebSocketFrame(stream, buffer); + } + + fn sendWebSocketFrame(stream: std.net.Stream, payload: []const u8) !void { + var frame: [14]u8 = undefined; // Extra space for mask + var frame_len: usize = 2; + + // FIN=1, opcode=0x2 (binary), MASK=1 + frame[0] = 0x82 | 0x80; + + // Payload length + if (payload.len < 126) { + frame[1] = @as(u8, @intCast(payload.len)) | 0x80; // Set MASK bit + } else if (payload.len < 65536) { + frame[1] = 126 | 0x80; // Set MASK bit + frame[2] = @intCast(payload.len >> 8); + frame[3] = @intCast(payload.len & 0xFF); + frame_len = 4; + } else { + return error.PayloadTooLarge; + } + + // Generate random mask (4 bytes) + var mask: [4]u8 = undefined; + var i: usize = 0; + while (i < 4) : (i += 1) { + mask[i] = @as(u8, @intCast(@mod(std.time.timestamp(), 256))); + } + + // Copy mask to frame + @memcpy(frame[frame_len .. frame_len + 4], &mask); + frame_len += 4; + + // Send frame header + _ = try stream.write(frame[0..frame_len]); + + // Send payload with masking + var masked_payload = try std.heap.page_allocator.alloc(u8, payload.len); + defer std.heap.page_allocator.free(masked_payload); + + for (payload, 0..) |byte, j| { + masked_payload[j] = byte ^ mask[j % 4]; + } + + _ = try stream.write(masked_payload); + } + + pub fn receiveMessage(self: *Client, allocator: std.mem.Allocator) ![]u8 { + const stream = self.stream orelse return error.NotConnected; + + // Read frame header + var header: [2]u8 = undefined; + const header_bytes = try stream.read(&header); + if (header_bytes < 2) return error.ConnectionClosed; + + // Check for binary frame and FIN bit + if (header[0] != 0x82) return error.InvalidFrame; + + // Get payload length + var payload_len: usize = header[1]; + if (payload_len == 126) { + var len_bytes: [2]u8 = undefined; + _ = try stream.read(&len_bytes); + payload_len = (@as(usize, len_bytes[0]) << 8) | len_bytes[1]; + } else if (payload_len == 127) { + return error.PayloadTooLarge; + } + + // Read payload + const payload = try allocator.alloc(u8, payload_len); + errdefer allocator.free(payload); + + var bytes_read: usize = 0; + while (bytes_read < payload_len) { + const n = try stream.read(payload[bytes_read..]); + if (n == 0) return error.ConnectionClosed; + bytes_read += n; + } + + return payload; + } + + /// Receive and handle response with automatic display + pub fn receiveAndHandleResponse(self: *Client, allocator: std.mem.Allocator, operation: []const u8) !void { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // For now, just display a simple success message + // TODO: Implement proper JSON parsing and packet handling + std.debug.print("{s} completed successfully\n", .{operation}); + } + + /// Receive and handle status response with user filtering + pub fn receiveAndHandleStatusResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype) !void { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // For now, just display a simple success message with user context + // TODO: Parse JSON response and display user-filtered jobs + std.debug.print("Status retrieved for user: {s}\n", .{user_context.name}); + + // Display basic status summary + std.debug.print("Your jobs will be displayed here\n", .{}); + } + + /// Receive and handle cancel response with user permissions + pub fn receiveAndHandleCancelResponse(self: *Client, allocator: std.mem.Allocator, user_context: anytype, job_name: []const u8) !void { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // For now, just display a simple success message with user context + // TODO: Parse response and handle permission errors + std.debug.print("Job '{s}' cancellation processed for user: {s}\n", .{ job_name, user_context.name }); + std.debug.print("Response will be parsed here\n", .{}); + } + + /// Handle response packet with appropriate display + pub fn handleResponsePacket(self: *Client, packet: protocol.ResponsePacket, operation: []const u8) !void { + switch (packet.packet_type) { + .success => { + if (packet.success_message) |msg| { + if (msg.len > 0) { + std.debug.print("✓ {s}: {s}\n", .{ operation, msg }); + } else { + std.debug.print("✓ {s} completed successfully\n", .{operation}); + } + } else { + std.debug.print("✓ {s} completed successfully\n", .{operation}); + } + }, + .error_packet => { + const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?); + std.debug.print("✗ {s} failed: {s}\n", .{ operation, error_msg }); + + if (packet.error_message) |msg| { + if (msg.len > 0) { + std.debug.print("Details: {s}\n", .{msg}); + } + } + + if (packet.error_details) |details| { + if (details.len > 0) { + std.debug.print("Additional info: {s}\n", .{details}); + } + } + + // Convert to appropriate CLI error + return self.convertServerError(packet.error_code.?); + }, + .progress => { + if (packet.progress_type) |ptype| { + switch (ptype) { + .percentage => { + const percentage = packet.progress_value.?; + if (packet.progress_total) |total| { + std.debug.print("Progress: {d}/{d} ({d:.1}%)\n", .{ percentage, total, @as(f32, @floatFromInt(percentage)) * 100.0 / @as(f32, @floatFromInt(total)) }); + } else { + std.debug.print("Progress: {d}%\n", .{percentage}); + } + }, + .stage => { + if (packet.progress_message) |msg| { + std.debug.print("Stage: {s}\n", .{msg}); + } + }, + .message => { + if (packet.progress_message) |msg| { + std.debug.print("Info: {s}\n", .{msg}); + } + }, + .bytes_transferred => { + const bytes = packet.progress_value.?; + if (packet.progress_total) |total| { + const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; + const total_mb = @as(f64, @floatFromInt(total)) / 1024.0 / 1024.0; + std.debug.print("Transferred: {d:.2} MB / {d:.2} MB\n", .{ transferred_mb, total_mb }); + } else { + const transferred_mb = @as(f64, @floatFromInt(bytes)) / 1024.0 / 1024.0; + std.debug.print("Transferred: {d:.2} MB\n", .{transferred_mb}); + } + }, + } + } + }, + .status => { + if (packet.status_data) |data| { + std.debug.print("Status: {s}\n", .{data}); + } + }, + .data => { + if (packet.data_type) |dtype| { + std.debug.print("Data [{s}]: ", .{dtype}); + if (packet.data_payload) |payload| { + // Try to display as string if it looks like text + const is_text = for (payload) |byte| { + if (byte < 32 and byte != '\n' and byte != '\r' and byte != '\t') break false; + } else true; + + if (is_text) { + std.debug.print("{s}\n", .{payload}); + } else { + std.debug.print("{d} bytes\n", .{payload.len}); + } + } + } + }, + .log => { + if (packet.log_level) |level| { + const level_name = protocol.ResponsePacket.getLogLevelName(level); + if (packet.log_message) |msg| { + std.debug.print("[{s}] {s}\n", .{ level_name, msg }); + } + } + }, + } + } + + /// Convert server error code to CLI error + fn convertServerError(self: *Client, server_error: protocol.ErrorCode) anyerror { + _ = self; // Client instance not needed for error conversion + return switch (server_error) { + .authentication_failed => error.AuthenticationFailed, + .permission_denied => error.PermissionDenied, + .resource_not_found => error.JobNotFound, + .resource_already_exists => error.ResourceExists, + .timeout => error.RequestTimeout, + .server_overloaded, .service_unavailable => error.ServerUnreachable, + .invalid_request => error.InvalidArguments, + .job_not_found => error.JobNotFound, + .job_already_running => error.JobAlreadyRunning, + .job_failed_to_start, .job_execution_failed => error.CommandFailed, + .job_cancelled => error.JobCancelled, + else => error.ServerError, + }; + } + + /// Clean up packet allocated memory + pub fn cleanupPacket(self: *Client, packet: protocol.ResponsePacket) void { + if (packet.success_message) |msg| { + self.allocator.free(msg); + } + if (packet.error_message) |msg| { + self.allocator.free(msg); + } + if (packet.error_details) |details| { + self.allocator.free(details); + } + if (packet.progress_message) |msg| { + self.allocator.free(msg); + } + if (packet.status_data) |data| { + self.allocator.free(data); + } + if (packet.data_type) |dtype| { + self.allocator.free(dtype); + } + if (packet.data_payload) |payload| { + self.allocator.free(payload); + } + if (packet.log_message) |msg| { + self.allocator.free(msg); + } + } + pub fn sendCrashReport(self: *Client, api_key_hash: []const u8, error_type: []const u8, error_message: []const u8, command: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + + // Build binary message: [opcode:1][api_key_hash:64][error_type_len:2][error_type][error_message_len:2][error_message][command_len:2][command] + const total_len = 1 + 64 + 2 + error_type.len + 2 + error_message.len + 2 + command.len; + const message = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(message); + + var offset: usize = 0; + + // Opcode + message[offset] = @intFromEnum(Opcode.crash_report); + offset += 1; + + // API key hash + @memcpy(message[offset .. offset + 64], api_key_hash); + offset += 64; + + // Error type length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_type.len), .big); + offset += 2; + @memcpy(message[offset .. offset + error_type.len], error_type); + offset += error_type.len; + + // Error message length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(error_message.len), .big); + offset += 2; + @memcpy(message[offset .. offset + error_message.len], error_message); + offset += error_message.len; + + // Command length and data + std.mem.writeInt(u16, message[offset .. offset + 2][0..2], @intCast(command.len), .big); + offset += 2; + @memcpy(message[offset .. offset + command.len], command); + + // Send WebSocket frame + try sendWebSocketFrame(stream, message); + } + + // Dataset management methods + pub fn sendDatasetList(self: *Client, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + + // Build binary message: [opcode: u8] [api_key_hash: 64 bytes] + const total_len = 1 + 64; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + buffer[0] = @intFromEnum(Opcode.dataset_list); + @memcpy(buffer[1..65], api_key_hash); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetRegister(self: *Client, name: []const u8, url: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + if (url.len > 1023) return error.URLTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [name_len: u8] [name: var] [url_len: u16] [url: var] + const total_len = 1 + 64 + 1 + name.len + 2 + url.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.dataset_register); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + name.len], name); + offset += name.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(url.len), .big); + offset += 2; + + @memcpy(buffer[offset .. offset + url.len], url); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetInfo(self: *Client, name: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [name_len: u8] [name: var] + const total_len = 1 + 64 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.dataset_info); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset..], name); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendDatasetSearch(self: *Client, term: []const u8, api_key_hash: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (term.len > 255) return error.SearchTermTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [term_len: u8] [term: var] + const total_len = 1 + 64 + 1 + term.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.dataset_search); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + buffer[offset] = @intCast(term.len); + offset += 1; + + @memcpy(buffer[offset..], term); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendLogMetric(self: *Client, api_key_hash: []const u8, commit_id: []const u8, name: []const u8, value: f64, step: u32) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (commit_id.len != 64) return error.InvalidCommitId; + if (name.len > 255) return error.NameTooLong; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [commit_id: 64 bytes] [step: u32] [value: f64] [name_len: u8] [name: var] + const total_len = 1 + 64 + 64 + 4 + 8 + 1 + name.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.log_metric); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + @memcpy(buffer[offset .. offset + 64], commit_id); + offset += 64; + + std.mem.writeInt(u32, buffer[offset .. offset + 4][0..4], step, .big); + offset += 4; + + std.mem.writeInt(u64, buffer[offset .. offset + 8][0..8], @as(u64, @bitCast(value)), .big); + offset += 8; + + buffer[offset] = @intCast(name.len); + offset += 1; + + @memcpy(buffer[offset..], name); + + try sendWebSocketFrame(stream, buffer); + } + + pub fn sendGetExperiment(self: *Client, api_key_hash: []const u8, commit_id: []const u8) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 64) return error.InvalidApiKeyHash; + if (commit_id.len != 64) return error.InvalidCommitId; + + // Build binary message: + // [opcode: u8] [api_key_hash: 64 bytes] [commit_id: 64 bytes] + const total_len = 1 + 64 + 64; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(Opcode.get_experiment); + offset += 1; + + @memcpy(buffer[offset .. offset + 64], api_key_hash); + offset += 64; + + @memcpy(buffer[offset .. offset + 64], commit_id); + + try sendWebSocketFrame(stream, buffer); + } + + /// Receive and handle dataset response + pub fn receiveAndHandleDatasetResponse(self: *Client, allocator: std.mem.Allocator) ![]const u8 { + const message = try self.receiveMessage(allocator); + defer allocator.free(message); + + // For now, just return the message as a string + // TODO: Parse JSON response and format properly + return allocator.dupe(u8, message); + } +}; + +fn resolveHostAddress(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { + return std.net.Address.parseIp(host, port) catch |err| switch (err) { + error.InvalidIPAddressFormat => resolveHostname(allocator, host, port), + else => return err, + }; +} + +fn resolveHostname(allocator: std.mem.Allocator, host: []const u8, port: u16) !std.net.Address { + var address_list = try std.net.getAddressList(allocator, host, port); + defer address_list.deinit(); + + if (address_list.addrs.len == 0) return error.HostResolutionFailed; + + return address_list.addrs[0]; +} + +test "resolve hostnames for WebSocket connections" { + _ = try resolveHostAddress(std.testing.allocator, "localhost", 9100); +} diff --git a/cli/src/utils/colors.zig b/cli/src/utils/colors.zig new file mode 100644 index 0000000..7d5473f --- /dev/null +++ b/cli/src/utils/colors.zig @@ -0,0 +1,80 @@ +const std = @import("std"); + +/// ANSI color codes for terminal output +pub const Color = struct { + pub const Reset = "\x1b[0m"; + pub const Bright = "\x1b[1m"; + pub const Dim = "\x1b[2m"; + + pub const Black = "\x1b[30m"; + pub const Red = "\x1b[31m"; + pub const Green = "\x1b[32m"; + pub const Yellow = "\x1b[33m"; + pub const Blue = "\x1b[34m"; + pub const Magenta = "\x1b[35m"; + pub const Cyan = "\x1b[36m"; + pub const White = "\x1b[37m"; + + pub const BgBlack = "\x1b[40m"; + pub const BgRed = "\x1b[41m"; + pub const BgGreen = "\x1b[42m"; + pub const BgYellow = "\x1b[43m"; + pub const BgBlue = "\x1b[44m"; + pub const BgMagenta = "\x1b[45m"; + pub const BgCyan = "\x1b[46m"; + pub const BgWhite = "\x1b[47m"; +}; + +/// Check if terminal supports colors +pub fn supportsColors() bool { + if (std.process.getEnvVarOwned(std.heap.page_allocator, "NO_COLOR")) |_| { + return false; + } else |_| { + // Check if we're in a terminal + return std.posix.isatty(std.posix.STDOUT_FILENO); + } +} + +/// Print colored text if colors are supported +pub fn print(comptime color: []const u8, comptime format: []const u8, args: anytype) void { + if (supportsColors()) { + const colored_format = color ++ format ++ Color.Reset ++ "\n"; + std.debug.print(colored_format, args); + } else { + std.debug.print(format ++ "\n", args); + } +} + +/// Print error message in red +pub fn printError(comptime format: []const u8, args: anytype) void { + print(Color.Red, format, args); +} + +/// Print success message in green +pub fn printSuccess(comptime format: []const u8, args: anytype) void { + print(Color.Green, format, args); +} + +/// Print warning message in yellow +pub fn printWarning(comptime format: []const u8, args: anytype) void { + print(Color.Yellow, format, args); +} + +/// Print info message in blue +pub fn printInfo(comptime format: []const u8, args: anytype) void { + print(Color.Blue, format, args); +} + +/// Print progress message in cyan +pub fn printProgress(comptime format: []const u8, args: anytype) void { + print(Color.Cyan, format, args); +} + +/// Ask for user confirmation (y/N) - simplified version +pub fn confirm(comptime prompt: []const u8, args: anytype) bool { + print(Color.Yellow, prompt ++ " [y/N]: ", args); + + // For now, always return true to avoid stdin complications + // TODO: Implement proper stdin reading when needed + return true; +} diff --git a/cli/src/utils/crypto.zig b/cli/src/utils/crypto.zig new file mode 100644 index 0000000..b69ce7d --- /dev/null +++ b/cli/src/utils/crypto.zig @@ -0,0 +1,114 @@ +const std = @import("std"); + +/// Hash a string using SHA256 and return lowercase hex string +pub fn hashString(allocator: std.mem.Allocator, input: []const u8) ![]u8 { + var hash: [32]u8 = undefined; + std.crypto.hash.sha2.Sha256.hash(input, &hash, .{}); + + // Convert to hex string manually + const hex = try allocator.alloc(u8, 64); + for (hash, 0..) |byte, i| { + const hi = (byte >> 4) & 0xf; + const lo = byte & 0xf; + hex[i * 2] = if (hi < 10) '0' + hi else 'a' + (hi - 10); + hex[i * 2 + 1] = if (lo < 10) '0' + lo else 'a' + (lo - 10); + } + return hex; +} + +/// Calculate commit ID for a directory (SHA256 of tree state) +pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 { + var hasher = std.crypto.hash.sha2.Sha256.init(.{}); + + var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true }); + defer dir.close(); + + var walker = try dir.walk(allocator); + defer walker.deinit(); + + // Collect and sort paths for deterministic hashing + var paths: std.ArrayList([]const u8) = .{}; + defer { + for (paths.items) |path| allocator.free(path); + paths.deinit(allocator); + } + + while (try walker.next()) |entry| { + if (entry.kind == .file) { + try paths.append(allocator, try allocator.dupe(u8, entry.path)); + } + } + + std.sort.block([]const u8, paths.items, {}, struct { + fn lessThan(_: void, a: []const u8, b: []const u8) bool { + return std.mem.order(u8, a, b) == .lt; + } + }.lessThan); + + // Hash each file path and content + for (paths.items) |path| { + hasher.update(path); + hasher.update(&[_]u8{0}); // Separator + + const file = try dir.openFile(path, .{}); + defer file.close(); + + var buf: [4096]u8 = undefined; + while (true) { + const bytes_read = try file.read(&buf); + if (bytes_read == 0) break; + hasher.update(buf[0..bytes_read]); + } + hasher.update(&[_]u8{0}); // Separator + } + + var hash: [32]u8 = undefined; + hasher.final(&hash); + + // Convert to hex string manually + const hex = try allocator.alloc(u8, 64); + for (hash, 0..) |byte, i| { + const hi = (byte >> 4) & 0xf; + const lo = byte & 0xf; + hex[i * 2] = if (hi < 10) '0' + hi else 'a' + (hi - 10); + hex[i * 2 + 1] = if (lo < 10) '0' + lo else 'a' + (lo - 10); + } + return hex; +} + +test "hash string" { + const allocator = std.testing.allocator; + + const hash = try hashString(allocator, "test"); + defer allocator.free(hash); + + // SHA256 of "test" + try std.testing.expectEqualStrings("9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08", hash); +} + +test "hash empty string" { + const allocator = std.testing.allocator; + + const hash = try hashString(allocator, ""); + defer allocator.free(hash); + + // SHA256 of empty string + try std.testing.expectEqualStrings("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash); +} + +test "hash directory" { + const allocator = std.testing.allocator; + + // For now, just test that we can hash the current directory + // This should work if there are files in the current directory + const hash = try hashDirectory(allocator, "."); + defer allocator.free(hash); + + // Should produce a valid 64-character hex string + try std.testing.expectEqual(@as(usize, 64), hash.len); + + // All characters should be valid hex + for (hash) |c| { + try std.testing.expect((c >= '0' and c <= '9') or (c >= 'a' and c <= 'f')); + } +} diff --git a/cli/src/utils/logging.zig b/cli/src/utils/logging.zig new file mode 100644 index 0000000..a6df7ea --- /dev/null +++ b/cli/src/utils/logging.zig @@ -0,0 +1,27 @@ +const std = @import("std"); +const colors = @import("colors.zig"); + +/// Simple logging utility that wraps colors functionality +pub fn info(comptime format: []const u8, args: anytype) void { + colors.printInfo(format, args); +} + +pub fn success(comptime format: []const u8, args: anytype) void { + colors.printSuccess(format, args); +} + +pub fn warn(comptime format: []const u8, args: anytype) void { + colors.printWarning(format, args); +} + +pub fn err(comptime format: []const u8, args: anytype) void { + colors.printError(format, args); +} + +pub fn progress(comptime format: []const u8, args: anytype) void { + colors.printProgress(format, args); +} + +pub fn confirm(comptime prompt: []const u8, args: anytype) bool { + return colors.confirm(prompt, args); +} diff --git a/cli/src/utils/rsync.zig b/cli/src/utils/rsync.zig new file mode 100644 index 0000000..349408b --- /dev/null +++ b/cli/src/utils/rsync.zig @@ -0,0 +1,45 @@ +const std = @import("std"); + +/// Sync local directory to remote via rsync over SSH +pub fn sync(allocator: std.mem.Allocator, local_path: []const u8, remote_path: []const u8, ssh_port: u16) !void { + const port_str = try std.fmt.allocPrint(allocator, "{d}", .{ssh_port}); + defer allocator.free(port_str); + + const ssh_opt = try std.fmt.allocPrint(allocator, "ssh -p {s}", .{port_str}); + defer allocator.free(ssh_opt); + + // Build rsync command: rsync -avz -e "ssh -p PORT" local/ remote/ + var child = std.process.Child.init( + &[_][]const u8{ + "rsync", + "-avz", + "--delete", + "-e", + ssh_opt, + local_path, + remote_path, + }, + allocator, + ); + + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Inherit; + child.stderr_behavior = .Inherit; + + const term = try child.spawnAndWait(); + + switch (term) { + .Exited => |code| { + if (code != 0) { + std.debug.print("rsync failed with exit code {d}\n", .{code}); + return error.RsyncFailed; + } + }, + .Signal => { + return error.RsyncKilled; + }, + else => { + return error.RsyncUnknownError; + }, + } +} diff --git a/cli/src/utils/rsync_embedded.zig b/cli/src/utils/rsync_embedded.zig new file mode 100644 index 0000000..126cc9f --- /dev/null +++ b/cli/src/utils/rsync_embedded.zig @@ -0,0 +1,107 @@ +const std = @import("std"); + +/// Embedded rsync binary functionality +pub const EmbeddedRsync = struct { + allocator: std.mem.Allocator, + + const Self = @This(); + + /// Extract embedded rsync binary to temporary location + pub fn extractRsyncBinary(self: Self) ![]const u8 { + const rsync_path = "/tmp/ml_rsync"; + + // Check if rsync binary already exists + if (std.fs.cwd().openFile(rsync_path, .{})) |file| { + file.close(); + // Check if it's executable + const stat = try std.fs.cwd().statFile(rsync_path); + if (stat.mode & 0o111 != 0) { + return try self.allocator.dupe(u8, rsync_path); + } + } else |_| { + // Need to extract the binary + try self.extractAndSetExecutable(rsync_path); + } + + return try self.allocator.dupe(u8, rsync_path); + } + + /// Extract rsync binary from embedded data and set executable permissions + fn extractAndSetExecutable(self: Self, path: []const u8) !void { + // Import embedded binary data + const embedded_binary = @import("rsync_embedded_binary.zig"); + + // Get the embedded rsync binary data + const binary_data = embedded_binary.getRsyncBinary(); + + // Debug output to show we're using embedded binary + const debug_msg = try std.fmt.allocPrint(self.allocator, "Extracting embedded rsync binary to {s} (size: {d} bytes)", .{ path, binary_data.len }); + defer self.allocator.free(debug_msg); + std.log.debug("{s}", .{debug_msg}); + + // Write embedded binary to file system + try std.fs.cwd().writeFile(.{ .sub_path = path, .data = binary_data }); + + // Set executable permissions using OS API + const mode = 0o755; // rwxr-xr-x + std.posix.fchmodat(std.fs.cwd().fd, path, mode, 0) catch |err| { + std.log.warn("Failed to set executable permissions on {s}: {}", .{ path, err }); + // Continue anyway - the script might still work + }; + } + + /// Sync using embedded rsync + pub fn sync(self: Self, local_path: []const u8, remote_path: []const u8, ssh_port: u16) !void { + const rsync_path = try self.extractRsyncBinary(); + defer self.allocator.free(rsync_path); + + const port_str = try std.fmt.allocPrint(self.allocator, "{d}", .{ssh_port}); + defer self.allocator.free(port_str); + + const ssh_opt = try std.fmt.allocPrint(self.allocator, "ssh -p {s}", .{port_str}); + defer self.allocator.free(ssh_opt); + + // Build rsync command using embedded binary + var child = std.process.Child.init( + &[_][]const u8{ + rsync_path, + "-avz", + "--delete", + "-e", + ssh_opt, + local_path, + remote_path, + }, + self.allocator, + ); + + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Inherit; + child.stderr_behavior = .Inherit; + + const term = try child.spawnAndWait(); + + switch (term) { + .Exited => |code| { + if (code != 0) { + std.debug.print("rsync failed with exit code {d}\n", .{code}); + return error.RsyncFailed; + } + }, + .Signal => { + return error.RsyncKilled; + }, + else => { + return error.RsyncUnknownError; + }, + } + } +}; + +/// Public sync function that uses embedded rsync +pub fn sync(allocator: std.mem.Allocator, local_path: []const u8, remote_path: []const u8, ssh_port: u16) !void { + var embedded_rsync = EmbeddedRsync{ .allocator = allocator }; + try embedded_rsync.sync(local_path, remote_path, ssh_port); + + std.debug.print("Synced {s} to {s} using embedded rsync\n", .{ local_path, remote_path }); +} diff --git a/cli/src/utils/rsync_embedded_binary.zig b/cli/src/utils/rsync_embedded_binary.zig new file mode 100644 index 0000000..3c467c1 --- /dev/null +++ b/cli/src/utils/rsync_embedded_binary.zig @@ -0,0 +1,24 @@ +const std = @import("std"); + +/// Embedded rsync binary data +/// For dev builds: uses placeholder wrapper that calls system rsync +/// For release builds: embed full static rsync binary from rsync_release.bin +/// +/// To prepare for release: +/// 1. Download or build a static rsync binary for your target platform +/// 2. Place it at cli/src/assets/rsync_release.bin +/// 3. Build with: zig build prod (or release/cross targets) + +// Check if rsync_release.bin exists, otherwise fall back to placeholder +const rsync_binary_data = if (@import("builtin").mode == .Debug or @import("builtin").mode == .ReleaseSafe) + @embedFile("../assets/rsync_placeholder.bin") +else + // For ReleaseSmall and ReleaseFast, try to use the release binary + @embedFile("../assets/rsync_release.bin"); + +pub const RSYNC_BINARY: []const u8 = rsync_binary_data; + +/// Get embedded rsync binary data +pub fn getRsyncBinary() []const u8 { + return RSYNC_BINARY; +} diff --git a/cli/src/utils/storage.zig b/cli/src/utils/storage.zig new file mode 100644 index 0000000..d3afc59 --- /dev/null +++ b/cli/src/utils/storage.zig @@ -0,0 +1,97 @@ +const std = @import("std"); +const crypto = @import("crypto.zig"); + +pub const ContentAddressedStorage = struct { + allocator: std.mem.Allocator, + base_path: []const u8, + + pub fn init(allocator: std.mem.Allocator, base_path: []const u8) ContentAddressedStorage { + return .{ + .allocator = allocator, + .base_path = base_path, + }; + } + + pub fn storeFile(self: *ContentAddressedStorage, content: []const u8) ![]const u8 { + // Hash content to get address + const hash = try crypto.hashString(self.allocator, content); + defer self.allocator.free(hash); + + // Create content path + const content_path = try std.fmt.allocPrint(self.allocator, "{s}/content/{s}/{s}", .{ self.base_path, hash[0..2], hash }); + defer self.allocator.free(content_path); + + // Check if content already exists + { + const file = std.fs.openFileAbsolute(content_path, .{}) catch |err| { + if (err == error.FileNotFound) { + // Store new content - simplified approach + const parent_dir = std.fs.path.dirname(content_path) orelse return error.InvalidPath; + std.fs.cwd().makePath(parent_dir) catch |make_err| { + if (make_err != error.PathAlreadyExists) return make_err; + }; + + const file = try std.fs.createFileAbsolute(content_path, .{}); + defer file.close(); + + try file.writeAll(content); + return self.allocator.dupe(u8, hash); + } + return err; + }; + defer file.close(); + + // Content already exists + return self.allocator.dupe(u8, hash); + } + } + + pub fn linkFile(self: *ContentAddressedStorage, content_hash: []const u8, target_path: []const u8) !void { + const content_path = try std.fmt.allocPrint(self.allocator, "{s}/content/{s}/{s}", .{ self.base_path, content_hash[0..2], content_hash }); + defer self.allocator.free(content_path); + + // Create parent directory if needed + const parent_dir = std.fs.path.dirname(target_path) orelse return error.InvalidPath; + std.fs.cwd().makePath(parent_dir) catch |make_err| { + if (make_err != error.PathAlreadyExists) return make_err; + }; + + // Copy file instead of hard link to avoid permission issues + const src_file = try std.fs.openFileAbsolute(content_path, .{}); + defer src_file.close(); + + const dest_file = try std.fs.createFileAbsolute(target_path, .{}); + defer dest_file.close(); + + const content = try src_file.readToEndAlloc(self.allocator, 1024 * 1024 * 100); + defer self.allocator.free(content); + + try dest_file.writeAll(content); + } + + pub fn deduplicateDirectory(self: *ContentAddressedStorage, dir_path: []const u8) !void { + var dir = try std.fs.openDirAbsolute(dir_path, .{ .iterate = true }); + defer dir.close(); + + var walker = try dir.walk(self.allocator); + defer walker.deinit(); + + while (try walker.next()) |entry| { + if (entry.kind == .file) { + const file = try dir.openFile(entry.path, .{}); + defer file.close(); + + const content = try file.readToEndAlloc(self.allocator, 1024 * 1024 * 100); // 100MB limit + defer self.allocator.free(content); + + const hash = try self.storeFile(content); + defer self.allocator.free(hash); + + const target_path = try std.fmt.allocPrint(self.allocator, "{s}/{s}", .{ dir_path, entry.path }); + defer self.allocator.free(target_path); + + try self.linkFile(hash, target_path); + } + } + } +}; diff --git a/cli/tests/config_test.zig b/cli/tests/config_test.zig new file mode 100644 index 0000000..379a85c --- /dev/null +++ b/cli/tests/config_test.zig @@ -0,0 +1,273 @@ +const std = @import("std"); +const testing = std.testing; +const src = @import("src"); +const Config = src.Config; + +test "config file validation" { + // Test various config file contents + const test_configs = [_]struct { + content: []const u8, + should_be_valid: bool, + description: []const u8, + }{ + .{ + .content = + \\server: + \\ host: "localhost" + \\ port: 8080 + \\auth: + \\ enabled: true + \\storage: + \\ type: "local" + \\ path: "/tmp/ml_experiments" + , + .should_be_valid = true, + .description = "Valid complete config", + }, + .{ + .content = + \\server: + \\ host: "localhost" + \\ port: 8080 + , + .should_be_valid = true, + .description = "Minimal valid config", + }, + .{ + .content = "", + .should_be_valid = false, + .description = "Empty config", + }, + .{ + .content = + \\invalid_yaml: [ + \\ missing_closing_bracket + , + .should_be_valid = false, + .description = "Invalid YAML syntax", + }, + }; + + for (test_configs) |case| { + // For now, just verify the content is readable + try testing.expect(case.content.len > 0 or !case.should_be_valid); + + if (case.should_be_valid) { + try testing.expect(std.mem.indexOf(u8, case.content, "server") != null); + } + } +} + +test "config default values" { + const allocator = testing.allocator; + + // Create minimal config + const minimal_config = + \\server: + \\ host: "localhost" + \\ port: 8080 + ; + + // Verify config content is readable + const content = try allocator.alloc(u8, minimal_config.len); + defer allocator.free(content); + + @memcpy(content, minimal_config); + + try testing.expect(std.mem.indexOf(u8, content, "localhost") != null); + try testing.expect(std.mem.indexOf(u8, content, "8080") != null); +} + +test "config server settings" { + const allocator = testing.allocator; + _ = allocator; // Mark as used for future test expansions + + // Test server configuration validation + const server_configs = [_]struct { + host: []const u8, + port: u16, + should_be_valid: bool, + }{ + .{ .host = "localhost", .port = 8080, .should_be_valid = true }, + .{ .host = "127.0.0.1", .port = 8080, .should_be_valid = true }, + .{ .host = "example.com", .port = 443, .should_be_valid = true }, + .{ .host = "", .port = 8080, .should_be_valid = false }, + .{ .host = "localhost", .port = 0, .should_be_valid = false }, + .{ .host = "localhost", .port = 65535, .should_be_valid = true }, + }; + + for (server_configs) |config| { + if (config.should_be_valid) { + try testing.expect(config.host.len > 0); + try testing.expect(config.port >= 1 and config.port <= 65535); + } else { + try testing.expect(config.host.len == 0 or + config.port == 0 or + config.port > 65535); + } + } +} + +test "config authentication settings" { + const allocator = testing.allocator; + _ = allocator; // Mark as used for future test expansions + + // Test authentication configuration + const auth_configs = [_]struct { + enabled: bool, + has_api_keys: bool, + should_be_valid: bool, + }{ + .{ .enabled = true, .has_api_keys = true, .should_be_valid = true }, + .{ .enabled = false, .has_api_keys = false, .should_be_valid = true }, + .{ .enabled = true, .has_api_keys = false, .should_be_valid = false }, + .{ .enabled = false, .has_api_keys = true, .should_be_valid = true }, // API keys can exist but auth disabled + }; + + for (auth_configs) |config| { + if (config.enabled and !config.has_api_keys) { + try testing.expect(!config.should_be_valid); + } else { + try testing.expect(true); // Valid configuration + } + } +} + +test "config storage settings" { + const allocator = testing.allocator; + _ = allocator; // Mark as used for future test expansions + + // Test storage configuration + const storage_configs = [_]struct { + storage_type: []const u8, + path: []const u8, + should_be_valid: bool, + }{ + .{ .storage_type = "local", .path = "/tmp/ml_experiments", .should_be_valid = true }, + .{ .storage_type = "s3", .path = "bucket-name", .should_be_valid = true }, + .{ .storage_type = "", .path = "/tmp/ml_experiments", .should_be_valid = false }, + .{ .storage_type = "local", .path = "", .should_be_valid = false }, + .{ .storage_type = "invalid", .path = "/tmp/ml_experiments", .should_be_valid = false }, + }; + + for (storage_configs) |config| { + if (config.should_be_valid) { + try testing.expect(config.storage_type.len > 0); + try testing.expect(config.path.len > 0); + } else { + try testing.expect(config.storage_type.len == 0 or + config.path.len == 0 or + std.mem.eql(u8, config.storage_type, "invalid")); + } + } +} + +test "config file paths" { + const allocator = testing.allocator; + + // Test config file path resolution + const config_paths = [_][]const u8{ + "config.yaml", + "config.yml", + ".ml/config.yaml", + "/etc/ml/config.yaml", + }; + + for (config_paths) |path| { + try testing.expect(path.len > 0); + + // Test path joining + const joined = try std.fs.path.join(allocator, &.{ "/base", path }); + defer allocator.free(joined); + + try testing.expect(joined.len > path.len); + try testing.expect(std.mem.startsWith(u8, joined, "/base/")); + } +} + +test "config environment variables" { + const allocator = testing.allocator; + _ = allocator; // Mark as used for future test expansions + + // Test environment variable substitution + const env_vars = [_]struct { + key: []const u8, + value: []const u8, + should_be_substituted: bool, + }{ + .{ .key = "ML_SERVER_HOST", .value = "localhost", .should_be_substituted = true }, + .{ .key = "ML_SERVER_PORT", .value = "8080", .should_be_substituted = true }, + .{ .key = "NON_EXISTENT_VAR", .value = "", .should_be_substituted = false }, + }; + + for (env_vars) |env_var| { + if (env_var.should_be_substituted) { + try testing.expect(env_var.value.len > 0); + } + } +} + +test "config validation errors" { + const allocator = testing.allocator; + + // Test various validation error scenarios + const error_scenarios = [_]struct { + config_content: []const u8, + expected_error_type: []const u8, + }{ + .{ + .config_content = "invalid: yaml: content: [", + .expected_error_type = "yaml_syntax_error", + }, + .{ + .config_content = "server:\n port: invalid_port", + .expected_error_type = "type_error", + }, + .{ + .config_content = "", + .expected_error_type = "empty_config", + }, + }; + + for (error_scenarios) |scenario| { + // For now, just verify the content is stored correctly + const content = try allocator.alloc(u8, scenario.config_content.len); + defer allocator.free(content); + + @memcpy(content, scenario.config_content); + + try testing.expect(std.mem.eql(u8, content, scenario.config_content)); + } +} + +test "config hot reload" { + const allocator = testing.allocator; + + // Initial config + const initial_config = + \\server: + \\ host: "localhost" + \\ port: 8080 + ; + + // Updated config + const updated_config = + \\server: + \\ host: "localhost" + \\ port: 9090 + ; + + // Test config content changes + const initial_content = try allocator.alloc(u8, initial_config.len); + defer allocator.free(initial_content); + + const updated_content = try allocator.alloc(u8, updated_config.len); + defer allocator.free(updated_content); + + @memcpy(initial_content, initial_config); + @memcpy(updated_content, updated_config); + + try testing.expect(std.mem.indexOf(u8, initial_content, "8080") != null); + try testing.expect(std.mem.indexOf(u8, updated_content, "9090") != null); + try testing.expect(std.mem.indexOf(u8, updated_content, "8080") == null); +} diff --git a/cli/tests/dataset_test.zig b/cli/tests/dataset_test.zig new file mode 100644 index 0000000..b214956 --- /dev/null +++ b/cli/tests/dataset_test.zig @@ -0,0 +1,58 @@ +const std = @import("std"); +const testing = std.testing; + +test "dataset command argument parsing" { + // Test various dataset command argument combinations + const test_cases = [_]struct { + args: []const []const u8, + expected_action: ?[]const u8, + should_be_valid: bool, + }{ + .{ .args = &[_][]const u8{"list"}, .expected_action = "list", .should_be_valid = true }, + .{ .args = &[_][]const u8{ "register", "test_dataset", "https://example.com/data.zip" }, .expected_action = "register", .should_be_valid = true }, + .{ .args = &[_][]const u8{ "info", "test_dataset" }, .expected_action = "info", .should_be_valid = true }, + .{ .args = &[_][]const u8{ "search", "test" }, .expected_action = "search", .should_be_valid = true }, + .{ .args = &[_][]const u8{}, .expected_action = null, .should_be_valid = false }, + .{ .args = &[_][]const u8{"invalid"}, .expected_action = null, .should_be_valid = false }, + }; + + for (test_cases) |case| { + if (case.should_be_valid and case.expected_action != null) { + const expected = case.expected_action.?; + try testing.expect(case.args.len > 0); + try testing.expect(std.mem.eql(u8, case.args[0], expected)); + } + } +} + +test "dataset URL validation" { + // Test URL format validation + const valid_urls = [_][]const u8{ + "http://example.com/data.zip", + "https://example.com/data.zip", + "s3://bucket/data.csv", + "gs://bucket/data.csv", + }; + + const invalid_urls = [_][]const u8{ + "ftp://example.com/data.zip", + "example.com/data.zip", + "not-a-url", + "", + }; + + for (valid_urls) |url| { + try testing.expect(url.len > 0); + try testing.expect(std.mem.startsWith(u8, url, "http://") or + std.mem.startsWith(u8, url, "https://") or + std.mem.startsWith(u8, url, "s3://") or + std.mem.startsWith(u8, url, "gs://")); + } + + for (invalid_urls) |url| { + try testing.expect(!std.mem.startsWith(u8, url, "http://") and + !std.mem.startsWith(u8, url, "https://") and + !std.mem.startsWith(u8, url, "s3://") and + !std.mem.startsWith(u8, url, "gs://")); + } +} diff --git a/cli/tests/main_test.zig b/cli/tests/main_test.zig new file mode 100644 index 0000000..e6bb7d6 --- /dev/null +++ b/cli/tests/main_test.zig @@ -0,0 +1,111 @@ +const std = @import("std"); +const testing = std.testing; +const src = @import("src"); + +test "CLI basic functionality" { + // Test that CLI module can be imported + const allocator = testing.allocator; + _ = allocator; + + // Test basic string operations used in CLI + const test_str = "ml sync"; + try testing.expect(test_str.len > 0); + try testing.expect(std.mem.startsWith(u8, test_str, "ml")); +} + +test "CLI command validation" { + // Test command validation logic + const commands = [_][]const u8{ "init", "sync", "queue", "status", "monitor", "cancel", "prune", "watch" }; + + for (commands) |cmd| { + try testing.expect(cmd.len > 0); + try testing.expect(std.mem.indexOf(u8, cmd, " ") == null); + } +} + +test "CLI argument parsing" { + // Test basic argument parsing scenarios + const test_cases = [_]struct { + input: []const []const u8, + expected_command: ?[]const u8, + }{ + .{ .input = &[_][]const u8{"init"}, .expected_command = "init" }, + .{ .input = &[_][]const u8{ "sync", "/tmp/test" }, .expected_command = "sync" }, + .{ .input = &[_][]const u8{ "queue", "test_job" }, .expected_command = "queue" }, + .{ .input = &[_][]const u8{}, .expected_command = null }, + }; + + for (test_cases) |case| { + if (case.input.len > 0) { + if (case.expected_command) |expected| { + try testing.expect(std.mem.eql(u8, case.input[0], expected)); + } + } else { + try testing.expect(case.expected_command == null); + } + } +} + +test "CLI path validation" { + // Test path validation logic + const test_paths = [_]struct { + path: []const u8, + is_valid: bool, + }{ + .{ .path = "/tmp/test", .is_valid = true }, + .{ .path = "./relative", .is_valid = true }, + .{ .path = "", .is_valid = false }, + .{ .path = " ", .is_valid = false }, + }; + + for (test_paths) |case| { + if (case.is_valid) { + try testing.expect(case.path.len > 0); + try testing.expect(!std.mem.eql(u8, case.path, "")); + } else { + try testing.expect(case.path.len == 0 or std.mem.eql(u8, case.path, " ")); + } + } +} + +test "CLI error handling" { + // Test error handling scenarios + const error_scenarios = [_]struct { + name: []const u8, + should_fail: bool, + }{ + .{ .name = "missing_config", .should_fail = true }, + .{ .name = "invalid_command", .should_fail = true }, + .{ .name = "missing_args", .should_fail = true }, + .{ .name = "valid_operation", .should_fail = false }, + }; + + for (error_scenarios) |scenario| { + if (scenario.should_fail) { + try testing.expect(scenario.name.len > 0); + } + } +} + +test "CLI memory management" { + // Test basic memory management + const allocator = testing.allocator; + + // Test allocation and deallocation + const test_str = try allocator.alloc(u8, 10); + defer allocator.free(test_str); + + try testing.expect(test_str.len == 10); + + // Fill with test data + for (test_str, 0..) |*byte, i| { + byte.* = @intCast(i % 256); + } + + // Test string formatting + const formatted = try std.fmt.allocPrint(allocator, "test_{d}", .{42}); + defer allocator.free(formatted); + + try testing.expect(std.mem.startsWith(u8, formatted, "test_")); + try testing.expect(std.mem.endsWith(u8, formatted, "42")); +} diff --git a/cli/tests/queue_test.zig b/cli/tests/queue_test.zig new file mode 100644 index 0000000..0ff5e34 --- /dev/null +++ b/cli/tests/queue_test.zig @@ -0,0 +1,219 @@ +const std = @import("std"); +const testing = std.testing; + +test "queue command argument parsing" { + // Test various queue command argument combinations + const test_cases = [_]struct { + args: []const []const u8, + expected_job: ?[]const u8, + expected_priority: ?u32, + should_be_valid: bool, + }{ + .{ .args = &[_][]const u8{"test_job"}, .expected_job = "test_job", .expected_priority = null, .should_be_valid = true }, + .{ .args = &[_][]const u8{ "test_job", "--priority", "5" }, .expected_job = "test_job", .expected_priority = 5, .should_be_valid = true }, + .{ .args = &[_][]const u8{}, .expected_job = null, .expected_priority = null, .should_be_valid = false }, + .{ .args = &[_][]const u8{ "", "--priority", "5" }, .expected_job = "", .expected_priority = 5, .should_be_valid = false }, + }; + + for (test_cases) |case| { + try testing.expect(case.args.len > 0 or !case.should_be_valid); + + if (case.should_be_valid and case.expected_job != null) { + const job = case.expected_job.?; + try testing.expect(std.mem.eql(u8, case.args[0], job)); + } + } +} + +test "queue job name validation" { + // Test job name validation rules + const test_names = [_]struct { + name: []const u8, + should_be_valid: bool, + reason: []const u8, + }{ + .{ .name = "valid_job_name", .should_be_valid = true, .reason = "Valid alphanumeric with underscore" }, + .{ .name = "job123", .should_be_valid = true, .reason = "Valid alphanumeric" }, + .{ .name = "job-with-dash", .should_be_valid = true, .reason = "Valid with dash" }, + .{ .name = "a", .should_be_valid = true, .reason = "Valid single character" }, + .{ .name = "", .should_be_valid = false, .reason = "Empty string" }, + .{ .name = " ", .should_be_valid = false, .reason = "Whitespace only" }, + .{ .name = "job with spaces", .should_be_valid = false, .reason = "Contains spaces" }, + .{ .name = "job/with/slashes", .should_be_valid = false, .reason = "Contains slashes" }, + .{ .name = "job\\with\\backslashes", .should_be_valid = false, .reason = "Contains backslashes" }, + .{ .name = "job@with@symbols", .should_be_valid = false, .reason = "Contains special symbols" }, + }; + + for (test_names) |case| { + if (case.should_be_valid) { + try testing.expect(case.name.len > 0); + try testing.expect(std.mem.indexOf(u8, case.name, " ") == null); + try testing.expect(std.mem.indexOf(u8, case.name, "/") == null); + try testing.expect(std.mem.indexOf(u8, case.name, "\\") == null); + } else { + try testing.expect(case.name.len == 0 or + std.mem.indexOf(u8, case.name, " ") != null or + std.mem.indexOf(u8, case.name, "/") != null or + std.mem.indexOf(u8, case.name, "\\") != null or + std.mem.indexOf(u8, case.name, "@") != null); + } + } +} + +test "queue priority validation" { + // Test priority value validation + const test_priorities = [_]struct { + priority_str: []const u8, + should_be_valid: bool, + expected_value: ?u32, + }{ + .{ .priority_str = "0", .should_be_valid = true, .expected_value = 0 }, + .{ .priority_str = "1", .should_be_valid = true, .expected_value = 1 }, + .{ .priority_str = "5", .should_be_valid = true, .expected_value = 5 }, + .{ .priority_str = "10", .should_be_valid = true, .expected_value = 10 }, + .{ .priority_str = "100", .should_be_valid = true, .expected_value = 100 }, + .{ .priority_str = "-1", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "-5", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "abc", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "5.5", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = " ", .should_be_valid = false, .expected_value = null }, + }; + + for (test_priorities) |case| { + if (case.should_be_valid) { + const parsed = std.fmt.parseInt(u32, case.priority_str, 10) catch |err| switch (err) { + error.InvalidCharacter => null, + else => null, + }; + + try testing.expect(parsed != null); + try testing.expect(parsed.? == case.expected_value.?); + } else { + const parsed = std.fmt.parseInt(u32, case.priority_str, 10) catch |err| switch (err) { + error.InvalidCharacter => null, + else => null, + }; + + try testing.expect(parsed == null); + } + } +} + +test "queue job metadata generation" { + // Test job metadata creation + const job_name = "test_job"; + const priority: u32 = 5; + const timestamp = std.time.timestamp(); + + // Create job metadata structure + const JobMetadata = struct { + name: []const u8, + priority: u32, + timestamp: i64, + status: []const u8, + }; + + const metadata = JobMetadata{ + .name = job_name, + .priority = priority, + .timestamp = timestamp, + .status = "queued", + }; + + try testing.expect(std.mem.eql(u8, metadata.name, job_name)); + try testing.expect(metadata.priority == priority); + try testing.expect(metadata.timestamp == timestamp); + try testing.expect(std.mem.eql(u8, metadata.status, "queued")); +} + +test "queue job serialization" { + const allocator = testing.allocator; + + // Test job serialization to JSON or other format + const job_name = "test_job"; + const priority: u32 = 3; + + // Create a simple job representation + const job_str = try std.fmt.allocPrint(allocator, "job:{s},priority:{d}", .{ job_name, priority }); + defer allocator.free(job_str); + + try testing.expect(std.mem.indexOf(u8, job_str, "job:test_job") != null); + try testing.expect(std.mem.indexOf(u8, job_str, "priority:3") != null); +} + +test "queue error handling" { + // Test various error scenarios + const error_cases = [_]struct { + scenario: []const u8, + should_fail: bool, + }{ + .{ .scenario = "empty job name", .should_fail = true }, + .{ .scenario = "invalid priority", .should_fail = true }, + .{ .scenario = "missing required fields", .should_fail = true }, + .{ .scenario = "valid job", .should_fail = false }, + }; + + for (error_cases) |case| { + if (case.should_fail) { + // Test that error conditions are properly handled + try testing.expect(true); // Placeholder for actual error handling tests + } + } +} + +test "queue concurrent operations" { + const allocator = testing.allocator; + + // Test queuing multiple jobs concurrently + const num_jobs = 5; // Reduced for simplicity + + // Generate job names and store them + var job_names: [5][]const u8 = undefined; + + for (0..num_jobs) |i| { + job_names[i] = try std.fmt.allocPrint(allocator, "job_{d}", .{i}); + } + defer { + for (job_names) |job_name| { + allocator.free(job_name); + } + } + + // Verify all job names are unique + for (job_names, 0..) |job1, i| { + for (job_names[i + 1 ..]) |job2| { + try testing.expect(!std.mem.eql(u8, job1, job2)); + } + } +} + +test "queue job priority ordering" { + + // Test job priority sorting + const JobQueueEntry = struct { + name: []const u8, + priority: u32, + + fn lessThan(_: void, a: @This(), b: @This()) bool { + return a.priority < b.priority; + } + }; + + var entries = [_]JobQueueEntry{ + .{ .name = "low_priority", .priority = 10 }, + .{ .name = "high_priority", .priority = 1 }, + .{ .name = "medium_priority", .priority = 5 }, + }; + + // Sort by priority (lower number = higher priority) + std.sort.insertion(JobQueueEntry, &entries, {}, JobQueueEntry.lessThan); + + try testing.expect(std.mem.eql(u8, entries[0].name, "high_priority")); + try testing.expect(std.mem.eql(u8, entries[1].name, "medium_priority")); + try testing.expect(std.mem.eql(u8, entries[2].name, "low_priority")); + + try testing.expect(entries[0].priority == 1); + try testing.expect(entries[1].priority == 5); + try testing.expect(entries[2].priority == 10); +} diff --git a/cli/tests/response_packets_test.zig b/cli/tests/response_packets_test.zig new file mode 100644 index 0000000..5690e35 --- /dev/null +++ b/cli/tests/response_packets_test.zig @@ -0,0 +1,116 @@ +const std = @import("std"); +const testing = std.testing; +const protocol = @import("src/net/protocol.zig"); + +test "ResponsePacket serialization - success" { + const timestamp = 1701234567; + const message = "Operation completed successfully"; + + var packet = protocol.ResponsePacket.initSuccess(timestamp, message); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + const allocator = gpa.allocator(); + + const serialized = try packet.serialize(allocator); + defer allocator.free(serialized); + + const deserialized = try protocol.ResponsePacket.deserialize(serialized, allocator); + defer cleanupTestPacket(allocator, deserialized); + + try testing.expect(deserialized.packet_type == .success); + try testing.expect(deserialized.timestamp == timestamp); + try testing.expect(std.mem.eql(u8, deserialized.success_message.?, message)); +} + +test "ResponsePacket serialization - error" { + const timestamp = 1701234567; + const error_code = protocol.ErrorCode.job_not_found; + const error_message = "Job not found"; + const error_details = "The specified job ID does not exist"; + + var packet = protocol.ResponsePacket.initError(timestamp, error_code, error_message, error_details); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + const allocator = gpa.allocator(); + + const serialized = try packet.serialize(allocator); + defer allocator.free(serialized); + + const deserialized = try protocol.ResponsePacket.deserialize(serialized, allocator); + defer cleanupTestPacket(allocator, deserialized); + + try testing.expect(deserialized.packet_type == .error_packet); + try testing.expect(deserialized.timestamp == timestamp); + try testing.expect(deserialized.error_code.? == error_code); + try testing.expect(std.mem.eql(u8, deserialized.error_message.?, error_message)); + try testing.expect(std.mem.eql(u8, deserialized.error_details.?, error_details)); +} + +test "ResponsePacket serialization - progress" { + const timestamp = 1701234567; + const progress_type = protocol.ProgressType.percentage; + const progress_value = 75; + const progress_total = 100; + const progress_message = "Processing files..."; + + var packet = protocol.ResponsePacket.initProgress(timestamp, progress_type, progress_value, progress_total, progress_message); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + const allocator = gpa.allocator(); + + const serialized = try packet.serialize(allocator); + defer allocator.free(serialized); + + const deserialized = try protocol.ResponsePacket.deserialize(serialized, allocator); + defer cleanupTestPacket(allocator, deserialized); + + try testing.expect(deserialized.packet_type == .progress); + try testing.expect(deserialized.timestamp == timestamp); + try testing.expect(deserialized.progress_type.? == progress_type); + try testing.expect(deserialized.progress_value.? == progress_value); + try testing.expect(deserialized.progress_total.? == progress_total); + try testing.expect(std.mem.eql(u8, deserialized.progress_message.?, progress_message)); +} + +test "Error message mapping" { + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getErrorMessage(.job_not_found), "Job not found")); + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getErrorMessage(.authentication_failed), "Authentication failed")); + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getErrorMessage(.server_overloaded), "Server is overloaded")); +} + +test "Log level names" { + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getLogLevelName(0), "DEBUG")); + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getLogLevelName(1), "INFO")); + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getLogLevelName(2), "WARN")); + try testing.expect(std.mem.eql(u8, protocol.ResponsePacket.getLogLevelName(3), "ERROR")); +} + +fn cleanupTestPacket(allocator: std.mem.Allocator, packet: protocol.ResponsePacket) void { + if (packet.success_message) |msg| { + allocator.free(msg); + } + if (packet.error_message) |msg| { + allocator.free(msg); + } + if (packet.error_details) |details| { + allocator.free(details); + } + if (packet.progress_message) |msg| { + allocator.free(msg); + } + if (packet.status_data) |data| { + allocator.free(data); + } + if (packet.data_type) |dtype| { + allocator.free(dtype); + } + if (packet.data_payload) |payload| { + allocator.free(payload); + } + if (packet.log_message) |msg| { + allocator.free(msg); + } +} diff --git a/cli/tests/rsync_embedded_test.zig b/cli/tests/rsync_embedded_test.zig new file mode 100644 index 0000000..8a997e0 --- /dev/null +++ b/cli/tests/rsync_embedded_test.zig @@ -0,0 +1,29 @@ +const std = @import("std"); +const testing = std.testing; +const src = @import("src"); +const rsync = src.utils.rsync_embedded.EmbeddedRsync; + +test "embedded rsync binary creation" { + const allocator = testing.allocator; + + var embedded_rsync = rsync.EmbeddedRsync{ .allocator = allocator }; + + // Test binary extraction + const rsync_path = try embedded_rsync.extractRsyncBinary(); + defer allocator.free(rsync_path); + + // Verify the binary was created + const file = try std.fs.cwd().openFile(rsync_path, .{}); + defer file.close(); + + // Verify it's executable + const stat = try std.fs.cwd().statFile(rsync_path); + try testing.expect(stat.mode & 0o111 != 0); + + // Verify it's a bash script wrapper + const content = try file.readToEndAlloc(allocator, 1024); + defer allocator.free(content); + + try testing.expect(std.mem.indexOf(u8, content, "rsync") != null); + try testing.expect(std.mem.indexOf(u8, content, "#!/usr/bin/env bash") != null); +} diff --git a/cli/tests/sync_test.zig b/cli/tests/sync_test.zig new file mode 100644 index 0000000..d9a9d77 --- /dev/null +++ b/cli/tests/sync_test.zig @@ -0,0 +1,108 @@ +const std = @import("std"); +const testing = std.testing; + +test "sync command argument parsing" { + // Test various sync command argument combinations + const test_cases = [_]struct { + args: []const []const u8, + expected_path: ?[]const u8, + expected_name: ?[]const u8, + expected_queue: bool, + expected_priority: ?u32, + }{ + .{ .args = &[_][]const u8{"/tmp/test"}, .expected_path = "/tmp/test", .expected_name = null, .expected_queue = false, .expected_priority = null }, + .{ .args = &[_][]const u8{ "/tmp/test", "--name", "test_job" }, .expected_path = "/tmp/test", .expected_name = "test_job", .expected_queue = false, .expected_priority = null }, + .{ .args = &[_][]const u8{ "/tmp/test", "--queue" }, .expected_path = "/tmp/test", .expected_name = null, .expected_queue = true, .expected_priority = null }, + .{ .args = &[_][]const u8{ "/tmp/test", "--priority", "5" }, .expected_path = "/tmp/test", .expected_name = null, .expected_queue = false, .expected_priority = 5 }, + }; + + for (test_cases) |case| { + // For now, just verify the arguments are valid + try testing.expect(case.args.len >= 1); + try testing.expect(case.args[0].len > 0); + + if (case.expected_path) |path| { + try testing.expect(std.mem.eql(u8, case.args[0], path)); + } + } +} + +test "sync path validation" { + // Test various path scenarios + const test_paths = [_]struct { + path: []const u8, + should_be_valid: bool, + }{ + .{ .path = "/tmp/test", .should_be_valid = true }, + .{ .path = "./relative", .should_be_valid = true }, + .{ .path = "../parent", .should_be_valid = true }, + .{ .path = "", .should_be_valid = false }, + .{ .path = " ", .should_be_valid = false }, + }; + + for (test_paths) |case| { + if (case.should_be_valid) { + try testing.expect(case.path.len > 0); + try testing.expect(!std.mem.eql(u8, case.path, "")); + } else { + try testing.expect(case.path.len == 0 or std.mem.eql(u8, case.path, " ")); + } + } +} + +test "sync priority validation" { + // Test priority values + const test_priorities = [_]struct { + priority_str: []const u8, + should_be_valid: bool, + expected_value: ?u32, + }{ + .{ .priority_str = "0", .should_be_valid = true, .expected_value = 0 }, + .{ .priority_str = "5", .should_be_valid = true, .expected_value = 5 }, + .{ .priority_str = "10", .should_be_valid = true, .expected_value = 10 }, + .{ .priority_str = "-1", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "abc", .should_be_valid = false, .expected_value = null }, + .{ .priority_str = "", .should_be_valid = false, .expected_value = null }, + }; + + for (test_priorities) |case| { + if (case.should_be_valid) { + const parsed = std.fmt.parseInt(u32, case.priority_str, 10) catch |err| switch (err) { + error.InvalidCharacter => null, + else => null, + }; + + if (parsed) |value| { + try testing.expect(value == case.expected_value.?); + } + } + } +} + +test "sync job name validation" { + // Test job name validation + const test_names = [_]struct { + name: []const u8, + should_be_valid: bool, + }{ + .{ .name = "valid_job_name", .should_be_valid = true }, + .{ .name = "job123", .should_be_valid = true }, + .{ .name = "job-with-dash", .should_be_valid = true }, + .{ .name = "", .should_be_valid = false }, + .{ .name = " ", .should_be_valid = false }, + .{ .name = "job with spaces", .should_be_valid = false }, + .{ .name = "job/with/slashes", .should_be_valid = false }, + }; + + for (test_names) |case| { + if (case.should_be_valid) { + try testing.expect(case.name.len > 0); + try testing.expect(std.mem.indexOf(u8, case.name, " ") == null); + try testing.expect(std.mem.indexOf(u8, case.name, "/") == null); + } else { + try testing.expect(case.name.len == 0 or + std.mem.indexOf(u8, case.name, " ") != null or + std.mem.indexOf(u8, case.name, "/") != null); + } + } +}