Skip to content

Commit

Permalink
Fixed Hybrid query for cases when it's wrapped into other compound qu…
Browse files Browse the repository at this point in the history
…eries (opensearch-project#498) (opensearch-project#501)

* Fixed nested field case

(cherry picked from commit 3991fe7)

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 27a38cb)
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 8, 2023
1 parent 51e6c00 commit d6095ab
Show file tree
Hide file tree
Showing 6 changed files with 792 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,10 @@ public DocIdSetIterator iterator() {
* Return the maximum score that documents between the last target that this iterator was shallow-advanced to included and upTo included.
* @param upTo upper limit for document id
* @return max score
* @throws IOException
*/
@Override
public float getMaxScore(int upTo) throws IOException {
return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
public float getMaxScore(int upTo) {
return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
try {
return scorer.getMaxScore(upTo);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.SeqNoFieldMapper;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
Expand Down Expand Up @@ -60,12 +67,120 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
if (isHybridQuery(query, searchContext)) {
Query hybridQuery = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

private boolean isHybridQuery(final Query query, final SearchContext searchContext) {
if (query instanceof HybridQuery) {
return true;
} else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) {
/* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code
https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370.
main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks
hybrid query for indexes with nested field types.
in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for
this search request.
below is sample structure of such query:
Boolean {
should: {
hybrid: {
sub_query1 {}
sub_query2 {}
}
}
filter: {
exists: {
field: "_primary_term"
}
}
}
TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */
// we have already checked if query in instance of Boolean in higher level else if condition
return ((BooleanQuery) query).clauses()
.stream()
.filter(clause -> clause.getQuery() instanceof HybridQuery == false)
.allMatch(clause -> {
return clause.getOccur() == BooleanClause.Occur.FILTER
&& clause.getQuery() instanceof FieldExistsQuery
&& SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField());
});
}
return false;
}

private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) {
return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query);
}

private boolean isWrappedHybridQuery(final Query query) {
return query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery);
}

private Query extractHybridQuery(final SearchContext searchContext, final Query query) {
if (hasNestedFieldOrNestedDocs(query, searchContext)
&& isWrappedHybridQuery(query)
&& ((BooleanQuery) query).clauses().size() > 0) {
// extract hybrid query and replace bool with hybrid query
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) {
throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query");
}
return booleanClauses.get(0).getQuery();
}
return query;
}

/**
* Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases
* when hybrid query is wrapped into other compound queries.
* For example, if we have Bool query like below we need to throw an error
* bool: {
* should: [
* match: {},
* hybrid: {
* sub_query1 {}
* sub_query2 {}
* }
* ]
* }
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
*
* @param query query to validate
*/
private void validateQuery(final SearchContext searchContext, final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext));
}
}
}

private void validateNestedBooleanQuery(final Query query, final int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level <= 0) {
// ideally we should throw an error here but this code is on the main search workflow path and that might block
// execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect
// results in case hybrid query is wrapped into such bool query
log.error("reached max nested query limit, cannot process bool query with that many nested clauses");
return;
}
if (query instanceof BooleanQuery) {
for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) {
validateNestedBooleanQuery(booleanClause.getQuery(), level - 1);
}
}
}

@VisibleForTesting
protected boolean searchWithCollector(
final SearchContext searchContext,
Expand Down Expand Up @@ -209,4 +324,9 @@ private float getMaxScore(final List<TopDocs> topDocs) {
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private int getMaxDepthLimit(final SearchContext searchContext) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList());
}

@SneakyThrows
protected void addKnnDoc(
String index,
String docId,
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
) {
addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList());
}

/**
* Add a set of knn vectors and text to an index
*
Expand All @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
* @param vectors List of vectors corresponding to those fields
* @param textFieldNames List of text fields to be added
* @param texts List of text corresponding to those fields
* @param nestedFieldNames List of nested fields to be added
* @param nestedFields List of fields and values corresponding to those fields
*/
@SneakyThrows
protected void addKnnDoc(
Expand All @@ -430,7 +444,9 @@ protected void addKnnDoc(
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
List<String> texts,
List<String> nestedFieldNames,
List<Map<String, String>> nestedFields
) {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
Expand All @@ -441,6 +457,16 @@ protected void addKnnDoc(
for (int i = 0; i < textFieldNames.size(); i++) {
builder.field(textFieldNames.get(i), texts.get(i));
}

for (int i = 0; i < nestedFieldNames.size(); i++) {
builder.field(nestedFieldNames.get(i));
builder.startObject();
Map<String, String> nestedValues = nestedFields.get(i);
for (Map.Entry<String, String> entry : nestedValues.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();

request.setJsonEntity(builder.toString());
Expand Down Expand Up @@ -523,7 +549,16 @@ protected boolean checkComplete(Map<String, Object> node) {
}

@SneakyThrows
private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int numberOfShards) {
protected String buildIndexConfiguration(final List<KNNFieldConfig> knnFieldConfigs, final int numberOfShards) {
return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards);
}

@SneakyThrows
protected String buildIndexConfiguration(
final List<KNNFieldConfig> knnFieldConfigs,
final List<String> nestedFields,
final int numberOfShards
) {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("settings")
Expand All @@ -544,6 +579,11 @@ private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int
.endObject()
.endObject();
}

for (String nestedField : nestedFields) {
xContentBuilder.startObject(nestedField).field("type", "nested").endObject();
}

xContentBuilder.endObject().endObject().endObject();
return xContentBuilder.toString();
}
Expand Down
Loading

0 comments on commit d6095ab

Please sign in to comment.