diff --git a/distribution/src/config/opensearch.yml b/distribution/src/config/opensearch.yml index 10bab9b3fce92..e0e623652d015 100644 --- a/distribution/src/config/opensearch.yml +++ b/distribution/src/config/opensearch.yml @@ -125,3 +125,7 @@ ${path.logs} # Gates the functionality of enabling Opensearch to use pluggable caches with respective store names via setting. # #opensearch.experimental.feature.pluggable.caching.enabled: false +# +# Gates the functionality of enabling Opensearch to use protobuf with basic searches and for node-to-node communication. +# +#opensearch.experimental.feature.search_with_protobuf.enabled: false diff --git a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java index 569f48a8465f3..6264d00f01887 100644 --- a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java +++ b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java @@ -238,7 +238,15 @@ private boolean inspectable(ExecutableElement executable) { */ private boolean inspectable(Element element) { final PackageElement pckg = processingEnv.getElementUtils().getPackageOf(element); - return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE); + return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE) + && !element.getEnclosingElement() + .getAnnotationMirrors() + .stream() + .anyMatch( + m -> m.getAnnotationType() + .toString() /* ClassSymbol.toString() returns class name */ + .equalsIgnoreCase("javax.annotation.Generated") + ); } /** diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/BytesWriteable.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BytesWriteable.java new file mode 100644 index 0000000000000..8b4f9a6aaf6ef --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BytesWriteable.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.core.common.io.stream; + +import org.opensearch.common.annotation.ExperimentalApi; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Implementers can be written to a {@linkplain OutputStream} and read from a byte array. This allows them to be "thrown + * across the wire" using OpenSearch's internal protocol with protobuf bytes. + * + * @opensearch.api + */ +@ExperimentalApi +public interface BytesWriteable { + + /** + * Write this into the {@linkplain OutputStream}. + */ + void writeTo(OutputStream out) throws IOException; + + /** + * Reference to a method that can write some object to a {@link OutputStream}. + * + * @opensearch.experimental + */ + @FunctionalInterface + @ExperimentalApi + interface Writer { + + /** + * Write {@code V}-type {@code value} to the {@code out}put stream. + * + * @param out Output to write the {@code value} too + * @param value The value to add + */ + void write(final OutputStream out, V value) throws IOException; + } + + /** + * Reference to a method that can read some object from a byte array. + * + * @opensearch.experimental + */ + @FunctionalInterface + @ExperimentalApi + interface Reader { + + /** + * Read {@code V}-type value from a byte array. + * + * @param in byte array to read the value from + */ + V read(final InputStream in) throws IOException; + } +} diff --git a/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java b/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java index 941babda40aa3..a50fc51e8e964 100644 --- a/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java +++ b/libs/core/src/main/java/org/opensearch/core/transport/TransportMessage.java @@ -32,19 +32,24 @@ package org.opensearch.core.transport; +import org.opensearch.core.common.io.stream.BytesWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.common.transport.TransportAddress; +import java.io.InputStream; + /** * Message over the transport interface * * @opensearch.internal */ -public abstract class TransportMessage implements Writeable { +public abstract class TransportMessage implements Writeable, BytesWriteable { private TransportAddress remoteAddress; + private String protocol; + public void remoteAddress(TransportAddress remoteAddress) { this.remoteAddress = remoteAddress; } @@ -53,6 +58,13 @@ public TransportAddress remoteAddress() { return remoteAddress; } + public String getProtocol() { + if (protocol != null) { + return protocol; + } + return "native"; + } + /** * Constructs a new empty transport message */ @@ -63,4 +75,10 @@ public TransportMessage() {} * currently a no-op */ public TransportMessage(StreamInput in) {} + + /** + * Constructs a new transport message with the data from the byte array. This is + * currently a no-op + */ + public TransportMessage(InputStream in) {} } diff --git a/libs/core/src/main/java/org/opensearch/core/transport/TransportResponse.java b/libs/core/src/main/java/org/opensearch/core/transport/TransportResponse.java index 4ae01e140a89c..3f0d2595365c7 100644 --- a/libs/core/src/main/java/org/opensearch/core/transport/TransportResponse.java +++ b/libs/core/src/main/java/org/opensearch/core/transport/TransportResponse.java @@ -37,6 +37,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; /** * Response over the transport interface @@ -60,6 +62,24 @@ public TransportResponse(StreamInput in) throws IOException { super(in); } + /** + * Constructs a new transport response with the data from the byte array. This is + * currently a no-op. However, this exists to allow extenders to call super(in) + * so that reading can mirror writing where we often call super.writeTo(out). + */ + public TransportResponse(InputStream in) throws IOException { + super(in); + } + + /** + * Writes this response to the {@linkplain OutputStream}. This is added here so that classes + * don't have to implement since it is an experimental feature and only being added for + * search apis incrementally. + */ + public void writeTo(OutputStream out) throws IOException { + // no-op + } + /** * Empty transport response * @@ -75,5 +95,8 @@ public String toString() { @Override public void writeTo(StreamOutput out) throws IOException {} + + @Override + public void writeTo(OutputStream out) throws IOException {} } } diff --git a/server/build.gradle b/server/build.gradle index cb48142a61159..9ea9f6a81425f 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -350,6 +350,32 @@ tasks.named("dependencyLicenses").configure { } } +tasks.named("missingJavadoc").configure { + /* + * annotate_code in L210 does not add the Generated annotation to nested code generated using protobuf. + * TODO: Add support to missingJavadoc task to ignore all such nested classes. + * https://github.com/opensearch-project/OpenSearch/issues/11913 + */ + dependsOn("generateProto") + javadocMissingIgnore = [ + "org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.RescoreDocIds.setIntegerOrBuilder", + "org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.RescoreDocIdsOrBuilder", + "org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDocOrBuilder", + "org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocsOrBuilder", + "org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocsAndMaxScoreOrBuilder", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.SearchSortValuesOrBuilder", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.HighlightFieldOrBuilder", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.DocumentFieldOrBuilder", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.NestedIdentityOrBuilder", + "org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.MessageCase", + "org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersListOrBuilder", + "org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.HeaderOrBuilder", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.Explanation.ExplanationValueCase", + "org.opensearch.server.proto.FetchSearchResultProto.SearchHit.ExplanationOrBuilder", + "org.opensearch.server.proto.ShardSearchRequestProto.OriginalIndices.IndicesOptionsOrBuilder", + ] +} + tasks.named("filepermissions").configure { mustRunAfter("generateProto") } @@ -364,6 +390,7 @@ tasks.named("licenseHeaders").configure { excludes << 'org/opensearch/client/documentation/placeholder.txt' // Ignore for protobuf generated code excludes << 'org/opensearch/extensions/proto/*' + excludes << 'org/opensearch/server/proto/*' } tasks.test { diff --git a/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java b/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java new file mode 100644 index 0000000000000..e62eaed1f51a4 --- /dev/null +++ b/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java @@ -0,0 +1,78 @@ +/* +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +package org.opensearch.action; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.BytesWriteable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * A simple base class for action response listeners, defaulting to using the SAME executor (as its +* very common on response handlers). +* +* @opensearch.api +*/ +public class ProtobufActionListenerResponseHandler implements TransportResponseHandler { + + private final ActionListener listener; + private final BytesWriteable.Reader reader; + private final String executor; + + public ProtobufActionListenerResponseHandler( + ActionListener listener, + BytesWriteable.Reader reader, + String executor + ) { + this.listener = Objects.requireNonNull(listener); + this.reader = Objects.requireNonNull(reader); + this.executor = Objects.requireNonNull(executor); + } + + public ProtobufActionListenerResponseHandler(ActionListener listener, BytesWriteable.Reader reader) { + this(listener, reader, ThreadPool.Names.SAME); + } + + @Override + public void handleResponse(Response response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + + @Override + public String executor() { + return executor; + } + + @Override + public String toString() { + return super.toString() + "/" + listener; + } + + @Override + public Response read(StreamInput in) throws IOException { + throw new UnsupportedOperationException("Unimplemented method 'read'"); + } + + @Override + public Response read(InputStream in) throws IOException { + return reader.read(in); + } +} diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index 64c738f633f2e..ce244883d457c 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -35,12 +35,15 @@ import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.IndicesRequest; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.ProtobufActionListenerResponseHandler; import org.opensearch.action.support.ChannelActionListener; import org.opensearch.action.support.IndicesOptions; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Nullable; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.BytesWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -67,6 +70,7 @@ import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; import java.io.IOException; @@ -241,16 +245,21 @@ public void sendExecuteQuery( // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request // this used to be the QUERY_AND_FETCH which doesn't exist anymore. final boolean fetchDocuments = request.numberOfShards() == 1; - Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; - final ActionListener handler = responseWrapper.apply(connection, listener); - transportService.sendChildRequest( - connection, - QUERY_ACTION_NAME, - request, - task, - new ConnectionCountingHandler<>(handler, reader, clientConnections, connection.getNode().getId()) - ); + TransportResponseHandler transportResponseHandler; + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + BytesWriteable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; + transportResponseHandler = new ProtobufConnectionCountingHandler<>( + handler, + reader, + clientConnections, + connection.getNode().getId() + ); + } else { + Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; + transportResponseHandler = new ConnectionCountingHandler<>(handler, reader, clientConnections, connection.getNode().getId()); + } + transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task, transportResponseHandler); } public void sendExecuteQuery( @@ -775,4 +784,57 @@ private boolean assertNodePresent() { return true; } } + + /** + * A handler that counts connections for protobuf + * + * @opensearch.internal + */ + final class ProtobufConnectionCountingHandler extends ProtobufActionListenerResponseHandler< + Response> { + private final Map clientConnections; + private final String nodeId; + + ProtobufConnectionCountingHandler( + final ActionListener listener, + final BytesWriteable.Reader responseReader, + final Map clientConnections, + final String nodeId + ) { + super(listener, responseReader); + this.clientConnections = clientConnections; + this.nodeId = nodeId; + // Increment the number of connections for this node by one + clientConnections.compute(nodeId, (id, conns) -> conns == null ? 1 : conns + 1); + } + + @Override + public void handleResponse(Response response) { + super.handleResponse(response); + // Decrement the number of connections or remove it entirely if there are no more connections + // We need to remove the entry here so we don't leak when nodes go away forever + assert assertNodePresent(); + clientConnections.computeIfPresent(nodeId, (id, conns) -> conns.longValue() == 1 ? null : conns - 1); + } + + @Override + public void handleException(TransportException e) { + super.handleException(e); + // Decrement the number of connections or remove it entirely if there are no more connections + // We need to remove the entry here so we don't leak when nodes go away forever + assert assertNodePresent(); + clientConnections.computeIfPresent(nodeId, (id, conns) -> conns.longValue() == 1 ? null : conns - 1); + } + + private boolean assertNodePresent() { + clientConnections.compute(nodeId, (id, conns) -> { + assert conns != null : "number of connections for " + id + " is null, but should be an integer"; + assert conns >= 1 : "number of connections for " + id + " should be >= 1 but was " + conns; + return conns; + }); + // Always return true, there is additional asserting here, the boolean is just so this + // can be skipped when assertions are not enabled + return true; + } + } } diff --git a/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java new file mode 100644 index 0000000000000..e606773d07826 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.document.serializer; + +import com.google.protobuf.ByteString; +import org.opensearch.OpenSearchException; +import org.opensearch.common.document.DocumentField; +import org.opensearch.core.common.text.Text; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder; + +import java.io.IOException; +import java.io.InputStream; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Serializer for {@link DocumentField} to/from protobuf. + */ +public class DocumentFieldProtobufSerializer implements DocumentFieldSerializer { + + private FetchSearchResultProto.SearchHit.DocumentField documentField; + + @Override + public DocumentField createDocumentField(InputStream inputStream) throws IOException { + documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(inputStream); + String name = documentField.getName(); + List values = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) { + values.add(readDocumentFieldValueFromProtobuf(value)); + } + return new DocumentField(name, values); + } + + private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException { + if (documentFieldValue.hasValueString()) { + return documentFieldValue.getValueString(); + } else if (documentFieldValue.hasValueInt()) { + return documentFieldValue.getValueInt(); + } else if (documentFieldValue.hasValueLong()) { + return documentFieldValue.getValueLong(); + } else if (documentFieldValue.hasValueFloat()) { + return documentFieldValue.getValueFloat(); + } else if (documentFieldValue.hasValueDouble()) { + return documentFieldValue.getValueDouble(); + } else if (documentFieldValue.hasValueBool()) { + return documentFieldValue.getValueBool(); + } else if (documentFieldValue.getValueByteArrayList().size() > 0) { + return documentFieldValue.getValueByteArrayList().toArray(); + } else if (documentFieldValue.getValueArrayListList().size() > 0) { + List list = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) { + list.add(readDocumentFieldValueFromProtobuf(value)); + } + return list; + } else if (documentFieldValue.getValueMapMap().size() > 0) { + Map map = Map.of(); + for (Map.Entry entrySet : documentFieldValue.getValueMapMap().entrySet()) { + map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue())); + } + return map; + } else if (documentFieldValue.hasValueDate()) { + return new Date(documentFieldValue.getValueDate()); + } else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) { + return ZonedDateTime.ofInstant( + Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()), + ZoneId.of(documentFieldValue.getValueZonedDate()) + ); + } else if (documentFieldValue.hasValueText()) { + return new Text(documentFieldValue.getValueText()); + } else { + throw new IOException("Can't read generic value of type [" + documentFieldValue + "]"); + } + } + + public static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) { + if (value == null) { + // null is not allowed in protobuf, so we use a special string to represent null + return valueBuilder.setValueString("null"); + } + Class type = value.getClass(); + if (type == String.class) { + valueBuilder.setValueString((String) value); + } else if (type == Integer.class) { + valueBuilder.setValueInt((Integer) value); + } else if (type == Long.class) { + valueBuilder.setValueLong((Long) value); + } else if (type == Float.class) { + valueBuilder.setValueFloat((Float) value); + } else if (type == Double.class) { + valueBuilder.setValueDouble((Double) value); + } else if (type == Boolean.class) { + valueBuilder.setValueBool((Boolean) value); + } else if (type == byte[].class) { + valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value)); + } else if (type == List.class) { + List list = (List) value; + for (Object listValue : list) { + valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder)); + } + } else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) { + Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build()); + } + } else if (type == Date.class) { + valueBuilder.setValueDate(((Date) value).getTime()); + } else if (type == ZonedDateTime.class) { + valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId()); + valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli()); + } else if (type == Text.class) { + valueBuilder.setValueText(((Text) value).string()); + } else { + throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf"); + } + return valueBuilder; + } + + public static FetchSearchResultProto.SearchHit.DocumentField convertDocumentFieldToProto(DocumentField documentField) { + FetchSearchResultProto.SearchHit.DocumentField.Builder builder = FetchSearchResultProto.SearchHit.DocumentField.newBuilder(); + builder.setName(documentField.getName()); + for (Object value : documentField.getValues()) { + FetchSearchResultProto.DocumentFieldValue.Builder valueBuilder = FetchSearchResultProto.DocumentFieldValue.newBuilder(); + builder.addValues(DocumentFieldProtobufSerializer.convertDocumentFieldValueToProto(value, valueBuilder)); + } + return builder.build(); + } + +} diff --git a/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java new file mode 100644 index 0000000000000..48916c4b87b51 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.document.serializer; + +import org.opensearch.common.document.DocumentField; + +import java.io.IOException; + +/** + * Serializer for {@link DocumentField} which can be implemented for different types of serialization. + */ +public interface DocumentFieldSerializer { + + DocumentField createDocumentField(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/common/document/serializer/package-info.java b/server/src/main/java/org/opensearch/common/document/serializer/package-info.java new file mode 100644 index 0000000000000..e8419ac59bb03 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for documents. */ +package org.opensearch.common.document.serializer; diff --git a/server/src/main/java/org/opensearch/common/lucene/Lucene.java b/server/src/main/java/org/opensearch/common/lucene/Lucene.java index 2c7b6b552b43f..d195859b63fed 100644 --- a/server/src/main/java/org/opensearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/opensearch/common/lucene/Lucene.java @@ -93,6 +93,7 @@ import org.opensearch.index.analysis.NamedAnalyzer; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.search.sort.SortedWiderNumericSortField; +import org.opensearch.server.proto.FetchSearchResultProto; import java.io.IOException; import java.math.BigInteger; @@ -651,6 +652,29 @@ public static Explanation readExplanation(StreamInput in) throws IOException { } } + public static Explanation readExplanation(byte[] in) throws IOException { + FetchSearchResultProto.SearchHit.Explanation explanationProto = FetchSearchResultProto.SearchHit.Explanation.parseFrom(in); + boolean match = explanationProto.getMatch(); + String description = explanationProto.getDescription(); + final Explanation[] subExplanations = new Explanation[explanationProto.getSubExplanationsCount()]; + for (int i = 0; i < subExplanations.length; ++i) { + subExplanations[i] = readExplanation(explanationProto.getSubExplanations(i).toByteArray()); + } + Number explanationValue = null; + if (explanationProto.hasValue1()) { + explanationValue = explanationProto.getValue1(); + } else if (explanationProto.hasValue2()) { + explanationValue = explanationProto.getValue2(); + } else if (explanationProto.hasValue3()) { + explanationValue = explanationProto.getValue3(); + } + if (match) { + return Explanation.match(explanationValue, description, subExplanations); + } else { + return Explanation.noMatch(description, subExplanations); + } + } + private static void writeExplanationValue(StreamOutput out, Number value) throws IOException { if (value instanceof Float) { out.writeByte((byte) 0); diff --git a/server/src/main/java/org/opensearch/common/settings/FeatureFlagSettings.java b/server/src/main/java/org/opensearch/common/settings/FeatureFlagSettings.java index 985eb40711e16..832138739d44b 100644 --- a/server/src/main/java/org/opensearch/common/settings/FeatureFlagSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/FeatureFlagSettings.java @@ -36,6 +36,7 @@ protected FeatureFlagSettings( FeatureFlags.DATETIME_FORMATTER_CACHING_SETTING, FeatureFlags.WRITEABLE_REMOTE_INDEX_SETTING, FeatureFlags.REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, - FeatureFlags.PLUGGABLE_CACHE_SETTING + FeatureFlags.PLUGGABLE_CACHE_SETTING, + FeatureFlags.PROTOBUF_SETTING ); } diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index bdfce72d106d3..9dbc340e54689 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -61,6 +61,11 @@ public class FeatureFlags { */ public static final String WRITEABLE_REMOTE_INDEX = "opensearch.experimental.feature.writeable_remote_index.enabled"; + /** + * Gates the functionality of integrating protobuf within search API and node-to-node communication. + */ + public static final String PROTOBUF = "opensearch.experimental.feature.search_with_protobuf.enabled"; + /** * Gates the functionality of pluggable cache. * Enables OpenSearch to use pluggable caches with respective store names via setting. @@ -93,6 +98,8 @@ public class FeatureFlags { public static final Setting PLUGGABLE_CACHE_SETTING = Setting.boolSetting(PLUGGABLE_CACHE, false, Property.NodeScope); + public static final Setting PROTOBUF_SETTING = Setting.boolSetting(PROTOBUF, false, Property.NodeScope, Property.Dynamic); + private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, EXTENSIONS_SETTING, @@ -100,7 +107,8 @@ public class FeatureFlags { TELEMETRY_SETTING, DATETIME_FORMATTER_CACHING_SETTING, WRITEABLE_REMOTE_INDEX_SETTING, - PLUGGABLE_CACHE_SETTING + PLUGGABLE_CACHE_SETTING, + PROTOBUF_SETTING ); /** * Should store the settings from opensearch.yml. diff --git a/server/src/main/java/org/opensearch/search/SearchPhaseResult.java b/server/src/main/java/org/opensearch/search/SearchPhaseResult.java index a351b3bd2dda6..ac411e1bdb362 100644 --- a/server/src/main/java/org/opensearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/opensearch/search/SearchPhaseResult.java @@ -43,6 +43,7 @@ import org.opensearch.search.query.QuerySearchResult; import java.io.IOException; +import java.io.InputStream; /** * This class is a base class for all search related results. It contains the shard target it @@ -71,6 +72,10 @@ protected SearchPhaseResult(StreamInput in) throws IOException { super(in); } + protected SearchPhaseResult(InputStream in) throws IOException { + super(in); + } + /** * Returns the search context ID that is used to reference the search context on the executing node * or null if no context was created. diff --git a/server/src/main/java/org/opensearch/search/SearchSortValues.java b/server/src/main/java/org/opensearch/search/SearchSortValues.java index cbc3900f72f79..d03cc80b90de3 100644 --- a/server/src/main/java/org/opensearch/search/SearchSortValues.java +++ b/server/src/main/java/org/opensearch/search/SearchSortValues.java @@ -67,6 +67,11 @@ public class SearchSortValues implements ToXContentFragment, Writeable { this.rawSortValues = EMPTY_ARRAY; } + public SearchSortValues(Object[] sortValues, Object[] rawSortValues) { + this.formattedSortValues = Objects.requireNonNull(sortValues, "sort values must not be empty"); + this.rawSortValues = rawSortValues; + } + public SearchSortValues(Object[] rawSortValues, DocValueFormat[] sortValueFormats) { Objects.requireNonNull(rawSortValues); Objects.requireNonNull(sortValueFormats); diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java index 26fa90141c2a9..681398d36e07b 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java @@ -33,6 +33,7 @@ package org.opensearch.search.fetch; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchHit; @@ -41,8 +42,14 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.serializer.SearchHitsProtobufSerializer; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.ShardSearchRequestProto; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; /** * Result from a fetch @@ -56,6 +63,8 @@ public final class FetchSearchResult extends SearchPhaseResult { // client side counter private transient int counter; + private FetchSearchResultProto.FetchSearchResult fetchSearchResultProto; + public FetchSearchResult() {} public FetchSearchResult(StreamInput in) throws IOException { @@ -64,9 +73,25 @@ public FetchSearchResult(StreamInput in) throws IOException { hits = new SearchHits(in); } + public FetchSearchResult(InputStream in) throws IOException { + super(in); + this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in); + contextId = new ShardSearchContextId( + this.fetchSearchResultProto.getContextId().getSessionId(), + this.fetchSearchResultProto.getContextId().getId() + ); + SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); + } + public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) { this.contextId = id; setSearchShardTarget(shardTarget); + this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.newBuilder() + .setContextId( + ShardSearchRequestProto.ShardSearchContextId.newBuilder().setSessionId(id.getSessionId()).setId(id.getId()).build() + ) + .build(); } @Override @@ -82,6 +107,11 @@ public FetchSearchResult fetchResult() { public void hits(SearchHits hits) { assert assertNoSearchTarget(hits); this.hits = hits; + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { + this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder() + .setHits(SearchHitsProtobufSerializer.convertHitsToProto(hits)) + .build(); + } } private boolean assertNoSearchTarget(SearchHits hits) { @@ -92,6 +122,16 @@ private boolean assertNoSearchTarget(SearchHits hits) { } public SearchHits hits() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { + SearchHits hits; + try { + SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); + return hits; + } catch (IOException e) { + throw new RuntimeException(e); + } + } return hits; } @@ -109,4 +149,17 @@ public void writeTo(StreamOutput out) throws IOException { contextId.writeTo(out); hits.writeTo(out); } + + @Override + public void writeTo(OutputStream out) throws IOException { + out.write(fetchSearchResultProto.toByteArray()); + } + + public FetchSearchResultProto.FetchSearchResult response() { + return this.fetchSearchResultProto; + } + + public FetchSearchResult(FetchSearchResultProto.FetchSearchResult fetchSearchResult) { + this.fetchSearchResultProto = fetchSearchResult; + } } diff --git a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java index ce4c59fc77489..8531fe027abd4 100644 --- a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java @@ -32,14 +32,19 @@ package org.opensearch.search.fetch; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.server.proto.QueryFetchSearchResultProto; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; +import java.io.InputStream; /** * Query fetch result @@ -51,15 +56,30 @@ public final class QueryFetchSearchResult extends SearchPhaseResult { private final QuerySearchResult queryResult; private final FetchSearchResult fetchResult; + private QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResultProto; + public QueryFetchSearchResult(StreamInput in) throws IOException { super(in); queryResult = new QuerySearchResult(in); fetchResult = new FetchSearchResult(in); } + public QueryFetchSearchResult(InputStream in) throws IOException { + super(in); + this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in); + queryResult = new QuerySearchResult(in); + fetchResult = new FetchSearchResult(in); + } + public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) { this.queryResult = queryResult; this.fetchResult = fetchResult; + if (queryResult.response() != null && fetchResult.response() != null) { + this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.newBuilder() + .setQueryResult(queryResult.response()) + .setFetchResult(fetchResult.response()) + .build(); + } } @Override @@ -101,4 +121,23 @@ public void writeTo(StreamOutput out) throws IOException { queryResult.writeTo(out); fetchResult.writeTo(out); } + + @Override + public String getProtocol() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return ProtobufInboundMessage.PROTOBUF_PROTOCOL; + } + return NativeInboundMessage.NATIVE_PROTOCOL; + } + + public QueryFetchSearchResultProto.QueryFetchSearchResult response() { + return this.queryFetchSearchResultProto; + } + + public QueryFetchSearchResult(QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResult) { + this.queryFetchSearchResultProto = queryFetchSearchResult; + this.queryResult = new QuerySearchResult(queryFetchSearchResult.getQueryResult()); + this.fetchResult = new FetchSearchResult(queryFetchSearchResult.getFetchResult()); + } + } diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java new file mode 100644 index 0000000000000..74557d5618ce0 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer; + +import org.opensearch.core.common.text.Text; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Serializer for {@link HighlightField} to/from protobuf. + */ +public class HighlightFieldProtobufSerializer implements HighlightFieldSerializer { + + @Override + public HighlightField createHighLightField(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.HighlightField highlightField = FetchSearchResultProto.SearchHit.HighlightField.parseFrom( + inputStream + ); + String name = highlightField.getName(); + Text[] fragments = Text.EMPTY_ARRAY; + if (highlightField.getFragmentsCount() > 0) { + List values = new ArrayList<>(); + for (String fragment : highlightField.getFragmentsList()) { + values.add(new Text(fragment)); + } + fragments = values.toArray(new Text[0]); + } + return new HighlightField(name, fragments); + } + +} diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java new file mode 100644 index 0000000000000..21a6afdac565f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer; + +import org.opensearch.search.fetch.subphase.highlight.HighlightField; + +import java.io.IOException; + +/** + * Serializer for {@link HighlightField} which can be implemented for different types of serialization. + */ +public interface HighlightFieldSerializer { + + HighlightField createHighLightField(T inputStream) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java new file mode 100644 index 0000000000000..dc08282a8954f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for highlights. */ +package org.opensearch.search.fetch.subphase.highlight.serializer; diff --git a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java index de1d5fb8b4098..0674101dc888f 100644 --- a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java @@ -67,6 +67,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.server.proto.ShardSearchRequestProto; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; @@ -99,7 +100,7 @@ public class ShardSearchRequest extends TransportRequest implements IndicesReque private final long nowInMillis; private long inboundNetworkTime; private long outboundNetworkTime; - private final boolean allowPartialSearchResults; + private final Boolean allowPartialSearchResults; private final String[] indexRoutings; private final String preference; private final OriginalIndices originalIndices; @@ -269,6 +270,54 @@ public ShardSearchRequest(StreamInput in) throws IOException { assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; } + public ShardSearchRequest(byte[] in) throws IOException { + ShardSearchRequestProto.ShardSearchRequest searchRequestProto = ShardSearchRequestProto.ShardSearchRequest.parseFrom(in); + this.clusterAlias = searchRequestProto.getClusterAlias(); + shardId = new ShardId( + searchRequestProto.getShardId().getIndexName(), + searchRequestProto.getShardId().getIndexUUID(), + searchRequestProto.getShardId().getShardId() + ); + this.numberOfShards = searchRequestProto.getNumberOfShards(); + // Since protobuf is currently done only for Query then Fetch types + searchType = SearchType.QUERY_THEN_FETCH; + this.scroll = searchRequestProto.hasScroll() + ? new Scroll(TimeValue.parseTimeValue(searchRequestProto.getScroll().getKeepAlive(), "keepAlive")) + : null; + this.indexBoost = searchRequestProto.getIndexBoost(); + this.requestCache = searchRequestProto.getRequestCache(); + this.nowInMillis = searchRequestProto.getNowInMillis(); + this.allowPartialSearchResults = searchRequestProto.getAllowPartialSearchResults(); + this.indexRoutings = searchRequestProto.getIndexRoutingsList().toArray(Strings.EMPTY_ARRAY); + this.preference = searchRequestProto.getPreference(); + ShardSearchRequestProto.OriginalIndices.IndicesOptions indicesOptionsFromProtobuf = searchRequestProto.getOriginalIndices() + .getIndicesOptions(); + IndicesOptions indicesOptions = IndicesOptions.fromOptions( + indicesOptionsFromProtobuf.getIgnoreUnavailable(), + indicesOptionsFromProtobuf.getAllowNoIndices(), + indicesOptionsFromProtobuf.getExpandWildcardsOpen(), + indicesOptionsFromProtobuf.getExpandWildcardsClosed(), + indicesOptionsFromProtobuf.getExpandWildcardsHidden(), + indicesOptionsFromProtobuf.getAllowAliasesToMultipleIndices(), + indicesOptionsFromProtobuf.getForbidClosedIndices(), + indicesOptionsFromProtobuf.getIgnoreAliases(), + indicesOptionsFromProtobuf.getIgnoreThrottled() + ); + this.originalIndices = new OriginalIndices( + searchRequestProto.getOriginalIndices().getIndicesList().toArray(Strings.EMPTY_ARRAY), + indicesOptions + ); + this.readerId = searchRequestProto.hasReaderId() + ? new ShardSearchContextId(searchRequestProto.getReaderId().getSessionId(), searchRequestProto.getReaderId().getId()) + : null; + this.keepAlive = searchRequestProto.hasTimeValue() + ? TimeValue.parseTimeValue(searchRequestProto.getTimeValue(), "keepAlive") + : null; + this.aliasFilter = searchRequestProto.hasAliasFilter() + ? new AliasFilter(null, searchRequestProto.getAliasFilter().getAliasesList().toArray(Strings.EMPTY_ARRAY)) + : AliasFilter.EMPTY; + } + public ShardSearchRequest(ShardSearchRequest clone) { this.shardId = clone.shardId; this.searchType = clone.searchType; diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index f3ac953ab9d1d..0cc151766084c 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -32,11 +32,17 @@ package org.opensearch.search.query; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.io.stream.DelayableWriteable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.DocValueFormat; @@ -50,8 +56,18 @@ import org.opensearch.search.profile.NetworkTime; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.suggest.Suggest; +import org.opensearch.server.proto.QuerySearchResultProto; +import org.opensearch.server.proto.ShardSearchRequestProto; +import org.opensearch.server.proto.ShardSearchRequestProto.AliasFilter; +import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest.SearchType; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; import static org.opensearch.common.lucene.Lucene.readTopDocs; import static org.opensearch.common.lucene.Lucene.writeTopDocs; @@ -64,6 +80,8 @@ @PublicApi(since = "1.0.0") public final class QuerySearchResult extends SearchPhaseResult { + private static final Logger logger = LogManager.getLogger(QuerySearchResult.class); + private int from; private int size; private TopDocsAndMaxScore topDocsAndMaxScore; @@ -90,6 +108,8 @@ public final class QuerySearchResult extends SearchPhaseResult { private final boolean isNull; + private QuerySearchResultProto.QuerySearchResult querySearchResultProto; + public QuerySearchResult() { this(false); } @@ -103,11 +123,98 @@ public QuerySearchResult(StreamInput in) throws IOException { } } + public QuerySearchResult(InputStream in) throws IOException { + super(in); + this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.parseFrom(in); + isNull = this.querySearchResultProto.getIsNull(); + if (!isNull) { + this.contextId = new ShardSearchContextId( + this.querySearchResultProto.getContextId().getSessionId(), + this.querySearchResultProto.getContextId().getId() + ); + ShardSearchRequest shardSearchRequest; + hasAggs = false; + try { + shardSearchRequest = new ShardSearchRequest(this.querySearchResultProto.getSearchShardRequest().toByteArray()); + setShardSearchRequest(shardSearchRequest); + } catch (IOException e) { + logger.error("Error while setting shard search request", e); + } + } + } + public QuerySearchResult(ShardSearchContextId contextId, SearchShardTarget shardTarget, ShardSearchRequest shardSearchRequest) { this.contextId = contextId; setSearchShardTarget(shardTarget); isNull = false; setShardSearchRequest(shardSearchRequest); + + ShardSearchRequestProto.ShardId shardIdProto = ShardSearchRequestProto.ShardId.newBuilder() + .setShardId(shardTarget.getShardId().getId()) + .setHashCode(shardTarget.getShardId().hashCode()) + .setIndexName(shardTarget.getShardId().getIndexName()) + .setIndexUUID(shardTarget.getShardId().getIndex().getUUID()) + .build(); + QuerySearchResultProto.SearchShardTarget.Builder searchShardTarget = QuerySearchResultProto.SearchShardTarget.newBuilder() + .setNodeId(shardTarget.getNodeId()) + .setShardId(shardIdProto); + ShardSearchRequestProto.ShardSearchContextId shardSearchContextId = ShardSearchRequestProto.ShardSearchContextId.newBuilder() + .setSessionId(contextId.getSessionId()) + .setId(contextId.getId()) + .build(); + ShardSearchRequestProto.ShardSearchRequest.Builder shardSearchRequestProto = ShardSearchRequestProto.ShardSearchRequest + .newBuilder(); + if (shardSearchRequest != null) { + ShardSearchRequestProto.OriginalIndices.Builder originalIndices = ShardSearchRequestProto.OriginalIndices.newBuilder(); + if (shardSearchRequest.indices() != null) { + for (String index : shardSearchRequest.indices()) { + originalIndices.addIndices(index); + } + originalIndices.setIndicesOptions( + ShardSearchRequestProto.OriginalIndices.IndicesOptions.newBuilder() + .setIgnoreUnavailable(shardSearchRequest.indicesOptions().ignoreUnavailable()) + .setAllowNoIndices(shardSearchRequest.indicesOptions().allowNoIndices()) + .setExpandWildcardsOpen(shardSearchRequest.indicesOptions().expandWildcardsOpen()) + .setExpandWildcardsClosed(shardSearchRequest.indicesOptions().expandWildcardsClosed()) + .setExpandWildcardsHidden(shardSearchRequest.indicesOptions().expandWildcardsHidden()) + .setAllowAliasesToMultipleIndices(shardSearchRequest.indicesOptions().allowAliasesToMultipleIndices()) + .setForbidClosedIndices(shardSearchRequest.indicesOptions().forbidClosedIndices()) + .setIgnoreAliases(shardSearchRequest.indicesOptions().ignoreAliases()) + .setIgnoreThrottled(shardSearchRequest.indicesOptions().ignoreThrottled()) + .build() + ); + } + AliasFilter.Builder aliasFilter = AliasFilter.newBuilder(); + if (shardSearchRequest.getAliasFilter() != null) { + for (int i = 0; i < shardSearchRequest.getAliasFilter().getAliases().length; i++) { + aliasFilter.addAliases(shardSearchRequest.getAliasFilter().getAliases()[i]); + } + } + shardSearchRequestProto.setInboundNetworkTime(shardSearchRequest.getInboundNetworkTime()) + .setOutboundNetworkTime(shardSearchRequest.getOutboundNetworkTime()) + .setShardId(shardIdProto) + .setAllowPartialSearchResults(shardSearchRequest.allowPartialSearchResults()) + .setNumberOfShards(shardSearchRequest.numberOfShards()) + .setReaderId(shardSearchContextId) + .setOriginalIndices(originalIndices) + .setSearchType(SearchType.QUERY_THEN_FETCH) + .setAliasFilter(aliasFilter); + if (shardSearchRequest.keepAlive() != null) { + shardSearchRequestProto.setTimeValue(shardSearchRequest.keepAlive().getStringRep()); + } + } + + if (shardTarget.getClusterAlias() != null) { + searchShardTarget.setClusterAlias(shardTarget.getClusterAlias()); + } + + this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.newBuilder() + .setContextId(shardSearchContextId) + .setSearchShardTarget(searchShardTarget.build()) + .setSearchShardRequest(shardSearchRequestProto.build()) + .setHasAggs(false) + .setIsNull(isNull) + .build(); } private QuerySearchResult(boolean isNull) { @@ -157,9 +264,33 @@ public Boolean terminatedEarly() { } public TopDocsAndMaxScore topDocs() { - if (topDocsAndMaxScore == null) { + if (topDocsAndMaxScore == null && this.querySearchResultProto.getTopDocsAndMaxScore() == null) { throw new IllegalStateException("topDocs already consumed"); } + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + ScoreDoc[] scoreDocs = new ScoreDoc[this.querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getScoreDocsCount()]; + for (int i = 0; i < scoreDocs.length; i++) { + org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDoc scoreDoc = this.querySearchResultProto + .getTopDocsAndMaxScore() + .getTopDocs() + .getScoreDocsList() + .get(i); + scoreDocs[i] = new ScoreDoc(scoreDoc.getDoc(), scoreDoc.getScore(), scoreDoc.getShardIndex()); + } + TopDocs topDocsFromProtobuf = new TopDocs( + new TotalHits( + this.querySearchResultProto.getTotalHits().getValue(), + Relation.valueOf(this.querySearchResultProto.getTotalHits().getRelation().toString()) + ), + scoreDocs + ); + + TopDocsAndMaxScore topDocsFromProtobufAndMaxScore = new TopDocsAndMaxScore( + topDocsFromProtobuf, + this.querySearchResultProto.getMaxScore() + ); + return topDocsFromProtobufAndMaxScore; + } return topDocsAndMaxScore; } @@ -175,7 +306,7 @@ public boolean hasConsumedTopDocs() { * @throws IllegalStateException if the top docs have already been consumed. */ public TopDocsAndMaxScore consumeTopDocs() { - TopDocsAndMaxScore topDocsAndMaxScore = this.topDocsAndMaxScore; + TopDocsAndMaxScore topDocsAndMaxScore = this.topDocsAndMaxScore == null ? topDocs() : this.topDocsAndMaxScore; if (topDocsAndMaxScore == null) { throw new IllegalStateException("topDocs already consumed"); } @@ -201,6 +332,43 @@ private void setTopDocs(TopDocsAndMaxScore topDocsAndMaxScore) { this.totalHits = topDocsAndMaxScore.topDocs.totalHits; this.maxScore = topDocsAndMaxScore.maxScore; this.hasScoreDocs = topDocsAndMaxScore.topDocs.scoreDocs.length > 0; + + if (this.querySearchResultProto != null) { + List scoreDocs = new ArrayList<>(); + if (this.hasScoreDocs) { + for (ScoreDoc scoreDoc : topDocsAndMaxScore.topDocs.scoreDocs) { + scoreDocs.add( + QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDoc.newBuilder() + .setDoc(scoreDoc.doc) + .setScore(scoreDoc.score) + .setShardIndex(scoreDoc.shardIndex) + .build() + ); + } + } + QuerySearchResultProto.QuerySearchResult.TopDocs topDocsBuilder = QuerySearchResultProto.QuerySearchResult.TopDocs.newBuilder() + .setTotalHits( + QuerySearchResultProto.TotalHits.newBuilder() + .setValue(topDocsAndMaxScore.topDocs.totalHits.value) + .setRelation( + QuerySearchResultProto.TotalHits.Relation.valueOf(topDocsAndMaxScore.topDocs.totalHits.relation.name()) + ) + .build() + ) + .addAllScoreDocs(scoreDocs) + .build(); + QuerySearchResultProto.QuerySearchResult.TopDocsAndMaxScore topDocsAndMaxScoreBuilder = + QuerySearchResultProto.QuerySearchResult.TopDocsAndMaxScore.newBuilder() + .setMaxScore(topDocsAndMaxScore.maxScore) + .setTopDocs(topDocsBuilder) + .build(); + this.querySearchResultProto = this.querySearchResultProto.toBuilder() + .setTopDocsAndMaxScore(topDocsAndMaxScoreBuilder) + .setMaxScore(this.maxScore) + .setTotalHits(topDocsBuilder.getTotalHits()) + .setHasScoreDocs(this.hasScoreDocs) + .build(); + } } public DocValueFormat[] sortValueFormats() { @@ -289,11 +457,17 @@ public void suggest(Suggest suggest) { } public int from() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return this.querySearchResultProto.getFrom(); + } return from; } public QuerySearchResult from(int from) { this.from = from; + if (this.querySearchResultProto != null) { + this.querySearchResultProto = this.querySearchResultProto.toBuilder().setFrom(from).build(); + } return this; } @@ -301,11 +475,17 @@ public QuerySearchResult from(int from) { * Returns the maximum size of this results top docs. */ public int size() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return this.querySearchResultProto.getSize(); + } return size; } public QuerySearchResult size(int size) { this.size = size; + if (this.querySearchResultProto != null) { + this.querySearchResultProto = this.querySearchResultProto.toBuilder().setSize(size).build(); + } return this; } @@ -377,6 +557,13 @@ public void writeTo(StreamOutput out) throws IOException { } } + @Override + public void writeTo(OutputStream out) throws IOException { + if (!isNull) { + out.write(this.querySearchResultProto.toByteArray()); + } + } + public void writeToNoId(StreamOutput out) throws IOException { out.writeVInt(from); out.writeVInt(size); @@ -417,4 +604,32 @@ public TotalHits getTotalHits() { public float getMaxScore() { return maxScore; } + + public QuerySearchResultProto.QuerySearchResult response() { + return this.querySearchResultProto; + } + + public QuerySearchResult(QuerySearchResultProto.QuerySearchResult querySearchResult) { + this.querySearchResultProto = querySearchResult; + this.isNull = this.querySearchResultProto.getIsNull(); + this.contextId = new ShardSearchContextId( + this.querySearchResultProto.getContextId().getSessionId(), + this.querySearchResultProto.getContextId().getId() + ); + ShardSearchRequest shardSearchRequest; + try { + shardSearchRequest = new ShardSearchRequest(this.querySearchResultProto.getSearchShardRequest().toByteArray()); + setShardSearchRequest(shardSearchRequest); + } catch (IOException e) { + logger.error("Error while setting shard search request", e); + } + } + + @Override + public String getProtocol() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return ProtobufInboundMessage.PROTOBUF_PROTOCOL; + } + return NativeInboundMessage.NATIVE_PROTOCOL; + } } diff --git a/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java new file mode 100644 index 0000000000000..0e1081c9cc8d8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * Serializer for {@link NestedIdentity} to/from protobuf. + */ +public class NestedIdentityProtobufSerializer implements NestedIdentitySerializer { + + @Override + public NestedIdentity createNestedIdentity(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.NestedIdentity proto = FetchSearchResultProto.SearchHit.NestedIdentity.parseFrom(inputStream); + String field; + int offset; + NestedIdentity child; + if (proto.hasField()) { + field = proto.getField(); + } else { + field = null; + } + if (proto.hasOffset()) { + offset = proto.getOffset(); + } else { + offset = -1; + } + if (proto.hasChild()) { + child = createNestedIdentity(new ByteArrayInputStream(proto.getChild().toByteArray())); + } else { + child = null; + } + return new NestedIdentity(field, offset, child); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java b/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java new file mode 100644 index 0000000000000..5ee30337bb4b6 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit.NestedIdentity; + +import java.io.IOException; + +/** + * Serializer for {@link NestedIdentity} which can be implemented for different types of serialization. + */ +public interface NestedIdentitySerializer { + + public NestedIdentity createNestedIdentity(T inputStream) throws IOException, Exception; +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java new file mode 100644 index 0000000000000..e06e31bee249f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java @@ -0,0 +1,180 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.Explanation; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.document.serializer.DocumentFieldProtobufSerializer; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.SearchSortValues; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.search.fetch.subphase.highlight.serializer.HighlightFieldProtobufSerializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Serializer for {@link SearchHit} to/from protobuf. + */ +public class SearchHitProtobufSerializer implements SearchHitSerializer { + + private FetchSearchResultProto.SearchHit searchHitProto; + + @Override + public SearchHit createSearchHit(InputStream inputStream) throws IOException { + this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(inputStream); + int docId = -1; + float score = this.searchHitProto.getScore(); + String id = this.searchHitProto.getId(); + NestedIdentity nestedIdentity; + if (!this.searchHitProto.hasNestedIdentity() && this.searchHitProto.getNestedIdentity().toByteArray().length > 0) { + NestedIdentityProtobufSerializer protobufSerializer = new NestedIdentityProtobufSerializer(); + nestedIdentity = protobufSerializer.createNestedIdentity( + new ByteArrayInputStream(this.searchHitProto.getNestedIdentity().toByteArray()) + ); + } else { + nestedIdentity = null; + } + long version = this.searchHitProto.getVersion(); + long seqNo = this.searchHitProto.getSeqNo(); + long primaryTerm = this.searchHitProto.getPrimaryTerm(); + BytesReference source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray())); + if (source.length() == 0) { + source = null; + } + Map documentFields = new HashMap<>(); + DocumentFieldProtobufSerializer protobufSerializer = new DocumentFieldProtobufSerializer(); + this.searchHitProto.getDocumentFieldsMap().forEach((k, v) -> { + try { + documentFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map metaFields = new HashMap<>(); + this.searchHitProto.getMetaFieldsMap().forEach((k, v) -> { + try { + metaFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map highlightFields = new HashMap<>(); + HighlightFieldProtobufSerializer highlightFieldProtobufSerializer = new HighlightFieldProtobufSerializer(); + this.searchHitProto.getHighlightFieldsMap().forEach((k, v) -> { + try { + highlightFields.put(k, highlightFieldProtobufSerializer.createHighLightField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse highlight field", e); + } + }); + SearchSortValuesProtobufSerializer sortValueProtobufSerializer = new SearchSortValuesProtobufSerializer(); + SearchSortValues sortValues = sortValueProtobufSerializer.createSearchSortValues( + new ByteArrayInputStream(this.searchHitProto.getSortValues().toByteArray()) + ); + Map matchedQueries = new HashMap<>(); + if (this.searchHitProto.getMatchedQueriesCount() > 0) { + matchedQueries = new LinkedHashMap<>(this.searchHitProto.getMatchedQueriesCount()); + for (String query : this.searchHitProto.getMatchedQueriesList()) { + matchedQueries.put(query, Float.NaN); + } + } + if (this.searchHitProto.getMatchedQueriesWithScoresCount() > 0) { + Map tempMap = this.searchHitProto.getMatchedQueriesWithScoresMap() + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue())); + matchedQueries = tempMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByKey()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + } + Explanation explanation = null; + if (this.searchHitProto.hasExplanation()) { + explanation = Lucene.readExplanation(this.searchHitProto.getExplanation().toByteArray()); + } + SearchShardTarget searchShardTarget = new SearchShardTarget( + this.searchHitProto.getShard().getNodeId(), + new ShardId( + this.searchHitProto.getShard().getShardId().getIndexName(), + this.searchHitProto.getShard().getShardId().getIndexUUID(), + this.searchHitProto.getShard().getShardId().getShardId() + ), + this.searchHitProto.getShard().getClusterAlias(), + OriginalIndices.NONE + ); + Map innerHits; + if (this.searchHitProto.getInnerHitsCount() > 0) { + innerHits = new HashMap<>(); + this.searchHitProto.getInnerHitsMap().forEach((k, v) -> { + try { + SearchHitsProtobufSerializer protobufHitsFactory = new SearchHitsProtobufSerializer(); + innerHits.put(k, protobufHitsFactory.createSearchHits(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse inner hits", e); + } + }); + } else { + innerHits = null; + } + SearchHit searchHit = new SearchHit(docId, id, nestedIdentity, documentFields, metaFields); + searchHit.score(score); + searchHit.version(version); + searchHit.setSeqNo(seqNo); + searchHit.setPrimaryTerm(primaryTerm); + searchHit.sourceRef(source); + searchHit.highlightFields(highlightFields); + searchHit.sortValues(sortValues); + searchHit.matchedQueriesWithScores(matchedQueries); + searchHit.explanation(explanation); + searchHit.shard(searchShardTarget); + searchHit.setInnerHits(innerHits); + return searchHit; + } + + public static FetchSearchResultProto.SearchHit convertHitToProto(SearchHit hit) { + FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder(); + if (hit.getIndex() != null) { + searchHitBuilder.setIndex(hit.getIndex()); + } + searchHitBuilder.setId(hit.getId()); + searchHitBuilder.setScore(hit.getScore()); + searchHitBuilder.setSeqNo(hit.getSeqNo()); + searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm()); + searchHitBuilder.setVersion(hit.getVersion()); + searchHitBuilder.setDocId(hit.docId()); + if (hit.getSourceRef() != null) { + searchHitBuilder.setSource(ByteString.copyFrom(hit.getSourceRef().toBytesRef().bytes)); + } + for (Map.Entry entry : hit.getFields().entrySet()) { + searchHitBuilder.putDocumentFields( + entry.getKey(), + DocumentFieldProtobufSerializer.convertDocumentFieldToProto(entry.getValue()) + ); + } + return searchHitBuilder.build(); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java new file mode 100644 index 0000000000000..217266f720079 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit; + +import java.io.IOException; + +/** + * Serializer for {@link SearchHit} which can be implemented for different types of serialization. + */ +public interface SearchHitSerializer { + + SearchHit createSearchHit(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java new file mode 100644 index 0000000000000..d32043d9caa49 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; +import org.apache.lucene.util.BytesRef; +import org.opensearch.OpenSearchException; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.QuerySearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +/** + * Serializer for {@link SearchHits} to/from protobuf. + */ +public class SearchHitsProtobufSerializer implements SearchHitsSerializer { + + private FetchSearchResultProto.SearchHits searchHitsProto; + + @Override + public SearchHits createSearchHits(InputStream inputStream) throws IOException { + this.searchHitsProto = FetchSearchResultProto.SearchHits.parseFrom(inputStream); + SearchHit[] hits = new SearchHit[this.searchHitsProto.getHitsCount()]; + SearchHitProtobufSerializer protobufSerializer = new SearchHitProtobufSerializer(); + for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) { + hits[i] = protobufSerializer.createSearchHit(new ByteArrayInputStream(this.searchHitsProto.getHits(i).toByteArray())); + } + TotalHits totalHits = new TotalHits( + this.searchHitsProto.getTotalHits().getValue(), + Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString()) + ); + float maxScore = this.searchHitsProto.getMaxScore(); + SortField[] sortFields = this.searchHitsProto.getSortFieldsList() + .stream() + .map(sortField -> new SortField(sortField.getField(), SortField.Type.valueOf(sortField.getType().toString()))) + .toArray(SortField[]::new); + String collapseField = this.searchHitsProto.getCollapseField(); + Object[] collapseValues = new Object[this.searchHitsProto.getCollapseValuesCount()]; + for (int i = 0; i < this.searchHitsProto.getCollapseValuesCount(); i++) { + collapseValues[i] = readSortValueFromProtobuf(this.searchHitsProto.getCollapseValues(i)); + } + return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues); + } + + public static Object readSortValueFromProtobuf(FetchSearchResultProto.SortValue collapseValue) throws IOException { + if (collapseValue.hasCollapseString()) { + return collapseValue.getCollapseString(); + } else if (collapseValue.hasCollapseInt()) { + return collapseValue.getCollapseInt(); + } else if (collapseValue.hasCollapseLong()) { + return collapseValue.getCollapseLong(); + } else if (collapseValue.hasCollapseFloat()) { + return collapseValue.getCollapseFloat(); + } else if (collapseValue.hasCollapseDouble()) { + return collapseValue.getCollapseDouble(); + } else if (collapseValue.hasCollapseBytes()) { + return new BytesRef(collapseValue.getCollapseBytes().toByteArray()); + } else if (collapseValue.hasCollapseBool()) { + return collapseValue.getCollapseBool(); + } else { + throw new IOException("Can't handle sort field value of type [" + collapseValue + "]"); + } + } + + public static FetchSearchResultProto.SearchHits convertHitsToProto(SearchHits hits) { + List searchHitList = new ArrayList<>(); + for (SearchHit hit : hits) { + searchHitList.add(SearchHitProtobufSerializer.convertHitToProto(hit)); + } + QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder(); + if (hits.getTotalHits() != null) { + totalHitsBuilder.setValue(hits.getTotalHits().value); + totalHitsBuilder.setRelation( + hits.getTotalHits().relation == Relation.EQUAL_TO + ? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO + : QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ); + } + FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder(); + searchHitsBuilder.setMaxScore(hits.getMaxScore()); + searchHitsBuilder.addAllHits(searchHitList); + searchHitsBuilder.setTotalHits(totalHitsBuilder.build()); + if (hits.getSortFields() != null && hits.getSortFields().length > 0) { + for (SortField sortField : hits.getSortFields()) { + FetchSearchResultProto.SortField.Builder sortFieldBuilder = FetchSearchResultProto.SortField.newBuilder(); + if (sortField.getField() != null) { + sortFieldBuilder.setField(sortField.getField()); + } + sortFieldBuilder.setType(FetchSearchResultProto.SortField.Type.valueOf(sortField.getType().name())); + searchHitsBuilder.addSortFields(sortFieldBuilder.build()); + } + } + if (hits.getCollapseField() != null) { + searchHitsBuilder.setCollapseField(hits.getCollapseField()); + for (Object value : hits.getCollapseValues()) { + FetchSearchResultProto.SortValue.Builder collapseValueBuilder = FetchSearchResultProto.SortValue.newBuilder(); + try { + collapseValueBuilder = readSortValueForProtobuf(value, collapseValueBuilder); + } catch (IOException e) { + throw new OpenSearchException(e); + } + searchHitsBuilder.addCollapseValues(collapseValueBuilder.build()); + } + } + return searchHitsBuilder.build(); + } + + public static FetchSearchResultProto.SortValue.Builder readSortValueForProtobuf( + Object collapseValue, + FetchSearchResultProto.SortValue.Builder collapseValueBuilder + ) throws IOException { + Class type = collapseValue.getClass(); + if (type == String.class) { + collapseValueBuilder.setCollapseString((String) collapseValue); + } else if (type == Integer.class || type == Short.class) { + collapseValueBuilder.setCollapseInt((Integer) collapseValue); + } else if (type == Long.class) { + collapseValueBuilder.setCollapseLong((Long) collapseValue); + } else if (type == Float.class) { + collapseValueBuilder.setCollapseFloat((Float) collapseValue); + } else if (type == Double.class) { + collapseValueBuilder.setCollapseDouble((Double) collapseValue); + } else if (type == Byte.class) { + byte b = (Byte) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(new byte[] { b })); + } else if (type == Boolean.class) { + collapseValueBuilder.setCollapseBool((Boolean) collapseValue); + } else if (type == BytesRef.class) { + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(((BytesRef) collapseValue).bytes)); + } else if (type == BigInteger.class) { + BigInteger bigInt = (BigInteger) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(bigInt.toByteArray())); + } else { + throw new IOException("Can't handle sort field value of type [" + type + "]"); + } + return collapseValueBuilder; + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java new file mode 100644 index 0000000000000..f09e369aaeee1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHits; + +import java.io.IOException; + +/** + * Serializer for {@link SearchHits} which can be implemented for different types of serialization. + */ +public interface SearchHitsSerializer { + + SearchHits createSearchHits(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java new file mode 100644 index 0000000000000..007e951cbf10a --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchSortValues; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Serializer for {@link SearchSortValues} to/from protobuf. + */ +public class SearchSortValuesProtobufSerializer implements SearchSortValuesSerializer { + + @Override + public SearchSortValues createSearchSortValues(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.SearchSortValues searchSortValues = FetchSearchResultProto.SearchHit.SearchSortValues.parseFrom( + inputStream + ); + Object[] formattedSortValues = new Object[searchSortValues.getFormattedSortValuesCount()]; + for (int i = 0; i < searchSortValues.getFormattedSortValuesCount(); i++) { + formattedSortValues[i] = SearchHitsProtobufSerializer.readSortValueFromProtobuf(searchSortValues.getFormattedSortValues(i)); + } + Object[] rawSortValues = new Object[searchSortValues.getRawSortValuesCount()]; + for (int i = 0; i < searchSortValues.getRawSortValuesCount(); i++) { + rawSortValues[i] = SearchHitsProtobufSerializer.readSortValueFromProtobuf(searchSortValues.getRawSortValues(i)); + } + return new SearchSortValues(formattedSortValues, rawSortValues); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java new file mode 100644 index 0000000000000..31feb206eb74d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchSortValues; + +import java.io.IOException; + +/** + * Serializer for {@link SearchSortValues} which can be implemented for different types of serialization. + */ +public interface SearchSortValuesSerializer { + + SearchSortValues createSearchSortValues(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/package-info.java b/server/src/main/java/org/opensearch/search/serializer/package-info.java new file mode 100644 index 0000000000000..25a4d1935016e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for search. */ +package org.opensearch.search.serializer; diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index 6492900c49a0e..d58d80857da0b 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -33,10 +33,13 @@ package org.opensearch.transport; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufMessageHandler; import java.io.IOException; import java.util.Map; @@ -67,19 +70,37 @@ public class InboundHandler { Tracer tracer ) { this.threadPool = threadPool; - this.protocolMessageHandlers = Map.of( - NativeInboundMessage.NATIVE_PROTOCOL, - new NativeMessageHandler( - threadPool, - outboundHandler, - namedWriteableRegistry, - handshaker, - requestHandlers, - responseHandlers, - tracer, - keepAlive - ) - ); + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + this.protocolMessageHandlers = Map.of( + ProtobufInboundMessage.PROTOBUF_PROTOCOL, + new ProtobufMessageHandler(threadPool, responseHandlers), + NativeInboundMessage.NATIVE_PROTOCOL, + new NativeMessageHandler( + threadPool, + outboundHandler, + namedWriteableRegistry, + handshaker, + requestHandlers, + responseHandlers, + tracer, + keepAlive + ) + ); + } else { + this.protocolMessageHandlers = Map.of( + NativeInboundMessage.NATIVE_PROTOCOL, + new NativeMessageHandler( + threadPool, + outboundHandler, + namedWriteableRegistry, + handshaker, + requestHandlers, + responseHandlers, + tracer, + keepAlive + ) + ); + } } void setMessageListener(TransportMessageListener listener) { diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index 5cee3bb975223..0e154d571f832 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -36,9 +36,11 @@ import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; +import org.opensearch.transport.protobufprotocol.ProtobufInboundBytesHandler; import java.io.IOException; import java.util.ArrayDeque; @@ -95,7 +97,14 @@ public InboundPipeline( this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; - this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + this.protocolBytesHandlers = List.of( + new ProtobufInboundBytesHandler(), + new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker) + ); + } else { + this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); + } this.messageHandler = messageHandler; } diff --git a/server/src/main/java/org/opensearch/transport/OutboundHandler.java b/server/src/main/java/org/opensearch/transport/OutboundHandler.java index b83dbdd0effe4..e5bb7764c70dd 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/OutboundHandler.java @@ -51,7 +51,10 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; import java.util.Set; @@ -146,17 +149,54 @@ void sendResponse( final boolean isHandshake ) throws IOException { Version version = Version.min(this.version, nodeVersion); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - features, - response, - version, - requestId, - isHandshake, - compress - ); ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - sendMessage(channel, message, listener); + if ((response.getProtocol()).equals(ProtobufInboundMessage.PROTOBUF_PROTOCOL) && version.onOrAfter(Version.V_3_0_0)) { + if (response instanceof QueryFetchSearchResult) { + QueryFetchSearchResult queryFetchSearchResult = (QueryFetchSearchResult) response; + if (queryFetchSearchResult.response() != null) { + byte[] bytes = new byte[1]; + bytes[0] = 1; + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage( + requestId, + bytes, + Version.CURRENT, + threadPool.getThreadContext(), + queryFetchSearchResult.response(), + features, + action + ); + sendProtobufMessage(channel, protobufMessage, listener); + } + } else if (response instanceof QuerySearchResult) { + QuerySearchResult querySearchResult = (QuerySearchResult) response; + if (querySearchResult.response() != null) { + byte[] bytes = new byte[1]; + bytes[0] = 1; + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage( + requestId, + bytes, + Version.CURRENT, + threadPool.getThreadContext(), + querySearchResult.response(), + features, + action + ); + sendProtobufMessage(channel, protobufMessage, listener); + } + } + + } else { + OutboundMessage.Response message = new OutboundMessage.Response( + threadPool.getThreadContext(), + features, + response, + version, + requestId, + isHandshake, + compress + ); + sendMessage(channel, message, listener); + } } /** @@ -192,6 +232,12 @@ private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, Act internalSend(channel, sendContext); } + private void sendProtobufMessage(TcpChannel channel, ProtobufInboundMessage message, ActionListener listener) throws IOException { + ProtobufMessageSerializer serializer = new ProtobufMessageSerializer(message, bigArrays); + SendContext sendContext = new SendContext(channel, serializer, listener, serializer); + internalSend(channel, sendContext); + } + private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); BytesReference reference = sendContext.get(); @@ -241,6 +287,29 @@ public void close() { } } + private static class ProtobufMessageSerializer implements CheckedSupplier, Releasable { + + private final ProtobufInboundMessage message; + private final BigArrays bigArrays; + private volatile ReleasableBytesStreamOutput bytesStreamOutput; + + private ProtobufMessageSerializer(ProtobufInboundMessage message, BigArrays bigArrays) { + this.message = message; + this.bigArrays = bigArrays; + } + + @Override + public BytesReference get() throws IOException { + bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); + return message.serialize(bytesStreamOutput); + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(bytesStreamOutput); + } + } + private class SendContext extends NotifyOnceListener implements CheckedSupplier { private final TcpChannel channel; diff --git a/server/src/main/java/org/opensearch/transport/TcpHeader.java b/server/src/main/java/org/opensearch/transport/TcpHeader.java index 78353a9a80403..1dd6e89ca9bee 100644 --- a/server/src/main/java/org/opensearch/transport/TcpHeader.java +++ b/server/src/main/java/org/opensearch/transport/TcpHeader.java @@ -73,6 +73,7 @@ public static int headerSize(Version version) { } private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' }; + private static final byte[] PROTOBUF_PREFIX = { (byte) 'O', (byte) 'S', (byte) 'P' }; public static void writeHeader( StreamOutput output, @@ -91,4 +92,8 @@ public static void writeHeader( assert variableHeaderSize != -1 : "Variable header size not set"; output.writeInt(variableHeaderSize); } + + public static void writeHeaderForProtobuf(StreamOutput output) throws IOException { + output.writeBytes(PROTOBUF_PREFIX); + } } diff --git a/server/src/main/java/org/opensearch/transport/TransportRequest.java b/server/src/main/java/org/opensearch/transport/TransportRequest.java index c62cf59d3be2f..8cc53c8ca43a4 100644 --- a/server/src/main/java/org/opensearch/transport/TransportRequest.java +++ b/server/src/main/java/org/opensearch/transport/TransportRequest.java @@ -40,6 +40,8 @@ import org.opensearch.tasks.TaskAwareRequest; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; /** * A transport request @@ -61,6 +63,10 @@ public Empty() {} public Empty(StreamInput in) throws IOException { super(in); } + + public Empty(InputStream in) throws IOException { + super(in); + } } /** @@ -74,6 +80,11 @@ public TransportRequest(StreamInput in) throws IOException { parentTaskId = TaskId.readFromStream(in); } + /** + * This is added here so that classes don't have to implement since it is an experimental feature and only being added for search apis incrementally. + */ + public TransportRequest(InputStream in) throws IOException {} + /** * Set a reference to task that created this request. */ @@ -94,4 +105,10 @@ public TaskId getParentTask() { public void writeTo(StreamOutput out) throws IOException { parentTaskId.writeTo(out); } + + /** + * This is added here so that classes don't have to implement since it is an experimental feature and only being added for search apis incrementally. + */ + @Override + public void writeTo(OutputStream out) throws IOException {} } diff --git a/server/src/main/java/org/opensearch/transport/TransportResponseHandler.java b/server/src/main/java/org/opensearch/transport/TransportResponseHandler.java index 748d2a4d867ec..42048258bb5c9 100644 --- a/server/src/main/java/org/opensearch/transport/TransportResponseHandler.java +++ b/server/src/main/java/org/opensearch/transport/TransportResponseHandler.java @@ -33,11 +33,13 @@ package org.opensearch.transport; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.core.common.io.stream.BytesWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.transport.TransportResponse; import java.io.IOException; +import java.io.InputStream; import java.util.function.Function; /** @@ -46,7 +48,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public interface TransportResponseHandler extends Writeable.Reader { +public interface TransportResponseHandler extends Writeable.Reader, BytesWriteable.Reader { void handleResponse(T response); @@ -54,6 +56,16 @@ public interface TransportResponseHandler extends W String executor(); + /** + * Read {@code V}-type value from a byte array. + * + * @param in byte array to read the value from + */ + default T read(final InputStream in) throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'read'"); + } + /** * This method should be handling the rejection/failure scenarios where connection to the node is rejected or failed. * It should be used to clear up the resources held by the {@link TransportResponseHandler}. @@ -83,6 +95,12 @@ public String executor() { public Q read(StreamInput in) throws IOException { return reader.read(in); } + + @Override + public Q read(InputStream in) throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'read'"); + } }; } } diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index d08b28730d417..a716a6efb97f7 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -76,6 +76,7 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.io.InputStream; import java.io.UncheckedIOException; import java.net.UnknownHostException; import java.util.Arrays; @@ -1508,6 +1509,11 @@ void setTimeoutHandler(TimeoutHandler handler) { this.handler = handler; } + @Override + public T read(InputStream in) throws IOException { + return delegate.read(in); + } + } /** @@ -1721,6 +1727,11 @@ public T read(StreamInput in) throws IOException { public String toString() { return getClass().getName() + "/[" + action + "]:" + handler.toString(); } + + @Override + public T read(InputStream in) throws IOException { + return handler.read(in); + } }; } else { delegate = handler; diff --git a/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java new file mode 100644 index 0000000000000..0b3d268001f69 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobufprotocol; + +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.transport.InboundBytesHandler; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.TcpChannel; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * Handler for inbound bytes for the protobuf protocol. + */ +public class ProtobufInboundBytesHandler implements InboundBytesHandler { + + public ProtobufInboundBytesHandler() {} + + @Override + public void doHandleBytes( + TcpChannel channel, + ReleasableBytesReference reference, + BiConsumer messageHandler + ) throws IOException { + // removing the first byte we added for protobuf message + byte[] incomingBytes = BytesReference.toBytes(reference.slice(3, reference.length() - 3)); + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage(new ByteArrayInputStream(incomingBytes)); + messageHandler.accept(channel, protobufMessage); + } + + @Override + public boolean canHandleBytes(ReleasableBytesReference reference) { + if (reference.get(0) == 'O' && reference.get(1) == 'S' && reference.get(2) == 'P') { + return true; + } + return false; + } + + @Override + public void close() { + // no-op + } + +} diff --git a/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java new file mode 100644 index 0000000000000..3f650ff61d3ab --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java @@ -0,0 +1,172 @@ +/* +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +package org.opensearch.transport.protobufprotocol; + +import com.google.protobuf.ByteString; +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.server.proto.NodeToNodeMessageProto; +import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header; +import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersList; +import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; +import org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.TcpHeader; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Outbound data as a message +* +* @opensearch.internal +*/ +public class ProtobufInboundMessage implements ProtocolInboundMessage { + + /** + * The protocol used to encode this message + */ + public static String PROTOBUF_PROTOCOL = "protobuf"; + + private final NodeToNodeMessageProto.NodeToNodeMessage message; + private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' }; + + public ProtobufInboundMessage( + long requestId, + byte[] status, + Version version, + ThreadContext threadContext, + QueryFetchSearchResult queryFetchSearchResult, + Set features, + String action + ) { + Header header = Header.newBuilder() + .addAllPrefix(Arrays.asList(ByteString.copyFrom(PREFIX))) + .setRequestId(requestId) + .setStatus(ByteString.copyFrom(status)) + .setVersionId(version.id) + .build(); + Map requestHeaders = threadContext.getHeaders(); + Map> responseHeaders = threadContext.getResponseHeaders(); + Map responseHandlers = new HashMap<>(); + for (Map.Entry> entry : responseHeaders.entrySet()) { + String key = entry.getKey(); + List value = entry.getValue(); + ResponseHandlersList responseHandlersList = ResponseHandlersList.newBuilder().addAllSetOfResponseHandlers(value).build(); + responseHandlers.put(key, responseHandlersList); + } + this.message = NodeToNodeMessageProto.NodeToNodeMessage.newBuilder() + .setHeader(header) + .putAllRequestHeaders(requestHeaders) + .putAllResponseHandlers(responseHandlers) + .setVersion(version.toString()) + .setStatus(ByteString.copyFrom(status)) + .setRequestId(requestId) + .setQueryFetchSearchResult(queryFetchSearchResult) + .setAction(action) + .addAllFeatures(features) + .build(); + } + + public ProtobufInboundMessage( + long requestId, + byte[] status, + Version version, + ThreadContext threadContext, + QuerySearchResult querySearchResult, + Set features, + String action + ) { + Header header = Header.newBuilder() + .addAllPrefix(Arrays.asList(ByteString.copyFrom(PREFIX))) + .setRequestId(requestId) + .setStatus(ByteString.copyFrom(status)) + .setVersionId(version.id) + .build(); + Map requestHeaders = threadContext.getHeaders(); + Map> responseHeaders = threadContext.getResponseHeaders(); + Map responseHandlers = new HashMap<>(); + for (Map.Entry> entry : responseHeaders.entrySet()) { + String key = entry.getKey(); + List value = entry.getValue(); + ResponseHandlersList responseHandlersList = ResponseHandlersList.newBuilder().addAllSetOfResponseHandlers(value).build(); + responseHandlers.put(key, responseHandlersList); + } + this.message = NodeToNodeMessageProto.NodeToNodeMessage.newBuilder() + .setHeader(header) + .putAllRequestHeaders(requestHeaders) + .putAllResponseHandlers(responseHandlers) + .setVersion(version.toString()) + .setStatus(ByteString.copyFrom(status)) + .setRequestId(requestId) + .setQuerySearchResult(querySearchResult) + .setAction(action) + .addAllFeatures(features) + .build(); + } + + public ProtobufInboundMessage(InputStream in) throws IOException { + this.message = NodeToNodeMessageProto.NodeToNodeMessage.parseFrom(in); + } + + public void writeTo(OutputStream out) throws IOException { + this.message.writeTo(out); + } + + public BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { + NodeToNodeMessageProto.NodeToNodeMessage message = getMessage(); + TcpHeader.writeHeaderForProtobuf(bytesStream); + message.writeTo(bytesStream); + return bytesStream.bytes(); + } + + public NodeToNodeMessageProto.NodeToNodeMessage getMessage() { + return this.message; + } + + @Override + public String toString() { + return "ProtobufInboundMessage [message=" + message + "]"; + } + + public org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header getHeader() { + return this.message.getHeader(); + } + + public Map getRequestHeaders() { + return this.message.getRequestHeadersMap(); + } + + public Map> getResponseHandlers() { + Map responseHandlers = this.message.getResponseHandlersMap(); + Map> responseHandlersMap = new HashMap<>(); + for (Map.Entry entry : responseHandlers.entrySet()) { + String key = entry.getKey(); + ResponseHandlersList value = entry.getValue(); + Set setOfResponseHandlers = value.getSetOfResponseHandlersList().stream().collect(Collectors.toSet()); + responseHandlersMap.put(key, setOfResponseHandlers); + } + return responseHandlersMap; + } + + @Override + public String getProtocol() { + return PROTOBUF_PROTOCOL; + } + +} diff --git a/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java new file mode 100644 index 0000000000000..a945bab7a345a --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java @@ -0,0 +1,178 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobufprotocol; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.ProtocolMessageHandler; +import org.opensearch.transport.RemoteTransportException; +import org.opensearch.transport.ResponseHandlerFailureTransportException; +import org.opensearch.transport.TcpChannel; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportMessageListener; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportSerializationException; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Map; +import java.util.Set; + +/** + * Protobuf handler for inbound data + * + * @opensearch.internal + */ +public class ProtobufMessageHandler implements ProtocolMessageHandler { + + private static final Logger logger = LogManager.getLogger(ProtobufMessageHandler.class); + + private final ThreadPool threadPool; + private final Transport.ResponseHandlers responseHandlers; + + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; + + private volatile long slowLogThresholdMs = Long.MAX_VALUE; + + public ProtobufMessageHandler(ThreadPool threadPool, Transport.ResponseHandlers responseHandlers) { + this.threadPool = threadPool; + this.responseHandlers = responseHandlers; + } + + void setMessageListener(TransportMessageListener listener) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { + messageListener = listener; + } else { + throw new IllegalStateException("Cannot set message listener twice"); + } + } + + void setSlowLogThreshold(TimeValue slowLogThreshold) { + this.slowLogThresholdMs = slowLogThreshold.getMillis(); + } + + @Override + public void messageReceived( + TcpChannel channel, + ProtocolInboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException { + ProtobufInboundMessage nodeToNodeMessage = (ProtobufInboundMessage) message; + final InetSocketAddress remoteAddress = channel.getRemoteAddress(); + final org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header header = nodeToNodeMessage.getHeader(); + + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + // Place the context with the headers from the message + final Tuple, Map>> headers = new Tuple, Map>>( + nodeToNodeMessage.getRequestHeaders(), + nodeToNodeMessage.getResponseHandlers() + ); + threadContext.setHeaders(headers); + threadContext.putTransient("_remote_address", remoteAddress); + + long requestId = header.getRequestId(); + TransportResponseHandler handler = responseHandlers.onResponseReceived(requestId, messageListener); + if (handler != null) { + handleProtobufResponse(requestId, remoteAddress, nodeToNodeMessage, handler); + } + } finally { + final long took = threadPool.relativeTimeInMillis() - startTime; + final long logThreshold = slowLogThresholdMs; + if (logThreshold > 0 && took > logThreshold) { + logger.warn( + "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", + message, + took, + logThreshold + ); + } + } + } + + private void handleProtobufResponse( + final long requestId, + InetSocketAddress remoteAddress, + final ProtobufInboundMessage message, + final TransportResponseHandler handler + ) throws IOException { + try { + org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage receivedMessage = message.getMessage(); + if (receivedMessage.hasQueryFetchSearchResult()) { + final QueryFetchSearchResult queryFetchSearchResult = receivedMessage.getQueryFetchSearchResult(); + org.opensearch.search.fetch.QueryFetchSearchResult queryFetchSearchResult2 = + new org.opensearch.search.fetch.QueryFetchSearchResult(queryFetchSearchResult); + final T response = (T) queryFetchSearchResult2; + response.remoteAddress(new TransportAddress(remoteAddress)); + + final String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { + doHandleResponse(handler, response); + } else { + threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); + } + } else if (receivedMessage.hasQuerySearchResult()) { + final org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult querySearchResult = receivedMessage + .getQuerySearchResult(); + QuerySearchResult querySearchResult2 = new QuerySearchResult(querySearchResult); + final T response = (T) querySearchResult2; + response.remoteAddress(new TransportAddress(remoteAddress)); + + final String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { + doHandleResponse(handler, response); + } else { + threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); + } + } + } catch (Exception e) { + final Exception serializationException = new TransportSerializationException( + "Failed to deserialize response from handler [" + handler + "]", + e + ); + logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); + handleException(handler, serializationException); + return; + } + } + + private void doHandleResponse(TransportResponseHandler handler, T response) { + try { + handler.handleResponse(response); + } catch (Exception e) { + handleException(handler, new ResponseHandlerFailureTransportException(e)); + } + } + + private void handleException(final TransportResponseHandler handler, Throwable error) { + if (!(error instanceof RemoteTransportException)) { + error = new RemoteTransportException(error.getMessage(), error); + } + final RemoteTransportException rtx = (RemoteTransportException) error; + threadPool.executor(handler.executor()).execute(() -> { + try { + handler.handleException(rtx); + } catch (Exception e) { + logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e); + } + }); + } +} diff --git a/server/src/main/java/org/opensearch/transport/protobufprotocol/package-info.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/package-info.java new file mode 100644 index 0000000000000..8bccdc2be0b79 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Protobuf transport protocol package. */ +package org.opensearch.transport.protobufprotocol; diff --git a/server/src/main/proto/server/NodeToNodeMessageProto.proto b/server/src/main/proto/server/NodeToNodeMessageProto.proto new file mode 100644 index 0000000000000..e83143202dce7 --- /dev/null +++ b/server/src/main/proto/server/NodeToNodeMessageProto.proto @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/QueryFetchSearchResultProto.proto"; +import "server/search/QuerySearchResultProto.proto"; + +option java_outer_classname = "NodeToNodeMessageProto"; + +message NodeToNodeMessage { + Header header = 1; + map requestHeaders = 2; + map responseHandlers = 3; + string version = 4; + repeated string features = 5; + bytes status = 6; + string action = 7; + int64 requestId = 8; + oneof message { + QueryFetchSearchResult queryFetchSearchResult = 9; + QuerySearchResult querySearchResult = 10; + } + + message Header { + repeated bytes prefix = 1; + int64 requestId = 2; + bytes status = 3; + int32 versionId = 4; + } + + message ResponseHandlersList { + repeated string setOfResponseHandlers = 1; + } +} diff --git a/server/src/main/proto/server/search/FetchSearchResultProto.proto b/server/src/main/proto/server/search/FetchSearchResultProto.proto new file mode 100644 index 0000000000000..f0f4f495ad179 --- /dev/null +++ b/server/src/main/proto/server/search/FetchSearchResultProto.proto @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "google/protobuf/any.proto"; +import "server/search/QuerySearchResultProto.proto"; +import "server/search/ShardSearchRequestProto.proto"; + +option java_outer_classname = "FetchSearchResultProto"; + +message FetchSearchResult { + ShardSearchContextId contextId = 1; + optional SearchHits hits = 2; +} + +message SearchHits { + TotalHits totalHits = 1; + float maxScore = 2; + int32 size = 3; + repeated SearchHit hits = 4; + repeated SortField sortFields = 5; + optional string collapseField = 6; + repeated SortValue collapseValues = 7; +} + +message SearchHit { + int32 docId = 1; + float score = 2; + string id = 3; + optional NestedIdentity nestedIdentity = 4; + int64 version = 5; + int64 seqNo = 6; + int64 primaryTerm = 7; + bytes source = 8; + map documentFields = 9; + map metaFields = 10; + map highlightFields = 11; + SearchSortValues sortValues = 12; + repeated string matchedQueries = 13; + optional Explanation explanation = 14; + SearchShardTarget shard = 15; + optional string index = 16; + optional string clusterAlias = 17; + map innerHits = 18; + map matchedQueriesWithScores = 19; + + message NestedIdentity { + optional string field = 1; + optional int32 offset = 2; + optional NestedIdentity child = 3; + } + + message DocumentField { + string name = 1; + repeated DocumentFieldValue values = 2; + } + + message HighlightField { + string name = 1; + repeated string fragments = 2; + } + + message SearchSortValues { + repeated SortValue formattedSortValues = 1; + repeated SortValue rawSortValues = 2; + } + + message Explanation { + bool match = 1; + string description = 2; + repeated Explanation subExplanations = 3; + oneof explanationValue { + float value1 = 4; + double value2 = 5; + int64 value3 = 6; + } + } +} + +message SortField { + Type type = 1; + string field = 2; + + enum Type { + SCORE = 0; + DOC = 1; + STRING = 2; + INT = 3; + FLOAT = 4; + LONG = 5; + DOUBLE = 6; + CUSTOM = 7; + STRING_VAL = 8; + REWRITEABLE = 9; + } +} + +message SortValue { + optional string collapseString = 1; + optional int32 collapseInt = 2; + optional int64 collapseLong = 3; + optional float collapseFloat = 4; + optional double collapseDouble = 5; + optional bytes collapseBytes = 6; + optional bool collapseBool = 7; +} + +message DocumentFieldValue { + optional string valueString = 1; + optional int32 valueInt = 2; + optional int64 valueLong = 3; + optional float valueFloat = 4; + optional double valueDouble = 5; + optional bool valueBool = 6; + repeated bytes valueByteArray = 7; + repeated DocumentFieldValue valueArrayList = 8; + map valueMap = 9; + optional int64 valueDate = 10; + optional string valueZonedDate = 11; + optional int64 valueZonedTime = 12; + optional string valueText = 13; +} + diff --git a/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto b/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto new file mode 100644 index 0000000000000..deac135c0e3d0 --- /dev/null +++ b/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/QuerySearchResultProto.proto"; +import "server/search/FetchSearchResultProto.proto"; + +option java_outer_classname = "QueryFetchSearchResultProto"; + +message QueryFetchSearchResult { + QuerySearchResult queryResult = 1; + FetchSearchResult fetchResult = 2; +} diff --git a/server/src/main/proto/server/search/QuerySearchResultProto.proto b/server/src/main/proto/server/search/QuerySearchResultProto.proto new file mode 100644 index 0000000000000..e4fc82207042c --- /dev/null +++ b/server/src/main/proto/server/search/QuerySearchResultProto.proto @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/ShardSearchRequestProto.proto"; + +option java_outer_classname = "QuerySearchResultProto"; + +message QuerySearchResult { + ShardSearchContextId contextId = 1; + optional int32 from = 2; + optional int32 size = 3; + optional TopDocsAndMaxScore topDocsAndMaxScore = 4; + optional bool hasScoreDocs = 5; + optional TotalHits totalHits = 6; + optional float maxScore = 7; + optional TopDocs topDocs = 8; + optional bool hasAggs = 9; + optional bool hasSuggest = 10; + optional bool searchTimedOut = 11; + optional bool terminatedEarly = 12; + optional bytes profileShardResults = 13; + optional int64 serviceTimeEWMA = 14; + optional int32 nodeQueueSize = 15; + SearchShardTarget searchShardTarget = 17; + ShardSearchRequest searchShardRequest = 18; + bool isNull = 19; + + message TopDocsAndMaxScore { + TopDocs topDocs = 1; + float maxScore = 2; + } + + message TopDocs { + TotalHits totalHits = 1; + repeated ScoreDoc scoreDocs = 2; + + message ScoreDoc { + int32 doc = 1; + float score = 2; + int32 shardIndex = 3; + } + } + + message RescoreDocIds { + map docIds = 1; + + message setInteger { + repeated int32 values = 1; + } + } + +} + +message SearchShardTarget { + string nodeId = 1; + ShardId shardId = 2; + optional string clusterAlias = 3; +} + +message TotalHits { + int64 value = 1; + Relation relation = 2; + + enum Relation { + EQUAL_TO = 0; + GREATER_THAN_OR_EQUAL_TO = 1; + } +} diff --git a/server/src/main/proto/server/search/ShardSearchRequestProto.proto b/server/src/main/proto/server/search/ShardSearchRequestProto.proto new file mode 100644 index 0000000000000..114e9154ee5d9 --- /dev/null +++ b/server/src/main/proto/server/search/ShardSearchRequestProto.proto @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +option java_outer_classname = "ShardSearchRequestProto"; + +message ShardSearchRequest { + OriginalIndices originalIndices = 1; + ShardId shardId = 2; + int32 numberOfShards = 3; + SearchType searchType = 4; + bytes source = 5; + bool requestCache = 6; + AliasFilter aliasFilter = 7; + float indexBoost = 8; + bool allowPartialSearchResults = 9; + repeated string indexRoutings = 10; + string preference = 11; + Scroll scroll = 12; + int64 nowInMillis = 13; + optional string clusterAlias = 14; + optional ShardSearchContextId readerId = 15; + optional string timeValue = 16; + int64 inboundNetworkTime = 17; + int64 outboundNetworkTime = 18; + bool canReturnNullResponseIfMatchNoDocs = 19; + + enum SearchType { + QUERY_THEN_FETCH = 0; + DFS_QUERY_THEN_FETCH = 1; + } +} + +message ShardSearchContextId { + string sessionId = 1; + int64 id = 2; +} + +message ShardId { + int32 shardId = 1; + int32 hashCode = 2; + string indexName = 3; + string indexUUID = 4; +} + +message Scroll { + string keepAlive = 1; +} + +message OriginalIndices { + repeated string indices = 1; + IndicesOptions indicesOptions = 2; + + message IndicesOptions { + bool ignoreUnavailable = 1; + bool allowNoIndices = 2; + bool expandWildcardsOpen = 3; + bool expandWildcardsClosed = 4; + bool expandWildcardsHidden = 5; + bool allowAliasesToMultipleIndices = 6; + bool forbidClosedIndices = 7; + bool ignoreAliases = 8; + bool ignoreThrottled = 9; + } +} + +message AliasFilter { + repeated string aliases = 1; +} \ No newline at end of file diff --git a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java index 41e4e1ae45a73..5ab6714a69116 100644 --- a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java +++ b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java @@ -39,9 +39,11 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.action.OriginalIndicesTests; import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.UUIDs; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.index.shard.ShardId; @@ -54,8 +56,13 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.suggest.SuggestTests; +import org.opensearch.server.proto.QuerySearchResultProto; import org.opensearch.test.OpenSearchTestCase; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; + import static java.util.Collections.emptyList; public class QuerySearchResultTests extends OpenSearchTestCase { @@ -125,4 +132,30 @@ public void testNullResponse() throws Exception { QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new, Version.CURRENT); assertEquals(querySearchResult.isNull(), deserialized.isNull()); } + + @SuppressForbidden(reason = "manipulates system properties for testing") + public void testProtobufSerialization() throws Exception { + System.setProperty(FeatureFlags.PROTOBUF, "true"); + QuerySearchResult querySearchResult = createTestInstance(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + querySearchResult.writeTo(stream); + + InputStream inputStream = new ByteArrayInputStream(stream.toByteArray()); + QuerySearchResult deserialized = new QuerySearchResult(inputStream); + QuerySearchResultProto.QuerySearchResult querySearchResultProto = deserialized.response(); + assertNotNull(querySearchResultProto); + assertEquals(querySearchResult.getContextId().getId(), querySearchResultProto.getContextId().getId()); + assertEquals( + querySearchResult.getSearchShardTarget().getShardId().getIndex().getUUID(), + querySearchResultProto.getSearchShardTarget().getShardId().getIndexUUID() + ); + assertEquals(querySearchResult.topDocs().maxScore, querySearchResultProto.getTopDocsAndMaxScore().getMaxScore(), 0f); + assertEquals( + querySearchResult.topDocs().topDocs.totalHits.value, + querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getTotalHits().getValue() + ); + assertEquals(querySearchResult.from(), querySearchResultProto.getFrom()); + assertEquals(querySearchResult.size(), querySearchResultProto.getSize()); + System.setProperty(FeatureFlags.PROTOBUF, "false"); + } } diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 0d171e17e70e1..c6c9a33041a66 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -37,18 +37,23 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.OpenSearchException; import org.opensearch.Version; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; import org.opensearch.tasks.TaskManager; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.MockLogAppender; @@ -56,9 +61,11 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import org.junit.After; import org.junit.Before; +import java.io.ByteArrayInputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; @@ -80,6 +87,9 @@ public class InboundHandlerTests extends OpenSearchTestCase { private final Version version = Version.CURRENT; private TaskManager taskManager; + private NamedWriteableRegistry namedWriteableRegistry; + private TransportHandshaker handshaker; + private TransportKeepAlive keepAlive; private Transport.ResponseHandlers responseHandlers; private Transport.RequestHandlers requestHandlers; private InboundHandler handler; @@ -98,8 +108,8 @@ public void sendMessage(BytesReference reference, ActionListener listener) } } }; - NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); - TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}); + namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}); outboundHandler = new OutboundHandler( "node", version, @@ -108,7 +118,7 @@ public void sendMessage(BytesReference reference, ActionListener listener) threadPool, BigArrays.NON_RECYCLING_INSTANCE ); - TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, outboundHandler::sendBytes); + keepAlive = new TransportKeepAlive(threadPool, outboundHandler::sendBytes); requestHandlers = new Transport.RequestHandlers(); responseHandlers = new Transport.ResponseHandlers(); handler = new InboundHandler( @@ -242,6 +252,104 @@ public TestResponse read(StreamInput in) throws IOException { } } + @SuppressForbidden(reason = "manipulates system properties for testing") + public void testProtobufResponse() throws Exception { + System.setProperty(FeatureFlags.PROTOBUF_SETTING.getKey(), "true"); + InboundHandler inboundHandler = new InboundHandler( + threadPool, + outboundHandler, + namedWriteableRegistry, + handshaker, + keepAlive, + requestHandlers, + responseHandlers, + NoopTracer.INSTANCE + ); + String action = "test-request"; + int headerSize = TcpHeader.headerSize(version); + AtomicReference requestCaptor = new AtomicReference<>(); + AtomicReference exceptionCaptor = new AtomicReference<>(); + AtomicReference responseCaptor = new AtomicReference<>(); + AtomicReference channelCaptor = new AtomicReference<>(); + + long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler() { + @Override + public void handleResponse(QueryFetchSearchResult response) { + responseCaptor.set(response); + } + + @Override + public void handleException(TransportException exp) { + exceptionCaptor.set(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public QueryFetchSearchResult read(StreamInput in) throws IOException { + throw new UnsupportedOperationException("Unimplemented method 'read'"); + } + + @Override + public QueryFetchSearchResult read(InputStream in) throws IOException { + return new QueryFetchSearchResult(in); + } + }, null, action)); + RequestHandlerRegistry registry = new RequestHandlerRegistry<>( + action, + TestRequest::new, + taskManager, + (request, channel, task) -> { + channelCaptor.set(channel); + requestCaptor.set(request); + }, + ThreadPool.Names.SAME, + false, + true + ); + requestHandlers.registerHandler(registry); + String requestValue = randomAlphaOfLength(10); + OutboundMessage.Request request = new OutboundMessage.Request( + threadPool.getThreadContext(), + new String[0], + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ); + + BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); + BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); + Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); + inboundHandler.inboundMessage(channel, requestMessage); + + TransportChannel transportChannel = channelCaptor.get(); + assertEquals(Version.CURRENT, transportChannel.getVersion()); + assertEquals("transport", transportChannel.getChannelType()); + assertEquals(requestValue, requestCaptor.get().value); + + QuerySearchResult queryResult = OutboundHandlerTests.createQuerySearchResult(); + FetchSearchResult fetchResult = OutboundHandlerTests.createFetchSearchResult(); + QueryFetchSearchResult response = new QueryFetchSearchResult(queryResult, fetchResult); + transportChannel.sendResponse(response); + + BytesReference fullResponseBytes = channel.getMessageCaptor().get(); + byte[] incomingBytes = BytesReference.toBytes(fullResponseBytes.slice(3, fullResponseBytes.length() - 3)); + ProtobufInboundMessage nodeToNodeMessage = new ProtobufInboundMessage(new ByteArrayInputStream(incomingBytes)); + inboundHandler.inboundMessage(channel, nodeToNodeMessage); + QueryFetchSearchResult result = responseCaptor.get(); + assertNotNull(result); + assertEquals(queryResult.getMaxScore(), result.queryResult().getMaxScore(), 0.0); + System.setProperty(FeatureFlags.PROTOBUF, "false"); + } + public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exception { // Nodes use their minimum compatibility version for the TCP handshake, so a node from v(major-1).x will report its version as // v(major-2).last in the TCP handshake, with which we are not really compatible. We put extra effort into making sure that if diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 2dfe8a0dd8590..03f10bb702144 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -95,7 +95,6 @@ public void testPipelineHandling() throws IOException { throw new AssertionError(e); } }; - final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index 36ba409a2de03..a9d8d3c45b9f9 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -34,29 +34,47 @@ import org.opensearch.OpenSearchException; import org.opensearch.Version; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.OriginalIndicesTests; +import org.opensearch.action.search.SearchRequest; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.UUIDs; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.io.Streams; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import org.junit.After; import org.junit.Before; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.concurrent.TimeUnit; @@ -76,6 +94,7 @@ public class OutboundHandlerTests extends OpenSearchTestCase { private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private final AtomicReference> message = new AtomicReference<>(); + private final AtomicReference protobufMessage = new AtomicReference<>(); private InboundPipeline pipeline; private OutboundHandler handler; private FakeTcpChannel channel; @@ -263,6 +282,66 @@ public void onResponseSent(long requestId, String action, TransportResponse resp assertEquals("header_value", header.getHeaders().v1().get("header")); } + @SuppressForbidden(reason = "manipulates system properties for testing") + public void testSendProtobufResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = Version.CURRENT; + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + QuerySearchResult queryResult = createQuerySearchResult(); + FetchSearchResult fetchResult = createFetchSearchResult(); + QueryFetchSearchResult response = new QueryFetchSearchResult(queryResult, fetchResult); + System.setProperty(FeatureFlags.PROTOBUF, "true"); + assertTrue((response.getProtocol()).equals(ProtobufInboundMessage.PROTOBUF_PROTOCOL)); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(response); + } + }); + handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake); + + StatsTracker statsTracker = new StatsTracker(); + final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) requestCanTripBreaker -> true); + InboundPipeline inboundPipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { + ProtobufInboundMessage m1 = (ProtobufInboundMessage) m; + protobufMessage.set(BytesReference.fromByteBuffer(ByteBuffer.wrap(m1.getMessage().toByteArray()))); + }); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(response, responseRef.get()); + + inboundPipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); + final BytesReference responseBytes = protobufMessage.get(); + final ProtobufInboundMessage message = new ProtobufInboundMessage(new ByteArrayInputStream(responseBytes.toBytesRef().bytes)); + assertEquals(version.toString(), message.getMessage().getVersion()); + assertEquals(requestId, message.getHeader().getRequestId()); + assertNotNull(message.getRequestHeaders()); + assertNotNull(message.getResponseHandlers()); + assertNotNull(message.getMessage()); + assertTrue(message.getMessage().hasQueryFetchSearchResult()); + System.setProperty(FeatureFlags.PROTOBUF, "false"); + } + public void testErrorResponse() throws IOException { ThreadContext threadContext = threadPool.getThreadContext(); Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); @@ -314,4 +393,35 @@ public void onResponseSent(long requestId, String action, Exception error) { assertEquals("header_value", header.getHeaders().v1().get("header")); } + + public static QuerySearchResult createQuerySearchResult() { + ShardId shardId = new ShardId("index", "uuid", randomInt()); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + OriginalIndicesTests.randomOriginalIndices(), + searchRequest, + shardId, + 1, + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + randomNonNegativeLong(), + null, + new String[0] + ); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId(UUIDs.base64UUID(), randomLong()), + new SearchShardTarget("node", shardId, null, OriginalIndices.NONE), + shardSearchRequest + ); + return result; + } + + public static FetchSearchResult createFetchSearchResult() { + ShardId shardId = new ShardId("index", "uuid", randomInt()); + FetchSearchResult result = new FetchSearchResult( + new ShardSearchContextId(UUIDs.base64UUID(), randomLong()), + new SearchShardTarget("node", shardId, null, OriginalIndices.NONE) + ); + return result; + } }