Skip to content

Commit

Permalink
[Backport 1.3] Allow customization of netty channel handles before an…
Browse files Browse the repository at this point in the history
…d during decompression (opensearch-project#10261) (opensearch-project#11086)

* [Backport 1.3] Allow customization of netty channel handles before and during decompression (opensearch-project#10261)

Signed-off-by: Peter Nied <[email protected]>

* Fix test cases issues by switching to marked instead of blocked workflow

Signed-off-by: Peter Nied <[email protected]>

* Fix spotless issues

Signed-off-by: Peter Nied <[email protected]>

---------

Signed-off-by: Peter Nied <[email protected]>
Co-authored-by: Craig Perkins <[email protected]>
  • Loading branch information
peternied and cwperks authored Nov 9, 2023
1 parent 77adb21 commit 2e9bfca
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased 1.3.x]

### Added
- Improve compressed request handling ([#10261](https://github.com/opensearch-project/OpenSearch/pull/10261))

### Dependencies
- Bump asm from 9.5 to 9.6 ([#10302](https://github.com/opensearch-project/OpenSearch/pull/10302))
- Bump netty from 4.1.97.Final to 4.1.99.Final ([#10303](https://github.com/opensearch-project/OpenSearch/pull/10303))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.http.netty4;

import org.opensearch.OpenSearchNetty4IntegTestCase;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope;
import org.opensearch.test.OpenSearchIntegTestCase.Scope;
import org.opensearch.transport.Netty4MarkedMessagePlugin;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat;

@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class Netty4HeaderVerifierIT extends OpenSearchNetty4IntegTestCase {

@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(Netty4MarkedMessagePlugin.class);
}

public void testThatNettyHttpServerRequestMarksMessageWithHeaderVerifier() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
TransportAddress transportAddress = randomFrom(boundAddresses);

final FullHttpRequest markedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final String expectedMarkedHeaderValue = "Mark with" + randomAlphaOfLength(10);
markedRequest.headers().add("marked-message", expectedMarkedHeaderValue);

final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) {
try {
final FullHttpResponse markedResponse = nettyHttpClient.send(transportAddress.address(), markedRequest);
responses.add(markedResponse);
final String rootResponseContent = new String(ByteBufUtil.getBytes(markedResponse.content()), StandardCharsets.UTF_8);
assertThat(rootResponseContent, containsString("opensearch"));
assertThat(markedResponse.status().code(), equalTo(200));

assertThat(Netty4MarkedMessagePlugin.MESSAGE.get().headers().get("marked-message"), equalTo(expectedMarkedHeaderValue));
} finally {
responses.forEach(ReferenceCounted::release);
}
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.indices.breaker.CircuitBreakerService;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.http.netty4.Netty4HttpServerTransport;
import org.opensearch.threadpool.ThreadPool;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.util.ReferenceCountUtil;

public class Netty4MarkedMessagePlugin extends Netty4Plugin {

public static final AtomicReference<HttpMessage> MESSAGE = new AtomicReference<>();

public class Netty4BlockingHttpServerTransport extends Netty4HttpServerTransport {

public Netty4BlockingHttpServerTransport(
Settings settings,
NetworkService networkService,
BigArrays bigArrays,
ThreadPool threadPool,
NamedXContentRegistry xContentRegistry,
Dispatcher dispatcher,
ClusterSettings clusterSettings,
SharedGroupFactory sharedGroupFactory
) {
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher, clusterSettings, sharedGroupFactory);
}

@Override
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
return new ExampleBlockingNetty4HeaderVerifier();
}
}

@Override
public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
Settings settings,
ThreadPool threadPool,
BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedXContentRegistry xContentRegistry,
NetworkService networkService,
HttpServerTransport.Dispatcher dispatcher,
ClusterSettings clusterSettings
) {
return Collections.singletonMap(
NETTY_HTTP_TRANSPORT_NAME,
() -> new Netty4BlockingHttpServerTransport(
settings,
networkService,
bigArrays,
threadPool,
xContentRegistry,
dispatcher,
clusterSettings,
getSharedGroupFactory(settings)
)
);
}

/** POC for how an external header verifier would be implemented */
@Sharable
public class ExampleBlockingNetty4HeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {

@Override
public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception {
ReferenceCountUtil.retain(msg);
if (isMarked(msg)) {
MESSAGE.compareAndSet(null, msg);
}

// Lets the request pass to the next channel handler
ctx.fireChannelRead(msg);
}

private boolean isMarked(HttpRequest request) {
return request.headers().contains("marked-message");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings);
}

static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
public static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
static final AttributeKey<Netty4HttpServerChannel> HTTP_SERVER_CHANNEL_KEY = AttributeKey.newInstance("es-http-server-channel");

protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
Expand Down Expand Up @@ -348,7 +348,8 @@ protected void initChannel(Channel ch) throws Exception {
);
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
ch.pipeline().addLast("decoder", decoder);
ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor());
ch.pipeline().addLast("header_verifier", transport.createHeaderVerifier());
ch.pipeline().addLast("decoder_compress", transport.createDecompressor());
ch.pipeline().addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
Expand Down Expand Up @@ -390,4 +391,21 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}
}

/**
* Extension point that allows a NetworkPlugin to extend the netty pipeline and inspect headers after request decoding
*/
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
// pass-through
return new ChannelInboundHandlerAdapter();
}

/**
* Extension point that allows a NetworkPlugin to override the default netty HttpContentDecompressor and supply a custom decompressor.
*
* Used in instances to conditionally decompress depending on the outcome from header verification
*/
protected ChannelInboundHandlerAdapter createDecompressor() {
return new HttpContentDecompressor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
);
}

private SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory groupFactory = this.groupFactory.get();
if (groupFactory != null) {
assert groupFactory.getSettings().equals(settings) : "Different settings than originally provided";
Expand Down

0 comments on commit 2e9bfca

Please sign in to comment.