diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 0ccf854c91..78a93ffab0 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -2422,13 +2422,15 @@ impl Connection { self.loss_recovery.on_packet_sent(path, sent); } - if *space == PacketNumberSpace::Handshake - && self.role == Role::Server - && self.state == State::Confirmed - { - // We could discard handshake keys in set_state, - // but wait until after sending an ACK. - self.discard_keys(PacketNumberSpace::Handshake, now); + if *space == PacketNumberSpace::Handshake { + if self.role == Role::Client { + // We're sending a Handshake packet, so we can discard Initial keys. + self.discard_keys(PacketNumberSpace::Initial, now); + } else if self.role == Role::Server && self.state == State::Confirmed { + // We could discard handshake keys in set_state, + // but wait until after sending an ACK. + self.discard_keys(PacketNumberSpace::Handshake, now); + } } } @@ -2779,11 +2781,6 @@ impl Connection { self.set_initial_limits(); } if self.crypto.install_keys(self.role)? { - if self.role == Role::Client { - // We won't acknowledge Initial packets as a result of this, but the - // server can rely on implicit acknowledgment. - self.discard_keys(PacketNumberSpace::Initial, now); - } self.saved_datagrams.make_available(CryptoSpace::Handshake); } } diff --git a/neqo-transport/src/connection/tests/handshake.rs b/neqo-transport/src/connection/tests/handshake.rs index b70b024c79..b9852bfc6f 100644 --- a/neqo-transport/src/connection/tests/handshake.rs +++ b/neqo-transport/src/connection/tests/handshake.rs @@ -591,8 +591,9 @@ fn reorder_1rtt() { now += RTT / 2; let s2 = server.process(c2.as_ref(), now).dgram(); // The server has now received those packets, and saved them. - // The two additional are a Handshake and a 1-RTT (w/ NEW_CONNECTION_ID). - assert_eq!(server.stats().packets_rx, PACKETS * 2 + 4); + // The two additional are an Initial w/ACK, a Handshake w/ACK and a 1-RTT (w/ + // NEW_CONNECTION_ID). + assert_eq!(server.stats().packets_rx, PACKETS * 2 + 5); assert_eq!(server.stats().saved_datagrams, PACKETS); assert_eq!(server.stats().dropped_rx, 1); assert_eq!(*server.state(), State::Confirmed); @@ -802,9 +803,9 @@ fn anti_amplification() { let ack = client.process(Some(&s_init3), now).dgram().unwrap(); assert!(!maybe_authenticate(&mut client)); // No need yet. - // The client sends a padded datagram, with just ACK for Handshake. - assert_eq!(client.stats().frame_tx.ack, ack_count + 1); - assert_eq!(client.stats().frame_tx.all(), frame_count + 1); + // The client sends a padded datagram, with just ACKs for Initial and Handshake. + assert_eq!(client.stats().frame_tx.ack, ack_count + 2); + assert_eq!(client.stats().frame_tx.all(), frame_count + 2); assert_ne!(ack.len(), client.plpmtu()); // Not padded (it includes Handshake). now += DEFAULT_RTT / 2; @@ -1051,25 +1052,24 @@ fn only_server_initial() { let (initial, handshake) = split_datagram(&server_dgram1.unwrap()); assert!(handshake.is_some()); - // The client will not acknowledge the Initial as it discards keys. - // It sends a Handshake probe instead, containing just a PING frame. - assert_eq!(client.stats().frame_tx.ping, 0); + // The client sends an Initial ACK. + assert_eq!(client.stats().frame_tx.ack, 0); let probe = client.process(Some(&initial), now).dgram(); - assertions::assert_handshake(&probe.unwrap()); + assertions::assert_initial(&probe.unwrap(), false); assert_eq!(client.stats().dropped_rx, 0); - assert_eq!(client.stats().frame_tx.ping, 1); + assert_eq!(client.stats().frame_tx.ack, 1); let (initial, handshake) = split_datagram(&server_dgram2.unwrap()); assert!(handshake.is_some()); - // The same happens after a PTO, even though the client will discard the Initial packet. + // The same happens after a PTO. now += AT_LEAST_PTO; - assert_eq!(client.stats().frame_tx.ping, 1); + assert_eq!(client.stats().frame_tx.ack, 1); let discarded = client.stats().dropped_rx; let probe = client.process(Some(&initial), now).dgram(); - assertions::assert_handshake(&probe.unwrap()); - assert_eq!(client.stats().frame_tx.ping, 2); - assert_eq!(client.stats().dropped_rx, discarded + 1); + assertions::assert_initial(&probe.unwrap(), false); + assert_eq!(client.stats().frame_tx.ack, 2); + assert_eq!(client.stats().dropped_rx, discarded); // Pass the Handshake packet and complete the handshake. client.process_input(&handshake.unwrap(), now); @@ -1136,7 +1136,9 @@ fn implicit_rtt_server() { let dgram = server.process(dgram.as_ref(), now).dgram(); now += RTT / 2; let dgram = client.process(dgram.as_ref(), now).dgram(); - assertions::assert_handshake(dgram.as_ref().unwrap()); + let (initial, handshake) = split_datagram(dgram.as_ref().unwrap()); + assertions::assert_initial(&initial, false); + assertions::assert_handshake(handshake.as_ref().unwrap()); now += RTT / 2; server.process_input(&dgram.unwrap(), now); diff --git a/neqo-transport/src/connection/tests/recovery.rs b/neqo-transport/src/connection/tests/recovery.rs index a97df1ca64..bc90fcee82 100644 --- a/neqo-transport/src/connection/tests/recovery.rs +++ b/neqo-transport/src/connection/tests/recovery.rs @@ -224,7 +224,9 @@ fn pto_handshake_complete() { now += HALF_RTT; let pkt = client.process(pkt.as_ref(), now).dgram(); - assert_handshake(pkt.as_ref().unwrap()); + let (initial, handshake) = split_datagram(&pkt.clone().unwrap()); + assert_initial(&initial, false); + assert_handshake(handshake.as_ref().unwrap()); let cb = client.process(None, now).callback(); // The client now has a single RTT estimate (20ms), so