diff --git a/CHANGELOG.md b/CHANGELOG.md index c8def76226014..84310884b891d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - Add useCompoundFile index setting ([#13478](https://github.com/opensearch-project/OpenSearch/pull/13478)) +- Make outbound side of transport protocol dependent ([#13293](https://github.com/opensearch-project/OpenSearch/pull/13293)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index 6492900c49a0e..f77c44ea362cf 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -32,7 +32,9 @@ package org.opensearch.transport; +import org.opensearch.Version; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.BigArrays; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; @@ -57,7 +59,12 @@ public class InboundHandler { private final Map protocolMessageHandlers; InboundHandler( + String nodeName, + Version version, + String[] features, + StatsTracker statsTracker, ThreadPool threadPool, + BigArrays bigArrays, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, TransportHandshaker handshaker, @@ -70,7 +77,12 @@ public class InboundHandler { this.protocolMessageHandlers = Map.of( NativeInboundMessage.NATIVE_PROTOCOL, new NativeMessageHandler( + nodeName, + version, + features, + statsTracker, threadPool, + bigArrays, outboundHandler, namedWriteableRegistry, handshaker, @@ -83,6 +95,7 @@ public class InboundHandler { } void setMessageListener(TransportMessageListener listener) { + protocolMessageHandlers.values().forEach(handler -> handler.setMessageListener(listener)); if (messageListener == TransportMessageListener.NOOP_LISTENER) { messageListener = listener; } else { diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java index 861b95a8098f2..58adc2d3d68a5 100644 --- a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -37,6 +37,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.BytesRef; import org.opensearch.Version; +import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.io.stream.ByteBufferStreamInput; @@ -51,6 +52,7 @@ import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import java.io.EOFException; import java.io.IOException; @@ -71,7 +73,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler { private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class); private final ThreadPool threadPool; - private final OutboundHandler outboundHandler; + private final NativeOutboundHandler outboundHandler; private final NamedWriteableRegistry namedWriteableRegistry; private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; @@ -81,7 +83,12 @@ public class NativeMessageHandler implements ProtocolMessageHandler { private final Tracer tracer; NativeMessageHandler( + String nodeName, + Version version, + String[] features, + StatsTracker statsTracker, ThreadPool threadPool, + BigArrays bigArrays, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, TransportHandshaker handshaker, @@ -91,7 +98,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler { TransportKeepAlive keepAlive ) { this.threadPool = threadPool; - this.outboundHandler = outboundHandler; + this.outboundHandler = new NativeOutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, outboundHandler); this.namedWriteableRegistry = namedWriteableRegistry; this.handshaker = handshaker; this.requestHandlers = requestHandlers; @@ -491,4 +498,9 @@ public void onFailure(Exception e) { } } + @Override + public void setMessageListener(TransportMessageListener listener) { + outboundHandler.setMessageListener(listener); + } + } diff --git a/server/src/main/java/org/opensearch/transport/OutboundHandler.java b/server/src/main/java/org/opensearch/transport/OutboundHandler.java index b83dbdd0effe4..43f53e4011260 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/OutboundHandler.java @@ -35,164 +35,47 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.CheckedSupplier; -import org.opensearch.common.io.stream.ReleasableBytesStreamOutput; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.network.CloseableChannel; import org.opensearch.common.transport.NetworkExceptionHelper; -import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.util.io.IOUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.NotifyOnceListener; import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.transport.TransportResponse; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.util.Set; /** * Outbound data handler * * @opensearch.internal */ -final class OutboundHandler { +public final class OutboundHandler { private static final Logger logger = LogManager.getLogger(OutboundHandler.class); - private final String nodeName; - private final Version version; - private final String[] features; private final StatsTracker statsTracker; private final ThreadPool threadPool; - private final BigArrays bigArrays; - private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; - OutboundHandler( - String nodeName, - Version version, - String[] features, - StatsTracker statsTracker, - ThreadPool threadPool, - BigArrays bigArrays - ) { - this.nodeName = nodeName; - this.version = version; - this.features = features; + public OutboundHandler(StatsTracker statsTracker, ThreadPool threadPool) { this.statsTracker = statsTracker; this.threadPool = threadPool; - this.bigArrays = bigArrays; } void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { - SendContext sendContext = new SendContext(channel, () -> bytes, listener); + SendContext sendContext = new SendContext(statsTracker, channel, () -> bytes, listener); try { - internalSend(channel, sendContext); + sendBytes(channel, sendContext); } catch (IOException e) { // This should not happen as the bytes are already serialized throw new AssertionError(e); } } - /** - * Sends the request to the given channel. This method should be used to send {@link TransportRequest} - * objects back to the caller. - */ - void sendRequest( - final DiscoveryNode node, - final TcpChannel channel, - final long requestId, - final String action, - final TransportRequest request, - final TransportRequestOptions options, - final Version channelVersion, - final boolean compressRequest, - final boolean isHandshake - ) throws IOException, TransportException { - Version version = Version.min(this.version, channelVersion); - OutboundMessage.Request message = new OutboundMessage.Request( - threadPool.getThreadContext(), - features, - request, - version, - action, - requestId, - isHandshake, - compressRequest - ); - ActionListener listener = ActionListener.wrap(() -> messageListener.onRequestSent(node, requestId, action, request, options)); - sendMessage(channel, message, listener); - } - - /** - * Sends the response to the given channel. This method should be used to send {@link TransportResponse} - * objects back to the caller. - * - * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses - */ - void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final long requestId, - final String action, - final TransportResponse response, - final boolean compress, - final boolean isHandshake - ) throws IOException { - Version version = Version.min(this.version, nodeVersion); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - features, - response, - version, - requestId, - isHandshake, - compress - ); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - sendMessage(channel, message, listener); - } - - /** - * Sends back an error response to the caller via the given channel - */ - void sendErrorResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final long requestId, - final String action, - final Exception error - ) throws IOException { - Version version = Version.min(this.version, nodeVersion); - TransportAddress address = new TransportAddress(channel.getLocalAddress()); - RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - features, - tx, - version, - requestId, - false, - false - ); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); - sendMessage(channel, message, listener); - } - - private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { - MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); - SendContext sendContext = new SendContext(channel, serializer, listener, serializer); - internalSend(channel, sendContext); - } - - private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException { + public void sendBytes(TcpChannel channel, SendContext sendContext) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); BytesReference reference = sendContext.get(); // stash thread context so that channel event loop is not polluted by thread context @@ -205,59 +88,30 @@ private void internalSend(TcpChannel channel, SendContext sendContext) throws IO } } - void setMessageListener(TransportMessageListener listener) { - if (messageListener == TransportMessageListener.NOOP_LISTENER) { - messageListener = listener; - } else { - throw new IllegalStateException("Cannot set message listener twice"); - } - } - /** * Internal message serializer * * @opensearch.internal */ - private static class MessageSerializer implements CheckedSupplier, Releasable { - - private final OutboundMessage message; - private final BigArrays bigArrays; - private volatile ReleasableBytesStreamOutput bytesStreamOutput; - - private MessageSerializer(OutboundMessage message, BigArrays bigArrays) { - this.message = message; - this.bigArrays = bigArrays; - } - - @Override - public BytesReference get() throws IOException { - bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); - return message.serialize(bytesStreamOutput); - } - - @Override - public void close() { - IOUtils.closeWhileHandlingException(bytesStreamOutput); - } - } - - private class SendContext extends NotifyOnceListener implements CheckedSupplier { - + public static class SendContext extends NotifyOnceListener implements CheckedSupplier { + private final StatsTracker statsTracker; private final TcpChannel channel; private final CheckedSupplier messageSupplier; private final ActionListener listener; private final Releasable optionalReleasable; private long messageSize = -1; - private SendContext( + SendContext( + StatsTracker statsTracker, TcpChannel channel, CheckedSupplier messageSupplier, ActionListener listener ) { - this(channel, messageSupplier, listener, null); + this(statsTracker, channel, messageSupplier, listener, null); } - private SendContext( + public SendContext( + StatsTracker statsTracker, TcpChannel channel, CheckedSupplier messageSupplier, ActionListener listener, @@ -267,6 +121,7 @@ private SendContext( this.messageSupplier = messageSupplier; this.listener = listener; this.optionalReleasable = optionalReleasable; + this.statsTracker = statsTracker; } public BytesReference get() throws IOException { diff --git a/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java b/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java index 714d91d1e74c7..3c3fafebc34df 100644 --- a/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java @@ -17,6 +17,14 @@ */ public interface ProtocolMessageHandler { + /** + * Handles the message received on the channel. + * @param channel the channel on which the message was received + * @param message the message received + * @param startTime the start time + * @param slowLogThresholdMs the threshold for slow logs + * @param messageListener the message listener + */ public void messageReceived( TcpChannel channel, ProtocolInboundMessage message, @@ -24,4 +32,10 @@ public void messageReceived( long slowLogThresholdMs, TransportMessageListener messageListener ) throws IOException; + + /** + * Sets the message listener to be used by the handler. + * @param listener the message listener + */ + public void setMessageListener(TransportMessageListener listener); } diff --git a/server/src/main/java/org/opensearch/transport/ProtocolOutboundHandler.java b/server/src/main/java/org/opensearch/transport/ProtocolOutboundHandler.java new file mode 100644 index 0000000000000..42c5462fddf80 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/ProtocolOutboundHandler.java @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.transport.TransportResponse; + +import java.io.IOException; +import java.util.Set; + +/** + * Protocol based outbound data handler. + * Different transport protocols can have different implementations of this class. + * + * @opensearch.internal + */ +public abstract class ProtocolOutboundHandler { + + /** + * Sends the request to the given channel. This method should be used to send {@link TransportRequest} + * objects back to the caller. + */ + public abstract void sendRequest( + final DiscoveryNode node, + final TcpChannel channel, + final long requestId, + final String action, + final TransportRequest request, + final TransportRequestOptions options, + final Version channelVersion, + final boolean compressRequest, + final boolean isHandshake + ) throws IOException, TransportException; + + /** + * Sends the response to the given channel. This method should be used to send {@link TransportResponse} + * objects back to the caller. + * + * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses + */ + public abstract void sendResponse( + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final long requestId, + final String action, + final TransportResponse response, + final boolean compress, + final boolean isHandshake + ) throws IOException; + + /** + * Sends back an error response to the caller via the given channel + */ + public abstract void sendErrorResponse( + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final long requestId, + final String action, + final Exception error + ) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index e32bba5e836d3..78452a25a58d6 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -71,6 +71,7 @@ import org.opensearch.node.Node; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import java.io.IOException; import java.io.StreamCorruptedException; @@ -150,6 +151,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final TransportKeepAlive keepAlive; private final OutboundHandler outboundHandler; private final InboundHandler inboundHandler; + private final NativeOutboundHandler handshakerHandler; private final ResponseHandlers responseHandlers = new ResponseHandlers(); private final RequestHandlers requestHandlers = new RequestHandlers(); @@ -188,11 +190,20 @@ public TcpTransport( } BigArrays bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); - this.outboundHandler = new OutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays); + this.outboundHandler = new OutboundHandler(statsTracker, threadPool); + this.handshakerHandler = new NativeOutboundHandler( + nodeName, + version, + features, + statsTracker, + threadPool, + bigArrays, + outboundHandler + ); this.handshaker = new TransportHandshaker( version, threadPool, - (node, channel, requestId, v) -> outboundHandler.sendRequest( + (node, channel, requestId, v) -> handshakerHandler.sendRequest( node, channel, requestId, @@ -206,7 +217,12 @@ public TcpTransport( ); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); this.inboundHandler = new InboundHandler( + nodeName, + version, + features, + statsTracker, threadPool, + bigArrays, outboundHandler, namedWriteableRegistry, handshaker, @@ -238,7 +254,7 @@ protected void doStart() {} @Override public synchronized void setMessageListener(TransportMessageListener listener) { - outboundHandler.setMessageListener(listener); + handshakerHandler.setMessageListener(listener); inboundHandler.setMessageListener(listener); } @@ -319,7 +335,7 @@ public void sendRequest(long requestId, String action, TransportRequest request, throw new NodeNotConnectedException(node, "connection already closed"); } TcpChannel channel = channel(options.type()); - outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false); + handshakerHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false); } } diff --git a/server/src/main/java/org/opensearch/transport/TcpTransportChannel.java b/server/src/main/java/org/opensearch/transport/TcpTransportChannel.java index 81de0af07ea7c..750fd50a4c44c 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransportChannel.java @@ -50,7 +50,7 @@ public final class TcpTransportChannel extends BaseTcpTransportChannel { private final AtomicBoolean released = new AtomicBoolean(); - private final OutboundHandler outboundHandler; + private final ProtocolOutboundHandler outboundHandler; private final String action; private final long requestId; private final Version version; @@ -60,7 +60,7 @@ public final class TcpTransportChannel extends BaseTcpTransportChannel { private final Releasable breakerRelease; TcpTransportChannel( - OutboundHandler outboundHandler, + ProtocolOutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, diff --git a/server/src/main/java/org/opensearch/transport/TransportStatus.java b/server/src/main/java/org/opensearch/transport/TransportStatus.java index dab572949e688..76377468535b9 100644 --- a/server/src/main/java/org/opensearch/transport/TransportStatus.java +++ b/server/src/main/java/org/opensearch/transport/TransportStatus.java @@ -76,11 +76,11 @@ public static byte setCompress(byte value) { return value; } - static boolean isHandshake(byte value) { // pkg private since it's only used internally + public static boolean isHandshake(byte value) { return (value & STATUS_HANDSHAKE) != 0; } - static byte setHandshake(byte value) { // pkg private since it's only used internally + public static byte setHandshake(byte value) { value |= STATUS_HANDSHAKE; return value; } diff --git a/server/src/main/java/org/opensearch/transport/CompressibleBytesOutputStream.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStream.java similarity index 98% rename from server/src/main/java/org/opensearch/transport/CompressibleBytesOutputStream.java rename to server/src/main/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStream.java index 57707d3b44477..92b682370bcd5 100644 --- a/server/src/main/java/org/opensearch/transport/CompressibleBytesOutputStream.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStream.java @@ -30,7 +30,7 @@ * GitHub history for details. */ -package org.opensearch.transport; +package org.opensearch.transport.nativeprotocol; import org.opensearch.common.io.Streams; import org.opensearch.common.util.io.IOUtils; diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundHandler.java new file mode 100644 index 0000000000000..66ed0d8e3eb2b --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundHandler.java @@ -0,0 +1,224 @@ +/* +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +/* +* Licensed to Elasticsearch under one or more contributor +* license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright +* ownership. Elasticsearch licenses this file to you under +* the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/* +* Modifications Copyright OpenSearch Contributors. See +* GitHub history for details. +*/ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.CheckedSupplier; +import org.opensearch.common.io.stream.ReleasableBytesStreamOutput; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.OutboundHandler; +import org.opensearch.transport.ProtocolOutboundHandler; +import org.opensearch.transport.RemoteTransportException; +import org.opensearch.transport.StatsTracker; +import org.opensearch.transport.TcpChannel; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportMessageListener; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; + +import java.io.IOException; +import java.util.Set; + +/** + * Outbound data handler + * + * @opensearch.internal + */ +public final class NativeOutboundHandler extends ProtocolOutboundHandler { + private final String nodeName; + private final Version version; + private final String[] features; + private final StatsTracker statsTracker; + private final ThreadPool threadPool; + private final BigArrays bigArrays; + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; + private final OutboundHandler handler; + + public NativeOutboundHandler( + String nodeName, + Version version, + String[] features, + StatsTracker statsTracker, + ThreadPool threadPool, + BigArrays bigArrays, + OutboundHandler handler + ) { + this.nodeName = nodeName; + this.version = version; + this.features = features; + this.statsTracker = statsTracker; + this.threadPool = threadPool; + this.bigArrays = bigArrays; + this.handler = handler; + } + + /** + * Sends the request to the given channel. This method should be used to send {@link TransportRequest} + * objects back to the caller. + */ + @Override + public void sendRequest( + final DiscoveryNode node, + final TcpChannel channel, + final long requestId, + final String action, + final TransportRequest request, + final TransportRequestOptions options, + final Version channelVersion, + final boolean compressRequest, + final boolean isHandshake + ) throws IOException, TransportException { + Version version = Version.min(this.version, channelVersion); + NativeOutboundMessage.Request message = new NativeOutboundMessage.Request( + threadPool.getThreadContext(), + features, + request, + version, + action, + requestId, + isHandshake, + compressRequest + ); + ActionListener listener = ActionListener.wrap(() -> messageListener.onRequestSent(node, requestId, action, request, options)); + sendMessage(channel, message, listener); + } + + /** + * Sends the response to the given channel. This method should be used to send {@link TransportResponse} + * objects back to the caller. + * + * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses + */ + @Override + public void sendResponse( + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final long requestId, + final String action, + final TransportResponse response, + final boolean compress, + final boolean isHandshake + ) throws IOException { + Version version = Version.min(this.version, nodeVersion); + NativeOutboundMessage.Response message = new NativeOutboundMessage.Response( + threadPool.getThreadContext(), + features, + response, + version, + requestId, + isHandshake, + compress + ); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); + sendMessage(channel, message, listener); + } + + /** + * Sends back an error response to the caller via the given channel + */ + @Override + public void sendErrorResponse( + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final long requestId, + final String action, + final Exception error + ) throws IOException { + Version version = Version.min(this.version, nodeVersion); + TransportAddress address = new TransportAddress(channel.getLocalAddress()); + RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); + NativeOutboundMessage.Response message = new NativeOutboundMessage.Response( + threadPool.getThreadContext(), + features, + tx, + version, + requestId, + false, + false + ); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); + sendMessage(channel, message, listener); + } + + private void sendMessage(TcpChannel channel, NativeOutboundMessage networkMessage, ActionListener listener) throws IOException { + MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); + OutboundHandler.SendContext sendContext = new OutboundHandler.SendContext(statsTracker, channel, serializer, listener, serializer); + handler.sendBytes(channel, sendContext); + } + + public void setMessageListener(TransportMessageListener listener) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { + messageListener = listener; + } else { + throw new IllegalStateException("Cannot set message listener twice"); + } + } + + /** + * Internal message serializer + * + * @opensearch.internal + */ + private static class MessageSerializer implements CheckedSupplier, Releasable { + + private final NativeOutboundMessage message; + private final BigArrays bigArrays; + private volatile ReleasableBytesStreamOutput bytesStreamOutput; + + private MessageSerializer(NativeOutboundMessage message, BigArrays bigArrays) { + this.message = message; + this.bigArrays = bigArrays; + } + + @Override + public BytesReference get() throws IOException { + bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); + return message.serialize(bytesStreamOutput); + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(bytesStreamOutput); + } + } +} diff --git a/server/src/main/java/org/opensearch/transport/OutboundMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessage.java similarity index 91% rename from server/src/main/java/org/opensearch/transport/OutboundMessage.java rename to server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessage.java index 358655288f849..a86c994a87e2e 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessage.java @@ -29,7 +29,7 @@ * GitHub history for details. */ -package org.opensearch.transport; +package org.opensearch.transport.nativeprotocol; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -39,6 +39,10 @@ import org.opensearch.core.common.bytes.CompositeBytesReference; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.transport.BytesTransportRequest; +import org.opensearch.transport.RemoteTransportException; +import org.opensearch.transport.TcpHeader; +import org.opensearch.transport.TransportStatus; import java.io.IOException; import java.util.Set; @@ -48,11 +52,11 @@ * * @opensearch.internal */ -abstract class OutboundMessage extends NetworkMessage { +abstract class NativeOutboundMessage extends NetworkMessage { private final Writeable message; - OutboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, Writeable message) { + NativeOutboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, Writeable message) { super(threadContext, version, status, requestId); this.message = message; } @@ -96,7 +100,7 @@ protected BytesReference writeMessage(CompressibleBytesOutputStream stream) thro if (message instanceof BytesTransportRequest) { BytesTransportRequest bRequest = (BytesTransportRequest) message; bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; + zeroCopyBuffer = bRequest.bytes(); } else if (message instanceof RemoteTransportException) { stream.writeException((RemoteTransportException) message); zeroCopyBuffer = BytesArray.EMPTY; @@ -122,7 +126,7 @@ protected BytesReference writeMessage(CompressibleBytesOutputStream stream) thro * * @opensearch.internal */ - static class Request extends OutboundMessage { + static class Request extends NativeOutboundMessage { private final String[] features; private final String action; @@ -152,7 +156,7 @@ protected void writeVariableHeader(StreamOutput stream) throws IOException { private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) { byte status = 0; status = TransportStatus.setRequest(status); - if (compress && OutboundMessage.canCompress(message)) { + if (compress && NativeOutboundMessage.canCompress(message)) { status = TransportStatus.setCompress(status); } if (isHandshake) { @@ -168,7 +172,7 @@ private static byte setStatus(boolean compress, boolean isHandshake, Writeable m * * @opensearch.internal */ - static class Response extends OutboundMessage { + static class Response extends NativeOutboundMessage { private final Set features; diff --git a/server/src/main/java/org/opensearch/transport/NetworkMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NetworkMessage.java similarity index 96% rename from server/src/main/java/org/opensearch/transport/NetworkMessage.java rename to server/src/main/java/org/opensearch/transport/nativeprotocol/NetworkMessage.java index f02d664b65929..c197539d2e009 100644 --- a/server/src/main/java/org/opensearch/transport/NetworkMessage.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NetworkMessage.java @@ -29,11 +29,12 @@ * GitHub history for details. */ -package org.opensearch.transport; +package org.opensearch.transport.nativeprotocol; import org.opensearch.Version; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.transport.TransportStatus; /** * Represents a transport message sent over the network. Subclasses implement serialization and diff --git a/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java b/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java index abde20cd4dcd1..dc1aeacf05295 100644 --- a/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java @@ -40,19 +40,29 @@ import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.transport.TransportMessage; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.VersionUtils; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import static org.hamcrest.Matchers.hasItems; -public class InboundDecoderTests extends OpenSearchTestCase { +public abstract class InboundDecoderTests extends OpenSearchTestCase { - private ThreadContext threadContext; + protected ThreadContext threadContext; + + protected abstract BytesReference serialize( + boolean isRequest, + Version version, + boolean handshake, + boolean compress, + String action, + long requestId, + Writeable transportMessage + ) throws IOException; @Override public void setUp() throws Exception { @@ -66,36 +76,16 @@ public void testDecode() throws IOException { long requestId = randomNonNegativeLong(); final String headerKey = randomAlphaOfLength(10); final String headerValue = randomAlphaOfLength(20); + TransportMessage transportMessage; if (isRequest) { threadContext.putHeader(headerKey, headerValue); + transportMessage = new TestRequest(randomAlphaOfLength(100)); } else { threadContext.addResponseHeader(headerKey, headerValue); - } - OutboundMessage message; - if (isRequest) { - message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(randomAlphaOfLength(100)), - Version.CURRENT, - action, - requestId, - false, - false - ); - } else { - message = new OutboundMessage.Response( - threadContext, - Collections.emptySet(), - new TestResponse(randomAlphaOfLength(100)), - Version.CURRENT, - requestId, - false, - false - ); + transportMessage = new TestResponse(randomAlphaOfLength(100)); } - final BytesReference totalBytes = message.serialize(new BytesStreamOutput()); + final BytesReference totalBytes = serialize(isRequest, Version.CURRENT, false, false, action, requestId, transportMessage); int totalHeaderSize = TcpHeader.headerSize(Version.CURRENT) + totalBytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION); final BytesReference messageBytes = totalBytes.slice(totalHeaderSize, totalBytes.length() - totalHeaderSize); @@ -143,18 +133,16 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException { long requestId = randomNonNegativeLong(); final Version preHeaderVariableInt = LegacyESVersion.V_7_5_0; final String contentValue = randomAlphaOfLength(100); - final OutboundMessage message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(contentValue), + + final BytesReference totalBytes = serialize( + true, preHeaderVariableInt, + true, + isCompressed, action, requestId, - true, - isCompressed + new TestRequest(contentValue) ); - - final BytesReference totalBytes = message.serialize(new BytesStreamOutput()); int partialHeaderSize = TcpHeader.headerSize(preHeaderVariableInt); InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -188,18 +176,16 @@ public void testDecodeHandshakeCompatibility() throws IOException { final String headerValue = randomAlphaOfLength(20); threadContext.putHeader(headerKey, headerValue); Version handshakeCompat = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(randomAlphaOfLength(100)), + + final BytesReference bytes = serialize( + true, handshakeCompat, + true, + false, action, requestId, - true, - false + new TestRequest(randomAlphaOfLength(100)) ); - - final BytesReference bytes = message.serialize(new BytesStreamOutput()); int totalHeaderSize = TcpHeader.headerSize(handshakeCompat); InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -231,34 +217,14 @@ public void testCompressedDecode() throws IOException { } else { threadContext.addResponseHeader(headerKey, headerValue); } - OutboundMessage message; TransportMessage transportMessage; if (isRequest) { transportMessage = new TestRequest(randomAlphaOfLength(100)); - message = new OutboundMessage.Request( - threadContext, - new String[0], - transportMessage, - Version.CURRENT, - action, - requestId, - false, - true - ); } else { transportMessage = new TestResponse(randomAlphaOfLength(100)); - message = new OutboundMessage.Response( - threadContext, - Collections.emptySet(), - transportMessage, - Version.CURRENT, - requestId, - false, - true - ); } - final BytesReference totalBytes = message.serialize(new BytesStreamOutput()); + final BytesReference totalBytes = serialize(isRequest, Version.CURRENT, false, true, action, requestId, transportMessage); final BytesStreamOutput out = new BytesStreamOutput(); transportMessage.writeTo(out); final BytesReference uncompressedBytes = out.bytes(); @@ -308,18 +274,16 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException { final String headerValue = randomAlphaOfLength(20); threadContext.putHeader(headerKey, headerValue); Version handshakeCompat = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(randomAlphaOfLength(100)), + + final BytesReference bytes = serialize( + true, handshakeCompat, + true, + true, action, requestId, - true, - true + new TestRequest(randomAlphaOfLength(100)) ); - - final BytesReference bytes = message.serialize(new BytesStreamOutput()); int totalHeaderSize = TcpHeader.headerSize(handshakeCompat); InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -344,19 +308,17 @@ public void testVersionIncompatibilityDecodeException() throws IOException { String action = "test-request"; long requestId = randomNonNegativeLong(); Version incompatibleVersion = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(randomAlphaOfLength(100)), + + final BytesReference bytes = serialize( + true, incompatibleVersion, + false, + true, action, requestId, - false, - true + new TestRequest(randomAlphaOfLength(100)) ); - final BytesReference bytes = message.serialize(new BytesStreamOutput()); - InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 0d171e17e70e1..ea656f6651b1e 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -39,16 +39,17 @@ import org.opensearch.Version; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.collect.Tuple; -import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.tasks.TaskManager; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.MockLogAppender; @@ -74,7 +75,17 @@ import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.instanceOf; -public class InboundHandlerTests extends OpenSearchTestCase { +public abstract class InboundHandlerTests extends OpenSearchTestCase { + + public abstract BytesReference serializeOutboundRequest( + ThreadContext threadContext, + Writeable message, + Version version, + String action, + long requestId, + boolean compress, + boolean handshake + ) throws IOException; private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final Version version = Version.CURRENT; @@ -100,19 +111,17 @@ public void sendMessage(BytesReference reference, ActionListener listener) }; NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}); - outboundHandler = new OutboundHandler( - "node", - version, - new String[0], - new StatsTracker(), - threadPool, - BigArrays.NON_RECYCLING_INSTANCE - ); + outboundHandler = new OutboundHandler(new StatsTracker(), threadPool); TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, outboundHandler::sendBytes); requestHandlers = new Transport.RequestHandlers(); responseHandlers = new Transport.ResponseHandlers(); handler = new InboundHandler( + "node", + version, + new String[0], + new StatsTracker(), threadPool, + BigArrays.NON_RECYCLING_INSTANCE, outboundHandler, namedWriteableRegistry, handshaker, @@ -194,9 +203,9 @@ public TestResponse read(StreamInput in) throws IOException { ); requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( + + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), - new String[0], new TestRequest(requestValue), version, action, @@ -204,8 +213,6 @@ public TestResponse read(StreamInput in) throws IOException { false, false ); - - BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); @@ -380,18 +387,8 @@ public TestResponse read(StreamInput in) throws IOException { requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( - threadPool.getThreadContext(), - new String[0], - new TestRequest(requestValue), - version, - action, - requestId, - false, - false - ); - outboundHandler.setMessageListener(new TransportMessageListener() { + handler.setMessageListener(new TransportMessageListener() { @Override public void onResponseSent(long requestId, String action, Exception error) { exceptionCaptor.set(error); @@ -399,7 +396,15 @@ public void onResponseSent(long requestId, String action, Exception error) { }); // Create the request payload with 1 byte overflow - final BytesRef bytes = request.serialize(new BytesStreamOutput()).toBytesRef(); + final BytesRef bytes = serializeOutboundRequest( + threadPool.getThreadContext(), + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ).toBytesRef(); final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1); buffer.put(bytes.bytes, 0, bytes.length); buffer.put((byte) 1); @@ -452,9 +457,16 @@ public TestResponse read(StreamInput in) throws IOException { requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( + + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + exceptionCaptor.set(error); + } + }); + + final BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), - new String[0], new TestRequest(requestValue), version, action, @@ -462,15 +474,6 @@ public TestResponse read(StreamInput in) throws IOException { false, false ); - - outboundHandler.setMessageListener(new TransportMessageListener() { - @Override - public void onResponseSent(long requestId, String action, Exception error) { - exceptionCaptor.set(error); - } - }); - - final BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); @@ -526,9 +529,9 @@ public TestResponse read(StreamInput in) throws IOException { ); requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( + + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), - new String[0], new TestRequest(requestValue), version, action, @@ -536,8 +539,6 @@ public TestResponse read(StreamInput in) throws IOException { false, false ); - - BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); @@ -614,9 +615,9 @@ public TestResponse read(StreamInput in) throws IOException { ); requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( + + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), - new String[0], new TestRequest(requestValue), version, action, @@ -624,8 +625,6 @@ public TestResponse read(StreamInput in) throws IOException { false, false ); - - BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 2dfe8a0dd8590..74457e2b153fd 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -52,7 +52,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -63,12 +62,22 @@ import static org.hamcrest.Matchers.instanceOf; -public class InboundPipelineTests extends OpenSearchTestCase { +public abstract class InboundPipelineTests extends OpenSearchTestCase { private static final int BYTE_THRESHOLD = 128 * 1024; - private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - - public void testPipelineHandling() throws IOException { + public final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + protected abstract BytesReference serialize( + boolean isRequest, + Version version, + boolean handshake, + boolean compress, + String action, + long requestId, + String value + ) throws IOException; + + public void testPipelineHandlingForNativeProtocol() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); @@ -85,10 +94,10 @@ public void testPipelineHandling() throws IOException { actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), null); } else if (isRequest) { final TestRequest request = new TestRequest(message.openOrGetStreamInput()); - actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.value); + actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.getValue()); } else { final TestResponse response = new TestResponse(message.openOrGetStreamInput()); - actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.value); + actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.getValue()); } actual.add(new Tuple<>(actualData, message.getException())); } catch (IOException e) { @@ -127,49 +136,23 @@ public void testPipelineHandling() throws IOException { final MessageData messageData; Exception expectedExceptionClass = null; - OutboundMessage message; + // NativeOutboundMessage message; + final BytesReference reference; if (isRequest) { if (rarely()) { messageData = new MessageData(version, requestId, true, isCompressed, breakThisAction, null); - message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(value), - version, - breakThisAction, - requestId, - false, - isCompressed - ); + reference = serialize(true, version, false, isCompressed, breakThisAction, requestId, value); expectedExceptionClass = new CircuitBreakingException("", CircuitBreaker.Durability.PERMANENT); } else { messageData = new MessageData(version, requestId, true, isCompressed, actionName, value); - message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(value), - version, - actionName, - requestId, - false, - isCompressed - ); + reference = serialize(true, version, false, isCompressed, actionName, requestId, value); } } else { messageData = new MessageData(version, requestId, false, isCompressed, null, value); - message = new OutboundMessage.Response( - threadContext, - Collections.emptySet(), - new TestResponse(value), - version, - requestId, - false, - isCompressed - ); + reference = serialize(false, version, false, isCompressed, actionName, requestId, value); } expected.add(new Tuple<>(messageData, expectedExceptionClass)); - final BytesReference reference = message.serialize(new BytesStreamOutput()); Streams.copy(reference.streamInput(), streamOutput); } @@ -230,31 +213,7 @@ public void testDecodeExceptionIsPropagated() throws IOException { final boolean isRequest = randomBoolean(); final long requestId = randomNonNegativeLong(); - OutboundMessage message; - if (isRequest) { - message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(value), - invalidVersion, - actionName, - requestId, - false, - false - ); - } else { - message = new OutboundMessage.Response( - threadContext, - Collections.emptySet(), - new TestResponse(value), - invalidVersion, - requestId, - false, - false - ); - } - - final BytesReference reference = message.serialize(streamOutput); + final BytesReference reference = serialize(isRequest, invalidVersion, false, false, actionName, requestId, value); try (ReleasableBytesReference releasable = ReleasableBytesReference.wrap(reference)) { expectThrows(IllegalStateException.class, () -> pipeline.handleBytes(new FakeTcpChannel(), releasable)); } @@ -284,31 +243,7 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { final boolean isRequest = randomBoolean(); final long requestId = randomNonNegativeLong(); - OutboundMessage message; - if (isRequest) { - message = new OutboundMessage.Request( - threadContext, - new String[0], - new TestRequest(value), - version, - actionName, - requestId, - false, - false - ); - } else { - message = new OutboundMessage.Response( - threadContext, - Collections.emptySet(), - new TestResponse(value), - version, - requestId, - false, - false - ); - } - - final BytesReference reference = message.serialize(streamOutput); + final BytesReference reference = serialize(isRequest, version, false, false, actionName, requestId, value); final int fixedHeaderSize = TcpHeader.headerSize(Version.CURRENT); final int variableHeaderSize = reference.getInt(fixedHeaderSize - 4); final int totalHeaderSize = fixedHeaderSize + variableHeaderSize; diff --git a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java new file mode 100644 index 0000000000000..a42a896e373bc --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java @@ -0,0 +1,301 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport; + +import org.opensearch.OpenSearchException; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.io.Streams; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.NoopCircuitBreaker; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.instanceOf; + +public class NativeOutboundHandlerTests extends OpenSearchTestCase { + + private final String feature1 = "feature1"; + private final String feature2 = "feature2"; + private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); + private final TransportRequestOptions options = TransportRequestOptions.EMPTY; + private final AtomicReference> message = new AtomicReference<>(); + private InboundPipeline pipeline; + private OutboundHandler handler; + private NativeOutboundHandler nativeOutboundHandler; + private FakeTcpChannel channel; + private DiscoveryNode node; + + @Before + public void setUp() throws Exception { + super.setUp(); + channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); + TransportAddress transportAddress = buildNewFakeTransportAddress(); + node = new DiscoveryNode("", transportAddress, Version.CURRENT); + String[] features = { feature1, feature2 }; + StatsTracker statsTracker = new StatsTracker(); + handler = new OutboundHandler(statsTracker, threadPool); + nativeOutboundHandler = new NativeOutboundHandler( + "node", + Version.CURRENT, + features, + statsTracker, + threadPool, + BigArrays.NON_RECYCLING_INSTANCE, + handler + ); + + final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + InboundMessage m1 = (InboundMessage) m; + Streams.copy(m1.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); + } catch (IOException e) { + throw new AssertionError(e); + } + }); + } + + @After + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + super.tearDown(); + } + + public void testSendRequest() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + TestRequest request = new TestRequest(value); + + AtomicReference nodeRef = new AtomicReference<>(); + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference requestRef = new AtomicReference<>(); + nativeOutboundHandler.setMessageListener(new TransportMessageListener() { + @Override + public void onRequestSent( + DiscoveryNode node, + long requestId, + String action, + TransportRequest request, + TransportRequestOptions options + ) { + nodeRef.set(node); + requestIdRef.set(requestId); + actionRef.set(action); + requestRef.set(request); + } + }); + nativeOutboundHandler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(node, nodeRef.get()); + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(request, requestRef.get()); + + pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); + final Tuple tuple = message.get(); + final Header header = tuple.v1(); + final TestRequest message = new TestRequest(tuple.v2().streamInput()); + assertEquals(version, header.getVersion()); + assertEquals(requestId, header.getRequestId()); + assertTrue(header.isRequest()); + assertFalse(header.isResponse()); + if (isHandshake) { + assertTrue(header.isHandshake()); + } else { + assertFalse(header.isHandshake()); + } + if (compress) { + assertTrue(header.isCompressed()); + } else { + assertFalse(header.isCompressed()); + } + + assertEquals(value, message.getValue()); + assertEquals("header_value", header.getHeaders().v1().get("header")); + } + + public void testSendResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + TestResponse response = new TestResponse(value); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + nativeOutboundHandler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(response); + } + }); + nativeOutboundHandler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(response, responseRef.get()); + + pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); + final Tuple tuple = message.get(); + final Header header = tuple.v1(); + final TestResponse message = new TestResponse(tuple.v2().streamInput()); + assertEquals(version, header.getVersion()); + assertEquals(requestId, header.getRequestId()); + assertFalse(header.isRequest()); + assertTrue(header.isResponse()); + if (isHandshake) { + assertTrue(header.isHandshake()); + } else { + assertFalse(header.isHandshake()); + } + if (compress) { + assertTrue(header.isCompressed()); + } else { + assertFalse(header.isCompressed()); + } + + assertFalse(header.isError()); + + assertEquals(value, message.getValue()); + assertEquals("header_value", header.getHeaders().v1().get("header")); + } + + public void testErrorResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + threadContext.putHeader("header", "header_value"); + OpenSearchException error = new OpenSearchException("boom"); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + nativeOutboundHandler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(error); + } + }); + nativeOutboundHandler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(error, responseRef.get()); + + pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); + final Tuple tuple = message.get(); + final Header header = tuple.v1(); + assertEquals(version, header.getVersion()); + assertEquals(requestId, header.getRequestId()); + assertFalse(header.isRequest()); + assertTrue(header.isResponse()); + assertFalse(header.isCompressed()); + assertFalse(header.isHandshake()); + assertTrue(header.isError()); + + RemoteTransportException remoteException = tuple.v2().streamInput().readException(); + assertThat(remoteException.getCause(), instanceOf(OpenSearchException.class)); + assertEquals(remoteException.getCause().getMessage(), "boom"); + assertEquals(action, remoteException.action()); + assertEquals(channel.getLocalAddress(), remoteException.address().address()); + + assertEquals("header_value", header.getHeaders().v1().get("header")); + } +} diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index 36ba409a2de03..7e7c60e2d3d29 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -32,24 +32,9 @@ package org.opensearch.transport; -import org.opensearch.OpenSearchException; -import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.bytes.ReleasableBytesReference; -import org.opensearch.common.collect.Tuple; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.BigArrays; -import org.opensearch.common.util.PageCacheRecycler; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.util.io.Streams; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.transport.TransportResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -58,52 +43,22 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.LongSupplier; -import java.util.function.Predicate; -import java.util.function.Supplier; - -import static org.hamcrest.Matchers.instanceOf; public class OutboundHandlerTests extends OpenSearchTestCase { - private final String feature1 = "feature1"; - private final String feature2 = "feature2"; private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); - private final TransportRequestOptions options = TransportRequestOptions.EMPTY; - private final AtomicReference> message = new AtomicReference<>(); - private InboundPipeline pipeline; private OutboundHandler handler; private FakeTcpChannel channel; - private DiscoveryNode node; @Before public void setUp() throws Exception { super.setUp(); channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); - TransportAddress transportAddress = buildNewFakeTransportAddress(); - node = new DiscoveryNode("", transportAddress, Version.CURRENT); - String[] features = { feature1, feature2 }; StatsTracker statsTracker = new StatsTracker(); - handler = new OutboundHandler("node", Version.CURRENT, features, statsTracker, threadPool, BigArrays.NON_RECYCLING_INSTANCE); - - final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); - final Supplier breaker = () -> new NoopCircuitBreaker("test"); - final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { - try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - InboundMessage m1 = (InboundMessage) m; - Streams.copy(m1.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); - } catch (IOException e) { - throw new AssertionError(e); - } - }); + handler = new OutboundHandler(statsTracker, threadPool); } @After @@ -136,182 +91,4 @@ public void testSendRawBytes() { assertEquals(bytesArray, reference); } - public void testSendRequest() throws IOException { - ThreadContext threadContext = threadPool.getThreadContext(); - Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); - String action = "handshake"; - long requestId = randomLongBetween(0, 300); - boolean isHandshake = randomBoolean(); - boolean compress = randomBoolean(); - String value = "message"; - threadContext.putHeader("header", "header_value"); - TestRequest request = new TestRequest(value); - - AtomicReference nodeRef = new AtomicReference<>(); - AtomicLong requestIdRef = new AtomicLong(); - AtomicReference actionRef = new AtomicReference<>(); - AtomicReference requestRef = new AtomicReference<>(); - handler.setMessageListener(new TransportMessageListener() { - @Override - public void onRequestSent( - DiscoveryNode node, - long requestId, - String action, - TransportRequest request, - TransportRequestOptions options - ) { - nodeRef.set(node); - requestIdRef.set(requestId); - actionRef.set(action); - requestRef.set(request); - } - }); - handler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake); - - BytesReference reference = channel.getMessageCaptor().get(); - ActionListener sendListener = channel.getListenerCaptor().get(); - if (randomBoolean()) { - sendListener.onResponse(null); - } else { - sendListener.onFailure(new IOException("failed")); - } - assertEquals(node, nodeRef.get()); - assertEquals(requestId, requestIdRef.get()); - assertEquals(action, actionRef.get()); - assertEquals(request, requestRef.get()); - - pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); - final Tuple tuple = message.get(); - final Header header = tuple.v1(); - final TestRequest message = new TestRequest(tuple.v2().streamInput()); - assertEquals(version, header.getVersion()); - assertEquals(requestId, header.getRequestId()); - assertTrue(header.isRequest()); - assertFalse(header.isResponse()); - if (isHandshake) { - assertTrue(header.isHandshake()); - } else { - assertFalse(header.isHandshake()); - } - if (compress) { - assertTrue(header.isCompressed()); - } else { - assertFalse(header.isCompressed()); - } - - assertEquals(value, message.value); - assertEquals("header_value", header.getHeaders().v1().get("header")); - } - - public void testSendResponse() throws IOException { - ThreadContext threadContext = threadPool.getThreadContext(); - Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); - String action = "handshake"; - long requestId = randomLongBetween(0, 300); - boolean isHandshake = randomBoolean(); - boolean compress = randomBoolean(); - String value = "message"; - threadContext.putHeader("header", "header_value"); - TestResponse response = new TestResponse(value); - - AtomicLong requestIdRef = new AtomicLong(); - AtomicReference actionRef = new AtomicReference<>(); - AtomicReference responseRef = new AtomicReference<>(); - handler.setMessageListener(new TransportMessageListener() { - @Override - public void onResponseSent(long requestId, String action, TransportResponse response) { - requestIdRef.set(requestId); - actionRef.set(action); - responseRef.set(response); - } - }); - handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake); - - BytesReference reference = channel.getMessageCaptor().get(); - ActionListener sendListener = channel.getListenerCaptor().get(); - if (randomBoolean()) { - sendListener.onResponse(null); - } else { - sendListener.onFailure(new IOException("failed")); - } - assertEquals(requestId, requestIdRef.get()); - assertEquals(action, actionRef.get()); - assertEquals(response, responseRef.get()); - - pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); - final Tuple tuple = message.get(); - final Header header = tuple.v1(); - final TestResponse message = new TestResponse(tuple.v2().streamInput()); - assertEquals(version, header.getVersion()); - assertEquals(requestId, header.getRequestId()); - assertFalse(header.isRequest()); - assertTrue(header.isResponse()); - if (isHandshake) { - assertTrue(header.isHandshake()); - } else { - assertFalse(header.isHandshake()); - } - if (compress) { - assertTrue(header.isCompressed()); - } else { - assertFalse(header.isCompressed()); - } - - assertFalse(header.isError()); - - assertEquals(value, message.value); - assertEquals("header_value", header.getHeaders().v1().get("header")); - } - - public void testErrorResponse() throws IOException { - ThreadContext threadContext = threadPool.getThreadContext(); - Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); - String action = "handshake"; - long requestId = randomLongBetween(0, 300); - threadContext.putHeader("header", "header_value"); - OpenSearchException error = new OpenSearchException("boom"); - - AtomicLong requestIdRef = new AtomicLong(); - AtomicReference actionRef = new AtomicReference<>(); - AtomicReference responseRef = new AtomicReference<>(); - handler.setMessageListener(new TransportMessageListener() { - @Override - public void onResponseSent(long requestId, String action, Exception error) { - requestIdRef.set(requestId); - actionRef.set(action); - responseRef.set(error); - } - }); - handler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error); - - BytesReference reference = channel.getMessageCaptor().get(); - ActionListener sendListener = channel.getListenerCaptor().get(); - if (randomBoolean()) { - sendListener.onResponse(null); - } else { - sendListener.onFailure(new IOException("failed")); - } - assertEquals(requestId, requestIdRef.get()); - assertEquals(action, actionRef.get()); - assertEquals(error, responseRef.get()); - - pipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); - final Tuple tuple = message.get(); - final Header header = tuple.v1(); - assertEquals(version, header.getVersion()); - assertEquals(requestId, header.getRequestId()); - assertFalse(header.isRequest()); - assertTrue(header.isResponse()); - assertFalse(header.isCompressed()); - assertFalse(header.isHandshake()); - assertTrue(header.isError()); - - RemoteTransportException remoteException = tuple.v2().streamInput().readException(); - assertThat(remoteException.getCause(), instanceOf(OpenSearchException.class)); - assertEquals(remoteException.getCause().getMessage(), "boom"); - assertEquals(action, remoteException.action()); - assertEquals(channel.getLocalAddress(), remoteException.address().address()); - - assertEquals("header_value", header.getHeaders().v1().get("header")); - } } diff --git a/server/src/test/java/org/opensearch/transport/TcpTransportTests.java b/server/src/test/java/org/opensearch/transport/TcpTransportTests.java index 7ab78cca7d615..7c5c9ec12360d 100644 --- a/server/src/test/java/org/opensearch/transport/TcpTransportTests.java +++ b/server/src/test/java/org/opensearch/transport/TcpTransportTests.java @@ -43,7 +43,6 @@ import org.opensearch.common.network.NetworkService; import org.opensearch.common.network.NetworkUtils; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.MockPageCacheRecycler; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; @@ -540,19 +539,7 @@ private void testExceptionHandling( final PlainActionFuture listener = new PlainActionFuture<>(); channel.addCloseListener(listener); - TcpTransport.handleException( - channel, - exception, - lifecycle, - new OutboundHandler( - randomAlphaOfLength(10), - Version.CURRENT, - new String[0], - new StatsTracker(), - testThreadPool, - BigArrays.NON_RECYCLING_INSTANCE - ) - ); + TcpTransport.handleException(channel, exception, lifecycle, new OutboundHandler(new StatsTracker(), testThreadPool)); if (expectClosed) { assertTrue(listener.isDone()); diff --git a/server/src/test/java/org/opensearch/transport/TransportLoggerTests.java b/server/src/test/java/org/opensearch/transport/TransportLoggerTests.java index 05296e9308657..cbd1b959f5f16 100644 --- a/server/src/test/java/org/opensearch/transport/TransportLoggerTests.java +++ b/server/src/test/java/org/opensearch/transport/TransportLoggerTests.java @@ -33,12 +33,6 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; -import org.opensearch.Version; -import org.opensearch.action.admin.cluster.stats.ClusterStatsAction; -import org.opensearch.action.admin.cluster.stats.ClusterStatsRequest; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; @@ -49,7 +43,7 @@ import static org.mockito.Mockito.mock; @TestLogging(value = "org.opensearch.transport.TransportLogger:trace", reason = "to ensure we log network events on TRACE level") -public class TransportLoggerTests extends OpenSearchTestCase { +public abstract class TransportLoggerTests extends OpenSearchTestCase { public void testLoggingHandler() throws Exception { try (MockLogAppender appender = MockLogAppender.createForLoggers(LogManager.getLogger(TransportLogger.class))) { final String writePattern = ".*\\[length: \\d+" @@ -90,20 +84,5 @@ public void testLoggingHandler() throws Exception { } } - private BytesReference buildRequest() throws IOException { - boolean compress = randomBoolean(); - try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { - OutboundMessage.Request request = new OutboundMessage.Request( - new ThreadContext(Settings.EMPTY), - new String[0], - new ClusterStatsRequest(), - Version.CURRENT, - ClusterStatsAction.NAME, - randomInt(30), - false, - compress - ); - return request.serialize(bytesStreamOutput); - } - } + public abstract BytesReference buildRequest() throws IOException; } diff --git a/server/src/test/java/org/opensearch/transport/CompressibleBytesOutputStreamTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStreamTests.java similarity index 99% rename from server/src/test/java/org/opensearch/transport/CompressibleBytesOutputStreamTests.java rename to server/src/test/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStreamTests.java index 89018b7353e7c..eaa35469b9ec0 100644 --- a/server/src/test/java/org/opensearch/transport/CompressibleBytesOutputStreamTests.java +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/CompressibleBytesOutputStreamTests.java @@ -30,7 +30,7 @@ * GitHub history for details. */ -package org.opensearch.transport; +package org.opensearch.transport.nativeprotocol; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundDecoderTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundDecoderTests.java new file mode 100644 index 0000000000000..bd85939c753fa --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundDecoderTests.java @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.transport.InboundDecoderTests; + +import java.io.IOException; +import java.util.Collections; + +public class NativeInboundDecoderTests extends InboundDecoderTests { + + @Override + protected BytesReference serialize( + boolean isRequest, + Version version, + boolean handshake, + boolean compress, + String action, + long requestId, + Writeable transportMessage + ) throws IOException { + NativeOutboundMessage message; + if (isRequest) { + message = new NativeOutboundMessage.Request( + threadContext, + new String[0], + transportMessage, + version, + action, + requestId, + handshake, + compress + ); + } else { + message = new NativeOutboundMessage.Response( + threadContext, + Collections.emptySet(), + transportMessage, + version, + requestId, + handshake, + compress + ); + } + + return message.serialize(new BytesStreamOutput()); + } + +} diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java new file mode 100644 index 0000000000000..ec0c1a50d5560 --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.transport.InboundHandlerTests; + +import java.io.IOException; + +public class NativeInboundHandlerTests extends InboundHandlerTests { + + @Override + public BytesReference serializeOutboundRequest( + ThreadContext threadContext, + Writeable message, + Version version, + String action, + long requestId, + boolean compress, + boolean handshake + ) throws IOException { + NativeOutboundMessage.Request request = new NativeOutboundMessage.Request( + threadContext, + new String[0], + message, + version, + action, + requestId, + handshake, + compress + ); + return request.serialize(new BytesStreamOutput()); + } + +} diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundPipelineTests.java new file mode 100644 index 0000000000000..6b5bf46eee3ae --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundPipelineTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.transport.InboundPipelineTests; +import org.opensearch.transport.TestRequest; +import org.opensearch.transport.TestResponse; + +import java.io.IOException; +import java.util.Collections; + +public class NativeInboundPipelineTests extends InboundPipelineTests { + + @Override + protected BytesReference serialize( + boolean isRequest, + Version version, + boolean handshake, + boolean compress, + String action, + long requestId, + String value + ) throws IOException { + NativeOutboundMessage message; + if (isRequest) { + message = new NativeOutboundMessage.Request( + threadContext, + new String[0], + new TestRequest(value), + version, + action, + requestId, + handshake, + compress + ); + } else { + message = new NativeOutboundMessage.Response( + threadContext, + Collections.emptySet(), + new TestResponse(value), + version, + requestId, + handshake, + compress + ); + } + + return message.serialize(new BytesStreamOutput()); + } + +} diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessageTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessageTests.java new file mode 100644 index 0000000000000..75c4e84b4456e --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeOutboundMessageTests.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TestRequest; + +import java.io.IOException; + +public class NativeOutboundMessageTests extends OpenSearchTestCase { + + public void testNativeOutboundMessageRequestSerialization() throws IOException { + NativeOutboundMessage.Request message = new NativeOutboundMessage.Request( + new ThreadContext(Settings.EMPTY), + new String[0], + new TestRequest("content"), + Version.CURRENT, + "action", + 1, + false, + false + ); + BytesStreamOutput output = new BytesStreamOutput(); + message.serialize(output); + + BytesStreamInput input = new BytesStreamInput(output.bytes().toBytesRef().bytes); + assertEquals(Version.CURRENT, input.getVersion()); + // reading header details + assertEquals((byte) 'E', input.readByte()); + assertEquals((byte) 'S', input.readByte()); + assertNotEquals(0, input.readInt()); + assertEquals(1, input.readLong()); + assertEquals(0, input.readByte()); + assertEquals(Version.CURRENT.id, input.readInt()); + int variableHeaderSize = input.readInt(); + assertNotEquals(-1, variableHeaderSize); + } + +} diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeTransportLoggerTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeTransportLoggerTests.java new file mode 100644 index 0000000000000..db75d6ff45556 --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeTransportLoggerTests.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.stats.ClusterStatsAction; +import org.opensearch.action.admin.cluster.stats.ClusterStatsRequest; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.test.junit.annotations.TestLogging; +import org.opensearch.transport.TransportLoggerTests; + +import java.io.IOException; + +@TestLogging(value = "org.opensearch.transport.TransportLogger:trace", reason = "to ensure we log network events on TRACE level") +public class NativeTransportLoggerTests extends TransportLoggerTests { + + public BytesReference buildRequest() throws IOException { + boolean compress = randomBoolean(); + try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { + NativeOutboundMessage.Request request = new NativeOutboundMessage.Request( + new ThreadContext(Settings.EMPTY), + new String[0], + new ClusterStatsRequest(), + Version.CURRENT, + ClusterStatsAction.NAME, + randomInt(30), + false, + compress + ); + return request.serialize(new BytesStreamOutput()); + } + } +} diff --git a/test/framework/src/main/java/org/opensearch/transport/TestRequest.java b/test/framework/src/main/java/org/opensearch/transport/TestRequest.java index 2fe917235e948..0bb5a6e16fff1 100644 --- a/test/framework/src/main/java/org/opensearch/transport/TestRequest.java +++ b/test/framework/src/main/java/org/opensearch/transport/TestRequest.java @@ -54,4 +54,8 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(value); } + + public String getValue() { + return value; + } } diff --git a/test/framework/src/main/java/org/opensearch/transport/TestResponse.java b/test/framework/src/main/java/org/opensearch/transport/TestResponse.java index 14db8b3372bf2..7fd7c760c9cf6 100644 --- a/test/framework/src/main/java/org/opensearch/transport/TestResponse.java +++ b/test/framework/src/main/java/org/opensearch/transport/TestResponse.java @@ -54,4 +54,8 @@ public TestResponse(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(value); } + + public String getValue() { + return value; + } }