Skip to content

Commit

Permalink
Replacing InboundMessage with NativeInboundMessage for deprecation
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Apr 8, 2024
1 parent d202d90 commit 38e668c
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -113,7 +114,7 @@ public void aggregate(ReleasableBytesReference content) {
}
}

public InboundMessage finishAggregation() throws IOException {
public NativeInboundMessage finishAggregation() throws IOException {
ensureOpen();
final ReleasableBytesReference releasableContent;
if (isFirstContent()) {
Expand All @@ -127,7 +128,7 @@ public InboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -142,7 +143,7 @@ public InboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.EOFException;
import java.io.IOException;
Expand Down Expand Up @@ -111,7 +112,7 @@ public void messageReceived(
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException {
InboundMessage inboundMessage = (InboundMessage) message;
NativeInboundMessage inboundMessage = (NativeInboundMessage) message;
TransportLogger.logInboundMessage(channel, inboundMessage);
if (inboundMessage.isPing()) {
keepAlive.receiveKeepAlive(channel);
Expand All @@ -122,7 +123,7 @@ public void messageReceived(

private void handleMessage(
TcpChannel channel,
InboundMessage message,
NativeInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
Expand Down Expand Up @@ -194,7 +195,7 @@ private Map<String, Collection<String>> extractHeaders(Map<String, String> heade
private <T extends TransportRequest> void handleRequest(
TcpChannel channel,
Header header,
InboundMessage message,
NativeInboundMessage message,
TransportMessageListener messageListener
) throws IOException {
final String action = header.getActionName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.compress.CompressorRegistry;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;

Expand All @@ -64,7 +65,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) {
}
}

static void logInboundMessage(TcpChannel channel, InboundMessage message) {
static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) {
if (logger.isTraceEnabled()) {
try {
String logMessage = format(channel, message, "READ");
Expand Down Expand Up @@ -136,7 +137,7 @@ private static String format(TcpChannel channel, BytesReference message, String
return sb.toString();
}

private static String format(TcpChannel channel, InboundMessage message, String event) throws IOException {
private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException {
final StringBuilder sb = new StringBuilder();
sb.append(channel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.transport.InboundAggregator;
import org.opensearch.transport.InboundBytesHandler;
import org.opensearch.transport.InboundDecoder;
import org.opensearch.transport.InboundMessage;
import org.opensearch.transport.ProtocolInboundMessage;
import org.opensearch.transport.StatsTracker;
import org.opensearch.transport.TcpChannel;
Expand All @@ -32,7 +31,7 @@
public class NativeInboundBytesHandler implements InboundBytesHandler {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);
private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true);
private static final NativeInboundMessage PING_MESSAGE = new NativeInboundMessage(null, true);

private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
Expand Down Expand Up @@ -152,7 +151,7 @@ private void forwardFragments(
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (InboundMessage aggregated = aggregator.finishAggregation()) {
try (NativeInboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.junit.Before;

import java.io.IOException;
Expand Down Expand Up @@ -107,7 +108,7 @@ public void testInboundAggregation() throws IOException {
}

// Signal EOS
InboundMessage aggregated = aggregator.finishAggregation();
NativeInboundMessage aggregated = aggregator.finishAggregation();

assertThat(aggregated, notNullValue());
assertFalse(aggregated.isPing());
Expand Down Expand Up @@ -138,7 +139,7 @@ public void testInboundUnknownAction() throws IOException {
assertEquals(0, content.refCount());

// Signal EOS
InboundMessage aggregated = aggregator.finishAggregation();
NativeInboundMessage aggregated = aggregator.finishAggregation();

assertThat(aggregated, notNullValue());
assertTrue(aggregated.isShortCircuit());
Expand All @@ -161,7 +162,7 @@ public void testCircuitBreak() throws IOException {
content1.close();

// Signal EOS
InboundMessage aggregated1 = aggregator.finishAggregation();
NativeInboundMessage aggregated1 = aggregator.finishAggregation();

assertEquals(0, content1.refCount());
assertThat(aggregated1, notNullValue());
Expand All @@ -180,7 +181,7 @@ public void testCircuitBreak() throws IOException {
content2.close();

// Signal EOS
InboundMessage aggregated2 = aggregator.finishAggregation();
NativeInboundMessage aggregated2 = aggregator.finishAggregation();

assertEquals(1, content2.refCount());
assertThat(aggregated2, notNullValue());
Expand All @@ -199,7 +200,7 @@ public void testCircuitBreak() throws IOException {
content3.close();

// Signal EOS
InboundMessage aggregated3 = aggregator.finishAggregation();
NativeInboundMessage aggregated3 = aggregator.finishAggregation();

assertEquals(1, content3.refCount());
assertThat(aggregated3, notNullValue());
Expand Down Expand Up @@ -263,7 +264,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException {
content.close();

// Signal EOS
InboundMessage aggregated = aggregator.finishAggregation();
NativeInboundMessage aggregated = aggregator.finishAggregation();

assertThat(aggregated, notNullValue());
assertFalse(header.needsToReadVariableHeader());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.test.VersionUtils;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -142,7 +143,7 @@ public void testPing() throws Exception {
);
requestHandlers.registerHandler(registry);

handler.inboundMessage(channel, new InboundMessage(null, true));
handler.inboundMessage(channel, new NativeInboundMessage(null, true));
if (channel.isServerChannel()) {
BytesReference ping = channel.getMessageCaptor().get();
assertEquals('E', ping.get(0));
Expand Down Expand Up @@ -208,7 +209,11 @@ public TestResponse read(StreamInput in) throws IOException {
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(requestContent),
() -> {}
);
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);

Expand All @@ -229,7 +234,11 @@ public TestResponse read(StreamInput in) throws IOException {
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
NativeInboundMessage responseMessage = new NativeInboundMessage(
responseHeader,
ReleasableBytesReference.wrap(responseContent),
() -> {}
);
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);

Expand All @@ -256,7 +265,7 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce
TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)),
remoteVersion
);
final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader);
final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader);
requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME;
requestHeader.headers = Tuple.tuple(Map.of(), Map.of());
requestHeader.features = Set.of();
Expand Down Expand Up @@ -296,7 +305,7 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws
TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)),
remoteVersion
);
final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader);
final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader);
requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME;
requestHeader.headers = Tuple.tuple(Map.of(), Map.of());
requestHeader.features = Set.of();
Expand Down Expand Up @@ -327,13 +336,17 @@ public void testLogsSlowInboundProcessing() throws Exception {
TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)),
remoteVersion
);
final InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {
try {
TimeUnit.SECONDS.sleep(1L);
} catch (InterruptedException e) {
throw new AssertionError(e);
final NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(BytesArray.EMPTY),
() -> {
try {
TimeUnit.SECONDS.sleep(1L);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
});
);
requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME;
requestHeader.headers = Tuple.tuple(Collections.emptyMap(), Collections.emptyMap());
requestHeader.features = Set.of();
Expand Down Expand Up @@ -407,7 +420,11 @@ public void onResponseSent(long requestId, String action, Exception error) {
BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(requestContent),
() -> {}
);
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);

Expand Down Expand Up @@ -474,7 +491,11 @@ public void onResponseSent(long requestId, String action, Exception error) {
// Create the request payload by intentionally stripping 1 byte away
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(requestContent),
() -> {}
);
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);

Expand Down Expand Up @@ -540,7 +561,11 @@ public TestResponse read(StreamInput in) throws IOException {
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(requestContent),
() -> {}
);
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);

Expand All @@ -562,7 +587,11 @@ public TestResponse read(StreamInput in) throws IOException {
BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
NativeInboundMessage responseMessage = new NativeInboundMessage(
responseHeader,
ReleasableBytesReference.wrap(responseContent),
() -> {}
);
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);

Expand Down Expand Up @@ -628,7 +657,11 @@ public TestResponse read(StreamInput in) throws IOException {
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
NativeInboundMessage requestMessage = new NativeInboundMessage(
requestHeader,
ReleasableBytesReference.wrap(requestContent),
() -> {}
);
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);

Expand All @@ -645,7 +678,11 @@ public TestResponse read(StreamInput in) throws IOException {
// Create the response payload by intentionally stripping 1 byte away
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
NativeInboundMessage responseMessage = new NativeInboundMessage(
responseHeader,
ReleasableBytesReference.wrap(responseContent),
() -> {}
);
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);

Expand All @@ -654,8 +691,8 @@ public TestResponse read(StreamInput in) throws IOException {
assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler"));
}

private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) {
return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) {
private static NativeInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) {
return new NativeInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) {
@Override
public StreamInput openOrGetStreamInput() {
final StreamInput streamInput = new InputStreamStreamInput(new InputStream() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -74,7 +75,7 @@ public void testPipelineHandling() throws IOException {
final List<ReleasableBytesReference> toRelease = new ArrayList<>();
final BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler = (c, m) -> {
try {
InboundMessage message = (InboundMessage) m;
NativeInboundMessage message = (NativeInboundMessage) m;
final Header header = message.getHeader();
final MessageData actualData;
final Version version = header.getVersion();
Expand Down
Loading

0 comments on commit 38e668c

Please sign in to comment.