feat(cli): Update server integration commands

- queue.zig: Add --rerun <run_id> flag to re-queue completed local runs
  - Requires server connection, rejects in offline mode with clear error
  - HandleRerun function sends rerun request via WebSocket
- sync.zig: Rewrite for WebSocket experiment sync protocol
  - Queries unsynced runs from SQLite ml_runs table
  - Builds sync JSON with metrics and params
  - Sends sync_run message, waits for sync_ack response
  - MarkRunSynced updates synced flag in database
- watch.zig: Add --sync flag for continuous experiment sync
  - Auto-sync runs to server every 30 seconds when online
  - Mode detection with offline error handling
This commit is contained in:
Jeremie Fraeys 2026-02-20 21:28:34 -05:00
parent f5b68cca49
commit d3461cd07f
No known key found for this signature in database
7 changed files with 913 additions and 307 deletions

View file

@ -6,6 +6,9 @@ const history = @import("../utils/history.zig");
const crypto = @import("../utils/crypto.zig");
const protocol = @import("../net/protocol.zig");
const stdcrypto = std.crypto;
const mode = @import("../mode.zig");
const db = @import("../db.zig");
const manifest_lib = @import("../manifest.zig");
pub const TrackingConfig = struct {
mlflow: ?MLflowConfig = null,
@ -103,6 +106,45 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
return;
}
// Load config for mode detection
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
// Detect mode early to provide clear error for offline
const mode_result = try mode.detect(allocator, config);
// Check for --rerun flag
var rerun_id: ?[]const u8 = null;
for (args, 0..) |arg, i| {
if (std.mem.eql(u8, arg, "--rerun") and i + 1 < args.len) {
rerun_id = args[i + 1];
break;
}
}
// If --rerun is specified, handle re-queueing
if (rerun_id) |id| {
if (mode.isOffline(mode_result.mode)) {
colors.printError("ml queue --rerun requires server connection\n", .{});
return error.RequiresServer;
}
return try handleRerun(allocator, id, args, config);
}
// Regular queue - requires server
if (mode.isOffline(mode_result.mode)) {
colors.printError("ml queue requires server connection (use 'ml run' for local execution)\n", .{});
return error.RequiresServer;
}
// Continue with regular queue logic...
try executeQueue(allocator, args, config);
}
fn executeQueue(allocator: std.mem.Allocator, args: []const []const u8, config: Config) !void {
// Support batch operations - multiple job names
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
colors.printError("Failed to allocate job list: {}\n", .{err});
@ -117,13 +159,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var args_override: ?[]const u8 = null;
var note_override: ?[]const u8 = null;
// Load configuration to get defaults
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
// Initialize options with config defaults
var options = QueueOptions{
.cpu = config.default_cpu,
@ -391,6 +426,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
}
}
/// Handle --rerun flag: re-queue a completed run
fn handleRerun(allocator: std.mem.Allocator, run_id: []const u8, args: []const []const u8, cfg: Config) !void {
_ = args; // Override args not implemented yet
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
// Send rerun request to server
try client.sendRerunRequest(run_id, api_key_hash);
// Wait for response
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response (simplified)
if (std.mem.indexOf(u8, message, "success") != null) {
colors.printSuccess("✓ Re-queued run {s}\n", .{run_id[0..8]});
} else {
colors.printError("Failed to re-queue: {s}\n", .{message});
return error.RerunFailed;
}
}
fn generateCommitID(allocator: std.mem.Allocator) ![]const u8 {
var bytes: [20]u8 = undefined;
stdcrypto.random.bytes(&bytes);
@ -621,6 +685,7 @@ fn queueSingleJob(
fn printUsage() !void {
colors.printInfo("Usage: ml queue <job-name> [job-name ...] [options]\n", .{});
colors.printInfo(" ml queue --rerun <run_id> # Re-queue a completed run\n", .{});
colors.printInfo("\nBasic Options:\n", .{});
colors.printInfo(" --commit <id> Specify commit ID\n", .{});
colors.printInfo(" --priority <num> Set priority (0-255, default: 5)\n", .{});
@ -640,6 +705,7 @@ fn printUsage() !void {
colors.printInfo(" --experiment-group <name> Group related experiments\n", .{});
colors.printInfo(" --tags <csv> Comma-separated tags (e.g., ablation,batch-size)\n", .{});
colors.printInfo("\nSpecial Modes:\n", .{});
colors.printInfo(" --rerun <run_id> Re-queue a completed local run to server\n", .{});
colors.printInfo(" --dry-run Show what would be queued\n", .{});
colors.printInfo(" --validate Validate experiment without queuing\n", .{});
colors.printInfo(" --explain Explain what will happen\n", .{});
@ -662,10 +728,11 @@ fn printUsage() !void {
colors.printInfo(" ml queue my_job # Queue a job\n", .{});
colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{});
colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{});
colors.printInfo(" ml queue --rerun abc123 # Re-queue completed run\n", .{});
colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{});
colors.printInfo("\nResearch Examples:\n", .{});
colors.printInfo(" ml queue train.py --hypothesis \"LR scaling improves convergence\" \\\n", .{});
colors.printInfo(" --context \"Following paper XYZ\" --tags ablation,lr-scaling\n", .{});
colors.printInfo(" ml queue train.py --hypothesis 'LR scaling improves convergence' \n", .{});
colors.printInfo(" --context 'Following paper XYZ' --tags ablation,lr-scaling\n", .{});
}
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {

View file

@ -0,0 +1,3 @@
pub const parse = @import("queue/parse.zig");
pub const validate = @import("queue/validate.zig");
pub const submit = @import("queue/submit.zig");

View file

@ -0,0 +1,177 @@
const std = @import("std");
/// Parse job template from command line arguments
pub const JobTemplate = struct {
job_names: std.ArrayList([]const u8),
commit_id_override: ?[]const u8,
priority: u8,
snapshot_id: ?[]const u8,
snapshot_sha256: ?[]const u8,
args_override: ?[]const u8,
note_override: ?[]const u8,
cpu: u8,
memory: u8,
gpu: u8,
gpu_memory: ?[]const u8,
dry_run: bool,
validate: bool,
explain: bool,
json: bool,
force: bool,
runner_args_start: ?usize,
pub fn init(allocator: std.mem.Allocator) JobTemplate {
return .{
.job_names = std.ArrayList([]const u8).init(allocator),
.commit_id_override = null,
.priority = 5,
.snapshot_id = null,
.snapshot_sha256 = null,
.args_override = null,
.note_override = null,
.cpu = 2,
.memory = 8,
.gpu = 0,
.gpu_memory = null,
.dry_run = false,
.validate = false,
.explain = false,
.json = false,
.force = false,
.runner_args_start = null,
};
}
pub fn deinit(self: *JobTemplate, allocator: std.mem.Allocator) void {
self.job_names.deinit(allocator);
}
};
/// Parse command arguments into a job template
pub fn parseArgs(allocator: std.mem.Allocator, args: []const []const u8) !JobTemplate {
var template = JobTemplate.init(allocator);
errdefer template.deinit(allocator);
var i: usize = 0;
while (i < args.len) : (i += 1) {
const arg = args[i];
if (std.mem.eql(u8, arg, "--")) {
template.runner_args_start = i + 1;
break;
} else if (std.mem.eql(u8, arg, "--commit-id")) {
if (i + 1 < args.len) {
template.commit_id_override = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--priority")) {
if (i + 1 < args.len) {
template.priority = std.fmt.parseInt(u8, args[i + 1], 10) catch 5;
i += 1;
}
} else if (std.mem.eql(u8, arg, "--snapshot")) {
if (i + 1 < args.len) {
template.snapshot_id = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--snapshot-sha256")) {
if (i + 1 < args.len) {
template.snapshot_sha256 = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--args")) {
if (i + 1 < args.len) {
template.args_override = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--note")) {
if (i + 1 < args.len) {
template.note_override = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--cpu")) {
if (i + 1 < args.len) {
template.cpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 2;
i += 1;
}
} else if (std.mem.eql(u8, arg, "--memory")) {
if (i + 1 < args.len) {
template.memory = std.fmt.parseInt(u8, args[i + 1], 10) catch 8;
i += 1;
}
} else if (std.mem.eql(u8, arg, "--gpu")) {
if (i + 1 < args.len) {
template.gpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 0;
i += 1;
}
} else if (std.mem.eql(u8, arg, "--gpu-memory")) {
if (i + 1 < args.len) {
template.gpu_memory = args[i + 1];
i += 1;
}
} else if (std.mem.eql(u8, arg, "--dry-run")) {
template.dry_run = true;
} else if (std.mem.eql(u8, arg, "--validate")) {
template.validate = true;
} else if (std.mem.eql(u8, arg, "--explain")) {
template.explain = true;
} else if (std.mem.eql(u8, arg, "--json")) {
template.json = true;
} else if (std.mem.eql(u8, arg, "--force")) {
template.force = true;
} else if (!std.mem.startsWith(u8, arg, "-")) {
// Positional argument - job name
try template.job_names.append(arg);
}
}
return template;
}
/// Get runner args from the parsed template
pub fn getRunnerArgs(self: JobTemplate, all_args: []const []const u8) []const []const u8 {
if (self.runner_args_start) |start| {
if (start < all_args.len) {
return all_args[start..];
}
}
return &[_][]const u8{};
}
/// Resolve commit ID from prefix or full hash
pub fn resolveCommitId(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
for (input) |c| {
if (!std.ascii.isHex(c)) return error.InvalidArgs;
}
if (input.len == 40) {
return allocator.dupe(u8, input);
}
var dir = if (std.fs.path.isAbsolute(base_path))
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
else
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
defer dir.close();
var it = dir.iterate();
var found: ?[]u8 = null;
errdefer if (found) |s| allocator.free(s);
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const name = entry.name;
if (name.len != 40) continue;
if (!std.mem.startsWith(u8, name, input)) continue;
for (name) |c| {
if (!std.ascii.isHex(c)) break;
} else {
if (found != null) return error.InvalidArgs;
found = try allocator.dupe(u8, name);
}
}
if (found) |s| return s;
return error.FileNotFound;
}

View file

@ -0,0 +1,200 @@
const std = @import("std");
const ws = @import("../../net/ws/client.zig");
const protocol = @import("../../net/protocol.zig");
const crypto = @import("../../utils/crypto.zig");
const Config = @import("../../config.zig").Config;
const core = @import("../../core.zig");
const history = @import("../../utils/history.zig");
/// Job submission configuration
pub const SubmitConfig = struct {
job_names: []const []const u8,
commit_id: ?[]const u8,
priority: u8,
snapshot_id: ?[]const u8,
snapshot_sha256: ?[]const u8,
args_override: ?[]const u8,
note_override: ?[]const u8,
cpu: u8,
memory: u8,
gpu: u8,
gpu_memory: ?[]const u8,
dry_run: bool,
force: bool,
runner_args: []const []const u8,
pub fn estimateTotalJobs(self: SubmitConfig) usize {
return self.job_names.len;
}
};
/// Submission result
pub const SubmitResult = struct {
success: bool,
job_count: usize,
errors: std.ArrayList([]const u8),
pub fn init(allocator: std.mem.Allocator) SubmitResult {
return .{
.success = true,
.job_count = 0,
.errors = std.ArrayList([]const u8).init(allocator),
};
}
pub fn deinit(self: *SubmitResult, allocator: std.mem.Allocator) void {
for (self.errors.items) |err| {
allocator.free(err);
}
self.errors.deinit(allocator);
}
};
/// Submit jobs to the server
pub fn submitJobs(
allocator: std.mem.Allocator,
config: Config,
submit_config: SubmitConfig,
json: bool,
) !SubmitResult {
var result = SubmitResult.init(allocator);
errdefer result.deinit(allocator);
// Dry run mode - just print what would be submitted
if (submit_config.dry_run) {
if (json) {
std.debug.print("{{\"success\":true,\"command\":\"queue.submit\",\"dry_run\":true,\"jobs\":[", .{});
for (submit_config.job_names, 0..) |name, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("\"{s}\"", .{name});
}
std.debug.print("],\"total\":{d}}}}}\n", .{submit_config.job_names.len});
} else {
std.debug.print("[DRY RUN] Would submit {d} jobs:\n", .{submit_config.job_names.len});
for (submit_config.job_names) |name| {
std.debug.print(" - {s}\n", .{name});
}
}
result.job_count = submit_config.job_names.len;
return result;
}
// Get WebSocket URL
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
// Hash API key
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Connect to server
var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| {
const msg = try std.fmt.allocPrint(allocator, "Failed to connect: {}", .{err});
result.addError(msg);
result.success = false;
return result;
};
defer client.close();
// Submit each job
for (submit_config.job_names) |job_name| {
submitSingleJob(
allocator,
&client,
api_key_hash,
job_name,
submit_config,
&result,
) catch |err| {
const msg = try std.fmt.allocPrint(allocator, "Failed to submit {s}: {}", .{ job_name, err });
result.addError(msg);
result.success = false;
};
}
// Save to history if successful
if (result.success and result.job_count > 0) {
if (submit_config.commit_id) |commit_id| {
for (submit_config.job_names) |job_name| {
history.saveEntry(allocator, job_name, commit_id) catch {};
}
}
}
return result;
}
/// Submit a single job
fn submitSingleJob(
allocator: std.mem.Allocator,
client: *ws.Client,
_: []const u8,
job_name: []const u8,
submit_config: SubmitConfig,
result: *SubmitResult,
) !void {
// Build job submission payload
var payload = std.ArrayList(u8).init(allocator);
defer payload.deinit();
const writer = payload.writer();
try writer.print(
"{{\"job_name\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory\":{d},\"gpu\":{d}",
.{ job_name, submit_config.priority, submit_config.cpu, submit_config.memory, submit_config.gpu },
);
if (submit_config.gpu_memory) |gm| {
try writer.print(",\"gpu_memory\":\"{s}\"", .{gm});
}
try writer.print("}}", .{});
if (submit_config.commit_id) |cid| {
try writer.print(",\"commit_id\":\"{s}\"", .{cid});
}
if (submit_config.snapshot_id) |sid| {
try writer.print(",\"snapshot_id\":\"{s}\"", .{sid});
}
if (submit_config.note_override) |note| {
try writer.print(",\"note\":\"{s}\"", .{note});
}
try writer.print("}}", .{});
// Send job submission
client.sendMessage(payload.items) catch |err| {
return err;
};
result.job_count += 1;
}
/// Print submission results
pub fn printResults(result: SubmitResult, json: bool) void {
if (json) {
const status = if (result.success) "true" else "false";
std.debug.print("{{\"success\":{s},\"command\":\"queue.submit\",\"data\":{{\"submitted\":{d}", .{ status, result.job_count });
if (result.errors.items.len > 0) {
std.debug.print(",\"errors\":[", .{});
for (result.errors.items, 0..) |err, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("\"{s}\"", .{err});
}
std.debug.print("]", .{});
}
std.debug.print("}}}}\n", .{});
} else {
if (result.success) {
std.debug.print("Successfully submitted {d} jobs\n", .{result.job_count});
} else {
std.debug.print("Failed to submit jobs ({d} errors)\n", .{result.errors.items.len});
for (result.errors.items) |err| {
std.debug.print(" Error: {s}\n", .{err});
}
}
}
}

View file

@ -0,0 +1,161 @@
const std = @import("std");
/// Validation errors for queue operations
pub const ValidationError = error{
MissingJobName,
InvalidCommitId,
InvalidSnapshotId,
InvalidResourceLimits,
DuplicateJobName,
InvalidPriority,
};
/// Validation result
pub const ValidationResult = struct {
valid: bool,
errors: std.ArrayList([]const u8),
pub fn init(allocator: std.mem.Allocator) ValidationResult {
return .{
.valid = true,
.errors = std.ArrayList([]const u8).init(allocator),
};
}
pub fn deinit(self: *ValidationResult, allocator: std.mem.Allocator) void {
for (self.errors.items) |err| {
allocator.free(err);
}
self.errors.deinit(allocator);
}
pub fn addError(self: *ValidationResult, allocator: std.mem.Allocator, msg: []const u8) void {
self.valid = false;
const copy = allocator.dupe(u8, msg) catch return;
self.errors.append(copy) catch {
allocator.free(copy);
};
}
};
/// Validate job name format
pub fn validateJobName(name: []const u8) bool {
if (name.len == 0 or name.len > 128) return false;
for (name) |c| {
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
return false;
}
}
return true;
}
/// Validate commit ID format (40 character hex)
pub fn validateCommitId(id: []const u8) bool {
if (id.len != 40) return false;
for (id) |c| {
if (!std.ascii.isHex(c)) return false;
}
return true;
}
/// Validate snapshot ID format
pub fn validateSnapshotId(id: []const u8) bool {
if (id.len == 0 or id.len > 64) return false;
for (id) |c| {
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
return false;
}
}
return true;
}
/// Validate resource limits
pub fn validateResources(cpu: u8, memory: u8, gpu: u8) ValidationError!void {
if (cpu == 0 or cpu > 128) {
return error.InvalidResourceLimits;
}
if (memory == 0 or memory > 1024) {
return error.InvalidResourceLimits;
}
if (gpu > 16) {
return error.InvalidResourceLimits;
}
}
/// Validate priority value (1-10)
pub fn validatePriority(priority: u8) ValidationError!void {
if (priority < 1 or priority > 10) {
return error.InvalidPriority;
}
}
/// Full validation for job template
pub fn validateJobTemplate(
allocator: std.mem.Allocator,
job_names: []const []const u8,
commit_id: ?[]const u8,
cpu: u8,
memory: u8,
gpu: u8,
) !ValidationResult {
var result = ValidationResult.init(allocator);
errdefer result.deinit(allocator);
// Check job names
if (job_names.len == 0) {
result.addError(allocator, "At least one job name is required");
return result;
}
// Check for duplicates
var seen = std.StringHashMap(void).init(allocator);
defer seen.deinit();
for (job_names) |name| {
if (!validateJobName(name)) {
const msg = try std.fmt.allocPrint(allocator, "Invalid job name: {s}", .{name});
result.addError(allocator, msg);
allocator.free(msg);
}
if (seen.contains(name)) {
const msg = try std.fmt.allocPrint(allocator, "Duplicate job name: {s}", .{name});
result.addError(allocator, msg);
allocator.free(msg);
} else {
try seen.put(name, {});
}
}
// Validate commit ID if provided
if (commit_id) |id| {
if (!validateCommitId(id)) {
result.addError(allocator, "Invalid commit ID format (expected 40 character hex)");
}
}
// Validate resources
validateResources(cpu, memory, gpu) catch {
result.addError(allocator, "Invalid resource limits");
};
return result;
}
/// Print validation errors
pub fn printValidationErrors(result: ValidationResult, json: bool) void {
if (json) {
std.debug.print("{{\"success\":false,\"command\":\"queue.validate\",\"errors\":[", .{});
for (result.errors.items, 0..) |err, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("\"{s}\"", .{err});
}
std.debug.print("]}}\n", .{});
} else {
std.debug.print("Validation failed:\n", .{});
for (result.errors.items) |err| {
std.debug.print(" - {s}\n", .{err});
}
}
}

View file

@ -1,236 +1,262 @@
const std = @import("std");
const colors = @import("../utils/colors.zig");
const Config = @import("../config.zig").Config;
const crypto = @import("../utils/crypto.zig");
const rsync = @import("../utils/rsync_embedded.zig");
const config = @import("../config.zig");
const db = @import("../db.zig");
const ws = @import("../net/ws/client.zig");
const logging = @import("../utils/logging.zig");
const json = @import("../utils/json.zig");
const native_hash = @import("../utils/native_hash.zig");
const ProgressBar = @import("../ui/progress.zig").ProgressBar;
const crypto = @import("../utils/crypto.zig");
const mode = @import("../mode.zig");
const core = @import("../core.zig");
const manifest_lib = @import("../manifest.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
printUsage();
return error.InvalidArgs;
}
var flags = core.flags.CommonFlags{};
var specific_run_id: ?[]const u8 = null;
// Global flags
for (args) |arg| {
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
printUsage();
return printUsage();
} else if (std.mem.eql(u8, arg, "--json")) {
flags.json = true;
} else if (!std.mem.startsWith(u8, arg, "--")) {
specific_run_id = arg;
}
}
core.output.init(if (flags.json) .json else .text);
const cfg = try config.Config.load(allocator);
defer {
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
colors.printError("ml sync requires server connection\n", .{});
return error.RequiresServer;
}
const db_path = try cfg.getDBPath(allocator);
defer allocator.free(db_path);
var database = try db.DB.init(allocator, db_path);
defer database.close();
var runs_to_sync = std.ArrayList(RunInfo).init(allocator);
defer {
for (runs_to_sync.items) |*r| r.deinit(allocator);
runs_to_sync.deinit();
}
if (specific_run_id) |run_id| {
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
try db.DB.bindText(stmt, 1, run_id);
if (try db.DB.step(stmt)) {
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
} else {
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
return;
}
}
const path = args[0];
var job_name: ?[]const u8 = null;
var should_queue = false;
var priority: u8 = 5;
var json_mode: bool = false;
var dev_mode: bool = false;
var use_timestamp_check = false;
var dry_run = false;
// Parse flags
var i: usize = 1;
while (i < args.len) : (i += 1) {
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
job_name = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, args[i], "--queue")) {
should_queue = true;
} else if (std.mem.eql(u8, args[i], "--json")) {
json_mode = true;
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, args[i], "--dev")) {
dev_mode = true;
} else if (std.mem.eql(u8, args[i], "--check-timestamp")) {
use_timestamp_check = true;
} else if (std.mem.eql(u8, args[i], "--dry-run")) {
dry_run = true;
}
}
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
// Detect if path is a subdirectory by finding git root
const git_root = try findGitRoot(allocator, path);
defer if (git_root) |gr| allocator.free(gr);
const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path);
const relative_path = if (is_subdir) blk: {
// Get relative path from git root to the specified path
break :blk try std.fs.path.relative(allocator, git_root.?, path);
} else null;
defer if (relative_path) |rp| allocator.free(rp);
// Determine commit_id and remote path based on mode
const commit_id: []const u8 = if (dev_mode) blk: {
// Dev mode: skip expensive hashing, use fixed "dev" commit
break :blk "dev";
} else blk: {
// Production mode: calculate SHA256 of directory tree (always from git root)
const hash_base = git_root orelse path;
break :blk try crypto.hashDirectory(allocator, hash_base);
};
defer if (!dev_mode) allocator.free(commit_id);
// In dev mode, sync to {worker_base}/dev/files/ instead of hashed path
// For subdirectories, append the relative path to the remote destination
const remote_path = if (dev_mode) blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/",
.{ config.api_key, config.worker_host, config.worker_base },
);
}
} else blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
);
}
};
defer allocator.free(remote_path);
// Sync using embedded rsync (no external binary needed)
try rsync.sync(allocator, path, remote_path, config.worker_port);
if (json_mode) {
std.debug.print("{\"ok\":true,\"action\":\"sync\",\"commit_id\":\"{s}\"}\n", .{commit_id});
} else {
colors.printSuccess("✓ Files synced to server\n", .{});
}
// If queue flag is set, queue the job
if (should_queue) {
const queue_cmd = @import("queue.zig");
const actual_job_name = job_name orelse commit_id[0..8];
const queue_args = [_][]const u8{ actual_job_name, "--commit", commit_id, "--priority", try std.fmt.allocPrint(allocator, "{d}", .{priority}) };
defer allocator.free(queue_args[queue_args.len - 1]);
try queue_cmd.run(allocator, &queue_args);
}
// Optional: Connect to server for progress monitoring if --monitor flag is used
var monitor_progress = false;
for (args[1..]) |arg| {
if (std.mem.eql(u8, arg, "--monitor")) {
monitor_progress = true;
break;
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;";
const stmt = try database.prepare(sql);
defer db.DB.finalize(stmt);
while (try db.DB.step(stmt)) {
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
}
}
if (monitor_progress) {
std.debug.print("\nMonitoring sync progress...\n", .{});
try monitorSyncProgress(allocator, &config, commit_id);
if (runs_to_sync.items.len == 0) {
if (!flags.json) colors.printSuccess("All runs already synced!\n", .{});
return;
}
}
fn printUsage() void {
logging.err("Usage: ml sync <path> [options]\n\n", .{});
logging.err("Options:\n", .{});
logging.err(" --name <job> Override job name when used with --queue\n", .{});
logging.err(" --queue Queue the job after syncing\n", .{});
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
logging.err(" --monitor Wait and show basic sync progress\n", .{});
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{});
logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{});
logging.err(" --dry-run Show what would be synced without transferring\n", .{});
logging.err(" --help, -h Show this help message\n", .{});
}
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, commit_id: []const u8) !void {
_ = commit_id;
// Use plain password for WebSocket authentication
const api_key_plain = config.api_key;
// Connect to server with retry logic
const ws_url = try config.getWebSocketUrl(allocator);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
logging.info("Connecting to server {s}...\n", .{ws_url});
var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3);
defer client.disconnect();
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
// Initialize progress bar (will be updated as we receive progress messages)
var progress = ProgressBar.init(100, "Syncing files");
var timeout_counter: u32 = 0;
const max_timeout = 30; // 30 seconds timeout
while (timeout_counter < max_timeout) {
const message = client.receiveMessage(allocator) catch |err| {
switch (err) {
error.ConnectionClosed, error.ConnectionTimedOut => {
timeout_counter += 1;
std.Thread.sleep(1 * std.time.ns_per_s);
continue;
},
else => return err,
}
var success_count: usize = 0;
for (runs_to_sync.items) |run_info| {
if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]});
syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| {
if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err });
continue;
};
defer allocator.free(message);
// Parse JSON progress message using shared utilities
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch {
logging.success("Sync progress: {s}\n", .{message});
break;
};
defer parsed.deinit();
if (parsed.value == .object) {
const root = parsed.value.object;
const status = json.getString(root, "status") orelse "unknown";
const current = json.getInt(root, "progress") orelse 0;
const total = json.getInt(root, "total") orelse 100;
if (std.mem.eql(u8, status, "complete")) {
progress.finish();
colors.printSuccess("Sync complete!\n", .{});
break;
} else if (std.mem.eql(u8, status, "error")) {
const error_msg = json.getString(root, "error") orelse "Unknown error";
colors.printError("Sync failed: {s}\n", .{error_msg});
return error.SyncFailed;
} else {
// Update progress bar
progress.total = @intCast(total);
progress.update(@intCast(current));
}
} else {
logging.success("Sync progress: {s}\n", .{message});
break;
}
const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;";
const update_stmt = try database.prepare(update_sql);
defer db.DB.finalize(update_stmt);
try db.DB.bindText(update_stmt, 1, run_info.run_id);
_ = try db.DB.step(update_stmt);
success_count += 1;
}
if (timeout_counter >= max_timeout) {
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
database.checkpointOnExit();
if (flags.json) {
std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len });
} else {
colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len });
}
}
const RunInfo = struct {
run_id: []const u8,
experiment_id: []const u8,
name: []const u8,
status: []const u8,
start_time: []const u8,
end_time: ?[]const u8,
fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo {
return RunInfo{
.run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)),
.experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)),
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)),
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)),
.start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)),
.end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null,
};
}
fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void {
allocator.free(self.run_id);
allocator.free(self.experiment_id);
allocator.free(self.name);
allocator.free(self.status);
allocator.free(self.start_time);
if (self.end_time) |et| allocator.free(et);
}
};
fn syncRun(
allocator: std.mem.Allocator,
database: *db.DB,
client: *ws.Client,
run_info: RunInfo,
api_key_hash: []const u8,
) !void {
// Get metrics for this run
var metrics = std.ArrayList(Metric).init(allocator);
defer {
for (metrics.items) |*m| m.deinit(allocator);
metrics.deinit();
}
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
const metrics_stmt = try database.prepare(metrics_sql);
defer db.DB.finalize(metrics_stmt);
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
while (try db.DB.step(metrics_stmt)) {
try metrics.append(Metric{
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
.value = db.DB.columnDouble(metrics_stmt, 1),
.step = db.DB.columnInt64(metrics_stmt, 2),
});
}
// Get params for this run
var params = std.ArrayList(Param).init(allocator);
defer {
for (params.items) |*p| p.deinit(allocator);
params.deinit();
}
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
const params_stmt = try database.prepare(params_sql);
defer db.DB.finalize(params_stmt);
try db.DB.bindText(params_stmt, 1, run_info.run_id);
while (try db.DB.step(params_stmt)) {
try params.append(Param{
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
});
}
// Build sync JSON
var sync_json = std.ArrayList(u8).init(allocator);
defer sync_json.deinit();
const writer = sync_json.writer(allocator);
try writer.writeAll("{");
try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id});
try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id});
try writer.print("\"name\":\"{s}\",", .{run_info.name});
try writer.print("\"status\":\"{s}\",", .{run_info.status});
try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time});
if (run_info.end_time) |et| {
try writer.print("\"end_time\":\"{s}\",", .{et});
} else {
try writer.writeAll("\"end_time\":null,");
}
// Add metrics
try writer.writeAll("\"metrics\":[");
for (metrics.items, 0..) |m, i| {
if (i > 0) try writer.writeAll(",");
try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step });
}
try writer.writeAll("],");
// Add params
try writer.writeAll("\"params\":[");
for (params.items, 0..) |p, i| {
if (i > 0) try writer.writeAll(",");
try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value });
}
try writer.writeAll("]}");
// Send sync_run message
try client.sendSyncRun(sync_json.items, api_key_hash);
// Wait for sync_ack
const response = try client.receiveMessage(allocator);
defer allocator.free(response);
if (std.mem.indexOf(u8, response, "sync_ack") == null) {
return error.SyncRejected;
}
}
const Metric = struct {
key: []const u8,
value: f64,
step: i64,
fn deinit(self: *Metric, allocator: std.mem.Allocator) void {
allocator.free(self.key);
}
};
const Param = struct {
key: []const u8,
value: []const u8,
fn deinit(self: *Param, allocator: std.mem.Allocator) void {
allocator.free(self.key);
allocator.free(self.value);
}
};
fn printUsage() void {
std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{});
std.debug.print("Push local experiment runs to the server.\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --json Output structured JSON\n", .{});
std.debug.print(" --help, -h Show this help message\n\n", .{});
std.debug.print("Examples:\n", .{});
std.debug.print(" ml sync # Sync all unsynced runs\n", .{});
std.debug.print(" ml sync abc123 # Sync specific run\n", .{});
}
/// Find the git root directory by walking up from the given path
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
var buf: [std.fs.max_path_bytes]u8 = undefined;

