Skip to content

Commit

Permalink
Pass empty QueryCollectorContext in case of hybrid query
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed May 3, 2024
1 parent 2dd2e8a commit 9659df1
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x)
### Features
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
### Bug Fixes
- Fix multi node "no such index" error in text chunking processor ([#713](https://github.com/opensearch-project/neural-search/pull/713))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import org.apache.lucene.search.Query;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.ConcurrentQueryPhaseSearcher;
import org.opensearch.search.query.QueryCollectorContext;

import java.io.IOException;
import java.util.LinkedList;

/**
* Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only
* empty query collector context
*/
public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher {

@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(
searchContext,
searcher,
query,
collectors,
QueryCollectorContext.EMPTY_CONTEXT,
hasFilterCollector,
hasTimeout
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import org.apache.lucene.search.Query;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.LinkedList;

/**
* Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only
* empty query collector context
*/
public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher {

@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(
searchContext,
searcher,
query,
collectors,
QueryCollectorContext.EMPTY_CONTEXT,
hasFilterCollector,
hasTimeout
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.search.query.QueryPhaseSearcherWrapper;

import lombok.extern.log4j.Log4j2;
Expand All @@ -36,6 +37,14 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

private final QueryPhaseSearcher defaultQueryPhaseSearcherWithEmptyCollectorContext;
private final QueryPhaseSearcher concurrentQueryPhaseSearcherWithEmptyCollectorContext;

public HybridQueryPhaseSearcher() {
this.defaultQueryPhaseSearcherWithEmptyCollectorContext = new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext();
this.concurrentQueryPhaseSearcherWithEmptyCollectorContext = new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext();
}

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
Expand All @@ -49,10 +58,17 @@ public boolean searchWith(
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
Query hybridQuery = extractHybridQuery(searchContext, query);
return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
}

private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchContext) {
return searchContext.shouldUseConcurrentSearch()
? concurrentQueryPhaseSearcherWithEmptyCollectorContext
: defaultQueryPhaseSearcherWithEmptyCollectorContext;
}

private static boolean isWrappedHybridQuery(final Query query) {
return query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import lombok.SneakyThrows;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.action.OriginalIndices;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QuerySearchResult;

import java.io.IOException;
import java.util.LinkedList;

import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase {
private static final String TEXT_FIELD_NAME = "field";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String QUERY_TEXT1 = "hello";
private static final Index dummyIndex = new Index("dummy", "dummy");

@SneakyThrows
public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy(
new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext()
);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(3);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext);
queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

assertTrue(querySearchResult.hasConsumedTopDocs());

releaseResources(directory, w, reader);
}

private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException {
w.close();
reader.close();
directory.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import lombok.SneakyThrows;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.action.OriginalIndices;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QuerySearchResult;

import java.io.IOException;
import java.util.LinkedList;

import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase {
private static final String TEXT_FIELD_NAME = "field";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String QUERY_TEXT1 = "hello";
private static final Index dummyIndex = new Index("dummy", "dummy");

@SneakyThrows
public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy(
new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext()
);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(3);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext);
queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

assertTrue(querySearchResult.hasConsumedTopDocs());

releaseResources(directory, w, reader);
}

private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException {
w.close();
reader.close();
directory.close();
}
}
Loading

0 comments on commit 9659df1

Please sign in to comment.