Skip to content

Commit

Permalink
Refactor parsing of KNNQueryBuilder (opensearch-project#1824)
Browse files Browse the repository at this point in the history
Refactors parsing of KNNQueryBuilder. First, it moves parsing logic to a
separate class. Next, it uses ObjectParser instead of parsing manually
by hand. Lastly, it also moves out the streaming. To test, it contains a 
simple jmh benchmark for testing as well as new unit tests.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Jul 29, 2024
1 parent 161a60a commit 52636c4
Show file tree
Hide file tree
Showing 9 changed files with 875 additions and 482 deletions.
2 changes: 1 addition & 1 deletion .idea/copyright/SPDX_ALv2.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
### Documentation
### Maintenance
### Refactoring
### Refactoring
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import org.opensearch.cluster.ClusterModule;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

/**
* Benchmarks for impact of changes around query parsing
*/
@Warmup(iterations = 5, time = 10)
@Measurement(iterations = 3, time = 10)
@Fork(3)
@State(Scope.Benchmark)
public class QueryParsingBenchmarks {
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
private static final NamedXContentRegistry NAMED_X_CONTENT_REGISTRY = xContentRegistry();

@Param({ "128", "1024" })
private int dimension;
@Param({ "basic", "filter" })
private String type;

private BytesReference bytesReference;

@Setup
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject("test");
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), generateVectorWithOnes(dimension));
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), 1);
if (type.equals("filter")) {
builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), TERM_QUERY);
}
builder.endObject();
builder.endObject();
bytesReference = BytesReference.bytes(builder);
}

@Benchmark
public void fromXContent(final Blackhole bh) throws IOException {
XContentParser xContentParser = createParser();
bh.consume(KNNQueryBuilderParser.fromXContent(xContentParser));
}

private XContentParser createParser() throws IOException {
XContentParser contentParser = createParser(bytesReference);
contentParser.nextToken();
return contentParser;
}

private float[] generateVectorWithOnes(final int dimensions) {
float[] vector = new float[dimensions];
Arrays.fill(vector, (float) 1);
return vector;
}

private XContentParser createParser(final BytesReference data) throws IOException {
BytesArray array = (BytesArray) data;
return JsonXContent.jsonXContent.createParser(
NAMED_X_CONTENT_REGISTRY,
LoggingDeprecationHandler.INSTANCE,
array.array(),
array.offset(),
array.length()
);
}

private static NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
SearchPlugin.QuerySpec<?> spec = new SearchPlugin.QuerySpec<>(
TermQueryBuilder.NAME,
TermQueryBuilder::new,
TermQueryBuilder::fromXContent
);
list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p)));
return new NamedXContentRegistry(list);
}
}
183 changes: 14 additions & 169 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,23 @@
import org.apache.lucene.search.Query;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.util.QueryContext;
Expand All @@ -44,7 +40,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand All @@ -55,7 +50,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters;
import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
import static org.opensearch.knn.validation.ParameterValidator.validateParameters;
Expand All @@ -78,6 +72,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);

public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand Down Expand Up @@ -141,7 +136,7 @@ public static class Builder {
private String queryName;
private float boost = DEFAULT_BOOST;

private Builder() {}
public Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
Expand Down Expand Up @@ -294,154 +289,26 @@ public static void initialize(ModelDao modelDao) {
KNNQueryBuilder.modelDao = modelDao;
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME)
);
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME)
);
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
return vec;
}

/**
* @param in Reads from stream
* @throws IOException Throws IO Exception
*/
public KNNQueryBuilder(StreamInput in) throws IOException {
super(in);
try {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
maxDistance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) {
methodParameters = MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
}

public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException {
String fieldName = null;
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float maxDistance = null;
Float minScore = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
Map<String, ?> methodParameters = null;
XContentParser.Token token;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName);
fieldName = currentFieldName;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue() || token == XContentParser.Token.START_ARRAY) {
if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
vector = parser.list();
} else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
filter = parseInnerQueryBuilder(parser);
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
methodParameters = MethodParametersParser.fromXContent(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
} else {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, parser.currentName());
fieldName = parser.currentName();
vector = parser.list();
}
}

return KNNQueryBuilder.builder()
.queryName(queryName)
.boost(boost)
.fieldName(fieldName)
.vector(ObjectsToFloats(vector))
.k(k)
.maxDistance(maxDistance)
.minScore(minScore)
.methodParameters(methodParameters)
.ignoreUnmapped(ignoreUnmapped)
.filter(filter)
.build();
KNNQueryBuilder.Builder builder = KNNQueryBuilderParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
fieldName = builder.fieldName;
vector = builder.vector;
k = builder.k;
filter = builder.filter;
ignoreUnmapped = builder.ignoreUnmapped;
maxDistance = builder.maxDistance;
minScore = builder.minScore;
methodParameters = builder.methodParameters;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(maxDistance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(minScore);
}
if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) {
MethodParametersParser.streamOutput(out, methodParameters, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}
KNNQueryBuilderParser.streamOutput(out, this, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

/**
Expand All @@ -460,29 +327,7 @@ public Object vector() {

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.startObject(fieldName);

builder.field(VECTOR_FIELD.getPreferredName(), vector);
builder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (maxDistance != null) {
builder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (minScore != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (methodParameters != null) {
MethodParametersParser.doXContent(builder, methodParameters);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
KNNQueryBuilderParser.toXContent(builder, params, this);
}

@Override
Expand Down
Loading

0 comments on commit 52636c4

Please sign in to comment.