From 0ea5e9385fe353a19dad65386cce255bd4ee04e5 Mon Sep 17 00:00:00 2001 From: Drew Nutter Date: Thu, 19 Dec 2024 08:26:10 -0300 Subject: [PATCH] feat: add Window and SharedPointerWindow data structures --- src/common/lru.zig | 28 +----- src/sync/lib.zig | 4 + src/sync/shared_memory.zig | 148 +++++++++++++++++++++++++++++ src/utils/collections.zig | 188 +++++++++++++++++++++++++++++++++++++ 4 files changed, 345 insertions(+), 23 deletions(-) create mode 100644 src/sync/shared_memory.zig diff --git a/src/common/lru.zig b/src/common/lru.zig index 88b550d44..f2307afb0 100644 --- a/src/common/lru.zig +++ b/src/common/lru.zig @@ -1,9 +1,13 @@ const std = @import("std"); +const sig = @import("../sig.zig"); + const Allocator = std.mem.Allocator; const TailQueue = std.TailQueue; const testing = std.testing; const Mutex = std.Thread.Mutex; +const normalizeDeinitFunction = sig.sync.normalizeDeinitFunction; + pub const Kind = enum { locking, non_locking, @@ -30,29 +34,7 @@ pub fn LruCacheCustom( comptime DeinitContext: type, comptime deinitFn_: anytype, ) type { - const deinitFn = switch (@TypeOf(deinitFn_)) { - fn (*V, DeinitContext) void => deinitFn_, - - fn (V, DeinitContext) void => struct { - fn f(v: *V, ctx: DeinitContext) void { - deinitFn_(v.*, ctx); - } - }.f, - - fn (V) void => struct { - fn f(v: *V, _: DeinitContext) void { - V.deinit(v.*); - } - }.f, - - fn (*V) void => struct { - fn f(v: *V, _: DeinitContext) void { - V.deinit(v); - } - }.f, - - else => @compileError("unsupported deinit function type"), - }; + const deinitFn = normalizeDeinitFunction(V, DeinitContext, deinitFn_); return struct { mux: if (kind == .locking) Mutex else void, allocator: Allocator, diff --git a/src/sync/lib.zig b/src/sync/lib.zig index 8463a8632..ca0b54eb9 100644 --- a/src/sync/lib.zig +++ b/src/sync/lib.zig @@ -3,6 +3,7 @@ pub const ref = @import("ref.zig"); pub const mux = @import("mux.zig"); pub const once_cell = @import("once_cell.zig"); pub const reference_counter = @import("reference_counter.zig"); +pub const shared_memory = @import("shared_memory.zig"); pub const thread_pool = @import("thread_pool.zig"); pub const exit = @import("exit.zig"); @@ -13,6 +14,9 @@ pub const RwMux = mux.RwMux; pub const OnceCell = once_cell.OnceCell; pub const ReferenceCounter = reference_counter.ReferenceCounter; pub const RcSlice = reference_counter.RcSlice; +pub const SharedPointerWindow = shared_memory.SharedPointerWindow; pub const ThreadPool = thread_pool.ThreadPool; pub const ExitCondition = exit.ExitCondition; + +pub const normalizeDeinitFunction = shared_memory.normalizeDeinitFunction; diff --git a/src/sync/shared_memory.zig b/src/sync/shared_memory.zig new file mode 100644 index 000000000..ca1d1b319 --- /dev/null +++ b/src/sync/shared_memory.zig @@ -0,0 +1,148 @@ +const std = @import("std"); +const sig = @import("../sig.zig"); + +const Allocator = std.mem.Allocator; + +/// Thread safe Window that stores a single copy of data that is shared with +/// readers as a pointer to the underlying data inside the Window. +/// +/// - this struct owns the data and is responsible for freeing it +/// - the lifetime of returned pointer exceeds every read operation of that pointer, +/// even if another thread evicts it from the Window, as long as `release` is used properly. +pub fn SharedPointerWindow( + T: type, + deinitItem_: anytype, + DeinitContext: type, +) type { + const Window = sig.utils.collections.Window; + const Rc = sig.sync.Rc; + const deinitItem = normalizeDeinitFunction(T, DeinitContext, deinitItem_); + + return struct { + allocator: Allocator, + window: Window(Rc(T)), + center: std.atomic.Value(usize), + lock: std.Thread.RwLock = .{}, + deinit_context: DeinitContext, + discard_buf: []?Rc(T), + + const Self = @This(); + + pub fn init( + allocator: Allocator, + len: usize, + start: usize, + deinit_context: DeinitContext, + ) !Self { + const discard_buf = try allocator.alloc(?Rc(T), len); + return .{ + .allocator = allocator, + .window = try Window(Rc(T)).init(allocator, len, start), + .deinit_context = deinit_context, + .center = std.atomic.Value(usize).init(start), + .discard_buf = discard_buf, + }; + } + + pub fn deinit(self: Self) void { + for (self.window.state) |maybe_item| if (maybe_item) |item| { + self.releaseItem(item); + }; + self.window.deinit(); + } + + pub fn put(self: *Self, index: usize, value: T) !void { + const ptr = try Rc(T).create(self.allocator); + ptr.payload().* = value; + + const item_to_release = blk: { + self.lock.lock(); + defer self.lock.unlock(); + break :blk self.window.put(index, ptr) catch null; + }; + + if (item_to_release) |old| { + self.releaseItem(old); + } + } + + /// call `release` when you're done with the pointer + pub fn get(self: *Self, index: usize) ?*const T { + self.lock.lockShared(); + defer self.lock.lockShared(); + + if (self.window.get(index)) |element| { + return element.acquire().payload(); + } else { + return null; + } + } + + /// call `release` when you're done with the pointer + pub fn contains(self: *Self, index: usize) bool { + self.lock.lockShared(); + defer self.lock.lockShared(); + + return self.window.contains(index); + } + + pub fn realign(self: *Self, new_center: usize) void { + if (new_center == self.center.load(.monotonic)) return; + + const items_to_release = blk: { + self.lock.lock(); + defer self.lock.lock(); + + self.center.store(new_center, .monotonic); + break :blk self.window.realignGet(new_center, self.discard_buf); + }; + + for (items_to_release) |maybe_item| { + if (maybe_item) |item| { + self.releaseItem(item); + } + } + } + + pub fn release(self: *Self, ptr: *const T) void { + self.releaseItem(Rc(T).fromPayload(ptr)); + } + + fn releaseItem(self: *const Self, item: Rc(T)) void { + if (item.release()) |bytes_to_free| { + deinitItem(item.payload(), self.deinit_context); + self.allocator.free(bytes_to_free); + } + } + }; +} + +pub fn normalizeDeinitFunction( + V: type, + DeinitContext: type, + deinitFn: anytype, +) fn (*V, DeinitContext) void { + return switch (@TypeOf(deinitFn)) { + fn (*V, DeinitContext) void => deinitFn, + + fn (V, DeinitContext) void => struct { + fn f(v: *V, ctx: DeinitContext) void { + deinitFn(v.*, ctx); + } + }.f, + + fn (V) void => struct { + fn f(v: *V, _: DeinitContext) void { + V.deinit(v.*); + } + }.f, + + fn (*V) void => struct { + fn f(v: *V, _: DeinitContext) void { + V.deinit(v); + } + }.f, + + else => @compileError("unsupported deinit function type"), + }; +} diff --git a/src/utils/collections.zig b/src/utils/collections.zig index 0d09de6b2..853577100 100644 --- a/src/utils/collections.zig +++ b/src/utils/collections.zig @@ -486,6 +486,98 @@ pub fn orderSlices( return if (a.len == b.len) .eq else if (a.len > b.len) .gt else .lt; } +pub fn Window(T: type) type { + return struct { + state: []?T, + center: usize, + offset: usize, + + const Self = @This(); + + pub fn init(allocator: Allocator, len: usize, start: usize) !Self { + const state = try allocator.alloc(?T, len); + @memset(state, null); + return .{ + .state = state, + .center = start, + .offset = len - (start % len), + }; + } + + pub fn deinit(self: Self, allocator: Allocator) void { + allocator.free(self.state); + } + + pub fn put(self: *Self, index: usize, item: T) error{OutOfBounds}!?T { + if (!self.isInRange(index)) { + return error.OutOfBounds; + } + const ptr = self.getAssumed(index); + const old = ptr.*; + ptr.* = item; + return old; + } + + pub fn get(self: *Self, index: usize) ?T { + return if (self.isInRange(index)) self.getAssumed(index).* else null; + } + + pub fn contains(self: *Self, index: usize) bool { + return self.isInRange(index) and self.getAssumed(index).* != null; + } + + pub fn realignGet(self: *Self, new_center: usize, deletion_buf: []?T) []?T { + return self.realignImpl(new_center, deletion_buf).?; + } + + pub fn realign(self: *Self, new_center: usize) void { + _ = self.realignImpl(new_center, null); + } + + fn realignImpl(self: *Self, new_center: usize, optional_deletion_buf: ?[]?T) ?[]?T { + var return_buf: ?[]?T = null; + if (self.center < new_center) { + const num_to_delete = @min(new_center - self.center, self.state.len); + const low = self.lowest(); + return_buf = self.deleteRange(low, low + num_to_delete, optional_deletion_buf); + } else if (self.center > new_center) { + const num_to_delete = @min(self.center - new_center, self.state.len); + const top = self.highest() + 1; + return_buf = self.deleteRange(top - num_to_delete, top, optional_deletion_buf); + } + self.center = new_center; + return return_buf; + } + + fn isInRange(self: *const Self, index: usize) bool { + return index <= self.highest() and index >= self.lowest(); + } + + fn highest(self: *const Self) usize { + return self.center + self.state.len / 2 - (self.state.len + 1) % 2; + } + + fn lowest(self: *const Self) usize { + return self.center - self.state.len / 2; + } + + fn getAssumed(self: *Self, index: usize) *?T { + return &self.state[(index + self.offset) % self.state.len]; + } + + fn deleteRange(self: *Self, start: usize, end: usize, optional_deletion_buf: ?[]?T) ?[]?T { + for (start..end, 0..) |in_index, out_index| { + const item = self.getAssumed(in_index); + if (optional_deletion_buf) |deletion_buf| { + deletion_buf[out_index] = item.*; + } + item.* = null; + } + return if (optional_deletion_buf) |buf| buf[0 .. end - start] else null; + } + }; +} + const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; const expectEqualSlices = std.testing.expectEqualSlices; @@ -655,3 +747,99 @@ test "binarySearch slice of slices" { binarySearch([]const u8, &slices, &.{ 0, 0, 21 }, .any, order), ); } + +test "Window starts empty" { + var mgr = try Window(u64).init(std.testing.allocator, 5, 7); + defer mgr.deinit(std.testing.allocator); + for (0..20) |i| { + try std.testing.expect(null == mgr.get(i)); + } +} + +test "Window populates and repopulates (odd)" { + var mgr = try Window(u64).init(std.testing.allocator, 5, 7); + defer mgr.deinit(std.testing.allocator); + for (0..20) |i| { + const result = mgr.put(i, i * 10); + if (i < 5 or i > 9) { + try std.testing.expectError(error.OutOfBounds, result); + } else { + try std.testing.expectEqual(null, try result); + } + } + for (0..20) |i| { + const result = mgr.put(i, i * 100); + if (i < 5 or i > 9) { + try std.testing.expectError(error.OutOfBounds, result); + } else { + try std.testing.expectEqual(i * 10, try result); + } + } + for (0..20) |i| { + const result = mgr.get(i); + if (i < 5 or i > 9) { + try std.testing.expectEqual(null, result); + } else { + try std.testing.expectEqual(i * 100, result); + } + } +} + +test "Window populates (even)" { + var mgr = try Window(u64).init(std.testing.allocator, 4, 7); + defer mgr.deinit(std.testing.allocator); + for (0..20) |i| { + const result = mgr.put(i, i * 10); + if (i < 5 or i > 8) { + try std.testing.expectError(error.OutOfBounds, result); + } else { + try std.testing.expectEqual(null, try result); + } + } + for (0..20) |i| { + const result = mgr.get(i); + if (i < 5 or i > 8) { + try std.testing.expectEqual(null, result); + } else { + try std.testing.expectEqual(i * 10, result); + } + } +} + +test "Window realigns" { + var mgr = try Window(u64).init(std.testing.allocator, 4, 7); + defer mgr.deinit(std.testing.allocator); + for (5..9) |i| { + _ = try mgr.put(i, i * 10); + } + var deletion_buf: [4]?u64 = undefined; + + const deletion = mgr.realignGet(8, deletion_buf[0..]); + try std.testing.expectEqual(1, deletion.len); + try std.testing.expectEqual(50, deletion[0]); + + const deletion2 = mgr.realignGet(6, deletion_buf[0..]); + try std.testing.expectEqual(2, deletion2.len); + try std.testing.expectEqual(80, deletion2[0]); + try std.testing.expectEqual(null, deletion2[1]); + + for (0..20) |i| { + const result = mgr.get(i); + if (i < 6 or i > 7) { + try std.testing.expectEqual(null, result); + } else { + try std.testing.expectEqual(i * 10, result); + } + } + + const deletion3 = mgr.realignGet(20, deletion_buf[0..]); + try std.testing.expectEqual(4, deletion3.len); + try std.testing.expectEqual(null, deletion3[0]); + try std.testing.expectEqual(null, deletion3[1]); + try std.testing.expectEqual(60, deletion3[2]); + try std.testing.expectEqual(70, deletion3[3]); + + for (0..40) |i| { + try std.testing.expectEqual(null, mgr.get(i)); + } +}