From 001b7cf2508648c0ffb7f055fb1c372715ff805e Mon Sep 17 00:00:00 2001 From: Peter Nied Date: Thu, 28 Sep 2023 23:25:31 +0000 Subject: [PATCH] Very basic header validator Signed-off-by: Peter Nied --- .../opensearch/http/netty4/Netty4Http2IT.java | 32 +++++ .../http/netty4/Netty4Authorizer.java | 56 +++++++++ .../http/netty4/Netty4HttpRequestHandler.java | 6 +- .../netty4/Netty4HttpServerTransport.java | 36 +++--- .../http/nio/HttpReadWriteHandler.java | 3 +- .../http/nio/HttpReadWriteHandlerTests.java | 13 +-- .../http/AbstractHttpServerTransport.java | 109 +++--------------- .../rest/DelegatingRestHandler.java | 74 ------------ .../org/opensearch/rest/RestController.java | 2 +- .../java/org/opensearch/rest/RestHandler.java | 66 ++++++++++- .../opensearch/rest/RestHandlerContext.java | 44 ------- .../java/org/opensearch/rest/RestRequest.java | 61 +--------- .../AbstractHttpServerTransportTests.java | 29 +---- .../rest/DelegatingRestHandlerTests.java | 58 ---------- 14 files changed, 195 insertions(+), 394 deletions(-) create mode 100644 modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java delete mode 100644 server/src/main/java/org/opensearch/rest/DelegatingRestHandler.java delete mode 100644 server/src/main/java/org/opensearch/rest/RestHandlerContext.java delete mode 100644 server/src/test/java/org/opensearch/rest/DelegatingRestHandlerTests.java diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java index eba2c5ce1e094..44a186393adae 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java @@ -16,18 +16,26 @@ import org.opensearch.test.OpenSearchIntegTestCase.Scope; import java.util.Collection; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.stream.IntStream; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import io.netty.util.ReferenceCounted; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasSize; +import io.netty.handler.codec.http2.HttpConversionUtil; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; + @ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1) public class Netty4Http2IT extends OpenSearchNetty4IntegTestCase { @@ -56,6 +64,30 @@ public void testThatNettyHttpServerSupportsHttp2GetUpgrades() throws Exception { } } + + public void testThatNettyHttpServerHandlesAuthenticateCheck() throws Exception { + HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); + TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); + TransportAddress transportAddress = randomFrom(boundAddresses); + + final FullHttpRequest unauthorizedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + unauthorizedRequest.headers().add("blockme", "Not Allowed"); + unauthorizedRequest.headers().add(HOST, "localhost"); + unauthorizedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + + + final List responses = new ArrayList<>(); + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2() ) { + try { + FullHttpResponse unauthorizedResponse = nettyHttpClient.send(transportAddress.address(), unauthorizedRequest); + responses.add(unauthorizedResponse); + assertThat(unauthorizedResponse.status().code(), equalTo(401)); + } finally { + responses.forEach(ReferenceCounted::release); + } + } + } + public void testThatNettyHttpServerSupportsHttp2PostUpgrades() throws Exception { final List> requests = List.of(Tuple.tuple("/_search", "{\"query\":{ \"match_all\":{}}}")); diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java new file mode 100644 index 0000000000000..07c0035f5142b --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.http.netty4; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelFutureListener; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.ReferenceCountUtil; + +@ChannelHandler.Sharable +public class Netty4Authorizer extends ChannelInboundHandlerAdapter { + + final static Logger log = LogManager.getLogger(Netty4Authorizer.class); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.fireChannelRead(msg); + } + + HttpRequest request = (HttpRequest) msg; + if (!isAuthenticated(request)) { + final FullHttpResponse response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.UNAUTHORIZED); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + ReferenceCountUtil.release(msg); + } else { + // Lets the request pass to the next channel handler + ctx.fireChannelRead(msg); + } + } + + private boolean isAuthenticated(HttpRequest request) { + log.info("Checking if request is authenticated:\n" + request); + + final boolean shouldBlock = request.headers().contains("blockme"); + + return !shouldBlock; + } +} diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java index cc331f1b4fd2d..37a6f131b4468 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java @@ -35,7 +35,6 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.http.HttpPipelinedRequest; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestResponse; import io.netty.channel.ChannelHandler; @@ -54,12 +53,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel"); + protected static final AttributeKey HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel"); protected static final AttributeKey HTTP_SERVER_CHANNEL_KEY = AttributeKey.newInstance( "opensearch-http-server-channel" ); - public static final AttributeKey EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); - public static final AttributeKey CONTEXT_TO_RESTORE = AttributeKey.newInstance( - "opensearch-http-request-thread-context" - ); - protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; @@ -427,10 +420,14 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E final ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), "handler", getRequestHandler()); pipeline.replace(this, "header_verifier", transport.createHeaderVerifier()); - pipeline.addAfter("header_verifier", "decompress", transport.createDecompressor()); - pipeline.addAfter("decompress", "aggregator", aggregator); + pipeline.addAfter("header_verifier", "decoder_compress", new HttpContentDecompressor()); + pipeline.addAfter("decoder_compress", "aggregator", aggregator); if (handlingSettings.isCompression()) { - pipeline.addAfter("aggregator", "compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + pipeline.addAfter( + "aggregator", + "encoder_compress", + new HttpContentCompressor(handlingSettings.getCompressionLevel()) + ); } pipeline.addBefore("handler", "request_creator", requestCreator); pipeline.addBefore("handler", "response_creator", responseCreator); @@ -450,13 +447,13 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) { decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); pipeline.addLast("decoder", decoder); pipeline.addLast("header_verifier", transport.createHeaderVerifier()); - pipeline.addLast("decompress", transport.createDecompressor()); + pipeline.addLast("decoder_compress", new HttpContentDecompressor()); pipeline.addLast("encoder", new HttpResponseEncoder()); final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); pipeline.addLast("aggregator", aggregator); if (handlingSettings.isCompression()) { - pipeline.addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + pipeline.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } pipeline.addLast("request_creator", requestCreator); pipeline.addLast("response_creator", responseCreator); @@ -491,16 +488,18 @@ protected void initChannel(Channel childChannel) throws Exception { final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); + childChannel.pipeline() .addLast(new LoggingHandler(LogLevel.DEBUG)) .addLast(new Http2StreamFrameToHttpObjectCodec(true)) .addLast("byte_buf_sizer", byteBufSizer) .addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)) .addLast("header_verifier", transport.createHeaderVerifier()) - .addLast("decompress", transport.createDecompressor()); + .addLast("decoder_decompress", new HttpContentDecompressor()); if (handlingSettings.isCompression()) { - childChannel.pipeline().addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + childChannel.pipeline() + .addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } childChannel.pipeline() @@ -535,12 +534,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } - protected HttpContentDecompressor createDecompressor() { - return new HttpContentDecompressor(); - } - protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return new Netty4Authorizer(); // pass-through - return new ChannelInboundHandlerAdapter(); +// return new ChannelInboundHandlerAdapter(); } } diff --git a/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java index 7b04f93d5c3be..d44515f3dc727 100644 --- a/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/opensearch/http/nio/HttpReadWriteHandler.java @@ -43,7 +43,6 @@ import org.opensearch.nio.SocketChannelContext; import org.opensearch.nio.TaskScheduler; import org.opensearch.nio.WriteOperation; -import org.opensearch.rest.RestHandlerContext; import java.io.IOException; import java.util.ArrayList; @@ -173,7 +172,7 @@ private void handleRequest(Object msg) { final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg; boolean success = false; try { - transport.incomingRequest(pipelinedRequest, nioHttpChannel, RestHandlerContext.EMPTY); + transport.incomingRequest(pipelinedRequest, nioHttpChannel); success = true; } finally { if (success == false) { diff --git a/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java index 1172472a3c6b1..a3f7a7822cd40 100644 --- a/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/opensearch/http/nio/HttpReadWriteHandlerTests.java @@ -48,7 +48,6 @@ import org.opensearch.nio.InboundChannelBuffer; import org.opensearch.nio.SocketChannelContext; import org.opensearch.nio.TaskScheduler; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.junit.Before; @@ -102,7 +101,7 @@ public void setMocks() { doAnswer(invocation -> { ((HttpRequest) invocation.getArguments()[0]).releaseAndCopy(); return null; - }).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class), any(RestHandlerContext.class)); + }).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class)); Settings settings = Settings.builder().put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(1024)).build(); HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); channel = mock(NioHttpChannel.class); @@ -123,12 +122,12 @@ public void testSuccessfulDecodeHttpRequest() throws IOException { try { handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); handler.consumeReads(toChannelBuffer(slicedBuf2)); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); HttpRequest nioHttpRequest = requestCaptor.getValue(); assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); @@ -154,7 +153,7 @@ public void testDecodeHttpRequestError() throws IOException { handler.consumeReads(toChannelBuffer(buf)); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class)); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); assertNotNull(requestCaptor.getValue().getInboundException()); assertTrue(requestCaptor.getValue().getInboundException() instanceof IllegalArgumentException); @@ -175,7 +174,7 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t } finally { buf.release(); } - verify(transport, times(0)).incomingRequest(any(), any(), any(RestHandlerContext.class)); + verify(transport, times(0)).incomingRequest(any(), any()); List flushOperations = handler.pollFlushOperations(); assertFalse(flushOperations.isEmpty()); @@ -281,7 +280,7 @@ private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOEx } ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpPipelinedRequest.class); - verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class), any(RestHandlerContext.class)); + verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class)); HttpRequest httpRequest = requestCaptor.getValue(); assertNotNull(httpRequest); diff --git a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java index 61c2fa01c9c06..ed44102d0abe4 100644 --- a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java @@ -53,7 +53,6 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -100,7 +99,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected final ThreadPool threadPool; protected final Dispatcher dispatcher; protected final CorsHandler corsHandler; - protected final NamedXContentRegistry xContentRegistry; + private final NamedXContentRegistry xContentRegistry; protected final PortsRange port; protected final ByteSizeValue maxContentLength; @@ -113,7 +112,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo private final Set httpServerChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); private final HttpTracer httpTracer; - protected final Tracer tracer; + private final Tracer tracer; protected AbstractHttpServerTransport( Settings settings, @@ -360,29 +359,20 @@ protected void serverAcceptedChannel(HttpChannel httpChannel) { * * @param httpRequest that is incoming * @param httpChannel that received the http request - * @param requestContext context carried over to the request handler from earlier stages in the request pipeline */ - public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final RestHandlerContext requestContext) { + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { final Span span = tracer.startSpan(SpanBuilder.from(httpRequest), httpRequest.getHeaders()); try (final SpanScope httpRequestSpanScope = tracer.withSpanInScope(span)) { HttpChannel traceableHttpChannel = TraceableHttpChannel.create(httpChannel, span, tracer); - handleIncomingRequest(httpRequest, traceableHttpChannel, requestContext, httpRequest.getInboundException()); + handleIncomingRequest(httpRequest, traceableHttpChannel, httpRequest.getInboundException()); } } // Visible for testing - protected void dispatchRequest( - final RestRequest restRequest, - final RestChannel channel, - final Throwable badRequestCause, - final ThreadContext.StoredContext storedContext - ) { + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { RestChannel traceableRestChannel = channel; final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - if (storedContext != null) { - storedContext.restore(); - } final Span span = tracer.startSpan(SpanBuilder.from(restRequest)); try (final SpanScope spanScope = tracer.withSpanInScope(span)) { if (channel != null) { @@ -398,12 +388,7 @@ protected void dispatchRequest( } - private void handleIncomingRequest( - final HttpRequest httpRequest, - final HttpChannel httpChannel, - final RestHandlerContext requestContext, - final Exception exception - ) { + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { if (exception == null) { HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest); if (earlyResponse != null) { @@ -426,13 +411,13 @@ private void handleIncomingRequest( { RestRequest innerRestRequest; try { - innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, true); + innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel); } catch (final RestRequest.ContentTypeHeaderException e) { badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause, true); + innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause); } catch (final RestRequest.BadParameterException e) { badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, true); + innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); } restRequest = innerRestRequest; } @@ -477,84 +462,16 @@ private void handleIncomingRequest( channel = innerChannel; } - if (requestContext.hasEarlyResponse()) { - channel.sendResponse(requestContext.getEarlyResponse()); - return; - } - - dispatchRequest(restRequest, channel, badRequestCause, requestContext.getContextToRestore()); - } - - public static RestRequest createRestRequest( - final NamedXContentRegistry xContentRegistry, - final HttpRequest httpRequest, - final HttpChannel httpChannel - ) { - // TODO Figure out how to only generate one request ID for each request in the pipeline. - Exception badRequestCause = httpRequest.getInboundException(); - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final RestRequest restRequest; - { - RestRequest innerRestRequest; - try { - innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, false); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = requestWithoutContentTypeHeader(xContentRegistry, httpRequest, httpChannel, badRequestCause, false); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, false); - } - restRequest = innerRestRequest; - } - return restRequest; - } - - private static RestRequest requestWithoutContentTypeHeader( - NamedXContentRegistry xContentRegistry, - HttpRequest httpRequest, - HttpChannel httpChannel, - Exception badRequestCause, - boolean shouldGenerateRequestId - ) { - HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); - try { - return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId); - } catch (final RestRequest.BadParameterException e) { - badRequestCause.addSuppressed(e); - return RestRequest.requestWithoutParameters( - xContentRegistry, - httpRequestWithoutContentType, - httpChannel, - shouldGenerateRequestId - ); - } + dispatchRequest(restRequest, channel, badRequestCause); } - private RestRequest requestWithoutContentTypeHeader( - HttpRequest httpRequest, - HttpChannel httpChannel, - Exception badRequestCause, - boolean shouldGenerateRequestId - ) { + private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) { HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); try { - return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId); + return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel); } catch (final RestRequest.BadParameterException e) { badRequestCause.addSuppressed(e); - return RestRequest.requestWithoutParameters( - xContentRegistry, - httpRequestWithoutContentType, - httpChannel, - shouldGenerateRequestId - ); + return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } } diff --git a/server/src/main/java/org/opensearch/rest/DelegatingRestHandler.java b/server/src/main/java/org/opensearch/rest/DelegatingRestHandler.java deleted file mode 100644 index 928ff94d8c1d3..0000000000000 --- a/server/src/main/java/org/opensearch/rest/DelegatingRestHandler.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.rest; - -import org.opensearch.client.node.NodeClient; - -import java.util.List; -import java.util.Objects; - -/** - * Delegating RestHandler that delegates all implementations to original handler - * - * @opensearch.api - */ -public class DelegatingRestHandler implements RestHandler { - - protected final RestHandler delegate; - - public DelegatingRestHandler(RestHandler delegate) { - Objects.requireNonNull(delegate, "RestHandler delegate can not be null"); - this.delegate = delegate; - } - - @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { - delegate.handleRequest(request, channel, client); - } - - @Override - public boolean canTripCircuitBreaker() { - return delegate.canTripCircuitBreaker(); - } - - @Override - public boolean supportsContentStream() { - return delegate.supportsContentStream(); - } - - @Override - public boolean allowsUnsafeBuffers() { - return delegate.allowsUnsafeBuffers(); - } - - @Override - public List routes() { - return delegate.routes(); - } - - @Override - public List deprecatedRoutes() { - return delegate.deprecatedRoutes(); - } - - @Override - public List replacedRoutes() { - return delegate.replacedRoutes(); - } - - @Override - public boolean allowSystemIndexAccessByDefault() { - return delegate.allowSystemIndexAccessByDefault(); - } - - @Override - public String toString() { - return delegate.toString(); - } -} diff --git a/server/src/main/java/org/opensearch/rest/RestController.java b/server/src/main/java/org/opensearch/rest/RestController.java index 4929f2a147dae..ac30f999d0da7 100644 --- a/server/src/main/java/org/opensearch/rest/RestController.java +++ b/server/src/main/java/org/opensearch/rest/RestController.java @@ -131,7 +131,7 @@ public RestController( this.headersToCopy = headersToCopy; this.usageService = usageService; if (handlerWrapper == null) { - handlerWrapper = (delegate) -> new DelegatingRestHandler(delegate); + handlerWrapper = h -> h; // passthrough if no wrapper set } this.handlerWrapper = handlerWrapper; this.client = client; diff --git a/server/src/main/java/org/opensearch/rest/RestHandler.java b/server/src/main/java/org/opensearch/rest/RestHandler.java index edb1cb341d2d8..7832649e8ad32 100644 --- a/server/src/main/java/org/opensearch/rest/RestHandler.java +++ b/server/src/main/java/org/opensearch/rest/RestHandler.java @@ -44,8 +44,6 @@ /** * Handler for REST requests * - * If new methods are added to this interface they must also be added to {@link DelegatingRestHandler} - * * @opensearch.api */ @FunctionalInterface @@ -110,13 +108,75 @@ default List replacedRoutes() { } /** - * Controls whether requests handled by this class are allowed to access system indices by default. + * Controls whether requests handled by this class are allowed to to access system indices by default. * @return {@code true} if requests handled by this class should be allowed to access system indices. */ default boolean allowSystemIndexAccessByDefault() { return false; } + static RestHandler wrapper(RestHandler delegate) { + return new Wrapper(delegate); + } + + /** + * Wrapper for a handler. + * + * @opensearch.internal + */ + class Wrapper implements RestHandler { + private final RestHandler delegate; + + public Wrapper(RestHandler delegate) { + this.delegate = Objects.requireNonNull(delegate, "RestHandler delegate can not be null"); + } + + @Override + public String toString() { + return delegate.toString(); + } + + @Override + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + delegate.handleRequest(request, channel, client); + } + + @Override + public boolean canTripCircuitBreaker() { + return delegate.canTripCircuitBreaker(); + } + + @Override + public boolean supportsContentStream() { + return delegate.supportsContentStream(); + } + + @Override + public boolean allowsUnsafeBuffers() { + return delegate.allowsUnsafeBuffers(); + } + + @Override + public List routes() { + return delegate.routes(); + } + + @Override + public List deprecatedRoutes() { + return delegate.deprecatedRoutes(); + } + + @Override + public List replacedRoutes() { + return delegate.replacedRoutes(); + } + + @Override + public boolean allowSystemIndexAccessByDefault() { + return delegate.allowSystemIndexAccessByDefault(); + } + } + /** * Route for the request. * diff --git a/server/src/main/java/org/opensearch/rest/RestHandlerContext.java b/server/src/main/java/org/opensearch/rest/RestHandlerContext.java deleted file mode 100644 index 297a44c705e2e..0000000000000 --- a/server/src/main/java/org/opensearch/rest/RestHandlerContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.rest; - -import org.opensearch.common.util.concurrent.ThreadContext; - -/** - * Holder for information that is shared between stages of the request pipeline - */ -public class RestHandlerContext { - private RestResponse earlyResponse; - private ThreadContext.StoredContext contextToRestore; - - public static RestHandlerContext EMPTY = new RestHandlerContext(); - - private RestHandlerContext() {} - - public RestHandlerContext(final RestResponse earlyResponse, ThreadContext.StoredContext contextToRestore) { - this.earlyResponse = earlyResponse; - this.contextToRestore = contextToRestore; - } - - public boolean hasEarlyResponse() { - return this.earlyResponse != null; - } - - public boolean hasContextToRestore() { - return this.contextToRestore != null; - } - - public RestResponse getEarlyResponse() { - return this.earlyResponse; - } - - public ThreadContext.StoredContext getContextToRestore() { - return contextToRestore; - } -} diff --git a/server/src/main/java/org/opensearch/rest/RestRequest.java b/server/src/main/java/org/opensearch/rest/RestRequest.java index 275341cd960b8..f64774686c89d 100644 --- a/server/src/main/java/org/opensearch/rest/RestRequest.java +++ b/server/src/main/java/org/opensearch/rest/RestRequest.java @@ -76,6 +76,7 @@ public class RestRequest implements ToXContent.Params { // tchar pattern as defined by RFC7230 section 3.2.6 private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+"); + private static final AtomicLong requestIdGenerator = new AtomicLong(); private final NamedXContentRegistry xContentRegistry; @@ -151,7 +152,7 @@ protected RestRequest(RestRequest restRequest) { * with an unpooled copy. This is supposed to be used before passing requests to {@link RestHandler} instances that can not safely * handle http requests that use pooled buffers as determined by {@link RestHandler#allowsUnsafeBuffers()}. */ - protected void ensureSafeBuffers() { + void ensureSafeBuffers() { httpRequest = httpRequest.releaseAndCopy(); } @@ -179,36 +180,6 @@ public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRe ); } - /** - * Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be - * decoded - * - * @param xContentRegistry the content registry - * @param httpRequest the http request - * @param httpChannel the http channel - * @param shouldGenerateRequestId should generate a new request id - * @throws BadParameterException if the parameters can not be decoded - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - public static RestRequest request( - NamedXContentRegistry xContentRegistry, - HttpRequest httpRequest, - HttpChannel httpChannel, - boolean shouldGenerateRequestId - ) { - Map params = params(httpRequest.uri()); - String path = path(httpRequest.uri()); - return new RestRequest( - xContentRegistry, - params, - path, - httpRequest.getHeaders(), - httpRequest, - httpChannel, - shouldGenerateRequestId ? requestIdGenerator.incrementAndGet() : -1 - ); - } - private static Map params(final String uri) { final Map params = new HashMap<>(); int index = uri.indexOf('?'); @@ -257,34 +228,6 @@ public static RestRequest requestWithoutParameters( ); } - /** - * Creates a new REST request. The path is not decoded so this constructor will not throw a - * {@link BadParameterException}. - * - * @param xContentRegistry the content registry - * @param httpRequest the http request - * @param httpChannel the http channel - * @param shouldGenerateRequestId should generate new request id - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - public static RestRequest requestWithoutParameters( - NamedXContentRegistry xContentRegistry, - HttpRequest httpRequest, - HttpChannel httpChannel, - boolean shouldGenerateRequestId - ) { - Map params = Collections.emptyMap(); - return new RestRequest( - xContentRegistry, - params, - httpRequest.uri(), - httpRequest.getHeaders(), - httpRequest, - httpChannel, - shouldGenerateRequestId ? requestIdGenerator.incrementAndGet() : -1 - ); - } - /** * The method used. * diff --git a/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java index eaa199664fceb..c34f13041cb11 100644 --- a/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/opensearch/http/AbstractHttpServerTransportTests.java @@ -49,7 +49,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; import org.opensearch.tasks.Task; @@ -65,7 +64,6 @@ import java.net.InetSocketAddress; import java.net.UnknownHostException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -151,21 +149,6 @@ public void testHttpPublishPort() throws Exception { } } - public void testCreateRestRequestDoesNotGenerateRequestID() { - FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent( - new BytesArray("bar".getBytes(StandardCharsets.UTF_8)), - null - ).withPath("/foo").withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList("text/plain"))).build(); - - RestRequest request = AbstractHttpServerTransport.createRestRequest( - xContentRegistry(), - fakeRestRequest.getHttpRequest(), - fakeRestRequest.getHttpChannel() - ); - - assertEquals("request should not generate id", -1, request.getRequestId()); - } - public void testDispatchDoesNotModifyThreadContext() { final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { @@ -217,11 +200,11 @@ public HttpStats stats() { } ) { - transport.dispatchRequest(null, null, null, null); + transport.dispatchRequest(null, null, null); assertNull(threadPool.getThreadContext().getHeader("foo")); assertNull(threadPool.getThreadContext().getTransient("bar")); - transport.dispatchRequest(null, null, new Exception(), null); + transport.dispatchRequest(null, null, new Exception()); assertNull(threadPool.getThreadContext().getHeader("foo_bad")); assertNull(threadPool.getThreadContext().getTransient("bar_bad")); } @@ -338,7 +321,7 @@ public HttpStats stats() { .withInboundException(inboundException) .build(); - transport.incomingRequest(fakeRestRequest.getHttpRequest(), fakeRestRequest.getHttpChannel(), RestHandlerContext.EMPTY); + transport.incomingRequest(fakeRestRequest.getHttpRequest(), fakeRestRequest.getHttpChannel()); final Exception inboundExceptionExcludedPath; if (randomBoolean()) { @@ -355,11 +338,7 @@ public HttpStats stats() { .withInboundException(inboundExceptionExcludedPath) .build(); - transport.incomingRequest( - fakeRestRequestExcludedPath.getHttpRequest(), - fakeRestRequestExcludedPath.getHttpChannel(), - RestHandlerContext.EMPTY - ); + transport.incomingRequest(fakeRestRequestExcludedPath.getHttpRequest(), fakeRestRequestExcludedPath.getHttpChannel()); appender.assertAllExpectationsMatched(); } } diff --git a/server/src/test/java/org/opensearch/rest/DelegatingRestHandlerTests.java b/server/src/test/java/org/opensearch/rest/DelegatingRestHandlerTests.java deleted file mode 100644 index ca802a7784ca0..0000000000000 --- a/server/src/test/java/org/opensearch/rest/DelegatingRestHandlerTests.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.rest; - -import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.test.OpenSearchTestCase; - -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -public class DelegatingRestHandlerTests extends OpenSearchTestCase { - public void testDelegatingRestHandlerShouldActAsOriginal() throws Exception { - RestHandler rh = new RestHandler() { - @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { - new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY); - } - }; - RestHandler handlerSpy = spy(rh); - DelegatingRestHandler drh = new DelegatingRestHandler(handlerSpy); - - List overridableMethods = Arrays.stream(RestHandler.class.getMethods()) - .filter( - m -> !(Modifier.isPrivate(m.getModifiers()) || Modifier.isStatic(m.getModifiers()) || Modifier.isFinal(m.getModifiers())) - ) - .collect(Collectors.toList()); - - for (Method method : overridableMethods) { - int argCount = method.getParameterCount(); - Object[] args = new Object[argCount]; - for (int i = 0; i < argCount; i++) { - args[i] = any(); - } - if (args.length > 0) { - method.invoke(drh, args); - } else { - method.invoke(drh); - } - method.invoke(verify(handlerSpy, times(1)), args); - } - } -}