Skip to content

Commit

Permalink
Extracting protobuf serialization from model classes
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Mar 26, 2024
1 parent 7d1d117 commit c96647c
Show file tree
Hide file tree
Showing 25 changed files with 787 additions and 446 deletions.
4 changes: 4 additions & 0 deletions distribution/src/config/opensearch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ ${path.logs}
# Gates the functionality of enabling Opensearch to use pluggable caches with respective store names via setting.
#
#opensearch.experimental.feature.pluggable.caching.enabled: false
#
# Gates the functionality of enabling Opensearch to use protobuf with basic searches and for node-to-node communication.
#
#opensearch.experimental.feature.search_with_protobuf.enabled: false
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ private void process(ExecutableElement executable, ReferenceType ref) {
return;
}

// Skip protobuf generated classes used in public apis
if (ref.toString().contains("Proto")) {
return;
}

if (ref instanceof DeclaredType) {
final DeclaredType declaredType = (DeclaredType) ref;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ public String toString() {

@Override
public Response read(StreamInput in) throws IOException {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'read'");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
Expand Down Expand Up @@ -115,11 +114,7 @@ public QueryPhaseResultConsumer(

SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) {
this.hasAggs = false;
} else {
this.hasAggs = source != null && source.aggregations() != null;
}
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
}
Expand Down Expand Up @@ -325,7 +320,7 @@ synchronized long addEstimateAndMaybeBreak(long estimatedSize) {
* provided {@link QuerySearchResult}.
*/
long ramBytesUsedQueryResult(QuerySearchResult result) {
if (hasAggs == false || FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) {
if (hasAggs == false) {
return 0;
}
return result.aggregations().asSerialized(InternalAggregations::readFrom, namedWriteableRegistry).ramBytesUsed();
Expand Down Expand Up @@ -494,7 +489,7 @@ public synchronized List<TopDocs> consumeTopDocs() {
}

public synchronized List<InternalAggregations> consumeAggs() {
if (hasAggs == false || FeatureFlags.isEnabled(FeatureFlags.PROTOBUF)) {
if (hasAggs == false) {
return Collections.emptyList();
}
List<InternalAggregations> aggsList = new ArrayList<>();
Expand Down
110 changes: 2 additions & 108 deletions server/src/main/java/org/opensearch/common/document/DocumentField.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,22 @@

package org.opensearch.common.document;

import com.google.protobuf.ByteString;
import org.opensearch.OpenSearchException;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.common.document.serializer.DocumentFieldProtobufSerializer;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.text.Text;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.get.GetResult;
import org.opensearch.search.SearchHit;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder;

import java.io.IOException;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
Expand Down Expand Up @@ -90,110 +78,16 @@ public DocumentField(String name, List<Object> values) {
this.values = Objects.requireNonNull(values, "values must not be null");
}

public DocumentField(byte[] in) throws IOException {
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(in);
name = documentField.getName();
values = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) {
values.add(readDocumentFieldValueFromProtobuf(value));
}
}

public static FetchSearchResultProto.SearchHit.DocumentField convertDocumentFieldToProto(DocumentField documentField) {
FetchSearchResultProto.SearchHit.DocumentField.Builder builder = FetchSearchResultProto.SearchHit.DocumentField.newBuilder();
builder.setName(documentField.getName());
for (Object value : documentField.getValues()) {
FetchSearchResultProto.DocumentFieldValue.Builder valueBuilder = FetchSearchResultProto.DocumentFieldValue.newBuilder();
builder.addValues(convertDocumentFieldValueToProto(value, valueBuilder));
builder.addValues(DocumentFieldProtobufSerializer.convertDocumentFieldValueToProto(value, valueBuilder));
}
return builder.build();
}

private static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) {
if (value == null) {
// null is not allowed in protobuf, so we use a special string to represent null
return valueBuilder.setValueString("null");
}
Class type = value.getClass();
if (type == String.class) {
valueBuilder.setValueString((String) value);
} else if (type == Integer.class) {
valueBuilder.setValueInt((Integer) value);
} else if (type == Long.class) {
valueBuilder.setValueLong((Long) value);
} else if (type == Float.class) {
valueBuilder.setValueFloat((Float) value);
} else if (type == Double.class) {
valueBuilder.setValueDouble((Double) value);
} else if (type == Boolean.class) {
valueBuilder.setValueBool((Boolean) value);
} else if (type == byte[].class) {
valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value));
} else if (type == List.class) {
List<Object> list = (List<Object>) value;
for (Object listValue : list) {
valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder));
}
} else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) {
Map<String, Object> map = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : map.entrySet()) {
valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build());
}
} else if (type == Date.class) {
valueBuilder.setValueDate(((Date) value).getTime());
} else if (type == ZonedDateTime.class) {
valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId());
valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli());
} else if (type == Text.class) {
valueBuilder.setValueText(((Text) value).string());
} else {
throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf");
}
return valueBuilder;
}

