From fdcb5d80df07672a0b44317783503c09a27afda9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 21 Nov 2023 18:37:52 -0800 Subject: [PATCH] Fixed nested field case, draft version Signed-off-by: Martin Gaievski --- .../query/HybridQueryPhaseSearcher.java | 56 ++++++++++++++++++- .../query/HybridQueryPhaseSearcherTests.java | 3 + 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index f65e30222..975abc642 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -19,12 +19,15 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; @@ -48,6 +51,8 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + final static int MAX_NESTED_SUBQUERY_LIMIT = 20; + public HybridQueryPhaseSearcher() { super(); } @@ -55,17 +60,64 @@ public HybridQueryPhaseSearcher() { public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, - final Query query, + Query query, final LinkedList collectors, final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (query instanceof HybridQuery) { + if (isHybridQuery(query, searchContext)) { + query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + validateHybridQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + void validateHybridQuery(final Query query) { + if (query instanceof BooleanQuery) { + List booleanClauses = ((BooleanQuery) query).clauses(); + for (BooleanClause booleanClause : booleanClauses) { + validateNestedBooleanQuery(booleanClause.getQuery(), 1); + } + } + } + + void validateNestedBooleanQuery(final Query query, int level) { + if (query instanceof HybridQuery) { + throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); + } + if (level >= MAX_NESTED_SUBQUERY_LIMIT) { + throw new IllegalStateException("reached max nested query limit, cannot process query"); + } + if (query instanceof BooleanQuery) { + for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { + validateNestedBooleanQuery(booleanClause.getQuery(), level + 1); + } + } + } + + private Query extractHybridQuery(SearchContext searchContext, Query query) { + if (query instanceof BooleanQuery + && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query) + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery)) { + //extract hybrid query and replace bool with hybrid query + query = ((BooleanQuery) query).clauses().get(0).getQuery(); + } + return query; + } + + boolean isHybridQuery(Query query, SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } + else if (new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query) + && query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery)) { + return true; + } + return false; + } + @VisibleForTesting protected boolean searchWithCollector( final SearchContext searchContext, diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e9c55cc54..f57903406 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -41,6 +41,7 @@ import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -204,6 +205,8 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);