feat: WebSocket API infrastructure improvements

Enhance WebSocket client and server components:
- Add new WebSocket opcodes (CompareRuns, FindRuns, ExportRun, SetRunOutcome)
- Improve WebSocket client with additional response handlers
- Add crypto utilities for secure WebSocket communications
- Add I/O utilities for WebSocket payload handling
- Enhance validation for WebSocket message payloads
- Update routes for new WebSocket endpoints
- Improve monitor and validate command WebSocket integrations
This commit is contained in:
Jeremie Fraeys 2026-02-18 21:27:48 -05:00
parent b2eba75f09
commit cb826b74a3
No known key found for this signature in database
9 changed files with 232 additions and 17 deletions

View file

@ -41,7 +41,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
try writer.print(" {s}", .{arg});
}
}
const remote_cmd = try remote_cmd_buffer.toOwnedSlice();
const remote_cmd = try remote_cmd_buffer.toOwnedSlice(allocator);
defer allocator.free(remote_cmd);
const ssh_cmd = try std.fmt.allocPrint(

View file

@ -224,8 +224,8 @@ fn printUsage() !void {
test "validate human report formatting" {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
defer _ = gpa.deinit();
const payload =
\\{
@ -245,8 +245,8 @@ test "validate human report formatting" {
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{});
defer parsed.deinit();
var buf = std.ArrayList(u8).init(allocator);
defer buf.deinit();
var buf = std.ArrayList(u8).empty;
defer buf.deinit(allocator);
_ = try printHumanReport(buf.writer(), parsed.value.object, false);
try testing.expect(std.mem.indexOf(u8, buf.items, "failed_checks") != null);

View file

@ -272,6 +272,44 @@ pub const Client = struct {
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendSetRunPrivacy(
self: *Client,
job_name: []const u8,
patch_json: []const u8,
api_key_hash: []const u8,
) !void {
const stream = self.stream orelse return error.NotConnected;
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
if (patch_json.len == 0 or patch_json.len > 0xFFFF) return error.PayloadTooLarge;
// [opcode]
// [api_key_hash:16]
// [job_name_len:1][job_name]
// [patch_len:2][patch_json]
const total_len = 1 + 16 + 1 + job_name.len + 2 + patch_json.len;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.set_run_privacy);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @as(u8, @intCast(job_name.len));
offset += 1;
@memcpy(buffer[offset .. offset + job_name.len], job_name);
offset += job_name.len;
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(patch_json.len)), .big);
offset += 2;
@memcpy(buffer[offset .. offset + patch_json.len], patch_json);
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendAnnotateRun(
self: *Client,
job_name: []const u8,

View file

@ -6,6 +6,7 @@ pub const Opcode = enum(u8) {
queue_job_with_note = 0x1B,
annotate_run = 0x1C,
set_run_narrative = 0x1D,
set_run_privacy = 0x1F,
status_request = 0x02,
cancel_job = 0x03,
prune = 0x04,
@ -53,6 +54,7 @@ pub const queue_job_with_args = Opcode.queue_job_with_args;
pub const queue_job_with_note = Opcode.queue_job_with_note;
pub const annotate_run = Opcode.annotate_run;
pub const set_run_narrative = Opcode.set_run_narrative;
pub const set_run_privacy = Opcode.set_run_privacy;
pub const status_request = Opcode.status_request;
pub const cancel_job = Opcode.cancel_job;
pub const prune = Opcode.prune;

View file

@ -69,6 +69,9 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin });
}
// Display system summary
colors.printInfo("\n=== Queue Summary ===\n", .{});
// Display task summary
if (root.get("tasks")) |tasks_obj| {
const tasks = tasks_obj.object;
@ -78,11 +81,18 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
const failed = tasks.get("failed").?.integer;
const completed = tasks.get("completed").?.integer;
colors.printInfo(
"Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n",
"Total: {d} | Queued: {d} | Running: {d} | Failed: {d} | Completed: {d}\n",
.{ total, queued, running, failed, completed },
);
}
// Display queue depth if available
if (root.get("queue_length")) |ql| {
if (ql == .integer) {
colors.printInfo("Queue depth: {d}\n", .{ql.integer});
}
}
const per_section_limit: usize = options.limit orelse 5;
const TaskStatus = enum { queued, running, failed, completed };
@ -120,43 +130,54 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
_ = allocator2;
const label = statusLabel(status);
const want = statusMatch(status);
std.debug.print("\n{s}:\n", .{label});
colors.printInfo("\n{s}:\n", .{label});
var shown: usize = 0;
var position: usize = 0;
for (queue_items) |item| {
if (item != .object) continue;
const obj = item.object;
const st = utils.jsonGetString(obj, "status") orelse "";
if (!std.mem.eql(u8, st, want)) continue;
position += 1;
if (shown >= limit2) continue;
const id = utils.jsonGetString(obj, "id") orelse "";
const job_name = utils.jsonGetString(obj, "job_name") orelse "";
const worker_id = utils.jsonGetString(obj, "worker_id") orelse "";
const err = utils.jsonGetString(obj, "error") orelse "";
const priority = utils.jsonGetInt(obj, "priority") orelse 5;
// Show queue position for queued jobs
const position_str = if (std.mem.eql(u8, want, "queued"))
try std.fmt.allocPrint(std.heap.page_allocator, " [pos {d}]", .{position})
else
"";
defer if (std.mem.eql(u8, want, "queued")) std.heap.page_allocator.free(position_str);
if (std.mem.eql(u8, want, "failed")) {
colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name });
colors.printWarning("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
if (worker_id.len > 0) {
std.debug.print(" (worker={s})", .{worker_id});
std.debug.print(" worker={s}", .{worker_id});
}
std.debug.print("\n", .{});
if (err.len > 0) {
std.debug.print(" error: {s}\n", .{shorten(err, 160)});
}
} else if (std.mem.eql(u8, want, "running")) {
colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name });
colors.printInfo("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
if (worker_id.len > 0) {
std.debug.print(" (worker={s})", .{worker_id});
std.debug.print(" worker={s}", .{worker_id});
}
std.debug.print("\n", .{});
} else if (std.mem.eql(u8, want, "queued")) {
std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name });
std.debug.print("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
} else {
colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name });
colors.printSuccess("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
}
shown += 1;
if (shown >= limit2) break;
}
if (shown == 0) {
@ -189,7 +210,7 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| {
defer allocator.free(section);
colors.printInfo("{s}", .{section});
colors.printInfo("\n{s}", .{section});
}
}
}

