Skip to content

Commit

Permalink
Don't mix packets from mutliple inbound conns
Browse files Browse the repository at this point in the history
  • Loading branch information
iduartgomez committed Sep 14, 2024
1 parent d8d6efe commit 65ef357
Showing 1 changed file with 43 additions and 61 deletions.
104 changes: 43 additions & 61 deletions crates/core/src/transport/connection_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,12 @@ impl<S: Socket> UdpPacketsListener<S> {
async fn listen(mut self) -> Result<(), TransportError> {
let mut buf = [0u8; MAX_PACKET_SIZE];
let mut ongoing_connections: BTreeMap<SocketAddr, OngoingConnection> = BTreeMap::new();
let mut ongoing_gw_connections: BTreeMap<
SocketAddr,
mpsc::Sender<PacketData<UnknownEncryption>>,
> = BTreeMap::new();
let mut connection_tasks = FuturesUnordered::new();
let mut gw_connection_tasks = FuturesUnordered::new();
let (gw_outbound_tx, mut gw_inbound_rx) = tokio::sync::mpsc::channel(100);
loop {
tokio::select! {
// Handling of inbound packets
Expand All @@ -267,6 +270,12 @@ impl<S: Socket> UdpPacketsListener<S> {
continue;
}

if let Some(inbound_packet_sender) = ongoing_gw_connections.remove(&remote_addr){
let _ = inbound_packet_sender.send(packet_data).await;
ongoing_gw_connections.insert(remote_addr, inbound_packet_sender);
continue;
}

if let Some((packets_sender, open_connection)) = ongoing_connections.remove(&remote_addr) {
if packets_sender.send(packet_data).await.is_err() {
// it can happen that the connection is established but the channel is closed because the task completed
Expand All @@ -282,11 +291,13 @@ impl<S: Socket> UdpPacketsListener<S> {
continue;
}
let packet_data = PacketData::from_buf(&buf[..size]);
let gw_ongoing_connection = self.gateway_connection(packet_data, remote_addr, gw_outbound_tx.clone())
.instrument(tracing::span!(tracing::Level::DEBUG, "gateway_connection"));
let task = tokio::spawn(gw_ongoing_connection.map_err(move |error| {
let (gw_ongoing_connection, packets_sender) = self.gateway_connection(packet_data, remote_addr);
let task = tokio::spawn(gw_ongoing_connection
.instrument(tracing::span!(tracing::Level::DEBUG, "gateway_connection"))
.map_err(move |error| {
(error, remote_addr)
}));
ongoing_gw_connections.insert(remote_addr, packets_sender);
gw_connection_tasks.push(task);
}
Err(e) => {
Expand All @@ -296,24 +307,14 @@ impl<S: Socket> UdpPacketsListener<S> {
}
}
},
req = gw_inbound_rx.recv() => {
let Some(GatewayMessage { remote_addr, packet, resp_tx }) = req else {
unreachable!();
};

if let Some(remote) = self.remote_connections.remove(&remote_addr) {
let _ = remote.inbound_packet_sender.send(packet).await;
self.remote_connections.insert(remote_addr, remote);
let _ = resp_tx.send(true);
}
}
gw_connection_handshake = gw_connection_tasks.next(), if !gw_connection_tasks.is_empty() => {
let Some(res): GwOngoingConnectionResult = gw_connection_handshake else {
unreachable!();
};
match res.expect("task shouldn't panic") {
Ok((outbound_remote_conn, inbound_remote_connection, outbound_ack_packet)) => {
let remote_addr = outbound_remote_conn.remote_addr;
ongoing_gw_connections.remove(&remote_addr);
let sent_tracker = outbound_remote_conn.sent_tracker.clone();

self.remote_connections.insert(remote_addr, inbound_remote_connection);
Expand All @@ -333,6 +334,7 @@ impl<S: Socket> UdpPacketsListener<S> {
}
Err((error, remote_addr)) => {
tracing::error!(%error, ?remote_addr, "Failed to establish gateway connection");
ongoing_gw_connections.remove(&remote_addr);
ongoing_connections.remove(&remote_addr);
}
}
Expand Down Expand Up @@ -386,23 +388,26 @@ impl<S: Socket> UdpPacketsListener<S> {
&mut self,
remote_intro_packet: PacketData<UnknownEncryption>,
remote_addr: SocketAddr,
outbound_tx: mpsc::Sender<GatewayMessage>,
) -> impl Future<
Output = Result<
(
RemoteConnection,
InboundRemoteConnection,
PacketData<SymmetricAES>,
),
TransportError,
>,
> + Send
+ 'static {
) -> (
impl Future<
Output = Result<
(
RemoteConnection,
InboundRemoteConnection,
PacketData<SymmetricAES>,
),
TransportError,
>,
> + Send
+ 'static,
mpsc::Sender<PacketData<UnknownEncryption>>,
) {
let secret = self.this_peer_keypair.secret.clone();
let outbound_packets = self.outbound_packets.clone();
let socket_listener = self.socket_listener.clone();

async move {
let (inbound_from_remote, mut next_inbound) =
mpsc::channel::<PacketData<UnknownEncryption>>(1);
let f = async move {
let decrypted_intro_packet =
secret.decrypt(remote_intro_packet.data()).map_err(|err| {
tracing::debug!(%remote_addr, %err, "Failed to decrypt intro packet");
Expand Down Expand Up @@ -436,7 +441,6 @@ impl<S: Socket> UdpPacketsListener<S> {
let outbound_ack_packet =
SymmetricMessage::ack_ok(&outbound_key, inbound_key_bytes, remote_addr)?;

let mut buf = [0u8; MAX_PACKET_SIZE];
let mut waiting_time = INITIAL_INTERVAL;
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 30;
Expand All @@ -448,43 +452,21 @@ impl<S: Socket> UdpPacketsListener<S> {
.map_err(|_| TransportError::ChannelClosed)?;

// wait until the remote sends the ack packet
let timeout =
tokio::time::timeout(waiting_time, socket_listener.recv_from(&mut buf));
let timeout = tokio::time::timeout(waiting_time, next_inbound.recv());
match timeout.await {
Ok(Ok((size, remote))) => {
let packet: PacketData<UnknownEncryption> =
PacketData::from_buf(&buf[..size]);

let mut should_continue = false;

if remote != remote_addr {
let (tx, rx) = tokio::sync::oneshot::channel();
outbound_tx
.send(GatewayMessage {
remote_addr,
packet: packet.clone(),
resp_tx: tx,
})
.await
.map_err(|_| TransportError::ChannelClosed)?;

should_continue =
rx.await.map_err(|_| TransportError::ChannelClosed)?;
}

if should_continue {
continue;
}

Ok(Some(packet)) => {
let _ = packet.try_decrypt_sym(&inbound_key).map_err(|_| {
tracing::debug!(%remote_addr, "Failed to decrypt packet with inbound key");
TransportError::ConnectionEstablishmentFailure {
cause: "invalid symmetric key".into(),
}
})?;
}
Ok(Err(_)) => {
return Err(TransportError::ChannelClosed);
Ok(None) => {
tracing::debug!(%remote_addr, "connection timed out");
return Err(TransportError::ConnectionEstablishmentFailure {
cause: "connection close".into(),
});
}
Err(_) => {
attempts += 1;
Expand Down Expand Up @@ -525,7 +507,8 @@ impl<S: Socket> UdpPacketsListener<S> {

tracing::debug!("returning connection at gw");
Ok((remote_conn, inbound_conn, outbound_ack_packet))
}
};
(f, inbound_from_remote)
}

// TODO: this value should be set given exponential backoff and max timeout
Expand All @@ -535,7 +518,6 @@ impl<S: Socket> UdpPacketsListener<S> {
#[cfg(test)]
const NAT_TRAVERSAL_MAX_ATTEMPTS: usize = 10;

// #[tracing::instrument(level = "debug", fields(peer = %self.this_peer_keypair.public), skip_all)]
fn traverse_nat(
&mut self,
remote_addr: SocketAddr,
Expand Down

0 comments on commit 65ef357

Please sign in to comment.