private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException {
if (documentFieldValue.hasValueString()) {
return documentFieldValue.getValueString();
} else if (documentFieldValue.hasValueInt()) {
return documentFieldValue.getValueInt();
} else if (documentFieldValue.hasValueLong()) {
return documentFieldValue.getValueLong();
} else if (documentFieldValue.hasValueFloat()) {
return documentFieldValue.getValueFloat();
} else if (documentFieldValue.hasValueDouble()) {
return documentFieldValue.getValueDouble();
} else if (documentFieldValue.hasValueBool()) {
return documentFieldValue.getValueBool();
} else if (documentFieldValue.getValueByteArrayList().size() > 0) {
return documentFieldValue.getValueByteArrayList().toArray();
} else if (documentFieldValue.getValueArrayListList().size() > 0) {
List<Object> list = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) {
list.add(readDocumentFieldValueFromProtobuf(value));
}
return list;
} else if (documentFieldValue.getValueMapMap().size() > 0) {
Map<String, Object> map = Map.of();
for (Map.Entry<String, FetchSearchResultProto.DocumentFieldValue> entrySet : documentFieldValue.getValueMapMap().entrySet()) {
map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue()));
}
return map;
} else if (documentFieldValue.hasValueDate()) {
return new Date(documentFieldValue.getValueDate());
} else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) {
return ZonedDateTime.ofInstant(
Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()),
ZoneId.of(documentFieldValue.getValueZonedDate())
);
} else if (documentFieldValue.hasValueText()) {
return new Text(documentFieldValue.getValueText());
} else {
throw new IOException("Can't read generic value of type [" + documentFieldValue + "]");
}
}

/**
* The name of the field.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.common.document.serializer;

import com.google.protobuf.ByteString;
import org.opensearch.OpenSearchException;
import org.opensearch.common.document.DocumentField;
import org.opensearch.core.common.text.Text;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder;

import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class DocumentFieldProtobufSerializer implements DocumentFieldSerializer<InputStream> {

private FetchSearchResultProto.SearchHit.DocumentField documentField;

@Override
public DocumentField createDocumentField(InputStream inputStream) throws IOException {
documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(inputStream);
String name = documentField.getName();
List<Object> values = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) {
values.add(readDocumentFieldValueFromProtobuf(value));
}
return new DocumentField(name, values);
}

private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException {
if (documentFieldValue.hasValueString()) {
return documentFieldValue.getValueString();
} else if (documentFieldValue.hasValueInt()) {
return documentFieldValue.getValueInt();
} else if (documentFieldValue.hasValueLong()) {
return documentFieldValue.getValueLong();
} else if (documentFieldValue.hasValueFloat()) {
return documentFieldValue.getValueFloat();
} else if (documentFieldValue.hasValueDouble()) {
return documentFieldValue.getValueDouble();
} else if (documentFieldValue.hasValueBool()) {
return documentFieldValue.getValueBool();
} else if (documentFieldValue.getValueByteArrayList().size() > 0) {
return documentFieldValue.getValueByteArrayList().toArray();
} else if (documentFieldValue.getValueArrayListList().size() > 0) {
List<Object> list = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) {
list.add(readDocumentFieldValueFromProtobuf(value));
}
return list;
} else if (documentFieldValue.getValueMapMap().size() > 0) {
Map<String, Object> map = Map.of();
for (Map.Entry<String, FetchSearchResultProto.DocumentFieldValue> entrySet : documentFieldValue.getValueMapMap().entrySet()) {
map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue()));
}
return map;
} else if (documentFieldValue.hasValueDate()) {
return new Date(documentFieldValue.getValueDate());
} else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) {
return ZonedDateTime.ofInstant(
Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()),
ZoneId.of(documentFieldValue.getValueZonedDate())
);
} else if (documentFieldValue.hasValueText()) {
return new Text(documentFieldValue.getValueText());
} else {
throw new IOException("Can't read generic value of type [" + documentFieldValue + "]");
}
}

public static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) {
if (value == null) {
// null is not allowed in protobuf, so we use a special string to represent null
return valueBuilder.setValueString("null");
}
Class type = value.getClass();
if (type == String.class) {
valueBuilder.setValueString((String) value);
} else if (type == Integer.class) {
valueBuilder.setValueInt((Integer) value);
} else if (type == Long.class) {
valueBuilder.setValueLong((Long) value);
} else if (type == Float.class) {
valueBuilder.setValueFloat((Float) value);
} else if (type == Double.class) {
valueBuilder.setValueDouble((Double) value);
} else if (type == Boolean.class) {
valueBuilder.setValueBool((Boolean) value);
} else if (type == byte[].class) {
valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value));
} else if (type == List.class) {
List<Object> list = (List<Object>) value;
for (Object listValue : list) {
valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder));
}
} else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) {
Map<String, Object> map = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : map.entrySet()) {
valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build());
}
} else if (type == Date.class) {
valueBuilder.setValueDate(((Date) value).getTime());
} else if (type == ZonedDateTime.class) {
valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId());
valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli());
} else if (type == Text.class) {
valueBuilder.setValueText(((Text) value).string());
} else {
throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf");
}
return valueBuilder;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.common.document.serializer;

import org.opensearch.common.document.DocumentField;

import java.io.IOException;

public interface DocumentFieldSerializer<T> {

DocumentField createDocumentField(T inputStream) throws IOException;

}
2 changes: 0 additions & 2 deletions server/src/main/java/org/opensearch/common/lucene/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.common.util.iterable.Iterables;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -654,7 +653,6 @@ public static Explanation readExplanation(StreamInput in) throws IOException {
}

public static Explanation readExplanation(byte[] in) throws IOException {
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
FetchSearchResultProto.SearchHit.Explanation explanationProto = FetchSearchResultProto.SearchHit.Explanation.parseFrom(in);
boolean match = explanationProto.getMatch();
String description = explanationProto.getDescription();
Expand Down
Loading

0 comments on commit c96647c

Please sign in to comment.