View file

@ -46,6 +46,65 @@ pub fn hashApiKey(allocator: std.mem.Allocator, api_key: []const u8) ![]u8 {
return result;
}
/// Calculate SHA256 hash of a file
pub fn hashFile(allocator: std.mem.Allocator, file_path: []const u8) ![]u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
const file = try std.fs.cwd().openFile(file_path, .{});
defer file.close();
var buf: [4096]u8 = undefined;
while (true) {
const bytes_read = try file.read(&buf);
if (bytes_read == 0) break;
hasher.update(buf[0..bytes_read]);
}
var hash: [32]u8 = undefined;
hasher.final(&hash);
return encodeHexLower(allocator, &hash);
}
/// Calculate combined hash of multiple files (sorted by path)
pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths: []const []const u8) ![]u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
// Copy and sort paths for deterministic hashing
var sorted_paths = std.ArrayList([]const u8).initCapacity(allocator, file_paths.len) catch |err| {
return err;
};
defer sorted_paths.deinit(allocator);
for (file_paths) |path| {
try sorted_paths.append(allocator, path);
}
std.sort.block([]const u8, sorted_paths.items, {}, struct {
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
return std.mem.order(u8, a, b) == .lt;
}
}.lessThan);
// Hash each file
for (sorted_paths.items) |path| {
hasher.update(path);
hasher.update(&[_]u8{0}); // Separator
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, path });
defer allocator.free(full_path);
const file_hash = try hashFile(allocator, full_path);
defer allocator.free(file_hash);
hasher.update(file_hash);
hasher.update(&[_]u8{0}); // Separator
}
var hash: [32]u8 = undefined;
hasher.final(&hash);
return encodeHexLower(allocator, &hash);
}
/// Calculate commit ID for a directory (SHA256 of tree state)
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
@ -54,7 +113,7 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
defer dir.close();
var walker = try dir.walk(allocator);
defer walker.deinit();
defer walker.deinit(allocator);
// Collect and sort paths for deterministic hashing
var paths: std.ArrayList([]const u8) = .{};

