Skip to content

Commit

Permalink
Allocate persistent connection IDs immediately when generated
Browse files Browse the repository at this point in the history
Guards against future bugs where multiple (e.g. concurrent) calls to
`new_cid` might otherwise lead to an undetected collision.
  • Loading branch information
Ralith committed Jun 3, 2023
1 parent c4367ae commit 26b2c37
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{hash_map, HashMap},
convert::TryFrom,
fmt, iter,
net::{IpAddr, SocketAddr},
Expand Down Expand Up @@ -324,7 +324,8 @@ impl Endpoint {
let remote_id = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
trace!(initial_dcid = %remote_id);

let loc_cid = self.new_cid();
let ch = ConnectionHandle(self.connections.vacant_key());
let loc_cid = self.new_cid(ch);
let params = TransportParameters::new(
&config.transport,
&self.config,
Expand All @@ -336,7 +337,7 @@ impl Endpoint {
.crypto
.start_session(config.version, server_name, &params)?;

let (ch, conn) = self.add_connection(
let conn = self.add_connection(
config.version,
remote_id,
loc_cid,
Expand All @@ -349,6 +350,7 @@ impl Endpoint {
tls,
None,
config.transport,
ch,
);
Ok((ch, conn))
}
Expand All @@ -361,8 +363,7 @@ impl Endpoint {
) -> ConnectionEvent {
let mut ids = vec![];
for _ in 0..num {
let id = self.new_cid();
self.index.insert_cid(id, ch);
let id = self.new_cid(ch);
let meta = &mut self.connections[ch];
meta.cids_issued += 1;
let sequence = meta.cids_issued;
Expand All @@ -376,10 +377,12 @@ impl Endpoint {
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
}

fn new_cid(&mut self) -> ConnectionId {
/// Generate a connection ID for `ch`
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
loop {
let cid = self.local_cid_generator.generate_cid();
if !self.index.connection_ids.contains_key(&cid) {
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
e.insert(ch);
break cid;
}
assert!(self.local_cid_generator.cid_len() > 0);
Expand Down Expand Up @@ -423,8 +426,7 @@ impl Endpoint {
return None;
}

let loc_cid = self.new_cid();
let server_config = self.server_config.as_ref().unwrap();
let server_config = self.server_config.as_ref().unwrap().clone();

if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full()
{
Expand All @@ -434,7 +436,6 @@ impl Endpoint {
addresses,
crypto,
&src_cid,
&loc_cid,
TransportError::CONNECTION_REFUSED(""),
)));
}
Expand All @@ -451,7 +452,6 @@ impl Endpoint {
addresses,
crypto,
&src_cid,
&loc_cid,
TransportError::PROTOCOL_VIOLATION("invalid destination CID length"),
)));
}
Expand All @@ -461,6 +461,12 @@ impl Endpoint {
// First Initial
let mut random_bytes = vec![0u8; RetryToken::RANDOM_BYTES_LEN];
self.rng.fill_bytes(&mut random_bytes);
// The peer will use this as the DCID of its following Initials. Initial DCIDs are
// looked up separately from Handshake/Data DCIDs, so there is no risk of collision
// with established connections. In the unlikely event that a collision occurs
// between two connections in the initial phase, both will fail fast and may be
// retried by the application layer.
let loc_cid = self.local_cid_generator.generate_cid();

let token = RetryToken {
orig_dst_cid: dst_cid,
Expand Down Expand Up @@ -508,7 +514,6 @@ impl Endpoint {
addresses,
crypto,
&src_cid,
&loc_cid,
TransportError::INVALID_TOKEN(""),
)));
}
Expand All @@ -517,7 +522,8 @@ impl Endpoint {
(None, dst_cid)
};

let server_config = server_config.clone();
let ch = ConnectionHandle(self.connections.vacant_key());
let loc_cid = self.new_cid(ch);
let mut params = TransportParameters::new(
&server_config.transport,
&self.config,
Expand All @@ -531,7 +537,7 @@ impl Endpoint {

let tls = server_config.crypto.clone().start_session(version, &params);
let transport_config = server_config.transport.clone();
let (ch, mut conn) = self.add_connection(
let mut conn = self.add_connection(
version,
dst_cid,
loc_cid,
Expand All @@ -541,6 +547,7 @@ impl Endpoint {
tls,
Some(server_config),
transport_config,
ch,
);
if dst_cid.len() != 0 {
self.index.insert_initial(dst_cid, ch);
Expand All @@ -554,9 +561,9 @@ impl Endpoint {
debug!("handshake failed: {}", e);
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
if let ConnectionError::TransportError(e) = e {
Some(DatagramEvent::Response(self.initial_close(
version, addresses, crypto, &src_cid, &loc_cid, e,
)))
Some(DatagramEvent::Response(
self.initial_close(version, addresses, crypto, &src_cid, e),
))
} else {
None
}
Expand All @@ -575,7 +582,8 @@ impl Endpoint {
tls: Box<dyn crypto::Session>,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
) -> (ConnectionHandle, Connection) {
ch: ConnectionHandle,
) -> Connection {
let conn = Connection::new(
self.config.clone(),
server_config,
Expand All @@ -599,11 +607,11 @@ impl Endpoint {
addresses,
reset_token: None,
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");

let ch = ConnectionHandle(id);
self.index.insert_conn(addresses, loc_cid, ch);

(ch, conn)
conn
}

fn initial_close(
Expand All @@ -612,13 +620,16 @@ impl Endpoint {
addresses: FourTuple,
crypto: &Keys,
remote_id: &ConnectionId,
local_id: &ConnectionId,
reason: TransportError,
) -> Transmit {
// We don't need to worry about CID collisions in initial closes because the peer
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
// unexpected response.
let local_id = self.local_cid_generator.generate_cid();
let number = PacketNumber::U8(0);
let header = Header::Initial {
dst_cid: *remote_id,
src_cid: *local_id,
src_cid: local_id,
number,
token: Bytes::new(),
version,
Expand Down Expand Up @@ -736,11 +747,6 @@ impl ConnectionIndex {
}
}

/// Add a new CID to an existing connection
fn insert_cid(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
self.connection_ids.insert(dst_cid, connection);
}

/// Discard a connection ID
fn retire(&mut self, dst_cid: &ConnectionId) {
self.connection_ids.remove(dst_cid);
Expand Down

0 comments on commit 26b2c37

Please sign in to comment.