Skip to content

Commit

Permalink
Make gateways connection attempts parallel (#1050)
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n authored Apr 22, 2024
1 parent c878a30 commit a6bc1c6
Showing 1 changed file with 195 additions and 102 deletions.
297 changes: 195 additions & 102 deletions crates/core/src/transport/connection_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot};
use tokio::task;

use super::packet_data::SymmetricAES;
use super::{
crypto::{TransportKeypair, TransportPublicKey},
packet_data::MAX_PACKET_SIZE,
Expand Down Expand Up @@ -50,6 +51,12 @@ struct OutboundMessage {
recv: mpsc::Receiver<SerializedMessage>,
}

struct GatewayMessage {
remote_addr: SocketAddr,
packet: PacketData<UnknownEncryption>,
resp_tx: oneshot::Sender<bool>,
}

pub(crate) struct ConnectionHandler {
send_queue: mpsc::Sender<(SocketAddr, ConnectionEvent)>,
new_connection_notifier: mpsc::Receiver<PeerConnection>,
Expand Down Expand Up @@ -190,6 +197,20 @@ type OngoingConnectionResult = Option<
>,
>;

type GwOngoingConnectionResult = Option<
Result<
Result<
(
RemoteConnection,
InboundRemoteConnection,
PacketData<SymmetricAES>,
),
(TransportError, SocketAddr),
>,
tokio::task::JoinError,
>,
>;

#[cfg(test)]
impl<T> Drop for UdpPacketsListener<T> {
fn drop(&mut self) {
Expand All @@ -201,7 +222,10 @@ impl<S: Socket> UdpPacketsListener<S> {
async fn listen(mut self) -> Result<(), TransportError> {
let mut buf = [0u8; MAX_PACKET_SIZE];
let mut ongoing_connections: BTreeMap<SocketAddr, OngoingConnection> = BTreeMap::new();
let mut gw_ongoing_connections: BTreeMap<SocketAddr, OngoingConnection> = BTreeMap::new();
let mut connection_tasks = FuturesUnordered::new();
let mut gw_connection_tasks = FuturesUnordered::new();
let (gw_outbound_tx, mut gw_inbound_rx) = tokio::sync::mpsc::channel(100);
loop {
tokio::select! {
// Handling of inbound packets
Expand Down Expand Up @@ -230,10 +254,11 @@ impl<S: Socket> UdpPacketsListener<S> {
continue;
}
let packet_data = PacketData::from_buf(&buf[..size]);
// FIXME: also parallelize this like we do with nat_traversal future
if let Err(error) = self.gateway_connection(packet_data, remote_addr).await {
tracing::error!(%error, ?remote_addr, "Failed to establish connection");
}
let gw_ongoing_connection = self.gateway_connection(packet_data, remote_addr, gw_outbound_tx.clone()).await;
let task = tokio::spawn(gw_ongoing_connection.map_err(move |error| {
(error, remote_addr)
}));
gw_connection_tasks.push(task);
}
Err(e) => {
// TODO: this should panic and be propagate to the main task or retry and eventually fail
Expand All @@ -242,6 +267,49 @@ impl<S: Socket> UdpPacketsListener<S> {
}
}
},
req = gw_inbound_rx.recv() => {
let Some(GatewayMessage { remote_addr, packet, resp_tx }) = req else {
unreachable!();
};

if let Some(remote) = self.remote_connections.remove(&remote_addr) {
let _ = remote.inbound_packet_sender.send(packet).await;
self.remote_connections.insert(remote_addr, remote);
let _ = resp_tx.send(true);
}
}
gw_connection_handshake = gw_connection_tasks.next(), if !gw_connection_tasks.is_empty() => {
let Some(res): GwOngoingConnectionResult = gw_connection_handshake else {
unreachable!();
};
match res.expect("task shouldn't panic") {
Ok((outbound_remote_conn, inbound_remote_connection, outbound_ack_packet)) => {
let remote_addr = outbound_remote_conn.remote_addr;
let sent_tracker = outbound_remote_conn.sent_tracker.clone();

self.remote_connections.insert(remote_addr, inbound_remote_connection);

if let Err(e) = self.new_connection_notifier
.send(PeerConnection::new(outbound_remote_conn))
.await
.map_err(|_| TransportError::ChannelClosed) {
tracing::error!(%remote_addr, %e, "gateway connection established but failed to notify new connection");
continue;
}

sent_tracker.lock().report_sent_packet(
SymmetricMessage::FIRST_PACKET_ID,
outbound_ack_packet.prepared_send(),
);
}
Err((error, remote_addr)) => {
tracing::error!(%error, ?remote_addr, "Failed to establish gateway connection");
if let Some((_, result_sender)) = gw_ongoing_connections.remove(&remote_addr) {
let _ = result_sender.send(Err(error));
}
}
}
}
connection_handshake = connection_tasks.next(), if !connection_tasks.is_empty() => {
let Some(res): OngoingConnectionResult = connection_handshake else {
unreachable!();
Expand Down Expand Up @@ -286,117 +354,142 @@ impl<S: Socket> UdpPacketsListener<S> {
&mut self,
remote_intro_packet: PacketData<UnknownEncryption>,
remote_addr: SocketAddr,
) -> Result<(), TransportError> {
outbound_tx: mpsc::Sender<GatewayMessage>,
) -> impl Future<
Output = Result<
(
RemoteConnection,
InboundRemoteConnection,
PacketData<SymmetricAES>,
),
TransportError,
>,
> + Send
+ 'static {
tracing::debug!(%remote_addr, "new connection to gateway");
let Ok(decrypted_intro_packet) = self
.this_peer_keypair
.secret
.decrypt(remote_intro_packet.data())
else {
tracing::debug!(%remote_addr, "failed to decrypt packet with private key");
return Ok(());
};
let protoc = &decrypted_intro_packet[..PROTOC_VERSION.len()];
let outbound_key_bytes =
&decrypted_intro_packet[PROTOC_VERSION.len()..PROTOC_VERSION.len() + 16];
let outbound_key = Aes128Gcm::new_from_slice(outbound_key_bytes).map_err(|_| {
TransportError::ConnectionEstablishmentFailure {
cause: "invalid symmetric key".into(),
let secret = self.this_peer_keypair.secret.clone();
let outbound_packets = self.outbound_packets.clone();
let socket_listener = self.socket_listener.clone();

async move {
let decrypted_intro_packet = secret.decrypt(remote_intro_packet.data())?;
let protoc = &decrypted_intro_packet[..PROTOC_VERSION.len()];
let outbound_key_bytes =
&decrypted_intro_packet[PROTOC_VERSION.len()..PROTOC_VERSION.len() + 16];
let outbound_key = Aes128Gcm::new_from_slice(outbound_key_bytes).map_err(|_| {
TransportError::ConnectionEstablishmentFailure {
cause: "invalid symmetric key".into(),
}
})?;
if protoc != PROTOC_VERSION {
let packet = SymmetricMessage::ack_error(&outbound_key)?;
outbound_packets
.send((remote_addr, packet.prepared_send()))
.await
.map_err(|_| TransportError::ChannelClosed)?;
return Err(TransportError::ConnectionEstablishmentFailure {
cause: format!(
"remote is using a different protocol version: {:?}",
String::from_utf8_lossy(protoc)
)
.into(),
});
}
})?;
if protoc != PROTOC_VERSION {
let packet = SymmetricMessage::ack_error(&outbound_key)?;
self.outbound_packets
.send((remote_addr, packet.prepared_send()))
.await
.map_err(|_| TransportError::ChannelClosed)?;
return Err(TransportError::ConnectionEstablishmentFailure {
cause: format!(
"remote is using a different protocol version: {:?}",
String::from_utf8_lossy(protoc)
)
.into(),
});
}

let inbound_key_bytes = rand::random::<[u8; 16]>();
let inbound_key = Aes128Gcm::new(&inbound_key_bytes.into());
let outbound_ack_packet =
SymmetricMessage::ack_ok(&outbound_key, inbound_key_bytes, remote_addr)?;
let inbound_key_bytes = rand::random::<[u8; 16]>();
let inbound_key = Aes128Gcm::new(&inbound_key_bytes.into());
let outbound_ack_packet =
SymmetricMessage::ack_ok(&outbound_key, inbound_key_bytes, remote_addr)?;

let mut buf = [0u8; MAX_PACKET_SIZE];
let mut waiting_time = INITIAL_INTERVAL;
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 20;

while attempts < MAX_ATTEMPTS {
outbound_packets
.send((remote_addr, outbound_ack_packet.clone().prepared_send()))
.await
.map_err(|_| TransportError::ChannelClosed)?;

// wait until the remote sends the ack packet
let timeout =
tokio::time::timeout(waiting_time, socket_listener.recv_from(&mut buf));
match timeout.await {
Ok(Ok((size, remote))) => {
let packet: PacketData<UnknownEncryption> =
PacketData::from_buf(&buf[..size]);

let mut should_continue = false;

if remote != remote_addr {
let (tx, rx) = tokio::sync::oneshot::channel();
outbound_tx
.send(GatewayMessage {
remote_addr,
packet: packet.clone(),
resp_tx: tx,
})
.await
.map_err(|_| TransportError::ChannelClosed)?;

should_continue =
rx.await.map_err(|_| TransportError::ChannelClosed)?;
}

let mut buf = [0u8; MAX_PACKET_SIZE];
let mut waiting_time = INITIAL_INTERVAL;
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 20;
while attempts < MAX_ATTEMPTS {
self.outbound_packets
.send((remote_addr, outbound_ack_packet.clone().prepared_send()))
.await
.map_err(|_| TransportError::ChannelClosed)?;

// wait until the remote sends the ack packet
let timeout =
tokio::time::timeout(waiting_time, self.socket_listener.recv_from(&mut buf));
match timeout.await {
Ok(Ok((size, remote))) => {
let packet = PacketData::from_buf(&buf[..size]);
if remote != remote_addr {
if let Some(remote) = self.remote_connections.remove(&remote_addr) {
let _ = remote.inbound_packet_sender.send(packet).await;
self.remote_connections.insert(remote_addr, remote);
if should_continue {
continue;
}

let _ = packet.try_decrypt_sym(&inbound_key).map_err(|_| {
tracing::debug!(%remote_addr, "Failed to decrypt packet with inbound key");
TransportError::ConnectionEstablishmentFailure {
cause: "invalid symmetric key".into(),
}
})?;
}
Ok(Err(_)) => {
return Err(TransportError::ChannelClosed);
}
Err(_) => {
attempts += 1;
waiting_time = std::cmp::min(
Duration::from_millis(
waiting_time.as_millis() as u64 * INTERVAL_INCREASE_FACTOR,
),
MAX_INTERVAL,
);
continue;
}
let _ = packet.try_decrypt_sym(&inbound_key).map_err(|_| {
tracing::debug!(%remote_addr, "Failed to decrypt packet with inbound key");
TransportError::ConnectionEstablishmentFailure {
cause: "invalid symmetric key".into(),
}
})?;
}
Ok(Err(_)) => {
return Err(TransportError::ChannelClosed);
}
Err(_) => {
attempts += 1;
waiting_time = std::cmp::min(
Duration::from_millis(
waiting_time.as_millis() as u64 * INTERVAL_INCREASE_FACTOR,
),
MAX_INTERVAL,
);
continue;
}
// we know the inbound is successfully connected now and can proceed
// ignoring this will force them to resend the packet but that is fine and simpler
break;
}
// we know the inbound is successfully connected now and can proceed
// ignoring this will force them to resend the packet but that is fine and simpler
break;
}

let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new()));
let peer_connection = PeerConnection::new(RemoteConnection {
outbound_packets: self.outbound_packets.clone(),
outbound_symmetric_key: outbound_key,
remote_addr,
sent_tracker: sent_tracker.clone(),
last_packet_id: Arc::new(AtomicU32::new(0)),
inbound_packet_recv: mpsc::channel(100).1,
inbound_symmetric_key: inbound_key,
inbound_symmetric_key_bytes: inbound_key_bytes,
my_address: None,
});
let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new()));

self.new_connection_notifier
.send(peer_connection)
.await
.map_err(|_| TransportError::ChannelClosed)?;
let (inbound_packet_tx, inbound_packet_rx) = mpsc::channel(100);
let remote_conn = RemoteConnection {
outbound_packets,
outbound_symmetric_key: outbound_key,
remote_addr,
sent_tracker: sent_tracker.clone(),
last_packet_id: Arc::new(AtomicU32::new(0)),
inbound_packet_recv: inbound_packet_rx,
inbound_symmetric_key: inbound_key,
inbound_symmetric_key_bytes: inbound_key_bytes,
my_address: None,
};

sent_tracker.lock().report_sent_packet(
SymmetricMessage::FIRST_PACKET_ID,
outbound_ack_packet.prepared_send(),
);
let inbound_conn = InboundRemoteConnection {
inbound_packet_sender: inbound_packet_tx,
inbound_intro_packet: None,
inbound_checked_times: 0,
};

Ok(())
Ok((remote_conn, inbound_conn, outbound_ack_packet))
}
}

// TODO: this value should be set given exponential backoff and max timeout
Expand Down

0 comments on commit a6bc1c6

Please sign in to comment.