Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_unmapped support in KNNQueryBuilder #1071

Merged
merged 16 commits into from
Sep 29, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.opensearch.core.common.Strings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -43,6 +44,7 @@
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -55,6 +57,7 @@
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -88,6 +91,7 @@
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -113,6 +117,7 @@
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
ignoreUnmapped = in.readOptionalBoolean();
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -126,6 +131,7 @@
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand All @@ -144,6 +150,8 @@
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
ignoreUnmapped = parser.booleanValue();
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down Expand Up @@ -176,6 +184,7 @@
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand All @@ -187,6 +196,7 @@
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
out.writeBoolean(ignoreUnmapped);
}

/**
Expand All @@ -211,6 +221,20 @@
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;

Check warning on line 235 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L235

Added line #L235 was not covered by tests
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -221,6 +245,7 @@
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -230,6 +255,10 @@
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -307,4 +309,13 @@ private void assertSerialization(final Version version, final Optional<QueryBuil
}
}
}

public void testIgnoreUnmapped() throws IOException {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
}
}
Loading