feat(cli): Update server integration commands
- queue.zig: Add --rerun <run_id> flag to re-queue completed local runs - Requires server connection, rejects in offline mode with clear error - HandleRerun function sends rerun request via WebSocket - sync.zig: Rewrite for WebSocket experiment sync protocol - Queries unsynced runs from SQLite ml_runs table - Builds sync JSON with metrics and params - Sends sync_run message, waits for sync_ack response - MarkRunSynced updates synced flag in database - watch.zig: Add --sync flag for continuous experiment sync - Auto-sync runs to server every 30 seconds when online - Mode detection with offline error handling
This commit is contained in:
parent
f5b68cca49
commit
d3461cd07f
7 changed files with 913 additions and 307 deletions
|
|
@ -6,6 +6,9 @@ const history = @import("../utils/history.zig");
|
|||
const crypto = @import("../utils/crypto.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const stdcrypto = std.crypto;
|
||||
const mode = @import("../mode.zig");
|
||||
const db = @import("../db.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub const TrackingConfig = struct {
|
||||
mlflow: ?MLflowConfig = null,
|
||||
|
|
@ -103,6 +106,45 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return;
|
||||
}
|
||||
|
||||
// Load config for mode detection
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Detect mode early to provide clear error for offline
|
||||
const mode_result = try mode.detect(allocator, config);
|
||||
|
||||
// Check for --rerun flag
|
||||
var rerun_id: ?[]const u8 = null;
|
||||
for (args, 0..) |arg, i| {
|
||||
if (std.mem.eql(u8, arg, "--rerun") and i + 1 < args.len) {
|
||||
rerun_id = args[i + 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If --rerun is specified, handle re-queueing
|
||||
if (rerun_id) |id| {
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml queue --rerun requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
return try handleRerun(allocator, id, args, config);
|
||||
}
|
||||
|
||||
// Regular queue - requires server
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml queue requires server connection (use 'ml run' for local execution)\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
|
||||
// Continue with regular queue logic...
|
||||
try executeQueue(allocator, args, config);
|
||||
}
|
||||
|
||||
fn executeQueue(allocator: std.mem.Allocator, args: []const []const u8, config: Config) !void {
|
||||
// Support batch operations - multiple job names
|
||||
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
|
||||
colors.printError("Failed to allocate job list: {}\n", .{err});
|
||||
|
|
@ -117,13 +159,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var args_override: ?[]const u8 = null;
|
||||
var note_override: ?[]const u8 = null;
|
||||
|
||||
// Load configuration to get defaults
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Initialize options with config defaults
|
||||
var options = QueueOptions{
|
||||
.cpu = config.default_cpu,
|
||||
|
|
@ -391,6 +426,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
/// Handle --rerun flag: re-queue a completed run
|
||||
fn handleRerun(allocator: std.mem.Allocator, run_id: []const u8, args: []const []const u8, cfg: Config) !void {
|
||||
_ = args; // Override args not implemented yet
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send rerun request to server
|
||||
try client.sendRerunRequest(run_id, api_key_hash);
|
||||
|
||||
// Wait for response
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse response (simplified)
|
||||
if (std.mem.indexOf(u8, message, "success") != null) {
|
||||
colors.printSuccess("✓ Re-queued run {s}\n", .{run_id[0..8]});
|
||||
} else {
|
||||
colors.printError("Failed to re-queue: {s}\n", .{message});
|
||||
return error.RerunFailed;
|
||||
}
|
||||
}
|
||||
|
||||
fn generateCommitID(allocator: std.mem.Allocator) ![]const u8 {
|
||||
var bytes: [20]u8 = undefined;
|
||||
stdcrypto.random.bytes(&bytes);
|
||||
|
|
@ -621,6 +685,7 @@ fn queueSingleJob(
|
|||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml queue <job-name> [job-name ...] [options]\n", .{});
|
||||
colors.printInfo(" ml queue --rerun <run_id> # Re-queue a completed run\n", .{});
|
||||
colors.printInfo("\nBasic Options:\n", .{});
|
||||
colors.printInfo(" --commit <id> Specify commit ID\n", .{});
|
||||
colors.printInfo(" --priority <num> Set priority (0-255, default: 5)\n", .{});
|
||||
|
|
@ -640,6 +705,7 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --experiment-group <name> Group related experiments\n", .{});
|
||||
colors.printInfo(" --tags <csv> Comma-separated tags (e.g., ablation,batch-size)\n", .{});
|
||||
colors.printInfo("\nSpecial Modes:\n", .{});
|
||||
colors.printInfo(" --rerun <run_id> Re-queue a completed local run to server\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be queued\n", .{});
|
||||
colors.printInfo(" --validate Validate experiment without queuing\n", .{});
|
||||
colors.printInfo(" --explain Explain what will happen\n", .{});
|
||||
|
|
@ -662,10 +728,11 @@ fn printUsage() !void {
|
|||
colors.printInfo(" ml queue my_job # Queue a job\n", .{});
|
||||
colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{});
|
||||
colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{});
|
||||
colors.printInfo(" ml queue --rerun abc123 # Re-queue completed run\n", .{});
|
||||
colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{});
|
||||
colors.printInfo("\nResearch Examples:\n", .{});
|
||||
colors.printInfo(" ml queue train.py --hypothesis \"LR scaling improves convergence\" \\\n", .{});
|
||||
colors.printInfo(" --context \"Following paper XYZ\" --tags ablation,lr-scaling\n", .{});
|
||||
colors.printInfo(" ml queue train.py --hypothesis 'LR scaling improves convergence' \n", .{});
|
||||
colors.printInfo(" --context 'Following paper XYZ' --tags ablation,lr-scaling\n", .{});
|
||||
}
|
||||
|
||||
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {
|
||||
|
|
|
|||
3
cli/src/commands/queue/index.zig
Normal file
3
cli/src/commands/queue/index.zig
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub const parse = @import("queue/parse.zig");
|
||||
pub const validate = @import("queue/validate.zig");
|
||||
pub const submit = @import("queue/submit.zig");
|
||||
177
cli/src/commands/queue/parse.zig
Normal file
177
cli/src/commands/queue/parse.zig
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Parse job template from command line arguments
|
||||
pub const JobTemplate = struct {
|
||||
job_names: std.ArrayList([]const u8),
|
||||
commit_id_override: ?[]const u8,
|
||||
priority: u8,
|
||||
snapshot_id: ?[]const u8,
|
||||
snapshot_sha256: ?[]const u8,
|
||||
args_override: ?[]const u8,
|
||||
note_override: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
dry_run: bool,
|
||||
validate: bool,
|
||||
explain: bool,
|
||||
json: bool,
|
||||
force: bool,
|
||||
runner_args_start: ?usize,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) JobTemplate {
|
||||
return .{
|
||||
.job_names = std.ArrayList([]const u8).init(allocator),
|
||||
.commit_id_override = null,
|
||||
.priority = 5,
|
||||
.snapshot_id = null,
|
||||
.snapshot_sha256 = null,
|
||||
.args_override = null,
|
||||
.note_override = null,
|
||||
.cpu = 2,
|
||||
.memory = 8,
|
||||
.gpu = 0,
|
||||
.gpu_memory = null,
|
||||
.dry_run = false,
|
||||
.validate = false,
|
||||
.explain = false,
|
||||
.json = false,
|
||||
.force = false,
|
||||
.runner_args_start = null,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *JobTemplate, allocator: std.mem.Allocator) void {
|
||||
self.job_names.deinit(allocator);
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse command arguments into a job template
|
||||
pub fn parseArgs(allocator: std.mem.Allocator, args: []const []const u8) !JobTemplate {
|
||||
var template = JobTemplate.init(allocator);
|
||||
errdefer template.deinit(allocator);
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
|
||||
if (std.mem.eql(u8, arg, "--")) {
|
||||
template.runner_args_start = i + 1;
|
||||
break;
|
||||
} else if (std.mem.eql(u8, arg, "--commit-id")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.commit_id_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--priority")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.priority = std.fmt.parseInt(u8, args[i + 1], 10) catch 5;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.snapshot_id = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-sha256")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.snapshot_sha256 = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--args")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.args_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--note")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.note_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--cpu")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.cpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 2;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--memory")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.memory = std.fmt.parseInt(u8, args[i + 1], 10) catch 8;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--gpu")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.gpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 0;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--gpu-memory")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.gpu_memory = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--dry-run")) {
|
||||
template.dry_run = true;
|
||||
} else if (std.mem.eql(u8, arg, "--validate")) {
|
||||
template.validate = true;
|
||||
} else if (std.mem.eql(u8, arg, "--explain")) {
|
||||
template.explain = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
template.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--force")) {
|
||||
template.force = true;
|
||||
} else if (!std.mem.startsWith(u8, arg, "-")) {
|
||||
// Positional argument - job name
|
||||
try template.job_names.append(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return template;
|
||||
}
|
||||
|
||||
/// Get runner args from the parsed template
|
||||
pub fn getRunnerArgs(self: JobTemplate, all_args: []const []const u8) []const []const u8 {
|
||||
if (self.runner_args_start) |start| {
|
||||
if (start < all_args.len) {
|
||||
return all_args[start..];
|
||||
}
|
||||
}
|
||||
return &[_][]const u8{};
|
||||
}
|
||||
|
||||
/// Resolve commit ID from prefix or full hash
|
||||
pub fn resolveCommitId(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
|
||||
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
|
||||
for (input) |c| {
|
||||
if (!std.ascii.isHex(c)) return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (input.len == 40) {
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(base_path))
|
||||
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
|
||||
else
|
||||
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
var found: ?[]u8 = null;
|
||||
errdefer if (found) |s| allocator.free(s);
|
||||
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
const name = entry.name;
|
||||
if (name.len != 40) continue;
|
||||
if (!std.mem.startsWith(u8, name, input)) continue;
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isHex(c)) break;
|
||||
} else {
|
||||
if (found != null) return error.InvalidArgs;
|
||||
found = try allocator.dupe(u8, name);
|
||||
}
|
||||
}
|
||||
|
||||
if (found) |s| return s;
|
||||
return error.FileNotFound;
|
||||
}
|
||||
200
cli/src/commands/queue/submit.zig
Normal file
200
cli/src/commands/queue/submit.zig
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
const std = @import("std");
|
||||
const ws = @import("../../net/ws/client.zig");
|
||||
const protocol = @import("../../net/protocol.zig");
|
||||
const crypto = @import("../../utils/crypto.zig");
|
||||
const Config = @import("../../config.zig").Config;
|
||||
const core = @import("../../core.zig");
|
||||
const history = @import("../../utils/history.zig");
|
||||
|
||||
/// Job submission configuration
|
||||
pub const SubmitConfig = struct {
|
||||
job_names: []const []const u8,
|
||||
commit_id: ?[]const u8,
|
||||
priority: u8,
|
||||
snapshot_id: ?[]const u8,
|
||||
snapshot_sha256: ?[]const u8,
|
||||
args_override: ?[]const u8,
|
||||
note_override: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
dry_run: bool,
|
||||
force: bool,
|
||||
runner_args: []const []const u8,
|
||||
|
||||
pub fn estimateTotalJobs(self: SubmitConfig) usize {
|
||||
return self.job_names.len;
|
||||
}
|
||||
};
|
||||
|
||||
/// Submission result
|
||||
pub const SubmitResult = struct {
|
||||
success: bool,
|
||||
job_count: usize,
|
||||
errors: std.ArrayList([]const u8),
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) SubmitResult {
|
||||
return .{
|
||||
.success = true,
|
||||
.job_count = 0,
|
||||
.errors = std.ArrayList([]const u8).init(allocator),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *SubmitResult, allocator: std.mem.Allocator) void {
|
||||
for (self.errors.items) |err| {
|
||||
allocator.free(err);
|
||||
}
|
||||
self.errors.deinit(allocator);
|
||||
}
|
||||
};
|
||||
|
||||
/// Submit jobs to the server
|
||||
pub fn submitJobs(
|
||||
allocator: std.mem.Allocator,
|
||||
config: Config,
|
||||
submit_config: SubmitConfig,
|
||||
json: bool,
|
||||
) !SubmitResult {
|
||||
var result = SubmitResult.init(allocator);
|
||||
errdefer result.deinit(allocator);
|
||||
|
||||
// Dry run mode - just print what would be submitted
|
||||
if (submit_config.dry_run) {
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"queue.submit\",\"dry_run\":true,\"jobs\":[", .{});
|
||||
for (submit_config.job_names, 0..) |name, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{name});
|
||||
}
|
||||
std.debug.print("],\"total\":{d}}}}}\n", .{submit_config.job_names.len});
|
||||
} else {
|
||||
std.debug.print("[DRY RUN] Would submit {d} jobs:\n", .{submit_config.job_names.len});
|
||||
for (submit_config.job_names) |name| {
|
||||
std.debug.print(" - {s}\n", .{name});
|
||||
}
|
||||
}
|
||||
result.job_count = submit_config.job_names.len;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get WebSocket URL
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
// Hash API key
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Connect to server
|
||||
var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Failed to connect: {}", .{err});
|
||||
result.addError(msg);
|
||||
result.success = false;
|
||||
return result;
|
||||
};
|
||||
defer client.close();
|
||||
|
||||
// Submit each job
|
||||
for (submit_config.job_names) |job_name| {
|
||||
submitSingleJob(
|
||||
allocator,
|
||||
&client,
|
||||
api_key_hash,
|
||||
job_name,
|
||||
submit_config,
|
||||
&result,
|
||||
) catch |err| {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Failed to submit {s}: {}", .{ job_name, err });
|
||||
result.addError(msg);
|
||||
result.success = false;
|
||||
};
|
||||
}
|
||||
|
||||
// Save to history if successful
|
||||
if (result.success and result.job_count > 0) {
|
||||
if (submit_config.commit_id) |commit_id| {
|
||||
for (submit_config.job_names) |job_name| {
|
||||
history.saveEntry(allocator, job_name, commit_id) catch {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Submit a single job
|
||||
fn submitSingleJob(
|
||||
allocator: std.mem.Allocator,
|
||||
client: *ws.Client,
|
||||
_: []const u8,
|
||||
job_name: []const u8,
|
||||
submit_config: SubmitConfig,
|
||||
result: *SubmitResult,
|
||||
) !void {
|
||||
// Build job submission payload
|
||||
var payload = std.ArrayList(u8).init(allocator);
|
||||
defer payload.deinit();
|
||||
|
||||
const writer = payload.writer();
|
||||
try writer.print(
|
||||
"{{\"job_name\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory\":{d},\"gpu\":{d}",
|
||||
.{ job_name, submit_config.priority, submit_config.cpu, submit_config.memory, submit_config.gpu },
|
||||
);
|
||||
|
||||
if (submit_config.gpu_memory) |gm| {
|
||||
try writer.print(",\"gpu_memory\":\"{s}\"", .{gm});
|
||||
}
|
||||
|
||||
try writer.print("}}", .{});
|
||||
|
||||
if (submit_config.commit_id) |cid| {
|
||||
try writer.print(",\"commit_id\":\"{s}\"", .{cid});
|
||||
}
|
||||
|
||||
if (submit_config.snapshot_id) |sid| {
|
||||
try writer.print(",\"snapshot_id\":\"{s}\"", .{sid});
|
||||
}
|
||||
|
||||
if (submit_config.note_override) |note| {
|
||||
try writer.print(",\"note\":\"{s}\"", .{note});
|
||||
}
|
||||
|
||||
try writer.print("}}", .{});
|
||||
|
||||
// Send job submission
|
||||
client.sendMessage(payload.items) catch |err| {
|
||||
return err;
|
||||
};
|
||||
|
||||
result.job_count += 1;
|
||||
}
|
||||
|
||||
/// Print submission results
|
||||
pub fn printResults(result: SubmitResult, json: bool) void {
|
||||
if (json) {
|
||||
const status = if (result.success) "true" else "false";
|
||||
std.debug.print("{{\"success\":{s},\"command\":\"queue.submit\",\"data\":{{\"submitted\":{d}", .{ status, result.job_count });
|
||||
|
||||
if (result.errors.items.len > 0) {
|
||||
std.debug.print(",\"errors\":[", .{});
|
||||
for (result.errors.items, 0..) |err, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{err});
|
||||
}
|
||||
std.debug.print("]", .{});
|
||||
}
|
||||
|
||||
std.debug.print("}}}}\n", .{});
|
||||
} else {
|
||||
if (result.success) {
|
||||
std.debug.print("Successfully submitted {d} jobs\n", .{result.job_count});
|
||||
} else {
|
||||
std.debug.print("Failed to submit jobs ({d} errors)\n", .{result.errors.items.len});
|
||||
for (result.errors.items) |err| {
|
||||
std.debug.print(" Error: {s}\n", .{err});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
161
cli/src/commands/queue/validate.zig
Normal file
161
cli/src/commands/queue/validate.zig
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Validation errors for queue operations
|
||||
pub const ValidationError = error{
|
||||
MissingJobName,
|
||||
InvalidCommitId,
|
||||
InvalidSnapshotId,
|
||||
InvalidResourceLimits,
|
||||
DuplicateJobName,
|
||||
InvalidPriority,
|
||||
};
|
||||
|
||||
/// Validation result
|
||||
pub const ValidationResult = struct {
|
||||
valid: bool,
|
||||
errors: std.ArrayList([]const u8),
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) ValidationResult {
|
||||
return .{
|
||||
.valid = true,
|
||||
.errors = std.ArrayList([]const u8).init(allocator),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *ValidationResult, allocator: std.mem.Allocator) void {
|
||||
for (self.errors.items) |err| {
|
||||
allocator.free(err);
|
||||
}
|
||||
self.errors.deinit(allocator);
|
||||
}
|
||||
|
||||
pub fn addError(self: *ValidationResult, allocator: std.mem.Allocator, msg: []const u8) void {
|
||||
self.valid = false;
|
||||
const copy = allocator.dupe(u8, msg) catch return;
|
||||
self.errors.append(copy) catch {
|
||||
allocator.free(copy);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Validate job name format
|
||||
pub fn validateJobName(name: []const u8) bool {
|
||||
if (name.len == 0 or name.len > 128) return false;
|
||||
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate commit ID format (40 character hex)
|
||||
pub fn validateCommitId(id: []const u8) bool {
|
||||
if (id.len != 40) return false;
|
||||
for (id) |c| {
|
||||
if (!std.ascii.isHex(c)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate snapshot ID format
|
||||
pub fn validateSnapshotId(id: []const u8) bool {
|
||||
if (id.len == 0 or id.len > 64) return false;
|
||||
for (id) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate resource limits
|
||||
pub fn validateResources(cpu: u8, memory: u8, gpu: u8) ValidationError!void {
|
||||
if (cpu == 0 or cpu > 128) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
if (memory == 0 or memory > 1024) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
if (gpu > 16) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate priority value (1-10)
|
||||
pub fn validatePriority(priority: u8) ValidationError!void {
|
||||
if (priority < 1 or priority > 10) {
|
||||
return error.InvalidPriority;
|
||||
}
|
||||
}
|
||||
|
||||
/// Full validation for job template
|
||||
pub fn validateJobTemplate(
|
||||
allocator: std.mem.Allocator,
|
||||
job_names: []const []const u8,
|
||||
commit_id: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
) !ValidationResult {
|
||||
var result = ValidationResult.init(allocator);
|
||||
errdefer result.deinit(allocator);
|
||||
|
||||
// Check job names
|
||||
if (job_names.len == 0) {
|
||||
result.addError(allocator, "At least one job name is required");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
var seen = std.StringHashMap(void).init(allocator);
|
||||
defer seen.deinit();
|
||||
|
||||
for (job_names) |name| {
|
||||
if (!validateJobName(name)) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Invalid job name: {s}", .{name});
|
||||
result.addError(allocator, msg);
|
||||
allocator.free(msg);
|
||||
}
|
||||
|
||||
if (seen.contains(name)) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Duplicate job name: {s}", .{name});
|
||||
result.addError(allocator, msg);
|
||||
allocator.free(msg);
|
||||
} else {
|
||||
try seen.put(name, {});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate commit ID if provided
|
||||
if (commit_id) |id| {
|
||||
if (!validateCommitId(id)) {
|
||||
result.addError(allocator, "Invalid commit ID format (expected 40 character hex)");
|
||||
}
|
||||
}
|
||||
|
||||
// Validate resources
|
||||
validateResources(cpu, memory, gpu) catch {
|
||||
result.addError(allocator, "Invalid resource limits");
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Print validation errors
|
||||
pub fn printValidationErrors(result: ValidationResult, json: bool) void {
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":false,\"command\":\"queue.validate\",\"errors\":[", .{});
|
||||
for (result.errors.items, 0..) |err, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{err});
|
||||
}
|
||||
std.debug.print("]}}\n", .{});
|
||||
} else {
|
||||
std.debug.print("Validation failed:\n", .{});
|
||||
for (result.errors.items) |err| {
|
||||
std.debug.print(" - {s}\n", .{err});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,236 +1,262 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig");
|
||||
const config = @import("../config.zig");
|
||||
const db = @import("../db.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const json = @import("../utils/json.zig");
|
||||
const native_hash = @import("../utils/native_hash.zig");
|
||||
const ProgressBar = @import("../ui/progress.zig").ProgressBar;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const core = @import("../core.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var specific_run_id: ?[]const u8 = null;
|
||||
|
||||
// Global flags
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (!std.mem.startsWith(u8, arg, "--")) {
|
||||
specific_run_id = arg;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml sync requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
var runs_to_sync = std.ArrayList(RunInfo).init(allocator);
|
||||
defer {
|
||||
for (runs_to_sync.items) |*r| r.deinit(allocator);
|
||||
runs_to_sync.deinit();
|
||||
}
|
||||
|
||||
if (specific_run_id) |run_id| {
|
||||
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
if (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
|
||||
} else {
|
||||
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var job_name: ?[]const u8 = null;
|
||||
var should_queue = false;
|
||||
var priority: u8 = 5;
|
||||
var json_mode: bool = false;
|
||||
var dev_mode: bool = false;
|
||||
var use_timestamp_check = false;
|
||||
var dry_run = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
|
||||
job_name = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--queue")) {
|
||||
should_queue = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--dev")) {
|
||||
dev_mode = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--check-timestamp")) {
|
||||
use_timestamp_check = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--dry-run")) {
|
||||
dry_run = true;
|
||||
}
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Detect if path is a subdirectory by finding git root
|
||||
const git_root = try findGitRoot(allocator, path);
|
||||
defer if (git_root) |gr| allocator.free(gr);
|
||||
|
||||
const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path);
|
||||
const relative_path = if (is_subdir) blk: {
|
||||
// Get relative path from git root to the specified path
|
||||
break :blk try std.fs.path.relative(allocator, git_root.?, path);
|
||||
} else null;
|
||||
defer if (relative_path) |rp| allocator.free(rp);
|
||||
|
||||
// Determine commit_id and remote path based on mode
|
||||
const commit_id: []const u8 = if (dev_mode) blk: {
|
||||
// Dev mode: skip expensive hashing, use fixed "dev" commit
|
||||
break :blk "dev";
|
||||
} else blk: {
|
||||
// Production mode: calculate SHA256 of directory tree (always from git root)
|
||||
const hash_base = git_root orelse path;
|
||||
break :blk try crypto.hashDirectory(allocator, hash_base);
|
||||
};
|
||||
defer if (!dev_mode) allocator.free(commit_id);
|
||||
|
||||
// In dev mode, sync to {worker_base}/dev/files/ instead of hashed path
|
||||
// For subdirectories, append the relative path to the remote destination
|
||||
const remote_path = if (dev_mode) blk: {
|
||||
if (is_subdir) {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/dev/files/{s}/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, relative_path.? },
|
||||
);
|
||||
} else {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/dev/files/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base },
|
||||
);
|
||||
}
|
||||
} else blk: {
|
||||
if (is_subdir) {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/{s}/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? },
|
||||
);
|
||||
} else {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
|
||||
);
|
||||
}
|
||||
};
|
||||
defer allocator.free(remote_path);
|
||||
|
||||
// Sync using embedded rsync (no external binary needed)
|
||||
try rsync.sync(allocator, path, remote_path, config.worker_port);
|
||||
|
||||
if (json_mode) {
|
||||
std.debug.print("{\"ok\":true,\"action\":\"sync\",\"commit_id\":\"{s}\"}\n", .{commit_id});
|
||||
} else {
|
||||
colors.printSuccess("✓ Files synced to server\n", .{});
|
||||
}
|
||||
|
||||
// If queue flag is set, queue the job
|
||||
if (should_queue) {
|
||||
const queue_cmd = @import("queue.zig");
|
||||
const actual_job_name = job_name orelse commit_id[0..8];
|
||||
const queue_args = [_][]const u8{ actual_job_name, "--commit", commit_id, "--priority", try std.fmt.allocPrint(allocator, "{d}", .{priority}) };
|
||||
defer allocator.free(queue_args[queue_args.len - 1]);
|
||||
try queue_cmd.run(allocator, &queue_args);
|
||||
}
|
||||
|
||||
// Optional: Connect to server for progress monitoring if --monitor flag is used
|
||||
var monitor_progress = false;
|
||||
for (args[1..]) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--monitor")) {
|
||||
monitor_progress = true;
|
||||
break;
|
||||
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
while (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(try RunInfo.fromStmt(stmt, allocator));
|
||||
}
|
||||
}
|
||||
|
||||
if (monitor_progress) {
|
||||
std.debug.print("\nMonitoring sync progress...\n", .{});
|
||||
try monitorSyncProgress(allocator, &config, commit_id);
|
||||
if (runs_to_sync.items.len == 0) {
|
||||
if (!flags.json) colors.printSuccess("All runs already synced!\n", .{});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
logging.err("Usage: ml sync <path> [options]\n\n", .{});
|
||||
logging.err("Options:\n", .{});
|
||||
logging.err(" --name <job> Override job name when used with --queue\n", .{});
|
||||
logging.err(" --queue Queue the job after syncing\n", .{});
|
||||
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
|
||||
logging.err(" --monitor Wait and show basic sync progress\n", .{});
|
||||
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
|
||||
logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{});
|
||||
logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{});
|
||||
logging.err(" --dry-run Show what would be synced without transferring\n", .{});
|
||||
logging.err(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, commit_id: []const u8) !void {
|
||||
_ = commit_id;
|
||||
// Use plain password for WebSocket authentication
|
||||
const api_key_plain = config.api_key;
|
||||
|
||||
// Connect to server with retry logic
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
logging.info("Connecting to server {s}...\n", .{ws_url});
|
||||
var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3);
|
||||
defer client.disconnect();
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Initialize progress bar (will be updated as we receive progress messages)
|
||||
var progress = ProgressBar.init(100, "Syncing files");
|
||||
|
||||
var timeout_counter: u32 = 0;
|
||||
const max_timeout = 30; // 30 seconds timeout
|
||||
|
||||
while (timeout_counter < max_timeout) {
|
||||
const message = client.receiveMessage(allocator) catch |err| {
|
||||
switch (err) {
|
||||
error.ConnectionClosed, error.ConnectionTimedOut => {
|
||||
timeout_counter += 1;
|
||||
std.Thread.sleep(1 * std.time.ns_per_s);
|
||||
continue;
|
||||
},
|
||||
else => return err,
|
||||
}
|
||||
var success_count: usize = 0;
|
||||
for (runs_to_sync.items) |run_info| {
|
||||
if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]});
|
||||
syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| {
|
||||
if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err });
|
||||
continue;
|
||||
};
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse JSON progress message using shared utilities
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch {
|
||||
logging.success("Sync progress: {s}\n", .{message});
|
||||
break;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value == .object) {
|
||||
const root = parsed.value.object;
|
||||
const status = json.getString(root, "status") orelse "unknown";
|
||||
const current = json.getInt(root, "progress") orelse 0;
|
||||
const total = json.getInt(root, "total") orelse 100;
|
||||
|
||||
if (std.mem.eql(u8, status, "complete")) {
|
||||
progress.finish();
|
||||
colors.printSuccess("Sync complete!\n", .{});
|
||||
break;
|
||||
} else if (std.mem.eql(u8, status, "error")) {
|
||||
const error_msg = json.getString(root, "error") orelse "Unknown error";
|
||||
colors.printError("Sync failed: {s}\n", .{error_msg});
|
||||
return error.SyncFailed;
|
||||
} else {
|
||||
// Update progress bar
|
||||
progress.total = @intCast(total);
|
||||
progress.update(@intCast(current));
|
||||
}
|
||||
} else {
|
||||
logging.success("Sync progress: {s}\n", .{message});
|
||||
break;
|
||||
}
|
||||
const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;";
|
||||
const update_stmt = try database.prepare(update_sql);
|
||||
defer db.DB.finalize(update_stmt);
|
||||
try db.DB.bindText(update_stmt, 1, run_info.run_id);
|
||||
_ = try db.DB.step(update_stmt);
|
||||
success_count += 1;
|
||||
}
|
||||
|
||||
if (timeout_counter >= max_timeout) {
|
||||
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
|
||||
database.checkpointOnExit();
|
||||
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len });
|
||||
} else {
|
||||
colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len });
|
||||
}
|
||||
}
|
||||
|
||||
const RunInfo = struct {
|
||||
run_id: []const u8,
|
||||
experiment_id: []const u8,
|
||||
name: []const u8,
|
||||
status: []const u8,
|
||||
start_time: []const u8,
|
||||
end_time: ?[]const u8,
|
||||
|
||||
fn fromStmt(stmt: *anyopaque, allocator: std.mem.Allocator) !RunInfo {
|
||||
return RunInfo{
|
||||
.run_id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)),
|
||||
.experiment_id = try allocator.dupe(u8, db.DB.columnText(stmt, 1)),
|
||||
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 2)),
|
||||
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 3)),
|
||||
.start_time = try allocator.dupe(u8, db.DB.columnText(stmt, 4)),
|
||||
.end_time = if (db.DB.columnText(stmt, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(stmt, 5)) else null,
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.run_id);
|
||||
allocator.free(self.experiment_id);
|
||||
allocator.free(self.name);
|
||||
allocator.free(self.status);
|
||||
allocator.free(self.start_time);
|
||||
if (self.end_time) |et| allocator.free(et);
|
||||
}
|
||||
};
|
||||
|
||||
fn syncRun(
|
||||
allocator: std.mem.Allocator,
|
||||
database: *db.DB,
|
||||
client: *ws.Client,
|
||||
run_info: RunInfo,
|
||||
api_key_hash: []const u8,
|
||||
) !void {
|
||||
// Get metrics for this run
|
||||
var metrics = std.ArrayList(Metric).init(allocator);
|
||||
defer {
|
||||
for (metrics.items) |*m| m.deinit(allocator);
|
||||
metrics.deinit();
|
||||
}
|
||||
|
||||
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
|
||||
const metrics_stmt = try database.prepare(metrics_sql);
|
||||
defer db.DB.finalize(metrics_stmt);
|
||||
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(metrics_stmt)) {
|
||||
try metrics.append(Metric{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
|
||||
.value = db.DB.columnDouble(metrics_stmt, 1),
|
||||
.step = db.DB.columnInt64(metrics_stmt, 2),
|
||||
});
|
||||
}
|
||||
|
||||
// Get params for this run
|
||||
var params = std.ArrayList(Param).init(allocator);
|
||||
defer {
|
||||
for (params.items) |*p| p.deinit(allocator);
|
||||
params.deinit();
|
||||
}
|
||||
|
||||
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
|
||||
const params_stmt = try database.prepare(params_sql);
|
||||
defer db.DB.finalize(params_stmt);
|
||||
try db.DB.bindText(params_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(params_stmt)) {
|
||||
try params.append(Param{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
|
||||
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
|
||||
});
|
||||
}
|
||||
|
||||
// Build sync JSON
|
||||
var sync_json = std.ArrayList(u8).init(allocator);
|
||||
defer sync_json.deinit();
|
||||
const writer = sync_json.writer(allocator);
|
||||
|
||||
try writer.writeAll("{");
|
||||
try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id});
|
||||
try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id});
|
||||
try writer.print("\"name\":\"{s}\",", .{run_info.name});
|
||||
try writer.print("\"status\":\"{s}\",", .{run_info.status});
|
||||
try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time});
|
||||
if (run_info.end_time) |et| {
|
||||
try writer.print("\"end_time\":\"{s}\",", .{et});
|
||||
} else {
|
||||
try writer.writeAll("\"end_time\":null,");
|
||||
}
|
||||
|
||||
// Add metrics
|
||||
try writer.writeAll("\"metrics\":[");
|
||||
for (metrics.items, 0..) |m, i| {
|
||||
if (i > 0) try writer.writeAll(",");
|
||||
try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step });
|
||||
}
|
||||
try writer.writeAll("],");
|
||||
|
||||
// Add params
|
||||
try writer.writeAll("\"params\":[");
|
||||
for (params.items, 0..) |p, i| {
|
||||
if (i > 0) try writer.writeAll(",");
|
||||
try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value });
|
||||
}
|
||||
try writer.writeAll("]}");
|
||||
|
||||
// Send sync_run message
|
||||
try client.sendSyncRun(sync_json.items, api_key_hash);
|
||||
|
||||
// Wait for sync_ack
|
||||
const response = try client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
if (std.mem.indexOf(u8, response, "sync_ack") == null) {
|
||||
return error.SyncRejected;
|
||||
}
|
||||
}
|
||||
|
||||
const Metric = struct {
|
||||
key: []const u8,
|
||||
value: f64,
|
||||
step: i64,
|
||||
|
||||
fn deinit(self: *Metric, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.key);
|
||||
}
|
||||
};
|
||||
|
||||
const Param = struct {
|
||||
key: []const u8,
|
||||
value: []const u8,
|
||||
|
||||
fn deinit(self: *Param, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.key);
|
||||
allocator.free(self.value);
|
||||
}
|
||||
};
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{});
|
||||
std.debug.print("Push local experiment runs to the server.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml sync # Sync all unsynced runs\n", .{});
|
||||
std.debug.print(" ml sync abc123 # Sync specific run\n", .{});
|
||||
}
|
||||
|
||||
/// Find the git root directory by walking up from the given path
|
||||
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
|
||||
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
|
|
|
|||
|
|
@ -1,110 +1,83 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const config = @import("../config.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const core = @import("../core.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var should_sync = false;
|
||||
const sync_interval: u64 = 30; // Default 30 seconds
|
||||
|
||||
if (args.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
// Global flags
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var job_name: ?[]const u8 = null;
|
||||
var priority: u8 = 5;
|
||||
var should_queue = false;
|
||||
var json: bool = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
|
||||
job_name = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--queue")) {
|
||||
should_queue = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--json")) {
|
||||
json = true;
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--sync")) {
|
||||
should_sync = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
}
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{\"ok\":true,\"action\":\"watch\",\"path\":\"{s}\",\"queued\":{s}}\n", .{ path, if (should_queue) "true" else "false" });
|
||||
} else {
|
||||
std.debug.print("Watching {s} for changes...\n", .{path});
|
||||
std.debug.print("Press Ctrl+C to stop\n", .{});
|
||||
}
|
||||
|
||||
// Initial sync
|
||||
var last_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
|
||||
defer allocator.free(last_commit_id);
|
||||
|
||||
// Watch for changes
|
||||
var watcher = try std.fs.cwd().openDir(path, .{ .iterate = true });
|
||||
defer watcher.close();
|
||||
|
||||
var last_modified: u64 = 0;
|
||||
|
||||
while (true) {
|
||||
// Check for file changes
|
||||
var modified = false;
|
||||
var walker = try watcher.walk(allocator);
|
||||
defer walker.deinit();
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
const file = try watcher.openFile(entry.path, .{});
|
||||
defer file.close();
|
||||
|
||||
const stat = try file.stat();
|
||||
if (stat.mtime > last_modified) {
|
||||
last_modified = @intCast(stat.mtime);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
// Check mode if syncing
|
||||
if (should_sync) {
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml watch --sync requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
if (!json) {
|
||||
std.debug.print("\nChanges detected, syncing...\n", .{});
|
||||
}
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"ok\":true,\"action\":\"watch\",\"sync\":{s}}}\n", .{if (should_sync) "true" else "false"});
|
||||
} else {
|
||||
if (should_sync) {
|
||||
colors.printInfo("Watching for changes with auto-sync every {d}s...\n", .{sync_interval});
|
||||
} else {
|
||||
colors.printInfo("Watching directory for changes...\n", .{});
|
||||
}
|
||||
colors.printInfo("Press Ctrl+C to stop\n", .{});
|
||||
}
|
||||
|
||||
const new_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
|
||||
defer allocator.free(new_commit_id);
|
||||
|
||||
if (!std.mem.eql(u8, last_commit_id, new_commit_id)) {
|
||||
allocator.free(last_commit_id);
|
||||
last_commit_id = try allocator.dupe(u8, new_commit_id);
|
||||
if (!json) {
|
||||
std.debug.print("✓ Synced new version: {s}\n", .{last_commit_id[0..8]});
|
||||
}
|
||||
// Watch loop
|
||||
var last_synced: i64 = 0;
|
||||
while (true) {
|
||||
if (should_sync) {
|
||||
const now = std.time.timestamp();
|
||||
if (now - last_synced >= @as(i64, @intCast(sync_interval))) {
|
||||
// Trigger sync
|
||||
const sync_cmd = @import("sync.zig");
|
||||
sync_cmd.run(allocator, &[_][]const u8{"--json"}) catch |err| {
|
||||
if (!flags.json) {
|
||||
colors.printError("Auto-sync failed: {}\n", .{err});
|
||||
}
|
||||
};
|
||||
last_synced = now;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait before checking again
|
||||
std.Thread.sleep(2_000_000_000); // 2 seconds in nanoseconds
|
||||
std.Thread.sleep(2_000_000_000); // 2 seconds
|
||||
}
|
||||
}
|
||||
|
||||
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, config: Config) ![]u8 {
|
||||
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, cfg: config.Config) ![]u8 {
|
||||
// Calculate commit ID
|
||||
const commit_id = try crypto.hashDirectory(allocator, path);
|
||||
|
||||
|
|
@ -112,22 +85,22 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
|
|||
const remote_path = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/",
|
||||
.{ config.worker_user, config.worker_host, config.worker_base, commit_id },
|
||||
.{ cfg.worker_user, cfg.worker_host, cfg.worker_base, commit_id },
|
||||
);
|
||||
defer allocator.free(remote_path);
|
||||
|
||||
try rsync.sync(allocator, path, remote_path, config.worker_port);
|
||||
try rsync.sync(allocator, path, remote_path, cfg.worker_port);
|
||||
|
||||
if (should_queue) {
|
||||
const actual_job_name = job_name orelse commit_id[0..8];
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Connect to WebSocket and queue job
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendQueueJob(actual_job_name, commit_id, priority, api_key_hash);
|
||||
|
|
@ -144,11 +117,10 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
|
|||
}
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml watch <path> [options]\n\n", .{});
|
||||
std.debug.print("Usage: ml watch [options]\n\n", .{});
|
||||
std.debug.print("Watch for changes and optionally auto-sync.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --name <job> Override job name when used with --queue\n", .{});
|
||||
std.debug.print(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
|
||||
std.debug.print(" --queue Queue on every sync\n", .{});
|
||||
std.debug.print(" --json Emit a single JSON line describing watch start\n", .{});
|
||||
std.debug.print(" --sync Auto-sync runs to server every 30s\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue