diff --git a/cli/src/commands/cancel.zig b/cli/src/commands/cancel.zig index 90ceebc..e24aeff 100644 --- a/cli/src/commands/cancel.zig +++ b/cli/src/commands/cancel.zig @@ -4,48 +4,13 @@ const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const logging = @import("../utils/logging.zig"); const colors = @import("../utils/colors.zig"); +const auth = @import("../utils/auth.zig"); pub const CancelOptions = struct { force: bool = false, json: bool = false, }; -const UserContext = struct { - name: []const u8, - admin: bool, - allocator: std.mem.Allocator, - - pub fn deinit(self: *UserContext) void { - self.allocator.free(self.name); - } -}; - -fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { - // Validate API key by making a simple API call to the server - const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); - defer allocator.free(ws_url); - - // Try to connect with the API key to validate it - var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { - switch (err) { - error.ConnectionRefused => return error.ConnectionFailed, - error.NetworkUnreachable => return error.ServerUnreachable, - error.InvalidURL => return error.ConfigInvalid, - else => return error.AuthenticationFailed, - } - }; - defer client.close(); - - // For now, create a user context after successful authentication - // In a real implementation, this would get user info from the server - const user_name = try allocator.dupe(u8, "authenticated_user"); - return UserContext{ - .name = user_name, - .admin = false, - .allocator = allocator, - }; -} - pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var options = CancelOptions{}; var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| { @@ -89,14 +54,14 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } // Authenticate with server to get user context - var user_context = try authenticateUser(allocator, config); + var user_context = try auth.authenticateUser(allocator, config); defer user_context.deinit(); const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); defer allocator.free(api_key_hash); // Connect to WebSocket and send cancel messages - const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + const ws_url = try config.getWebSocketUrl(allocator); defer allocator.free(ws_url); var client = try ws.Client.connect(allocator, ws_url, config.api_key); @@ -139,7 +104,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } } -fn cancelSingleJob(allocator: std.mem.Allocator, client: *ws.Client, user_context: UserContext, job_name: []const u8, options: CancelOptions, api_key_hash: []const u8) !void { +fn cancelSingleJob(allocator: std.mem.Allocator, client: *ws.Client, user_context: auth.UserContext, job_name: []const u8, options: CancelOptions, api_key_hash: []const u8) !void { try client.sendCancelJob(job_name, api_key_hash); // Receive structured response with user context diff --git a/cli/src/commands/status.zig b/cli/src/commands/status.zig index 7168bff..68b3256 100644 --- a/cli/src/commands/status.zig +++ b/cli/src/commands/status.zig @@ -3,6 +3,7 @@ const Config = @import("../config.zig").Config; const ws = @import("../net/ws/client.zig"); const crypto = @import("../utils/crypto.zig"); const colors = @import("../utils/colors.zig"); +const auth = @import("../utils/auth.zig"); pub const StatusOptions = struct { json: bool = false, @@ -11,16 +12,6 @@ pub const StatusOptions = struct { watch_interval: u32 = 5, }; -const UserContext = struct { - name: []const u8, - admin: bool, - allocator: std.mem.Allocator, - - pub fn deinit(self: *UserContext) void { - self.allocator.free(self.name); - } -}; - pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var options = StatusOptions{}; @@ -52,7 +43,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { return error.APIKeyMissing; } - var user_context = UserContext{ + var user_context = auth.UserContext{ .name = try allocator.dupe(u8, "default"), .admin = true, .allocator = allocator, @@ -66,11 +57,11 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } } -fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: UserContext, options: StatusOptions) !void { +fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: auth.UserContext, options: StatusOptions) !void { const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); defer allocator.free(api_key_hash); - const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + const ws_url = try config.getWebSocketUrl(allocator); defer allocator.free(ws_url); var client = try ws.Client.connect(allocator, ws_url, config.api_key); @@ -80,7 +71,7 @@ fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: U try client.receiveAndHandleStatusResponse(allocator, user_context, options); } -fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: UserContext, options: StatusOptions) !void { +fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: auth.UserContext, options: StatusOptions) !void { colors.printInfo("Starting watch mode (interval: {d}s). Press Ctrl+C to stop.\n", .{options.watch_interval}); while (true) { diff --git a/cli/src/config.zig b/cli/src/config.zig index e47f23f..104ad5c 100644 --- a/cli/src/config.zig +++ b/cli/src/config.zig @@ -203,4 +203,12 @@ pub const Config = struct { allocator.free(gpu_mem); } } + + /// Get WebSocket URL for connecting to the server + pub fn getWebSocketUrl(self: Config, allocator: std.mem.Allocator) ![]u8 { + const protocol = if (self.worker_port == 443) "wss" else "ws"; + return std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{ + protocol, self.worker_host, self.worker_port, + }); + } }; diff --git a/cli/src/utils.zig b/cli/src/utils.zig index 25b4e70..b1f6d38 100644 --- a/cli/src/utils.zig +++ b/cli/src/utils.zig @@ -1,6 +1,8 @@ // Utilities module - exports all utility modules +pub const auth = @import("utils/auth.zig"); pub const colors = @import("utils/colors.zig"); pub const crypto = @import("utils/crypto.zig"); +pub const flags = @import("utils/flags.zig"); pub const history = @import("utils/history.zig"); pub const io = @import("utils/io.zig"); pub const logging = @import("utils/logging.zig"); diff --git a/cli/src/utils/auth.zig b/cli/src/utils/auth.zig new file mode 100644 index 0000000..060d171 --- /dev/null +++ b/cli/src/utils/auth.zig @@ -0,0 +1,69 @@ +const std = @import("std"); +const Config = @import("../config.zig").Config; +const ws = @import("../net/ws/client.zig"); +const colors = @import("colors.zig"); + +/// UserContext represents an authenticated user session +pub const UserContext = struct { + name: []const u8, + admin: bool, + allocator: std.mem.Allocator, + + pub fn deinit(self: *UserContext) void { + self.allocator.free(self.name); + } +}; + +/// Authenticate user with the server using API key +pub fn authenticateUser(allocator: std.mem.Allocator, config: Config) !UserContext { + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + // Try to connect with the API key to validate it + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return error.AuthenticationFailed, + } + }; + defer client.close(); + + // For now, create a user context after successful authentication + // In a real implementation, this would get user info from the server + const user_name = try allocator.dupe(u8, "authenticated_user"); + return UserContext{ + .name = user_name, + .admin = false, + .allocator = allocator, + }; +} + +/// Authenticate user and return context with connection +pub fn authenticateWithConnection( + allocator: std.mem.Allocator, + config: Config, +) !struct { UserContext, ws.Client } { + const ws_url = try std.fmt.allocPrint(allocator, "ws://{s}:9101/ws", .{config.worker_host}); + defer allocator.free(ws_url); + + var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| { + switch (err) { + error.ConnectionRefused => return error.ConnectionFailed, + error.NetworkUnreachable => return error.ServerUnreachable, + error.InvalidURL => return error.ConfigInvalid, + else => return error.AuthenticationFailed, + } + }; + errdefer client.close(); + + const user_name = try allocator.dupe(u8, "authenticated_user"); + const user_ctx = UserContext{ + .name = user_name, + .admin = false, + .allocator = allocator, + }; + + return .{ user_ctx, client }; +} diff --git a/cli/src/utils/flags.zig b/cli/src/utils/flags.zig new file mode 100644 index 0000000..30b4032 --- /dev/null +++ b/cli/src/utils/flags.zig @@ -0,0 +1,35 @@ +/// Common command-line flags shared across commands +pub const CommonFlags = struct { + json: bool = false, + dry_run: bool = false, + verbose: bool = false, + force: bool = false, + validate: bool = false, +}; + +/// Parse common flags from arguments +pub fn parseCommonFlags(args: []const []const u8, flags: *CommonFlags) struct { consumed: usize, had_help: bool } { + var i: usize = 0; + var had_help = false; + while (i < args.len) : (i += 1) { + const arg = args[i]; + if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { + had_help = true; + } else if (std.mem.eql(u8, arg, "--json")) { + flags.json = true; + } else if (std.mem.eql(u8, arg, "--dry-run")) { + flags.dry_run = true; + } else if (std.mem.eql(u8, arg, "--verbose") or std.mem.eql(u8, arg, "-v")) { + flags.verbose = true; + } else if (std.mem.eql(u8, arg, "--force") or std.mem.eql(u8, arg, "-f")) { + flags.force = true; + } else if (std.mem.eql(u8, arg, "--validate")) { + flags.validate = true; + } else { + break; // Stop at first non-flag argument + } + } + return .{ .consumed = i, .had_help = had_help }; +} + +const std = @import("std");