diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4BlockingPlugin.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4BlockingPlugin.java index 0acf3e50325c7..ae17209543e5e 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4BlockingPlugin.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4BlockingPlugin.java @@ -25,10 +25,14 @@ import java.util.Map; import java.util.function.Supplier; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; @@ -98,18 +102,14 @@ public Map> getHttpTransports( } /** POC for how an external header verifier would be implemented */ - public class ExampleBlockingNetty4HeaderVerifier extends ChannelInboundHandlerAdapter { + public class ExampleBlockingNetty4HeaderVerifier extends SimpleChannelInboundHandler { @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (!(msg instanceof HttpRequest)) { - ctx.fireChannelRead(msg); - return; - } - - HttpRequest request = (HttpRequest) msg; - if (!isAuthenticated(request)) { - final FullHttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.UNAUTHORIZED); + public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception { + ReferenceCountUtil.retain(msg); + if (isBlocked(msg)) { + ByteBuf buf = Unpooled.copiedBuffer("Hit header_verifier".getBytes()); + final FullHttpResponse response = new DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.UNAUTHORIZED, buf); ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); ReferenceCountUtil.release(msg); } else { @@ -118,10 +118,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } - private boolean isAuthenticated(HttpRequest request) { + private boolean isBlocked(HttpRequest request) { final boolean shouldBlock = request.headers().contains("blockme"); - return !shouldBlock; + return shouldBlock; } } } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HeaderVerifierIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HeaderVerifierIT.java index 9ba93b8f91bf1..94e7a3240454b 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HeaderVerifierIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HeaderVerifierIT.java @@ -15,18 +15,22 @@ import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope; import org.opensearch.test.OpenSearchIntegTestCase.Scope; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; +import io.netty.buffer.ByteBufUtil; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http2.HttpConversionUtil; import io.netty.util.ReferenceCounted; +import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; import static io.netty.handler.codec.http.HttpHeaderNames.HOST; @@ -43,7 +47,6 @@ protected Collection> nodePlugins() { return Collections.singletonList(Netty4BlockingPlugin.class); } - @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/10260") public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exception { HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); @@ -52,12 +55,15 @@ public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exc final FullHttpRequest blockedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); blockedRequest.headers().add("blockme", "Not Allowed"); blockedRequest.headers().add(HOST, "localhost"); + blockedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); final List responses = new ArrayList<>(); - try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2()) { try { FullHttpResponse blockedResponse = nettyHttpClient.send(transportAddress.address(), blockedRequest); responses.add(blockedResponse); + String blockedResponseContent = new String(ByteBufUtil.getBytes(blockedResponse.content()), StandardCharsets.UTF_8); + assertThat(blockedResponseContent, containsString("Hit header_verifier")); assertThat(blockedResponse.status().code(), equalTo(401)); } finally { responses.forEach(ReferenceCounted::release); diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4ConditionalDecompressor.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4ConditionalDecompressor.java deleted file mode 100644 index 72815bae24ae6..0000000000000 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4ConditionalDecompressor.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.http.netty4; - -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.HttpContentDecompressor; - -import static org.opensearch.http.netty4.Netty4HttpServerTransport.SHOULD_DECOMPRESS; - -public class Netty4ConditionalDecompressor extends HttpContentDecompressor { - @Override - protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { - if (Boolean.FALSE.equals(ctx.channel().attr(SHOULD_DECOMPRESS).get())) { - return super.newContentDecoder("identity"); - } - return super.newContentDecoder(contentEncoding); - } -} diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java index b26d849e62dce..1f7aaf17d2191 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java @@ -33,10 +33,7 @@ package org.opensearch.http.netty4; import org.opensearch.ExceptionsHelper; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.http.HttpPipelinedRequest; -import org.opensearch.rest.RestHandlerContext; -import org.opensearch.rest.RestResponse; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; @@ -54,14 +51,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); - public static final AttributeKey CONTEXT_TO_RESTORE = AttributeKey.newInstance( - "opensearch-http-request-thread-context" - ); - public static final AttributeKey SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress"); - protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; @@ -427,7 +420,7 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E final ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), "handler", getRequestHandler()); pipeline.replace(this, "header_verifier", transport.createHeaderVerifier()); - pipeline.addAfter("header_verifier", "decoder_compress", new Netty4ConditionalDecompressor()); + pipeline.addAfter("header_verifier", "decoder_compress", transport.createDecompressor()); pipeline.addAfter("decoder_compress", "aggregator", aggregator); if (handlingSettings.isCompression()) { pipeline.addAfter( @@ -454,7 +447,7 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) { decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); pipeline.addLast("decoder", decoder); pipeline.addLast("header_verifier", transport.createHeaderVerifier()); - pipeline.addLast("decoder_compress", new Netty4ConditionalDecompressor()); + pipeline.addLast("decoder_compress", transport.createDecompressor()); pipeline.addLast("encoder", new HttpResponseEncoder()); final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); @@ -501,7 +494,7 @@ protected void initChannel(Channel childChannel) throws Exception { .addLast("byte_buf_sizer", byteBufSizer) .addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)) .addLast("header_verifier", transport.createHeaderVerifier()) - .addLast("decoder_decompress", new Netty4ConditionalDecompressor()); + .addLast("decoder_decompress", transport.createDecompressor()); if (handlingSettings.isCompression()) { childChannel.pipeline() @@ -544,4 +537,8 @@ protected ChannelInboundHandlerAdapter createHeaderVerifier() { // pass-through return new ChannelInboundHandlerAdapter(); } + + protected ChannelInboundHandlerAdapter createDecompressor() { + return new HttpContentDecompressor(); + } } diff --git a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/example/ExampleNetty4HeaderVerifier.java b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/example/ExampleNetty4HeaderVerifier.java deleted file mode 100644 index fb2d228cdd431..0000000000000 --- a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/example/ExampleNetty4HeaderVerifier.java +++ /dev/null @@ -1,9 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.http.netty4.example; diff --git a/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java index 7b04f93d5c3be..d44515f3dc727 100644 --- a/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java @@ -43,7 +43,6 @@ import org.opensearch.nio.SocketChannelContext; import org.opensearch.nio.TaskScheduler; import org.opensearch.nio.WriteOperation; -import org.opensearch.rest.RestHandlerContext; import java.io.IOException; import java.util.ArrayList; @@ -173,7 +172,7 @@ private void handleRequest(Object msg) { final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg; boolean success = false; try { - transport.incomingRequest(pipelinedRequest, nioHttpChannel, RestHandlerContext.EMPTY); + transport.incomingRequest(pipelinedRequest, nioHttpChannel); success = true; } finally { if (success == false) { diff --git a/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java index 1172472a3c6b1..a3f7a7822cd40 100644 --- a/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java @@ -48,7 +48,6 @@ import org.opensearch.nio.InboundChannelBuffer; import org.opensearch.nio.SocketChannelContext; import org.opensearch.nio.TaskScheduler; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.junit.Before; @@ -102,7 +101,7 @@ public void setMocks() { doAnswer(invocation -> { ((HttpRequest) invocation.getArguments()[0]).releaseAndCopy(); return null; - }).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class), any(RestHandlerContext.class)); + }).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class)); Settings settings = Settings.builder().put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(1024)).build(); HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); channel = mock(NioHttpChannel.class); @@ -123,12 +122,12 @@ public void testSuccessfulDecodeHttpRequest() throws IOException { try { handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); handler.consumeReads(toChannelBuffer(slicedBuf2)); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); HttpRequest nioHttpRequest = requestCaptor.getValue(); assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); @@ -154,7 +153,7 @@ public void testDecodeHttpRequestError() throws IOException { handler.consumeReads(toChannelBuffer(buf)); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); assertNotNull(requestCaptor.getValue().getInboundException()); assertTrue(requestCaptor.getValue().getInboundException() instanceof IllegalArgumentException); @@ -175,7 +174,7 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t } finally { buf.release(); } - verify(transport, times(0)).incomingRequest(any(), any(), any(RestHandlerContext.class)); + verify(transport, times(0)).incomingRequest(any(), any()); List flushOperations = handler.pollFlushOperations(); assertFalse(flushOperations.isEmpty()); @@ -281,7 +280,7 @@ private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOEx } ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpPipelinedRequest.class); - verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class), any(RestHandlerContext.class)); + verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class)); HttpRequest httpRequest = requestCaptor.getValue(); assertNotNull(httpRequest); diff --git a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java index d4c3a8c79c5db..bc64bd85cac04 100644 --- a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java @@ -53,7 +53,6 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -360,29 +359,20 @@ protected void serverAcceptedChannel(HttpChannel httpChannel) { * * @param httpRequest that is incoming * @param httpChannel that received the http request - * @param requestContext context carried over to the request handler from earlier stages in the request pipeline */ - public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final RestHandlerContext requestContext) { + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { final Span span = tracer.startSpan(SpanBuilder.from(httpRequest), httpRequest.getHeaders()); try (final SpanScope httpRequestSpanScope = tracer.withSpanInScope(span)) { HttpChannel traceableHttpChannel = TraceableHttpChannel.create(httpChannel, span, tracer); - handleIncomingRequest(httpRequest, traceableHttpChannel, requestContext, httpRequest.getInboundException()); + handleIncomingRequest(httpRequest, traceableHttpChannel, httpRequest.getInboundException()); } } // Visible for testing - protected void dispatchRequest( - final RestRequest restRequest, - final RestChannel channel, - final Throwable badRequestCause, - final ThreadContext.StoredContext storedContext - ) { + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { RestChannel traceableRestChannel = channel; final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - if (storedContext != null) { - storedContext.restore(); - } final Span span = tracer.startSpan(SpanBuilder.from(restRequest)); try (final SpanScope spanScope = tracer.withSpanInScope(span)) { if (channel != null) { @@ -398,12 +388,7 @@ protected void dispatchRequest( } - private void handleIncomingRequest( - final HttpRequest httpRequest, - final HttpChannel httpChannel, - final RestHandlerContext requestContext, - final Exception exception - ) { + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { if (exception == null) { HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest); if (earlyResponse != null) { @@ -477,12 +462,7 @@ private void handleIncomingRequest( channel = innerChannel; } - if (requestContext.hasEarlyResponse()) { - channel.sendResponse(requestContext.getEarlyResponse()); - return; - } - - dispatchRequest(restRequest, channel, badRequestCause, requestContext.getContextToRestore()); + dispatchRequest(restRequest, channel, badRequestCause); } public static RestRequest createRestRequest( diff --git a/server/src/main/java/org/opensearch/rest/RestHandlerContext.java b/server/src/main/java/org/opensearch/rest/RestHandlerContext.java deleted file mode 100644 index 297a44c705e2e..0000000000000 --- a/server/src/main/java/org/opensearch/rest/RestHandlerContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.rest; - -import org.opensearch.common.util.concurrent.ThreadContext; - -/** - * Holder for information that is shared between stages of the request pipeline - */ -public class RestHandlerContext { - private RestResponse earlyResponse; - private ThreadContext.StoredContext contextToRestore; - - public static RestHandlerContext EMPTY = new RestHandlerContext(); - - private RestHandlerContext() {} - - public RestHandlerContext(final RestResponse earlyResponse, ThreadContext.StoredContext contextToRestore) { - this.earlyResponse = earlyResponse; - this.contextToRestore = contextToRestore; - } - - public boolean hasEarlyResponse() { - return this.earlyResponse != null; - } - - public boolean hasContextToRestore() { - return this.contextToRestore != null; - } - - public RestResponse getEarlyResponse() { - return this.earlyResponse; - } - - public ThreadContext.StoredContext getContextToRestore() { - return contextToRestore; - } -} diff --git a/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java index eaa199664fceb..7dcea1c206ac3 100644 --- a/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java @@ -49,7 +49,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; import org.opensearch.tasks.Task; @@ -217,11 +216,11 @@ public HttpStats stats() { } ) { - transport.dispatchRequest(null, null, null, null); + transport.dispatchRequest(null, null, null); assertNull(threadPool.getThreadContext().getHeader("foo")); assertNull(threadPool.getThreadContext().getTransient("bar")); - transport.dispatchRequest(null, null, new Exception(), null); + transport.dispatchRequest(null, null, new Exception()); assertNull(threadPool.getThreadContext().getHeader("foo_bad")); assertNull(threadPool.getThreadContext().getTransient("bar_bad")); } @@ -338,7 +337,7 @@ public HttpStats stats() { .withInboundException(inboundException) .build(); - transport.incomingRequest(fakeRestRequest.getHttpRequest(), fakeRestRequest.getHttpChannel(), RestHandlerContext.EMPTY); + transport.incomingRequest(fakeRestRequest.getHttpRequest(), fakeRestRequest.getHttpChannel()); final Exception inboundExceptionExcludedPath; if (randomBoolean()) { @@ -355,11 +354,7 @@ public HttpStats stats() { .withInboundException(inboundExceptionExcludedPath) .build(); - transport.incomingRequest( - fakeRestRequestExcludedPath.getHttpRequest(), - fakeRestRequestExcludedPath.getHttpChannel(), - RestHandlerContext.EMPTY - ); + transport.incomingRequest(fakeRestRequestExcludedPath.getHttpRequest(), fakeRestRequestExcludedPath.getHttpChannel()); appender.assertAllExpectationsMatched(); } }