Skip to content

Commit

Permalink
feat(p2p/messages): add getblocks message (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ybensacq authored Sep 25, 2024
1 parent e6c8425 commit 4d9a459
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 1 deletion.
181 changes: 181 additions & 0 deletions src/network/protocol/messages/getblocks.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
const std = @import("std");
const native_endian = @import("builtin").target.cpu.arch.endian();
const protocol = @import("../lib.zig");

const ServiceFlags = protocol.ServiceFlags;

const Endian = std.builtin.Endian;
const Sha256 = std.crypto.hash.sha2.Sha256;

const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;

/// GetblocksMessage represents the "getblocks" message
///
/// https://developer.bitcoin.org/reference/p2p_networking.html#getblocks
pub const GetblocksMessage = struct {

version: i32,
header_hashes: [] [32]u8,
stop_hash: [32]u8,

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.GETBLOCKS ++ [_]u8{0} ** 5;
}

/// Returns the message checksum
///
/// Computed as `Sha256(Sha256(self.serialize()))[0..4]`
pub fn checksum(self: *const GetblocksMessage) [4]u8 {
var digest: [32]u8 = undefined;
var hasher = Sha256.init(.{});
const writer = hasher.writer();
self.serializeToWriter(writer) catch unreachable; // Sha256.write is infaible
hasher.final(&digest);

Sha256.hash(&digest, &digest, .{});

return digest[0..4].*;
}

/// Free the `header_hashes`
pub fn deinit(self: GetblocksMessage, allocator: std.mem.Allocator) void {
allocator.free(self.header_hashes);
}

/// Serialize the message as bytes and write them to the Writer.
///
/// `w` should be a valid `Writer`.
pub fn serializeToWriter(self: *const GetblocksMessage, w: anytype) !void {
comptime {
if (!std.meta.hasFn(@TypeOf(w), "writeInt")) @compileError("Expects r to have fn 'writeInt'.");
if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects r to have fn 'writeAll'.");
}

try w.writeInt(i32, self.version, .little);
const compact_hash_count = CompactSizeUint.new(self.header_hashes.len);
try compact_hash_count.encodeToWriter(w);
for (self.header_hashes) |header_hash| {
try w.writeAll(&header_hash);
}
try w.writeAll(&self.stop_hash);
}

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const GetblocksMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetblocksMessage {
comptime {
if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'.");
if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects r to have fn 'readNoEof'.");
}

var gb: GetblocksMessage = undefined;

gb.version = try r.readInt(i32, .little);

// Read CompactSize hash_count
const compact_hash_count = try CompactSizeUint.decodeReader(r);

// Allocate space for header_hashes based on hash_count
const header_hashes = try allocator.alloc([32]u8, compact_hash_count.value());

for (header_hashes) |*hash| {
try r.readNoEof(hash);
}
gb.header_hashes = header_hashes;

// Read the stop_hash (32 bytes)
try r.readNoEof(&gb.stop_hash);
return gb;
}

/// Deserialize bytes into a `GetblocksMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetblocksMessage {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();

return try GetblocksMessage.deserializeReader(allocator, reader);
}

pub fn hintSerializedLen(self: *const GetblocksMessage) usize {
const fixed_length = 4 + 32; // version (4 bytes) + stop_hash (32 bytes)
const compact_hash_count_len = CompactSizeUint.new(self.header_hashes.len).hint_encoded_len();
const header_hashes_len = self.header_hashes.len * 32; // hash (32 bytes)
return fixed_length + compact_hash_count_len + header_hashes_len;
}

pub fn eql(self: *const GetblocksMessage, other: *const GetblocksMessage) bool {
if (self.version != other.version or self.header_hashes.len != other.header_hashes.len) {
return false;
}

if (self.header_hashes.len != other.header_hashes.len) {
return false;
}

for (0..self.header_hashes.len) |i| {
if (!std.mem.eql(u8, self.header_hashes[i][0..], other.header_hashes[i][0..])) {
return false;
}
}

if (!std.mem.eql(u8, &self.stop_hash, &other.stop_hash)) {
return false;
}

return true;
}
};

// TESTS

test "ok_full_flow_GetBlocksMessage" {
const allocator = std.testing.allocator;

// With some header_hashes
{

const gb = GetblocksMessage{
.version = 42,
.header_hashes = try allocator.alloc([32]u8, 2),
.stop_hash = [_]u8{0} ** 32,
};

defer allocator.free(gb.header_hashes);

// Fill in the header_hashes
for (gb.header_hashes) |*hash| {
for (hash) |*byte| {
byte.* = 0xab;
}
}

const payload = try gb.serialize(allocator);
defer allocator.free(payload);

const deserialized_gb = try GetblocksMessage.deserializeSlice(allocator, payload);

try std.testing.expect(gb.eql(&deserialized_gb));
defer allocator.free(deserialized_gb.header_hashes);

}

}
6 changes: 6 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ pub const VersionMessage = @import("version.zig").VersionMessage;
pub const VerackMessage = @import("verack.zig").VerackMessage;
pub const MempoolMessage = @import("mempool.zig").MempoolMessage;
pub const GetaddrMessage = @import("getaddr.zig").GetaddrMessage;
pub const GetblocksMessage = @import("getblocks.zig").GetblocksMessage;
pub const PingMessage = @import("ping.zig").PingMessage;

