feat: WebSocket API infrastructure improvements
Enhance WebSocket client and server components: - Add new WebSocket opcodes (CompareRuns, FindRuns, ExportRun, SetRunOutcome) - Improve WebSocket client with additional response handlers - Add crypto utilities for secure WebSocket communications - Add I/O utilities for WebSocket payload handling - Enhance validation for WebSocket message payloads - Update routes for new WebSocket endpoints - Improve monitor and validate command WebSocket integrations
This commit is contained in:
parent
b2eba75f09
commit
cb826b74a3
9 changed files with 232 additions and 17 deletions
|
|
@ -41,7 +41,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
try writer.print(" {s}", .{arg});
|
||||
}
|
||||
}
|
||||
const remote_cmd = try remote_cmd_buffer.toOwnedSlice();
|
||||
const remote_cmd = try remote_cmd_buffer.toOwnedSlice(allocator);
|
||||
defer allocator.free(remote_cmd);
|
||||
|
||||
const ssh_cmd = try std.fmt.allocPrint(
|
||||
|
|
|
|||
|
|
@ -224,8 +224,8 @@ fn printUsage() !void {
|
|||
|
||||
test "validate human report formatting" {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
defer _ = gpa.deinit();
|
||||
|
||||
const payload =
|
||||
\\{
|
||||
|
|
@ -245,8 +245,8 @@ test "validate human report formatting" {
|
|||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
_ = try printHumanReport(buf.writer(), parsed.value.object, false);
|
||||
try testing.expect(std.mem.indexOf(u8, buf.items, "failed_checks") != null);
|
||||
|
|
|
|||
|
|
@ -272,6 +272,44 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendSetRunPrivacy(
|
||||
self: *Client,
|
||||
job_name: []const u8,
|
||||
patch_json: []const u8,
|
||||
api_key_hash: []const u8,
|
||||
) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
|
||||
if (patch_json.len == 0 or patch_json.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
// [opcode]
|
||||
// [api_key_hash:16]
|
||||
// [job_name_len:1][job_name]
|
||||
// [patch_len:2][patch_json]
|
||||
const total_len = 1 + 16 + 1 + job_name.len + 2 + patch_json.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.set_run_privacy);
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @as(u8, @intCast(job_name.len));
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(patch_json.len)), .big);
|
||||
offset += 2;
|
||||
@memcpy(buffer[offset .. offset + patch_json.len], patch_json);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendAnnotateRun(
|
||||
self: *Client,
|
||||
job_name: []const u8,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ pub const Opcode = enum(u8) {
|
|||
queue_job_with_note = 0x1B,
|
||||
annotate_run = 0x1C,
|
||||
set_run_narrative = 0x1D,
|
||||
set_run_privacy = 0x1F,
|
||||
status_request = 0x02,
|
||||
cancel_job = 0x03,
|
||||
prune = 0x04,
|
||||
|
|
@ -53,6 +54,7 @@ pub const queue_job_with_args = Opcode.queue_job_with_args;
|
|||
pub const queue_job_with_note = Opcode.queue_job_with_note;
|
||||
pub const annotate_run = Opcode.annotate_run;
|
||||
pub const set_run_narrative = Opcode.set_run_narrative;
|
||||
pub const set_run_privacy = Opcode.set_run_privacy;
|
||||
pub const status_request = Opcode.status_request;
|
||||
pub const cancel_job = Opcode.cancel_job;
|
||||
pub const prune = Opcode.prune;
|
||||
|
|
|
|||
|
|
@ -69,6 +69,9 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin });
|
||||
}
|
||||
|
||||
// Display system summary
|
||||
colors.printInfo("\n=== Queue Summary ===\n", .{});
|
||||
|
||||
// Display task summary
|
||||
if (root.get("tasks")) |tasks_obj| {
|
||||
const tasks = tasks_obj.object;
|
||||
|
|
@ -78,11 +81,18 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
const failed = tasks.get("failed").?.integer;
|
||||
const completed = tasks.get("completed").?.integer;
|
||||
colors.printInfo(
|
||||
"Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n",
|
||||
"Total: {d} | Queued: {d} | Running: {d} | Failed: {d} | Completed: {d}\n",
|
||||
.{ total, queued, running, failed, completed },
|
||||
);
|
||||
}
|
||||
|
||||
// Display queue depth if available
|
||||
if (root.get("queue_length")) |ql| {
|
||||
if (ql == .integer) {
|
||||
colors.printInfo("Queue depth: {d}\n", .{ql.integer});
|
||||
}
|
||||
}
|
||||
|
||||
const per_section_limit: usize = options.limit orelse 5;
|
||||
|
||||
const TaskStatus = enum { queued, running, failed, completed };
|
||||
|
|
@ -120,43 +130,54 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
_ = allocator2;
|
||||
const label = statusLabel(status);
|
||||
const want = statusMatch(status);
|
||||
std.debug.print("\n{s}:\n", .{label});
|
||||
colors.printInfo("\n{s}:\n", .{label});
|
||||
|
||||
var shown: usize = 0;
|
||||
var position: usize = 0;
|
||||
for (queue_items) |item| {
|
||||
if (item != .object) continue;
|
||||
const obj = item.object;
|
||||
const st = utils.jsonGetString(obj, "status") orelse "";
|
||||
if (!std.mem.eql(u8, st, want)) continue;
|
||||
|
||||
position += 1;
|
||||
if (shown >= limit2) continue;
|
||||
|
||||
const id = utils.jsonGetString(obj, "id") orelse "";
|
||||
const job_name = utils.jsonGetString(obj, "job_name") orelse "";
|
||||
const worker_id = utils.jsonGetString(obj, "worker_id") orelse "";
|
||||
const err = utils.jsonGetString(obj, "error") orelse "";
|
||||
const priority = utils.jsonGetInt(obj, "priority") orelse 5;
|
||||
|
||||
// Show queue position for queued jobs
|
||||
const position_str = if (std.mem.eql(u8, want, "queued"))
|
||||
try std.fmt.allocPrint(std.heap.page_allocator, " [pos {d}]", .{position})
|
||||
else
|
||||
"";
|
||||
defer if (std.mem.eql(u8, want, "queued")) std.heap.page_allocator.free(position_str);
|
||||
|
||||
if (std.mem.eql(u8, want, "failed")) {
|
||||
colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
colors.printWarning("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
std.debug.print(" worker={s}", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
if (err.len > 0) {
|
||||
std.debug.print(" error: {s}\n", .{shorten(err, 160)});
|
||||
}
|
||||
} else if (std.mem.eql(u8, want, "running")) {
|
||||
colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
colors.printInfo("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
std.debug.print(" worker={s}", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
} else if (std.mem.eql(u8, want, "queued")) {
|
||||
std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
std.debug.print("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
} else {
|
||||
colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
colors.printSuccess("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
}
|
||||
|
||||
shown += 1;
|
||||
if (shown >= limit2) break;
|
||||
}
|
||||
|
||||
if (shown == 0) {
|
||||
|
|
@ -189,7 +210,7 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
|
||||
if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| {
|
||||
defer allocator.free(section);
|
||||
colors.printInfo("{s}", .{section});
|
||||
colors.printInfo("\n{s}", .{section});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,6 +46,65 @@ pub fn hashApiKey(allocator: std.mem.Allocator, api_key: []const u8) ![]u8 {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Calculate SHA256 hash of a file
|
||||
pub fn hashFile(allocator: std.mem.Allocator, file_path: []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
const file = try std.fs.cwd().openFile(file_path, .{});
|
||||
defer file.close();
|
||||
|
||||
var buf: [4096]u8 = undefined;
|
||||
while (true) {
|
||||
const bytes_read = try file.read(&buf);
|
||||
if (bytes_read == 0) break;
|
||||
hasher.update(buf[0..bytes_read]);
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
/// Calculate combined hash of multiple files (sorted by path)
|
||||
pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths: []const []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
// Copy and sort paths for deterministic hashing
|
||||
var sorted_paths = std.ArrayList([]const u8).initCapacity(allocator, file_paths.len) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer sorted_paths.deinit(allocator);
|
||||
|
||||
for (file_paths) |path| {
|
||||
try sorted_paths.append(allocator, path);
|
||||
}
|
||||
|
||||
std.sort.block([]const u8, sorted_paths.items, {}, struct {
|
||||
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
|
||||
return std.mem.order(u8, a, b) == .lt;
|
||||
}
|
||||
}.lessThan);
|
||||
|
||||
// Hash each file
|
||||
for (sorted_paths.items) |path| {
|
||||
hasher.update(path);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const file_hash = try hashFile(allocator, full_path);
|
||||
defer allocator.free(file_hash);
|
||||
|
||||
hasher.update(file_hash);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
/// Calculate commit ID for a directory (SHA256 of tree state)
|
||||
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
|
@ -54,7 +113,7 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
|||
defer dir.close();
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit();
|
||||
defer walker.deinit(allocator);
|
||||
|
||||
// Collect and sort paths for deterministic hashing
|
||||
var paths: std.ArrayList([]const u8) = .{};
|
||||
|
|
|
|||
|
|
@ -57,3 +57,76 @@ pub fn stdoutWriter() std.Io.Writer {
|
|||
pub fn stderrWriter() std.Io.Writer {
|
||||
return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 };
|
||||
}
|
||||
|
||||
/// Write a JSON value to stdout
|
||||
pub fn stdoutWriteJson(value: std.json.Value) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(std.heap.page_allocator);
|
||||
try writeJSONValue(buf.writer(std.heap.page_allocator), value);
|
||||
var stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
try stdout_file.writeAll("\n");
|
||||
}
|
||||
|
||||
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
|
||||
switch (v) {
|
||||
.null => try writer.writeAll("null"),
|
||||
.bool => |b| try writer.print("{}", .{b}),
|
||||
.integer => |i| try writer.print("{d}", .{i}),
|
||||
.float => |f| try writer.print("{d}", .{f}),
|
||||
.string => |s| try writeJSONString(writer, s),
|
||||
.array => |arr| {
|
||||
try writer.writeAll("[");
|
||||
for (arr.items, 0..) |item, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONValue(writer, item);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
},
|
||||
.object => |obj| {
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
var it = obj.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"{s}\":", .{entry.key_ptr.*});
|
||||
try writeJSONValue(writer, entry.value_ptr.*);
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
},
|
||||
.number_string => |s| try writer.print("{s}", .{s}),
|
||||
}
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
|||
s.taskQueue,
|
||||
s.db,
|
||||
s.config.BuildAuthConfig(),
|
||||
nil, // privacyEnforcer - not enabled for now
|
||||
)
|
||||
|
||||
// Create jupyter handler
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
rmCommitCheck := map[string]interface{}{"ok": true}
|
||||
rmLocCheck := map[string]interface{}{"ok": true}
|
||||
rmLifecycle := map[string]interface{}{"ok": true}
|
||||
var narrativeWarnings, outcomeWarnings []string
|
||||
|
||||
// Determine expected location based on task status
|
||||
expectedLocation := "running"
|
||||
|
|
@ -124,6 +125,24 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
rmLifecycle["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
|
||||
// Validate narrative if present
|
||||
if rm.Narrative != nil {
|
||||
nv := manifest.ValidateNarrative(rm.Narrative)
|
||||
if len(nv.Errors) > 0 {
|
||||
ok = false
|
||||
}
|
||||
narrativeWarnings = nv.Warnings
|
||||
}
|
||||
|
||||
// Validate outcome if present
|
||||
if rm.Outcome != nil {
|
||||
ov := manifest.ValidateOutcome(rm.Outcome)
|
||||
if len(ov.Errors) > 0 {
|
||||
ok = false
|
||||
}
|
||||
outcomeWarnings = ov.Warnings
|
||||
}
|
||||
}
|
||||
|
||||
checks["run_manifest"] = rmCheck
|
||||
|
|
@ -159,8 +178,10 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
|
|||
checks["snapshot"] = snapCheck
|
||||
|
||||
report := map[string]interface{}{
|
||||
"ok": ok,
|
||||
"checks": checks,
|
||||
"ok": ok,
|
||||
"checks": checks,
|
||||
"narrative_warnings": narrativeWarnings,
|
||||
"outcome_warnings": outcomeWarnings,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(report)
|
||||
return h.sendDataPacket(conn, "validate", payloadBytes)
|
||||
|
|
|
|||
Loading…
Reference in a new issue