Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_unmapped support in KNNQueryBuilder #1071

Merged
merged 16 commits into from
Sep 29, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013)
* Improved the logic to switch to exact search for restrictive filters search for better recall. [#1059](https://github.com/opensearch-project/k-NN/pull/1059)
* Added max distance computation logic to enhance the switch to exact search in filtered Nearest Neighbor Search. [#1066](https://github.com/opensearch-project/k-NN/pull/1066)
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.opensearch.core.common.Strings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -43,6 +44,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -55,6 +57,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -88,6 +91,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -113,6 +117,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
ignoreUnmapped = in.readBoolean();
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -126,6 +131,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand All @@ -144,6 +150,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
ignoreUnmapped = parser.booleanValue();
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down Expand Up @@ -176,6 +184,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand All @@ -187,6 +196,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
out.writeBoolean(ignoreUnmapped);
}

/**
Expand All @@ -211,6 +221,20 @@ public QueryBuilder getFilter() {
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -221,6 +245,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -230,6 +255,10 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -307,4 +309,13 @@ private void assertSerialization(final Version version, final Optional<QueryBuil
}
}
}

public void testIgnoreUnmapped() throws IOException {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
}
}