diff --git a/src/network/protocol/messages/filterload.zig b/src/network/protocol/messages/filterload.zig new file mode 100644 index 0000000..afa51db --- /dev/null +++ b/src/network/protocol/messages/filterload.zig @@ -0,0 +1,122 @@ +const std = @import("std"); +const protocol = @import("../lib.zig"); +const genericChecksum = @import("lib.zig").genericChecksum; + +const Sha256 = std.crypto.hash.sha2.Sha256; +const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; + +/// FilterLoadMessage represents the "filterload" message +/// +/// https://developer.bitcoin.org/reference/p2p_networking.html#filterload +pub const FilterLoadMessage = struct { + filter: []const u8, + hash_func: u32, + tweak: u32, + flags: u8, + + const Self = @This(); + + pub fn name() *const [12]u8 { + return protocol.CommandNames.FILTERLOAD ++ [_]u8{0} ** 2; + } + + /// Returns the message checksum + pub fn checksum(self: *const Self) [4]u8 { + return genericChecksum(self); + } + + /// Serialize the message as bytes and write them to the Writer. + pub fn serializeToWriter(self: *const Self, w: anytype) !void { + comptime { + if (!std.meta.hasFn(@TypeOf(w), "writeInt")) @compileError("Expects w to have fn 'writeInt'."); + if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects w to have fn 'writeAll'."); + } + + const compact_filter_len = CompactSizeUint.new(self.filter.len); + try compact_filter_len.encodeToWriter(w); + + try w.writeAll(self.filter); + try w.writeInt(u32, self.hash_func, .little); + try w.writeInt(u32, self.tweak, .little); + try w.writeInt(u8, self.flags, .little); + } + + /// Serialize a message as bytes and return them. + pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 { + const serialized_len = self.hintSerializedLen(); + + const ret = try allocator.alloc(u8, serialized_len); + errdefer allocator.free(ret); + + try self.serializeToSlice(ret); + + return ret; + } + + /// Serialize a message as bytes and write them to the buffer. + pub fn serializeToSlice(self: *const Self, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + try self.serializeToWriter(fbs.writer()); + } + + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self { + comptime { + if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'."); + if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects r to have fn 'readNoEof'."); + } + + const filter_len = (try CompactSizeUint.decodeReader(r)).value(); + const filter = try allocator.alloc(u8, filter_len); + errdefer allocator.free(filter); + try r.readNoEof(filter); + + const hash_func = try r.readInt(u32, .little); + const tweak = try r.readInt(u32, .little); + const flags = try r.readInt(u8, .little); + + return Self{ + .filter = filter, + .hash_func = hash_func, + .tweak = tweak, + .flags = flags, + }; + } + + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self { + var fbs = std.io.fixedBufferStream(bytes); + return try Self.deserializeReader(allocator, fbs.reader()); + } + + pub fn hintSerializedLen(self: *const Self) usize { + const fixed_length = 4 + 4 + 1; // hash_func (4 bytes) + tweak (4 bytes) + flags (1 byte) + const compact_filter_len = CompactSizeUint.new(self.filter.len).hint_encoded_len(); + return compact_filter_len + self.filter.len + fixed_length; + } + + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + allocator.free(self.filter); + } +}; + +test "ok_fullflow_filterload_message" { + const allocator = std.testing.allocator; + + const filter = "this is a test filter"; + var fl = FilterLoadMessage{ + .filter = filter, + .hash_func = 0xdeadbeef, + .tweak = 0xfeedface, + .flags = 0x02, + }; + + const payload = try fl.serialize(allocator); + defer allocator.free(payload); + + var deserialized_fl = try FilterLoadMessage.deserializeSlice(allocator, payload); + defer deserialized_fl.deinit(allocator); + + try std.testing.expectEqualSlices(u8, filter, deserialized_fl.filter); + try std.testing.expect(fl.hash_func == deserialized_fl.hash_func); + try std.testing.expect(fl.tweak == deserialized_fl.tweak); + try std.testing.expect(fl.flags == deserialized_fl.flags); +} diff --git a/src/network/protocol/messages/lib.zig b/src/network/protocol/messages/lib.zig index 9fd7b56..5a27d62 100644 --- a/src/network/protocol/messages/lib.zig +++ b/src/network/protocol/messages/lib.zig @@ -16,6 +16,7 @@ pub const FilterAddMessage = @import("filteradd.zig").FilterAddMessage; const Sha256 = std.crypto.hash.sha2.Sha256; pub const NotFoundMessage = @import("notfound.zig").NotFoundMessage; pub const SendHeadersMessage = @import("sendheaders.zig").SendHeadersMessage; +pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage; pub const InventoryVector = struct { type: u32, @@ -63,6 +64,7 @@ pub const MessageTypes = enum { filteradd, notfound, sendheaders, + filterload, }; pub const Message = union(MessageTypes) { @@ -81,6 +83,7 @@ pub const Message = union(MessageTypes) { filteradd: FilterAddMessage, notfound: NotFoundMessage, sendheaders: SendHeadersMessage, + filterload: FilterLoadMessage, pub fn name(self: Message) *const [12]u8 { return switch (self) { @@ -99,6 +102,7 @@ pub const Message = union(MessageTypes) { .filteradd => |m| @TypeOf(m).name(), .notfound => |m| @TypeOf(m).name(), .sendheaders => |m| @TypeOf(m).name(), + .filterload => |m| @TypeOf(m).name(), }; } @@ -119,6 +123,7 @@ pub const Message = union(MessageTypes) { .filteradd => |*m| m.deinit(allocator), .notfound => {}, .sendheaders => {}, + .filterload => {}, } } @@ -139,6 +144,7 @@ pub const Message = union(MessageTypes) { .filteradd => |*m| m.checksum(), .notfound => |*m| m.checksum(), .sendheaders => |*m| m.checksum(), + .filterload => |*m| m.checksum(), }; } @@ -159,6 +165,7 @@ pub const Message = union(MessageTypes) { .filteradd => |*m| m.hintSerializedLen(), .notfound => |m| m.hintSerializedLen(), .sendheaders => |m| m.hintSerializedLen(), + .filterload => |*m| m.hintSerializedLen(), }; } }; diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index 4bf9595..bf0d022 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -139,6 +139,8 @@ pub fn receiveMessage( protocol.messages.Message{ .feefilter = try protocol.messages.FeeFilterMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.SendHeadersMessage.name())) protocol.messages.Message{ .sendheaders = try protocol.messages.SendHeadersMessage.deserializeReader(allocator, r) } + else if (std.mem.eql(u8, &command, protocol.messages.FilterLoadMessage.name())) + protocol.messages.Message{ .filterload = try protocol.messages.FilterLoadMessage.deserializeReader(allocator, r) } else { try r.skipBytes(payload_len, .{}); // Purge the wire return error.UnknownMessage;