Skip to content

Commit

Permalink
Updating toXContent implementation for retrievers (elastic#114017)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis authored Oct 8, 2024
1 parent 7bbebbd commit 44f3791
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,19 @@ public ActionRequestValidationException validate(
@Override
public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startObject(getName());
if (preFilterQueryBuilders.isEmpty() == false) {
builder.field(PRE_FILTER_FIELD.getPreferredName(), preFilterQueryBuilders);
}
if (minScore != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (retrieverName != null) {
builder.field(NAME_FIELD.getPreferredName(), retrieverName);
}
doToXContent(builder, params);
builder.endObject();
builder.endObject();

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.elasticsearch.search.collapse.CollapseBuilderTests;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
Expand Down Expand Up @@ -600,6 +602,75 @@ public void testNegativeTrackTotalHits() throws IOException {
}
}

public void testStandardRetrieverParsing() throws IOException {
String restContent = "{"
+ " \"retriever\": {"
+ " \"standard\": {"
+ " \"query\": {"
+ " \"match_all\": {}"
+ " },"
+ " \"min_score\": 10,"
+ " \"_name\": \"foo_standard\""
+ " }"
+ " }"
+ "}";
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
assertThat(source.retriever(), instanceOf(StandardRetrieverBuilder.class));
StandardRetrieverBuilder parsed = (StandardRetrieverBuilder) source.retriever();
assertThat(parsed.minScore(), equalTo(10f));
assertThat(parsed.retrieverName(), equalTo("foo_standard"));
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
parseSerialized,
true,
searchUsageHolder,
nf -> true
);
assertThat(deserializedSource.retriever(), instanceOf(StandardRetrieverBuilder.class));
StandardRetrieverBuilder deserialized = (StandardRetrieverBuilder) source.retriever();
assertThat(parsed, equalTo(deserialized));
}
}
}

public void testKnnRetrieverParsing() throws IOException {
String restContent = "{"
+ " \"retriever\": {"
+ " \"knn\": {"
+ " \"query_vector\": ["
+ " 3"
+ " ],"
+ " \"field\": \"vector\","
+ " \"k\": 10,"
+ " \"num_candidates\": 15,"
+ " \"min_score\": 10,"
+ " \"_name\": \"foo_knn\""
+ " }"
+ " }"
+ "}";
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
assertThat(source.retriever(), instanceOf(KnnRetrieverBuilder.class));
KnnRetrieverBuilder parsed = (KnnRetrieverBuilder) source.retriever();
assertThat(parsed.minScore(), equalTo(10f));
assertThat(parsed.retrieverName(), equalTo("foo_knn"));
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
parseSerialized,
true,
searchUsageHolder,
nf -> true
);
assertThat(deserializedSource.retriever(), instanceOf(KnnRetrieverBuilder.class));
KnnRetrieverBuilder deserialized = (KnnRetrieverBuilder) source.retriever();
assertThat(parsed, equalTo(deserialized));
}
}
}

