diff --git a/distribution/src/config/opensearch.yml b/distribution/src/config/opensearch.yml index 10bab9b3fce92..e0e623652d015 100644 --- a/distribution/src/config/opensearch.yml +++ b/distribution/src/config/opensearch.yml @@ -125,3 +125,7 @@ ${path.logs} # Gates the functionality of enabling Opensearch to use pluggable caches with respective store names via setting. # #opensearch.experimental.feature.pluggable.caching.enabled: false +# +# Gates the functionality of enabling Opensearch to use protobuf with basic searches and for node-to-node communication. +# +#opensearch.experimental.feature.search_with_protobuf.enabled: false diff --git a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java index 569f48a8465f3..0d4e342bd9dc6 100644 --- a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java +++ b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java @@ -179,6 +179,11 @@ private void process(ExecutableElement executable, ReferenceType ref) { return; } + // Skip protobuf generated classes used in public apis + if (ref.toString().contains("Proto")) { + return; + } + if (ref instanceof DeclaredType) { final DeclaredType declaredType = (DeclaredType) ref; diff --git a/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java b/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java index 7e69bb9dc6cbd..e62eaed1f51a4 100644 --- a/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java +++ b/server/src/main/java/org/opensearch/action/ProtobufActionListenerResponseHandler.java @@ -68,7 +68,6 @@ public String toString() { @Override public Response read(StreamInput in) throws IOException { - // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'read'"); } diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index 3081e95dc5e26..f1b06378bd579 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -38,7 +38,6 @@ import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.CircuitBreakingException; @@ -115,11 +114,7 @@ public QueryPhaseResultConsumer( SearchSourceBuilder source = request.source(); this.hasTopDocs = source == null || source.size() != 0; - if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) { - this.hasAggs = false; - } else { - this.hasAggs = source != null && source.aggregations() != null; - } + this.hasAggs = source != null && source.aggregations() != null; int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo()); } @@ -325,7 +320,7 @@ synchronized long addEstimateAndMaybeBreak(long estimatedSize) { * provided {@link QuerySearchResult}. */ long ramBytesUsedQueryResult(QuerySearchResult result) { - if (hasAggs == false || FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) { + if (hasAggs == false) { return 0; } return result.aggregations().asSerialized(InternalAggregations::readFrom, namedWriteableRegistry).ramBytesUsed(); @@ -494,7 +489,7 @@ public synchronized List consumeTopDocs() { } public synchronized List consumeAggs() { - if (hasAggs == false || FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) { + if (hasAggs == false) { return Collections.emptyList(); } List aggsList = new ArrayList<>(); diff --git a/server/src/main/java/org/opensearch/common/document/DocumentField.java b/server/src/main/java/org/opensearch/common/document/DocumentField.java index 71b0c4488dd82..383648e175dc0 100644 --- a/server/src/main/java/org/opensearch/common/document/DocumentField.java +++ b/server/src/main/java/org/opensearch/common/document/DocumentField.java @@ -32,34 +32,22 @@ package org.opensearch.common.document; -import com.google.protobuf.ByteString; -import org.opensearch.OpenSearchException; import org.opensearch.common.annotation.PublicApi; -import org.opensearch.common.util.FeatureFlags; +import org.opensearch.common.document.serializer.DocumentFieldProtobufSerializer; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.common.text.Text; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.get.GetResult; import org.opensearch.search.SearchHit; import org.opensearch.server.proto.FetchSearchResultProto; -import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue; -import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder; import java.io.IOException; -import java.time.Instant; -import java.time.ZoneId; -import java.time.ZonedDateTime; import java.util.ArrayList; -import java.util.Date; -import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -90,110 +78,16 @@ public DocumentField(String name, List values) { this.values = Objects.requireNonNull(values, "values must not be null"); } - public DocumentField(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(in); - name = documentField.getName(); - values = new ArrayList<>(); - for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) { - values.add(readDocumentFieldValueFromProtobuf(value)); - } - } - public static FetchSearchResultProto.SearchHit.DocumentField convertDocumentFieldToProto(DocumentField documentField) { FetchSearchResultProto.SearchHit.DocumentField.Builder builder = FetchSearchResultProto.SearchHit.DocumentField.newBuilder(); builder.setName(documentField.getName()); for (Object value : documentField.getValues()) { FetchSearchResultProto.DocumentFieldValue.Builder valueBuilder = FetchSearchResultProto.DocumentFieldValue.newBuilder(); - builder.addValues(convertDocumentFieldValueToProto(value, valueBuilder)); + builder.addValues(DocumentFieldProtobufSerializer.convertDocumentFieldValueToProto(value, valueBuilder)); } return builder.build(); } - private static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) { - if (value == null) { - // null is not allowed in protobuf, so we use a special string to represent null - return valueBuilder.setValueString("null"); - } - Class type = value.getClass(); - if (type == String.class) { - valueBuilder.setValueString((String) value); - } else if (type == Integer.class) { - valueBuilder.setValueInt((Integer) value); - } else if (type == Long.class) { - valueBuilder.setValueLong((Long) value); - } else if (type == Float.class) { - valueBuilder.setValueFloat((Float) value); - } else if (type == Double.class) { - valueBuilder.setValueDouble((Double) value); - } else if (type == Boolean.class) { - valueBuilder.setValueBool((Boolean) value); - } else if (type == byte[].class) { - valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value)); - } else if (type == List.class) { - List list = (List) value; - for (Object listValue : list) { - valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder)); - } - } else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) { - Map map = (Map) value; - for (Map.Entry entry : map.entrySet()) { - valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build()); - } - } else if (type == Date.class) { - valueBuilder.setValueDate(((Date) value).getTime()); - } else if (type == ZonedDateTime.class) { - valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId()); - valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli()); - } else if (type == Text.class) { - valueBuilder.setValueText(((Text) value).string()); - } else { - throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf"); - } - return valueBuilder; - } - - private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException { - if (documentFieldValue.hasValueString()) { - return documentFieldValue.getValueString(); - } else if (documentFieldValue.hasValueInt()) { - return documentFieldValue.getValueInt(); - } else if (documentFieldValue.hasValueLong()) { - return documentFieldValue.getValueLong(); - } else if (documentFieldValue.hasValueFloat()) { - return documentFieldValue.getValueFloat(); - } else if (documentFieldValue.hasValueDouble()) { - return documentFieldValue.getValueDouble(); - } else if (documentFieldValue.hasValueBool()) { - return documentFieldValue.getValueBool(); - } else if (documentFieldValue.getValueByteArrayList().size() > 0) { - return documentFieldValue.getValueByteArrayList().toArray(); - } else if (documentFieldValue.getValueArrayListList().size() > 0) { - List list = new ArrayList<>(); - for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) { - list.add(readDocumentFieldValueFromProtobuf(value)); - } - return list; - } else if (documentFieldValue.getValueMapMap().size() > 0) { - Map map = Map.of(); - for (Map.Entry entrySet : documentFieldValue.getValueMapMap().entrySet()) { - map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue())); - } - return map; - } else if (documentFieldValue.hasValueDate()) { - return new Date(documentFieldValue.getValueDate()); - } else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) { - return ZonedDateTime.ofInstant( - Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()), - ZoneId.of(documentFieldValue.getValueZonedDate()) - ); - } else if (documentFieldValue.hasValueText()) { - return new Text(documentFieldValue.getValueText()); - } else { - throw new IOException("Can't read generic value of type [" + documentFieldValue + "]"); - } - } - /** * The name of the field. */ diff --git a/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java new file mode 100644 index 0000000000000..614837ce7c9d3 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldProtobufSerializer.java @@ -0,0 +1,135 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.document.serializer; + +import com.google.protobuf.ByteString; +import org.opensearch.OpenSearchException; +import org.opensearch.common.document.DocumentField; +import org.opensearch.core.common.text.Text; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder; + +import java.io.IOException; +import java.io.InputStream; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public class DocumentFieldProtobufSerializer implements DocumentFieldSerializer { + + private FetchSearchResultProto.SearchHit.DocumentField documentField; + + @Override + public DocumentField createDocumentField(InputStream inputStream) throws IOException { + documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(inputStream); + String name = documentField.getName(); + List values = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) { + values.add(readDocumentFieldValueFromProtobuf(value)); + } + return new DocumentField(name, values); + } + + private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException { + if (documentFieldValue.hasValueString()) { + return documentFieldValue.getValueString(); + } else if (documentFieldValue.hasValueInt()) { + return documentFieldValue.getValueInt(); + } else if (documentFieldValue.hasValueLong()) { + return documentFieldValue.getValueLong(); + } else if (documentFieldValue.hasValueFloat()) { + return documentFieldValue.getValueFloat(); + } else if (documentFieldValue.hasValueDouble()) { + return documentFieldValue.getValueDouble(); + } else if (documentFieldValue.hasValueBool()) { + return documentFieldValue.getValueBool(); + } else if (documentFieldValue.getValueByteArrayList().size() > 0) { + return documentFieldValue.getValueByteArrayList().toArray(); + } else if (documentFieldValue.getValueArrayListList().size() > 0) { + List list = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) { + list.add(readDocumentFieldValueFromProtobuf(value)); + } + return list; + } else if (documentFieldValue.getValueMapMap().size() > 0) { + Map map = Map.of(); + for (Map.Entry entrySet : documentFieldValue.getValueMapMap().entrySet()) { + map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue())); + } + return map; + } else if (documentFieldValue.hasValueDate()) { + return new Date(documentFieldValue.getValueDate()); + } else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) { + return ZonedDateTime.ofInstant( + Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()), + ZoneId.of(documentFieldValue.getValueZonedDate()) + ); + } else if (documentFieldValue.hasValueText()) { + return new Text(documentFieldValue.getValueText()); + } else { + throw new IOException("Can't read generic value of type [" + documentFieldValue + "]"); + } + } + + public static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) { + if (value == null) { + // null is not allowed in protobuf, so we use a special string to represent null + return valueBuilder.setValueString("null"); + } + Class type = value.getClass(); + if (type == String.class) { + valueBuilder.setValueString((String) value); + } else if (type == Integer.class) { + valueBuilder.setValueInt((Integer) value); + } else if (type == Long.class) { + valueBuilder.setValueLong((Long) value); + } else if (type == Float.class) { + valueBuilder.setValueFloat((Float) value); + } else if (type == Double.class) { + valueBuilder.setValueDouble((Double) value); + } else if (type == Boolean.class) { + valueBuilder.setValueBool((Boolean) value); + } else if (type == byte[].class) { + valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value)); + } else if (type == List.class) { + List list = (List) value; + for (Object listValue : list) { + valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder)); + } + } else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) { + Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build()); + } + } else if (type == Date.class) { + valueBuilder.setValueDate(((Date) value).getTime()); + } else if (type == ZonedDateTime.class) { + valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId()); + valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli()); + } else if (type == Text.class) { + valueBuilder.setValueText(((Text) value).string()); + } else { + throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf"); + } + return valueBuilder; + } + +} diff --git a/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java new file mode 100644 index 0000000000000..04d0b5a81bcf6 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldSerializer.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.document.serializer; + +import org.opensearch.common.document.DocumentField; + +import java.io.IOException; + +public interface DocumentFieldSerializer { + + DocumentField createDocumentField(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/common/lucene/Lucene.java b/server/src/main/java/org/opensearch/common/lucene/Lucene.java index 51025209a5fd1..d195859b63fed 100644 --- a/server/src/main/java/org/opensearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/opensearch/common/lucene/Lucene.java @@ -85,7 +85,6 @@ import org.opensearch.common.Nullable; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.iterable.Iterables; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; @@ -654,7 +653,6 @@ public static Explanation readExplanation(StreamInput in) throws IOException { } public static Explanation readExplanation(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; FetchSearchResultProto.SearchHit.Explanation explanationProto = FetchSearchResultProto.SearchHit.Explanation.parseFrom(in); boolean match = explanationProto.getMatch(); String description = explanationProto.getDescription(); diff --git a/server/src/main/java/org/opensearch/search/SearchHit.java b/server/src/main/java/org/opensearch/search/SearchHit.java index 1e43e60611eb9..3cd85488bd0c6 100644 --- a/server/src/main/java/org/opensearch/search/SearchHit.java +++ b/server/src/main/java/org/opensearch/search/SearchHit.java @@ -32,7 +32,6 @@ package org.opensearch.search; -import com.google.protobuf.ByteString; import org.apache.lucene.search.Explanation; import org.opensearch.OpenSearchParseException; import org.opensearch.Version; @@ -69,12 +68,11 @@ import org.opensearch.rest.action.search.RestSearchAction; import org.opensearch.search.fetch.subphase.highlight.HighlightField; import org.opensearch.search.lookup.SourceLookup; +import org.opensearch.search.serializer.SearchHitProtobufSerializer; import org.opensearch.server.proto.FetchSearchResultProto; import org.opensearch.transport.RemoteClusterAware; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -172,7 +170,7 @@ public SearchHit( this.documentFields = documentFields == null ? emptyMap() : documentFields; this.metaFields = metaFields == null ? emptyMap() : metaFields; if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) { - this.searchHitProto = convertHitToProto(this); + this.searchHitProto = SearchHitProtobufSerializer.convertHitToProto(this); } } @@ -246,114 +244,6 @@ public SearchHit(StreamInput in) throws IOException { } } - public SearchHit(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(in); - this.docId = -1; - this.score = this.searchHitProto.getScore(); - this.id = new Text(this.searchHitProto.getId()); - if (!this.searchHitProto.hasNestedIdentity() && this.searchHitProto.getNestedIdentity().toByteArray().length > 0) { - this.nestedIdentity = new NestedIdentity(this.searchHitProto.getNestedIdentity().toByteArray()); - } else { - this.nestedIdentity = null; - } - this.version = this.searchHitProto.getVersion(); - this.seqNo = this.searchHitProto.getSeqNo(); - this.primaryTerm = this.searchHitProto.getPrimaryTerm(); - this.source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray())); - if (source.length() == 0) { - source = null; - } - this.documentFields = new HashMap<>(); - this.searchHitProto.getDocumentFieldsMap().forEach((k, v) -> { - try { - this.documentFields.put(k, new DocumentField(v.toByteArray())); - } catch (IOException e) { - throw new OpenSearchParseException("failed to parse document field", e); - } - }); - this.metaFields = new HashMap<>(); - this.searchHitProto.getMetaFieldsMap().forEach((k, v) -> { - try { - this.metaFields.put(k, new DocumentField(v.toByteArray())); - } catch (IOException e) { - throw new OpenSearchParseException("failed to parse document field", e); - } - }); - this.highlightFields = new HashMap<>(); - this.searchHitProto.getHighlightFieldsMap().forEach((k, v) -> { - try { - this.highlightFields.put(k, new HighlightField(v.toByteArray())); - } catch (IOException e) { - throw new OpenSearchParseException("failed to parse highlight field", e); - } - }); - this.sortValues = new SearchSortValues(this.searchHitProto.getSortValues().toByteArray()); - if (this.searchHitProto.getMatchedQueriesCount() > 0) { - this.matchedQueries = new LinkedHashMap<>(this.searchHitProto.getMatchedQueriesCount()); - for (String query : this.searchHitProto.getMatchedQueriesList()) { - matchedQueries.put(query, Float.NaN); - } - } - if (this.searchHitProto.getMatchedQueriesWithScoresCount() > 0) { - Map tempMap = this.searchHitProto.getMatchedQueriesWithScoresMap() - .entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue())); - this.matchedQueries = tempMap.entrySet() - .stream() - .sorted(Map.Entry.comparingByKey()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)); - } - if (this.searchHitProto.hasExplanation()) { - this.explanation = readExplanation(this.searchHitProto.getExplanation().toByteArray()); - } - SearchShardTarget searchShardTarget = new SearchShardTarget( - this.searchHitProto.getShard().getNodeId(), - new ShardId( - this.searchHitProto.getShard().getShardId().getIndexName(), - this.searchHitProto.getShard().getShardId().getIndexUUID(), - this.searchHitProto.getShard().getShardId().getShardId() - ), - this.searchHitProto.getShard().getClusterAlias(), - OriginalIndices.NONE - ); - shard(searchShardTarget); - if (this.searchHitProto.getInnerHitsCount() > 0) { - this.innerHits = new HashMap<>(); - this.searchHitProto.getInnerHitsMap().forEach((k, v) -> { - try { - this.innerHits.put(k, new SearchHits(v.toByteArray())); - } catch (IOException e) { - throw new OpenSearchParseException("failed to parse inner hits", e); - } - }); - } else { - this.innerHits = null; - } - - } - - public static FetchSearchResultProto.SearchHit convertHitToProto(SearchHit hit) { - FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder(); - if (hit.getIndex() != null) { - searchHitBuilder.setIndex(hit.getIndex()); - } - searchHitBuilder.setId(hit.getId()); - searchHitBuilder.setScore(hit.getScore()); - searchHitBuilder.setSeqNo(hit.getSeqNo()); - searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm()); - searchHitBuilder.setVersion(hit.getVersion()); - searchHitBuilder.setDocId(hit.docId); - if (hit.getSourceRef() != null) { - searchHitBuilder.setSource(ByteString.copyFrom(hit.getSourceRef().toBytesRef().bytes)); - } - for (Map.Entry entry : hit.getFields().entrySet()) { - searchHitBuilder.putDocumentFields(entry.getKey(), DocumentField.convertDocumentFieldToProto(entry.getValue())); - } - return searchHitBuilder.build(); - } - private static final Text SINGLE_MAPPING_TYPE = new Text(MapperService.SINGLE_MAPPING_NAME); @Override @@ -408,10 +298,6 @@ public void writeTo(StreamOutput out) throws IOException { } } - public void writeTo(OutputStream out) throws IOException { - out.write(this.searchHitProto.toByteArray()); - } - public int docId() { return this.docId; } @@ -571,7 +457,7 @@ public void setDocumentField(String fieldName, DocumentField field) { if (documentFields.isEmpty()) this.documentFields = new HashMap<>(); this.documentFields.put(fieldName, field); if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) { - this.searchHitProto = convertHitToProto(this); + this.searchHitProto = SearchHitProtobufSerializer.convertHitToProto(this); } } @@ -1188,26 +1074,6 @@ public NestedIdentity(String field, int offset, NestedIdentity child) { child = in.readOptionalWriteable(NestedIdentity::new); } - NestedIdentity(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - FetchSearchResultProto.SearchHit.NestedIdentity proto = FetchSearchResultProto.SearchHit.NestedIdentity.parseFrom(in); - if (proto.hasField()) { - field = new Text(proto.getField()); - } else { - field = null; - } - if (proto.hasOffset()) { - offset = proto.getOffset(); - } else { - offset = -1; - } - if (proto.hasChild()) { - child = new NestedIdentity(proto.getChild().toByteArray()); - } else { - child = null; - } - } - /** * Returns the nested field in the source this hit originates from */ diff --git a/server/src/main/java/org/opensearch/search/SearchHits.java b/server/src/main/java/org/opensearch/search/SearchHits.java index 630a7c90f8b3d..cd5854898fccd 100644 --- a/server/src/main/java/org/opensearch/search/SearchHits.java +++ b/server/src/main/java/org/opensearch/search/SearchHits.java @@ -32,13 +32,9 @@ package org.opensearch.search; -import com.google.protobuf.ByteString; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.SuppressForbidden; -import org.opensearch.OpenSearchException; import org.opensearch.common.Nullable; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.lucene.Lucene; @@ -50,12 +46,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.action.search.RestSearchAction; -import org.opensearch.server.proto.FetchSearchResultProto; -import org.opensearch.server.proto.QuerySearchResultProto; +import org.opensearch.search.serializer.SearchHitsProtobufSerializer; import java.io.IOException; -import java.io.OutputStream; -import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; @@ -112,84 +105,10 @@ public SearchHits( this.collapseField = collapseField; this.collapseValues = collapseValues; if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { - this.searchHitsProto = convertHitsToProto(this); + this.searchHitsProto = SearchHitsProtobufSerializer.convertHitsToProto(this); } } - public static FetchSearchResultProto.SearchHits convertHitsToProto(SearchHits hits) { - List searchHitList = new ArrayList<>(); - for (SearchHit hit : hits) { - searchHitList.add(SearchHit.convertHitToProto(hit)); - } - QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder(); - if (hits.getTotalHits() != null) { - totalHitsBuilder.setValue(hits.getTotalHits().value); - totalHitsBuilder.setRelation( - hits.getTotalHits().relation == Relation.EQUAL_TO - ? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO - : QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - ); - } - FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder(); - searchHitsBuilder.setMaxScore(hits.getMaxScore()); - searchHitsBuilder.addAllHits(searchHitList); - searchHitsBuilder.setTotalHits(totalHitsBuilder.build()); - if (hits.getSortFields() != null && hits.getSortFields().length > 0) { - for (SortField sortField : hits.getSortFields()) { - FetchSearchResultProto.SortField.Builder sortFieldBuilder = FetchSearchResultProto.SortField.newBuilder(); - if (sortField.getField() != null) { - sortFieldBuilder.setField(sortField.getField()); - } - sortFieldBuilder.setType(FetchSearchResultProto.SortField.Type.valueOf(sortField.getType().name())); - searchHitsBuilder.addSortFields(sortFieldBuilder.build()); - } - } - if (hits.getCollapseField() != null) { - searchHitsBuilder.setCollapseField(hits.getCollapseField()); - for (Object value : hits.getCollapseValues()) { - FetchSearchResultProto.CollapseValue.Builder collapseValueBuilder = FetchSearchResultProto.CollapseValue.newBuilder(); - try { - collapseValueBuilder = readCollapseValueForProtobuf(value, collapseValueBuilder); - } catch (IOException e) { - throw new OpenSearchException(e); - } - searchHitsBuilder.addCollapseValues(collapseValueBuilder.build()); - } - } - return searchHitsBuilder.build(); - } - - private static FetchSearchResultProto.CollapseValue.Builder readCollapseValueForProtobuf( - Object collapseValue, - FetchSearchResultProto.CollapseValue.Builder collapseValueBuilder - ) throws IOException { - Class type = collapseValue.getClass(); - if (type == String.class) { - collapseValueBuilder.setCollapseString((String) collapseValue); - } else if (type == Integer.class || type == Short.class) { - collapseValueBuilder.setCollapseInt((Integer) collapseValue); - } else if (type == Long.class) { - collapseValueBuilder.setCollapseLong((Long) collapseValue); - } else if (type == Float.class) { - collapseValueBuilder.setCollapseFloat((Float) collapseValue); - } else if (type == Double.class) { - collapseValueBuilder.setCollapseDouble((Double) collapseValue); - } else if (type == Byte.class) { - byte b = (Byte) collapseValue; - collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(new byte[] { b })); - } else if (type == Boolean.class) { - collapseValueBuilder.setCollapseBool((Boolean) collapseValue); - } else if (type == BytesRef.class) { - collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(((BytesRef) collapseValue).bytes)); - } else if (type == BigInteger.class) { - BigInteger bigInt = (BigInteger) collapseValue; - collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(bigInt.toByteArray())); - } else { - throw new IOException("Can't handle sort field value of type [" + type + "]"); - } - return collapseValueBuilder; - } - public SearchHits(StreamInput in) throws IOException { if (in.readBoolean()) { totalHits = Lucene.readTotalHits(in); @@ -212,50 +131,6 @@ public SearchHits(StreamInput in) throws IOException { collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new); } - @SuppressForbidden(reason = "serialization of object to protobuf") - public SearchHits(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - this.searchHitsProto = org.opensearch.server.proto.FetchSearchResultProto.SearchHits.parseFrom(in); - this.hits = new SearchHit[this.searchHitsProto.getHitsCount()]; - for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) { - this.hits[i] = new SearchHit(this.searchHitsProto.getHits(i).toByteArray()); - } - this.totalHits = new TotalHits( - this.searchHitsProto.getTotalHits().getValue(), - Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString()) - ); - this.maxScore = this.searchHitsProto.getMaxScore(); - this.sortFields = this.searchHitsProto.getSortFieldsList() - .stream() - .map(sortField -> new SortField(sortField.getField(), SortField.Type.valueOf(sortField.getType().toString()))) - .toArray(SortField[]::new); - this.collapseField = this.searchHitsProto.getCollapseField(); - this.collapseValues = new Object[this.searchHitsProto.getCollapseValuesCount()]; - for (int i = 0; i < this.searchHitsProto.getCollapseValuesCount(); i++) { - this.collapseValues[i] = readCollapseValueFromProtobuf(this.searchHitsProto.getCollapseValues(i)); - } - } - - private Object readCollapseValueFromProtobuf(FetchSearchResultProto.CollapseValue collapseValue) throws IOException { - if (collapseValue.hasCollapseString()) { - return collapseValue.getCollapseString(); - } else if (collapseValue.hasCollapseInt()) { - return collapseValue.getCollapseInt(); - } else if (collapseValue.hasCollapseLong()) { - return collapseValue.getCollapseLong(); - } else if (collapseValue.hasCollapseFloat()) { - return collapseValue.getCollapseFloat(); - } else if (collapseValue.hasCollapseDouble()) { - return collapseValue.getCollapseDouble(); - } else if (collapseValue.hasCollapseBytes()) { - return new BytesRef(collapseValue.getCollapseBytes().toByteArray()); - } else if (collapseValue.hasCollapseBool()) { - return collapseValue.getCollapseBool(); - } else { - throw new IOException("Can't handle sort field value of type [" + collapseValue + "]"); - } - } - @Override public void writeTo(StreamOutput out) throws IOException { final boolean hasTotalHits = totalHits != null; @@ -475,7 +350,4 @@ private static Relation parseRelation(String relation) { } } - public void writeTo(OutputStream out) throws IOException { - out.write(searchHitsProto.toByteArray()); - } } diff --git a/server/src/main/java/org/opensearch/search/SearchSortValues.java b/server/src/main/java/org/opensearch/search/SearchSortValues.java index 1eef677ae6863..d03cc80b90de3 100644 --- a/server/src/main/java/org/opensearch/search/SearchSortValues.java +++ b/server/src/main/java/org/opensearch/search/SearchSortValues.java @@ -32,13 +32,9 @@ package org.opensearch.search; -import com.google.protobuf.ByteString; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.SuppressForbidden; -import org.opensearch.OpenSearchException; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.lucene.Lucene; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -47,12 +43,8 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.search.SearchHit.Fields; -import org.opensearch.server.proto.FetchSearchResultProto; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; -import java.io.ObjectInputStream; import java.util.Arrays; import java.util.Objects; @@ -75,6 +67,11 @@ public class SearchSortValues implements ToXContentFragment, Writeable { this.rawSortValues = EMPTY_ARRAY; } + public SearchSortValues(Object[] sortValues, Object[] rawSortValues) { + this.formattedSortValues = Objects.requireNonNull(sortValues, "sort values must not be empty"); + this.rawSortValues = rawSortValues; + } + public SearchSortValues(Object[] rawSortValues, DocValueFormat[] sortValueFormats) { Objects.requireNonNull(rawSortValues); Objects.requireNonNull(sortValueFormats); @@ -102,34 +99,6 @@ public SearchSortValues(Object[] rawSortValues, DocValueFormat[] sortValueFormat this.rawSortValues = in.readArray(Lucene::readSortValue, Object[]::new); } - @SuppressForbidden(reason = "We need to read from a byte array") - SearchSortValues(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - FetchSearchResultProto.SearchHit.SearchSortValues searchSortValues = FetchSearchResultProto.SearchHit.SearchSortValues.parseFrom( - in - ); - this.formattedSortValues = new Object[searchSortValues.getFormattedSortValuesCount()]; - for (int i = 0; i < searchSortValues.getFormattedSortValuesCount(); i++) { - ByteString formattedSortValue = searchSortValues.getFormattedSortValues(i); - InputStream is = new ByteArrayInputStream(formattedSortValue.toByteArray()); - try (ObjectInputStream ois = new ObjectInputStream(is)) { - this.formattedSortValues[i] = ois.readObject(); - } catch (ClassNotFoundException e) { - throw new OpenSearchException(e); - } - } - this.rawSortValues = new Object[searchSortValues.getRawSortValuesCount()]; - for (int i = 0; i < searchSortValues.getRawSortValuesCount(); i++) { - ByteString rawSortValue = searchSortValues.getRawSortValues(i); - InputStream is = new ByteArrayInputStream(rawSortValue.toByteArray()); - try (ObjectInputStream ois = new ObjectInputStream(is)) { - this.rawSortValues[i] = ois.readObject(); - } catch (ClassNotFoundException e) { - throw new OpenSearchException(e); - } - } - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeArray(Lucene::writeSortValue, this.formattedSortValues); diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java index a9c70d336225d..b46b2c5ca12d8 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java @@ -42,9 +42,11 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.serializer.SearchHitsProtobufSerializer; import org.opensearch.server.proto.FetchSearchResultProto; import org.opensearch.server.proto.ShardSearchRequestProto; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -79,7 +81,8 @@ public FetchSearchResult(InputStream in) throws IOException { this.fetchSearchResultProto.getContextId().getSessionId(), this.fetchSearchResultProto.getContextId().getId() ); - hits = new SearchHits(this.fetchSearchResultProto.getHits().toByteArray()); + SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); } public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) { @@ -106,7 +109,9 @@ public void hits(SearchHits hits) { assert assertNoSearchTarget(hits); this.hits = hits; if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { - this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder().setHits(SearchHits.convertHitsToProto(hits)).build(); + this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder() + .setHits(SearchHitsProtobufSerializer.convertHitsToProto(hits)) + .build(); } } @@ -121,7 +126,8 @@ public SearchHits hits() { if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { SearchHits hits; try { - hits = new SearchHits(this.fetchSearchResultProto.getHits().toByteArray()); + SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); return hits; } catch (IOException e) { throw new RuntimeException(e); diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/HighlightField.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/HighlightField.java index 7ea5f5fdca7d0..30effe2826d76 100644 --- a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/HighlightField.java +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/HighlightField.java @@ -33,7 +33,6 @@ package org.opensearch.search.fetch.subphase.highlight; import org.opensearch.common.annotation.PublicApi; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -42,7 +41,6 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.server.proto.FetchSearchResultProto; import java.io.IOException; import java.util.ArrayList; @@ -79,21 +77,6 @@ public HighlightField(StreamInput in) throws IOException { } } - public HighlightField(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; - FetchSearchResultProto.SearchHit.HighlightField highlightField = FetchSearchResultProto.SearchHit.HighlightField.parseFrom(in); - name = highlightField.getName(); - if (highlightField.getFragmentsCount() == 0) { - fragments = Text.EMPTY_ARRAY; - } else { - List values = new ArrayList<>(); - for (String fragment : highlightField.getFragmentsList()) { - values.add(new Text(fragment)); - } - fragments = values.toArray(new Text[0]); - } - } - public HighlightField(String name, Text[] fragments) { this.name = Objects.requireNonNull(name, "missing highlight field name"); this.fragments = fragments; diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java new file mode 100644 index 0000000000000..dbf9ac5ac48ee --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldProtobufSerializer.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer; + +import org.opensearch.core.common.text.Text; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; + +public class HighlightFieldProtobufSerializer implements HighlightFieldSerializer { + + @Override + public HighlightField createHighLightField(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.HighlightField highlightField = FetchSearchResultProto.SearchHit.HighlightField.parseFrom( + inputStream + ); + String name = highlightField.getName(); + Text[] fragments = Text.EMPTY_ARRAY; + if (highlightField.getFragmentsCount() > 0) { + List values = new ArrayList<>(); + for (String fragment : highlightField.getFragmentsList()) { + values.add(new Text(fragment)); + } + fragments = values.toArray(new Text[0]); + } + return new HighlightField(name, fragments); + } + +} diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java new file mode 100644 index 0000000000000..d62016969885d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldSerializer.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer; + +import org.opensearch.search.fetch.subphase.highlight.HighlightField; + +import java.io.IOException; +import java.io.InputStream; + +public interface HighlightFieldSerializer { + + HighlightField createHighLightField(InputStream inputStream) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java new file mode 100644 index 0000000000000..fc6902193af84 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityProtobufSerializer.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +public class NestedIdentityProtobufSerializer implements NestedIdentitySerializer { + + @Override + public NestedIdentity createNestedIdentity(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.NestedIdentity proto = FetchSearchResultProto.SearchHit.NestedIdentity.parseFrom(inputStream); + String field; + int offset; + NestedIdentity child; + if (proto.hasField()) { + field = proto.getField(); + } else { + field = null; + } + if (proto.hasOffset()) { + offset = proto.getOffset(); + } else { + offset = -1; + } + if (proto.hasChild()) { + child = createNestedIdentity(new ByteArrayInputStream(proto.getChild().toByteArray())); + } else { + child = null; + } + return new NestedIdentity(field, offset, child); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java b/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java new file mode 100644 index 0000000000000..7631fc61b28c7 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/NestedIdentitySerializer.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit.NestedIdentity; + +import java.io.IOException; + +public interface NestedIdentitySerializer { + + public NestedIdentity createNestedIdentity(T inputStream) throws IOException, Exception; +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java new file mode 100644 index 0000000000000..78e2f76e6a8d1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitProtobufSerializer.java @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.Explanation; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.document.serializer.DocumentFieldProtobufSerializer; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.SearchSortValues; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.search.fetch.subphase.highlight.serializer.HighlightFieldProtobufSerializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + +public class SearchHitProtobufSerializer implements SearchHitSerializer { + + private FetchSearchResultProto.SearchHit searchHitProto; + + @Override + public SearchHit createSearchHit(InputStream inputStream) throws IOException { + this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(inputStream); + int docId = -1; + float score = this.searchHitProto.getScore(); + String id = this.searchHitProto.getId(); + NestedIdentity nestedIdentity; + if (!this.searchHitProto.hasNestedIdentity() && this.searchHitProto.getNestedIdentity().toByteArray().length > 0) { + NestedIdentityProtobufSerializer protobufSerializer = new NestedIdentityProtobufSerializer(); + nestedIdentity = protobufSerializer.createNestedIdentity( + new ByteArrayInputStream(this.searchHitProto.getNestedIdentity().toByteArray()) + ); + } else { + nestedIdentity = null; + } + long version = this.searchHitProto.getVersion(); + long seqNo = this.searchHitProto.getSeqNo(); + long primaryTerm = this.searchHitProto.getPrimaryTerm(); + BytesReference source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray())); + if (source.length() == 0) { + source = null; + } + Map documentFields = new HashMap<>(); + DocumentFieldProtobufSerializer protobufSerializer = new DocumentFieldProtobufSerializer(); + this.searchHitProto.getDocumentFieldsMap().forEach((k, v) -> { + try { + documentFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map metaFields = new HashMap<>(); + this.searchHitProto.getMetaFieldsMap().forEach((k, v) -> { + try { + metaFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map highlightFields = new HashMap<>(); + HighlightFieldProtobufSerializer highlightFieldProtobufSerializer = new HighlightFieldProtobufSerializer(); + this.searchHitProto.getHighlightFieldsMap().forEach((k, v) -> { + try { + highlightFields.put(k, highlightFieldProtobufSerializer.createHighLightField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse highlight field", e); + } + }); + SearchSortValuesProtobufSerializer sortValueProtobufSerializer = new SearchSortValuesProtobufSerializer(); + SearchSortValues sortValues = sortValueProtobufSerializer.createSearchSortValues( + new ByteArrayInputStream(this.searchHitProto.getSortValues().toByteArray()) + ); + Map matchedQueries = new HashMap<>(); + if (this.searchHitProto.getMatchedQueriesCount() > 0) { + matchedQueries = new LinkedHashMap<>(this.searchHitProto.getMatchedQueriesCount()); + for (String query : this.searchHitProto.getMatchedQueriesList()) { + matchedQueries.put(query, Float.NaN); + } + } + if (this.searchHitProto.getMatchedQueriesWithScoresCount() > 0) { + Map tempMap = this.searchHitProto.getMatchedQueriesWithScoresMap() + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue())); + matchedQueries = tempMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByKey()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + } + Explanation explanation = null; + if (this.searchHitProto.hasExplanation()) { + explanation = Lucene.readExplanation(this.searchHitProto.getExplanation().toByteArray()); + } + SearchShardTarget searchShardTarget = new SearchShardTarget( + this.searchHitProto.getShard().getNodeId(), + new ShardId( + this.searchHitProto.getShard().getShardId().getIndexName(), + this.searchHitProto.getShard().getShardId().getIndexUUID(), + this.searchHitProto.getShard().getShardId().getShardId() + ), + this.searchHitProto.getShard().getClusterAlias(), + OriginalIndices.NONE + ); + Map innerHits; + if (this.searchHitProto.getInnerHitsCount() > 0) { + innerHits = new HashMap<>(); + this.searchHitProto.getInnerHitsMap().forEach((k, v) -> { + try { + SearchHitsProtobufSerializer protobufHitsFactory = new SearchHitsProtobufSerializer(); + innerHits.put(k, protobufHitsFactory.createSearchHits(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse inner hits", e); + } + }); + } else { + innerHits = null; + } + SearchHit searchHit = new SearchHit(docId, id, nestedIdentity, documentFields, metaFields); + searchHit.score(score); + searchHit.version(version); + searchHit.setSeqNo(seqNo); + searchHit.setPrimaryTerm(primaryTerm); + searchHit.sourceRef(source); + searchHit.highlightFields(highlightFields); + searchHit.sortValues(sortValues); + searchHit.matchedQueriesWithScores(matchedQueries); + searchHit.explanation(explanation); + searchHit.shard(searchShardTarget); + searchHit.setInnerHits(innerHits); + return searchHit; + } + + public static FetchSearchResultProto.SearchHit convertHitToProto(SearchHit hit) { + FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder(); + if (hit.getIndex() != null) { + searchHitBuilder.setIndex(hit.getIndex()); + } + searchHitBuilder.setId(hit.getId()); + searchHitBuilder.setScore(hit.getScore()); + searchHitBuilder.setSeqNo(hit.getSeqNo()); + searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm()); + searchHitBuilder.setVersion(hit.getVersion()); + searchHitBuilder.setDocId(hit.docId()); + if (hit.getSourceRef() != null) { + searchHitBuilder.setSource(ByteString.copyFrom(hit.getSourceRef().toBytesRef().bytes)); + } + for (Map.Entry entry : hit.getFields().entrySet()) { + searchHitBuilder.putDocumentFields(entry.getKey(), DocumentField.convertDocumentFieldToProto(entry.getValue())); + } + return searchHitBuilder.build(); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java new file mode 100644 index 0000000000000..0f66eeb9d03da --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitSerializer.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit; + +import java.io.IOException; + +public interface SearchHitSerializer { + + SearchHit createSearchHit(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java new file mode 100644 index 0000000000000..b61fc194200b5 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitsProtobufSerializer.java @@ -0,0 +1,157 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; +import org.apache.lucene.util.BytesRef; +import org.opensearch.OpenSearchException; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.QuerySearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +public class SearchHitsProtobufSerializer implements SearchHitsSerializer { + + private FetchSearchResultProto.SearchHits searchHitsProto; + + @Override + public SearchHits createSearchHits(InputStream inputStream) throws IOException { + this.searchHitsProto = FetchSearchResultProto.SearchHits.parseFrom(inputStream); + SearchHit[] hits = new SearchHit[this.searchHitsProto.getHitsCount()]; + SearchHitProtobufSerializer protobufSerializer = new SearchHitProtobufSerializer(); + for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) { + hits[i] = protobufSerializer.createSearchHit(new ByteArrayInputStream(this.searchHitsProto.getHits(i).toByteArray())); + } + TotalHits totalHits = new TotalHits( + this.searchHitsProto.getTotalHits().getValue(), + Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString()) + ); + float maxScore = this.searchHitsProto.getMaxScore(); + SortField[] sortFields = this.searchHitsProto.getSortFieldsList() + .stream() + .map(sortField -> new SortField(sortField.getField(), SortField.Type.valueOf(sortField.getType().toString()))) + .toArray(SortField[]::new); + String collapseField = this.searchHitsProto.getCollapseField(); + Object[] collapseValues = new Object[this.searchHitsProto.getCollapseValuesCount()]; + for (int i = 0; i < this.searchHitsProto.getCollapseValuesCount(); i++) { + collapseValues[i] = readSortValueFromProtobuf(this.searchHitsProto.getCollapseValues(i)); + } + return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues); + } + + public static Object readSortValueFromProtobuf(FetchSearchResultProto.SortValue collapseValue) throws IOException { + if (collapseValue.hasCollapseString()) { + return collapseValue.getCollapseString(); + } else if (collapseValue.hasCollapseInt()) { + return collapseValue.getCollapseInt(); + } else if (collapseValue.hasCollapseLong()) { + return collapseValue.getCollapseLong(); + } else if (collapseValue.hasCollapseFloat()) { + return collapseValue.getCollapseFloat(); + } else if (collapseValue.hasCollapseDouble()) { + return collapseValue.getCollapseDouble(); + } else if (collapseValue.hasCollapseBytes()) { + return new BytesRef(collapseValue.getCollapseBytes().toByteArray()); + } else if (collapseValue.hasCollapseBool()) { + return collapseValue.getCollapseBool(); + } else { + throw new IOException("Can't handle sort field value of type [" + collapseValue + "]"); + } + } + + public static FetchSearchResultProto.SearchHits convertHitsToProto(SearchHits hits) { + List searchHitList = new ArrayList<>(); + for (SearchHit hit : hits) { + searchHitList.add(SearchHitProtobufSerializer.convertHitToProto(hit)); + } + QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder(); + if (hits.getTotalHits() != null) { + totalHitsBuilder.setValue(hits.getTotalHits().value); + totalHitsBuilder.setRelation( + hits.getTotalHits().relation == Relation.EQUAL_TO + ? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO + : QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ); + } + FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder(); + searchHitsBuilder.setMaxScore(hits.getMaxScore()); + searchHitsBuilder.addAllHits(searchHitList); + searchHitsBuilder.setTotalHits(totalHitsBuilder.build()); + if (hits.getSortFields() != null && hits.getSortFields().length > 0) { + for (SortField sortField : hits.getSortFields()) { + FetchSearchResultProto.SortField.Builder sortFieldBuilder = FetchSearchResultProto.SortField.newBuilder(); + if (sortField.getField() != null) { + sortFieldBuilder.setField(sortField.getField()); + } + sortFieldBuilder.setType(FetchSearchResultProto.SortField.Type.valueOf(sortField.getType().name())); + searchHitsBuilder.addSortFields(sortFieldBuilder.build()); + } + } + if (hits.getCollapseField() != null) { + searchHitsBuilder.setCollapseField(hits.getCollapseField()); + for (Object value : hits.getCollapseValues()) { + FetchSearchResultProto.SortValue.Builder collapseValueBuilder = FetchSearchResultProto.SortValue.newBuilder(); + try { + collapseValueBuilder = readSortValueForProtobuf(value, collapseValueBuilder); + } catch (IOException e) { + throw new OpenSearchException(e); + } + searchHitsBuilder.addCollapseValues(collapseValueBuilder.build()); + } + } + return searchHitsBuilder.build(); + } + + public static FetchSearchResultProto.SortValue.Builder readSortValueForProtobuf( + Object collapseValue, + FetchSearchResultProto.SortValue.Builder collapseValueBuilder + ) throws IOException { + Class type = collapseValue.getClass(); + if (type == String.class) { + collapseValueBuilder.setCollapseString((String) collapseValue); + } else if (type == Integer.class || type == Short.class) { + collapseValueBuilder.setCollapseInt((Integer) collapseValue); + } else if (type == Long.class) { + collapseValueBuilder.setCollapseLong((Long) collapseValue); + } else if (type == Float.class) { + collapseValueBuilder.setCollapseFloat((Float) collapseValue); + } else if (type == Double.class) { + collapseValueBuilder.setCollapseDouble((Double) collapseValue); + } else if (type == Byte.class) { + byte b = (Byte) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(new byte[] { b })); + } else if (type == Boolean.class) { + collapseValueBuilder.setCollapseBool((Boolean) collapseValue); + } else if (type == BytesRef.class) { + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(((BytesRef) collapseValue).bytes)); + } else if (type == BigInteger.class) { + BigInteger bigInt = (BigInteger) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(bigInt.toByteArray())); + } else { + throw new IOException("Can't handle sort field value of type [" + type + "]"); + } + return collapseValueBuilder; + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java new file mode 100644 index 0000000000000..f61af18682627 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitsSerializer.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHits; + +import java.io.IOException; + +public interface SearchHitsSerializer { + + SearchHits createSearchHits(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java new file mode 100644 index 0000000000000..dec522ba4a7ef --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesProtobufSerializer.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchSortValues; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; + +public class SearchSortValuesProtobufSerializer implements SearchSortValuesSerializer { + + @Override + public SearchSortValues createSearchSortValues(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.SearchSortValues searchSortValues = FetchSearchResultProto.SearchHit.SearchSortValues.parseFrom( + inputStream + ); + Object[] formattedSortValues = new Object[searchSortValues.getFormattedSortValuesCount()]; + for (int i = 0; i < searchSortValues.getFormattedSortValuesCount(); i++) { + formattedSortValues[i] = SearchHitsProtobufSerializer.readSortValueFromProtobuf(searchSortValues.getFormattedSortValues(i)); + } + Object[] rawSortValues = new Object[searchSortValues.getRawSortValuesCount()]; + for (int i = 0; i < searchSortValues.getRawSortValuesCount(); i++) { + rawSortValues[i] = SearchHitsProtobufSerializer.readSortValueFromProtobuf(searchSortValues.getRawSortValues(i)); + } + return new SearchSortValues(formattedSortValues, rawSortValues); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java new file mode 100644 index 0000000000000..6a9a23700635d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesSerializer.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchSortValues; + +import java.io.IOException; + +public interface SearchSortValuesSerializer { + + SearchSortValues createSearchSortValues(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java b/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java index 112b55fdf72be..db0ddfc3f0cd0 100644 --- a/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java @@ -8,12 +8,15 @@ package org.opensearch.transport; +import org.opensearch.common.annotation.ExperimentalApi; + /** * Base class for inbound data as a message. * Different implementations are used for different protocols. * * @opensearch.internal */ +@ExperimentalApi public interface BaseInboundMessage { /** diff --git a/server/src/main/proto/server/search/FetchSearchResultProto.proto b/server/src/main/proto/server/search/FetchSearchResultProto.proto index b126ca439e9f2..88f306d31a173 100644 --- a/server/src/main/proto/server/search/FetchSearchResultProto.proto +++ b/server/src/main/proto/server/search/FetchSearchResultProto.proto @@ -30,7 +30,7 @@ message SearchHits { repeated SearchHit hits = 4; repeated SortField sortFields = 5; optional string collapseField = 6; - repeated CollapseValue collapseValues = 7; + repeated SortValue collapseValues = 7; } message SearchHit { @@ -71,8 +71,8 @@ message SearchHit { } message SearchSortValues { - repeated bytes formattedSortValues = 1; - repeated bytes rawSortValues = 2; + repeated SortValue formattedSortValues = 1; + repeated SortValue rawSortValues = 2; } message Explanation { @@ -105,7 +105,7 @@ message SortField { } } -message CollapseValue { +message SortValue { optional string collapseString = 1; optional int32 collapseInt = 2; optional int64 collapseLong = 3;