diff --git a/src/main/java/org/opensearch/security/filter/NettyAttribute.java b/src/main/java/org/opensearch/security/filter/NettyAttribute.java index 46bf0296cb..3a035a390b 100644 --- a/src/main/java/org/opensearch/security/filter/NettyAttribute.java +++ b/src/main/java/org/opensearch/security/filter/NettyAttribute.java @@ -12,7 +12,7 @@ import java.util.Optional; -import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.http.HttpChannel; import org.opensearch.rest.RestRequest; import io.netty.channel.Channel; @@ -25,11 +25,12 @@ public class NettyAttribute { * Gets an attribute value from the request context and clears it from that context */ public static Optional popFrom(final RestRequest request, final AttributeKey attribute) { - if (request.getHttpChannel() instanceof Netty4HttpChannel) { - Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); - return Optional.ofNullable(nettyChannel.attr(attribute).getAndSet(null)); + final HttpChannel httpChannel = request.getHttpChannel(); + if (httpChannel != null) { + return httpChannel.get("channel", Channel.class).map(channel -> channel.attr(attribute).getAndSet(null)); + } else { + return Optional.empty(); } - return Optional.empty(); } /** @@ -50,9 +51,9 @@ public static Optional peekFrom(final ChannelHandlerContext ctx, final At * Clears an attribute value from the channel handler context */ public static void clearAttribute(final RestRequest request, final AttributeKey attribute) { - if (request.getHttpChannel() instanceof Netty4HttpChannel) { - Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); - nettyChannel.attr(attribute).set(null); + final HttpChannel httpChannel = request.getHttpChannel(); + if (httpChannel != null) { + httpChannel.get("channel", Channel.class).ifPresent(channel -> channel.attr(attribute).set(null)); } } diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java index 80ede8b2c1..e86012f594 100644 --- a/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java @@ -17,7 +17,6 @@ import java.util.Optional; import javax.net.ssl.SSLEngine; -import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; @@ -41,21 +40,11 @@ public Map> getHeaders() { @Override public SSLEngine getSSLEngine() { - if (underlyingRequest == null - || underlyingRequest.getHttpChannel() == null - || !(underlyingRequest.getHttpChannel() instanceof Netty4HttpChannel)) { + if (underlyingRequest == null || underlyingRequest.getHttpChannel() == null) { return null; } - // We look for Ssl_handler called `ssl_http` in the outbound pipeline of Netty channel first, and if its not - // present we look for it in inbound channel. If its present in neither we return null, else we return the sslHandler. - final Netty4HttpChannel httpChannel = (Netty4HttpChannel) underlyingRequest.getHttpChannel(); - SslHandler sslhandler = (SslHandler) httpChannel.getNettyChannel().pipeline().get("ssl_http"); - if (sslhandler == null && httpChannel.inboundPipeline() != null) { - sslhandler = (SslHandler) httpChannel.inboundPipeline().get("ssl_http"); - } - - return sslhandler != null ? sslhandler.engine() : null; + return underlyingRequest.getHttpChannel().get("ssl_http", SslHandler.class).map(SslHandler::engine).orElse(null); } @Override diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 39312e29ad..078c822357 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -36,13 +36,9 @@ import org.opensearch.security.support.ConfigConstants; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TaskTransportChannel; -import org.opensearch.transport.TcpChannel; -import org.opensearch.transport.TcpTransportChannel; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestHandler; -import org.opensearch.transport.netty4.Netty4TcpChannel; import io.netty.handler.ssl.SslHandler; @@ -111,21 +107,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task } try { - - Netty4TcpChannel nettyChannel = null; - - if (channel instanceof TaskTransportChannel) { - final TransportChannel inner = ((TaskTransportChannel) channel).getChannel(); - nettyChannel = (Netty4TcpChannel) ((TcpTransportChannel) inner).getChannel(); - } else if (channel instanceof TcpTransportChannel) { - final TcpChannel inner = ((TcpTransportChannel) channel).getChannel(); - nettyChannel = (Netty4TcpChannel) inner; - } else { - throw new Exception("Invalid channel of type " + channel.getClass() + " (" + channel.getChannelType() + ")"); - } - - final SslHandler sslhandler = (SslHandler) nettyChannel.getNettyChannel().pipeline().get("ssl_server"); - + final SslHandler sslhandler = channel.get("ssl_server", SslHandler.class).orElse(null); if (sslhandler == null) { if (SSLConfig.isDualModeEnabled()) { log.info("Communication in dual mode. Skipping SSL handler check");