Skip to content

Commit

Permalink
Move new AttributeKeys to security plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks committed Oct 4, 2023
1 parent aec3ad3 commit 01dfa89
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
import java.util.Map;
import java.util.function.Supplier;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
Expand Down Expand Up @@ -98,18 +102,14 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
}

/** POC for how an external header verifier would be implemented */
public class ExampleBlockingNetty4HeaderVerifier extends ChannelInboundHandlerAdapter {
public class ExampleBlockingNetty4HeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof HttpRequest)) {
ctx.fireChannelRead(msg);
return;
}

HttpRequest request = (HttpRequest) msg;
if (!isAuthenticated(request)) {
final FullHttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.UNAUTHORIZED);
public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception {
ReferenceCountUtil.retain(msg);
if (isBlocked(msg)) {
ByteBuf buf = Unpooled.copiedBuffer("Hit header_verifier".getBytes());
final FullHttpResponse response = new DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.UNAUTHORIZED, buf);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
ReferenceCountUtil.release(msg);
} else {
Expand All @@ -118,10 +118,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
}

private boolean isAuthenticated(HttpRequest request) {
private boolean isBlocked(HttpRequest request) {
final boolean shouldBlock = request.headers().contains("blockme");

return !shouldBlock;
return shouldBlock;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@
import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope;
import org.opensearch.test.OpenSearchIntegTestCase.Scope;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.HttpConversionUtil;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;

Expand All @@ -43,7 +47,6 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(Netty4BlockingPlugin.class);
}

@AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/10260")
public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
Expand All @@ -52,12 +55,15 @@ public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exc
final FullHttpRequest blockedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
blockedRequest.headers().add("blockme", "Not Allowed");
blockedRequest.headers().add(HOST, "localhost");
blockedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http");

final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) {
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2()) {
try {
FullHttpResponse blockedResponse = nettyHttpClient.send(transportAddress.address(), blockedRequest);
responses.add(blockedResponse);
String blockedResponseContent = new String(ByteBufUtil.getBytes(blockedResponse.content()), StandardCharsets.UTF_8);
assertThat(blockedResponseContent, containsString("Hit header_verifier"));
assertThat(blockedResponse.status().code(), equalTo(401));
} finally {
responses.forEach(ReferenceCounted::release);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
package org.opensearch.http.netty4;

import org.opensearch.ExceptionsHelper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.http.HttpPipelinedRequest;
import org.opensearch.rest.RestHandlerContext;
import org.opensearch.rest.RestResponse;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -54,14 +51,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest httpRequest) {
final Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
final RestResponse earlyResponse = ctx.channel().attr(Netty4HttpServerTransport.EARLY_RESPONSE).get();
final ThreadContext.StoredContext contextToRestore = ctx.channel().attr(Netty4HttpServerTransport.CONTEXT_TO_RESTORE).get();
ctx.channel().attr(Netty4HttpServerTransport.CONTEXT_TO_RESTORE).set(null);
ctx.channel().attr(Netty4HttpServerTransport.EARLY_RESPONSE).set(null);
final RestHandlerContext requestContext = new RestHandlerContext(earlyResponse, contextToRestore);
boolean success = false;
try {
serverTransport.incomingRequest(httpRequest, channel, requestContext);
serverTransport.incomingRequest(httpRequest, channel);
success = true;
} finally {
if (success == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.common.util.net.NetUtils;
import org.opensearch.core.common.unit.ByteSizeUnit;
Expand All @@ -53,7 +52,6 @@
import org.opensearch.http.HttpHandlingSettings;
import org.opensearch.http.HttpReadTimeoutException;
import org.opensearch.http.HttpServerChannel;
import org.opensearch.rest.RestResponse;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.NettyAllocator;
Expand All @@ -80,6 +78,7 @@
import io.netty.channel.socket.nio.NioChannelOption;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder;
Expand Down Expand Up @@ -340,12 +339,6 @@ public ChannelHandler configureServerChannelHandler() {
"opensearch-http-server-channel"
);

public static final AttributeKey<RestResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);
public static final AttributeKey<Boolean> SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress");

protected static class HttpChannelHandler extends ChannelInitializer<Channel> {

private final Netty4HttpServerTransport transport;
Expand Down Expand Up @@ -427,7 +420,7 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E
final ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), "handler", getRequestHandler());
pipeline.replace(this, "header_verifier", transport.createHeaderVerifier());
pipeline.addAfter("header_verifier", "decoder_compress", new Netty4ConditionalDecompressor());
pipeline.addAfter("header_verifier", "decoder_compress", transport.createDecompressor());
pipeline.addAfter("decoder_compress", "aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addAfter(
Expand All @@ -454,7 +447,7 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) {
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
pipeline.addLast("decoder", decoder);
pipeline.addLast("header_verifier", transport.createHeaderVerifier());
pipeline.addLast("decoder_compress", new Netty4ConditionalDecompressor());
pipeline.addLast("decoder_compress", transport.createDecompressor());
pipeline.addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
Expand Down Expand Up @@ -501,7 +494,7 @@ protected void initChannel(Channel childChannel) throws Exception {
.addLast("byte_buf_sizer", byteBufSizer)
.addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS))
.addLast("header_verifier", transport.createHeaderVerifier())
.addLast("decoder_decompress", new Netty4ConditionalDecompressor());
.addLast("decoder_decompress", transport.createDecompressor());

if (handlingSettings.isCompression()) {
childChannel.pipeline()
Expand Down Expand Up @@ -544,4 +537,8 @@ protected ChannelInboundHandlerAdapter createHeaderVerifier() {
// pass-through
return new ChannelInboundHandlerAdapter();
}

protected ChannelInboundHandlerAdapter createDecompressor() {
return new HttpContentDecompressor();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.opensearch.nio.SocketChannelContext;
import org.opensearch.nio.TaskScheduler;
import org.opensearch.nio.WriteOperation;
import org.opensearch.rest.RestHandlerContext;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -173,7 +172,7 @@ private void handleRequest(Object msg) {
final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg;
boolean success = false;
try {
transport.incomingRequest(pipelinedRequest, nioHttpChannel, RestHandlerContext.EMPTY);
transport.incomingRequest(pipelinedRequest, nioHttpChannel);
success = true;
} finally {
if (success == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.opensearch.nio.InboundChannelBuffer;
import org.opensearch.nio.SocketChannelContext;
import org.opensearch.nio.TaskScheduler;
import org.opensearch.rest.RestHandlerContext;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.Before;
Expand Down Expand Up @@ -102,7 +101,7 @@ public void setMocks() {
doAnswer(invocation -> {
((HttpRequest) invocation.getArguments()[0]).releaseAndCopy();
return null;
}).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class), any(RestHandlerContext.class));
}).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class));
Settings settings = Settings.builder().put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(1024)).build();
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
channel = mock(NioHttpChannel.class);
Expand All @@ -123,12 +122,12 @@ public void testSuccessfulDecodeHttpRequest() throws IOException {
try {
handler.consumeReads(toChannelBuffer(slicedBuf));

verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));

handler.consumeReads(toChannelBuffer(slicedBuf2));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
Expand All @@ -154,7 +153,7 @@ public void testDecodeHttpRequestError() throws IOException {
handler.consumeReads(toChannelBuffer(buf));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

assertNotNull(requestCaptor.getValue().getInboundException());
assertTrue(requestCaptor.getValue().getInboundException() instanceof IllegalArgumentException);
Expand All @@ -175,7 +174,7 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t
} finally {
buf.release();
}
verify(transport, times(0)).incomingRequest(any(), any(), any(RestHandlerContext.class));
verify(transport, times(0)).incomingRequest(any(), any());

List<FlushOperation> flushOperations = handler.pollFlushOperations();
assertFalse(flushOperations.isEmpty());
Expand Down Expand Up @@ -281,7 +280,7 @@ private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOEx
}

