diff --git a/server/src/main/java/org/opensearch/transport/InboundAggregator.java b/server/src/main/java/org/opensearch/transport/InboundAggregator.java index e894331f3b64e..f52875d880b4f 100644 --- a/server/src/main/java/org/opensearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/opensearch/transport/InboundAggregator.java @@ -40,6 +40,7 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.bytes.CompositeBytesReference; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -113,7 +114,7 @@ public void aggregate(ReleasableBytesReference content) { } } - public InboundMessage finishAggregation() throws IOException { + public NativeInboundMessage finishAggregation() throws IOException { ensureOpen(); final ReleasableBytesReference releasableContent; if (isFirstContent()) { @@ -127,7 +128,7 @@ public InboundMessage finishAggregation() throws IOException { } final BreakerControl breakerControl = new BreakerControl(circuitBreaker); - final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); + final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { @@ -142,7 +143,7 @@ public InboundMessage finishAggregation() throws IOException { if (isShortCircuited()) { aggregated.close(); success = true; - return new InboundMessage(aggregated.getHeader(), aggregationException); + return new NativeInboundMessage(aggregated.getHeader(), aggregationException); } else { success = true; return aggregated; diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java index 861b95a8098f2..c5b65f9eb7a11 100644 --- a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -51,6 +51,7 @@ import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.EOFException; import java.io.IOException; @@ -111,7 +112,7 @@ public void messageReceived( long slowLogThresholdMs, TransportMessageListener messageListener ) throws IOException { - InboundMessage inboundMessage = (InboundMessage) message; + NativeInboundMessage inboundMessage = (NativeInboundMessage) message; TransportLogger.logInboundMessage(channel, inboundMessage); if (inboundMessage.isPing()) { keepAlive.receiveKeepAlive(channel); @@ -122,7 +123,7 @@ public void messageReceived( private void handleMessage( TcpChannel channel, - InboundMessage message, + NativeInboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener @@ -194,7 +195,7 @@ private Map> extractHeaders(Map heade private void handleRequest( TcpChannel channel, Header header, - InboundMessage message, + NativeInboundMessage message, TransportMessageListener messageListener ) throws IOException { final String action = header.getActionName(); diff --git a/server/src/main/java/org/opensearch/transport/TransportLogger.java b/server/src/main/java/org/opensearch/transport/TransportLogger.java index 997b3bb5ba18e..e780f643aafd7 100644 --- a/server/src/main/java/org/opensearch/transport/TransportLogger.java +++ b/server/src/main/java/org/opensearch/transport/TransportLogger.java @@ -40,6 +40,7 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.compress.CompressorRegistry; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; @@ -64,7 +65,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) { } } - static void logInboundMessage(TcpChannel channel, InboundMessage message) { + static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) { if (logger.isTraceEnabled()) { try { String logMessage = format(channel, message, "READ"); @@ -136,7 +137,7 @@ private static String format(TcpChannel channel, BytesReference message, String return sb.toString(); } - private static String format(TcpChannel channel, InboundMessage message, String event) throws IOException { + private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException { final StringBuilder sb = new StringBuilder(); sb.append(channel); diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java index a8a4c0da7ec0f..97981aeb6736e 100644 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java @@ -16,7 +16,6 @@ import org.opensearch.transport.InboundAggregator; import org.opensearch.transport.InboundBytesHandler; import org.opensearch.transport.InboundDecoder; -import org.opensearch.transport.InboundMessage; import org.opensearch.transport.ProtocolInboundMessage; import org.opensearch.transport.StatsTracker; import org.opensearch.transport.TcpChannel; @@ -32,7 +31,7 @@ public class NativeInboundBytesHandler implements InboundBytesHandler { private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); + private static final NativeInboundMessage PING_MESSAGE = new NativeInboundMessage(null, true); private final ArrayDeque pending; private final InboundDecoder decoder; @@ -152,7 +151,7 @@ private void forwardFragments( messageHandler.accept(channel, PING_MESSAGE); } else if (fragment == InboundDecoder.END_CONTENT) { assert aggregator.isAggregating(); - try (InboundMessage aggregated = aggregator.finishAggregation()) { + try (NativeInboundMessage aggregated = aggregator.finishAggregation()) { statsTracker.markMessageReceived(); messageHandler.accept(channel, aggregated); } diff --git a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java index 2dd98a8efe2a3..4ac78366360d7 100644 --- a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java @@ -42,6 +42,7 @@ import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.Before; import java.io.IOException; @@ -107,7 +108,7 @@ public void testInboundAggregation() throws IOException { } // Signal EOS - InboundMessage aggregated = aggregator.finishAggregation(); + NativeInboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(aggregated.isPing()); @@ -138,7 +139,7 @@ public void testInboundUnknownAction() throws IOException { assertEquals(0, content.refCount()); // Signal EOS - InboundMessage aggregated = aggregator.finishAggregation(); + NativeInboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertTrue(aggregated.isShortCircuit()); @@ -161,7 +162,7 @@ public void testCircuitBreak() throws IOException { content1.close(); // Signal EOS - InboundMessage aggregated1 = aggregator.finishAggregation(); + NativeInboundMessage aggregated1 = aggregator.finishAggregation(); assertEquals(0, content1.refCount()); assertThat(aggregated1, notNullValue()); @@ -180,7 +181,7 @@ public void testCircuitBreak() throws IOException { content2.close(); // Signal EOS - InboundMessage aggregated2 = aggregator.finishAggregation(); + NativeInboundMessage aggregated2 = aggregator.finishAggregation(); assertEquals(1, content2.refCount()); assertThat(aggregated2, notNullValue()); @@ -199,7 +200,7 @@ public void testCircuitBreak() throws IOException { content3.close(); // Signal EOS - InboundMessage aggregated3 = aggregator.finishAggregation(); + NativeInboundMessage aggregated3 = aggregator.finishAggregation(); assertEquals(1, content3.refCount()); assertThat(aggregated3, notNullValue()); @@ -263,7 +264,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException { content.close(); // Signal EOS - InboundMessage aggregated = aggregator.finishAggregation(); + NativeInboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(header.needsToReadVariableHeader()); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 0d171e17e70e1..2dde27d62e759 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -56,6 +56,7 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.After; import org.junit.Before; @@ -142,7 +143,7 @@ public void testPing() throws Exception { ); requestHandlers.registerHandler(registry); - handler.inboundMessage(channel, new InboundMessage(null, true)); + handler.inboundMessage(channel, new NativeInboundMessage(null, true)); if (channel.isServerChannel()) { BytesReference ping = channel.getMessageCaptor().get(); assertEquals('E', ping.get(0)); @@ -208,7 +209,11 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(requestContent), + () -> {} + ); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -229,7 +234,11 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + NativeInboundMessage responseMessage = new NativeInboundMessage( + responseHeader, + ReleasableBytesReference.wrap(responseContent), + () -> {} + ); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -256,7 +265,7 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -296,7 +305,7 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -327,13 +336,17 @@ public void testLogsSlowInboundProcessing() throws Exception { TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> { - try { - TimeUnit.SECONDS.sleep(1L); - } catch (InterruptedException e) { - throw new AssertionError(e); + final NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(BytesArray.EMPTY), + () -> { + try { + TimeUnit.SECONDS.sleep(1L); + } catch (InterruptedException e) { + throw new AssertionError(e); + } } - }); + ); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Collections.emptyMap(), Collections.emptyMap()); requestHeader.features = Set.of(); @@ -407,7 +420,11 @@ public void onResponseSent(long requestId, String action, Exception error) { BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(requestContent), + () -> {} + ); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -474,7 +491,11 @@ public void onResponseSent(long requestId, String action, Exception error) { // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(requestContent), + () -> {} + ); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -540,7 +561,11 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(requestContent), + () -> {} + ); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -562,7 +587,11 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + NativeInboundMessage responseMessage = new NativeInboundMessage( + responseHeader, + ReleasableBytesReference.wrap(responseContent), + () -> {} + ); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -628,7 +657,11 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + NativeInboundMessage requestMessage = new NativeInboundMessage( + requestHeader, + ReleasableBytesReference.wrap(requestContent), + () -> {} + ); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -645,7 +678,11 @@ public TestResponse read(StreamInput in) throws IOException { // Create the response payload by intentionally stripping 1 byte away BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + NativeInboundMessage responseMessage = new NativeInboundMessage( + responseHeader, + ReleasableBytesReference.wrap(responseContent), + () -> {} + ); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -654,8 +691,8 @@ public TestResponse read(StreamInput in) throws IOException { assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); } - private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { - return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { + private static NativeInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { + return new NativeInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { @Override public StreamInput openOrGetStreamInput() { final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 2dfe8a0dd8590..d54f7e6fd2c2b 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -49,6 +49,7 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -74,7 +75,7 @@ public void testPipelineHandling() throws IOException { final List toRelease = new ArrayList<>(); final BiConsumer messageHandler = (c, m) -> { try { - InboundMessage message = (InboundMessage) m; + NativeInboundMessage message = (NativeInboundMessage) m; final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index 36ba409a2de03..ad7d4401af13c 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -53,6 +53,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.After; import org.junit.Before; @@ -97,7 +98,7 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - InboundMessage m1 = (InboundMessage) m; + NativeInboundMessage m1 = (NativeInboundMessage) m; Streams.copy(m1.openOrGetStreamInput(), streamOutput); message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); } catch (IOException e) {