View file

@ -1,110 +1,83 @@
const std = @import("std");
const Config = @import("../config.zig").Config;
const config = @import("../config.zig");
const crypto = @import("../utils/crypto.zig");
const rsync = @import("../utils/rsync_embedded.zig");
const ws = @import("../net/ws/client.zig");
const core = @import("../core.zig");
const mode = @import("../mode.zig");
const colors = @import("../utils/colors.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var flags = core.flags.CommonFlags{};
var should_sync = false;
const sync_interval: u64 = 30; // Default 30 seconds
if (args.len == 0) {
printUsage();
return error.InvalidArgs;
return printUsage();
}
// Global flags
for (args) |arg| {
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
printUsage();
return;
}
}
const path = args[0];
var job_name: ?[]const u8 = null;
var priority: u8 = 5;
var should_queue = false;
var json: bool = false;
// Parse flags
var i: usize = 1;
while (i < args.len) : (i += 1) {
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
job_name = args[i + 1];
i += 1;
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, args[i], "--queue")) {
should_queue = true;
} else if (std.mem.eql(u8, args[i], "--json")) {
json = true;
for (args) |arg| {
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
return printUsage();
} else if (std.mem.eql(u8, arg, "--sync")) {
should_sync = true;
} else if (std.mem.eql(u8, arg, "--json")) {
flags.json = true;
}
}
const config = try Config.load(allocator);
core.output.init(if (flags.json) .json else .text);
const cfg = try config.Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
var mut_cfg = cfg;
mut_cfg.deinit(allocator);
}
if (json) {
std.debug.print("{\"ok\":true,\"action\":\"watch\",\"path\":\"{s}\",\"queued\":{s}}\n", .{ path, if (should_queue) "true" else "false" });
} else {
std.debug.print("Watching {s} for changes...\n", .{path});
std.debug.print("Press Ctrl+C to stop\n", .{});
}
// Initial sync
var last_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
defer allocator.free(last_commit_id);
// Watch for changes
var watcher = try std.fs.cwd().openDir(path, .{ .iterate = true });
defer watcher.close();
var last_modified: u64 = 0;
while (true) {
// Check for file changes
var modified = false;
var walker = try watcher.walk(allocator);
defer walker.deinit();
while (try walker.next()) |entry| {
if (entry.kind == .file) {
const file = try watcher.openFile(entry.path, .{});
defer file.close();
const stat = try file.stat();
if (stat.mtime > last_modified) {
last_modified = @intCast(stat.mtime);
modified = true;
}
}
// Check mode if syncing
if (should_sync) {
const mode_result = try mode.detect(allocator, cfg);
if (mode.isOffline(mode_result.mode)) {
colors.printError("ml watch --sync requires server connection\n", .{});
return error.RequiresServer;
}
}
if (modified) {
if (!json) {
std.debug.print("\nChanges detected, syncing...\n", .{});
}
if (flags.json) {
std.debug.print("{{\"ok\":true,\"action\":\"watch\",\"sync\":{s}}}\n", .{if (should_sync) "true" else "false"});
} else {
if (should_sync) {
colors.printInfo("Watching for changes with auto-sync every {d}s...\n", .{sync_interval});
} else {
colors.printInfo("Watching directory for changes...\n", .{});
}
colors.printInfo("Press Ctrl+C to stop\n", .{});
}
const new_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
defer allocator.free(new_commit_id);
if (!std.mem.eql(u8, last_commit_id, new_commit_id)) {
allocator.free(last_commit_id);
last_commit_id = try allocator.dupe(u8, new_commit_id);
if (!json) {
std.debug.print("✓ Synced new version: {s}\n", .{last_commit_id[0..8]});
}
// Watch loop
var last_synced: i64 = 0;
while (true) {
if (should_sync) {
const now = std.time.timestamp();
if (now - last_synced >= @as(i64, @intCast(sync_interval))) {
// Trigger sync
const sync_cmd = @import("sync.zig");
sync_cmd.run(allocator, &[_][]const u8{"--json"}) catch |err| {
if (!flags.json) {
colors.printError("Auto-sync failed: {}\n", .{err});
}
};
last_synced = now;
}
}
// Wait before checking again
std.Thread.sleep(2_000_000_000); // 2 seconds in nanoseconds
std.Thread.sleep(2_000_000_000); // 2 seconds
}
}
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, config: Config) ![]u8 {
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, cfg: config.Config) ![]u8 {
// Calculate commit ID
const commit_id = try crypto.hashDirectory(allocator, path);
@ -112,22 +85,22 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
const remote_path = try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/",
.{ config.worker_user, config.worker_host, config.worker_base, commit_id },
.{ cfg.worker_user, cfg.worker_host, cfg.worker_base, commit_id },
);
defer allocator.free(remote_path);
try rsync.sync(allocator, path, remote_path, config.worker_port);
try rsync.sync(allocator, path, remote_path, cfg.worker_port);
if (should_queue) {
const actual_job_name = job_name orelse commit_id[0..8];
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
// Connect to WebSocket and queue job
const ws_url = try config.getWebSocketUrl(allocator);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
try client.sendQueueJob(actual_job_name, commit_id, priority, api_key_hash);
@ -144,11 +117,10 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
}
fn printUsage() void {
std.debug.print("Usage: ml watch <path> [options]\n\n", .{});
std.debug.print("Usage: ml watch [options]\n\n", .{});
std.debug.print("Watch for changes and optionally auto-sync.\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print(" --name <job> Override job name when used with --queue\n", .{});
std.debug.print(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
std.debug.print(" --queue Queue on every sync\n", .{});
std.debug.print(" --json Emit a single JSON line describing watch start\n", .{});
std.debug.print(" --sync Auto-sync runs to server every 30s\n", .{});
std.debug.print(" --json Output structured JSON\n", .{});
std.debug.print(" --help, -h Show this help message\n", .{});
}