Skip to content

Commit

Permalink
Packet size guard (#1055)
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n authored Apr 23, 2024
1 parent c8c2a15 commit 02f873d
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 18 deletions.
73 changes: 72 additions & 1 deletion crates/core/src/transport/connection_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ mod test {
use tracing::info;

use super::*;
use crate::DynError;
use crate::{transport::packet_data::MAX_DATA_SIZE, DynError};

#[allow(clippy::type_complexity)]
static CHANNELS: OnceLock<
Expand Down Expand Up @@ -1333,6 +1333,77 @@ mod test {
.await
}

#[tokio::test]
async fn simulate_send_max_short_message() -> Result<(), DynError> {
let (peer_a_pub, mut peer_a, peer_a_addr) = set_peer_connection(Default::default()).await?;
let (peer_b_pub, mut peer_b, peer_b_addr) = set_peer_connection(Default::default()).await?;

let peer_b = tokio::spawn(async move {
let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr).await;
let mut conn = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??;
let data = vec![0u8; 1432];
let data = tokio::task::spawn_blocking(move || bincode::serialize(&data).unwrap())
.await
.unwrap();
conn.outbound_short_message(data).await?;
Ok::<_, DynError>(())
});

let peer_a = tokio::spawn(async move {
let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr).await;
let mut conn = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??;
let msg = conn.recv().await?;
assert!(msg.len() <= MAX_DATA_SIZE);
Ok::<_, DynError>(())
});

let (a, b) = tokio::try_join!(peer_a, peer_b)?;
a?;
b?;
Ok(())
}

#[test]
#[should_panic]
fn simulate_send_max_short_message_plus_1() {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(async move {
let (peer_a_pub, mut peer_a, peer_a_addr) =
set_peer_connection(Default::default()).await?;
let (peer_b_pub, mut peer_b, peer_b_addr) =
set_peer_connection(Default::default()).await?;

let peer_b = tokio::spawn(async move {
let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr).await;
let mut conn =
tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??;
let data = vec![0u8; 1433];
let data =
tokio::task::spawn_blocking(move || bincode::serialize(&data).unwrap())
.await
.unwrap();
conn.outbound_short_message(data).await?;
Ok::<_, DynError>(())
});

let peer_a = tokio::spawn(async move {
let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr).await;
let mut conn =
tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??;
let msg = conn.recv().await?;
assert!(msg.len() <= MAX_DATA_SIZE);
Ok::<_, DynError>(())
});

let (a, b) = tokio::try_join!(peer_a, peer_b)?;
a?;
b?;
Result::<(), DynError>::Ok(())
})
.unwrap();
}

#[tokio::test]
async fn simulate_send_streamed_message() -> Result<(), DynError> {
// crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE));
Expand Down
95 changes: 80 additions & 15 deletions crates/core/src/transport/peer_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl PeerConnection {
}

#[inline]
async fn outbound_short_message(&mut self, data: SerializedMessage) -> Result<()> {
pub(crate) async fn outbound_short_message(&mut self, data: SerializedMessage) -> Result<()> {
let receipts = self.received_tracker.get_receipts();
let packet_id = self
.remote_conn
Expand Down Expand Up @@ -365,24 +365,89 @@ async fn packet_sending(
payload: impl Into<SymmetricMessagePayload>,
sent_tracker: &parking_lot::Mutex<SentPacketTracker<InstantTimeSrc>>,
) -> Result<()> {
// FIXME: here ensure that `confirm_receipt` won't make the packet exceed the max data size
// if it does, split it to send multiple noop packets with the receipts

// tracing::trace!(packet_id, "sending packet");
let packet = SymmetricMessage::serialize_msg_to_packet_data(
match SymmetricMessage::try_serialize_msg_to_packet_data(
packet_id,
payload,
outbound_sym_key,
confirm_receipt,
)?;
outbound_packets
.send((remote_addr, packet.clone().prepared_send()))
.await
.map_err(|_| TransportError::ConnectionClosed)?;
sent_tracker
.lock()
.report_sent_packet(packet_id, packet.prepared_send());
Ok(())
)? {
either::Either::Left(packet) => {
outbound_packets
.send((remote_addr, packet.clone().prepared_send()))
.await
.map_err(|_| TransportError::ConnectionClosed)?;
sent_tracker
.lock()
.report_sent_packet(packet_id, packet.prepared_send());
Ok(())
}
either::Either::Right((payload, mut confirm_receipt)) => {
macro_rules! send {
($packets:ident) => {{
for packet in $packets {
outbound_packets
.send((remote_addr, packet.clone().prepared_send()))
.await
.map_err(|_| TransportError::ConnectionClosed)?;
sent_tracker
.lock()
.report_sent_packet(packet_id, packet.prepared_send());
}
}};
}

let max_num = SymmetricMessage::max_num_of_confirm_receipts_of_noop_message();
let packet = SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
payload,
outbound_sym_key,
vec![],
)?;

if max_num > confirm_receipt.len() {
let packets = [
packet,
SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
confirm_receipt,
)?,
];

send!(packets);
return Ok(());
}

let mut packets = Vec::with_capacity(8);
packets.push(packet);

while !confirm_receipt.is_empty() {
let len = confirm_receipt.len();

if len <= max_num {
packets.push(SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
confirm_receipt,
)?);
break;
}

let receipts = confirm_receipt.split_off(max_num);
packets.push(SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
receipts,
)?);
}

send!(packets);
Ok(())
}
}
}

