Skip to content

Commit

Permalink
feat(p2p/messages): add getblocktxn (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
bloomingpeach authored Oct 7, 2024
1 parent 7a02c47 commit 828cbb5
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
124 changes: 124 additions & 0 deletions src/network/protocol/messages/getblocktxn.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
const std = @import("std");
const protocol = @import("../lib.zig");

const BlockHeader = @import("../../../types/block_header.zig");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;
/// GetBlockTxnMessage represents the "GetBlockTxn" message
///
/// https://developer.bitcoin.org/reference/p2p_networking.html#getblocktxn
pub const GetBlockTxnMessage = struct {
block_hash: [32]u8,
indexes: []u64,

const Self = @This();

pub fn name() *const [12]u8 {
return protocol.CommandNames.GETBLOCKTXN ++ [_]u8{0};
}

/// Returns the message checksum
pub fn checksum(self: *const Self) [4]u8 {
return genericChecksum(self);
}

/// Free the allocated memory
pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void {
allocator.free(self.indexes);
}

/// Serialize the message as bytes and write them to the Writer.
pub fn serializeToWriter(self: *const Self, w: anytype) !void {
comptime {
if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects r to have fn 'writeAll");
}
try w.writeAll(&self.block_hash);
const indexes_count = CompactSizeUint.new(self.indexes.len);
try indexes_count.encodeToWriter(w);
for (self.indexes) |*index| {
const compact_index = CompactSizeUint.new(index.*);
try compact_index.encodeToWriter(w);
}
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
}

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 {
return try genericSerialize(self, allocator);
}

/// Returns the hint of the serialized length of the message.
pub fn hintSerializedLen(self: *const Self) usize {
// 32 bytes for the block hash
const fixed_length = 32;

const indexes_count_length: usize = CompactSizeUint.new(self.indexes.len).hint_encoded_len();

var compact_indexes_length: usize = 0;
for (self.indexes) |index| {
compact_indexes_length += CompactSizeUint.new(index).hint_encoded_len();
}

const variable_length = indexes_count_length + compact_indexes_length;

return fixed_length + variable_length;
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self {
var blockhash: [32]u8 = undefined;
try r.readNoEof(&blockhash);

const indexes_count = try CompactSizeUint.decodeReader(r);
const indexes = try allocator.alloc(u64, indexes_count.value());
errdefer allocator.free(indexes);

for (indexes) |*index| {
const compact_index = try CompactSizeUint.decodeReader(r);
index.* = compact_index.value();
}

return new(blockhash, indexes);
}

/// Deserialize bytes into a `GetBlockTxnMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return try genericDeserializeSlice(GetBlockTxnMessage, allocator, bytes);
}

pub fn new(block_hash: [32]u8, indexes: []u64) Self {
return .{
.block_hash = block_hash,
.indexes = indexes,
};
}
};

test "GetBlockTxnMessage serialization and deserialization" {
const test_allocator = std.testing.allocator;

const block_hash: [32]u8 = [_]u8{0} ** 32;
const indexes = try test_allocator.alloc(u64, 1);
indexes[0] = 123;
const msg = GetBlockTxnMessage.new(block_hash, indexes);

defer msg.deinit(test_allocator);

const serialized = try msg.serialize(test_allocator);
defer test_allocator.free(serialized);

const deserialized = try GetBlockTxnMessage.deserializeSlice(test_allocator, serialized);
defer deserialized.deinit(test_allocator);

try std.testing.expectEqual(msg.block_hash, deserialized.block_hash);
try std.testing.expectEqual(msg.indexes[0], msg.indexes[0]);
try std.testing.expectEqual(msg.hintSerializedLen(), 32 + 1 + 1);
}
9 changes: 9 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const Sha256 = std.crypto.hash.sha2.Sha256;
pub const NotFoundMessage = @import("notfound.zig").NotFoundMessage;
pub const SendHeadersMessage = @import("sendheaders.zig").SendHeadersMessage;
pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage;
pub const GetBlockTxnMessage = @import("getblocktxn.zig").GetBlockTxnMessage;
pub const HeadersMessage = @import("headers.zig").HeadersMessage;
pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage;

Expand All @@ -40,6 +41,7 @@ pub const MessageTypes = enum {
notfound,
sendheaders,
filterload,
getblocktxn,
getdata,
headers,
cmpctblock,
Expand All @@ -64,6 +66,7 @@ pub const Message = union(MessageTypes) {
notfound: NotFoundMessage,
sendheaders: SendHeadersMessage,
filterload: FilterLoadMessage,
getblocktxn: GetBlockTxnMessage,
getdata: GetdataMessage,
headers: HeadersMessage,
cmpctblock: CmpctBlockMessage,
Expand All @@ -87,6 +90,7 @@ pub const Message = union(MessageTypes) {
.notfound => |m| @TypeOf(m).name(),
.sendheaders => |m| @TypeOf(m).name(),
.filterload => |m| @TypeOf(m).name(),
.getblocktxn => |m| @TypeOf(m).name(),
.getdata => |m| @TypeOf(m).name(),
.headers => |m| @TypeOf(m).name(),
.cmpctblock => |m| @TypeOf(m).name(),
Expand All @@ -105,6 +109,9 @@ pub const Message = union(MessageTypes) {
.filteradd => |*m| m.deinit(allocator),
.getdata => |*m| m.deinit(allocator),
.cmpctblock => |*m| m.deinit(allocator),
.sendheaders => {},
.filterload => {},
.getblocktxn => |*m| m.deinit(allocator),
.headers => |*m| m.deinit(allocator),
else => {}
}
Expand All @@ -128,6 +135,7 @@ pub const Message = union(MessageTypes) {
.notfound => |*m| m.checksum(),
.sendheaders => |*m| m.checksum(),
.filterload => |*m| m.checksum(),
.getblocktxn => |*m| m.checksum(),
.addr => |*m| m.checksum(),
.getdata => |*m| m.checksum(),
.headers => |*m| m.checksum(),
Expand All @@ -153,6 +161,7 @@ pub const Message = union(MessageTypes) {
.notfound => |m| m.hintSerializedLen(),
.sendheaders => |m| m.hintSerializedLen(),
.filterload => |*m| m.hintSerializedLen(),
.getblocktxn => |*m| m.hintSerializedLen(),
.addr => |*m| m.hintSerializedLen(),
.getdata => |m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
Expand Down
9 changes: 9 additions & 0 deletions src/network/protocol/messages/merkleblock.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pub const MerkleBlockMessage = struct {

/// Serialize the message as bytes and write them to the Writer.
pub fn serializeToWriter(self: *const Self, w: anytype) !void {
comptime {
if (!std.meta.hasFn(@TypeOf(w), "writeInt")) @compileError("Expects w to have fn 'writeInt'.");
if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects w to have fn 'writeAll'.");
}
try self.block_header.serializeToWriter(w);
try w.writeInt(u32, self.transaction_count, .little);
const hash_count = CompactSizeUint.new(self.hashes.len);
Expand Down Expand Up @@ -68,6 +72,11 @@ pub const MerkleBlockMessage = struct {
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self {
comptime {
if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'.");
if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects r to have fn 'readNoEof'.");
}

var merkle_block_message: Self = undefined;
merkle_block_message.block_header = try BlockHeader.deserializeReader(r);
merkle_block_message.transaction_count = try r.readInt(u32, .little);
Expand Down
35 changes: 35 additions & 0 deletions src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub const Error = error{
};

const BlockHeader = @import("../../types/block_header.zig");

/// Return the checksum of a slice
///
/// Use it on serialized messages to compute the header's value
Expand Down Expand Up @@ -143,6 +144,8 @@ pub fn receiveMessage(
protocol.messages.Message{ .sendheaders = try protocol.messages.SendHeadersMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.FilterLoadMessage.name()))
protocol.messages.Message{ .filterload = try protocol.messages.FilterLoadMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.GetBlockTxnMessage.name()))
protocol.messages.Message{ .getblocktxn = try protocol.messages.GetBlockTxnMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.GetdataMessage.name()))
protocol.messages.Message{ .getdata = try protocol.messages.GetdataMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name()))
Expand Down Expand Up @@ -519,6 +522,38 @@ test "ok_send_sendheaders_message" {
}
}

test "ok_send_getblocktxn_message" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
const test_allocator = std.testing.allocator;
const GetBlockTxnMessage = protocol.messages.GetBlockTxnMessage;

var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator);
defer list.deinit();

const block_hash = [_]u8{1} ** 32;
const indexes = try test_allocator.alloc(u64, 1);
indexes[0] = 1;
const message = GetBlockTxnMessage.new(block_hash, indexes);
defer message.deinit(test_allocator);
var received_message = try write_and_read_message(
test_allocator,
&list,
Config.BitcoinNetworkId.MAINNET,
Config.PROTOCOL_VERSION,
message,
) orelse unreachable;
defer received_message.deinit(test_allocator);

switch (received_message) {
.getblocktxn => {
try std.testing.expectEqual(message.block_hash, received_message.getblocktxn.block_hash);
try std.testing.expectEqual(indexes[0], received_message.getblocktxn.indexes[0]);
},
else => unreachable,
}
}

test "ko_receive_invalid_payload_length" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
Expand Down

0 comments on commit 828cbb5

Please sign in to comment.