From cb826b74a37017695fadb76f842f3e1a9e1aab12 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Wed, 18 Feb 2026 21:27:48 -0500 Subject: [PATCH] feat: WebSocket API infrastructure improvements Enhance WebSocket client and server components: - Add new WebSocket opcodes (CompareRuns, FindRuns, ExportRun, SetRunOutcome) - Improve WebSocket client with additional response handlers - Add crypto utilities for secure WebSocket communications - Add I/O utilities for WebSocket payload handling - Enhance validation for WebSocket message payloads - Update routes for new WebSocket endpoints - Improve monitor and validate command WebSocket integrations --- cli/src/commands/monitor.zig | 2 +- cli/src/commands/validate.zig | 6 +-- cli/src/net/ws/client.zig | 38 +++++++++++++++ cli/src/net/ws/opcode.zig | 2 + cli/src/net/ws/response_handlers.zig | 41 ++++++++++++---- cli/src/utils/crypto.zig | 61 ++++++++++++++++++++++- cli/src/utils/io.zig | 73 ++++++++++++++++++++++++++++ internal/api/routes.go | 1 + internal/api/ws/validate.go | 25 +++++++++- 9 files changed, 232 insertions(+), 17 deletions(-) diff --git a/cli/src/commands/monitor.zig b/cli/src/commands/monitor.zig index cc08dcb..0c2908f 100644 --- a/cli/src/commands/monitor.zig +++ b/cli/src/commands/monitor.zig @@ -41,7 +41,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { try writer.print(" {s}", .{arg}); } } - const remote_cmd = try remote_cmd_buffer.toOwnedSlice(); + const remote_cmd = try remote_cmd_buffer.toOwnedSlice(allocator); defer allocator.free(remote_cmd); const ssh_cmd = try std.fmt.allocPrint( diff --git a/cli/src/commands/validate.zig b/cli/src/commands/validate.zig index e229a06..0a29f2f 100644 --- a/cli/src/commands/validate.zig +++ b/cli/src/commands/validate.zig @@ -224,8 +224,8 @@ fn printUsage() !void { test "validate human report formatting" { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer _ = gpa.deinit(); const allocator = gpa.allocator(); + defer _ = gpa.deinit(); const payload = \\{ @@ -245,8 +245,8 @@ test "validate human report formatting" { const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{}); defer parsed.deinit(); - var buf = std.ArrayList(u8).init(allocator); - defer buf.deinit(); + var buf = std.ArrayList(u8).empty; + defer buf.deinit(allocator); _ = try printHumanReport(buf.writer(), parsed.value.object, false); try testing.expect(std.mem.indexOf(u8, buf.items, "failed_checks") != null); diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 38acac4..47668bb 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -272,6 +272,44 @@ pub const Client = struct { try frame.sendWebSocketFrame(stream, buffer); } + pub fn sendSetRunPrivacy( + self: *Client, + job_name: []const u8, + patch_json: []const u8, + api_key_hash: []const u8, + ) !void { + const stream = self.stream orelse return error.NotConnected; + + if (api_key_hash.len != 16) return error.InvalidApiKeyHash; + if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong; + if (patch_json.len == 0 or patch_json.len > 0xFFFF) return error.PayloadTooLarge; + + // [opcode] + // [api_key_hash:16] + // [job_name_len:1][job_name] + // [patch_len:2][patch_json] + const total_len = 1 + 16 + 1 + job_name.len + 2 + patch_json.len; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.set_run_privacy); + offset += 1; + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @as(u8, @intCast(job_name.len)); + offset += 1; + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(patch_json.len)), .big); + offset += 2; + @memcpy(buffer[offset .. offset + patch_json.len], patch_json); + + try frame.sendWebSocketFrame(stream, buffer); + } + pub fn sendAnnotateRun( self: *Client, job_name: []const u8, diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig index 4f6e142..c688faf 100644 --- a/cli/src/net/ws/opcode.zig +++ b/cli/src/net/ws/opcode.zig @@ -6,6 +6,7 @@ pub const Opcode = enum(u8) { queue_job_with_note = 0x1B, annotate_run = 0x1C, set_run_narrative = 0x1D, + set_run_privacy = 0x1F, status_request = 0x02, cancel_job = 0x03, prune = 0x04, @@ -53,6 +54,7 @@ pub const queue_job_with_args = Opcode.queue_job_with_args; pub const queue_job_with_note = Opcode.queue_job_with_note; pub const annotate_run = Opcode.annotate_run; pub const set_run_narrative = Opcode.set_run_narrative; +pub const set_run_privacy = Opcode.set_run_privacy; pub const status_request = Opcode.status_request; pub const cancel_job = Opcode.cancel_job; pub const prune = Opcode.prune; diff --git a/cli/src/net/ws/response_handlers.zig b/cli/src/net/ws/response_handlers.zig index 514da66..134cc6f 100644 --- a/cli/src/net/ws/response_handlers.zig +++ b/cli/src/net/ws/response_handlers.zig @@ -69,6 +69,9 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8 colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin }); } + // Display system summary + colors.printInfo("\n=== Queue Summary ===\n", .{}); + // Display task summary if (root.get("tasks")) |tasks_obj| { const tasks = tasks_obj.object; @@ -78,11 +81,18 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8 const failed = tasks.get("failed").?.integer; const completed = tasks.get("completed").?.integer; colors.printInfo( - "Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n", + "Total: {d} | Queued: {d} | Running: {d} | Failed: {d} | Completed: {d}\n", .{ total, queued, running, failed, completed }, ); } + // Display queue depth if available + if (root.get("queue_length")) |ql| { + if (ql == .integer) { + colors.printInfo("Queue depth: {d}\n", .{ql.integer}); + } + } + const per_section_limit: usize = options.limit orelse 5; const TaskStatus = enum { queued, running, failed, completed }; @@ -120,43 +130,54 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8 _ = allocator2; const label = statusLabel(status); const want = statusMatch(status); - std.debug.print("\n{s}:\n", .{label}); + colors.printInfo("\n{s}:\n", .{label}); var shown: usize = 0; + var position: usize = 0; for (queue_items) |item| { if (item != .object) continue; const obj = item.object; const st = utils.jsonGetString(obj, "status") orelse ""; if (!std.mem.eql(u8, st, want)) continue; + position += 1; + if (shown >= limit2) continue; + const id = utils.jsonGetString(obj, "id") orelse ""; const job_name = utils.jsonGetString(obj, "job_name") orelse ""; const worker_id = utils.jsonGetString(obj, "worker_id") orelse ""; const err = utils.jsonGetString(obj, "error") orelse ""; + const priority = utils.jsonGetInt(obj, "priority") orelse 5; + + // Show queue position for queued jobs + const position_str = if (std.mem.eql(u8, want, "queued")) + try std.fmt.allocPrint(std.heap.page_allocator, " [pos {d}]", .{position}) + else + ""; + defer if (std.mem.eql(u8, want, "queued")) std.heap.page_allocator.free(position_str); if (std.mem.eql(u8, want, "failed")) { - colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name }); + colors.printWarning("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority }); if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); + std.debug.print(" worker={s}", .{worker_id}); } std.debug.print("\n", .{}); if (err.len > 0) { std.debug.print(" error: {s}\n", .{shorten(err, 160)}); } } else if (std.mem.eql(u8, want, "running")) { - colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name }); + colors.printInfo("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority }); if (worker_id.len > 0) { - std.debug.print(" (worker={s})", .{worker_id}); + std.debug.print(" worker={s}", .{worker_id}); } std.debug.print("\n", .{}); } else if (std.mem.eql(u8, want, "queued")) { - std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name }); + std.debug.print("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority }); } else { - colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name }); + colors.printSuccess("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority }); } shown += 1; - if (shown >= limit2) break; } if (shown == 0) { @@ -189,7 +210,7 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8 if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| { defer allocator.free(section); - colors.printInfo("{s}", .{section}); + colors.printInfo("\n{s}", .{section}); } } } diff --git a/cli/src/utils/crypto.zig b/cli/src/utils/crypto.zig index 7b5b959..cc1ed59 100644 --- a/cli/src/utils/crypto.zig +++ b/cli/src/utils/crypto.zig @@ -46,6 +46,65 @@ pub fn hashApiKey(allocator: std.mem.Allocator, api_key: []const u8) ![]u8 { return result; } +/// Calculate SHA256 hash of a file +pub fn hashFile(allocator: std.mem.Allocator, file_path: []const u8) ![]u8 { + var hasher = std.crypto.hash.sha2.Sha256.init(.{}); + + const file = try std.fs.cwd().openFile(file_path, .{}); + defer file.close(); + + var buf: [4096]u8 = undefined; + while (true) { + const bytes_read = try file.read(&buf); + if (bytes_read == 0) break; + hasher.update(buf[0..bytes_read]); + } + + var hash: [32]u8 = undefined; + hasher.final(&hash); + return encodeHexLower(allocator, &hash); +} + +/// Calculate combined hash of multiple files (sorted by path) +pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths: []const []const u8) ![]u8 { + var hasher = std.crypto.hash.sha2.Sha256.init(.{}); + + // Copy and sort paths for deterministic hashing + var sorted_paths = std.ArrayList([]const u8).initCapacity(allocator, file_paths.len) catch |err| { + return err; + }; + defer sorted_paths.deinit(allocator); + + for (file_paths) |path| { + try sorted_paths.append(allocator, path); + } + + std.sort.block([]const u8, sorted_paths.items, {}, struct { + fn lessThan(_: void, a: []const u8, b: []const u8) bool { + return std.mem.order(u8, a, b) == .lt; + } + }.lessThan); + + // Hash each file + for (sorted_paths.items) |path| { + hasher.update(path); + hasher.update(&[_]u8{0}); // Separator + + const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, path }); + defer allocator.free(full_path); + + const file_hash = try hashFile(allocator, full_path); + defer allocator.free(file_hash); + + hasher.update(file_hash); + hasher.update(&[_]u8{0}); // Separator + } + + var hash: [32]u8 = undefined; + hasher.final(&hash); + return encodeHexLower(allocator, &hash); +} + /// Calculate commit ID for a directory (SHA256 of tree state) pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 { var hasher = std.crypto.hash.sha2.Sha256.init(.{}); @@ -54,7 +113,7 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 { defer dir.close(); var walker = try dir.walk(allocator); - defer walker.deinit(); + defer walker.deinit(allocator); // Collect and sort paths for deterministic hashing var paths: std.ArrayList([]const u8) = .{}; diff --git a/cli/src/utils/io.zig b/cli/src/utils/io.zig index b79d99e..dd3d75c 100644 --- a/cli/src/utils/io.zig +++ b/cli/src/utils/io.zig @@ -57,3 +57,76 @@ pub fn stdoutWriter() std.Io.Writer { pub fn stderrWriter() std.Io.Writer { return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 }; } + +/// Write a JSON value to stdout +pub fn stdoutWriteJson(value: std.json.Value) !void { + var buf = std.ArrayList(u8).empty; + defer buf.deinit(std.heap.page_allocator); + try writeJSONValue(buf.writer(std.heap.page_allocator), value); + var stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO }; + try stdout_file.writeAll(buf.items); + try stdout_file.writeAll("\n"); +} + +fn writeJSONValue(writer: anytype, v: std.json.Value) !void { + switch (v) { + .null => try writer.writeAll("null"), + .bool => |b| try writer.print("{}", .{b}), + .integer => |i| try writer.print("{d}", .{i}), + .float => |f| try writer.print("{d}", .{f}), + .string => |s| try writeJSONString(writer, s), + .array => |arr| { + try writer.writeAll("["); + for (arr.items, 0..) |item, idx| { + if (idx > 0) try writer.writeAll(","); + try writeJSONValue(writer, item); + } + try writer.writeAll("]"); + }, + .object => |obj| { + try writer.writeAll("{"); + var first = true; + var it = obj.iterator(); + while (it.next()) |entry| { + if (!first) try writer.writeAll(","); + first = false; + try writer.print("\"{s}\":", .{entry.key_ptr.*}); + try writeJSONValue(writer, entry.value_ptr.*); + } + try writer.writeAll("}"); + }, + .number_string => |s| try writer.print("{s}", .{s}), + } +} + +fn writeJSONString(writer: anytype, s: []const u8) !void { + try writer.writeAll("\""); + for (s) |c| { + switch (c) { + '"' => try writer.writeAll("\\\""), + '\\' => try writer.writeAll("\\\\"), + '\n' => try writer.writeAll("\\n"), + '\r' => try writer.writeAll("\\r"), + '\t' => try writer.writeAll("\\t"), + else => { + if (c < 0x20) { + var buf: [6]u8 = undefined; + buf[0] = '\\'; + buf[1] = 'u'; + buf[2] = '0'; + buf[3] = '0'; + buf[4] = hexDigit(@intCast((c >> 4) & 0x0F)); + buf[5] = hexDigit(@intCast(c & 0x0F)); + try writer.writeAll(&buf); + } else { + try writer.writeAll(&[_]u8{c}); + } + }, + } + } + try writer.writeAll("\""); +} + +fn hexDigit(v: u8) u8 { + return if (v < 10) ('0' + v) else ('a' + (v - 10)); +} diff --git a/internal/api/routes.go b/internal/api/routes.go index 4953289..ac96eb2 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -61,6 +61,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) { s.taskQueue, s.db, s.config.BuildAuthConfig(), + nil, // privacyEnforcer - not enabled for now ) // Create jupyter handler diff --git a/internal/api/ws/validate.go b/internal/api/ws/validate.go index 261815a..7dddc95 100644 --- a/internal/api/ws/validate.go +++ b/internal/api/ws/validate.go @@ -72,6 +72,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er rmCommitCheck := map[string]interface{}{"ok": true} rmLocCheck := map[string]interface{}{"ok": true} rmLifecycle := map[string]interface{}{"ok": true} + var narrativeWarnings, outcomeWarnings []string // Determine expected location based on task status expectedLocation := "running" @@ -124,6 +125,24 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er rmLifecycle["ok"] = false ok = false } + + // Validate narrative if present + if rm.Narrative != nil { + nv := manifest.ValidateNarrative(rm.Narrative) + if len(nv.Errors) > 0 { + ok = false + } + narrativeWarnings = nv.Warnings + } + + // Validate outcome if present + if rm.Outcome != nil { + ov := manifest.ValidateOutcome(rm.Outcome) + if len(ov.Errors) > 0 { + ok = false + } + outcomeWarnings = ov.Warnings + } } checks["run_manifest"] = rmCheck @@ -159,8 +178,10 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er checks["snapshot"] = snapCheck report := map[string]interface{}{ - "ok": ok, - "checks": checks, + "ok": ok, + "checks": checks, + "narrative_warnings": narrativeWarnings, + "outcome_warnings": outcomeWarnings, } payloadBytes, _ := json.Marshal(report) return h.sendDataPacket(conn, "validate", payloadBytes)