diff --git a/src/network/protocol/messages/getblocks.zig b/src/network/protocol/messages/getblocks.zig index 4366b82..a73e501 100644 --- a/src/network/protocol/messages/getblocks.zig +++ b/src/network/protocol/messages/getblocks.zig @@ -8,17 +8,12 @@ const Endian = std.builtin.Endian; const Sha256 = std.crypto.hash.sha2.Sha256; const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; -const MAX_SIZE: usize = 0x02000000; // 32 MB /// GetblocksMessage represents the "getblocks" message /// /// https://developer.bitcoin.org/reference/p2p_networking.html#getblocks pub const GetblocksMessage = struct { - pub const Error = error{ - MessageTooLarge, - }; - version: i32, hash_count: u64, header_hashes: [] [32]u8, @@ -73,40 +68,39 @@ pub const GetblocksMessage = struct { return ret; } + /// Serialize a message as bytes and write them to the buffer. + /// + /// buffer.len must be >= than self.hintSerializedLen() + pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + const writer = fbs.writer(); + try self.serializeToWriter(writer); + } + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetblocksMessage { comptime { if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'."); if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects r to have fn 'readNoEof'."); } - const buffer: []const u8 = try r.readAllAlloc(std.heap.page_allocator, MAX_SIZE + 1 ); - - // check message size - const msg_size = buffer.len; - if (msg_size > MAX_SIZE) { - return error.MessageTooLarge; - } - var gb: GetblocksMessage = undefined; - var fbs = std.io.fixedBufferStream(buffer); - var reader = fbs.reader(); - gb.version = try reader.readInt(i32, .little); + gb.version = try r.readInt(i32, .little); // Read CompactSize hash_count - const compact_hash_count = try CompactSizeUint.decodeReader(reader); + const compact_hash_count = try CompactSizeUint.decodeReader(r); gb.hash_count = compact_hash_count.value(); // Allocate space for header_hashes based on hash_count const header_hashes = try allocator.alloc([32]u8, gb.hash_count); for (header_hashes) |*hash| { - try reader.readNoEof(hash); + try r.readNoEof(hash); } gb.header_hashes = header_hashes; // Read the stop_hash (32 bytes) - try reader.readNoEof(&gb.stop_hash); + try r.readNoEof(&gb.stop_hash); return gb; } @@ -134,10 +128,12 @@ pub const GetblocksMessage = struct { return false; } + var i: usize = 0; for (self.header_hashes) |*hash| { - if (!std.mem.eql(u8, &hash.*, &other.header_hashes[0])) { + if (!std.mem.eql(u8, &hash.*, &other.header_hashes[i])) { return false; } + i += 1; } if (!std.mem.eql(u8, &self.stop_hash, &other.stop_hash)) { @@ -146,16 +142,6 @@ pub const GetblocksMessage = struct { return true; } - - /// Serialize a message as bytes and write them to the buffer. - /// - /// buffer.len must be >= than self.hintSerializedLen() - pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void { - var fbs = std.io.fixedBufferStream(buffer); - const writer = fbs.writer(); - try self.serializeToWriter(writer); - } - }; // TESTS diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index 5faed44..cd278a3 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -15,6 +15,11 @@ const protocol = @import("../protocol/lib.zig"); const Stream = std.net.Stream; const io = std.io; const Sha256 = std.crypto.hash.sha2.Sha256; +const MAX_SIZE: usize = 0x02000000; // 32 MB + +pub const Error = error{ + MessageTooLarge, +}; /// Return the checksum of a slice /// @@ -49,9 +54,15 @@ pub fn sendMessage(allocator: std.mem.Allocator, w: anytype, protocol_version: i defer allocator.free(payload); const checksum = computePayloadChecksum(payload); - // No payload will be longer than u32.MAX const payload_len: u32 = @intCast(payload.len); + // Calculate total message size + const total_message_size = network_id.len + command.len + @sizeOf(u32) + @sizeOf(u32) + payload_len; + + if (total_message_size > MAX_SIZE) { + return Error.MessageTooLarge; + } + try w.writeAll(&network_id); try w.writeAll(command); try w.writeAll(std.mem.asBytes(&payload_len));