fetch_ml/cli/src/commands/queue.zig
Jeremie Fraeys 7a6d454174
refactor(cli): modularize queue.zig structure
Move configuration types to queue/mod.zig:
- TrackingConfig with MLflow, TensorBoard, Wandb sub-configs
- QueueOptions with all queue-related options

queue.zig now re-exports from queue/mod.zig for backward compatibility.
Build passes successfully.
2026-03-04 21:00:23 -05:00

1239 lines
51 KiB
Zig

const std = @import("std");
const Config = @import("../config.zig").Config;
const ws = @import("../net/ws/client.zig");
const history = @import("../utils/history.zig");
const crypto = @import("../utils/crypto.zig");
const protocol = @import("../net/protocol.zig");
const stdcrypto = std.crypto;
const mode = @import("../mode.zig");
const db = @import("../db.zig");
const manifest_lib = @import("../manifest.zig");
// Use modular queue structure
const queue_mod = @import("queue/mod.zig");
// Re-export for backward compatibility
pub const TrackingConfig = queue_mod.TrackingConfig;
pub const QueueOptions = queue_mod.QueueOptions;
pub const parse = queue_mod.parse;
pub const validate = queue_mod.validate;
pub const submit = queue_mod.submit;
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
for (input) |c| {
if (!std.ascii.isHex(c)) return error.InvalidArgs;
}
if (input.len == 40) {
return allocator.dupe(u8, input);
}
var dir = if (std.fs.path.isAbsolute(base_path))
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
else
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
defer dir.close();
var it = dir.iterate();
var found: ?[]u8 = null;
errdefer if (found) |s| allocator.free(s);
while (try it.next()) |entry| {
if (entry.kind != .directory) continue;
const name = entry.name;
if (name.len != 40) continue;
if (!std.mem.startsWith(u8, name, input)) continue;
for (name) |c| {
if (!std.ascii.isHex(c)) break;
} else {
if (found != null) return error.InvalidArgs;
found = try allocator.dupe(u8, name);
}
}
if (found) |s| return s;
return error.FileNotFound;
}
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
try printUsage();
return error.InvalidArgs;
}
if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) {
try printUsage();
return;
}
// Load config for mode detection
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
// Detect mode early to provide clear error for offline
const mode_result = try mode.detect(allocator, config);
// Check for --rerun flag
var rerun_id: ?[]const u8 = null;
for (args, 0..) |arg, i| {
if (std.mem.eql(u8, arg, "--rerun") and i + 1 < args.len) {
rerun_id = args[i + 1];
break;
}
}
// If --rerun is specified, handle re-queueing
if (rerun_id) |id| {
if (mode.isOffline(mode_result.mode)) {
std.debug.print("ml queue --rerun requires server connection\n", .{});
return error.RequiresServer;
}
return try handleRerun(allocator, id, args, config);
}
// Regular queue - requires server
if (mode.isOffline(mode_result.mode)) {
std.debug.print("ml queue requires server connection (use 'ml run' for local execution)\n", .{});
return error.RequiresServer;
}
// Continue with regular queue logic...
try executeQueue(allocator, args, config);
}
fn executeQueue(allocator: std.mem.Allocator, args: []const []const u8, config: Config) !void {
// Support batch operations - multiple job names
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
std.debug.print("Failed to allocate job list: {}\n", .{err});
return err;
};
defer job_names.deinit(allocator);
var commit_id_override: ?[]const u8 = null;
var priority: u8 = 5;
var snapshot_id: ?[]const u8 = null;
var snapshot_sha256: ?[]const u8 = null;
var args_override: ?[]const u8 = null;
var note_override: ?[]const u8 = null;
// Initialize options with config defaults
var options = QueueOptions{
.cpu = config.default_cpu,
.memory = config.default_memory,
.gpu = config.default_gpu,
.gpu_memory = config.default_gpu_memory,
.dry_run = config.default_dry_run,
.validate = config.default_validate,
.json = config.default_json,
.secrets = try std.ArrayList([]const u8).initCapacity(allocator, 4),
};
defer options.secrets.deinit(allocator);
priority = config.default_priority;
// Tracking configuration
var tracking = TrackingConfig{};
var has_tracking = false;
// Support passing runner args after "--".
var sep_index: ?usize = null;
for (args, 0..) |a, idx| {
if (std.mem.eql(u8, a, "--")) {
sep_index = idx;
break;
}
}
const pre = args[0..(sep_index orelse args.len)];
const post = if (sep_index) |si| args[(si + 1)..] else args[0..0];
var args_joined: []const u8 = "";
if (post.len > 0) {
var buf: std.ArrayList(u8) = .{};
defer buf.deinit(allocator);
for (post, 0..) |a, j| {
if (j > 0) try buf.append(allocator, ' ');
try buf.appendSlice(allocator, a);
}
args_joined = try buf.toOwnedSlice(allocator);
}
defer if (post.len > 0) allocator.free(args_joined);
// Parse arguments - separate job names from flags
var i: usize = 0;
while (i < pre.len) : (i += 1) {
const arg = pre[i];
if (std.mem.startsWith(u8, arg, "--") or std.mem.eql(u8, arg, "-h")) {
// Parse flags
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
try printUsage();
return;
}
if (std.mem.eql(u8, arg, "--commit") and i + 1 < pre.len) {
if (commit_id_override != null) {
allocator.free(commit_id_override.?);
}
const commit_in = pre[i + 1];
const commit_hex = resolveCommitHexOrPrefix(allocator, config.worker_base, commit_in) catch |err| {
if (err == error.FileNotFound) {
std.debug.print("No commit matches prefix: {s}\n", .{commit_in});
return error.InvalidArgs;
}
std.debug.print("Invalid commit id\n", .{});
return error.InvalidArgs;
};
defer allocator.free(commit_hex);
const commit_bytes = crypto.decodeHex(allocator, commit_hex) catch {
std.debug.print("Invalid commit id: must be hex\n", .{});
return error.InvalidArgs;
};
if (commit_bytes.len != 20) {
allocator.free(commit_bytes);
std.debug.print("Invalid commit id: expected 20 bytes\n", .{});
return error.InvalidArgs;
}
commit_id_override = commit_bytes;
i += 1;
} else if (std.mem.eql(u8, arg, "--priority") and i + 1 < pre.len) {
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--mlflow")) {
tracking.mlflow = TrackingConfig.MLflowConfig{};
has_tracking = true;
} else if (std.mem.eql(u8, arg, "--mlflow-uri") and i + 1 < pre.len) {
tracking.mlflow = TrackingConfig.MLflowConfig{
.mode = "remote",
.tracking_uri = pre[i + 1],
};
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--tensorboard")) {
tracking.tensorboard = TrackingConfig.TensorBoardConfig{};
has_tracking = true;
} else if (std.mem.eql(u8, arg, "--wandb-key") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.api_key = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--wandb-project") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.project = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--wandb-entity") and i + 1 < pre.len) {
if (tracking.wandb == null) tracking.wandb = TrackingConfig.WandbConfig{};
tracking.wandb.?.entity = pre[i + 1];
has_tracking = true;
i += 1;
} else if (std.mem.eql(u8, arg, "--dry-run")) {
options.dry_run = true;
} else if (std.mem.eql(u8, arg, "--validate")) {
options.validate = true;
} else if (std.mem.eql(u8, arg, "--explain")) {
options.explain = true;
} else if (std.mem.eql(u8, arg, "--json")) {
options.json = true;
} else if (std.mem.eql(u8, arg, "--force")) {
options.force = true;
} else if (std.mem.eql(u8, arg, "--cpu") and i + 1 < pre.len) {
options.cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--memory") and i + 1 < pre.len) {
options.memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--gpu") and i + 1 < pre.len) {
options.gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--gpu-memory") and i + 1 < pre.len) {
options.gpu_memory = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--snapshot-id") and i + 1 < pre.len) {
snapshot_id = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--snapshot-sha256") and i + 1 < pre.len) {
snapshot_sha256 = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--args") and i + 1 < pre.len) {
args_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
note_override = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--hypothesis") and i + 1 < pre.len) {
options.hypothesis = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--context") and i + 1 < pre.len) {
options.context = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--intent") and i + 1 < pre.len) {
options.intent = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--expected-outcome") and i + 1 < pre.len) {
options.expected_outcome = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < pre.len) {
options.experiment_group = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--tags") and i + 1 < pre.len) {
options.tags = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--network") and i + 1 < pre.len) {
options.network_mode = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--read-only")) {
options.read_only = true;
} else if (std.mem.eql(u8, arg, "--secret") and i + 1 < pre.len) {
try options.secrets.append(allocator, pre[i + 1]);
i += 1;
} else if (std.mem.eql(u8, arg, "--reservation") and i + 1 < pre.len) {
options.reservation_id = pre[i + 1];
i += 1;
} else if (std.mem.eql(u8, arg, "--gang-size") and i + 1 < pre.len) {
options.gang_size = try std.fmt.parseInt(u32, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--max-wait") and i + 1 < pre.len) {
options.max_wait_time = try std.fmt.parseInt(u32, pre[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, arg, "--preemptible")) {
options.preemptible = true;
} else if (std.mem.eql(u8, arg, "--worker") and i + 1 < pre.len) {
options.preferred_worker = pre[i + 1];
i += 1;
}
} else {
// This is a job name
job_names.append(allocator, arg) catch |err| {
std.debug.print("Failed to append job: {}\n", .{err});
return err;
};
}
}
if (job_names.items.len == 0) {
std.debug.print("No job names specified\n", .{});
return error.InvalidArgs;
}
const print_next_steps = (!options.json) and (job_names.items.len == 1);
// Handle special modes
if (options.explain) {
try explainJob(allocator, job_names.items[0], commit_id_override, priority, &options);
return;
}
if (options.validate) {
try validateJob(allocator, job_names.items[0], commit_id_override, &options);
return;
}
if (options.dry_run) {
try dryRunJob(allocator, job_names.items[0], commit_id_override, priority, &options);
return;
}
std.debug.print("Queueing {d} job(s)...\n", .{job_names.items.len});
// Generate tracking JSON if needed (simplified for now)
const tracking_json: []const u8 = "";
// Process each job
var success_count: usize = 0;
var failed_jobs = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
std.debug.print("Failed to allocate failed jobs list: {}\n", .{err});
return err;
};
defer failed_jobs.deinit(allocator);
const args_str: []const u8 = if (args_override) |a| a else args_joined;
const note_str: []const u8 = if (note_override) |n| n else "";
for (job_names.items, 0..) |job_name, index| {
std.debug.print("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name });
queueSingleJob(
allocator,
job_name,
commit_id_override,
priority,
tracking_json,
&options,
snapshot_id,
snapshot_sha256,
args_str,
note_str,
print_next_steps,
) catch |err| {
std.debug.print("Failed to queue job '{s}': {}\n", .{ job_name, err });
failed_jobs.append(allocator, job_name) catch |append_err| {
std.debug.print("Failed to track failed job: {}\n", .{append_err});
};
continue;
};
std.debug.print("Successfully queued job '{s}'\n", .{job_name});
success_count += 1;
}
// Show summary
std.debug.print("Batch queuing complete.\n", .{});
std.debug.print("Successfully queued: {d} job(s)\n", .{success_count});
if (failed_jobs.items.len > 0) {
std.debug.print("Failed to queue: {d} job(s)\n", .{failed_jobs.items.len});
for (failed_jobs.items) |failed_job| {
std.debug.print(" - {s}\n", .{failed_job});
}
}
if (!options.json and success_count > 0 and job_names.items.len > 1) {
std.debug.print("\nNext steps:\n", .{});
std.debug.print(" ml status --watch\n", .{});
}
}
/// Handle --rerun flag: re-queue a completed run
fn handleRerun(allocator: std.mem.Allocator, run_id: []const u8, args: []const []const u8, cfg: Config) !void {
_ = args; // Override args not implemented yet
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
defer allocator.free(api_key_hash);
const ws_url = try cfg.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
defer client.close();
// Send rerun request to server
try client.sendRerunRequest(run_id, api_key_hash);
// Wait for response
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response (simplified)
if (std.mem.indexOf(u8, message, "success") != null) {
std.debug.print("Re-queued run {s}\n", .{run_id[0..8]});
} else {
std.debug.print("Failed to re-queue: {s}\n", .{message});
return error.RerunFailed;
}
}
fn generateCommitID(allocator: std.mem.Allocator) ![]const u8 {
var bytes: [20]u8 = undefined;
stdcrypto.random.bytes(&bytes);
return allocator.dupe(u8, &bytes);
}
fn queueSingleJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_override: ?[]const u8,
priority: u8,
tracking_json: []const u8,
options: *const QueueOptions,
snapshot_id: ?[]const u8,
snapshot_sha256: ?[]const u8,
args_str: []const u8,
note_str: []const u8,
print_next_steps: bool,
) !void {
const commit_id = blk: {
if (commit_override) |cid| break :blk cid;
const generated = try generateCommitID(allocator);
break :blk generated;
};
defer if (commit_override == null) allocator.free(commit_id);
// Build narrative JSON if any narrative fields are set
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
const config = try Config.load(allocator);
defer {
var mut_config = config;
mut_config.deinit(allocator);
}
const commit_hex = try crypto.encodeHexLower(allocator, commit_id);
defer allocator.free(commit_hex);
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Check for existing job with same commit (incremental queue)
if (!options.force) {
const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config);
if (existing) |ex| {
defer allocator.free(ex);
// Server already has this job - handle duplicate response
try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options);
return;
}
}
std.debug.print("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
// Connect to WebSocket and send queue message
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
defer client.close();
if ((snapshot_id != null) != (snapshot_sha256 != null)) {
std.debug.print("Both --snapshot-id and --snapshot-sha256 must be set\n", .{});
return error.InvalidArgs;
}
if (snapshot_id != null and tracking_json.len > 0) {
std.debug.print("Snapshot queueing is not supported with tracking yet\n", .{});
return error.InvalidArgs;
}
// Build combined metadata JSON with tracking and/or narrative
const combined_json = blk: {
if (tracking_json.len > 0 and narrative_json != null) {
// Merge tracking and narrative
var buf = try std.ArrayList(u8).initCapacity(allocator, 256);
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
try writer.writeAll(tracking_json[1 .. tracking_json.len - 1]); // Remove outer braces
try writer.writeAll(",");
try writer.writeAll("\"narrative\":");
try writer.writeAll(narrative_json.?);
try writer.writeAll("}");
break :blk try buf.toOwnedSlice(allocator);
} else if (tracking_json.len > 0) {
break :blk try allocator.dupe(u8, tracking_json);
} else if (narrative_json) |nj| {
var buf = try std.ArrayList(u8).initCapacity(allocator, 256);
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{\"narrative\":");
try writer.writeAll(nj);
try writer.writeAll("}");
break :blk try buf.toOwnedSlice(allocator);
} else {
break :blk "";
}
};
defer if (combined_json.len > 0 and combined_json.ptr != tracking_json.ptr) allocator.free(combined_json);
if (combined_json.len > 0) {
try client.sendQueueJobWithTrackingAndResources(
job_name,
commit_id,
priority,
api_key_hash,
combined_json,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
} else if (note_str.len > 0 or args_str.len > 0) {
if (note_str.len > 0) {
try client.sendQueueJobWithArgsNoteAndResources(
job_name,
commit_id,
priority,
api_key_hash,
args_str,
note_str,
options.force,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
} else {
try client.sendQueueJobWithArgsAndResources(
job_name,
commit_id,
priority,
api_key_hash,
args_str,
options.force,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
}
} else if (snapshot_id) |sid| {
try client.sendQueueJobWithSnapshotAndResources(
job_name,
commit_id,
priority,
api_key_hash,
sid,
snapshot_sha256.?,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
} else {
try client.sendQueueJobWithResources(
job_name,
commit_id,
priority,
api_key_hash,
options.cpu,
options.memory,
options.gpu,
options.gpu_memory,
);
}
// Receive and handle response with duplicate detection
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Try to parse as structured packet first
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
// Fallback: handle as plain text/JSON
if (message.len > 0 and message[0] == '{') {
try handleDuplicateResponse(allocator, message, job_name, commit_hex, options);
} else {
std.debug.print("Server response: {s}\n", .{message});
}
return;
};
defer packet.deinit(allocator);
switch (packet.packet_type) {
.success => {
history.record(allocator, job_name, commit_hex) catch |err| {
std.debug.print("Warning: failed to record job in history ({})\n", .{err});
};
if (options.json) {
std.debug.print("{{\"success\":true,\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"status\":\"queued\"}}\n", .{ job_name, commit_hex });
} else {
std.debug.print("Job queued: {s}\n", .{job_name});
if (print_next_steps) {
const next_steps = try formatNextSteps(allocator, job_name, commit_hex);
defer allocator.free(next_steps);
std.debug.print("{s}\n", .{next_steps});
}
}
},
.error_packet => {
const err_msg = packet.error_message orelse "Unknown error";
if (options.json) {
std.debug.print("{{\"success\":false,\"error\":\"{s}\"}}\n", .{err_msg});
} else {
std.debug.print("Error: {s}\n", .{err_msg});
}
return error.ServerError;
},
else => {
try client.handleResponsePacket(packet, "Job queue");
history.record(allocator, job_name, commit_hex) catch |err| {
std.debug.print("Warning: failed to record job in history ({})\n", .{err});
};
if (print_next_steps) {
const next_steps = try formatNextSteps(allocator, job_name, commit_hex);
defer allocator.free(next_steps);
std.debug.print("{s}\n", .{next_steps});
}
},
}
}
fn printUsage() !void {
std.debug.print("Usage: ml queue [options] <job_name> [job_name2 ...]\n\n", .{});
std.debug.print("Options:\n", .{});
std.debug.print("\t--priority <1-10>\tJob priority (default: 5)\n", .{});
std.debug.print("\t--commit <hex>\t\tSpecific commit to run\n", .{});
std.debug.print("\t--snapshot-id <id>\tSnapshot ID to use\n", .{});
std.debug.print("\t--snapshot-sha256 <sha>\tSnapshot SHA256 to use\n", .{});
std.debug.print("\t--dry-run\t\tShow what would be queued\n", .{});
std.debug.print("\t--explain <reason>\tReason for running\n", .{});
std.debug.print("\t--json\t\t\tOutput machine-readable JSON\n", .{});
std.debug.print("\t--help, -h\t\tShow this help message\n", .{});
std.debug.print("\t--context <text>\tBackground context for this experiment\n", .{});
std.debug.print("\t--intent <text>\t\tWhat you're trying to accomplish\n", .{});
std.debug.print("\t--expected-outcome <text>\tWhat you expect to happen\n", .{});
std.debug.print("\t--experiment-group <name>\tGroup related experiments\n", .{});
std.debug.print("\t--tags <csv>\t\tComma-separated tags (e.g., ablation,batch-size)\n", .{});
std.debug.print("\nSpecial Modes:\n", .{});
std.debug.print("\t--rerun <run_id>\tRe-queue a completed local run to server\n", .{});
std.debug.print("\t--dry-run\t\tShow what would be queued\n", .{});
std.debug.print("\t--validate\t\tValidate experiment without queuing\n", .{});
std.debug.print("\t--explain\t\tExplain what will happen\n", .{});
std.debug.print("\t--json\t\t\tOutput structured JSON\n", .{});
std.debug.print("\t--force\t\t\tQueue even if duplicate exists\n", .{});
std.debug.print("\nTracking:\n", .{});
std.debug.print("\t--mlflow\t\tEnable MLflow (sidecar)\n", .{});
std.debug.print("\t--mlflow-uri <uri>\tEnable MLflow (remote)\n", .{});
std.debug.print("\t--tensorboard\t\tEnable TensorBoard\n", .{});
std.debug.print("\t--wandb-key <key>\tEnable Wandb with API key\n", .{});
std.debug.print("\t--wandb-project <prj>\tSet Wandb project\n", .{});
std.debug.print("\t--wandb-entity <ent>\tSet Wandb entity\n", .{});
std.debug.print("\nSandboxing:\n", .{});
std.debug.print("\t--network <mode>\tNetwork mode: none, bridge, slirp4netns\n", .{});
std.debug.print("\t--read-only\t\tMount root filesystem as read-only\n", .{});
std.debug.print("\t--secret <name>\t\tInject secret as env var (can repeat)\n", .{});
std.debug.print("\nScheduler Options:\n", .{});
std.debug.print("\t--reservation <id>\tUse existing GPU reservation\n", .{});
std.debug.print("\t--gang-size <n>\t\tRequest gang scheduling for multi-node jobs\n", .{});
std.debug.print("\t--max-wait <min>\tMaximum wait time before failing\n", .{});
std.debug.print("\t--preemptible\t\tAllow job to be preempted\n", .{});
std.debug.print("\t--worker <id>\t\tPrefer specific worker\n", .{});
std.debug.print("\nExamples:\n", .{});
std.debug.print("\tml queue my_job\t\t\t # Queue a job\n", .{});
std.debug.print("\tml queue my_job --dry-run\t # Preview submission\n", .{});
std.debug.print("\tml queue my_job --validate\t # Validate locally\n", .{});
std.debug.print("\tml queue --rerun abc123\t # Re-queue completed run\n", .{});
std.debug.print("\tml status --watch\t\t # Watch queue + prewarm\n", .{});
std.debug.print("\nResearch Examples:\n", .{});
std.debug.print("\tml queue train.py --hypothesis 'LR scaling improves convergence'\n", .{});
std.debug.print("\t\t--context 'Following paper XYZ' --tags ablation,lr-scaling\n", .{});
}
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {
var out = try std.ArrayList(u8).initCapacity(allocator, 128);
errdefer out.deinit(allocator);
const writer = out.writer(allocator);
try writer.writeAll("Next steps:\n");
try writer.writeAll("\tml status --watch\n");
try writer.print("\tml cancel {s}\n", .{job_name});
try writer.print("\tml validate {s}\n", .{commit_hex});
return out.toOwnedSlice(allocator);
}
fn explainJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_override: ?[]const u8,
priority: u8,
options: *const QueueOptions,
) !void {
var commit_display: []const u8 = "current-git-head";
var commit_display_owned: ?[]u8 = null;
defer if (commit_display_owned) |b| allocator.free(b);
if (commit_override) |cid| {
const enc = try crypto.encodeHexLower(allocator, cid);
commit_display_owned = enc;
commit_display = enc;
}
// Build narrative JSON for display
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"explain\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch |err| {
std.log.err("Failed to format output: {}", .{err});
return error.FormatError;
};
try stdout_file.writeAll(formatted);
try writeJSONNullableString(&stdout_file, options.gpu_memory);
if (narrative_json) |nj| {
try stdout_file.writeAll("},\"narrative\":");
try stdout_file.writeAll(nj);
try stdout_file.writeAll("}\n");
} else {
try stdout_file.writeAll("}}\n");
}
return;
} else {
std.debug.print("Job Explanation:\n", .{});
std.debug.print("\tJob Name: {s}\n", .{job_name});
std.debug.print("\tCommit ID: {s}\n", .{commit_display});
std.debug.print("\tPriority: {d}\n", .{priority});
std.debug.print("\tResources Requested:\n", .{});
std.debug.print("\t\tCPU: {d} cores\n", .{options.cpu});
std.debug.print("\t\tMemory: {d} GB\n", .{options.memory});
std.debug.print("\t\tGPU: {d} device(s)\n", .{options.gpu});
std.debug.print("\t\tGPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
// Display narrative if provided
if (narrative_json != null) {
std.debug.print("\n\tResearch Narrative:\n", .{});
if (options.hypothesis) |h| {
std.debug.print("\t\tHypothesis: {s}\n", .{h});
}
if (options.context) |c| {
std.debug.print("\t\tContext: {s}\n", .{c});
}
if (options.intent) |i| {
std.debug.print("\t\tIntent: {s}\n", .{i});
}
if (options.expected_outcome) |eo| {
std.debug.print("\t\tExpected Outcome: {s}\n", .{eo});
}
if (options.experiment_group) |eg| {
std.debug.print("\t\tExperiment Group: {s}\n", .{eg});
}
if (options.tags) |t| {
std.debug.print("\t\tTags: {s}\n", .{t});
}
}
std.debug.print("\n Action: Job would be queued for execution\n", .{});
}
}
fn validateJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_override: ?[]const u8,
options: *const QueueOptions,
) !void {
var commit_display: []const u8 = "current-git-head";
var commit_display_owned: ?[]u8 = null;
defer if (commit_display_owned) |b| allocator.free(b);
if (commit_override) |cid| {
const enc = try crypto.encodeHexLower(allocator, cid);
commit_display_owned = enc;
commit_display = enc;
}
// Basic local validation - simplified without JSON ObjectMap for now
// Check if current directory has required files
const train_script_exists = if (std.fs.cwd().access("train.py", .{})) true else |err| switch (err) {
error.FileNotFound => false,
else => false, // Treat other errors as file doesn't exist
};
const requirements_exists = if (std.fs.cwd().access("requirements.txt", .{})) true else |err| switch (err) {
error.FileNotFound => false,
else => false, // Treat other errors as file doesn't exist
};
const overall_valid = train_script_exists and requirements_exists;
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"validate\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"checks\":{{\"train_py\":{s},\"requirements_txt\":{s}}},\"ok\":{s}}}\n", .{ job_name, commit_display, if (train_script_exists) "true" else "false", if (requirements_exists) "true" else "false", if (overall_valid) "true" else "false" }) catch unreachable;
try stdout_file.writeAll(formatted);
return;
} else {
std.debug.print("Validation Results:\n", .{});
std.debug.print("\tJob Name: {s}\n", .{job_name});
std.debug.print("\tCommit ID: {s}\n", .{commit_display});
std.debug.print("\tRequired Files:\n", .{});
const train_status = if (train_script_exists) "yes" else "no";
const req_status = if (requirements_exists) "yes" else "no";
std.debug.print("\t\ttrain.py {s}\n", .{train_status});
std.debug.print("\t\trequirements.txt {s}\n", .{req_status});
if (overall_valid) {
std.debug.print("\tValidation passed - job is ready to queue\n", .{});
} else {
std.debug.print("\tValidation failed - missing required files\n", .{});
}
}
}
fn dryRunJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_override: ?[]const u8,
priority: u8,
options: *const QueueOptions,
) !void {
var commit_display: []const u8 = "current-git-head";
var commit_display_owned: ?[]u8 = null;
defer if (commit_display_owned) |b| allocator.free(b);
if (commit_override) |cid| {
const enc = try crypto.encodeHexLower(allocator, cid);
commit_display_owned = enc;
commit_display = enc;
}
// Build narrative JSON for display
const narrative_json = buildNarrativeJson(allocator, options) catch null;
defer if (narrative_json) |j| allocator.free(j);
if (options.json) {
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
var buffer: [4096]u8 = undefined;
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"dry_run\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch |err| {
std.log.err("Failed to format output: {}", .{err});
return error.FormatError;
};
try stdout_file.writeAll(formatted);
try writeJSONNullableString(&stdout_file, options.gpu_memory);
if (narrative_json) |nj| {
try stdout_file.writeAll("},\"narrative\":");
try stdout_file.writeAll(nj);
try stdout_file.writeAll(",\"would_queue\":true}}\n");
} else {
try stdout_file.writeAll("},\"would_queue\":true}}\n");
}
return;
} else {
std.debug.print("Dry Run - Job Queue Preview:\n", .{});
std.debug.print("\tJob Name: {s}\n", .{job_name});
std.debug.print("\tCommit ID: {s}\n", .{commit_display});
std.debug.print("\tPriority: {d}\n", .{priority});
std.debug.print("\tResources Requested:\n", .{});
std.debug.print("\t\tCPU: {d} cores\n", .{options.cpu});
std.debug.print("\t\tMemory: {d} GB\n", .{options.memory});
std.debug.print("\t\tGPU: {d} device(s)\n", .{options.gpu});
std.debug.print("\t\tGPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
// Display narrative if provided
if (narrative_json != null) {
std.debug.print("\n\tResearch Narrative:\n", .{});
if (options.hypothesis) |h| {
std.debug.print("\t\tHypothesis: {s}\n", .{h});
}
if (options.context) |c| {
std.debug.print("\t\tContext: {s}\n", .{c});
}
if (options.intent) |i| {
std.debug.print("\t\tIntent: {s}\n", .{i});
}
if (options.expected_outcome) |eo| {
std.debug.print("\t\tExpected Outcome: {s}\n", .{eo});
}
if (options.experiment_group) |eg| {
std.debug.print("\t\tExperiment Group: {s}\n", .{eg});
}
if (options.tags) |t| {
std.debug.print("\t\tTags: {s}\n", .{t});
}
}
std.debug.print("\n\tAction: Would queue job\n", .{});
std.debug.print("\tEstimated queue time: 2-5 minutes\n", .{});
std.debug.print("\tDry run completed - no job was actually queued\n", .{});
}
}
fn writeJSONNullableString(writer: anytype, s: ?[]const u8) !void {
if (s) |val| {
try writeJSONString(writer, val);
} else {
try writer.writeAll("null");
}
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn handleDuplicateResponse(
allocator: std.mem.Allocator,
payload: []const u8,
job_name: []const u8,
commit_hex: []const u8,
options: *const QueueOptions,
) !void {
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch {
if (options.json) {
std.debug.print("{s}\n", .{payload});
} else {
std.debug.print("Server response: {s}\n", .{payload});
}
return;
};
defer parsed.deinit();
const root = parsed.value.object;
const is_dup = root.get("duplicate") != null and root.get("duplicate").?.bool;
if (!is_dup) {
if (options.json) {
std.debug.print("{s}\n", .{payload});
} else {
std.debug.print("Job queued: {s}\n", .{job_name});
}
return;
}
const existing_id = root.get("existing_id").?.string;
const status = root.get("status").?.string;
const queued_by = root.get("queued_by").?.string;
const queued_at = root.get("queued_at").?.integer;
const now = std.time.timestamp();
const minutes_ago = @divTrunc(now - queued_at, 60);
if (std.mem.eql(u8, status, "queued") or std.mem.eql(u8, status, "running")) {
if (options.json) {
std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"{s}\",\"queued_by\":\"{s}\",\"minutes_ago\":{d},\"suggested_action\":\"watch\"}}\n", .{ existing_id, status, queued_by, minutes_ago });
} else {
std.debug.print("\nIdentical job already in progress: {s}\n", .{existing_id[0..8]});
std.debug.print("\tQueued by {s}, {d} minutes ago\n", .{ queued_by, minutes_ago });
std.debug.print("\tStatus: {s}\n", .{status});
std.debug.print("\n\tWatch: ml watch {s}\n", .{existing_id[0..8]});
std.debug.print("\tRerun: ml queue {s} --commit {s} --force\n", .{ job_name, commit_hex });
}
} else if (std.mem.eql(u8, status, "completed")) {
const duration_sec = root.get("duration_seconds").?.integer;
const duration_min = @divTrunc(duration_sec, 60);
if (options.json) {
std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"completed\",\"queued_by\":\"{s}\",\"duration_minutes\":{d},\"suggested_action\":\"show\"}}\n", .{ existing_id, queued_by, duration_min });
} else {
std.debug.print("\nIdentical job already completed: {s}\n", .{existing_id[0..8]});
std.debug.print(" Queued by {s}\n", .{queued_by});
const metrics = root.get("metrics");
if (metrics) |m| {
if (m == .object) {
std.debug.print("\n Results:\n", .{});
if (m.object.get("accuracy")) |v| {
if (v == .float) std.debug.print(" accuracy: {d:.3}\n", .{v.float});
}
if (m.object.get("loss")) |v| {
if (v == .float) std.debug.print(" loss: {d:.3}\n", .{v.float});
}
}
}
std.debug.print("\t\tduration: {d}m\n", .{duration_min});
std.debug.print("\n\tInspect: ml experiment show {s}\n", .{existing_id[0..8]});
std.debug.print("\tRerun: ml queue {s} --commit {s} --force\n", .{ job_name, commit_hex });
}
} else if (std.mem.eql(u8, status, "failed")) {
const error_reason = root.get("error_reason").?.string;
const failure_class = if (root.get("failure_class")) |fc| fc.string else "unknown";
const exit_code = if (root.get("exit_code")) |ec| ec.integer else 0;
const signal = if (root.get("signal")) |s| s.string else "";
const log_tail = if (root.get("log_tail")) |lt| lt.string else "";
const suggestion = if (root.get("suggestion")) |s| s.string else "";
const retry_count = if (root.get("retry_count")) |rc| rc.integer else 0;
const retry_cap = if (root.get("retry_cap")) |rc| rc.integer else 3;
const auto_retryable = if (root.get("auto_retryable")) |ar| ar.bool else false;
const requires_fix = if (root.get("requires_fix")) |rf| rf.bool else false;
if (options.json) {
const suggested_action = if (requires_fix) "fix" else if (auto_retryable) "wait" else "requeue";
std.debug.print("{{\"success\":true,\"duplicate\":true,\"existing_id\":\"{s}\",\"status\":\"failed\",\"failure_class\":\"{s}\",\"exit_code\":{d},\"signal\":\"{s}\",\"error_reason\":\"{s}\",\"retry_count\":{d},\"retry_cap\":{d},\"auto_retryable\":{},\"requires_fix\":{},\"suggested_action\":\"{s}\"}}\n", .{ existing_id, failure_class, exit_code, signal, error_reason, retry_count, retry_cap, auto_retryable, requires_fix, suggested_action });
} else {
// Print rich failure information based on FailureClass
std.debug.print("\nFAILED {s} {s} failure\n", .{ existing_id[0..8], failure_class });
if (signal.len > 0) {
std.debug.print("\tSignal: {s} (exit code: {d})\n", .{ signal, exit_code });
} else if (exit_code != 0) {
std.debug.print("\tExit code: {d}\n", .{exit_code});
}
// Show log tail if available
if (log_tail.len > 0) {
// Truncate long log tails
const display_tail = if (log_tail.len > 160) log_tail[0..160] else log_tail;
std.debug.print("\tLog: {s}...\n", .{display_tail});
}
// Show retry history
if (retry_count > 0) {
if (auto_retryable and retry_count < retry_cap) {
std.debug.print("\tRetried: {d}/{d} — auto-retry in progress\n", .{ retry_count, retry_cap });
} else {
std.debug.print("\tRetried: {d}/{d}\n", .{ retry_count, retry_cap });
}
}
// Class-specific guidance per design spec
if (std.mem.eql(u8, failure_class, "infrastructure")) {
std.debug.print("\n\tInfrastructure failure (node died, preempted).\n", .{});
if (auto_retryable and retry_count < retry_cap) {
std.debug.print("\t-> Auto-retrying transparently (attempt {d}/{d})\n", .{ retry_count + 1, retry_cap });
} else if (retry_count >= retry_cap) {
std.debug.print("\t-> Retry cap reached. Requires manual intervention.\n", .{});
std.debug.print("\tResubmit: ml requeue {s}\n", .{existing_id[0..8]});
}
std.debug.print("\tLogs: ml logs {s}\n", .{existing_id[0..8]});
} else if (std.mem.eql(u8, failure_class, "code")) {
// CRITICAL RULE: code failures never auto-retry
std.debug.print("\n\tCode failure — auto-retry is blocked.\n", .{});
std.debug.print("\tYou must fix the code before resubmitting.\n", .{});
std.debug.print("\t\tView logs: ml logs {s}\n", .{existing_id[0..8]});
std.debug.print("\n\tAfter fix:\n", .{});
std.debug.print("\t\tRequeue with same config:\n", .{});
std.debug.print("\t\t\tml requeue {s}\n", .{existing_id[0..8]});
std.debug.print("\t\tOr with more resources:\n", .{});
std.debug.print("\t\t\tml requeue {s} --gpu-memory 16\n", .{existing_id[0..8]});
} else if (std.mem.eql(u8, failure_class, "data")) {
// Data failures never auto-retry
std.debug.print("\n\tData failure — verification/checksum issue.\n", .{});
std.debug.print("\tAuto-retry will fail again with same data.\n", .{});
std.debug.print("\n\tCheck:\n", .{});
std.debug.print("\t\tDataset availability: ml dataset verify {s}\n", .{existing_id[0..8]});
std.debug.print("\t\tView logs: ml logs {s}\n", .{existing_id[0..8]});
std.debug.print("\n\tAfter data issue resolved:\n", .{});
std.debug.print("\t\t\tml requeue {s}\n", .{existing_id[0..8]});
} else if (std.mem.eql(u8, failure_class, "resource")) {
std.debug.print("\n\tResource failure — OOM or disk full.\n", .{});
if (retry_count == 0 and auto_retryable) {
std.debug.print("\t-> Will retry once with backoff (30s delay)\n", .{});
} else if (retry_count >= 1) {
std.debug.print("\t-> Retried once, failed again with same error.\n", .{});
std.debug.print("\n\tSuggestion: resubmit with more resources:\n", .{});
std.debug.print("\t\tml requeue {s} --gpu-memory 16\n", .{existing_id[0..8]});
std.debug.print("\t\tml requeue {s} --memory 32 --cpu 8\n", .{existing_id[0..8]});
}
std.debug.print("\n\tCheck capacity: ml status\n", .{});
std.debug.print("\tLogs: ml logs {s}\n", .{existing_id[0..8]});
} else {
// Unknown failures
std.debug.print("\n\tUnknown failure — classification unclear.\n", .{});
std.debug.print("\n\tReview full logs and decide:\n", .{});
std.debug.print("\t\tml logs {s}\n", .{existing_id[0..8]});
if (auto_retryable) {
std.debug.print("\n\tOr retry:\n", .{});
std.debug.print("\t\tml requeue {s}\n", .{existing_id[0..8]});
}
}
// Always show the suggestion if available
if (suggestion.len > 0) {
std.debug.print("\n\t{s}\n", .{suggestion});
}
}
}
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}
// buildNarrativeJson creates a JSON object from narrative fields
fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions) !?[]u8 {
// Check if any narrative field is set
if (options.hypothesis == null and
options.context == null and
options.intent == null and
options.expected_outcome == null and
options.experiment_group == null and
options.tags == null)
{
return null;
}
var buf = try std.ArrayList(u8).initCapacity(allocator, 256);
defer buf.deinit(allocator);
const writer = buf.writer(allocator);
try writer.writeAll("{");
var first = true;
if (options.hypothesis) |h| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"hypothesis\":");
try writeJSONString(writer, h);
}
if (options.context) |c| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"context\":");
try writeJSONString(writer, c);
}
if (options.intent) |i| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"intent\":");
try writeJSONString(writer, i);
}
if (options.expected_outcome) |eo| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"expected_outcome\":");
try writeJSONString(writer, eo);
}
if (options.experiment_group) |eg| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"experiment_group\":");
try writeJSONString(writer, eg);
}
if (options.tags) |t| {
if (!first) try writer.writeAll(",");
first = false;
try writer.writeAll("\"tags\":");
try writeJSONString(writer, t);
}
try writer.writeAll("}");
return try buf.toOwnedSlice(allocator);
}
/// Check if a job with the same commit_id already exists on the server
/// Returns: Optional JSON response from server if duplicate found
fn checkExistingJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_id: []const u8,
api_key_hash: []const u8,
config: Config,
) !?[]const u8 {
// Connect to server and query for existing job
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
defer client.close();
// Send query for existing job
try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash);
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| {
// If JSON parse fails, treat as no duplicate found
std.log.debug("Failed to parse check response: {}", .{err});
return null;
};
defer parsed.deinit();
const root = parsed.value.object;
// Check if job exists
if (root.get("exists")) |exists| {
if (!exists.bool) return null;
// Job exists - copy the full response for caller
return try allocator.dupe(u8, message);
}
return null;
}