pub const MessageTypes = enum {
Version,
Verack,
Mempool,
Getaddr,
Getblocks,
Ping,
};

Expand All @@ -18,6 +20,7 @@ pub const Message = union(MessageTypes) {
Verack: VerackMessage,
Mempool: MempoolMessage,
Getaddr: GetaddrMessage,
Getblocks: GetblocksMessage,
Ping: PingMessage,

pub fn deinit(self: Message, allocator: std.mem.Allocator) void {
Expand All @@ -26,6 +29,7 @@ pub const Message = union(MessageTypes) {
.Verack => {},
.Mempool => {},
.Getaddr => {},
.Getblocks => |m| m.deinit(allocator),
.Ping => {},
}
}
Expand All @@ -35,6 +39,7 @@ pub const Message = union(MessageTypes) {
.Verack => |m| m.checksum(),
.Mempool => |m| m.checksum(),
.Getaddr => |m| m.checksum(),
.Getblocks => |m| m.checksum(),
.Ping => |m| m.checksum(),
};
}
Expand All @@ -45,6 +50,7 @@ pub const Message = union(MessageTypes) {
.Verack => |m| m.hintSerializedLen(),
.Mempool => |m| m.hintSerializedLen(),
.Getaddr => |m| m.hintSerializedLen(),
.Getblocks => |m| m.hintSerializedLen(),
.Ping => |m| m.hintSerializedLen(),
};
}
Expand Down
57 changes: 56 additions & 1 deletion src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ const protocol = @import("../protocol/lib.zig");
const Stream = std.net.Stream;
const io = std.io;
const Sha256 = std.crypto.hash.sha2.Sha256;
const MAX_SIZE: usize = 0x02000000; // 32 MB

pub const Error = error{
MessageTooLarge,
};

/// Return the checksum of a slice
///
Expand Down Expand Up @@ -49,9 +54,16 @@ pub fn sendMessage(allocator: std.mem.Allocator, w: anytype, protocol_version: i
defer allocator.free(payload);
const checksum = computePayloadChecksum(payload);

// No payload will be longer than u32.MAX
const payload_len: u32 = @intCast(payload.len);

// Calculate total message size
const precomputed_total_size = 24; // network (4 bytes) + command (12 bytes) + payload size (4 bytes) + checksum (4 bytes)
const total_message_size = precomputed_total_size + payload_len;

if (total_message_size > MAX_SIZE) {
return Error.MessageTooLarge;
}

try w.writeAll(&network_id);
try w.writeAll(command);
try w.writeAll(std.mem.asBytes(&payload_len));
Expand Down Expand Up @@ -84,6 +96,8 @@ pub fn receiveMessage(allocator: std.mem.Allocator, r: anytype) !protocol.messag
protocol.messages.Message{ .Mempool = try protocol.messages.MempoolMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.GetaddrMessage.name()))
protocol.messages.Message{ .Getaddr = try protocol.messages.GetaddrMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.GetblocksMessage.name()))
protocol.messages.Message{ .Getblocks = try protocol.messages.GetblocksMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.PingMessage.name()))
protocol.messages.Message{ .Ping = try protocol.messages.PingMessage.deserializeReader(allocator, r) }
else
Expand Down Expand Up @@ -193,6 +207,46 @@ test "ok_send_mempool_message" {
}
}


test "ok_send_getblocks_message" {
const Config = @import("../../config/config.zig").Config;

const ArrayList = std.ArrayList;
const test_allocator = std.testing.allocator;
const GetblocksMessage = protocol.messages.GetblocksMessage;

var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator);
defer list.deinit();

const message = GetblocksMessage{
.version = 42,
.header_hashes = try test_allocator.alloc([32]u8, 2),
.stop_hash = [_]u8{0} ** 32,
};

defer test_allocator.free(message.header_hashes);

// Fill in the header_hashes
for (message.header_hashes) |*hash| {
for (hash) |*byte| {
byte.* = 0xab;
}
}

const writer = list.writer();
try sendMessage(test_allocator, writer, Config.PROTOCOL_VERSION, Config.BitcoinNetworkId.MAINNET, message);
var fbs: std.io.FixedBufferStream([]u8) = std.io.fixedBufferStream(list.items);
const reader = fbs.reader();

const received_message = try receiveMessage(test_allocator, reader);
defer received_message.deinit(test_allocator);

switch (received_message) {
.Getblocks => |rm| try std.testing.expect(message.eql(&rm)),
else => unreachable,
}
}

test "ok_send_ping_message" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
Expand All @@ -218,6 +272,7 @@ test "ok_send_ping_message" {
}
}


test "ko_receive_invalid_payload_length" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
Expand Down

0 comments on commit 4d9a459

Please sign in to comment.