Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_unmapped support in KNNQueryBuilder #1071

Merged
merged 16 commits into from
Sep 29, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.opensearch.Version;
import org.opensearch.core.common.Strings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand All @@ -28,7 +31,9 @@
import org.opensearch.index.query.QueryShardContext;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
Expand All @@ -43,6 +48,7 @@
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -55,6 +61,13 @@
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
}
};

/**
* Constructs a new knn query
Expand Down Expand Up @@ -88,6 +101,7 @@
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -113,6 +127,9 @@
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -126,6 +143,7 @@
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand All @@ -134,6 +152,8 @@
} else if (token == XContentParser.Token.START_OBJECT) {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName);
fieldName = currentFieldName;
System.out.println(currentFieldName);
System.out.println(IGNORE_UNMAPPED_FIELD.getPreferredName());
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
Expand All @@ -144,6 +164,10 @@
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();

Check warning on line 169 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L169

Added line #L169 was not covered by tests
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down Expand Up @@ -176,6 +200,7 @@
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand All @@ -187,6 +212,9 @@
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
}

/**
Expand All @@ -211,6 +239,20 @@
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -221,6 +263,9 @@
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);

Check warning on line 267 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L267

Added line #L267 was not covered by tests
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -230,6 +275,10 @@
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down Expand Up @@ -314,4 +363,12 @@
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion(String key) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;

Check warning on line 370 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L370

Added line #L370 was not covered by tests
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -307,4 +309,16 @@ private void assertSerialization(final Version version, final Optional<QueryBuil
}
}
}

public void testIgnoreUnmapped() throws IOException {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
assertTrue(knnQueryBuilder.getIgnoreUnmapped());
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
knnQueryBuilder.ignoreUnmapped(false);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class)));
}
}
Loading