const std = @import("std"); const tls = @import("../tls.zig"); const resolve = @import("../resolve.zig"); const handshake = @import("../handshake.zig"); /// Transport abstraction for WebSocket connections /// Supports both raw TCP and TLS-wrapped connections pub const Transport = union(enum) { tcp: std.net.Stream, tls: *tls.TlsStream, pub fn read(self: Transport, buffer: []u8) !usize { return switch (self) { .tcp => |s| s.read(buffer), .tls => |s| s.read(buffer), }; } pub fn write(self: Transport, buffer: []const u8) !void { return switch (self) { .tcp => |s| s.writeAll(buffer), .tls => |s| s.write(buffer), }; } pub fn close(self: *Transport) void { switch (self.*) { .tcp => |*s| s.close(), .tls => |s| s.close(), } self.* = undefined; } }; /// Connection result returned after successful connection pub const ConnectionResult = struct { transport: Transport, host: []const u8, port: u16, is_tls: bool, }; /// Connection configuration pub const ConnectionConfig = struct { url: []const u8, api_key: []const u8, max_retries: u32 = 3, }; /// Parse WebSocket URL and extract host, port, TLS info pub fn parseUrl(url: []const u8) !struct { host: []const u8, port: u16, is_tls: bool, path: []const u8 } { const is_tls = std.mem.startsWith(u8, url, "wss://"); const host_start = std.mem.indexOf(u8, url, "//") orelse return error.InvalidURL; const host_port_start = host_start + 2; const path_start = std.mem.indexOfPos(u8, url, host_port_start, "/") orelse url.len; const colon_pos = std.mem.indexOfPos(u8, url, host_port_start, ":"); const host_end = blk: { if (colon_pos) |pos| { if (pos < path_start) break :blk pos; } break :blk path_start; }; const host = url[host_port_start..host_end]; var port: u16 = if (is_tls) 9101 else 9100; if (colon_pos) |pos| { if (pos < path_start) { const port_start = pos + 1; const port_end = std.mem.indexOfPos(u8, url, port_start, "/") orelse url.len; const port_str = url[port_start..port_end]; port = try std.fmt.parseInt(u16, port_str, 10); } } const path = url[path_start..]; return .{ .host = host, .port = port, .is_tls = is_tls, .path = path }; } /// Create transport (TCP or TLS) for connection pub fn createTransport(allocator: std.mem.Allocator, host: []const u8, port: u16, is_tls: bool) !Transport { const tcp_stream = try std.net.tcpConnectToAddress(try resolve.resolveHostAddress(allocator, host, port)); if (is_tls) { const tls_stream = try allocator.create(tls.TlsStream); errdefer allocator.destroy(tls_stream); tls_stream.* = tls.TlsStream.init(allocator, tcp_stream, host) catch |err| { allocator.destroy(tls_stream); tcp_stream.close(); if (err == error.TlsLibraryRequired) { std.log.warn("TLS (wss://) support requires external TLS library integration. Falling back to ws://", .{}); return error.TLSNotSupported; } return err; }; return Transport{ .tls = tls_stream }; } else { return Transport{ .tcp = tcp_stream }; } } /// Connect with retry logic pub fn connectWithRetry( allocator: std.mem.Allocator, url: []const u8, api_key: []const u8, max_retries: u32, ) !ConnectionResult { var retry_count: u32 = 0; var last_error: anyerror = error.ConnectionFailed; while (retry_count < max_retries) { const result = tryConnect(allocator, url, api_key) catch |err| { last_error = err; retry_count += 1; if (retry_count < max_retries) { const delay_ms = @min(1000 * retry_count, 5000); std.log.warn("Connection failed (attempt {d}/{d}), retrying in {d}s...\n", .{ retry_count, max_retries, delay_ms / 1000 }); std.Thread.sleep(@as(u64, delay_ms) * std.time.ns_per_ms); } continue; }; if (retry_count > 0) { std.log.info("Connected successfully after {d} attempts\n", .{retry_count + 1}); } return result; } return last_error; } /// Single connection attempt fn tryConnect(allocator: std.mem.Allocator, url: []const u8, api_key: []const u8) !ConnectionResult { const parsed = try parseUrl(url); var transport = try createTransport(allocator, parsed.host, parsed.port, parsed.is_tls); errdefer transport.close(); try handshake.handshake(allocator, transport, parsed.host, url, api_key); const host_copy = try allocator.dupe(u8, parsed.host); errdefer allocator.free(host_copy); return .{ .transport = transport, .host = host_copy, .port = parsed.port, .is_tls = parsed.is_tls, }; } /// Close connection and cleanup pub fn closeConnection(transport: *Transport, host: []const u8, allocator: std.mem.Allocator) void { transport.close(); if (host.len > 0) { allocator.free(host); } }