Skip to content

Commit

Permalink
feat: (p2p/messages): add cmpctblock message (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
supreme2580 authored Oct 3, 2024
1 parent 292db95 commit 002228a
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 0 deletions.
225 changes: 225 additions & 0 deletions src/network/protocol/messages/cmpctblock.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const Transaction = @import("../../../types/transaction.zig");

const Sha256 = std.crypto.hash.sha2.Sha256;
const BlockHeader = @import("../../../types/block_header.zig");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const genericChecksum = @import("lib.zig").genericChecksum;

pub const CmpctBlockMessage = struct {
header: BlockHeader,
nonce: u64,
short_ids: []u64,
prefilled_txns: []PrefilledTransaction,

const Self = @This();

pub const PrefilledTransaction = struct {
index: usize,
tx: Transaction,
};

pub fn name() *const [12]u8 {
return protocol.CommandNames.CMPCTBLOCK;
}

pub fn checksum(self: *const Self) [4]u8 {
return genericChecksum(self);
}

pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
allocator.free(self.short_ids);
for (self.prefilled_txns) |*txn| {
txn.tx.deinit();
}
allocator.free(self.prefilled_txns);
}

pub fn serializeToWriter(self: *const Self, w: anytype) !void {
comptime {
if (!@hasDecl(@TypeOf(w), "writeInt")) {
@compileError("Writer must have a writeInt method");
}
}

try self.header.serializeToWriter(w);
try w.writeInt(u64, self.nonce, .little);

const short_ids_count = CompactSizeUint.new(self.short_ids.len);
try short_ids_count.encodeToWriter(w);
for (self.short_ids) |id| {
try w.writeInt(u64, id, .little);
}

const prefilled_txns_count = CompactSizeUint.new(self.prefilled_txns.len);
try prefilled_txns_count.encodeToWriter(w);

for (self.prefilled_txns) |txn| {
try CompactSizeUint.new(txn.index).encodeToWriter(w);
try txn.tx.serializeToWriter(w);
}
}

pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
}

pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();
if (serialized_len == 0) return &.{};
const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self {
comptime {
if (!@hasDecl(@TypeOf(r), "readInt")) {
@compileError("Reader must have a readInt method");
}
}

const header = try BlockHeader.deserializeReader(r);
const nonce = try r.readInt(u64, .little);

const short_ids_count = try CompactSizeUint.decodeReader(r);
const short_ids = try allocator.alloc(u64, short_ids_count.value());
errdefer allocator.free(short_ids);

for (short_ids) |*id| {
id.* = try r.readInt(u64, .little);
}

const prefilled_txns_count = try CompactSizeUint.decodeReader(r);
const prefilled_txns = try allocator.alloc(PrefilledTransaction, prefilled_txns_count.value());
errdefer allocator.free(prefilled_txns);

for (prefilled_txns) |*txn| {
const index = try CompactSizeUint.decodeReader(r);
const tx = try Transaction.deserializeReader(allocator, r);

txn.* = PrefilledTransaction{
.index = index.value(),
.tx = tx,
};
}

return Self{
.header = header,
.nonce = nonce,
.short_ids = short_ids,
.prefilled_txns = prefilled_txns,
};
}

pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
}

pub fn hintSerializedLen(self: *const Self) usize {
var len: usize = 80 + 8; // BlockHeader + nonce
len += CompactSizeUint.new(self.short_ids.len).hint_encoded_len();
len += self.short_ids.len * 8;
len += CompactSizeUint.new(self.prefilled_txns.len).hint_encoded_len();
for (self.prefilled_txns) |txn| {
len += CompactSizeUint.new(txn.index).hint_encoded_len();
len += txn.tx.hintEncodedLen();
}
return len;
}

