From c3c7a27b9e046190544d7a7beb05f3a17a947095 Mon Sep 17 00:00:00 2001 From: Andrew Ross Date: Mon, 26 Aug 2024 16:58:41 -0700 Subject: [PATCH] Change abstraction point for transport protocol The previous implementation had a transport switch point in InboundPipeline when the bytes were initially pulled off the wire. There was no implementation for any other protocol as the `canHandleBytes` method was hardcoded to return true. I believe this is the wrong point to switch on the protocol. This change makes NativeInboundBytesHandler protocol agnostic beyond the header. With this change, a complete message is parsed from the stream of bytes, with the header schema being unchanged from what exists today. The protocol switch point will now be at `InboundHandler::inboundMessage`. The header will indicate what protocol was used to serialize the the non-header bytes of the message and then invoke the appropriate handler based on that field. Signed-off-by: Andrew Ross --- .../java/org/opensearch/transport/Header.java | 10 +- .../transport/InboundBytesHandler.java | 135 ++++++++++++-- .../opensearch/transport/InboundDecoder.java | 5 +- .../opensearch/transport/InboundHandler.java | 13 +- .../opensearch/transport/InboundMessage.java | 89 +++++++--- .../opensearch/transport/InboundPipeline.java | 36 +--- .../opensearch/transport/TcpTransport.java | 14 +- .../transport/TransportProtocol.java | 29 +++ .../NativeInboundBytesHandler.java | 167 ------------------ .../nativeprotocol/NativeInboundMessage.java | 100 +---------- .../transport/InboundAggregatorTests.java | 32 +++- .../transport/InboundHandlerTests.java | 51 +++++- .../transport/InboundPipelineTests.java | 7 +- .../transport/NativeOutboundHandlerTests.java | 6 +- 14 files changed, 328 insertions(+), 366 deletions(-) create mode 100644 server/src/main/java/org/opensearch/transport/TransportProtocol.java delete mode 100644 server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java diff --git a/server/src/main/java/org/opensearch/transport/Header.java b/server/src/main/java/org/opensearch/transport/Header.java index ac30df8dda02c..fcfeb9c632075 100644 --- a/server/src/main/java/org/opensearch/transport/Header.java +++ b/server/src/main/java/org/opensearch/transport/Header.java @@ -55,6 +55,7 @@ public class Header { private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES"; + private final TransportProtocol protocol; private final int networkMessageSize; private final Version version; private final long requestId; @@ -64,13 +65,18 @@ public class Header { Tuple, Map>> headers; Set features; - Header(int networkMessageSize, long requestId, byte status, Version version) { + Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) { + this.protocol = protocol; this.networkMessageSize = networkMessageSize; this.version = version; this.requestId = requestId; this.status = status; } + TransportProtocol getTransportProtocol() { + return protocol; + } + public int getNetworkMessageSize() { return networkMessageSize; } @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException { @Override public String toString() { return "Header{" + + protocol + + "}{" + networkMessageSize + "}{" + version diff --git a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java index 276891212e43f..a8c774111ae50 100644 --- a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java @@ -9,24 +9,137 @@ package org.opensearch.transport; import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.core.common.bytes.CompositeBytesReference; -import java.io.Closeable; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.function.BiConsumer; /** - * Interface for handling inbound bytes. Can be implemented by different transport protocols. + * Handler for inbound bytes for the native protocol. */ -public interface InboundBytesHandler extends Closeable { +class InboundBytesHandler { - public void doHandleBytes( - TcpChannel channel, - ReleasableBytesReference reference, - BiConsumer messageHandler - ) throws IOException; + private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - public boolean canHandleBytes(ReleasableBytesReference reference); + private final ArrayDeque pending; + private final InboundDecoder decoder; + private final InboundAggregator aggregator; + private final StatsTracker statsTracker; + private boolean isClosed = false; + + InboundBytesHandler( + ArrayDeque pending, + InboundDecoder decoder, + InboundAggregator aggregator, + StatsTracker statsTracker + ) { + this.pending = pending; + this.decoder = decoder; + this.aggregator = aggregator; + this.statsTracker = statsTracker; + } + + public void close() { + isClosed = true; + } + + public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer messageHandler) + throws IOException { + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { + continueDecoding = false; + } + } + } + + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments, messageHandler); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } + } + fragments.clear(); + } + } + } + } + + private ReleasableBytesReference getPendingBytes() { + if (pending.size() == 1) { + return pending.peekFirst().retain(); + } else { + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; + } + final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); + return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); + } + } + + private void releasePendingBytes(int bytesConsumed) { + int bytesToRelease = bytesConsumed; + while (bytesToRelease != 0) { + try (ReleasableBytesReference reference = pending.pollFirst()) { + assert reference != null; + if (bytesToRelease < reference.length()) { + pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); + bytesToRelease -= bytesToRelease; + } else { + bytesToRelease -= reference.length(); + } + } + } + } + + private boolean endOfMessage(Object fragment) { + return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; + } + + private void forwardFragments(TcpChannel channel, ArrayList fragments, BiConsumer messageHandler) + throws IOException { + for (Object fragment : fragments) { + if (fragment instanceof Header) { + assert aggregator.isAggregating() == false; + aggregator.headerReceived((Header) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, InboundMessage.PING); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + try (InboundMessage aggregated = aggregator.finishAggregation()) { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); + } + } + } - @Override - void close(); } diff --git a/server/src/main/java/org/opensearch/transport/InboundDecoder.java b/server/src/main/java/org/opensearch/transport/InboundDecoder.java index d6b7a98e876b3..3e735d4be2420 100644 --- a/server/src/main/java/org/opensearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/opensearch/transport/InboundDecoder.java @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) { // exposed for use in tests static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException { try (StreamInput streamInput = bytesReference.streamInput()) { - streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE); + TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte()); + streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE); long requestId = streamInput.readLong(); byte status = streamInput.readByte(); Version remoteVersion = Version.fromId(streamInput.readInt()); - Header header = new Header(networkMessageSize, requestId, status, remoteVersion); + Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion); final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake()); if (invalidVersion != null) { throw invalidVersion; diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index f77c44ea362cf..76a44832b08dc 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -38,7 +38,6 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.Map; @@ -56,7 +55,7 @@ public class InboundHandler { private volatile long slowLogThresholdMs = Long.MAX_VALUE; - private final Map protocolMessageHandlers; + private final Map protocolMessageHandlers; InboundHandler( String nodeName, @@ -75,7 +74,7 @@ public class InboundHandler { ) { this.threadPool = threadPool; this.protocolMessageHandlers = Map.of( - NativeInboundMessage.NATIVE_PROTOCOL, + TransportProtocol.NATIVE, new NativeMessageHandler( nodeName, version, @@ -107,16 +106,16 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { this.slowLogThresholdMs = slowLogThreshold.getMillis(); } - void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception { + void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception { final long startTime = threadPool.relativeTimeInMillis(); channel.getChannelStats().markAccessed(startTime); messageReceivedFromPipeline(channel, message, startTime); } - private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException { - ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol()); + private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException { + ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol()); if (protocolMessageHandler == null) { - throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol()); + throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol()); } protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener); } diff --git a/server/src/main/java/org/opensearch/transport/InboundMessage.java b/server/src/main/java/org/opensearch/transport/InboundMessage.java index 5c68257557061..576ab73ce9c98 100644 --- a/server/src/main/java/org/opensearch/transport/InboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/InboundMessage.java @@ -32,77 +32,118 @@ package org.opensearch.transport; -import org.opensearch.common.annotation.DeprecatedApi; +import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.io.IOUtils; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; /** * Inbound data as a message - * This api is deprecated, please use {@link org.opensearch.transport.nativeprotocol.NativeInboundMessage} instead. - * @opensearch.api */ -@DeprecatedApi(since = "2.14.0") +@PublicApi(since = "1.0.0") public class InboundMessage implements Releasable, ProtocolInboundMessage { - private final NativeInboundMessage nativeInboundMessage; + static final InboundMessage PING = new InboundMessage(null, null, null, true, null); + + protected final Header header; + protected final ReleasableBytesReference content; + protected final Exception exception; + protected final boolean isPing; + private Releasable breakerRelease; + private StreamInput streamInput; public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - this.nativeInboundMessage = new NativeInboundMessage(header, content, breakerRelease); + this(header, content, null, false, breakerRelease); } public InboundMessage(Header header, Exception exception) { - this.nativeInboundMessage = new NativeInboundMessage(header, exception); + this(header, null, exception, false, null); } public InboundMessage(Header header, boolean isPing) { - this.nativeInboundMessage = new NativeInboundMessage(header, isPing); + this(header, null, null, isPing, null); + } + + private InboundMessage( + Header header, + ReleasableBytesReference content, + Exception exception, + boolean isPing, + Releasable breakerRelease + ) { + this.header = header; + this.content = content; + this.exception = exception; + this.isPing = isPing; + this.breakerRelease = breakerRelease; + } + + TransportProtocol getTransportProtocol() { + if (isPing) { + return TransportProtocol.NATIVE; + } + return header.getTransportProtocol(); + } + + public String getProtocol() { + return header.getTransportProtocol().toString(); } public Header getHeader() { - return this.nativeInboundMessage.getHeader(); + return header; } public int getContentLength() { - return this.nativeInboundMessage.getContentLength(); + if (content == null) { + return 0; + } else { + return content.length(); + } } public Exception getException() { - return this.nativeInboundMessage.getException(); + return exception; } public boolean isPing() { - return this.nativeInboundMessage.isPing(); + return isPing; } public boolean isShortCircuit() { - return this.nativeInboundMessage.getException() != null; + return exception != null; } public Releasable takeBreakerReleaseControl() { - return this.nativeInboundMessage.takeBreakerReleaseControl(); + final Releasable toReturn = breakerRelease; + breakerRelease = null; + if (toReturn != null) { + return toReturn; + } else { + return () -> {}; + } } public StreamInput openOrGetStreamInput() throws IOException { - return this.nativeInboundMessage.openOrGetStreamInput(); + assert isPing == false && content != null; + if (streamInput == null) { + streamInput = content.streamInput(); + streamInput.setVersion(header.getVersion()); + } + return streamInput; } @Override public void close() { - this.nativeInboundMessage.close(); + IOUtils.closeWhileHandlingException(streamInput); + Releasables.closeWhileHandlingException(content, breakerRelease); } @Override public String toString() { - return this.nativeInboundMessage.toString(); - } - - @Override - public String getProtocol() { - return this.nativeInboundMessage.getProtocol(); + return "InboundMessage{" + header + "}"; } - } diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index 5cee3bb975223..3acb43f58b443 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -38,11 +38,9 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; import java.io.IOException; import java.util.ArrayDeque; -import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongSupplier; @@ -62,9 +60,8 @@ public class InboundPipeline implements Releasable { private Exception uncaughtException; private final ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; - private final BiConsumer messageHandler; - private final List protocolBytesHandlers; - private InboundBytesHandler currentHandler; + private final BiConsumer messageHandler; + private final InboundBytesHandler bytesHandler; public InboundPipeline( Version version, @@ -73,7 +70,7 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, Supplier circuitBreaker, Function> registryFunction, - BiConsumer messageHandler + BiConsumer messageHandler ) { this( statsTracker, @@ -89,23 +86,20 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, InboundDecoder decoder, InboundAggregator aggregator, - BiConsumer messageHandler + BiConsumer messageHandler ) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; - this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); + this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker); this.messageHandler = messageHandler; } @Override public void close() { isClosed = true; - if (currentHandler != null) { - currentHandler.close(); - currentHandler = null; - } + bytesHandler.close(); Releasables.closeWhileHandlingException(decoder, aggregator); Releasables.closeWhileHandlingException(pending); pending.clear(); @@ -127,22 +121,6 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); - - // If we don't have a current handler, we should try to find one based on the protocol of the incoming bytes. - if (currentHandler == null) { - for (InboundBytesHandler handler : protocolBytesHandlers) { - if (handler.canHandleBytes(reference)) { - currentHandler = handler; - break; - } - } - } - - // If we have a current handler determined based on protocol, we should continue to use it for the fragmented bytes. - if (currentHandler != null) { - currentHandler.doHandleBytes(channel, reference, messageHandler); - } else { - throw new IllegalStateException("No bytes handler found for the incoming transport protocol"); - } + bytesHandler.doHandleBytes(channel, reference, messageHandler); } } diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index 78452a25a58d6..f56cd146ce953 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -778,15 +778,11 @@ protected void serverAcceptedChannel(TcpChannel channel) { protected abstract void stopInternal(); /** - * @deprecated use {@link #inboundMessage(TcpChannel, ProtocolInboundMessage)} - * Handles inbound message that has been decoded. - * - * @param channel the channel the message is from - * @param message the message + * @deprecated Use {{@link #inboundMessage(TcpChannel, InboundMessage)}} instead */ - @Deprecated(since = "2.14.0", forRemoval = true) - public void inboundMessage(TcpChannel channel, InboundMessage message) { - inboundMessage(channel, (ProtocolInboundMessage) message); + @Deprecated + public void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) { + inboundMessage(channel, (InboundMessage) message); } /** @@ -795,7 +791,7 @@ public void inboundMessage(TcpChannel channel, InboundMessage message) { * @param channel the channel the message is from * @param message the message */ - public void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) { + public void inboundMessage(TcpChannel channel, InboundMessage message) { try { inboundHandler.inboundMessage(channel, message); } catch (Exception e) { diff --git a/server/src/main/java/org/opensearch/transport/TransportProtocol.java b/server/src/main/java/org/opensearch/transport/TransportProtocol.java new file mode 100644 index 0000000000000..4a11520d38d56 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/TransportProtocol.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +/** + * Enumeration of transport protocols. + */ +enum TransportProtocol { + /** + * The original, hand-rolled binary protocol used for node-to-node + * communication. Message schemas are defined implicitly in code using the + * StreamInput and StreamOutput classes to parse and generate binary data. + */ + NATIVE; + + public static TransportProtocol fromBytes(byte b1, byte b2) { + if (b1 == 'E' && b2 == 'S') { + return NATIVE; + } + + throw new IllegalArgumentException("Unknown transport protocol: [" + b1 + ", " + b2 + "]"); + } +} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java deleted file mode 100644 index a8a4c0da7ec0f..0000000000000 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport.nativeprotocol; - -import org.opensearch.common.bytes.ReleasableBytesReference; -import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.core.common.bytes.CompositeBytesReference; -import org.opensearch.transport.Header; -import org.opensearch.transport.InboundAggregator; -import org.opensearch.transport.InboundBytesHandler; -import org.opensearch.transport.InboundDecoder; -import org.opensearch.transport.InboundMessage; -import org.opensearch.transport.ProtocolInboundMessage; -import org.opensearch.transport.StatsTracker; -import org.opensearch.transport.TcpChannel; - -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.function.BiConsumer; - -/** - * Handler for inbound bytes for the native protocol. - */ -public class NativeInboundBytesHandler implements InboundBytesHandler { - - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); - - private final ArrayDeque pending; - private final InboundDecoder decoder; - private final InboundAggregator aggregator; - private final StatsTracker statsTracker; - private boolean isClosed = false; - - public NativeInboundBytesHandler( - ArrayDeque pending, - InboundDecoder decoder, - InboundAggregator aggregator, - StatsTracker statsTracker - ) { - this.pending = pending; - this.decoder = decoder; - this.aggregator = aggregator; - this.statsTracker = statsTracker; - } - - @Override - public void close() { - isClosed = true; - } - - @Override - public boolean canHandleBytes(ReleasableBytesReference reference) { - return true; - } - - @Override - public void doHandleBytes( - TcpChannel channel, - ReleasableBytesReference reference, - BiConsumer messageHandler - ) throws IOException { - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } - } - } - - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments, messageHandler); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } - } - } - } - - private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); - } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); - } - } - - private void releasePendingBytes(int bytesConsumed) { - int bytesToRelease = bytesConsumed; - while (bytesToRelease != 0) { - try (ReleasableBytesReference reference = pending.pollFirst()) { - assert reference != null; - if (bytesToRelease < reference.length()) { - pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); - bytesToRelease -= bytesToRelease; - } else { - bytesToRelease -= reference.length(); - } - } - } - } - - private boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; - } - - private void forwardFragments( - TcpChannel channel, - ArrayList fragments, - BiConsumer messageHandler - ) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - assert aggregator.isAggregating() == false; - aggregator.headerReceived((Header) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - try (InboundMessage aggregated = aggregator.finishAggregation()) { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); - } - } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); - } - } - } - -} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java index 1143f129b6319..db0b5eb4ffa72 100644 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java @@ -32,118 +32,34 @@ package org.opensearch.transport.nativeprotocol; -import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.annotation.DeprecatedApi; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.common.util.io.IOUtils; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.transport.Header; -import org.opensearch.transport.ProtocolInboundMessage; - -import java.io.IOException; +import org.opensearch.transport.InboundMessage; /** * Inbound data as a message * - * @opensearch.api + * This class is deprecated in favor of the InboundMessage. */ -@PublicApi(since = "2.14.0") -public class NativeInboundMessage implements Releasable, ProtocolInboundMessage { +@DeprecatedApi(since = "2.17.0") +public class NativeInboundMessage extends InboundMessage { /** * The protocol used to encode this message */ public static String NATIVE_PROTOCOL = "native"; - private final Header header; - private final ReleasableBytesReference content; - private final Exception exception; - private final boolean isPing; - private Releasable breakerRelease; - private StreamInput streamInput; - public NativeInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - this.header = header; - this.content = content; - this.breakerRelease = breakerRelease; - this.exception = null; - this.isPing = false; + super(header, content, breakerRelease); } public NativeInboundMessage(Header header, Exception exception) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = exception; - this.isPing = false; + super(header, exception); } public NativeInboundMessage(Header header, boolean isPing) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = null; - this.isPing = isPing; - } - - @Override - public String getProtocol() { - return NATIVE_PROTOCOL; - } - - public Header getHeader() { - return header; - } - - public int getContentLength() { - if (content == null) { - return 0; - } else { - return content.length(); - } - } - - public Exception getException() { - return exception; - } - - public boolean isPing() { - return isPing; + super(header, isPing); } - - public boolean isShortCircuit() { - return exception != null; - } - - public Releasable takeBreakerReleaseControl() { - final Releasable toReturn = breakerRelease; - breakerRelease = null; - if (toReturn != null) { - return toReturn; - } else { - return () -> {}; - } - } - - public StreamInput openOrGetStreamInput() throws IOException { - assert isPing == false && content != null; - if (streamInput == null) { - streamInput = content.streamInput(); - streamInput.setVersion(header.getVersion()); - } - return streamInput; - } - - @Override - public void close() { - IOUtils.closeWhileHandlingException(streamInput); - Releasables.closeWhileHandlingException(content, breakerRelease); - } - - @Override - public String toString() { - return "InboundMessage{" + header + "}"; - } - } diff --git a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java index 2dd98a8efe2a3..6168fd1c6a307 100644 --- a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java @@ -78,7 +78,7 @@ public void setUp() throws Exception { public void testInboundAggregation() throws IOException { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = "action_name"; // Initiate Message @@ -125,7 +125,7 @@ public void testInboundAggregation() throws IOException { public void testInboundUnknownAction() throws IOException { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = unknownAction; // Initiate Message @@ -149,7 +149,13 @@ public void testInboundUnknownAction() throws IOException { public void testCircuitBreak() throws IOException { circuitBreaker.startBreaking(); // Actions are breakable - Header breakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header breakableHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + TransportStatus.setRequest((byte) 0), + Version.CURRENT + ); breakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); breakableHeader.actionName = "action_name"; // Initiate Message @@ -169,7 +175,13 @@ public void testCircuitBreak() throws IOException { assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class)); // Actions marked as unbreakable are not broken - Header unbreakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header unbreakableHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + TransportStatus.setRequest((byte) 0), + Version.CURRENT + ); unbreakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); unbreakableHeader.actionName = unBreakableAction; // Initiate Message @@ -188,7 +200,13 @@ public void testCircuitBreak() throws IOException { // Handshakes are not broken final byte handshakeStatus = TransportStatus.setHandshake(TransportStatus.setRequest((byte) 0)); - Header handshakeHeader = new Header(randomInt(), randomNonNegativeLong(), handshakeStatus, Version.CURRENT); + Header handshakeHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + handshakeStatus, + Version.CURRENT + ); handshakeHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); handshakeHeader.actionName = "handshake"; // Initiate Message @@ -208,7 +226,7 @@ public void testCircuitBreak() throws IOException { public void testCloseWillCloseContent() { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = "action_name"; // Initiate Message @@ -248,7 +266,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException { } else { actionName = "action_name"; } - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); // Initiate Message aggregator.headerReceived(header); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index ea656f6651b1e..7779db9dacc3c 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -151,7 +151,7 @@ public void testPing() throws Exception { ); requestHandlers.registerHandler(registry); - handler.inboundMessage(channel, new InboundMessage(null, true)); + handler.inboundMessage(channel, InboundMessage.PING); if (channel.isServerChannel()) { BytesReference ping = channel.getMessageCaptor().get(); assertEquals('E', ping.get(0)); @@ -214,7 +214,13 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version + ); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -235,7 +241,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -258,6 +264,7 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce final Version remoteVersion = VersionUtils.randomCompatibleVersion(random(), version); final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), @@ -298,6 +305,7 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws final Version remoteVersion = Version.fromId(randomIntBetween(0, version.minimumCompatibilityVersion().id - 1)); final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), @@ -329,6 +337,7 @@ public void testLogsSlowInboundProcessing() throws Exception { final Version remoteVersion = Version.CURRENT; final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), @@ -411,7 +420,13 @@ public void onResponseSent(long requestId, String action, Exception error) { BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version + ); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -476,7 +491,13 @@ public void onResponseSent(long requestId, String action, Exception error) { ); // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version + ); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -540,7 +561,13 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version + ); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -562,7 +589,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -626,7 +653,13 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version + ); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -643,7 +676,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); // Create the response payload by intentionally stripping 1 byte away BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 74457e2b153fd..cd6c4cf260176 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -81,9 +81,8 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); - final BiConsumer messageHandler = (c, m) -> { + final BiConsumer messageHandler = (c, message) -> { try { - InboundMessage message = (InboundMessage) m; final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); @@ -198,7 +197,7 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { } public void testDecodeExceptionIsPropagated() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -228,7 +227,7 @@ public void testDecodeExceptionIsPropagated() throws IOException { } public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java index 01f19bea7a37f..11ca683c306bf 100644 --- a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java @@ -52,7 +52,6 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import org.junit.After; import org.junit.Before; @@ -106,9 +105,8 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - NativeInboundMessage m1 = (NativeInboundMessage) m; - Streams.copy(m1.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); + Streams.copy(m.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m.getHeader(), streamOutput.bytes())); } catch (IOException e) { throw new AssertionError(e); }