Skip to content

Commit

Permalink
Block top docs collector for hybrid query
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Apr 24, 2024
1 parent 82d9432 commit 0c0a733
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.ConcurrentQueryPhaseSearcher;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.search.query.QueryPhaseSearcherWrapper;

import lombok.extern.log4j.Log4j2;
Expand All @@ -36,6 +38,14 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

private final QueryPhaseSearcher noDocCollectorDefaultQueryPhaseSearcher;
private final QueryPhaseSearcher concurrentQueryPhaseSearcher;

public HybridQueryPhaseSearcher() {
this.noDocCollectorDefaultQueryPhaseSearcher = new NoDocCollectorDefaultQueryPhaseSearcher();
this.concurrentQueryPhaseSearcher = new ConcurrentQueryPhaseSearcher();
}

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
Expand All @@ -49,7 +59,29 @@ public boolean searchWith(
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
Query hybridQuery = extractHybridQuery(searchContext, query);
return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
return searchWithForHybridQuery(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
}

private boolean searchWithForHybridQuery(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (searchContext.shouldUseConcurrentSearch()) {
return concurrentQueryPhaseSearcher.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
return noDocCollectorDefaultQueryPhaseSearcher.searchWith(
searchContext,
searcher,
query,
collectors,
hasFilterCollector,
hasTimeout
);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import org.apache.lucene.search.Query;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.LinkedList;

public class NoDocCollectorDefaultQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return super.searchWithCollector(searchContext, searcher, query, collectors, null, hasFilterCollector, hasTimeout, false);
}
}

0 comments on commit 0c0a733

Please sign in to comment.