public void testStoredFieldsUsage() throws IOException {
Set<String> storedFieldRestVariations = Set.of(
"{\"stored_fields\" : [\"_none_\"]}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ protected KnnRetrieverBuilder createTestInstance() {

@Override
protected KnnRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return KnnRetrieverBuilder.fromXContent(
return (KnnRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(
new SearchUsage(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected StandardRetrieverBuilder createTestInstance() {

@Override
protected StandardRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return StandardRetrieverBuilder.fromXContent(
return (StandardRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(
new SearchUsage(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ public int rankWindowSize() {

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RETRIEVER_FIELD.getPreferredName());
builder.startObject();
builder.field(retrieverBuilder.getName(), retrieverBuilder);
builder.endObject();
builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
builder.field(FIELD_FIELD.getPreferredName(), field);
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
if (seed != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,11 @@ public int rankWindowSize() {

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RETRIEVER_FIELD.getPreferredName());
builder.startObject();
builder.field(retrieverBuilder.getName(), retrieverBuilder);
builder.endObject();
builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
builder.field(FIELD_FIELD.getPreferredName(), field);
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
if (minScore != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -48,8 +46,8 @@ protected RandomRankRetrieverBuilder createTestInstance() {
}

@Override
protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) {
return RandomRankRetrieverBuilder.PARSER.apply(
protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return (RandomRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(
new SearchUsage(),
Expand Down Expand Up @@ -77,8 +75,8 @@ protected NamedXContentRegistry xContentRegistry() {
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
new ParseField(TextSimilarityRankBuilder.NAME),
(p, c) -> TextSimilarityRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
new ParseField(RandomRankBuilder.NAME),
(p, c) -> RandomRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
)
);
return new NamedXContentRegistry(entries);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.rank.textsimilarity;

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
Expand All @@ -25,6 +26,8 @@
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.usage.SearchUsageHolder;
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
Expand Down Expand Up @@ -72,8 +75,8 @@ protected TextSimilarityRankRetrieverBuilder createTestInstance() {
}

@Override
protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) {
return TextSimilarityRankRetrieverBuilder.PARSER.apply(
protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return (TextSimilarityRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(
new SearchUsage(),
Expand Down Expand Up @@ -208,6 +211,45 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
}
}

public void testTextSimilarityRetrieverParsing() throws IOException {
String restContent = "{"
+ " \"retriever\": {"
+ " \"text_similarity_reranker\": {"
+ " \"retriever\": {"
+ " \"test\": {"
+ " \"value\": \"my-test-retriever\""
+ " }"
+ " },"
+ " \"field\": \"my-field\","
+ " \"inference_id\": \"my-inference-id\","
+ " \"inference_text\": \"my-inference-text\","
+ " \"rank_window_size\": 100,"
+ " \"min_score\": 20.0,"
+ " \"_name\": \"foo_reranker\""
+ " }"
+ " }"
+ "}";
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever();
assertThat(parsed.minScore(), equalTo(20f));
assertThat(parsed.retrieverName(), equalTo("foo_reranker"));
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
parseSerialized,
true,
searchUsageHolder,
nf -> true
);
assertThat(deserializedSource.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
TextSimilarityRankRetrieverBuilder deserialized = (TextSimilarityRankRetrieverBuilder) source.retriever();
assertThat(parsed, equalTo(deserialized));
}
}
}

public void testIsCompound() {
RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
builder.startArray(RETRIEVERS_FIELD.getPreferredName());

for (var entry : innerRetrievers) {
builder.startObject();
builder.field(entry.retriever().getName());
entry.retriever().toXContent(builder, params);
builder.endObject();
}
builder.endArray();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,27 @@
package org.elasticsearch.xpack.rank.rrf;

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.usage.SearchUsageHolder;
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.json.JsonXContent;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RRFRetrieverBuilder> {

/**
Expand Down Expand Up @@ -53,7 +61,10 @@ protected RRFRetrieverBuilder createTestInstance() {

@Override
protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true));
return (RRFRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(new SearchUsage(), nf -> true)
);
}

@Override
Expand Down Expand Up @@ -81,4 +92,48 @@ protected NamedXContentRegistry xContentRegistry() {
);
return new NamedXContentRegistry(entries);
}

public void testRRFRetrieverParsing() throws IOException {
String restContent = "{"
+ " \"retriever\": {"
+ " \"rrf\": {"
+ " \"retrievers\": ["
+ " {"
+ " \"test\": {"
+ " \"value\": \"foo\""
+ " }"
+ " },"
+ " {"
+ " \"test\": {"
+ " \"value\": \"bar\""
+ " }"
+ " }"
+ " ],"
+ " \"rank_window_size\": 100,"
+ " \"rank_constant\": 10,"
+ " \"min_score\": 20.0,"
+ " \"_name\": \"foo_rrf\""
+ " }"
+ " }"
+ "}";
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever();
assertThat(parsed.minScore(), equalTo(20f));
assertThat(parsed.retrieverName(), equalTo("foo_rrf"));
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
parseSerialized,
true,
searchUsageHolder,
nf -> true
);
assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever();
assertThat(parsed, equalTo(deserialized));
}
}
}
}

0 comments on commit 44f3791

Please sign in to comment.