From ea9f955caaa16e5b8b4efa47e3942c0cf2a1eafb Mon Sep 17 00:00:00 2001 From: Muki Kiboigo Date: Wed, 25 Sep 2024 13:19:23 -0700 Subject: [PATCH] continued reworking of sockets --- src/async/busy_loop.zig | 18 ++- src/async/completion.zig | 12 +- src/async/io_uring.zig | 328 +++++++++++++++++++++------------------ src/async/lib.zig | 15 +- src/core/server.zig | 81 +++++++--- src/core/socket.zig | 2 +- 6 files changed, 260 insertions(+), 196 deletions(-) diff --git a/src/async/busy_loop.zig b/src/async/busy_loop.zig index ec8aa1e..aac8fd8 100644 --- a/src/async/busy_loop.zig +++ b/src/async/busy_loop.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const builtin = @import("builtin"); const Completion = @import("completion.zig").Completion; const Async = @import("lib.zig").Async; @@ -67,19 +68,22 @@ pub const AsyncBusyLoop = struct { .accept => |inner| { const com_ptr = &self.completions[reaped]; - const res: i32 = blk: { + const res: std.posix.socket_t = blk: { const ad = std.posix.accept(inner.socket, null, null, 0) catch |e| { if (e == error.WouldBlock) { continue; } else { - break :blk -1; + switch (comptime builtin.target.os.tag) { + .windows => break :blk std.os.windows.ws2_32.INVALID_SOCKET, + else => break :blk -1, + } } }; - break :blk @intCast(ad); + break :blk ad; }; - com_ptr.result = @intCast(res); + com_ptr.result = .{ .socket = res }; com_ptr.context = inner.context; _ = list.swapRemove(i); i -|= 1; @@ -100,7 +104,7 @@ pub const AsyncBusyLoop = struct { break :blk @intCast(rd); }; - com_ptr.result = @intCast(len); + com_ptr.result = .{ .value = @intCast(len) }; com_ptr.context = inner.context; _ = list.swapRemove(i); i -|= 1; @@ -121,7 +125,7 @@ pub const AsyncBusyLoop = struct { break :blk @intCast(sd); }; - com_ptr.result = @intCast(len); + com_ptr.result = .{ .value = @intCast(len) }; com_ptr.context = inner.context; _ = list.swapRemove(i); i -|= 1; @@ -131,7 +135,7 @@ pub const AsyncBusyLoop = struct { .close => |inner| { const com_ptr = &self.completions[reaped]; std.posix.close(inner.socket); - com_ptr.result = 0; + com_ptr.result = .{ .value = 0 }; com_ptr.context = inner.context; _ = list.swapRemove(i); i -|= 1; diff --git a/src/async/completion.zig b/src/async/completion.zig index 58f4a03..4ec399d 100644 --- a/src/async/completion.zig +++ b/src/async/completion.zig @@ -1,11 +1,11 @@ const std = @import("std"); -const CompletionResult = union { - socket: std.posix.socket_t, - result: i32, -}; - pub const Completion = struct { + pub const Result = union { + socket: std.posix.socket_t, + value: i32, + }; + context: *anyopaque, - result: i32, + result: Result, }; diff --git a/src/async/io_uring.zig b/src/async/io_uring.zig index 64b3f5d..196724c 100644 --- a/src/async/io_uring.zig +++ b/src/async/io_uring.zig @@ -1,181 +1,201 @@ const std = @import("std"); const assert = std.debug.assert; const Completion = @import("completion.zig").Completion; - const Async = @import("lib.zig").Async; const AsyncError = @import("lib.zig").AsyncError; const AsyncOptions = @import("lib.zig").AsyncOptions; const log = std.log.scoped(.@"zzz/async/io_uring"); -pub const AsyncIoUring = struct { - const base_flags = std.os.linux.IORING_SETUP_COOP_TASKRUN | std.os.linux.IORING_SETUP_SINGLE_ISSUER; - runner: *anyopaque, - - pub fn init(allocator: std.mem.Allocator, options: AsyncOptions) !AsyncIoUring { - const uring = blk: { - if (options.in_thread) { - assert(options.root_async != null); - const parent_uring: *std.os.linux.IoUring = @ptrCast(@alignCast(options.root_async.?.runner)); - assert(parent_uring.fd >= 0); - - // Initialize using the WQ from the parent ring. - const flags: u32 = base_flags | std.os.linux.IORING_SETUP_ATTACH_WQ; - - var params = std.mem.zeroInit(std.os.linux.io_uring_params, .{ - .flags = flags, - .wq_fd = @as(u32, @intCast(parent_uring.fd)), - }); - - const uring = try allocator.create(std.os.linux.IoUring); - uring.* = try std.os.linux.IoUring.init_params( - std.math.ceilPowerOfTwoAssert(u16, options.size_connections_max), - ¶ms, - ); - - break :blk uring; - } else { - // Initalize IO Uring - const uring = try allocator.create(std.os.linux.IoUring); - uring.* = try std.os.linux.IoUring.init( - std.math.ceilPowerOfTwoAssert(u16, options.size_connections_max), - base_flags, - ); - - break :blk uring; - } - }; - - return AsyncIoUring{ .runner = uring }; - } - - pub fn deinit(self: *Async, allocator: std.mem.Allocator) void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - uring.deinit(); - allocator.destroy(uring); - } - - pub fn queue_accept( - self: *Async, - context: *anyopaque, - socket: std.posix.socket_t, - ) AsyncError!void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - _ = uring.accept(@as(u64, @intFromPtr(context)), socket, null, null, 0) catch |e| switch (e) { - error.SubmissionQueueFull => return AsyncError.QueueFull, - else => unreachable, - }; - } - - pub fn queue_recv( - self: *Async, - context: *anyopaque, - socket: std.posix.socket_t, - buffer: []u8, - ) AsyncError!void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - _ = uring.recv(@as(u64, @intFromPtr(context)), socket, .{ .buffer = buffer }, 0) catch |e| switch (e) { - error.SubmissionQueueFull => return AsyncError.QueueFull, - else => unreachable, - }; - } - - pub fn queue_send( - self: *Async, - context: *anyopaque, - socket: std.posix.socket_t, - buffer: []const u8, - ) AsyncError!void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - _ = uring.send(@as(u64, @intFromPtr(context)), socket, buffer, 0) catch |e| switch (e) { - error.SubmissionQueueFull => return AsyncError.QueueFull, - else => unreachable, - }; - } - - pub fn queue_close( - self: *Async, - context: *anyopaque, - fd: std.posix.fd_t, - ) AsyncError!void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - _ = uring.close(@as(u64, @intFromPtr(context)), fd) catch |e| switch (e) { - error.SubmissionQueueFull => return AsyncError.QueueFull, - else => unreachable, - }; - } - - pub fn submit(self: *Async) AsyncError!void { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - _ = uring.submit() catch |e| switch (e) { - // TODO: match error states. - else => unreachable, - }; - } - - pub fn reap(self: *Async) AsyncError![]Completion { - const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); - // NOTE: this can be dynamic and then we would just have to make a single call - // which would probably be better. - var cqes: [256]std.os.linux.io_uring_cqe = [_]std.os.linux.io_uring_cqe{undefined} ** 256; - var total_reaped: u64 = 0; - - const min_length = @min(cqes.len, self.completions.len); - { - // only the first one blocks waiting for an initial set of completions. - const count = uring.copy_cqes(cqes[0..min_length], 1) catch |e| switch (e) { - // TODO: match error states. +pub fn AsyncIoUring(comptime Provision: type) type { + return struct { + const Self = @This(); + const base_flags = std.os.linux.IORING_SETUP_COOP_TASKRUN | std.os.linux.IORING_SETUP_SINGLE_ISSUER; + runner: *anyopaque, + + pub fn init(allocator: std.mem.Allocator, options: AsyncOptions) !Self { + const uring = blk: { + if (options.in_thread) { + assert(options.root_async != null); + const parent_uring: *std.os.linux.IoUring = @ptrCast( + @alignCast(options.root_async.?.runner), + ); + assert(parent_uring.fd >= 0); + + // Initialize using the WQ from the parent ring. + const flags: u32 = base_flags | std.os.linux.IORING_SETUP_ATTACH_WQ; + + var params = std.mem.zeroInit(std.os.linux.io_uring_params, .{ + .flags = flags, + .wq_fd = @as(u32, @intCast(parent_uring.fd)), + }); + + const uring = try allocator.create(std.os.linux.IoUring); + uring.* = try std.os.linux.IoUring.init_params( + std.math.ceilPowerOfTwoAssert(u16, options.size_connections_max), + ¶ms, + ); + + break :blk uring; + } else { + // Initalize IO Uring + const uring = try allocator.create(std.os.linux.IoUring); + uring.* = try std.os.linux.IoUring.init( + std.math.ceilPowerOfTwoAssert(u16, options.size_connections_max), + base_flags, + ); + + break :blk uring; + } + }; + + return Self{ .runner = uring }; + } + + pub fn deinit(self: *Async, allocator: std.mem.Allocator) void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + uring.deinit(); + allocator.destroy(uring); + } + + pub fn queue_accept( + self: *Async, + context: *anyopaque, + socket: std.posix.socket_t, + ) AsyncError!void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + _ = uring.accept(@as(u64, @intFromPtr(context)), socket, null, null, 0) catch |e| switch (e) { + error.SubmissionQueueFull => return AsyncError.QueueFull, else => unreachable, }; + } - total_reaped += count; + pub fn queue_recv( + self: *Async, + context: *anyopaque, + socket: std.posix.socket_t, + buffer: []u8, + ) AsyncError!void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + _ = uring.recv(@as(u64, @intFromPtr(context)), socket, .{ .buffer = buffer }, 0) catch |e| switch (e) { + error.SubmissionQueueFull => return AsyncError.QueueFull, + else => unreachable, + }; + } - // Copy over the first one. - for (0..total_reaped) |i| { - self.completions[i] = Completion{ - .result = cqes[i].res, - .context = @ptrFromInt(@as(usize, @intCast(cqes[i].user_data))), - }; - } + pub fn queue_send( + self: *Async, + context: *anyopaque, + socket: std.posix.socket_t, + buffer: []const u8, + ) AsyncError!void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + _ = uring.send(@as(u64, @intFromPtr(context)), socket, buffer, 0) catch |e| switch (e) { + error.SubmissionQueueFull => return AsyncError.QueueFull, + else => unreachable, + }; } - while (total_reaped < self.completions.len) { - const start = total_reaped; - const remaining = self.completions.len - total_reaped; + pub fn queue_close( + self: *Async, + context: *anyopaque, + fd: std.posix.fd_t, + ) AsyncError!void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + _ = uring.close(@as(u64, @intFromPtr(context)), fd) catch |e| switch (e) { + error.SubmissionQueueFull => return AsyncError.QueueFull, + else => unreachable, + }; + } - const count = uring.copy_cqes(cqes[0..remaining], 0) catch |e| switch (e) { + pub fn submit(self: *Async) AsyncError!void { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + _ = uring.submit() catch |e| switch (e) { // TODO: match error states. else => unreachable, }; + } + + pub fn reap(self: *Async) AsyncError![]Completion { + const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner)); + // NOTE: this can be dynamic and then we would just have to make a single call + // which would probably be better. + var cqes: [256]std.os.linux.io_uring_cqe = [_]std.os.linux.io_uring_cqe{undefined} ** 256; + var total_reaped: u64 = 0; + + const min_length = @min(cqes.len, self.completions.len); + { + // only the first one blocks waiting for an initial set of completions. + const count = uring.copy_cqes(cqes[0..min_length], 1) catch |e| switch (e) { + // TODO: match error states. + else => unreachable, + }; + + total_reaped += count; + + // Copy over the first one. + for (0..total_reaped) |i| { + const provision: *Provision = @ptrFromInt(@as(usize, cqes[i].user_data)); + + const result: Completion.Result = if (provision.job == .accept) .{ + .socket = cqes[i].res, + } else .{ + .value = cqes[i].res, + }; - if (count == 0) { - return self.completions[0..total_reaped]; + self.completions[i] = Completion{ + .result = result, + .context = @ptrFromInt(@as(usize, @intCast(cqes[i].user_data))), + }; + } } - total_reaped += count; + while (total_reaped < self.completions.len) { + const start = total_reaped; + const remaining = self.completions.len - total_reaped; - for (start..total_reaped) |i| { - const cqe_index = i - start; - self.completions[i] = Completion{ - .result = cqes[cqe_index].res, - .context = @ptrFromInt(@as(usize, @intCast(cqes[cqe_index].user_data))), + const count = uring.copy_cqes(cqes[0..remaining], 0) catch |e| switch (e) { + // TODO: match error states. + else => unreachable, }; + + if (count == 0) { + return self.completions[0..total_reaped]; + } + + total_reaped += count; + + for (start..total_reaped) |i| { + const cqe_index = i - start; + const provision: *Provision = @ptrFromInt(@as(usize, cqes[cqe_index].user_data)); + + const result: Completion.Result = if (provision.job == .accept) .{ + .socket = cqes[cqe_index].res, + } else .{ + .value = cqes[cqe_index].res, + }; + + self.completions[i] = Completion{ + .result = result, + .context = @ptrFromInt(@as(usize, @intCast(cqes[cqe_index].user_data))), + }; + } } + + return self.completions[0..total_reaped]; } - return self.completions[0..total_reaped]; - } - - pub fn to_async(self: *AsyncIoUring) Async { - return Async{ - .runner = self.runner, - ._deinit = deinit, - ._queue_accept = queue_accept, - ._queue_recv = queue_recv, - ._queue_send = queue_send, - ._queue_close = queue_close, - ._submit = submit, - ._reap = reap, - }; - } -}; + pub fn to_async(self: *Self) Async { + return Async{ + .runner = self.runner, + ._deinit = deinit, + ._queue_accept = queue_accept, + ._queue_recv = queue_recv, + ._queue_send = queue_send, + ._queue_close = queue_close, + ._submit = submit, + ._reap = reap, + }; + } + }; +} diff --git a/src/async/lib.zig b/src/async/lib.zig index d3fc1ca..4ace051 100644 --- a/src/async/lib.zig +++ b/src/async/lib.zig @@ -14,18 +14,18 @@ pub const AsyncType = union(enum) { /// `https://kernel.dk/io_uring.pdf` io_uring, /// Available on most targets. - /// Slowest by far. + /// Slowest. Workable for development. + /// Should rely on one of the faster backends for production. /// Relies on non-blocking sockets and busy loop polling. busy_loop, /// Available on all targets. - /// You have to provide all of the methods. custom: type, }; pub fn auto_async_match() AsyncType { - switch (builtin.os.tag) { + switch (comptime builtin.target.os.tag) { .linux => { - if (builtin.os.isAtLeast(.linux, .{ .major = 5, .minor = 1, .patch = 0 })) |geq| { + if (comptime builtin.target.os.isAtLeast(.linux, .{ .major = 5, .minor = 1, .patch = 0 })) |geq| { if (geq) { return AsyncType.io_uring; } else { @@ -36,6 +36,9 @@ pub fn auto_async_match() AsyncType { } }, .windows => return AsyncType.busy_loop, + .ios, .macos, .watchos, .tvos, .visionos => return AsyncType.busy_loop, + .kfreebsd, .freebsd, .openbsd, .netbsd, .dragonfly => return AsyncType.busy_loop, + .solaris, .illumos => return AsyncType.busy_loop, else => @compileError("Unsupported platform! Provide a custom Async backend."), } } @@ -45,8 +48,12 @@ pub const AsyncError = error{ }; pub const AsyncOptions = struct { + /// The root Async that this should inherit + /// parameters from. This is useful for io_uring. root_async: ?Async = null, + /// Is this Async instance spawning within a thread? in_thread: bool = false, + /// Maximum number of connections for this backend. size_connections_max: u16, }; diff --git a/src/core/server.zig b/src/core/server.zig index 3eb5df9..7085c56 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -4,6 +4,7 @@ const assert = std.debug.assert; const log = std.log.scoped(.@"zzz/server"); const Completion = @import("../async/completion.zig").Completion; +const CompletionResult = @import("../async/completion.zig").CompletionResult; const Async = @import("../async/lib.zig").Async; const auto_async_match = @import("../async/lib.zig").auto_async_match; const AsyncType = @import("../async/lib.zig").AsyncType; @@ -174,8 +175,7 @@ pub fn Server( if (self.socket) |socket| { switch (comptime Socket) { std.posix.socket_t => std.posix.close(socket), - std.os.windows.ws2_32.SOCKET => std.os.windows.closesocket(socket), - else => {}, + else => unreachable, } } @@ -197,7 +197,12 @@ pub fn Server( assert(port > 0); defer assert(self.socket != null); - const addr = try std.net.Address.resolveIp(host, port); + const addr = blk: { + switch (comptime builtin.os.tag) { + .windows => break :blk try std.net.Address.parseIp(host, port), + else => break :blk try std.net.Address.resolveIp(host, port), + } + }; const socket: std.posix.socket_t = blk: { const socket_flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK; @@ -254,7 +259,10 @@ pub fn Server( log.info("{d} - closing connection", .{provision.index}); std.posix.close(provision.socket); - provision.socket = -1; + switch (comptime builtin.target.os.tag) { + .windows => provision.socket = std.os.windows.ws2_32.INVALID_SOCKET, + else => provision.socket = -1, + } _ = provision.arena.reset(.{ .retain_with_limit = config.size_connection_arena_retain }); provision.data.clean(); provision.recv_buffer.clearRetainingCapacity(); @@ -341,30 +349,55 @@ pub fn Server( switch (p.job) { .accept => { accept_queued = false; - const socket: Socket = completion.result; + const socket: Socket = completion.result.socket; - if (socket < 0) { - log.err("socket accept failed", .{}); - continue :reap_loop; - } + const index = blk: { + switch (comptime builtin.target.os.tag) { + .windows => { + if (socket == std.os.windows.ws2_32.INVALID_SOCKET) { + log.err("socket accept failed", .{}); + continue :reap_loop; + } + + break :blk 0; + }, + else => { + if (socket < 0) { + log.err("socket accept failed", .{}); + continue :reap_loop; + } + + break :blk socket; + }, + } + }; // Borrow a provision from the pool otherwise close the socket. - const borrowed = provision_pool.borrow(@intCast(completion.result)) catch { + const borrowed = provision_pool.borrow(@intCast(index)) catch { log.warn("out of provision pool entries", .{}); std.posix.close(socket); continue :reap_loop; }; - switch (comptime Socket) { - std.posix.socket_t => { - try std.posix.setsockopt( + // Disable Nagle's. + try std.posix.setsockopt( + socket, + std.posix.IPPROTO.TCP, + std.posix.TCP.NODELAY, + &std.mem.toBytes(@as(c_int, 1)), + ); + + // Set non-blocking. + switch (comptime builtin.target.os.tag) { + .windows => { + var mode: u32 = 1; + _ = std.os.windows.ws2_32.ioctlsocket( socket, - std.posix.IPPROTO.TCP, - std.posix.TCP.NODELAY, - &std.mem.toBytes(@as(c_int, 1)), + std.os.windows.ws2_32.FIONBIO, + &mode, ); - - // Set this socket as non-blocking. + }, + else => { const current_flags = try std.posix.fcntl(socket, std.posix.F.GETFL, 0); var new_flags = @as( std.posix.O, @@ -374,7 +407,6 @@ pub fn Server( const arg: u32 = @bitCast(new_flags); _ = try std.posix.fcntl(socket, std.posix.F.SETFL, arg); }, - else => {}, } const provision = borrowed.item; @@ -417,7 +449,7 @@ pub fn Server( log.debug("{d} - recv triggered", .{p.index}); // If the socket is closed. - if (completion.result <= 0) { + if (completion.result.value <= 0) { if (comptime security == .tls) { const tls_ptr: *?TLS = &tls_pool[p.index]; clean_tls(tls_ptr); @@ -427,7 +459,7 @@ pub fn Server( continue :reap_loop; } - const read_count: u32 = @intCast(completion.result); + const read_count: u32 = @intCast(completion.result.value); inner.count += read_count; const pre_recv_buffer = p.buffer[0..read_count]; @@ -508,7 +540,7 @@ pub fn Server( .send => |*send_type| { log.debug("{d} - send triggered", .{p.index}); - const send_count = completion.result; + const send_count = completion.result.value; if (send_count <= 0) { if (comptime security == .tls) { @@ -640,6 +672,7 @@ pub fn Server( switch (comptime Socket) { std.posix.socket_t => try std.posix.listen(server_socket, self.config.size_backlog), + // TODO: Handle freestanding targets that use an u32 here. else => unreachable, } @@ -652,7 +685,7 @@ pub fn Server( switch (comptime async_type) { .io_uring => { - var uring = try AsyncIoUring.init( + var uring = try AsyncIoUring(Provision).init( self.allocator, options, ); @@ -734,7 +767,7 @@ pub fn Server( switch (comptime async_type) { .io_uring => { - var uring = AsyncIoUring.init( + var uring = AsyncIoUring(Provision).init( z_config.allocator, options, ) catch unreachable; diff --git a/src/core/socket.zig b/src/core/socket.zig index 16bfc52..afeac88 100644 --- a/src/core/socket.zig +++ b/src/core/socket.zig @@ -1,7 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); -pub const Socket = switch (builtin.os.tag) { +pub const Socket = switch (builtin.target.os.tag) { .freestanding => u32, else => std.posix.socket_t, };