diff --git a/.idea/copyright/SPDX_ALv2.xml b/.idea/copyright/SPDX_ALv2.xml index a2485beef..3475d1512 100644 --- a/.idea/copyright/SPDX_ALv2.xml +++ b/.idea/copyright/SPDX_ALv2.xml @@ -1,6 +1,6 @@ - \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 09b4c47f9..c90ea1621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,4 +20,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Infrastructure ### Documentation ### Maintenance -### Refactoring \ No newline at end of file +### Refactoring +* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) diff --git a/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java b/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java new file mode 100644 index 000000000..1c5a3b875 --- /dev/null +++ b/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; +import org.opensearch.plugins.SearchPlugin; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Benchmarks for impact of changes around query parsing + */ +@Warmup(iterations = 5, time = 10) +@Measurement(iterations = 3, time = 10) +@Fork(3) +@State(Scope.Benchmark) +public class QueryParsingBenchmarks { + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final NamedXContentRegistry NAMED_X_CONTENT_REGISTRY = xContentRegistry(); + + @Param({ "128", "1024" }) + private int dimension; + @Param({ "basic", "filter" }) + private String type; + + private BytesReference bytesReference; + + @Setup + public void setup() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject("test"); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), generateVectorWithOnes(dimension)); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), 1); + if (type.equals("filter")) { + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), TERM_QUERY); + } + builder.endObject(); + builder.endObject(); + bytesReference = BytesReference.bytes(builder); + } + + @Benchmark + public void fromXContent(final Blackhole bh) throws IOException { + XContentParser xContentParser = createParser(); + bh.consume(KNNQueryBuilderParser.fromXContent(xContentParser)); + } + + private XContentParser createParser() throws IOException { + XContentParser contentParser = createParser(bytesReference); + contentParser.nextToken(); + return contentParser; + } + + private float[] generateVectorWithOnes(final int dimensions) { + float[] vector = new float[dimensions]; + Arrays.fill(vector, (float) 1); + return vector; + } + + private XContentParser createParser(final BytesReference data) throws IOException { + BytesArray array = (BytesArray) data; + return JsonXContent.jsonXContent.createParser( + NAMED_X_CONTENT_REGISTRY, + LoggingDeprecationHandler.INSTANCE, + array.array(), + array.offset(), + array.length() + ); + } + + private static NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + return new NamedXContentRegistry(list); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2f6a1fd90..69039e7c3 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -14,19 +14,15 @@ import org.apache.lucene.search.Query; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; @@ -34,7 +30,7 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.index.query.parser.MethodParametersParser; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.index.util.QueryContext; @@ -44,7 +40,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -55,7 +50,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; import static org.opensearch.knn.validation.ParameterValidator.validateParameters; @@ -78,6 +72,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES); public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); + public static final int K_MAX = 10000; /** * The name for the knn query @@ -141,7 +136,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; - private Builder() {} + public Builder() {} public Builder fieldName(String fieldName) { this.fieldName = fieldName; @@ -294,154 +289,26 @@ public static void initialize(ModelDao modelDao) { KNNQueryBuilder.modelDao = modelDao; } - private static float[] ObjectsToFloats(List objs) { - if (Objects.isNull(objs) || objs.isEmpty()) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME) - ); - } - float[] vec = new float[objs.size()]; - for (int i = 0; i < objs.size(); i++) { - if ((objs.get(i) instanceof Number) == false) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME) - ); - } - vec[i] = ((Number) objs.get(i)).floatValue(); - } - return vec; - } - /** * @param in Reads from stream * @throws IOException Throws IO Exception */ public KNNQueryBuilder(StreamInput in) throws IOException { super(in); - try { - fieldName = in.readString(); - vector = in.readFloatArray(); - k = in.readInt(); - filter = in.readOptionalNamedWriteable(QueryBuilder.class); - if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { - ignoreUnmapped = in.readOptionalBoolean(); - } - if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - maxDistance = in.readOptionalFloat(); - } - if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - minScore = in.readOptionalFloat(); - } - if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) { - methodParameters = MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion); - } - - } catch (IOException ex) { - throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); - } - } - - public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException { - String fieldName = null; - List vector = null; - float boost = AbstractQueryBuilder.DEFAULT_BOOST; - Integer k = null; - Float maxDistance = null; - Float minScore = null; - QueryBuilder filter = null; - String queryName = null; - String currentFieldName = null; - boolean ignoreUnmapped = false; - Map methodParameters = null; - XContentParser.Token token; - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token == XContentParser.Token.START_OBJECT) { - throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName); - fieldName = currentFieldName; - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token.isValue() || token == XContentParser.Token.START_ARRAY) { - if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - vector = parser.list(); - } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - boost = parser.floatValue(); - } else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); - } else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) { - if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { - ignoreUnmapped = parser.booleanValue(); - } - } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - queryName = parser.text(); - } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); - } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); - } else { - throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" - ); - } - } else if (token == XContentParser.Token.START_OBJECT) { - String tokenName = parser.currentName(); - if (FILTER_FIELD.getPreferredName().equals(tokenName)) { - log.debug(String.format("Start parsing filter for field [%s]", fieldName)); - filter = parseInnerQueryBuilder(parser); - } else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - methodParameters = MethodParametersParser.fromXContent(parser); - } else { - throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); - } - } else { - throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" - ); - } - } - } else { - throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, parser.currentName()); - fieldName = parser.currentName(); - vector = parser.list(); - } - } - - return KNNQueryBuilder.builder() - .queryName(queryName) - .boost(boost) - .fieldName(fieldName) - .vector(ObjectsToFloats(vector)) - .k(k) - .maxDistance(maxDistance) - .minScore(minScore) - .methodParameters(methodParameters) - .ignoreUnmapped(ignoreUnmapped) - .filter(filter) - .build(); + KNNQueryBuilder.Builder builder = KNNQueryBuilderParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion); + fieldName = builder.fieldName; + vector = builder.vector; + k = builder.k; + filter = builder.filter; + ignoreUnmapped = builder.ignoreUnmapped; + maxDistance = builder.maxDistance; + minScore = builder.minScore; + methodParameters = builder.methodParameters; } @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(fieldName); - out.writeFloatArray(vector); - out.writeInt(k); - out.writeOptionalNamedWriteable(filter); - if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { - out.writeOptionalBoolean(ignoreUnmapped); - } - if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(maxDistance); - } - if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(minScore); - } - if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) { - MethodParametersParser.streamOutput(out, methodParameters, IndexUtil::isClusterOnOrAfterMinRequiredVersion); - } + KNNQueryBuilderParser.streamOutput(out, this, IndexUtil::isClusterOnOrAfterMinRequiredVersion); } /** @@ -460,29 +327,7 @@ public Object vector() { @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); - builder.startObject(fieldName); - - builder.field(VECTOR_FIELD.getPreferredName(), vector); - builder.field(K_FIELD.getPreferredName(), k); - if (filter != null) { - builder.field(FILTER_FIELD.getPreferredName(), filter); - } - if (maxDistance != null) { - builder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance); - } - if (ignoreUnmapped) { - builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); - } - if (minScore != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); - } - if (methodParameters != null) { - MethodParametersParser.doXContent(builder, methodParameters); - } - printBoostAndQueryName(builder); - builder.endObject(); - builder.endObject(); + KNNQueryBuilderParser.toXContent(builder, params, this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java new file mode 100644 index 000000000..eb2273a7a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -0,0 +1,259 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentLocation; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.query.KNNQueryBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.function.Function; + +import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.K_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; +import static org.opensearch.knn.index.query.KNNQueryBuilder.VECTOR_FIELD; + +/** + * Helper class responsible for parsing and reverse parsing KNNQueryBuilder's + */ +@Log4j2 +public final class KNNQueryBuilderParser { + + private static final ObjectParser INTERNAL_PARSER = createInternalObjectParser(); + + /** + * For a k-NN query, we need to parse roughly the following structure into a KNNQueryBuilder: + * "my_vector2": { + * "vector": [2, 3, 5, 6], + * "k": 2, + * ... + * } + * to simplify the parsing process, we can define an object parser that will the internal structure after the + * field name. We cannot unfortunately also parse the field name because it ends up in the same structure + * as the nested portion. So we need to do that separately. + */ + private static ObjectParser createInternalObjectParser() { + ObjectParser internalParser = new ObjectParser<>(NAME, KNNQueryBuilder.Builder::new); + internalParser.declareFloat(KNNQueryBuilder.Builder::boost, BOOST_FIELD); + internalParser.declareString(KNNQueryBuilder.Builder::queryName, NAME_FIELD); + internalParser.declareFloatArray((b, v) -> b.vector(floatListToFloatArray(v)), VECTOR_FIELD); + internalParser.declareInt(KNNQueryBuilder.Builder::k, K_FIELD); + internalParser.declareBoolean((b, v) -> { + if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { + b.ignoreUnmapped(v); + } + }, IGNORE_UNMAPPED_FIELD); + internalParser.declareFloat(KNNQueryBuilder.Builder::maxDistance, MAX_DISTANCE_FIELD); + internalParser.declareFloat(KNNQueryBuilder.Builder::minScore, MIN_SCORE_FIELD); + + internalParser.declareObject( + KNNQueryBuilder.Builder::methodParameters, + (p, v) -> MethodParametersParser.fromXContent(p), + METHOD_PARAMS_FIELD + ); + internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD); + + return internalParser; + } + + /** + * Stream input for KNNQueryBuilder + * + * @param in stream out + * @param minClusterVersionCheck function to check min version + * @return KNNQueryBuilder.Builder class + * @throws IOException on stream failure + */ + public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function minClusterVersionCheck) throws IOException { + KNNQueryBuilder.Builder builder = new KNNQueryBuilder.Builder(); + builder.fieldName(in.readString()); + builder.vector(in.readFloatArray()); + builder.k(in.readInt()); + builder.filter(in.readOptionalNamedWriteable(QueryBuilder.class)); + + if (minClusterVersionCheck.apply("ignore_unmapped")) { + builder.ignoreUnmapped(in.readOptionalBoolean()); + } + if (minClusterVersionCheck.apply(KNNConstants.RADIAL_SEARCH_KEY)) { + builder.maxDistance(in.readOptionalFloat()); + } + if (minClusterVersionCheck.apply(KNNConstants.RADIAL_SEARCH_KEY)) { + builder.minScore(in.readOptionalFloat()); + } + if (minClusterVersionCheck.apply(METHOD_PARAMETER)) { + builder.methodParameters(MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion)); + } + + return builder; + } + + /** + * Stream output for KNNQueryBuilder + * + * @param out stream out + * @param builder KNNQueryBuilder to stream + * @param minClusterVersionCheck function to check min version + * @throws IOException on stream failure + */ + public static void streamOutput(StreamOutput out, KNNQueryBuilder builder, Function minClusterVersionCheck) + throws IOException { + out.writeString(builder.fieldName()); + out.writeFloatArray((float[]) builder.vector()); + out.writeInt(builder.getK()); + out.writeOptionalNamedWriteable(builder.getFilter()); + if (minClusterVersionCheck.apply("ignore_unmapped")) { + out.writeOptionalBoolean(builder.isIgnoreUnmapped()); + } + if (minClusterVersionCheck.apply(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(builder.getMaxDistance()); + } + if (minClusterVersionCheck.apply(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(builder.getMinScore()); + } + if (minClusterVersionCheck.apply(METHOD_PARAMETER)) { + MethodParametersParser.streamOutput(out, builder.getMethodParameters(), IndexUtil::isClusterOnOrAfterMinRequiredVersion); + } + } + + /** + * Convert XContent to KNNQueryBuilder + * + * @param parser input parser + * @return KNNQueryBuilder + * @throws IOException on parsing failure + */ + public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException { + String fieldName = null; + String currentFieldName = null; + XContentParser.Token token; + KNNQueryBuilder.Builder builder = null; + List vector = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + throwParsingExceptionOnMultipleFields(parser.getTokenLocation(), fieldName, currentFieldName); + fieldName = currentFieldName; + builder = INTERNAL_PARSER.apply(parser, null); + } else { + throwParsingExceptionOnMultipleFields(parser.getTokenLocation(), fieldName, parser.currentName()); + fieldName = parser.currentName(); + vector = parser.list(); + } + } + + if (builder == null) { + builder = KNNQueryBuilder.builder().vector(objectsToFloats(vector)); + } + builder.fieldName(fieldName); + return builder.build(); + } + + /** + * Convert KNNQueryBuilder to XContent + * + * @param builder xcontent builder to add KNNQueryBuilder + * @param params ToXContent params + * @param knnQueryBuilder KNNQueryBuilder to convert + * @throws IOException on conversion failure + */ + public static void toXContent(XContentBuilder builder, ToXContent.Params params, KNNQueryBuilder knnQueryBuilder) throws IOException { + builder.startObject(NAME); + builder.startObject(knnQueryBuilder.fieldName()); + + builder.field(VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + if (knnQueryBuilder.getFilter() != null) { + builder.field(FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + } + if (knnQueryBuilder.getMaxDistance() != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + } + if (knnQueryBuilder.isIgnoreUnmapped()) { + builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), knnQueryBuilder.isIgnoreUnmapped()); + } + if (knnQueryBuilder.getMinScore() != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + } + if (knnQueryBuilder.getMethodParameters() != null) { + MethodParametersParser.doXContent(builder, knnQueryBuilder.getMethodParameters()); + } + + builder.field(BOOST_FIELD.getPreferredName(), knnQueryBuilder.boost()); + if (knnQueryBuilder.queryName() != null) { + builder.field(NAME_FIELD.getPreferredName(), knnQueryBuilder.queryName()); + } + + builder.endObject(); + builder.endObject(); + } + + private static float[] floatListToFloatArray(List floats) { + if (Objects.isNull(floats) || floats.isEmpty()) { + throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME)); + } + float[] vec = new float[floats.size()]; + for (int i = 0; i < floats.size(); i++) { + vec[i] = floats.get(i); + } + return vec; + } + + private static float[] objectsToFloats(List objs) { + if (Objects.isNull(objs) || objs.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME) + ); + } + float[] vec = new float[objs.size()]; + for (int i = 0; i < objs.size(); i++) { + if ((objs.get(i) instanceof Number) == false) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME) + ); + } + vec[i] = ((Number) objs.get(i)).floatValue(); + } + return vec; + } + + private static void throwParsingExceptionOnMultipleFields( + XContentLocation contentLocation, + String processedFieldName, + String currentFieldName + ) { + if (processedFieldName != null) { + throw new ParsingException( + contentLocation, + "[" + NAME + "] query doesn't support multiple fields, found [" + processedFieldName + "] and [" + currentFieldName + "]" + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index f898b622e..5301b6e4e 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -19,6 +19,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.codec.KNNCodecService; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -173,7 +174,7 @@ public Map getMappers() { @Override public List> getQueries() { - return singletonList(new QuerySpec<>(KNNQueryBuilder.NAME, KNNQueryBuilder::new, KNNQueryBuilder::fromXContent)); + return singletonList(new QuerySpec<>(KNNQueryBuilder.NAME, KNNQueryBuilder::new, KNNQueryBuilderParser::fromXContent)); } @Override diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index cc8b86572..1ba79a495 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -452,7 +452,7 @@ public void testSearchWithInvalidSearchVectorType() { ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertEquals(400, ex.getResponse().getStatusLine().getStatusCode()); - assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); + assertTrue(ex.getMessage(), ex.getMessage().contains("[knn] failed to parse field [vector]")); } @SneakyThrows @@ -474,7 +474,7 @@ public void testSearchWithMissingQueryVector() { ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertEquals(400, ex.getResponse().getStatusLine().getStatusCode()); - assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); + assertTrue(ex.getMessage().contains("[knn] requires query vector")); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index e9c6ae449..d3c17d383 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -15,14 +15,10 @@ import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.index.Index; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; @@ -41,10 +37,8 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; -import org.opensearch.plugins.SearchPlugin; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; @@ -58,7 +52,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; -import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; public class KNNQueryBuilderTests extends KNNTestCase { @@ -163,307 +156,6 @@ public void testEmptyVector() { ); } - public void testFromXContent() throws Exception { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).k(K).build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_KnnWithMethodParameters() throws Exception { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .k(K) - .methodParameters(HNSW_METHOD_PARAMS) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); - builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); - builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); - builder.endObject(); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_whenMethodParameter_thenSucceed() throws Exception { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .maxDistance(MAX_DISTANCE) - .methodParameters(HNSW_METHOD_PARAMS) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); - builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); - builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); - builder.endObject(); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_whenMethodParameter_thenSucceed() throws Exception { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .minScore(MAX_DISTANCE) - .methodParameters(HNSW_METHOD_PARAMS) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); - builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); - builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); - builder.endObject(); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_withFilter() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .k(K) - .filter(TERM_QUERY) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); - builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_KnnWithEfSearch_withFilter() throws Exception { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .k(K) - .filter(TERM_QUERY) - .methodParameters(HNSW_METHOD_PARAMS) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); - builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); - builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); - builder.endObject(); - builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .maxDistance(MAX_DISTANCE) - .filter(TERM_QUERY) - .build(); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); - builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(queryVector) - .minScore(MIN_SCORE) - .filter(TERM_QUERY) - .build(); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(knnQueryBuilder.fieldName()); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); - builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - assertEquals(knnQueryBuilder, actualBuilder); - } - - public void testFromXContent_InvalidQueryVectorType() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - List invalidTypeQueryVector = new ArrayList<>(); - invalidTypeQueryVector.add(1.5); - invalidTypeQueryVector.add(2.5); - invalidTypeQueryVector.add("a"); - invalidTypeQueryVector.add(null); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(FIELD_NAME); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); - builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> KNNQueryBuilder.fromXContent(contentParser) - ); - assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); - } - - public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_thenException() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - List invalidTypeQueryVector = new ArrayList<>(); - invalidTypeQueryVector.add(1.5); - invalidTypeQueryVector.add(2.5); - invalidTypeQueryVector.add("a"); - invalidTypeQueryVector.add(null); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.startObject(FIELD_NAME); - builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); - builder.endObject(); - builder.endObject(); - XContentParser contentParser = createParser(builder); - contentParser.nextToken(); - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> KNNQueryBuilder.fromXContent(contentParser) - ); - assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); - } - - public void testFromXContent_missingQueryVector() throws Exception { - final ClusterService clusterService = mockClusterService(Version.CURRENT); - - final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); - - // Test without vector field - XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder(); - builderWithoutVectorField.startObject(); - builderWithoutVectorField.startObject(FIELD_NAME); - builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); - builderWithoutVectorField.endObject(); - builderWithoutVectorField.endObject(); - XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField); - contentParserWithoutVectorField.nextToken(); - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField) - ); - assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); - - // Test empty vector field - List emptyQueryVector = new ArrayList<>(); - XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder(); - builderWithEmptyVector.startObject(); - builderWithEmptyVector.startObject(FIELD_NAME); - builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector); - builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); - builderWithEmptyVector.endObject(); - builderWithEmptyVector.endObject(); - XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector); - contentParserWithEmptyVector.nextToken(); - exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector)); - assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); - } - - @Override - protected NamedXContentRegistry xContentRegistry() { - List list = ClusterModule.getNamedXWriteables(); - SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( - TermQueryBuilder.NAME, - TermQueryBuilder::new, - TermQueryBuilder::fromXContent - ); - list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); - NamedXContentRegistry registry = new NamedXContentRegistry(list); - return registry; - } - @Override protected NamedWriteableRegistry writableRegistry() { final List entries = ClusterModule.getNamedWriteables(); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryParserTests.java new file mode 100644 index 000000000..8a1fadf89 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryParserTests.java @@ -0,0 +1,486 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.index.query.parser; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.plugins.SearchPlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; + +public class KNNQueryParserTests extends KNNTestCase { + + private static final String FIELD_NAME = "myvector"; + private static final int K = 1; + private static final int EF_SEARCH = 10; + private static final Map HNSW_METHOD_PARAMS = Map.of("ef_search", EF_SEARCH); + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.5f; + private static final Float BOOST = 10.5f; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + + public void testFromXContent() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).k(K).build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_KnnWithMethodParameters() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_whenMethodParameter_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_whenMethodParameter_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MAX_DISTANCE) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_withFilter() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_KnnWithEfSearch_withFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MIN_SCORE) + .filter(TERM_QUERY) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilderParser.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_InvalidQueryVectorType() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilderParser.fromXContent(contentParser) + ); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] failed to parse field [vector]")); + } + + public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_thenException() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilderParser.fromXContent(contentParser) + ); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] failed to parse field [vector]")); + } + + public void testFromXContent_missingQueryVector() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + // Test without vector field + XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder(); + builderWithoutVectorField.startObject(); + builderWithoutVectorField.startObject(FIELD_NAME); + builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builderWithoutVectorField.endObject(); + builderWithoutVectorField.endObject(); + XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField); + contentParserWithoutVectorField.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilderParser.fromXContent(contentParserWithoutVectorField) + ); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] requires query vector")); + + // Test empty vector field + List emptyQueryVector = new ArrayList<>(); + XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder(); + builderWithEmptyVector.startObject(); + builderWithEmptyVector.startObject(FIELD_NAME); + builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector); + builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builderWithEmptyVector.endObject(); + builderWithEmptyVector.endObject(); + XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector); + contentParserWithEmptyVector.nextToken(); + exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilderParser.fromXContent(contentParserWithEmptyVector)); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] failed to parse field [vector]")); + } + + public void testFromXContent_whenFlat_thenException() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(FIELD_NAME, queryVector); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + Exception exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilderParser.fromXContent(contentParser)); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] requires exactly one of k, distance or score to be set")); + } + + public void testFromXContent_whenMultiFields_thenException() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME + "1"); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.endObject(); + builder.startObject(FIELD_NAME + "2"); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + Exception exception = expectThrows(ParsingException.class, () -> KNNQueryBuilderParser.fromXContent(contentParser)); + assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] query doesn't support multiple fields")); + } + + public void testToXContent_whenParamsVectorBoostK_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(NAME); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.field(BOOST_FIELD.getPreferredName(), BOOST); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).k(K).boost(BOOST).build(); + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilder); + testBuilder.endObject(); + assertEquals(builder.toString(), testBuilder.toString()); + } + + public void testToXContent_whenFilter_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(NAME); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), TERM_QUERY); + builder.field(BOOST_FIELD.getPreferredName(), BOOST); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .boost(BOOST) + .filter(TERM_QUERY) + .build(); + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilder); + testBuilder.endObject(); + assertEquals(builder.toString(), testBuilder.toString()); + } + + public void testToXContent_whenMaxDistance_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(NAME); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), 0); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); + builder.field(BOOST_FIELD.getPreferredName(), BOOST); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .boost(BOOST) + .maxDistance(MAX_DISTANCE) + .build(); + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilder); + testBuilder.endObject(); + assertEquals(builder.toString(), testBuilder.toString()); + } + + public void testToXContent_whenMethodParams_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(NAME); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.field(BOOST_FIELD.getPreferredName(), BOOST); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .boost(BOOST) + .k(K) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilder); + testBuilder.endObject(); + logger.info(builder.toString()); + logger.info(testBuilder.toString()); + assertEquals(builder.toString(), testBuilder.toString()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + NamedXContentRegistry registry = new NamedXContentRegistry(list); + return registry; + } +}