pub fn eql(self: *const Self, other: *const Self) bool {
if (self.header.version != other.header.version or
!std.mem.eql(u8, &self.header.prev_block, &other.header.prev_block) or
!std.mem.eql(u8, &self.header.merkle_root, &other.header.merkle_root) or
self.header.timestamp != other.header.timestamp or
self.header.nbits != other.header.nbits or
self.header.nonce != other.header.nonce or
self.nonce != other.nonce) return false;

if (self.short_ids.len != other.short_ids.len) return false;
for (self.short_ids, other.short_ids) |a, b| {
if (a != b) return false;
}
if (self.prefilled_txns.len != other.prefilled_txns.len) return false;
for (self.prefilled_txns, other.prefilled_txns) |a, b| {
if (a.index != b.index or !a.tx.eql(b.tx)) return false;
}
return true;
}
};

test "CmpctBlockMessage serialization and deserialization" {
const testing = std.testing;
const Hash = @import("../../../types/hash.zig");
const Script = @import("../../../types/script.zig");
const OutPoint = @import("../../../types/outpoint.zig");
const OpCode = @import("../../../script/opcodes/constant.zig").Opcode;

const test_allocator = testing.allocator;

// Create a sample BlockHeader
const header = BlockHeader{
.version = 1,
.prev_block = [_]u8{0} ** 32, // Zero-filled array of 32 bytes
.merkle_root = [_]u8{0} ** 32, // Zero-filled array of 32 bytes
.timestamp = 1631234567,
.nbits = 0x1d00ffff,
.nonce = 12345,
};

// Create sample short_ids
const short_ids = try test_allocator.alloc(u64, 2);
defer test_allocator.free(short_ids);
short_ids[0] = 123456789;
short_ids[1] = 987654321;

// Create a sample Transaction
var tx = try Transaction.init(test_allocator);
defer tx.deinit();
try tx.addInput(OutPoint{ .hash = Hash.newZeroed(), .index = 0 });
{
var script_pubkey = try Script.init(test_allocator);
defer script_pubkey.deinit();
try script_pubkey.push(&[_]u8{ OpCode.OP_DUP.toBytes(), OpCode.OP_HASH160.toBytes(), OpCode.OP_EQUALVERIFY.toBytes(), OpCode.OP_CHECKSIG.toBytes() });
try tx.addOutput(50000, script_pubkey);
}

// Create sample prefilled_txns
const prefilled_txns = try test_allocator.alloc(CmpctBlockMessage.PrefilledTransaction, 1);
defer test_allocator.free(prefilled_txns);
prefilled_txns[0] = .{
.index = 0,
.tx = tx,
};

// Create CmpctBlockMessage
const msg = CmpctBlockMessage{
.header = header,
.nonce = 9876543210,
.short_ids = short_ids,
.prefilled_txns = prefilled_txns,
};

// Test serialization
const serialized = try msg.serialize(test_allocator);
defer test_allocator.free(serialized);

// Test deserialization
var deserialized = try CmpctBlockMessage.deserializeSlice(test_allocator, serialized);
defer deserialized.deinit(test_allocator);

// Verify deserialized data
try std.testing.expect(msg.eql(&deserialized));

// Test hintSerializedLen
const hint_len = msg.hintSerializedLen();
try testing.expect(hint_len > 0);
try testing.expect(hint_len == serialized.len);
}
7 changes: 7 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub const InventoryVector = struct {
};
}
};
pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage;

