From 24581188be70d006d045b33525a1fe20658c49f9 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 25 Sep 2023 14:40:32 +0100 Subject: [PATCH] Make ChunkedRestResponseBody extend Releasable (#99871) If an unchunked body implements `Releasable` then it is released once the response is sent. We have some need for the same behaviour with chunked bodies, except that in this case there's no need for an `instanceof` check since we control the body type directly. This commit makes `ChunkedRestResponseBody extend Releasable` and adds support for closing it when the response is sent. --- .../Netty4HttpPipeliningHandlerTests.java | 3 + .../http/DefaultRestChannel.java | 7 +- .../rest/ChunkedRestResponseBody.java | 49 ++++++++++++- .../rest/LoggingChunkedRestResponseBody.java | 5 ++ .../org/elasticsearch/rest/RestResponse.java | 7 +- .../http/DefaultRestChannelTests.java | 28 +++++++- .../rest/ChunkedRestResponseBodyTests.java | 70 ++++++++++++------- 7 files changed, 139 insertions(+), 30 deletions(-) diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index 1a011db433a49..895c07ec7a3f6 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -481,6 +481,9 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return "application/octet-stream"; } + + @Override + public void close() {} }; } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java index b2d3afe30cc36..930b20b927bd8 100644 --- a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -117,6 +117,7 @@ public void sendResponse(RestResponse restResponse) { final HttpResponse httpResponse; if (isHeadRequest == false && restResponse.isChunked()) { ChunkedRestResponseBody chunkedContent = restResponse.chunkedContent(); + toClose.add(chunkedContent); if (httpLogger != null && httpLogger.isBodyTracerEnabled()) { final var loggerStream = httpLogger.openResponseBodyLoggingStream(request.getRequestId()); toClose.add(() -> { @@ -132,8 +133,10 @@ public void sendResponse(RestResponse restResponse) { httpResponse = httpRequest.createResponse(restResponse.status(), chunkedContent); } else { final BytesReference content = restResponse.content(); - if (content instanceof Releasable) { - toClose.add((Releasable) content); + if (content instanceof Releasable releasable) { + toClose.add(releasable); + } else if (restResponse.isChunked()) { + toClose.add(restResponse.chunkedContent()); } toClose.add(this::releaseOutputBuffer); diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java index 78e529eef2d98..fb73677e265f4 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java @@ -15,6 +15,8 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.Streams; @@ -32,7 +34,7 @@ * The body of a rest response that uses chunked HTTP encoding. Implementations are used to avoid materializing full responses on heap and * instead serialize only as much of the response as can be flushed to the network right away. */ -public interface ChunkedRestResponseBody { +public interface ChunkedRestResponseBody extends Releasable { /** * @return true once this response has been written fully. @@ -62,9 +64,29 @@ public interface ChunkedRestResponseBody { * @param params parameters to use for serialization * @param channel channel the response will be written to * @return chunked rest response body + * @deprecated Use {@link #fromXContent(ChunkedToXContent, ToXContent.Params, RestChannel, Releasable)} instead. */ + @Deprecated(forRemoval = true) static ChunkedRestResponseBody fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel) throws IOException { + return fromXContent(chunkedToXContent, params, channel, null); + } + + /** + * Create a chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}. + * + * @param chunkedToXContent chunked x-content instance to serialize + * @param params parameters to use for serialization + * @param channel channel the response will be written to + * @param releasable resource to release when the response is fully sent, or {@code null} if nothing to release + * @return chunked rest response body + */ + static ChunkedRestResponseBody fromXContent( + ChunkedToXContent chunkedToXContent, + ToXContent.Params params, + RestChannel channel, + @Nullable Releasable releasable + ) throws IOException { return new ChunkedRestResponseBody() { @@ -132,14 +154,34 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return builder.getResponseContentTypeString(); } + + @Override + public void close() { + Releasables.closeExpectNoException(releasable); + } }; } /** * Create a chunked response body to be written to a specific {@link RestChannel} from a stream of text chunks, each represented as a * consumer of a {@link Writer}. The last chunk that the iterator yields must write at least one byte. + * + * @deprecated Use {@link #fromTextChunks(String, Iterator, Releasable)} instead. */ + @Deprecated(forRemoval = true) static ChunkedRestResponseBody fromTextChunks(String contentType, Iterator> chunkIterator) { + return fromTextChunks(contentType, chunkIterator, null); + } + + /** + * Create a chunked response body to be written to a specific {@link RestChannel} from a stream of text chunks, each represented as a + * consumer of a {@link Writer}. The last chunk that the iterator yields must write at least one byte. + */ + static ChunkedRestResponseBody fromTextChunks( + String contentType, + Iterator> chunkIterator, + @Nullable Releasable releasable + ) { return new ChunkedRestResponseBody() { private RecyclerBytesStreamOutput currentOutput; private final Writer writer = new OutputStreamWriter(new OutputStream() { @@ -209,6 +251,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return contentType; } + + @Override + public void close() { + Releasables.closeExpectNoException(releasable); + } }; } } diff --git a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java index 0508828c70da1..00b56d0e05051 100644 --- a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java @@ -46,4 +46,9 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return inner.getResponseContentTypeString(); } + + @Override + public void close() { + inner.close(); + } } diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 3a82a827e3726..1e86b7ddae367 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -16,6 +16,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContent; @@ -81,7 +82,11 @@ public RestResponse(RestStatus status, String responseMediaType, BytesReference public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBody content) { if (content.isDone()) { - return new RestResponse(restStatus, content.getResponseContentTypeString(), BytesArray.EMPTY); + return new RestResponse( + restStatus, + content.getResponseContentTypeString(), + new ReleasableBytesReference(BytesArray.EMPTY, content) + ); } else { return new RestResponse(restStatus, content.getResponseContentTypeString(), null, content); } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index 6e0f58d0cdb97..63cdc2c485197 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -57,6 +57,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; @@ -525,6 +526,7 @@ public void testHandleHeadRequest() { } { // chunked response + final var isClosed = new AtomicBoolean(); channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { @Override @@ -541,11 +543,28 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return RestResponse.TEXT_CONTENT_TYPE; } + + @Override + public void close() { + assertTrue(isClosed.compareAndSet(false, true)); + } })); - verify(httpChannel, times(2)).sendResponse(requestCaptor.capture(), any()); + @SuppressWarnings("unchecked") + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel, times(2)).sendResponse(requestCaptor.capture(), listenerCaptor.capture()); HttpResponse response = requestCaptor.getValue(); assertThat(response, instanceOf(TestHttpResponse.class)); assertThat(((TestHttpResponse) response).content().length(), equalTo(0)); + + ActionListener listener = listenerCaptor.getValue(); + assertFalse(isClosed.get()); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + assertTrue(isClosed.get()); } } @@ -703,6 +722,7 @@ public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody co ) ); + final var isClosed = new AtomicBoolean(); assertEquals( responseBody, ChunkedLoggingStreamTests.getDecodedLoggedBody( @@ -730,10 +750,16 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return RestResponse.TEXT_CONTENT_TYPE; } + + @Override + public void close() { + assertTrue(isClosed.compareAndSet(false, true)); + } })) ) ); + assertTrue(isClosed.get()); } private TestHttpResponse executeRequest(final Settings settings, final String host) { diff --git a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java index 9842aff24dac1..485e2a3a3fdd7 100644 --- a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java +++ b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; public class ChunkedRestResponseBodyTests extends ESTestCase { @@ -50,40 +51,59 @@ public void testEncodesChunkedXContentCorrectly() throws IOException { } final var bytesDirect = BytesReference.bytes(builderDirect); - final var chunkedResponse = ChunkedRestResponseBody.fromXContent( - chunkedToXContent, - ToXContent.EMPTY_PARAMS, - new FakeRestChannel( - new FakeRestRequest.Builder(xContentRegistry()).withContent(BytesArray.EMPTY, randomXContent.type()).build(), - randomBoolean(), - 1 + final var isClosed = new AtomicBoolean(); + try ( + var chunkedResponse = ChunkedRestResponseBody.fromXContent( + chunkedToXContent, + ToXContent.EMPTY_PARAMS, + new FakeRestChannel( + new FakeRestRequest.Builder(xContentRegistry()).withContent(BytesArray.EMPTY, randomXContent.type()).build(), + randomBoolean(), + 1 + ), + () -> assertTrue(isClosed.compareAndSet(false, true)) ) - ); + ) { - final List refsGenerated = new ArrayList<>(); - while (chunkedResponse.isDone() == false) { - refsGenerated.add(chunkedResponse.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); - } + final List refsGenerated = new ArrayList<>(); + while (chunkedResponse.isDone() == false) { + refsGenerated.add(chunkedResponse.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + } - assertEquals(bytesDirect, CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0]))); + assertEquals(bytesDirect, CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0]))); + assertFalse(isClosed.get()); + } + assertTrue(isClosed.get()); } public void testFromTextChunks() throws IOException { final var chunks = randomList(1000, () -> randomUnicodeOfLengthBetween(1, 100)); - final var body = ChunkedRestResponseBody.fromTextChunks("text/plain", Iterators.map(chunks.iterator(), s -> w -> w.write(s))); - - final List refsGenerated = new ArrayList<>(); - while (body.isDone() == false) { - refsGenerated.add(body.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); - } - final BytesReference chunkedBytes = CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0])); + final var isClosed = new AtomicBoolean(); + try ( + var body = ChunkedRestResponseBody.fromTextChunks( + "text/plain", + Iterators.map(chunks.iterator(), s -> w -> w.write(s)), + () -> assertTrue(isClosed.compareAndSet(false, true)) + ) + ) { + final List refsGenerated = new ArrayList<>(); + while (body.isDone() == false) { + refsGenerated.add(body.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + } + final BytesReference chunkedBytes = CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0])); - try (var outputStream = new ByteArrayOutputStream(); var writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { - for (final var chunk : chunks) { - writer.write(chunk); + try ( + var outputStream = new ByteArrayOutputStream(); + var writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8) + ) { + for (final var chunk : chunks) { + writer.write(chunk); + } + writer.flush(); + assertEquals(new BytesArray(outputStream.toByteArray()), chunkedBytes); } - writer.flush(); - assertEquals(new BytesArray(outputStream.toByteArray()), chunkedBytes); + assertFalse(isClosed.get()); } + assertTrue(isClosed.get()); } }