fetch_ml/cli/src/commands/queue/submit.zig

200 lines
5.9 KiB
Zig

const std = @import("std");
const ws = @import("../../net/ws/client.zig");
const protocol = @import("../../net/protocol.zig");
const crypto = @import("../../utils/crypto.zig");
const Config = @import("../../config.zig").Config;
const core = @import("../../core.zig");
const history = @import("../../utils/history.zig");
/// Job submission configuration
pub const SubmitConfig = struct {
job_names: []const []const u8,
commit_id: ?[]const u8,
priority: u8,
snapshot_id: ?[]const u8,
snapshot_sha256: ?[]const u8,
args_override: ?[]const u8,
note_override: ?[]const u8,
cpu: u8,
memory: u8,
gpu: u8,
gpu_memory: ?[]const u8,
dry_run: bool,
force: bool,
runner_args: []const []const u8,
pub fn estimateTotalJobs(self: SubmitConfig) usize {
return self.job_names.len;
}
};
/// Submission result
pub const SubmitResult = struct {
success: bool,
job_count: usize,
errors: std.ArrayList([]const u8),
pub fn init() SubmitResult {
return .{
.success = true,
.job_count = 0,
.errors = .empty,
};
}
pub fn deinit(self: *SubmitResult, allocator: std.mem.Allocator) void {
for (self.errors.items) |err| {
allocator.free(err);
}
self.errors.deinit(allocator);
}
};
/// Submit jobs to the server
pub fn submitJobs(
allocator: std.mem.Allocator,
config: Config,
submit_config: SubmitConfig,
json: bool,
) !SubmitResult {
var result = SubmitResult.init(allocator);
errdefer result.deinit(allocator);
// Dry run mode - just print what would be submitted
if (submit_config.dry_run) {
if (json) {
std.debug.print("{{\"success\":true,\"command\":\"queue.submit\",\"dry_run\":true,\"jobs\":[", .{});
for (submit_config.job_names, 0..) |name, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("\"{s}\"", .{name});
}
std.debug.print("],\"total\":{d}}}}}\n", .{submit_config.job_names.len});
} else {
std.debug.print("[DRY RUN] Would submit {d} jobs:\n", .{submit_config.job_names.len});
for (submit_config.job_names) |name| {
std.debug.print(" - {s}\n", .{name});
}
}
result.job_count = submit_config.job_names.len;
return result;
}
// Get WebSocket URL
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
// Hash API key
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Connect to server
var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| {
const msg = try std.fmt.allocPrint(allocator, "Failed to connect: {}", .{err});
result.addError(msg);
result.success = false;
return result;
};
defer client.close();
// Submit each job
for (submit_config.job_names) |job_name| {
submitSingleJob(
allocator,
&client,
api_key_hash,
job_name,
submit_config,
&result,
) catch |err| {
const msg = try std.fmt.allocPrint(allocator, "Failed to submit {s}: {}", .{ job_name, err });
result.addError(msg);
result.success = false;
};
}
// Save to history if successful
if (result.success and result.job_count > 0) {
if (submit_config.commit_id) |commit_id| {
for (submit_config.job_names) |job_name| {
history.saveEntry(allocator, job_name, commit_id) catch {};
}
}
}
return result;
}
/// Submit a single job
fn submitSingleJob(
allocator: std.mem.Allocator,
client: *ws.Client,
_: []const u8,
job_name: []const u8,
submit_config: SubmitConfig,
result: *SubmitResult,
) !void {
// Build job submission payload
var payload = std.ArrayList(u8).init(allocator);
defer payload.deinit();
const writer = payload.writer();
try writer.print(
"{{\"job_name\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory\":{d},\"gpu\":{d}",
.{ job_name, submit_config.priority, submit_config.cpu, submit_config.memory, submit_config.gpu },
);
if (submit_config.gpu_memory) |gm| {
try writer.print(",\"gpu_memory\":\"{s}\"", .{gm});
}
try writer.print("}}", .{});
if (submit_config.commit_id) |cid| {
try writer.print(",\"commit_id\":\"{s}\"", .{cid});
}
if (submit_config.snapshot_id) |sid| {
try writer.print(",\"snapshot_id\":\"{s}\"", .{sid});
}
if (submit_config.note_override) |note| {
try writer.print(",\"note\":\"{s}\"", .{note});
}
try writer.print("}}", .{});
// Send job submission
client.sendMessage(payload.items) catch |err| {
return err;
};
result.job_count += 1;
}
/// Print submission results
pub fn printResults(result: SubmitResult, json: bool) void {
if (json) {
const status = if (result.success) "true" else "false";
std.debug.print("{{\"success\":{s},\"command\":\"queue.submit\",\"data\":{{\"submitted\":{d}", .{ status, result.job_count });
if (result.errors.items.len > 0) {
std.debug.print(",\"errors\":[", .{});
for (result.errors.items, 0..) |err, i| {
if (i > 0) std.debug.print(",", .{});
std.debug.print("\"{s}\"", .{err});
}
std.debug.print("]", .{});
}
std.debug.print("}}}}\n", .{});
} else {
if (result.success) {
std.debug.print("Successfully submitted {d} jobs\n", .{result.job_count});
} else {
std.debug.print("Failed to submit jobs ({d} errors)\n", .{result.errors.items.len});
for (result.errors.items) |err| {
std.debug.print(" Error: {s}\n", .{err});
}
}
}
}