pub const MessageTypes = enum {
version,
Expand All @@ -67,6 +68,7 @@ pub const MessageTypes = enum {
sendheaders,
filterload,
headers,
cmpctblock,
};

pub const Message = union(MessageTypes) {
Expand All @@ -87,6 +89,7 @@ pub const Message = union(MessageTypes) {
sendheaders: SendHeadersMessage,
filterload: FilterLoadMessage,
headers: HeadersMessage,
cmpctblock: CmpctBlockMessage,

pub fn name(self: Message) *const [12]u8 {
return switch (self) {
Expand All @@ -107,6 +110,7 @@ pub const Message = union(MessageTypes) {
.sendheaders => |m| @TypeOf(m).name(),
.filterload => |m| @TypeOf(m).name(),
.headers => |m| @TypeOf(m).name(),
.cmpctblock => |m| @TypeOf(m).name(),
};
}

Expand All @@ -126,6 +130,7 @@ pub const Message = union(MessageTypes) {
.block => |*m| m.deinit(allocator),
.filteradd => |*m| m.deinit(allocator),
.notfound => {},
.cmpctblock => |*m| m.deinit(allocator),
.sendheaders => {},
.filterload => {},
.headers => |*m| m.deinit(allocator),
Expand All @@ -151,6 +156,7 @@ pub const Message = union(MessageTypes) {
.sendheaders => |*m| m.checksum(),
.filterload => |*m| m.checksum(),
.headers => |*m| m.checksum(),
.cmpctblock => |*m| m.checksum(),
};
}

Expand All @@ -173,6 +179,7 @@ pub const Message = union(MessageTypes) {
.sendheaders => |m| m.hintSerializedLen(),
.filterload => |*m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
.cmpctblock => |*m| m.hintSerializedLen(),
};
}
};
Expand Down
72 changes: 72 additions & 0 deletions src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ pub fn receiveMessage(
protocol.messages.Message{ .sendheaders = try protocol.messages.SendHeadersMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.FilterLoadMessage.name()))
protocol.messages.Message{ .filterload = try protocol.messages.FilterLoadMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name()))
protocol.messages.Message{ .cmpctblock = try protocol.messages.CmpctBlockMessage.deserializeReader(allocator, r) }
else {
try r.skipBytes(payload_len, .{}); // Purge the wire
return error.UnknownMessage;
Expand Down Expand Up @@ -579,3 +581,73 @@ test "ok_send_sendcmpct_message" {
else => unreachable,
}
}

test "ok_send_cmpctblock_message" {
const Transaction = @import("../../types/transaction.zig");
const OutPoint = @import("../../types/outpoint.zig");
const OpCode = @import("../../script/opcodes/constant.zig").Opcode;
const Hash = @import("../../types/hash.zig");
const Script = @import("../../types/script.zig");
const CmpctBlockMessage = @import("../protocol/messages/cmpctblock.zig").CmpctBlockMessage;

const allocator = std.testing.allocator;

// Create a sample BlockHeader
const header = BlockHeader{
.version = 1,
.prev_block = [_]u8{0} ** 32, // Zero-filled array of 32 bytes
.merkle_root = [_]u8{0} ** 32, // Zero-filled array of 32 bytes
.timestamp = 1631234567,
.nbits = 0x1d00ffff,
.nonce = 12345,
};

// Create sample short_ids
const short_ids = try allocator.alloc(u64, 2);
defer allocator.free(short_ids);
short_ids[0] = 123456789;
short_ids[1] = 987654321;

// Create a sample Transaction
var tx = try Transaction.init(allocator);
defer tx.deinit();
try tx.addInput(OutPoint{ .hash = Hash.newZeroed(), .index = 0 });
{
var script_pubkey = try Script.init(allocator);
defer script_pubkey.deinit();
try script_pubkey.push(&[_]u8{ OpCode.OP_DUP.toBytes(), OpCode.OP_HASH160.toBytes(), OpCode.OP_EQUALVERIFY.toBytes(), OpCode.OP_CHECKSIG.toBytes() });
try tx.addOutput(50000, script_pubkey);
}

// Create sample prefilled_txns
const prefilled_txns = try allocator.alloc(CmpctBlockMessage.PrefilledTransaction, 1);
defer allocator.free(prefilled_txns);
prefilled_txns[0] = .{
.index = 0,
.tx = tx,
};

// Create CmpctBlockMessage
const msg = CmpctBlockMessage{
.header = header,
.nonce = 9876543210,
.short_ids = short_ids,
.prefilled_txns = prefilled_txns,
};

// Test serialization
const serialized = try msg.serialize(allocator);
defer allocator.free(serialized);

// Test deserialization
var deserialized = try CmpctBlockMessage.deserializeSlice(allocator, serialized);
defer deserialized.deinit(allocator);

// Verify deserialized data
try std.testing.expect(msg.eql(&deserialized));

// Test hintSerializedLen
const hint_len = msg.hintSerializedLen();
try std.testing.expect(hint_len > 0);
try std.testing.expect(hint_len == serialized.len);
}

0 comments on commit 002228a

Please sign in to comment.