From 3d6726256b72a81ab00ab355f38ecae9512abd62 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 29 Aug 2023 14:39:37 -0700 Subject: [PATCH] Add SearchExtBuilders to SearchResponse (#9379) * Add SearchExtBuilders to SearchResponse. [Issue #9328](https://github.com/opensearch-project/OpenSearch/issues/9328) Signed-off-by: Austin Lee * Keep SearchResponse immutable, add a constructor to take a List of SearchExtBuilders. Signed-off-by: Austin Lee * Fix spotlessJavaCheck findings. Signed-off-by: Austin Lee * Move SearchExtBuilders into SearchResponseSections, fix indenting in SearchRequest. Signed-off-by: Austin Lee * Updated changelog (mixed minor formatting issues), added version checks on serialization/deserialization, added a Builder for making copies of SearchResponse easier. Signed-off-by: Austin Lee * Add GenericSearchExtBuilder as a catch-all for SearchExtBuilders not registered in xcontent registry. Signed-off-by: Austin Lee * Simplify GenericSearchExtBuilder using a single Object member. Signed-off-by: Austin Lee * Address additional review comments. Signed-off-by: Austin Lee * Add Javadocs. Signed-off-by: Austin Lee --------- Signed-off-by: Austin Lee Signed-off-by: Ivan Brusic --- CHANGELOG.md | 11 +- .../action/search/SearchResponse.java | 36 +- .../action/search/SearchResponseSections.java | 33 ++ .../search/GenericSearchExtBuilder.java | 165 +++++++ .../internal/InternalSearchResponse.java | 33 +- .../action/search/SearchResponseTests.java | 209 ++++++++- .../search/GenericSearchExtBuilderTests.java | 422 ++++++++++++++++++ 7 files changed, 893 insertions(+), 16 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/GenericSearchExtBuilder.java create mode 100644 server/src/test/java/org/opensearch/search/GenericSearchExtBuilderTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c26302cb2272..eabc17f81917f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,17 +78,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - Add server version as REST response header [#6583](https://github.com/opensearch-project/OpenSearch/issues/6583) -- Start replication checkpointTimers on primary before segments upload to remote store. ([#8221]()https://github.com/opensearch-project/OpenSearch/pull/8221) -- [distribution/archives] [Linux] [x64] Provide the variant of the distributions bundled with JRE ([#8195]()https://github.com/opensearch-project/OpenSearch/pull/8195) +- Start replication checkpointTimers on primary before segments upload to remote store. ([#8221](https://github.com/opensearch-project/OpenSearch/pull/8221)) +- [distribution/archives] [Linux] [x64] Provide the variant of the distributions bundled with JRE ([#8195](https://github.com/opensearch-project/OpenSearch/pull/8195)) - Add configuration for file cache size to max remote data ratio to prevent oversubscription of file cache ([#8606](https://github.com/opensearch-project/OpenSearch/pull/8606)) -- Disallow compression level to be set for default and best_compression index codecs ([#8737]()https://github.com/opensearch-project/OpenSearch/pull/8737) +- Disallow compression level to be set for default and best_compression index codecs ([#8737](https://github.com/opensearch-project/OpenSearch/pull/8737)) - Prioritize replica shard movement during shard relocation ([#8875](https://github.com/opensearch-project/OpenSearch/pull/8875)) -- Introducing Default and Best Compression codecs as their algorithm name ([#9123]()https://github.com/opensearch-project/OpenSearch/pull/9123) -- Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122]()https://github.com/opensearch-project/OpenSearch/pull/9122) +- Introducing Default and Best Compression codecs as their algorithm name ([#9123](https://github.com/opensearch-project/OpenSearch/pull/9123)) +- Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122](https://github.com/opensearch-project/OpenSearch/pull/9122)) - [BWC and API enforcement] Define the initial set of annotations, their meaning and relations between them ([#9223](https://github.com/opensearch-project/OpenSearch/pull/9223)) - [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212)) - [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081)) - Add support for reading partial files to HDFS repository ([#9513](https://github.com/opensearch-project/OpenSearch/issues/9513)) +- Add support for extensions to search responses using SearchExtBuilder ([#9379](https://github.com/opensearch-project/OpenSearch/pull/9379)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponse.java b/server/src/main/java/org/opensearch/action/search/SearchResponse.java index cfdbe5647df5a..a546311a1f668 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponse.java @@ -46,9 +46,12 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParser.Token; import org.opensearch.rest.action.RestActions; +import org.opensearch.search.GenericSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; @@ -65,6 +68,7 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.opensearch.action.search.SearchResponseSections.EXT_FIELD; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; /** @@ -312,6 +316,7 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t ); clusters.toXContent(builder, params); internalResponse.toXContent(builder, params); + return builder; } @@ -339,6 +344,7 @@ public static SearchResponse innerFromXContent(XContentParser parser) throws IOE String searchContextId = null; List failures = new ArrayList<>(); Clusters clusters = Clusters.EMPTY; + List extBuilders = new ArrayList<>(); for (Token token = parser.nextToken(); token != Token.END_OBJECT; token = parser.nextToken()) { if (token == Token.FIELD_NAME) { currentFieldName = parser.currentName(); @@ -417,6 +423,33 @@ public static SearchResponse innerFromXContent(XContentParser parser) throws IOE } } clusters = new Clusters(total, successful, skipped); + } else if (EXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + String extSectionName = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + extSectionName = parser.currentName(); + } else { + SearchExtBuilder searchExtBuilder; + try { + searchExtBuilder = parser.namedObject(SearchExtBuilder.class, extSectionName, null); + if (!searchExtBuilder.getWriteableName().equals(extSectionName)) { + throw new IllegalStateException( + "The parsed [" + + searchExtBuilder.getClass().getName() + + "] object has a " + + "different writeable name compared to the name of the section that it was parsed from: found [" + + searchExtBuilder.getWriteableName() + + "] expected [" + + extSectionName + + "]" + ); + } + } catch (XContentParseException e) { + searchExtBuilder = GenericSearchExtBuilder.fromXContent(parser); + } + extBuilders.add(searchExtBuilder); + } + } } else { parser.skipChildren(); } @@ -429,7 +462,8 @@ public static SearchResponse innerFromXContent(XContentParser parser) throws IOE timedOut, terminatedEarly, profile, - numReducePhases + numReducePhases, + extBuilders ); return new SearchResponse( searchResponseSections, diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java b/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java index 214bc0448b90c..2e447abd125c5 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java @@ -32,9 +32,11 @@ package org.opensearch.action.search; +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.profile.ProfileShardResult; @@ -42,8 +44,11 @@ import org.opensearch.search.suggest.Suggest; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Objects; /** * Base class that holds the various sections which a search response is @@ -57,6 +62,8 @@ */ public class SearchResponseSections implements ToXContentFragment { + public static final ParseField EXT_FIELD = new ParseField("ext"); + protected final SearchHits hits; protected final Aggregations aggregations; protected final Suggest suggest; @@ -64,6 +71,7 @@ public class SearchResponseSections implements ToXContentFragment { protected final boolean timedOut; protected final Boolean terminatedEarly; protected final int numReducePhases; + protected final List searchExtBuilders = new ArrayList<>(); public SearchResponseSections( SearchHits hits, @@ -73,6 +81,19 @@ public SearchResponseSections( Boolean terminatedEarly, SearchProfileShardResults profileResults, int numReducePhases + ) { + this(hits, aggregations, suggest, timedOut, terminatedEarly, profileResults, numReducePhases, Collections.emptyList()); + } + + public SearchResponseSections( + SearchHits hits, + Aggregations aggregations, + Suggest suggest, + boolean timedOut, + Boolean terminatedEarly, + SearchProfileShardResults profileResults, + int numReducePhases, + List searchExtBuilders ) { this.hits = hits; this.aggregations = aggregations; @@ -81,6 +102,7 @@ public SearchResponseSections( this.timedOut = timedOut; this.terminatedEarly = terminatedEarly; this.numReducePhases = numReducePhases; + this.searchExtBuilders.addAll(Objects.requireNonNull(searchExtBuilders, "searchExtBuilders must not be null")); } public final boolean timedOut() { @@ -135,9 +157,20 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) if (profileResults != null) { profileResults.toXContent(builder, params); } + if (!searchExtBuilders.isEmpty()) { + builder.startObject(EXT_FIELD.getPreferredName()); + for (SearchExtBuilder searchExtBuilder : searchExtBuilders) { + searchExtBuilder.toXContent(builder, params); + } + builder.endObject(); + } return builder; } + public List getSearchExtBuilders() { + return Collections.unmodifiableList(this.searchExtBuilders); + } + protected void writeTo(StreamOutput out) throws IOException { throw new UnsupportedOperationException(); } diff --git a/server/src/main/java/org/opensearch/search/GenericSearchExtBuilder.java b/server/src/main/java/org/opensearch/search/GenericSearchExtBuilder.java new file mode 100644 index 0000000000000..35e68f78774e3 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/GenericSearchExtBuilder.java @@ -0,0 +1,165 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * This is a catch-all SearchExtBuilder implementation that is used when an appropriate SearchExtBuilder + * is not found during SearchResponse's fromXContent operation. + */ +public final class GenericSearchExtBuilder extends SearchExtBuilder { + + public final static ParseField EXT_BUILDER_NAME = new ParseField("generic_ext"); + + private final Object genericObj; + private final ValueType valueType; + + enum ValueType { + SIMPLE(0), + MAP(1), + LIST(2); + + private final int value; + + ValueType(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + static ValueType fromInt(int value) { + switch (value) { + case 0: + return SIMPLE; + case 1: + return MAP; + case 2: + return LIST; + default: + throw new IllegalArgumentException("Unsupported value: " + value); + } + } + } + + public GenericSearchExtBuilder(Object genericObj, ValueType valueType) { + this.genericObj = genericObj; + this.valueType = valueType; + } + + public GenericSearchExtBuilder(StreamInput in) throws IOException { + valueType = ValueType.fromInt(in.readInt()); + switch (valueType) { + case SIMPLE: + genericObj = in.readGenericValue(); + break; + case MAP: + genericObj = in.readMap(); + break; + case LIST: + genericObj = in.readList(r -> r.readGenericValue()); + break; + default: + throw new IllegalStateException("Unable to construct GenericSearchExtBuilder from incoming stream."); + } + } + + public static GenericSearchExtBuilder fromXContent(XContentParser parser) throws IOException { + // Look at the parser's next token. + // If it's START_OBJECT, parse as map, if it's START_ARRAY, parse as list, else + // parse as simpleVal + XContentParser.Token token = parser.currentToken(); + ValueType valueType; + Object genericObj; + if (token == XContentParser.Token.START_OBJECT) { + genericObj = parser.map(); + valueType = ValueType.MAP; + } else if (token == XContentParser.Token.START_ARRAY) { + genericObj = parser.list(); + valueType = ValueType.LIST; + } else if (token.isValue()) { + genericObj = parser.objectText(); + valueType = ValueType.SIMPLE; + } else { + throw new XContentParseException("Unknown token: " + token); + } + + return new GenericSearchExtBuilder(genericObj, valueType); + } + + @Override + public String getWriteableName() { + return EXT_BUILDER_NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(valueType.getValue()); + switch (valueType) { + case SIMPLE: + out.writeGenericValue(genericObj); + break; + case MAP: + out.writeMap((Map) genericObj); + break; + case LIST: + out.writeCollection((List) genericObj, StreamOutput::writeGenericValue); + break; + default: + throw new IllegalStateException("Unknown valueType: " + valueType); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + switch (valueType) { + case SIMPLE: + return builder.field(EXT_BUILDER_NAME.getPreferredName(), genericObj); + case MAP: + return builder.field(EXT_BUILDER_NAME.getPreferredName(), (Map) genericObj); + case LIST: + return builder.field(EXT_BUILDER_NAME.getPreferredName(), (List) genericObj); + default: + return null; + } + } + + // We need this for the equals method. + Object getValue() { + return genericObj; + } + + @Override + public int hashCode() { + return Objects.hash(this.valueType, this.genericObj); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof GenericSearchExtBuilder)) { + return false; + } + return Objects.equals(getValue(), ((GenericSearchExtBuilder) obj).getValue()) + && Objects.equals(valueType, ((GenericSearchExtBuilder) obj).valueType); + } +} diff --git a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java index 1561d18f3040a..8e3979045f857 100644 --- a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java +++ b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java @@ -32,17 +32,21 @@ package org.opensearch.search.internal; +import org.opensearch.Version; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; import java.io.IOException; +import java.util.Collections; +import java.util.List; /** * {@link SearchResponseSections} subclass that can be serialized over the wire. @@ -67,7 +71,20 @@ public InternalSearchResponse( Boolean terminatedEarly, int numReducePhases ) { - super(hits, aggregations, suggest, timedOut, terminatedEarly, profileResults, numReducePhases); + this(hits, aggregations, suggest, profileResults, timedOut, terminatedEarly, numReducePhases, Collections.emptyList()); + } + + public InternalSearchResponse( + SearchHits hits, + InternalAggregations aggregations, + Suggest suggest, + SearchProfileShardResults profileResults, + boolean timedOut, + Boolean terminatedEarly, + int numReducePhases, + List searchExtBuilderList + ) { + super(hits, aggregations, suggest, timedOut, terminatedEarly, profileResults, numReducePhases, searchExtBuilderList); } public InternalSearchResponse(StreamInput in) throws IOException { @@ -78,7 +95,8 @@ public InternalSearchResponse(StreamInput in) throws IOException { in.readBoolean(), in.readOptionalBoolean(), in.readOptionalWriteable(SearchProfileShardResults::new), - in.readVInt() + in.readVInt(), + readSearchExtBuildersOnOrAfter(in) ); } @@ -91,5 +109,16 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(terminatedEarly); out.writeOptionalWriteable(profileResults); out.writeVInt(numReducePhases); + writeSearchExtBuildersOnOrAfter(out, searchExtBuilders); + } + + private static List readSearchExtBuildersOnOrAfter(StreamInput in) throws IOException { + return (in.getVersion().onOrAfter(Version.V_3_0_0)) ? in.readNamedWriteableList(SearchExtBuilder.class) : Collections.emptyList(); + } + + private static void writeSearchExtBuildersOnOrAfter(StreamOutput out, List searchExtBuilders) throws IOException { + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeNamedWriteableList(searchExtBuilders); + } } } diff --git a/server/src/test/java/org/opensearch/action/search/SearchResponseTests.java b/server/src/test/java/org/opensearch/action/search/SearchResponseTests.java index c35bdf9c14587..097e922147698 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchResponseTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchResponseTests.java @@ -37,9 +37,13 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -47,7 +51,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentHelper; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.rest.action.search.RestSearchAction; +import org.opensearch.search.GenericSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchHitsTests; @@ -68,8 +75,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.UUID; -import static java.util.Collections.emptyList; import static java.util.Collections.singletonMap; import static org.opensearch.test.XContentTestUtils.insertRandomFields; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertToXContentEquivalent; @@ -80,11 +87,25 @@ public class SearchResponseTests extends OpenSearchTestCase { static { List namedXContents = new ArrayList<>(InternalAggregationTestCase.getDefaultNamedXContents()); namedXContents.addAll(SuggestTests.getDefaultNamedXContents()); + namedXContents.add( + new NamedXContentRegistry.Entry(SearchExtBuilder.class, DummySearchExtBuilder.DUMMY_FIELD, DummySearchExtBuilder::parse) + ); xContentRegistry = new NamedXContentRegistry(namedXContents); } private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry( - new SearchModule(Settings.EMPTY, emptyList()).getNamedWriteables() + new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List.of( + new SearchExtSpec<>( + DummySearchExtBuilder.DUMMY_FIELD, + DummySearchExtBuilder::new, + parser -> DummySearchExtBuilder.parse(parser) + ) + ); + } + })).getNamedWriteables() ); private AggregationsTests aggregationsTests = new AggregationsTests(); @@ -119,6 +140,14 @@ private SearchResponse createMinimalTestItem() { * if minimal is set, don't include search hits, aggregations, suggest etc... to make test simpler */ private SearchResponse createTestItem(boolean minimal, ShardSearchFailure... shardSearchFailures) { + return createTestItem(minimal, Collections.emptyList(), shardSearchFailures); + } + + public SearchResponse createTestItem( + boolean minimal, + List searchExtBuilders, + ShardSearchFailure... shardSearchFailures + ) { boolean timedOut = randomBoolean(); Boolean terminatedEarly = randomBoolean() ? null : randomBoolean(); int numReducePhases = randomIntBetween(1, 10); @@ -139,7 +168,8 @@ private SearchResponse createTestItem(boolean minimal, ShardSearchFailure... sha profileShardResults, timedOut, terminatedEarly, - numReducePhases + numReducePhases, + searchExtBuilders ); } else { internalSearchResponse = InternalSearchResponse.empty(); @@ -153,7 +183,8 @@ private SearchResponse createTestItem(boolean minimal, ShardSearchFailure... sha skippedShards, tookInMillis, shardSearchFailures, - randomBoolean() ? randomClusters() : SearchResponse.Clusters.EMPTY + randomBoolean() ? randomClusters() : SearchResponse.Clusters.EMPTY, + null ); } @@ -172,6 +203,32 @@ public void testFromXContent() throws IOException { doFromXContentTestWithRandomFields(createTestItem(), false); } + public void testFromXContentWithSearchExtBuilders() throws IOException { + doFromXContentTestWithRandomFields(createTestItem(false, List.of(new DummySearchExtBuilder(UUID.randomUUID().toString()))), false); + } + + public void testFromXContentWithUnregisteredSearchExtBuilders() throws IOException { + List namedXContents = new ArrayList<>(InternalAggregationTestCase.getDefaultNamedXContents()); + namedXContents.addAll(SuggestTests.getDefaultNamedXContents()); + String dummyId = UUID.randomUUID().toString(); + String fakeId = UUID.randomUUID().toString(); + List extBuilders = List.of(new DummySearchExtBuilder(dummyId), new FakeSearchExtBuilder(fakeId)); + SearchResponse response = createTestItem(false, extBuilders); + MediaType xcontentType = randomFrom(XContentType.values()); + boolean humanReadable = randomBoolean(); + final ToXContent.Params params = new ToXContent.MapParams(singletonMap(RestSearchAction.TYPED_KEYS_PARAM, "true")); + BytesReference originalBytes = toShuffledXContent(response, xcontentType, params, humanReadable); + XContentParser parser = createParser(new NamedXContentRegistry(namedXContents), xcontentType.xContent(), originalBytes); + SearchResponse parsed = SearchResponse.fromXContent(parser); + assertEquals(extBuilders.size(), response.getInternalResponse().getSearchExtBuilders().size()); + + List actual = parsed.getInternalResponse().getSearchExtBuilders(); + assertEquals(extBuilders.size(), actual.size()); + for (int i = 0; i < actual.size(); i++) { + assertTrue(actual.get(0) instanceof GenericSearchExtBuilder); + } + } + /** * This test adds random fields and objects to the xContent rendered out to * ensure we can parse it back to be forward compatible with additions to @@ -182,7 +239,7 @@ public void testFromXContentWithRandomFields() throws IOException { doFromXContentTestWithRandomFields(createMinimalTestItem(), true); } - private void doFromXContentTestWithRandomFields(SearchResponse response, boolean addRandomFields) throws IOException { + public void doFromXContentTestWithRandomFields(SearchResponse response, boolean addRandomFields) throws IOException { MediaType xcontentType = randomFrom(XContentType.values()); boolean humanReadable = randomBoolean(); final ToXContent.Params params = new ToXContent.MapParams(singletonMap(RestSearchAction.TYPED_KEYS_PARAM, "true")); @@ -245,6 +302,7 @@ public void testToXContent() { SearchHit hit = new SearchHit(1, "id1", Collections.emptyMap(), Collections.emptyMap()); hit.score(2.0f); SearchHit[] hits = new SearchHit[] { hit }; + String dummyId = UUID.randomUUID().toString(); { SearchResponse response = new SearchResponse( new InternalSearchResponse( @@ -254,7 +312,8 @@ public void testToXContent() { null, false, null, - 1 + 1, + List.of(new DummySearchExtBuilder(dummyId)) ), null, 0, @@ -262,7 +321,8 @@ public void testToXContent() { 0, 0, ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + null ); StringBuilder expectedString = new StringBuilder(); expectedString.append("{"); @@ -280,11 +340,17 @@ public void testToXContent() { { expectedString.append("{\"total\":{\"value\":100,\"relation\":\"eq\"},"); expectedString.append("\"max_score\":1.5,"); - expectedString.append("\"hits\":[{\"_id\":\"id1\",\"_score\":2.0}]}"); + expectedString.append("\"hits\":[{\"_id\":\"id1\",\"_score\":2.0}]},"); + } + expectedString.append("\"ext\":"); + { + expectedString.append("{\"dummy\":\"" + dummyId + "\"}"); } } expectedString.append("}"); assertEquals(expectedString.toString(), Strings.toString(MediaTypeRegistry.JSON, response)); + List searchExtBuilders = response.getInternalResponse().getSearchExtBuilders(); + assertEquals(1, searchExtBuilders.size()); } { SearchResponse response = new SearchResponse( @@ -352,6 +418,48 @@ public void testSerialization() throws IOException { assertEquals(searchResponse.getClusters(), deserialized.getClusters()); } + public void testSerializationWithSearchExtBuilders() throws IOException { + String id = UUID.randomUUID().toString(); + SearchResponse searchResponse = createTestItem(false, List.of(new DummySearchExtBuilder(id))); + SearchResponse deserialized = copyWriteable(searchResponse, namedWriteableRegistry, SearchResponse::new, Version.CURRENT); + if (searchResponse.getHits().getTotalHits() == null) { + assertNull(deserialized.getHits().getTotalHits()); + } else { + assertEquals(searchResponse.getHits().getTotalHits().value, deserialized.getHits().getTotalHits().value); + assertEquals(searchResponse.getHits().getTotalHits().relation, deserialized.getHits().getTotalHits().relation); + } + assertEquals(searchResponse.getHits().getHits().length, deserialized.getHits().getHits().length); + assertEquals(searchResponse.getNumReducePhases(), deserialized.getNumReducePhases()); + assertEquals(searchResponse.getFailedShards(), deserialized.getFailedShards()); + assertEquals(searchResponse.getTotalShards(), deserialized.getTotalShards()); + assertEquals(searchResponse.getSkippedShards(), deserialized.getSkippedShards()); + assertEquals(searchResponse.getClusters(), deserialized.getClusters()); + assertEquals( + searchResponse.getInternalResponse().getSearchExtBuilders().get(0), + deserialized.getInternalResponse().getSearchExtBuilders().get(0) + ); + } + + public void testSerializationWithSearchExtBuildersOnUnsupportedWriterVersion() throws IOException { + String id = UUID.randomUUID().toString(); + SearchResponse searchResponse = createTestItem(false, List.of(new DummySearchExtBuilder(id))); + SearchResponse deserialized = copyWriteable(searchResponse, namedWriteableRegistry, SearchResponse::new, Version.V_2_9_0); + if (searchResponse.getHits().getTotalHits() == null) { + assertNull(deserialized.getHits().getTotalHits()); + } else { + assertEquals(searchResponse.getHits().getTotalHits().value, deserialized.getHits().getTotalHits().value); + assertEquals(searchResponse.getHits().getTotalHits().relation, deserialized.getHits().getTotalHits().relation); + } + assertEquals(searchResponse.getHits().getHits().length, deserialized.getHits().getHits().length); + assertEquals(searchResponse.getNumReducePhases(), deserialized.getNumReducePhases()); + assertEquals(searchResponse.getFailedShards(), deserialized.getFailedShards()); + assertEquals(searchResponse.getTotalShards(), deserialized.getTotalShards()); + assertEquals(searchResponse.getSkippedShards(), deserialized.getSkippedShards()); + assertEquals(searchResponse.getClusters(), deserialized.getClusters()); + assertEquals(1, searchResponse.getInternalResponse().getSearchExtBuilders().size()); + assertTrue(deserialized.getInternalResponse().getSearchExtBuilders().isEmpty()); + } + public void testToXContentEmptyClusters() throws IOException { SearchResponse searchResponse = new SearchResponse( InternalSearchResponse.empty(), @@ -368,4 +476,89 @@ public void testToXContentEmptyClusters() throws IOException { deserialized.getClusters().toXContent(builder, ToXContent.EMPTY_PARAMS); assertEquals(0, builder.toString().length()); } + + static class DummySearchExtBuilder extends SearchExtBuilder { + + static ParseField DUMMY_FIELD = new ParseField("dummy"); + + protected final String id; + + public DummySearchExtBuilder(String id) { + assertNotNull(id); + this.id = id; + } + + public DummySearchExtBuilder(StreamInput in) throws IOException { + this.id = in.readString(); + } + + public String getId() { + return this.id; + } + + @Override + public String getWriteableName() { + return DUMMY_FIELD.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field("dummy", id); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (!(obj instanceof DummySearchExtBuilder)) { + return false; + } + + return this.id.equals(((DummySearchExtBuilder) obj).getId()); + } + + public static DummySearchExtBuilder parse(XContentParser parser) throws IOException { + String id; + XContentParser.Token token = parser.currentToken(); + if (token == XContentParser.Token.VALUE_STRING) { + id = parser.text(); + } else { + throw new ParsingException(parser.getTokenLocation(), "Expected a VALUE_STRING but got " + token); + } + if (id == null) { + throw new ParsingException(parser.getTokenLocation(), "no id specified for " + DUMMY_FIELD.getPreferredName()); + } + return new DummySearchExtBuilder(id); + } + } + + static class FakeSearchExtBuilder extends DummySearchExtBuilder { + static ParseField DUMMY_FIELD = new ParseField("fake"); + + public FakeSearchExtBuilder(String id) { + super(id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(DUMMY_FIELD.getPreferredName()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(DUMMY_FIELD.getPreferredName(), id); + } + } } diff --git a/server/src/test/java/org/opensearch/search/GenericSearchExtBuilderTests.java b/server/src/test/java/org/opensearch/search/GenericSearchExtBuilderTests.java new file mode 100644 index 0000000000000..8fb1814962155 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/GenericSearchExtBuilderTests.java @@ -0,0 +1,422 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search; + +import org.opensearch.Version; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseTests; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.rest.action.search.RestSearchAction; +import org.opensearch.search.aggregations.AggregationsTests; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.profile.SearchProfileShardResultsTests; +import org.opensearch.search.suggest.Suggest; +import org.opensearch.search.suggest.SuggestTests; +import org.opensearch.test.InternalAggregationTestCase; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +import static java.util.Collections.singletonMap; + +public class GenericSearchExtBuilderTests extends OpenSearchTestCase { + + private static final NamedXContentRegistry xContentRegistry; + static { + List namedXContents = new ArrayList<>(InternalAggregationTestCase.getDefaultNamedXContents()); + namedXContents.addAll(SuggestTests.getDefaultNamedXContents()); + namedXContents.add( + new NamedXContentRegistry.Entry( + SearchExtBuilder.class, + GenericSearchExtBuilder.EXT_BUILDER_NAME, + GenericSearchExtBuilder::fromXContent + ) + ); + xContentRegistry = new NamedXContentRegistry(namedXContents); + } + + private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry( + new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List.of( + new SearchExtSpec<>( + GenericSearchExtBuilder.EXT_BUILDER_NAME, + GenericSearchExtBuilder::new, + GenericSearchExtBuilder::fromXContent + ) + ); + } + })).getNamedWriteables() + ); + + @Override + protected NamedXContentRegistry xContentRegistry() { + return xContentRegistry; + } + + SearchResponseTests srt = new SearchResponseTests(); + private AggregationsTests aggregationsTests = new AggregationsTests(); + + @Before + public void init() throws Exception { + aggregationsTests.init(); + } + + @After + public void cleanUp() throws Exception { + aggregationsTests.cleanUp(); + } + + public void testFromXContentWithUnregisteredSearchExtBuilders() throws IOException { + List namedXContents = new ArrayList<>(InternalAggregationTestCase.getDefaultNamedXContents()); + namedXContents.addAll(SuggestTests.getDefaultNamedXContents()); + String dummyId = UUID.randomUUID().toString(); + List extBuilders = List.of( + new SimpleValueSearchExtBuilder(dummyId), + new MapSearchExtBuilder(Map.of("x", "y", "a", "b")), + new ListSearchExtBuilder(List.of("1", "2", "3")) + ); + SearchResponse response = srt.createTestItem(false, extBuilders); + MediaType xcontentType = randomFrom(XContentType.values()); + boolean humanReadable = randomBoolean(); + final ToXContent.Params params = new ToXContent.MapParams(singletonMap(RestSearchAction.TYPED_KEYS_PARAM, "true")); + BytesReference originalBytes = toShuffledXContent(response, xcontentType, params, humanReadable); + XContentParser parser = createParser(new NamedXContentRegistry(namedXContents), xcontentType.xContent(), originalBytes); + SearchResponse parsed = SearchResponse.fromXContent(parser); + assertEquals(extBuilders.size(), response.getInternalResponse().getSearchExtBuilders().size()); + + List actual = parsed.getInternalResponse().getSearchExtBuilders(); + assertEquals(extBuilders.size(), actual.size()); + for (int i = 0; i < actual.size(); i++) { + assertTrue(actual.get(0) instanceof GenericSearchExtBuilder); + } + } + + // This test case fails because GenericSearchExtBuilder does not retain the name of the SearchExtBuilder that it is replacing. + // GenericSearchExtBuilder has its own "generic_ext" section name. + // public void testFromXContentWithSearchExtBuilders() throws IOException { + // String dummyId = UUID.randomUUID().toString(); + // srt.doFromXContentTestWithRandomFields(createTestItem(false, List.of(new SimpleValueSearchExtBuilder(dummyId))), false); + // } + + public void testFromXContentWithGenericSearchExtBuildersForSimpleValues() throws IOException { + String dummyId = UUID.randomUUID().toString(); + srt.doFromXContentTestWithRandomFields( + createTestItem(false, List.of(new GenericSearchExtBuilder(dummyId, GenericSearchExtBuilder.ValueType.SIMPLE))), + false + ); + } + + public void testFromXContentWithGenericSearchExtBuildersForMapValues() throws IOException { + srt.doFromXContentTestWithRandomFields( + createTestItem(false, List.of(new GenericSearchExtBuilder(Map.of("x", "y", "a", "b"), GenericSearchExtBuilder.ValueType.MAP))), + false + ); + } + + public void testFromXContentWithGenericSearchExtBuildersForListValues() throws IOException { + String dummyId = UUID.randomUUID().toString(); + srt.doFromXContentTestWithRandomFields( + createTestItem(false, List.of(new GenericSearchExtBuilder(List.of("1", "2", "3"), GenericSearchExtBuilder.ValueType.LIST))), + false + ); + } + + public void testSerializationWithGenericSearchExtBuildersForSimpleValues() throws IOException { + String id = UUID.randomUUID().toString(); + SearchResponse searchResponse = createTestItem( + false, + List.of(new GenericSearchExtBuilder(id, GenericSearchExtBuilder.ValueType.SIMPLE)) + ); + SearchResponse deserialized = copyWriteable(searchResponse, namedWriteableRegistry, SearchResponse::new, Version.CURRENT); + if (searchResponse.getHits().getTotalHits() == null) { + assertNull(deserialized.getHits().getTotalHits()); + } else { + assertEquals(searchResponse.getHits().getTotalHits().value, deserialized.getHits().getTotalHits().value); + assertEquals(searchResponse.getHits().getTotalHits().relation, deserialized.getHits().getTotalHits().relation); + } + assertEquals(searchResponse.getHits().getHits().length, deserialized.getHits().getHits().length); + assertEquals(searchResponse.getNumReducePhases(), deserialized.getNumReducePhases()); + assertEquals(searchResponse.getFailedShards(), deserialized.getFailedShards()); + assertEquals(searchResponse.getTotalShards(), deserialized.getTotalShards()); + assertEquals(searchResponse.getSkippedShards(), deserialized.getSkippedShards()); + assertEquals(searchResponse.getClusters(), deserialized.getClusters()); + assertEquals( + searchResponse.getInternalResponse().getSearchExtBuilders().get(0), + deserialized.getInternalResponse().getSearchExtBuilders().get(0) + ); + } + + public void testSerializationWithGenericSearchExtBuildersForMapValues() throws IOException { + SearchResponse searchResponse = createTestItem( + false, + List.of(new GenericSearchExtBuilder(Map.of("x", "y", "a", "b"), GenericSearchExtBuilder.ValueType.MAP)) + ); + SearchResponse deserialized = copyWriteable(searchResponse, namedWriteableRegistry, SearchResponse::new, Version.CURRENT); + if (searchResponse.getHits().getTotalHits() == null) { + assertNull(deserialized.getHits().getTotalHits()); + } else { + assertEquals(searchResponse.getHits().getTotalHits().value, deserialized.getHits().getTotalHits().value); + assertEquals(searchResponse.getHits().getTotalHits().relation, deserialized.getHits().getTotalHits().relation); + } + assertEquals(searchResponse.getHits().getHits().length, deserialized.getHits().getHits().length); + assertEquals(searchResponse.getNumReducePhases(), deserialized.getNumReducePhases()); + assertEquals(searchResponse.getFailedShards(), deserialized.getFailedShards()); + assertEquals(searchResponse.getTotalShards(), deserialized.getTotalShards()); + assertEquals(searchResponse.getSkippedShards(), deserialized.getSkippedShards()); + assertEquals(searchResponse.getClusters(), deserialized.getClusters()); + assertEquals( + searchResponse.getInternalResponse().getSearchExtBuilders().get(0), + deserialized.getInternalResponse().getSearchExtBuilders().get(0) + ); + } + + public void testSerializationWithGenericSearchExtBuildersForListValues() throws IOException { + SearchResponse searchResponse = createTestItem( + false, + List.of(new GenericSearchExtBuilder(List.of("1", "2", "3"), GenericSearchExtBuilder.ValueType.LIST)) + ); + SearchResponse deserialized = copyWriteable(searchResponse, namedWriteableRegistry, SearchResponse::new, Version.CURRENT); + if (searchResponse.getHits().getTotalHits() == null) { + assertNull(deserialized.getHits().getTotalHits()); + } else { + assertEquals(searchResponse.getHits().getTotalHits().value, deserialized.getHits().getTotalHits().value); + assertEquals(searchResponse.getHits().getTotalHits().relation, deserialized.getHits().getTotalHits().relation); + } + assertEquals(searchResponse.getHits().getHits().length, deserialized.getHits().getHits().length); + assertEquals(searchResponse.getNumReducePhases(), deserialized.getNumReducePhases()); + assertEquals(searchResponse.getFailedShards(), deserialized.getFailedShards()); + assertEquals(searchResponse.getTotalShards(), deserialized.getTotalShards()); + assertEquals(searchResponse.getSkippedShards(), deserialized.getSkippedShards()); + assertEquals(searchResponse.getClusters(), deserialized.getClusters()); + assertEquals( + searchResponse.getInternalResponse().getSearchExtBuilders().get(0), + deserialized.getInternalResponse().getSearchExtBuilders().get(0) + ); + } + + public SearchResponse createTestItem( + boolean minimal, + List searchExtBuilders, + ShardSearchFailure... shardSearchFailures + ) { + boolean timedOut = randomBoolean(); + Boolean terminatedEarly = randomBoolean() ? null : randomBoolean(); + int numReducePhases = randomIntBetween(1, 10); + long tookInMillis = randomNonNegativeLong(); + int totalShards = randomIntBetween(1, Integer.MAX_VALUE); + int successfulShards = randomIntBetween(0, totalShards); + int skippedShards = randomIntBetween(0, totalShards); + InternalSearchResponse internalSearchResponse; + if (minimal == false) { + SearchHits hits = SearchHitsTests.createTestItem(true, true); + InternalAggregations aggregations = aggregationsTests.createTestInstance(); + Suggest suggest = SuggestTests.createTestItem(); + SearchProfileShardResults profileShardResults = SearchProfileShardResultsTests.createTestItem(); + internalSearchResponse = new InternalSearchResponse( + hits, + aggregations, + suggest, + profileShardResults, + timedOut, + terminatedEarly, + numReducePhases, + searchExtBuilders + ); + } else { + internalSearchResponse = InternalSearchResponse.empty(); + } + + return new SearchResponse( + internalSearchResponse, + null, + totalShards, + successfulShards, + skippedShards, + tookInMillis, + shardSearchFailures, + randomBoolean() ? randomClusters() : SearchResponse.Clusters.EMPTY, + null + ); + } + + static SearchResponse.Clusters randomClusters() { + int totalClusters = randomIntBetween(0, 10); + int successfulClusters = randomIntBetween(0, totalClusters); + int skippedClusters = totalClusters - successfulClusters; + return new SearchResponse.Clusters(totalClusters, successfulClusters, skippedClusters); + } + + static class SimpleValueSearchExtBuilder extends SearchExtBuilder { + + static ParseField FIELD = new ParseField("simple_value"); + + private final String id; + + public SimpleValueSearchExtBuilder(String id) { + assertNotNull(id); + this.id = id; + } + + public SimpleValueSearchExtBuilder(StreamInput in) throws IOException { + this.id = in.readString(); + } + + public String getId() { + return this.id; + } + + @Override + public String getWriteableName() { + return FIELD.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(FIELD.getPreferredName(), id); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (!(obj instanceof SimpleValueSearchExtBuilder)) { + return false; + } + + return this.id.equals(((SimpleValueSearchExtBuilder) obj).getId()); + } + + public static SimpleValueSearchExtBuilder parse(XContentParser parser) throws IOException { + String id; + XContentParser.Token token = parser.currentToken(); + if (token == XContentParser.Token.VALUE_STRING) { + id = parser.text(); + } else { + throw new ParsingException(parser.getTokenLocation(), "Expected a VALUE_STRING but got " + token); + } + if (id == null) { + throw new ParsingException(parser.getTokenLocation(), "no id specified for " + FIELD.getPreferredName()); + } + return new SimpleValueSearchExtBuilder(id); + } + } + + static class MapSearchExtBuilder extends SearchExtBuilder { + + private final static String EXT_FIELD = "map0"; + + private final Map map; + + public MapSearchExtBuilder(Map map) { + this.map = new HashMap<>(); + for (Map.Entry e : map.entrySet()) { + this.map.put(e.getKey(), e.getValue()); + } + } + + @Override + public String getWriteableName() { + return EXT_FIELD; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(this.map); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(EXT_FIELD, this.map); + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.map); + } + + @Override + public boolean equals(Object obj) { + return false; + } + } + + static class ListSearchExtBuilder extends SearchExtBuilder { + + private final static String EXT_FIELD = "list0"; + + private final List list; + + public ListSearchExtBuilder(List list) { + this.list = new ArrayList<>(); + list.forEach(e -> this.list.add(e)); + } + + @Override + public String getWriteableName() { + return EXT_FIELD; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(this.list, StreamOutput::writeString); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(EXT_FIELD, this.list); + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.list); + } + + @Override + public boolean equals(Object obj) { + return false; + } + } +}