From 0dab8ea7473d85fb0f7a166a712b5bf9a09b6722 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 14:51:14 +0100 Subject: [PATCH] Add test for knn retriever --- .../search/retriever/KnnRetrieverBuilder.java | 10 ++++++++-- .../retriever/KnnRetrieverBuilderParsingTests.java | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 97c87d755ca25..da6254201072b 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -257,7 +257,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder searchSourceBuilder.knnSearch(knnSearchBuilders); } - // ---- FOR TESTING XCONTENT PARSING ---- + RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + +// ---- FOR TESTING XCONTENT PARSING ---- @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { @@ -278,7 +282,9 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } if (rescoreVectorBuilder != null) { - builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + builder.startObject(RESCORE_FIELD.getPreferredName()); + rescoreVectorBuilder.toXContent(builder, params); + builder.endObject(); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 0213a83385739..da28b0eff441f 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -105,6 +105,7 @@ public void testRewrite() throws IOException { assertNull(source.query()); assertThat(source.knnSearch().size(), equalTo(1)); assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size())); + assertThat(source.knnSearch().get(0).getRescoreVectorBuilder(), equalTo(knnRetriever.rescoreVectorBuilder())); for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) { assertThat( source.knnSearch().get(0).getFilterQueries().get(j),