Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ybensacq committed Sep 23, 2024
1 parent 8791393 commit eef9717
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 31 deletions.
46 changes: 16 additions & 30 deletions src/network/protocol/messages/getblocks.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@ const Endian = std.builtin.Endian;
const Sha256 = std.crypto.hash.sha2.Sha256;

const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const MAX_SIZE: usize = 0x02000000; // 32 MB

/// GetblocksMessage represents the "getblocks" message
///
/// https://developer.bitcoin.org/reference/p2p_networking.html#getblocks
pub const GetblocksMessage = struct {

pub const Error = error{
MessageTooLarge,
};

version: i32,
hash_count: u64,
header_hashes: [] [32]u8,
Expand Down Expand Up @@ -73,40 +68,39 @@ pub const GetblocksMessage = struct {
return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetblocksMessage {
comptime {
if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'.");
if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects r to have fn 'readNoEof'.");
}

const buffer: []const u8 = try r.readAllAlloc(std.heap.page_allocator, MAX_SIZE + 1 );

// check message size
const msg_size = buffer.len;
if (msg_size > MAX_SIZE) {
return error.MessageTooLarge;
}

var gb: GetblocksMessage = undefined;
var fbs = std.io.fixedBufferStream(buffer);
var reader = fbs.reader();

gb.version = try reader.readInt(i32, .little);
gb.version = try r.readInt(i32, .little);

// Read CompactSize hash_count
const compact_hash_count = try CompactSizeUint.decodeReader(reader);
const compact_hash_count = try CompactSizeUint.decodeReader(r);
gb.hash_count = compact_hash_count.value();

// Allocate space for header_hashes based on hash_count
const header_hashes = try allocator.alloc([32]u8, gb.hash_count);

for (header_hashes) |*hash| {
try reader.readNoEof(hash);
try r.readNoEof(hash);
}
gb.header_hashes = header_hashes;

// Read the stop_hash (32 bytes)
try reader.readNoEof(&gb.stop_hash);
try r.readNoEof(&gb.stop_hash);
return gb;
}

Expand Down Expand Up @@ -134,10 +128,12 @@ pub const GetblocksMessage = struct {
return false;
}

var i: usize = 0;
for (self.header_hashes) |*hash| {
if (!std.mem.eql(u8, &hash.*, &other.header_hashes[0])) {
if (!std.mem.eql(u8, &hash.*, &other.header_hashes[i])) {
return false;
}
i += 1;
}

if (!std.mem.eql(u8, &self.stop_hash, &other.stop_hash)) {
Expand All @@ -146,16 +142,6 @@ pub const GetblocksMessage = struct {

return true;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

};

// TESTS
Expand Down
13 changes: 12 additions & 1 deletion src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ const protocol = @import("../protocol/lib.zig");
const Stream = std.net.Stream;
const io = std.io;
const Sha256 = std.crypto.hash.sha2.Sha256;
const MAX_SIZE: usize = 0x02000000; // 32 MB

pub const Error = error{
MessageTooLarge,
};

/// Return the checksum of a slice
///
Expand Down Expand Up @@ -49,9 +54,15 @@ pub fn sendMessage(allocator: std.mem.Allocator, w: anytype, protocol_version: i
defer allocator.free(payload);
const checksum = computePayloadChecksum(payload);

// No payload will be longer than u32.MAX
const payload_len: u32 = @intCast(payload.len);

// Calculate total message size
const total_message_size = network_id.len + command.len + @sizeOf(u32) + @sizeOf(u32) + payload_len;

if (total_message_size > MAX_SIZE) {
return Error.MessageTooLarge;
}

try w.writeAll(&network_id);
try w.writeAll(command);
try w.writeAll(std.mem.asBytes(&payload_len));
Expand Down

0 comments on commit eef9717

Please sign in to comment.