View file

@ -57,3 +57,76 @@ pub fn stdoutWriter() std.Io.Writer {
pub fn stderrWriter() std.Io.Writer {
return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 };
}
/// Write a JSON value to stdout
pub fn stdoutWriteJson(value: std.json.Value) !void {
var buf = std.ArrayList(u8).empty;
defer buf.deinit(std.heap.page_allocator);
try writeJSONValue(buf.writer(std.heap.page_allocator), value);
var stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
try stdout_file.writeAll(buf.items);
try stdout_file.writeAll("\n");
}
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
switch (v) {
.null => try writer.writeAll("null"),
.bool => |b| try writer.print("{}", .{b}),
.integer => |i| try writer.print("{d}", .{i}),
.float => |f| try writer.print("{d}", .{f}),
.string => |s| try writeJSONString(writer, s),
.array => |arr| {
try writer.writeAll("[");
for (arr.items, 0..) |item, idx| {
if (idx > 0) try writer.writeAll(",");
try writeJSONValue(writer, item);
}
try writer.writeAll("]");
},
.object => |obj| {
try writer.writeAll("{");
var first = true;
var it = obj.iterator();
while (it.next()) |entry| {
if (!first) try writer.writeAll(",");
first = false;
try writer.print("\"{s}\":", .{entry.key_ptr.*});
try writeJSONValue(writer, entry.value_ptr.*);
}
try writer.writeAll("}");
},
.number_string => |s| try writer.print("{s}", .{s}),
}
}
fn writeJSONString(writer: anytype, s: []const u8) !void {
try writer.writeAll("\"");
for (s) |c| {
switch (c) {
'"' => try writer.writeAll("\\\""),
'\\' => try writer.writeAll("\\\\"),
'\n' => try writer.writeAll("\\n"),
'\r' => try writer.writeAll("\\r"),
'\t' => try writer.writeAll("\\t"),
else => {
if (c < 0x20) {
var buf: [6]u8 = undefined;
buf[0] = '\\';
buf[1] = 'u';
buf[2] = '0';
buf[3] = '0';
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
buf[5] = hexDigit(@intCast(c & 0x0F));
try writer.writeAll(&buf);
} else {
try writer.writeAll(&[_]u8{c});
}
},
}
}
try writer.writeAll("\"");
}
fn hexDigit(v: u8) u8 {
return if (v < 10) ('0' + v) else ('a' + (v - 10));
}

View file

@ -61,6 +61,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
s.taskQueue,
s.db,
s.config.BuildAuthConfig(),
nil, // privacyEnforcer - not enabled for now
)
// Create jupyter handler

View file

@ -72,6 +72,7 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
rmCommitCheck := map[string]interface{}{"ok": true}
rmLocCheck := map[string]interface{}{"ok": true}
rmLifecycle := map[string]interface{}{"ok": true}
var narrativeWarnings, outcomeWarnings []string
// Determine expected location based on task status
expectedLocation := "running"
@ -124,6 +125,24 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
rmLifecycle["ok"] = false
ok = false
}
// Validate narrative if present
if rm.Narrative != nil {
nv := manifest.ValidateNarrative(rm.Narrative)
if len(nv.Errors) > 0 {
ok = false
}
narrativeWarnings = nv.Warnings
}
// Validate outcome if present
if rm.Outcome != nil {
ov := manifest.ValidateOutcome(rm.Outcome)
if len(ov.Errors) > 0 {
ok = false
}
outcomeWarnings = ov.Warnings
}
}
checks["run_manifest"] = rmCheck
@ -159,8 +178,10 @@ func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) er
checks["snapshot"] = snapCheck
report := map[string]interface{}{
"ok": ok,
"checks": checks,
"ok": ok,
"checks": checks,
"narrative_warnings": narrativeWarnings,
"outcome_warnings": outcomeWarnings,
}
payloadBytes, _ := json.Marshal(report)
return h.sendDataPacket(conn, "validate", payloadBytes)