diff --git a/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java b/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java index 69be6cbecc96a..37bfcef434582 100644 --- a/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java +++ b/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java @@ -46,6 +46,8 @@ public abstract class TransportMessage implements Writeable, ProtobufWriteable { private TransportAddress remoteAddress; + private boolean isProtobuf; + public void remoteAddress(TransportAddress remoteAddress) { this.remoteAddress = remoteAddress; } @@ -54,6 +56,10 @@ public TransportAddress remoteAddress() { return remoteAddress; } + public boolean isMessageProtobuf() { + return isProtobuf; + } + /** * Constructs a new empty transport message */ diff --git a/modules/transport-netty4/src/main/java/org/opensearch/transport/netty4/Netty4MessageChannelHandler.java b/modules/transport-netty4/src/main/java/org/opensearch/transport/netty4/Netty4MessageChannelHandler.java index 7b9999ce5b20e..a05c15d689b22 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/transport/netty4/Netty4MessageChannelHandler.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/transport/netty4/Netty4MessageChannelHandler.java @@ -78,7 +78,8 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { threadPool::relativeTimeInMillis, transport.getInflightBreaker(), requestHandlers::getHandler, - transport::inboundMessage + transport::inboundMessage, + transport::inboundMessageProtobuf ); } diff --git a/plugins/transport-nio/src/main/java/org/opensearch/transport/nio/TcpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/opensearch/transport/nio/TcpReadWriteHandler.java index 0c90deed6411c..545f03e03761e 100644 --- a/plugins/transport-nio/src/main/java/org/opensearch/transport/nio/TcpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/opensearch/transport/nio/TcpReadWriteHandler.java @@ -68,7 +68,8 @@ public TcpReadWriteHandler(NioTcpChannel channel, PageCacheRecycler recycler, Tc threadPool::relativeTimeInMillis, breaker, requestHandlers::getHandler, - transport::inboundMessage + transport::inboundMessage, + transport::inboundMessageProtobuf ); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchExecutionStatsCollector.java b/server/src/main/java/org/opensearch/action/search/SearchExecutionStatsCollector.java index 842e87b3eb635..5fb838cb6c70d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchExecutionStatsCollector.java +++ b/server/src/main/java/org/opensearch/action/search/SearchExecutionStatsCollector.java @@ -70,8 +70,10 @@ public static BiFunction reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; final ActionListener handler = responseWrapper.apply(connection, listener); diff --git a/server/src/main/java/org/opensearch/search/SearchHit.java b/server/src/main/java/org/opensearch/search/SearchHit.java index 10e65fca3afb5..d8d78be4c6b68 100644 --- a/server/src/main/java/org/opensearch/search/SearchHit.java +++ b/server/src/main/java/org/opensearch/search/SearchHit.java @@ -44,6 +44,7 @@ import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.ProtobufWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -66,9 +67,12 @@ import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.search.fetch.subphase.highlight.HighlightField; import org.opensearch.search.lookup.SourceLookup; +import org.opensearch.server.proto.FetchSearchResultProto; import org.opensearch.transport.RemoteClusterAware; import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -96,7 +100,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public final class SearchHit implements Writeable, ToXContentObject, Iterable { +public final class SearchHit implements Writeable, ToXContentObject, Iterable, ProtobufWriteable { private final transient int docId; @@ -137,6 +141,8 @@ public final class SearchHit implements Writeable, ToXContentObject, Iterable innerHits; + private FetchSearchResultProto.SearchHit searchHitProto; + // used only in tests public SearchHit(int docId) { this(docId, null, null, null); @@ -224,6 +230,23 @@ public SearchHit(StreamInput in) throws IOException { } } + public SearchHit(byte[] in) throws IOException { + this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(in); + docId = -1; + score = this.searchHitProto.getScore(); + id = new Text(this.searchHitProto.getId()); + // Support for nestedIdentity to be added in the future + nestedIdentity = null; + version = this.searchHitProto.getVersion(); + seqNo = this.searchHitProto.getSeqNo(); + primaryTerm = this.searchHitProto.getPrimaryTerm(); + source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray())); + if (source.length() == 0) { + source = null; + } + metaFields = new HashMap<>(); + } + private Map readFields(StreamInput in) throws IOException { Map fields; int size = in.readVInt(); @@ -306,6 +329,11 @@ public void writeTo(StreamOutput out) throws IOException { } } + @Override + public void writeTo(OutputStream out) throws IOException { + out.write(this.searchHitProto.toByteArray()); + } + public int docId() { return this.docId; } diff --git a/server/src/main/java/org/opensearch/search/SearchHits.java b/server/src/main/java/org/opensearch/search/SearchHits.java index 8232643b353f5..2df927318b6c7 100644 --- a/server/src/main/java/org/opensearch/search/SearchHits.java +++ b/server/src/main/java/org/opensearch/search/SearchHits.java @@ -38,6 +38,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.lucene.Lucene; +import org.opensearch.core.common.io.stream.ProtobufWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -47,6 +48,7 @@ import org.opensearch.rest.action.search.RestSearchAction; import java.io.IOException; +import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; @@ -61,7 +63,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public final class SearchHits implements Writeable, ToXContentFragment, Iterable { +public final class SearchHits implements Writeable, ToXContentFragment, Iterable, ProtobufWriteable { public static SearchHits empty() { return empty(true); } @@ -82,6 +84,8 @@ public static SearchHits empty(boolean withTotalHits) { @Nullable private final Object[] collapseValues; + private org.opensearch.server.proto.FetchSearchResultProto.SearchHits searchHitsProto; + public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) { this(hits, totalHits, maxScore, null, null, null); } @@ -124,6 +128,23 @@ public SearchHits(StreamInput in) throws IOException { collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new); } + public SearchHits(byte[] in) throws IOException { + this.searchHitsProto = org.opensearch.server.proto.FetchSearchResultProto.SearchHits.parseFrom(in); + this.hits = new SearchHit[this.searchHitsProto.getHitsCount()]; + for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) { + this.hits[i] = new SearchHit(this.searchHitsProto.getHits(i).toByteArray()); + } + this.totalHits = new TotalHits( + this.searchHitsProto.getTotalHits().getValue(), + Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString()) + ); + this.maxScore = this.searchHitsProto.getMaxScore(); + this.collapseField = this.searchHitsProto.getCollapseField(); + // Below fields are set to null currently, support to be added in the future + this.collapseValues = null; + this.sortFields = null; + } + @Override public void writeTo(StreamOutput out) throws IOException { final boolean hasTotalHits = totalHits != null; @@ -342,4 +363,9 @@ private static Relation parseRelation(String relation) { throw new IllegalArgumentException("invalid total hits relation: " + relation); } } + + @Override + public void writeTo(OutputStream out) throws IOException { + out.write(searchHitsProto.toByteArray()); + } } diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java index 667b022091a85..76368f3f2de07 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java @@ -32,7 +32,10 @@ package org.opensearch.search.fetch; +import com.google.protobuf.ByteString; +import org.apache.lucene.search.TotalHits.Relation; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchHit; @@ -76,6 +79,7 @@ public FetchSearchResult(byte[] in) throws IOException { this.fetchSearchResultProto.getContextId().getSessionId(), this.fetchSearchResultProto.getContextId().getId() ); + hits = new SearchHits(this.fetchSearchResultProto.getHits().toByteArray()); } public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) { @@ -101,6 +105,30 @@ public FetchSearchResult fetchResult() { public void hits(SearchHits hits) { assert assertNoSearchTarget(hits); this.hits = hits; + if (this.fetchSearchResultProto != null) { + QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder(); + totalHitsBuilder.setValue(hits.getTotalHits().value); + totalHitsBuilder.setRelation( + hits.getTotalHits().relation == Relation.EQUAL_TO + ? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO + : QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ); + FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder(); + searchHitsBuilder.setMaxScore(hits.getMaxScore()); + searchHitsBuilder.setTotalHits(totalHitsBuilder.build()); + for (SearchHit hit : hits.getHits()) { + FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder(); + searchHitBuilder.setIndex(hit.getIndex()); + searchHitBuilder.setId(hit.getId()); + searchHitBuilder.setScore(hit.getScore()); + searchHitBuilder.setSeqNo(hit.getSeqNo()); + searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm()); + searchHitBuilder.setVersion(hit.getVersion()); + searchHitBuilder.setSource(ByteString.copyFrom(BytesReference.toBytes(hit.getSourceRef()))); + searchHitsBuilder.addHits(searchHitBuilder.build()); + } + this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder().setHits(searchHitsBuilder.build()).build(); + } } private boolean assertNoSearchTarget(SearchHits hits) { diff --git a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java index b57434a99edb4..6e0c0bf20e7ac 100644 --- a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java @@ -32,6 +32,7 @@ package org.opensearch.search.fetch; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchPhaseResult; @@ -117,4 +118,23 @@ public void writeTo(StreamOutput out) throws IOException { queryResult.writeTo(out); fetchResult.writeTo(out); } + + @Override + public boolean isMessageProtobuf() { + // System.setProperty(FeatureFlags.PROTOBUF, "true"); + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return true; + } + return false; + } + + public QueryFetchSearchResultProto.QueryFetchSearchResult response() { + return this.queryFetchSearchResultProto; + } + + public QueryFetchSearchResult(QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResult) { + this.queryFetchSearchResultProto = queryFetchSearchResult; + this.queryResult = new QuerySearchResult(queryFetchSearchResult.getQueryResult()); + this.fetchResult = new FetchSearchResult(queryFetchSearchResult.getFetchResult()); + } } diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index b181d6c2462fb..f786c00d05377 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -34,7 +34,9 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.io.stream.DelayableWriteable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; @@ -128,28 +130,27 @@ public QuerySearchResult( isNull = false; setShardSearchRequest(shardSearchRequest); - if (FeatureFlags.PROTOBUF_SETTING.get(settings)) { - QuerySearchResultProto.ShardId shardIdProto = QuerySearchResultProto.ShardId.newBuilder() - .setShardId(shardTarget.getShardId().getId()) - .setHashCode(shardTarget.getShardId().hashCode()) - .setIndexName(shardTarget.getShardId().getIndexName()) - .setIndexUUID(shardTarget.getShardId().getIndex().getUUID()) - .build(); - QuerySearchResultProto.SearchShardTarget searchShardTarget = QuerySearchResultProto.SearchShardTarget.newBuilder() - .setNodeId(shardTarget.getNodeId()) - .setShardId(shardIdProto) - .setClusterAlias(shardTarget.getClusterAlias()) - .build(); - this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.newBuilder() - .setContextId( - QuerySearchResultProto.ShardSearchContextId.newBuilder() - .setSessionId(contextId.getSessionId()) - .setId(contextId.getId()) - .build() - ) - .setSearchShardTarget(searchShardTarget) - .build(); + QuerySearchResultProto.ShardId shardIdProto = QuerySearchResultProto.ShardId.newBuilder() + .setShardId(shardTarget.getShardId().getId()) + .setHashCode(shardTarget.getShardId().hashCode()) + .setIndexName(shardTarget.getShardId().getIndexName()) + .setIndexUUID(shardTarget.getShardId().getIndex().getUUID()) + .build(); + QuerySearchResultProto.SearchShardTarget.Builder searchShardTarget = QuerySearchResultProto.SearchShardTarget.newBuilder() + .setNodeId(shardTarget.getNodeId()) + .setShardId(shardIdProto); + if (shardTarget.getClusterAlias() != null) { + searchShardTarget.setClusterAlias(shardTarget.getClusterAlias()); } + this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.newBuilder() + .setContextId( + QuerySearchResultProto.ShardSearchContextId.newBuilder() + .setSessionId(contextId.getSessionId()) + .setId(contextId.getId()) + .build() + ) + .setSearchShardTarget(searchShardTarget.build()) + .build(); } private QuerySearchResult(boolean isNull) { @@ -199,9 +200,33 @@ public Boolean terminatedEarly() { } public TopDocsAndMaxScore topDocs() { - if (topDocsAndMaxScore == null) { + if (topDocsAndMaxScore == null && this.querySearchResultProto.getTopDocsAndMaxScore() == null) { throw new IllegalStateException("topDocs already consumed"); } + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + ScoreDoc[] scoreDocs = new ScoreDoc[this.querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getScoreDocsCount()]; + for (int i = 0; i < scoreDocs.length; i++) { + org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDoc scoreDoc = this.querySearchResultProto + .getTopDocsAndMaxScore() + .getTopDocs() + .getScoreDocsList() + .get(i); + scoreDocs[i] = new ScoreDoc(scoreDoc.getDoc(), scoreDoc.getScore(), scoreDoc.getShardIndex()); + } + TopDocs topDocsFromProtobuf = new TopDocs( + new TotalHits( + this.querySearchResultProto.getTotalHits().getValue(), + Relation.valueOf(this.querySearchResultProto.getTotalHits().getRelation().toString()) + ), + scoreDocs + ); + + TopDocsAndMaxScore topDocsFromProtobufAndMaxScore = new TopDocsAndMaxScore( + topDocsFromProtobuf, + this.querySearchResultProto.getMaxScore() + ); + return topDocsFromProtobufAndMaxScore; + } return topDocsAndMaxScore; } @@ -217,7 +242,7 @@ public boolean hasConsumedTopDocs() { * @throws IllegalStateException if the top docs have already been consumed. */ public TopDocsAndMaxScore consumeTopDocs() { - TopDocsAndMaxScore topDocsAndMaxScore = this.topDocsAndMaxScore; + TopDocsAndMaxScore topDocsAndMaxScore = this.topDocsAndMaxScore == null ? topDocs() : this.topDocsAndMaxScore; if (topDocsAndMaxScore == null) { throw new IllegalStateException("topDocs already consumed"); } @@ -273,7 +298,7 @@ private void setTopDocs(TopDocsAndMaxScore topDocsAndMaxScore) { .setMaxScore(topDocsAndMaxScore.maxScore) .setTopDocs(topDocsBuilder) .build(); - this.querySearchResultProto.toBuilder() + this.querySearchResultProto = this.querySearchResultProto.toBuilder() .setTopDocsAndMaxScore(topDocsAndMaxScoreBuilder) .setMaxScore(this.maxScore) .setTotalHits(topDocsBuilder.getTotalHits()) diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index a8315c3cae4e0..d6ec2770e4176 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -37,15 +37,18 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.BytesRef; import org.opensearch.Version; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.ByteBufferStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; import org.opensearch.telemetry.tracing.SpanScope; @@ -60,6 +63,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** @@ -128,6 +132,13 @@ void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception } } + void inboundMessageProtobuf(TcpChannel channel, BytesReference message) throws IOException { + final long startTime = threadPool.relativeTimeInMillis(); + channel.getChannelStats().markAccessed(startTime); + NodeToNodeMessage protobufMessage = new NodeToNodeMessage(BytesReference.toBytes(message)); + messageReceivedProtobuf(channel, protobufMessage, startTime); + } + // Empty stream constant to avoid instantiating a new stream for empty messages. private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES)); @@ -192,6 +203,41 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st } } + private void messageReceivedProtobuf(TcpChannel channel, NodeToNodeMessage message, long startTime) throws IOException { + final InetSocketAddress remoteAddress = channel.getRemoteAddress(); + final org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header header = message.getHeader(); + + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + // Place the context with the headers from the message + final Tuple, Map>> headers = new Tuple, Map>>( + message.getRequestHeaders(), + message.getResponseHandlers() + ); + threadContext.setHeaders(headers); + threadContext.putTransient("_remote_address", remoteAddress); + + long requestId = header.getRequestId(); + TransportResponseHandler handler = responseHandlers.onResponseReceived(requestId, messageListener); + if (handler != null) { + // if (handler.toString().contains("Protobuf")) { + handleProtobufResponse(requestId, remoteAddress, message, handler); + // } + } + } finally { + final long took = threadPool.relativeTimeInMillis() - startTime; + final long logThreshold = slowLogThresholdMs; + if (logThreshold > 0 && took > logThreshold) { + logger.warn( + "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", + message, + took, + logThreshold + ); + } + } + } + private Map> extractHeaders(Map headers) { return headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> Collections.singleton(e.getValue()))); } @@ -415,6 +461,39 @@ private void handleResponse( } } + private void handleProtobufResponse( + final long requestId, + InetSocketAddress remoteAddress, + final NodeToNodeMessage message, + final TransportResponseHandler handler + ) throws IOException { + try { + org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage receivedMessage = message.getMessage(); + if (receivedMessage.hasQueryFetchSearchResult()) { + final QueryFetchSearchResult queryFetchSearchResult = receivedMessage.getQueryFetchSearchResult(); + org.opensearch.search.fetch.QueryFetchSearchResult queryFetchSearchResult2 = + new org.opensearch.search.fetch.QueryFetchSearchResult(queryFetchSearchResult); + final T response = (T) queryFetchSearchResult2; + response.remoteAddress(new TransportAddress(remoteAddress)); + + final String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { + doHandleResponse(handler, response); + } else { + threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); + } + } + } catch (Exception e) { + final Exception serializationException = new TransportSerializationException( + "Failed to deserialize response from handler [" + handler + "]", + e + ); + logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); + handleException(handler, serializationException); + return; + } + } + private void doHandleResponse(TransportResponseHandler handler, T response) { try { handler.handleResponse(response); diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index dd4690e5e6abf..d1410c5272e76 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -38,6 +38,7 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.bytes.CompositeBytesReference; import java.io.IOException; @@ -63,6 +64,7 @@ public class InboundPipeline implements Releasable { private final InboundDecoder decoder; private final InboundAggregator aggregator; private final BiConsumer messageHandler; + private final BiConsumer messageHandlerProtobuf; private Exception uncaughtException; private final ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; @@ -74,14 +76,16 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, Supplier circuitBreaker, Function> registryFunction, - BiConsumer messageHandler + BiConsumer messageHandler, + BiConsumer messageHandlerProtobuf ) { this( statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), new InboundAggregator(circuitBreaker, registryFunction), - messageHandler + messageHandler, + messageHandlerProtobuf ); } @@ -90,13 +94,15 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, InboundDecoder decoder, InboundAggregator aggregator, - BiConsumer messageHandler + BiConsumer messageHandler, + BiConsumer messageHandlerProtobuf ) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; this.messageHandler = messageHandler; + this.messageHandlerProtobuf = messageHandlerProtobuf; } @Override @@ -120,41 +126,49 @@ public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) } public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { - channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); - statsTracker.markBytesRead(reference.length()); - pending.add(reference.retain()); - - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + try { + byte[] incomingBytes = BytesReference.toBytes(reference); + NodeToNodeMessage protobufMessage = new NodeToNodeMessage(incomingBytes); + if (protobufMessage.isProtobuf()) { + forwardFragmentsProtobuf(channel, reference); + } + } catch (Exception e) { + channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); + statsTracker.markBytesRead(reference.length()); + pending.add(reference.retain()); + + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { continueDecoding = false; } - } else { - continueDecoding = false; } } - } - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } } + fragments.clear(); } - fragments.clear(); } } } @@ -182,6 +196,10 @@ private void forwardFragments(TcpChannel channel, ArrayList fragments) t } } + private void forwardFragmentsProtobuf(TcpChannel channel, ReleasableBytesReference reference) throws IOException { + messageHandlerProtobuf.accept(channel, reference); + } + private boolean endOfMessage(Object fragment) { return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; } diff --git a/server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java b/server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java new file mode 100644 index 0000000000000..468a9327b86e5 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java @@ -0,0 +1,118 @@ +/* +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +package org.opensearch.transport; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import org.opensearch.Version; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.server.proto.NodeToNodeMessageProto; +import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header; +import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersList; +import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Outbound data as a message +* +* @opensearch.internal +*/ +public class NodeToNodeMessage { + + private final NodeToNodeMessageProto.NodeToNodeMessage message; + private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' }; + + public NodeToNodeMessage( + long requestId, + byte[] status, + Version version, + ThreadContext threadContext, + QueryFetchSearchResult queryFetchSearchResult, + Set features, + String action + ) { + Header header = Header.newBuilder() + .addAllPrefix(Arrays.asList(ByteString.copyFrom(PREFIX))) + .setRequestId(requestId) + .setStatus(ByteString.copyFrom(status)) + .setVersionId(version.id) + .build(); + Map requestHeaders = threadContext.getHeaders(); + Map> responseHeaders = threadContext.getResponseHeaders(); + Map responseHandlers = new HashMap<>(); + for (Map.Entry> entry : responseHeaders.entrySet()) { + String key = entry.getKey(); + List value = entry.getValue(); + ResponseHandlersList responseHandlersList = ResponseHandlersList.newBuilder().addAllSetOfResponseHandlers(value).build(); + responseHandlers.put(key, responseHandlersList); + } + this.message = NodeToNodeMessageProto.NodeToNodeMessage.newBuilder() + .setHeader(header) + .putAllRequestHeaders(requestHeaders) + .putAllResponseHandlers(responseHandlers) + .setVersion(version.toString()) + .setStatus(ByteString.copyFrom(status)) + .setRequestId(requestId) + .setQueryFetchSearchResult(queryFetchSearchResult) + .setAction(action) + .addAllFeatures(features) + .setIsProtobuf(true) + .build(); + + } + + public NodeToNodeMessage(byte[] data) throws InvalidProtocolBufferException { + this.message = NodeToNodeMessageProto.NodeToNodeMessage.parseFrom(data); + } + + public void writeTo(OutputStream out) throws IOException { + out.write(this.message.toByteArray()); + } + + public NodeToNodeMessageProto.NodeToNodeMessage getMessage() { + return this.message; + } + + @Override + public String toString() { + return "NodeToNodeMessage [message=" + message + "]"; + } + + public org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header getHeader() { + return this.message.getHeader(); + } + + public Map getRequestHeaders() { + return this.message.getRequestHeadersMap(); + } + + public Map> getResponseHandlers() { + Map responseHandlers = this.message.getResponseHandlersMap(); + Map> responseHandlersMap = new HashMap<>(); + for (Map.Entry entry : responseHandlers.entrySet()) { + String key = entry.getKey(); + ResponseHandlersList value = entry.getValue(); + Set setOfResponseHandlers = value.getSetOfResponseHandlersList().stream().collect(Collectors.toSet()); + responseHandlersMap.put(key, setOfResponseHandlers); + } + return responseHandlersMap; + } + + public boolean isProtobuf() { + return this.message.getIsProtobuf(); + } +} diff --git a/server/src/main/java/org/opensearch/transport/OutboundHandler.java b/server/src/main/java/org/opensearch/transport/OutboundHandler.java index b83dbdd0effe4..bb102c47c3085 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/OutboundHandler.java @@ -38,6 +38,7 @@ import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.CheckedSupplier; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.ReleasableBytesStreamOutput; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; @@ -51,9 +52,11 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Set; /** @@ -146,17 +149,35 @@ void sendResponse( final boolean isHandshake ) throws IOException { Version version = Version.min(this.version, nodeVersion); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - features, - response, - version, - requestId, - isHandshake, - compress - ); ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - sendMessage(channel, message, listener); + if (response.isMessageProtobuf()) { + QueryFetchSearchResult queryFetchSearchResult = (QueryFetchSearchResult) response; + if (queryFetchSearchResult.response() != null) { + byte[] bytes = new byte[1]; + bytes[0] = 1; + NodeToNodeMessage protobufMessage = new NodeToNodeMessage( + requestId, + bytes, + Version.CURRENT, + threadPool.getThreadContext(), + queryFetchSearchResult.response(), + features, + action + ); + sendProtobufMessage(channel, protobufMessage, listener); + } + } else { + OutboundMessage.Response message = new OutboundMessage.Response( + threadPool.getThreadContext(), + features, + response, + version, + requestId, + isHandshake, + compress + ); + sendMessage(channel, message, listener); + } } /** @@ -192,6 +213,12 @@ private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, Act internalSend(channel, sendContext); } + private void sendProtobufMessage(TcpChannel channel, NodeToNodeMessage message, ActionListener listener) throws IOException { + ProtobufMessageSerializer serializer = new ProtobufMessageSerializer(message, bigArrays); + SendContext sendContext = new SendContext(channel, serializer, listener, serializer); + internalSend(channel, sendContext); + } + private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); BytesReference reference = sendContext.get(); @@ -241,6 +268,36 @@ public void close() { } } + private static class ProtobufMessageSerializer implements CheckedSupplier, Releasable { + + private final NodeToNodeMessage message; + private final BigArrays bigArrays; + private volatile ReleasableBytesStreamOutput bytesStreamOutput; + + private ProtobufMessageSerializer(NodeToNodeMessage message, BigArrays bigArrays) { + this.message = message; + this.bigArrays = bigArrays; + } + + @Override + public BytesReference get() throws IOException { + bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); + BytesReference reference = serialize(bytesStreamOutput); + return reference; + } + + private BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { + ByteBuffer byteBuffers = ByteBuffer.wrap(message.getMessage().toByteArray()); + message.getMessage().writeTo(bytesStream); + return BytesReference.fromByteBuffer(byteBuffers); + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(bytesStreamOutput); + } + } + private class SendContext extends NotifyOnceListener implements CheckedSupplier { private final TcpChannel channel; diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index d0e6516973382..7f7f023411916 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -773,6 +773,14 @@ public void inboundMessage(TcpChannel channel, InboundMessage message) { } } + public void inboundMessageProtobuf(TcpChannel channel, BytesReference message) { + try { + inboundHandler.inboundMessageProtobuf(channel, message); + } catch (Exception e) { + onException(channel, e); + } + } + /** * Validates the first 6 bytes of the message header and returns the length of the message. If 6 bytes * are not available, it returns -1. diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index b84f85cedfeb8..c06aa36ea6bc5 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -1505,8 +1505,7 @@ void setTimeoutHandler(TimeoutHandler handler) { @Override public T read(byte[] in) throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'read'"); + return delegate.read(in); } } @@ -1725,8 +1724,7 @@ public String toString() { @Override public T read(byte[] in) throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'read'"); + return handler.read(in); } }; } else { diff --git a/server/src/main/proto/server/NodeToNodeMessageProto.proto b/server/src/main/proto/server/NodeToNodeMessageProto.proto new file mode 100644 index 0000000000000..9348941cd2345 --- /dev/null +++ b/server/src/main/proto/server/NodeToNodeMessageProto.proto @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/QueryFetchSearchResultProto.proto"; + +option java_outer_classname = "NodeToNodeMessageProto"; + +message NodeToNodeMessage { + Header header = 1; + map requestHeaders = 2; + map responseHandlers = 3; + string version = 4; + repeated string features = 5; + bytes status = 6; + string action = 7; + int64 requestId = 8; + oneof message { + QueryFetchSearchResult queryFetchSearchResult = 16; + } + bool isProtobuf = 17; + + message Header { + repeated bytes prefix = 1; + int64 requestId = 2; + bytes status = 3; + int32 versionId = 4; + } + + message ResponseHandlersList { + repeated string setOfResponseHandlers = 1; + } +} diff --git a/server/src/main/proto/server/search/QuerySearchResultProto.proto b/server/src/main/proto/server/search/QuerySearchResultProto.proto index 8caf561bf2bdd..f61e144501a71 100644 --- a/server/src/main/proto/server/search/QuerySearchResultProto.proto +++ b/server/src/main/proto/server/search/QuerySearchResultProto.proto @@ -62,7 +62,7 @@ message QuerySearchResult { message SearchShardTarget { string nodeId = 1; ShardId shardId = 2; - string clusterAlias = 3; + optional string clusterAlias = 3; } message TotalHits { diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index ae4b537223394..af768809e9970 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -94,6 +94,7 @@ public void testPipelineHandling() throws IOException { throw new AssertionError(e); } }; + final BiConsumer messageHandlerProtobuf = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); @@ -104,7 +105,14 @@ public void testPipelineHandling() throws IOException { final TestCircuitBreaker circuitBreaker = new TestCircuitBreaker(); circuitBreaker.startBreaking(); final InboundAggregator aggregator = new InboundAggregator(() -> circuitBreaker, canTripBreaker); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + messageHandlerProtobuf + ); final FakeTcpChannel channel = new FakeTcpChannel(); final int iterations = randomIntBetween(5, 10); @@ -215,12 +223,20 @@ public void testPipelineHandling() throws IOException { public void testDecodeExceptionIsPropagated() throws IOException { BiConsumer messageHandler = (c, m) -> {}; + final BiConsumer messageHandlerProtobuf = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + messageHandlerProtobuf + ); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; @@ -269,12 +285,20 @@ public void testDecodeExceptionIsPropagated() throws IOException { public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { BiConsumer messageHandler = (c, m) -> {}; + final BiConsumer messageHandlerProtobuf = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + messageHandlerProtobuf + ); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index ff99435f765d8..32ceb123f0e37 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -76,6 +76,7 @@ public class OutboundHandlerTests extends OpenSearchTestCase { private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private final AtomicReference> message = new AtomicReference<>(); + private final AtomicReference protobufMessage = new AtomicReference<>(); private InboundPipeline pipeline; private OutboundHandler handler; private FakeTcpChannel channel; @@ -102,7 +103,7 @@ public void setUp() throws Exception { } catch (IOException e) { throw new AssertionError(e); } - }); + }, (c, m) -> { protobufMessage.set(m); }); } @After diff --git a/test/framework/src/main/java/org/opensearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/opensearch/transport/nio/MockNioTransport.java index cd6bf02efef6f..76eac48bc1a6d 100644 --- a/test/framework/src/main/java/org/opensearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/opensearch/transport/nio/MockNioTransport.java @@ -327,7 +327,8 @@ private MockTcpReadWriteHandler(MockSocketChannel channel, PageCacheRecycler rec threadPool::relativeTimeInMillis, breaker, requestHandlers::getHandler, - transport::inboundMessage + transport::inboundMessage, + transport::inboundMessageProtobuf ); }