Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_unmapped support in KNNQueryBuilder #1071

Merged
merged 16 commits into from
Sep 29, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
Expand All @@ -24,6 +25,7 @@

import java.io.File;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand All @@ -32,6 +34,13 @@

public class IndexUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
}
};

/**
* Determines the size of a file on disk in kilobytes
*
Expand Down Expand Up @@ -195,4 +204,12 @@

return Collections.unmodifiableMap(loadParameters);
}

public static boolean isClusterOnOrAfterMinRequiredVersion(String key) {
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;

Check warning on line 211 in src/main/java/org/opensearch/knn/index/IndexUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/IndexUtil.java#L211

Added line #L211 was not covered by tests
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.opensearch.core.common.Strings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
Expand All @@ -31,6 +32,7 @@
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;

/**
Expand All @@ -43,6 +45,7 @@
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -55,6 +58,7 @@
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -88,6 +92,7 @@
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -113,6 +118,9 @@
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -126,6 +134,7 @@
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand All @@ -134,6 +143,8 @@
} else if (token == XContentParser.Token.START_OBJECT) {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName);
fieldName = currentFieldName;
System.out.println(currentFieldName);
System.out.println(IGNORE_UNMAPPED_FIELD.getPreferredName());
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
Expand All @@ -144,6 +155,10 @@
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();

Check warning on line 160 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L160

Added line #L160 was not covered by tests
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down Expand Up @@ -176,6 +191,7 @@
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand All @@ -187,6 +203,9 @@
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
}

/**
Expand All @@ -211,6 +230,20 @@
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -221,6 +254,9 @@
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);

Check warning on line 258 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L258

Added line #L258 was not covered by tests
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -230,6 +266,10 @@
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -307,4 +309,16 @@ private void assertSerialization(final Version version, final Optional<QueryBuil
}
}
}

public void testIgnoreUnmapped() throws IOException {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
assertTrue(knnQueryBuilder.getIgnoreUnmapped());
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
knnQueryBuilder.ignoreUnmapped(false);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class)));
}
}
Loading