diff --git a/src/network/protocol/lib.zig b/src/network/protocol/lib.zig index 67e8733..25a84f4 100644 --- a/src/network/protocol/lib.zig +++ b/src/network/protocol/lib.zig @@ -1,6 +1,6 @@ pub const messages = @import("./messages/lib.zig"); pub const NetworkAddress = @import("NetworkAddress.zig"); - +pub const BlockHeader = @import("../../types/BlockHeader.zig").BlockHeader; /// Network services pub const ServiceFlags = struct { pub const NODE_NETWORK: u64 = 0x1; diff --git a/src/network/protocol/messages/getblocks.zig b/src/network/protocol/messages/getblocks.zig index c7eb170..9e85ad2 100644 --- a/src/network/protocol/messages/getblocks.zig +++ b/src/network/protocol/messages/getblocks.zig @@ -152,10 +152,10 @@ test "ok_full_flow_GetBlocksMessage" { .header_hashes = try allocator.alloc([32]u8, 2), .stop_hash = [_]u8{0} ** 32, }; - defer allocator.free(gb.header_hashes); // Fill in the header_hashes + for (gb.header_hashes) |*hash| { for (hash) |*byte| { byte.* = 0xab; diff --git a/src/network/protocol/messages/lib.zig b/src/network/protocol/messages/lib.zig index facc74a..6397712 100644 --- a/src/network/protocol/messages/lib.zig +++ b/src/network/protocol/messages/lib.zig @@ -6,6 +6,7 @@ pub const GetaddrMessage = @import("getaddr.zig").GetaddrMessage; pub const GetblocksMessage = @import("getblocks.zig").GetblocksMessage; pub const PingMessage = @import("ping.zig").PingMessage; pub const PongMessage = @import("pong.zig").PongMessage; +pub const MerkleBlockMessage = @import("merkleblock.zig").MerkleBlockMessage; pub const FeeFilterMessage = @import("feefilter.zig").FeeFilterMessage; pub const SendCmpctMessage = @import("sendcmpct.zig").SendCmpctMessage; pub const FilterClearMessage = @import("filterclear.zig").FilterClearMessage; @@ -18,6 +19,7 @@ pub const MessageTypes = enum { getblocks, ping, pong, + merkleblock, sendcmpct, feefilter, filterclear, @@ -31,6 +33,7 @@ pub const Message = union(MessageTypes) { getblocks: GetblocksMessage, ping: PingMessage, pong: PongMessage, + merkleblock: MerkleBlockMessage, sendcmpct: SendCmpctMessage, feefilter: FeeFilterMessage, filterclear: FilterClearMessage, @@ -44,6 +47,7 @@ pub const Message = union(MessageTypes) { .getblocks => |m| @TypeOf(m).name(), .ping => |m| @TypeOf(m).name(), .pong => |m| @TypeOf(m).name(), + .merkleblock => |m| @TypeOf(m).name(), .sendcmpct => |m| @TypeOf(m).name(), .feefilter => |m| @TypeOf(m).name(), .filterclear => |m| @TypeOf(m).name(), @@ -59,6 +63,7 @@ pub const Message = union(MessageTypes) { .getblocks => |m| m.deinit(allocator), .ping => {}, .pong => {}, + .merkleblock => |m| m.deinit(allocator), .sendcmpct => {}, .feefilter => {}, .filterclear => {}, @@ -74,6 +79,7 @@ pub const Message = union(MessageTypes) { .getblocks => |m| m.checksum(), .ping => |m| m.checksum(), .pong => |m| m.checksum(), + .merkleblock => |m| m.checksum(), .sendcmpct => |m| m.checksum(), .feefilter => |m| m.checksum(), .filterclear => |m| m.checksum(), @@ -89,6 +95,7 @@ pub const Message = union(MessageTypes) { .getblocks => |m| m.hintSerializedLen(), .ping => |m| m.hintSerializedLen(), .pong => |m| m.hintSerializedLen(), + .merkleblock => |m| m.hintSerializedLen(), .sendcmpct => |m| m.hintSerializedLen(), .feefilter => |m| m.hintSerializedLen(), .filterclear => |m| m.hintSerializedLen(), diff --git a/src/network/protocol/messages/merkleblock.zig b/src/network/protocol/messages/merkleblock.zig new file mode 100644 index 0000000..cb452c8 --- /dev/null +++ b/src/network/protocol/messages/merkleblock.zig @@ -0,0 +1,166 @@ +const std = @import("std"); +const protocol = @import("../lib.zig"); + +const Sha256 = std.crypto.hash.sha2.Sha256; +const BlockHeader = @import("../../../types/BlockHeader.zig"); +const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; + +/// MerkleBlockMessage represents the "MerkleBlock" message +/// +/// https://developer.bitcoin.org/reference/p2p_networking.html#merkleblock +pub const MerkleBlockMessage = struct { + block_header: BlockHeader, + transaction_count: u32, + hashes: [][32]u8, + flags: []u8, + + const Self = @This(); + + pub fn name() *const [12]u8 { + return protocol.CommandNames.MERKLEBLOCK; + } + + /// Returns the message checksum + pub fn checksum(self: *const Self) [4]u8 { + var digest: [32]u8 = undefined; + var hasher = Sha256.init(.{}); + self.serializeToWriter(hasher.writer()) catch unreachable; + hasher.final(&digest); + + Sha256.hash(&digest, &digest, .{}); + + return digest[0..4].*; + } + + /// Free the allocated memory + pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void { + allocator.free(self.flags); + allocator.free(self.hashes); + } + + /// Serialize the message as bytes and write them to the Writer. + pub fn serializeToWriter(self: *const Self, w: anytype) !void { + try self.block_header.serializeToWriter(w); + try w.writeInt(u32, self.transaction_count, .little); + const hash_count = CompactSizeUint.new(self.hashes.len); + try hash_count.encodeToWriter(w); + + for (self.hashes) |*hash| { + try w.writeAll(hash); + } + const flag_bytes = CompactSizeUint.new(self.flags.len); + + try flag_bytes.encodeToWriter(w); + try w.writeAll(self.flags); + } + /// Serialize a message as bytes and write them to the buffer. + /// + /// buffer.len must be >= than self.hintSerializedLen() + pub fn serializeToSlice(self: *const Self, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + try self.serializeToWriter(fbs.writer()); + } + + /// Serialize a message as bytes and return them. + pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 { + const serialized_len = self.hintSerializedLen(); + const ret = try allocator.alloc(u8, serialized_len); + errdefer allocator.free(ret); + + try self.serializeToSlice(ret); + + return ret; + } + + /// Returns the hint of the serialized length of the message. + pub fn hintSerializedLen(self: *const Self) usize { + // 80 bytes for the block header, 4 bytes for the transaction count + const fixed_length = 84; + const hash_count_len: usize = CompactSizeUint.new(self.hashes.len).hint_encoded_len(); + const compact_hashes_len = 32 * self.hashes.len; + const flag_bytes_len: usize = CompactSizeUint.new(self.flags.len).hint_encoded_len(); + const flags_len = self.flags.len; + const variable_length = hash_count_len + compact_hashes_len + flag_bytes_len + flags_len; + return fixed_length + variable_length; + } + + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self { + var merkle_block_message: Self = undefined; + merkle_block_message.block_header = try BlockHeader.deserializeReader(r); + merkle_block_message.transaction_count = try r.readInt(u32, .little); + + // Read CompactSize hash_count + const hash_count = try CompactSizeUint.decodeReader(r); + merkle_block_message.hashes = try allocator.alloc([32]u8, hash_count.value()); + errdefer allocator.free(merkle_block_message.hashes); + + for (merkle_block_message.hashes) |*hash| { + try r.readNoEof(hash); + } + + // Read CompactSize flags_count + const flags_count = try CompactSizeUint.decodeReader(r); + merkle_block_message.flags = try allocator.alloc(u8, flags_count.value()); + errdefer allocator.free(merkle_block_message.flags); + + try r.readNoEof(merkle_block_message.flags); + return merkle_block_message; + } + + /// Deserialize bytes into a `MerkleBlockMessage` + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self { + var fbs = std.io.fixedBufferStream(bytes); + return try Self.deserializeReader(allocator, fbs.reader()); + } + + pub fn new(block_header: BlockHeader, transaction_count: u32, hashes: [][32]u8, flags: []u8) Self { + return .{ + .block_header = block_header, + .transaction_count = transaction_count, + .hashes = hashes, + .flags = flags, + }; + } +}; + +test "MerkleBlockMessage serialization and deserialization" { + const test_allocator = std.testing.allocator; + + const block_header = BlockHeader{ + .version = 1, + .prev_block = [_]u8{0} ** 32, + .merkle_root = [_]u8{1} ** 32, + .timestamp = 1234567890, + .bits = 0x1d00ffff, + .nonce = 987654321, + }; + const hashes = try test_allocator.alloc([32]u8, 3); + + const flags = try test_allocator.alloc(u8, 1); + const transaction_count = 1; + const msg = MerkleBlockMessage.new(block_header, transaction_count, hashes, flags); + + defer msg.deinit(test_allocator); + + // Fill in the header_hashes + for (msg.hashes) |*hash| { + for (hash) |*byte| { + byte.* = 0xab; + } + } + + flags[0] = 0x1; + + const serialized = try msg.serialize(test_allocator); + defer test_allocator.free(serialized); + + const deserialized = try MerkleBlockMessage.deserializeSlice(test_allocator, serialized); + defer deserialized.deinit(test_allocator); + + try std.testing.expectEqual(msg.block_header, deserialized.block_header); + try std.testing.expectEqual(msg.transaction_count, deserialized.transaction_count); + try std.testing.expectEqualSlices([32]u8, msg.hashes, deserialized.hashes); + try std.testing.expectEqualSlices(u8, msg.flags, deserialized.flags); + + try std.testing.expectEqual(msg.hintSerializedLen(), 84 + 1 + 32 * 3 + 1 + 1); +} diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index 8640c97..b908c18 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -18,6 +18,7 @@ pub const Error = error{ MessageTooLarge, }; +const BlockHeader = @import("../../types/BlockHeader.zig"); /// Return the checksum of a slice /// /// Use it on serialized messages to compute the header's value @@ -122,6 +123,8 @@ pub fn receiveMessage( protocol.messages.Message{ .ping = try protocol.messages.PingMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.PongMessage.name())) protocol.messages.Message{ .pong = try protocol.messages.PongMessage.deserializeReader(allocator, r) } + else if (std.mem.eql(u8, &command, protocol.messages.MerkleBlockMessage.name())) + protocol.messages.Message{ .merkleblock = try protocol.messages.MerkleBlockMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.SendCmpctMessage.name())) protocol.messages.Message{ .sendcmpct = try protocol.messages.SendCmpctMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.FilterClearMessage.name())) @@ -312,6 +315,58 @@ test "ok_send_ping_message" { } } +test "ok_send_merkleblock_message" { + const Config = @import("../../config/config.zig").Config; + const ArrayList = std.ArrayList; + const test_allocator = std.testing.allocator; + const MerkleBlockMessage = protocol.messages.MerkleBlockMessage; + + var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator); + defer list.deinit(); + + const block_header = BlockHeader{ + .version = 1, + .prev_block = [_]u8{0} ** 32, + .merkle_root = [_]u8{1} ** 32, + .timestamp = 1234567890, + .bits = 0x1d00ffff, + .nonce = 987654321, + }; + const hashes = try test_allocator.alloc([32]u8, 3); + + const flags = try test_allocator.alloc(u8, 1); + const transaction_count = 1; + const message = MerkleBlockMessage.new(block_header, transaction_count, hashes, flags); + + defer message.deinit(test_allocator); + // Fill in the header_hashes + for (message.hashes) |*hash| { + for (hash) |*byte| { + byte.* = 0xab; + } + } + flags[0] = 0x1; + + const serialized = try message.serialize(test_allocator); + defer test_allocator.free(serialized); + + const deserialized = try MerkleBlockMessage.deserializeSlice(test_allocator, serialized); + defer deserialized.deinit(test_allocator); + + const received_message = try write_and_read_message(test_allocator, &list, Config.BitcoinNetworkId.MAINNET, Config.PROTOCOL_VERSION, message) orelse unreachable; + defer received_message.deinit(test_allocator); + + switch (received_message) { + .merkleblock => {}, + else => unreachable, + } + + try std.testing.expectEqual(received_message.hintSerializedLen(), 183); + try std.testing.expectEqualSlices(u8, received_message.merkleblock.flags, flags); + try std.testing.expectEqual(received_message.merkleblock.transaction_count, transaction_count); + try std.testing.expectEqualSlices([32]u8, received_message.merkleblock.hashes, hashes); +} + test "ok_send_pong_message" { const Config = @import("../../config/config.zig").Config; const ArrayList = std.ArrayList; diff --git a/src/node/ibd.zig b/src/node/ibd.zig index 06eb1a6..d1b7d9d 100644 --- a/src/node/ibd.zig +++ b/src/node/ibd.zig @@ -8,6 +8,6 @@ pub const IBD = struct { .p2p = p2p, }; } - + pub fn start(_: *IBD) !void {} }; diff --git a/src/types/BlockHeader.zig b/src/types/BlockHeader.zig new file mode 100644 index 0000000..aaf58ce --- /dev/null +++ b/src/types/BlockHeader.zig @@ -0,0 +1,40 @@ +const std = @import("std"); + +version: i32, +prev_block: [32]u8, +merkle_root: [32]u8, +timestamp: u32, +bits: u32, +nonce: u32, + +const Self = @This(); + +pub fn serializeToWriter(self: *const Self, writer: anytype) !void { + comptime { + if (!std.meta.hasFn(@TypeOf(writer), "writeInt")) @compileError("Expects r to have fn 'writeInt'."); + if (!std.meta.hasFn(@TypeOf(writer), "writeAll")) @compileError("Expects r to have fn 'writeAll'."); + } + + try writer.writeInt(i32, self.version, .little); + try writer.writeAll(std.mem.asBytes(&self.prev_block)); + try writer.writeAll(std.mem.asBytes(&self.merkle_root)); + try writer.writeInt(u32, self.timestamp, .little); + try writer.writeInt(u32, self.bits, .little); + try writer.writeInt(u32, self.nonce, .little); +} + +pub fn deserializeReader(r: anytype) !Self { + var bh: Self = undefined; + bh.version = try r.readInt(i32, .little); + try r.readNoEof(&bh.prev_block); + try r.readNoEof(&bh.merkle_root); + bh.timestamp = try r.readInt(u32, .little); + bh.bits = try r.readInt(u32, .little); + bh.nonce = try r.readInt(u32, .little); + + return bh; +} + +pub fn serializedLen() usize { + return 80; +} diff --git a/src/types/block.zig b/src/types/block.zig index 57b8ade..5ae383a 100644 --- a/src/types/block.zig +++ b/src/types/block.zig @@ -11,3 +11,4 @@ pub const Block = struct { return ret; } }; +