#[cfg(test)]
Expand Down
72 changes: 70 additions & 2 deletions crates/core/src/transport/symmetric_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{borrow::Cow, net::SocketAddr, sync::OnceLock};

use crate::transport::packet_data::SymmetricAES;
use aes_gcm::Aes128Gcm;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;

Expand Down Expand Up @@ -37,6 +38,24 @@ impl SymmetricMessage {
},
};

pub(crate) fn max_num_of_confirm_receipts_of_noop_message() -> usize {
static MAX_NUM_CONFIRM_RECEIPTS: Lazy<usize> = Lazy::new(|| {
// try to find the maximum number of confirm_receipts that can be serialized within the MAX_DATA_SIZE
let blank = SymmetricMessage {
packet_id: u32::MAX,
confirm_receipt: vec![],
payload: SymmetricMessagePayload::NoOp,
};
let overhead = bincode::serialized_size(&blank).unwrap();

let max_elems = (MAX_DATA_SIZE as u64 - overhead) / core::mem::size_of::<u32>() as u64;

max_elems as usize
});

*MAX_NUM_CONFIRM_RECEIPTS
}

pub fn ack_error(
outbound_sym_key: &Aes128Gcm,
) -> Result<PacketData<SymmetricAES>, bincode::Error> {
Expand Down Expand Up @@ -76,6 +95,34 @@ impl SymmetricMessage {
Ok(packet.encrypt_symmetric(outbound_sym_key))
}

#[allow(clippy::type_complexity)]
pub fn try_serialize_msg_to_packet_data(
packet_id: PacketId,
payload: impl Into<SymmetricMessagePayload>,
outbound_sym_key: &Aes128Gcm,
confirm_receipt: Vec<u32>,
) -> Result<
either::Either<PacketData<SymmetricAES>, (SymmetricMessagePayload, Vec<u32>)>,
bincode::Error,
> {
let msg = Self {
packet_id,
confirm_receipt,
payload: payload.into(),
};

let size = bincode::serialized_size(&msg)?;
if size <= MAX_DATA_SIZE as u64 {
let mut packet = [0u8; MAX_DATA_SIZE];
bincode::serialize_into(packet.as_mut_slice(), &msg)?;
let bytes = &packet[..size as usize];
let packet = PacketData::from_buf_plain(bytes);
Ok(either::Left(packet.encrypt_symmetric(outbound_sym_key)))
} else {
Ok(either::Right((msg.payload, msg.confirm_receipt)))
}
}

pub fn serialize_msg_to_packet_data(
packet_id: PacketId,
payload: impl Into<SymmetricMessagePayload>,
Expand All @@ -87,10 +134,18 @@ impl SymmetricMessage {
confirm_receipt,
payload: payload.into(),
};

message.to_packet_data(outbound_sym_key)
}

pub(crate) fn to_packet_data(
&self,
outbound_sym_key: &Aes128Gcm,
) -> Result<PacketData<SymmetricAES>, bincode::Error> {
let mut packet = [0u8; MAX_DATA_SIZE];
let size = bincode::serialized_size(&message)?;
let size = bincode::serialized_size(self)?;
debug_assert!(size <= MAX_DATA_SIZE as u64);
bincode::serialize_into(packet.as_mut_slice(), &message)?;
bincode::serialize_into(packet.as_mut_slice(), self)?;
let bytes = &packet[..size as usize];
let packet = PacketData::from_buf_plain(bytes);
Ok(packet.encrypt_symmetric(outbound_sym_key))
Expand Down Expand Up @@ -285,4 +340,17 @@ mod test {
));
Ok(())
}

#[test]
fn max_confirm_receipts_of_noop_message() {
let num = SymmetricMessage::max_num_of_confirm_receipts_of_noop_message();

let msg = SymmetricMessage {
packet_id: u32::MAX,
confirm_receipt: vec![u32::MAX; num],
payload: SymmetricMessagePayload::NoOp,
};
let size = bincode::serialized_size(&msg).unwrap();
assert!(size <= MAX_DATA_SIZE as u64);
}
}

0 comments on commit 02f873d

Please sign in to comment.