From 037bc20b212fae282e4eefc379ff6631f49db6cf Mon Sep 17 00:00:00 2001 From: Stephen Crawford <65832608+scrawfor99@users.noreply.github.com> Date: Thu, 18 Jan 2024 10:10:34 -0500 Subject: [PATCH] Improve code coverage for SSLNettyTransport class (#3953) ### Description [Describe what this change achieves] This change increases code coverage for the SecuritySSLNettyTransport class. In the middle of 12/23, a few unit tests were added to give coverage to different parts of the class. This change builds on these existing changes. ### Issues Resolved Box three of https://github.com/opensearch-project/security/issues/3137 Signed-off-by: Stephen Crawford --- .../transport/SecuritySSLNettyTransport.java | 15 +- .../SecuritySSLNettyTransportTests.java | 132 ++++++++++++++---- 2 files changed, 120 insertions(+), 27 deletions(-) diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java index 242c7c56ed..5be3424528 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java @@ -39,6 +39,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.network.NetworkService; @@ -103,6 +104,11 @@ public SecuritySSLNettyTransport( this.SSLConfig = SSLConfig; } + // This allows for testing log messages + Logger getLogger() { + return logger; + } + @Override public void onException(TcpChannel channel, Exception e) { @@ -113,8 +119,11 @@ public void onException(TcpChannel channel, Exception e) { } errorHandler.logError(cause, false); - logger.error("Exception during establishing a SSL connection: " + cause, cause); + getLogger().error("Exception during establishing a SSL connection: " + cause, cause); + if (channel == null || !channel.isOpen()) { + throw new OpenSearchSecurityException("The provided TCP channel is invalid.", e); + } super.onException(channel, e); } @@ -156,7 +165,7 @@ public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) th } errorHandler.logError(cause, false); - logger.error("Exception during establishing a SSL connection: " + cause, cause); + getLogger().error("Exception during establishing a SSL connection: " + cause, cause); super.exceptionCaught(ctx, cause); } @@ -291,7 +300,7 @@ public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) th } errorHandler.logError(cause, false); - logger.error("Exception during establishing a SSL connection: " + cause, cause); + getLogger().error("Exception during establishing a SSL connection: " + cause, cause); super.exceptionCaught(ctx, cause); } diff --git a/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java b/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java index 27705988d8..32e0f48fac 100644 --- a/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java +++ b/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java @@ -11,10 +11,14 @@ package org.opensearch.security.ssl.transport; -import org.junit.Assert; +import java.util.Collections; + +import org.apache.logging.log4j.Logger; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.network.NetworkService; @@ -28,15 +32,27 @@ import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport.SSLServerChannelInitializer; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.FakeTcpChannel; import org.opensearch.transport.SharedGroupFactory; +import org.opensearch.transport.TcpChannel; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.DecoderException; import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class SecuritySSLNettyTransportTests { @@ -45,16 +61,12 @@ public class SecuritySSLNettyTransportTests { @Mock private ThreadPool threadPool; @Mock - private NetworkService networkService; - @Mock private PageCacheRecycler pageCacheRecycler; @Mock private NamedWriteableRegistry namedWriteableRegistry; @Mock private CircuitBreakerService circuitBreakerService; @Mock - private SharedGroupFactory sharedGroupFactory; - @Mock private Tracer trace; @Mock private SecurityKeyStore ossks; @@ -63,55 +75,127 @@ public class SecuritySSLNettyTransportTests { @Mock private DiscoveryNode discoveryNode; + // This initializes all the above mocks + @Rule + public MockitoRule rule = MockitoJUnit.rule(); + + private NetworkService networkService; + private SharedGroupFactory sharedGroupFactory; + private Logger mockLogger; private SSLConfig sslConfig; private SecuritySSLNettyTransport securitySSLNettyTransport; + Throwable testCause = new Throwable("Test Cause"); @Before public void setup() { - sslConfig = new SSLConfig(Settings.EMPTY); + networkService = new NetworkService(Collections.emptyList()); + sharedGroupFactory = new SharedGroupFactory(Settings.EMPTY); - securitySSLNettyTransport = new SecuritySSLNettyTransport( - Settings.EMPTY, - version, - threadPool, - networkService, - pageCacheRecycler, - namedWriteableRegistry, - circuitBreakerService, - ossks, - sslExceptionHandler, - sharedGroupFactory, - sslConfig, - trace + sslConfig = new SSLConfig(Settings.EMPTY); + mockLogger = mock(Logger.class); + + securitySSLNettyTransport = spy( + new SecuritySSLNettyTransport( + Settings.EMPTY, + version, + threadPool, + networkService, + pageCacheRecycler, + namedWriteableRegistry, + circuitBreakerService, + ossks, + sslExceptionHandler, + sharedGroupFactory, + sslConfig, + trace + ) ); } @Test public void OnException_withNullChannelShouldThrowException() { - NullPointerException exception = new NullPointerException("Test Exception"); + OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid"); + assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(null, exception)); + } + + @Test + public void OnException_withClosedChannelShouldThrowException() { + + TcpChannel channel = new FakeTcpChannel(); + channel.close(); + OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid"); + assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(channel, exception)); + } + + @Test + public void OnException_withNullExceptionShouldSucceed() { + + TcpChannel channel = new FakeTcpChannel(); + securitySSLNettyTransport.onException(channel, null); + verify(securitySSLNettyTransport, times(1)).onException(channel, null); + channel.close(); + } - Assert.assertThrows(NullPointerException.class, () -> securitySSLNettyTransport.onException(null, exception)); + @Test + public void OnException_withDecoderExceptionShouldGetCause() { + when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); + DecoderException exception = new DecoderException("Test Exception", testCause); + TcpChannel channel = new FakeTcpChannel(); + securitySSLNettyTransport.onException(channel, exception); + verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); } @Test public void getServerChannelInitializer_shouldReturnValidServerChannel() { ChannelHandler channelHandler = securitySSLNettyTransport.getServerChannelInitializer("test-server-channel"); - assertThat(channelHandler, is(notNullValue())); assertThat(channelHandler, is(instanceOf(SSLServerChannelInitializer.class))); } @Test public void getClientChannelInitializer_shouldReturnValidClientChannel() { - ChannelHandler channelHandler = securitySSLNettyTransport.getClientChannelInitializer(discoveryNode); - assertThat(channelHandler, is(notNullValue())); assertThat(channelHandler, is(instanceOf(SSLClientChannelInitializer.class))); } + @Test + public void exceptionWithServerChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception { + when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); + Throwable exception = new DecoderException("Test Exception", testCause); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception); + verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); + } + + @Test + public void exceptionWithServerChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception { + when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); + Throwable exception = new OpenSearchSecurityException("Test Exception", testCause); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception); + verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception); + } + + @Test + public void exceptionWithClientChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception { + when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); + Throwable exception = new DecoderException("Test Exception", testCause); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception); + verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); + } + + @Test + public void exceptionWithClientChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception { + when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); + Throwable exception = new OpenSearchSecurityException("Test Exception", testCause); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception); + verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception); + } }