ArgumentCaptor<HttpPipelinedRequest> requestCaptor = ArgumentCaptor.forClass(HttpPipelinedRequest.class);
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class), any(RestHandlerContext.class));
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));

HttpRequest httpRequest = requestCaptor.getValue();
assertNotNull(httpRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandlerContext;
import org.opensearch.rest.RestRequest;
import org.opensearch.telemetry.tracing.Span;
import org.opensearch.telemetry.tracing.SpanBuilder;
Expand Down Expand Up @@ -360,29 +359,20 @@ protected void serverAcceptedChannel(HttpChannel httpChannel) {
*
* @param httpRequest that is incoming
* @param httpChannel that received the http request
* @param requestContext context carried over to the request handler from earlier stages in the request pipeline
*/
public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final RestHandlerContext requestContext) {
public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) {
final Span span = tracer.startSpan(SpanBuilder.from(httpRequest), httpRequest.getHeaders());
try (final SpanScope httpRequestSpanScope = tracer.withSpanInScope(span)) {
HttpChannel traceableHttpChannel = TraceableHttpChannel.create(httpChannel, span, tracer);
handleIncomingRequest(httpRequest, traceableHttpChannel, requestContext, httpRequest.getInboundException());
handleIncomingRequest(httpRequest, traceableHttpChannel, httpRequest.getInboundException());
}
}

// Visible for testing
protected void dispatchRequest(
final RestRequest restRequest,
final RestChannel channel,
final Throwable badRequestCause,
final ThreadContext.StoredContext storedContext
) {
void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) {
RestChannel traceableRestChannel = channel;
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
if (storedContext != null) {
storedContext.restore();
}
final Span span = tracer.startSpan(SpanBuilder.from(restRequest));
try (final SpanScope spanScope = tracer.withSpanInScope(span)) {
if (channel != null) {
Expand All @@ -398,12 +388,7 @@ protected void dispatchRequest(

}

private void handleIncomingRequest(
final HttpRequest httpRequest,
final HttpChannel httpChannel,
final RestHandlerContext requestContext,
final Exception exception
) {
private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
if (exception == null) {
HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest);
if (earlyResponse != null) {
Expand Down Expand Up @@ -477,12 +462,7 @@ private void handleIncomingRequest(
channel = innerChannel;
}

if (requestContext.hasEarlyResponse()) {
channel.sendResponse(requestContext.getEarlyResponse());
return;
}

dispatchRequest(restRequest, channel, badRequestCause, requestContext.getContextToRestore());
dispatchRequest(restRequest, channel, badRequestCause);
}

public static RestRequest createRestRequest(
Expand Down
Loading

0 comments on commit 01dfa89

Please sign in to comment.