diff --git a/neqo-transport/src/cc/classic_cc.rs b/neqo-transport/src/cc/classic_cc.rs index 81d2abf270..6914e91f67 100644 --- a/neqo-transport/src/cc/classic_cc.rs +++ b/neqo-transport/src/cc/classic_cc.rs @@ -298,6 +298,14 @@ impl CongestionControl for ClassicCongestionControl { congestion || persistent_congestion } + /// Report received ECN CE mark(s) to the congestion controller as a + /// congestion event. + /// + /// See . + fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool { + self.on_congestion_event(largest_acked_pkt) + } + fn discard(&mut self, pkt: &SentPacket) { if pkt.cc_outstanding() { assert!(self.bytes_in_flight >= pkt.size); @@ -488,8 +496,8 @@ impl ClassicCongestionControl { /// Handle a congestion event. /// Returns true if this was a true congestion event. fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool { - // Start a new congestion event if lost packet was sent after the start - // of the previous congestion recovery period. + // Start a new congestion event if lost or ECN CE marked packet was sent + // after the start of the previous congestion recovery period. if !self.after_recovery_start(last_packet) { return false; } @@ -1189,4 +1197,26 @@ mod tests { last_acked_bytes = cc.acked_bytes; } } + + #[test] + fn ecn_ce() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let p_ce = SentPacket::new( + PacketType::Short, + 1, + IpTosEcn::default(), + now(), + true, + Vec::new(), + MAX_DATAGRAM_SIZE, + ); + cc.on_packet_sent(&p_ce); + cwnd_is_default(&cc); + assert_eq!(cc.state, State::SlowStart); + + // Signal congestion (ECN CE) and thus change state to recovery start. + cc.on_ecn_ce_received(&p_ce); + cwnd_is_halved(&cc); + assert_eq!(cc.state, State::RecoveryStart); + } } diff --git a/neqo-transport/src/cc/mod.rs b/neqo-transport/src/cc/mod.rs index 486d15e67e..2adffbc0c4 100644 --- a/neqo-transport/src/cc/mod.rs +++ b/neqo-transport/src/cc/mod.rs @@ -53,6 +53,9 @@ pub trait CongestionControl: Display + Debug { lost_packets: &[SentPacket], ) -> bool; + /// Returns true if the congestion window was reduced. + fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool; + #[must_use] fn recovery_packet(&self) -> bool; diff --git a/neqo-transport/src/connection/tests/cc.rs b/neqo-transport/src/connection/tests/cc.rs index b708bc421d..f21f4e184f 100644 --- a/neqo-transport/src/connection/tests/cc.rs +++ b/neqo-transport/src/connection/tests/cc.rs @@ -6,7 +6,7 @@ use std::{mem, time::Duration}; -use neqo_common::{qdebug, qinfo, Datagram}; +use neqo_common::{qdebug, qinfo, Datagram, IpTosEcn}; use super::{ super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets, @@ -36,9 +36,13 @@ fn cc_slow_start() { assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); } -#[test] -/// Verify that CC moves to cong avoidance when a packet is marked lost. -fn cc_slow_start_to_cong_avoidance_recovery_period() { +#[derive(PartialEq, Eq, Clone, Copy)] +enum CongestionSignal { + PacketLoss, + EcnCe, +} + +fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: CongestionSignal) { let mut client = default_client(); let mut server = default_server(); let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); @@ -78,9 +82,17 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); - // Server: Receive and generate ack again, but drop first packet + // Server: Receive and generate ack again, but this time add congestion + // signal first. now += DEFAULT_RTT / 2; - c_tx_dgrams.remove(0); + match congestion_signal { + CongestionSignal::PacketLoss => { + c_tx_dgrams.remove(0); + } + CongestionSignal::EcnCe => { + c_tx_dgrams.last_mut().unwrap().set_tos(IpTosEcn::Ce.into()); + } + } let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); assert_eq!( server.stats().frame_tx.largest_acknowledged, @@ -97,6 +109,18 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { assert!(cwnd(&client) < cwnd_before_cong); } +#[test] +/// Verify that CC moves to cong avoidance when a packet is marked lost. +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_packet_loss() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::PacketLoss); +} + +/// Verify that CC moves to cong avoidance when ACK is marked with ECN CE. +#[test] +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_ecn_ce() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::EcnCe); +} + #[test] /// Verify that CC stays in recovery period when packet sent before start of /// recovery period is acked. diff --git a/neqo-transport/src/ecn.rs b/neqo-transport/src/ecn.rs index 31334cb52d..20eb4da003 100644 --- a/neqo-transport/src/ecn.rs +++ b/neqo-transport/src/ecn.rs @@ -122,15 +122,36 @@ impl EcnInfo { } } + /// Process ECN counts from an ACK frame. + /// + /// Returns whether ECN counts contain new valid ECN CE marks. + pub fn on_packets_acked( + &mut self, + acked_packets: &[SentPacket], + ack_ecn: Option, + ) -> bool { + let prev_baseline = self.baseline; + + self.validate_ack_ecn_and_update(acked_packets, ack_ecn); + + matches!(self.state, EcnValidationState::Capable) + && (self.baseline - prev_baseline)[IpTosEcn::Ce] > 0 + } + /// After the ECN validation test has ended, check if the path is ECN capable. - pub fn validate_ack_ecn(&mut self, acked_packets: &[SentPacket], ack_ecn: Option) { + pub fn validate_ack_ecn_and_update( + &mut self, + acked_packets: &[SentPacket], + ack_ecn: Option, + ) { // RFC 9000, Appendix A.4: // // > From the "unknown" state, successful validation of the ECN counts in an ACK frame // > (see Section 13.4.2.1) causes the ECN state for the path to become "capable", unless // > no marked packet has been acknowledged. - if self.state != EcnValidationState::Unknown { - return; + match self.state { + EcnValidationState::Testing { .. } | EcnValidationState::Failed => return, + EcnValidationState::Unknown | EcnValidationState::Capable => {} } // RFC 9000, Section 13.4.2.1: diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index dc8834fdf4..0e4c82b1ca 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -977,8 +977,18 @@ impl Path { now: Instant, ) { debug_assert!(self.is_primary()); + + let ecn_ce_received = self.ecn_info.on_packets_acked(acked_pkts, ack_ecn); + if ecn_ce_received { + let cwnd_reduced = self + .sender + .on_ecn_ce_received(acked_pkts.first().expect("must be there")); + if cwnd_reduced { + self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu()); + } + } + self.sender.on_packets_acked(acked_pkts, &self.rtt, now); - self.ecn_info.validate_ack_ecn(acked_pkts, ack_ecn); } /// Record packets as lost with the sender. diff --git a/neqo-transport/src/recovery.rs b/neqo-transport/src/recovery.rs index 0a40f21ecd..22a635d9f3 100644 --- a/neqo-transport/src/recovery.rs +++ b/neqo-transport/src/recovery.rs @@ -695,10 +695,10 @@ impl LossRecovery { let (acked_packets, any_ack_eliciting) = space.remove_acked(acked_ranges, &mut self.stats.borrow_mut()); - if acked_packets.is_empty() { + let Some(largest_acked_pkt) = acked_packets.first() else { // No new information. return (Vec::new(), Vec::new()); - } + }; // Track largest PN acked per space let prev_largest_acked = space.largest_acked_sent_time; @@ -707,7 +707,6 @@ impl LossRecovery { // If the largest acknowledged is newly acked and any newly acked // packet was ack-eliciting, update the RTT. (-recovery 5.1) - let largest_acked_pkt = acked_packets.first().expect("must be there"); space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent); if any_ack_eliciting && largest_acked_pkt.on_primary_path() { self.rtt_sample( diff --git a/neqo-transport/src/sender.rs b/neqo-transport/src/sender.rs index 3a54851533..abb14d0a25 100644 --- a/neqo-transport/src/sender.rs +++ b/neqo-transport/src/sender.rs @@ -97,6 +97,11 @@ impl PacketSender { ) } + /// Called when ECN CE mark received. Returns true if the congestion window was reduced. + pub fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool { + self.cc.on_ecn_ce_received(largest_acked_pkt) + } + pub fn discard(&mut self, pkt: &SentPacket) { self.cc.discard(pkt); }