From 26699ea58aa43935f5c1756208e64f7262e86af3 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:17:04 -0800 Subject: [PATCH] Enable support for default model id on HybridQueryBuilder (#541) (#549) * Enable support for default model id on HybridQueryBuilder Signed-off-by: Varun Jain * Adding tests and updating changelog.md Signed-off-by: Varun Jain * Optimizing code Signed-off-by: Varun Jain * modyfing the tests4 Signed-off-by: Varun Jain * Addressing Heemin comment Signed-off-by: Varun Jain --------- Signed-off-by: Varun Jain (cherry picked from commit 98e5534929efb054e7b5d4382af46b0626ef1f50) Co-authored-by: Varun Jain --- CHANGELOG.md | 1 + .../query/HybridQueryBuilder.java | 16 +++++++++++ .../NeuralQueryEnricherProcessorIT.java | 20 ++++++++++++++ .../query/HybridQueryBuilderTests.java | 8 ++++++ .../query/OpenSearchQueryTestCase.java | 27 +++++++++++++++---- 5 files changed, 67 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5932d5ea9..13f949f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) - Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) - Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533)) +- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541)) ### Infrastructure - BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515)) - Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535)) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index b0104639b..aa4242c2e 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -13,6 +13,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.Query; import org.opensearch.common.lucene.search.Queries; import org.opensearch.core.ParseField; @@ -27,6 +28,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; import org.opensearch.index.query.Rewriteable; +import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; import lombok.NoArgsConstructor; @@ -295,4 +297,18 @@ private Collection toQueries(Collection queryBuilders, Quer }).filter(Objects::nonNull).collect(Collectors.toList()); return queries; } + + /** + * visit method to parse the HybridQueryBuilder by a visitor + */ + @Override + public void visit(QueryBuilderVisitor visitor) { + visitor.accept(this); + // getChildVisitor of NeuralSearchQueryVisitor return this. + // therefore any argument can be passed. Here we have used Occcur.MUST as an argument. + QueryBuilderVisitor subVisitor = visitor.getChildVisitor(Occur.MUST); + for (QueryBuilder subQueryBuilder : queries) { + subQueryBuilder.visit(subVisitor); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java index 3c8adf5b3..b37db7c82 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java @@ -15,6 +15,7 @@ import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import com.google.common.primitives.Floats; @@ -62,6 +63,25 @@ public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() { } + @SneakyThrows + public void testNeuralQueryEnricherProcessor_whenHybridQueryBuilderAndNoModelIdPassed_thenSuccess() { + initializeIndexIfNotExist(); + String modelId = getDeployedModelId(); + createSearchRequestProcessor(modelId, search_pipeline); + createPipelineProcessor(modelId, ingest_pipeline); + updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline)); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + neuralQueryBuilder.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1); + neuralQueryBuilder.queryText("Hello World"); + neuralQueryBuilder.k(1); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + Map response = search(index, hybridQueryBuilder, 2); + + assertFalse(response.isEmpty()); + + } + @SneakyThrows private void initializeIndexIfNotExist() { if (index.equals(NeuralQueryEnricherProcessorIT.index) && !indexExists(index)) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 7bae96780..72ea8fa7b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -20,6 +20,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.ArrayList; import java.util.function.Supplier; import org.apache.lucene.search.MatchNoDocsQuery; @@ -708,6 +709,13 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { assertNotNull(hybridQueryBuilder); } + public void testVisit() { + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder()); + List visitedQueries = new ArrayList<>(); + hybridQueryBuilder.visit(createTestVisitor(visitedQueries)); + assertEquals(3, visitedQueries.size()); + } + private Map getInnerMap(Object innerObject, String queryName, String fieldName) { if (!(innerObject instanceof Map)) { fail("field name does not map to nested object"); diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index 97e1f7c5d..a1e8210e6 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -7,6 +7,14 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonMap; import static java.util.stream.Collectors.toList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.BooleanClause; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilderVisitor; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import java.io.IOException; @@ -20,11 +28,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.Weight; import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.CheckedConsumer; @@ -230,4 +233,18 @@ public float getMaxScore(int upTo) { protected static void initFeatureFlags() { System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey(), "true"); } + + protected static QueryBuilderVisitor createTestVisitor(List visitedQueries) { + return new QueryBuilderVisitor() { + @Override + public void accept(QueryBuilder qb) { + visitedQueries.add(qb); + } + + @Override + public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) { + return this; + } + }; + } }