From 6d6fb1ab43be75fd8003d0a22e4ef7d166031f83 Mon Sep 17 00:00:00 2001 From: Samuel Williams Date: Mon, 23 Sep 2024 12:26:01 +1200 Subject: [PATCH] Wait for input to finish even when streaming, if necessary. --- lib/async/http/protocol/http1/finishable.rb | 8 ++++++-- lib/async/http/protocol/http1/server.rb | 15 +++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/async/http/protocol/http1/finishable.rb b/lib/async/http/protocol/http1/finishable.rb index 9b2195c..5c972e5 100644 --- a/lib/async/http/protocol/http1/finishable.rb +++ b/lib/async/http/protocol/http1/finishable.rb @@ -40,11 +40,15 @@ def close(error = nil) super end - def wait + def wait(persistent = true) if @reading @closed.wait - else + elsif persistent + # If the connection can be reused, let's gracefully discard the body: self.discard + else + # Else, we don't care about the body, so we can close it immediately: + self.close end end diff --git a/lib/async/http/protocol/http1/server.rb b/lib/async/http/protocol/http1/server.rb index d75337c..e4379dd 100644 --- a/lib/async/http/protocol/http1/server.rb +++ b/lib/async/http/protocol/http1/server.rb @@ -96,8 +96,8 @@ def each(task: Task.current) request = nil response = nil - # We must return here as no further request processing can be done: - return body.call(stream) + # In the case of streaming, `finishable` should wrap a `Remainder` body, which we can safely discard later on. + body.call(stream) elsif response.status == 101 # This code path is to support legacy behavior where the response status is set to 101, but the protocol is not upgraded. This may not be a valid use case, but it is supported for compatibility. We expect the response headers to contain the `upgrade` header. write_response(@version, response.status, response.headers) @@ -108,8 +108,7 @@ def each(task: Task.current) request = nil response = nil - # We must return here as no further request processing can be done: - return body&.call(stream) + body&.call(stream) else write_response(@version, response.status, response.headers) @@ -143,8 +142,12 @@ def each(task: Task.current) request&.finish end - # Discard or wait for the input body to be consumed: - finishable&.wait + if finishable + finishable.wait(@persistent) + else + # Do not remove this line or you will unleash the gods of concurrency hell. + task.yield + end rescue => error raise ensure