diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 626e61a3..997b0fa4 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -207,6 +207,9 @@ impl Send { // Transition the state to reset no matter what. stream.state.set_reset(stream_id, reason, initiator); + // Notify the recv task if it's waiting, because it'll + // want to hear about the reset. + stream.notify_recv(); // If closed AND the send queue is flushed, then the stream cannot be // reset explicitly, either. Implicit resets can still be queued. diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 9cc2f91e..50be06e6 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1,10 +1,10 @@ -use futures::future::{join, ready, select, Either}; +use futures::future::{join, join_all, ready, select, Either}; use futures::stream::FuturesUnordered; use futures::StreamExt; use h2_support::prelude::*; -use std::io; use std::pin::Pin; use std::task::Context; +use std::{io, panic}; #[tokio::test] async fn handshake() { @@ -849,7 +849,7 @@ async fn recv_too_big_headers() { }; let client = async move { - let (mut client, mut conn) = client::Builder::new() + let (mut client, conn) = client::Builder::new() .max_header_list_size(10) .handshake::<_, Bytes>(io) .await @@ -861,10 +861,12 @@ async fn recv_too_big_headers() { .unwrap(); let req1 = client.send_request(request, true); - let req1 = async move { + // Spawn tasks to ensure that the error wakes up tasks that are blocked + // waiting for a response. + let req1 = tokio::spawn(async move { let err = req1.expect("send_request").0.await.expect_err("response1"); assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); - }; + }); let request = Request::builder() .uri("https://http2.akamai.com/") @@ -872,14 +874,21 @@ async fn recv_too_big_headers() { .unwrap(); let req2 = client.send_request(request, true); - let req2 = async move { + let req2 = tokio::spawn(async move { let err = req2.expect("send_request").0.await.expect_err("response2"); assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); - }; + }); - conn.drive(join(req1, req2)).await; - conn.await.expect("client"); + let conn = tokio::spawn(async move { + conn.await.expect("client"); + }); + for err in join_all([req1, req2, conn]).await { + if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) { + std::panic::resume_unwind(err); + } + } }; + join(srv, client).await; }