From 15fc51d876a042a5ec2d7d377a8f508bdbf910c2 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 22 Jun 2022 16:56:56 -0700 Subject: [PATCH 1/7] Adding 'dense_vector' field type Signed-off-by: Martin Gaievski --- .../opensearch/search/knn/DenseVectorIT.java | 125 +++++ .../index/codec/KnnVectorFormatFactory.java | 93 ++++ .../PerFieldMappingPostingFormatCodec.java | 10 +- .../index/mapper/DenseVectorFieldMapper.java | 360 ++++++++++++++ .../index/mapper/KnnAlgorithmContext.java | 192 ++++++++ .../mapper/KnnAlgorithmContextFactory.java | 146 ++++++ .../opensearch/index/mapper/KnnContext.java | 144 ++++++ .../org/opensearch/index/mapper/Metric.java | 78 +++ .../org/opensearch/indices/IndicesModule.java | 2 + .../mapper/DenseVectorFieldTypeTests.java | 459 ++++++++++++++++++ .../index/mapper/DenseVectorMapperTests.java | 161 ++++++ .../aggregations/AggregatorTestCase.java | 2 + 12 files changed, 1771 insertions(+), 1 deletion(-) create mode 100644 server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java create mode 100644 server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java create mode 100644 server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java create mode 100644 server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java create mode 100644 server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java create mode 100644 server/src/main/java/org/opensearch/index/mapper/KnnContext.java create mode 100644 server/src/main/java/org/opensearch/index/mapper/Metric.java create mode 100644 server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java create mode 100644 server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java diff --git a/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java new file mode 100644 index 0000000000000..deefc3dcc309b --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.knn; + +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.VersionUtils; + +import java.util.Map; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; + +public class DenseVectorIT extends OpenSearchIntegTestCase { + + private static final float[] VECTOR_ONE = { 2.0f, 4.5f, 5.6f, 4.2f }; + private static final float[] VECTOR_TWO = { 4.0f, 2.5f, 1.6f, 2.2f }; + + @Override + protected boolean forbidPrivateIndexSettings() { + return false; + } + + public void testIndexingSingleDocumentWithoutKnn() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("vector_field") + .field("type", "dense_vector") + .field("dimension", 4) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()) + ); + ensureSearchable("test"); + } + + public void testIndexingSingleDocumentWithDefaultKnnParams() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("vector_field") + .field("type", "dense_vector") + .field("dimension", 4) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()) + ); + ensureSearchable("test"); + } + + public void testIndexingMultipleDocumentsWithHnswDefinition() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("field") + .field("type", "dense_vector") + .field("dimension", 4) + .field( + "knn", + Map.of("metric", "l2", "algorithm", Map.of("name", "hnsw", "parameters", Map.of("max_connections", 12, "beam_width", 256))) + ) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()), + client().prepareIndex("test").setId("2").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_TWO).endObject()) + ); + ensureSearchable("test"); + } +} diff --git a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java new file mode 100644 index 0000000000000..5eb1e2b9b0765 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene92.Lucene92Codec; +import org.apache.lucene.codecs.lucene92.Lucene92HnswVectorsFormat; +import org.opensearch.index.mapper.DenseVectorFieldMapper; +import org.opensearch.index.mapper.KnnAlgorithmContext; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; + +import java.util.Map; + +import static org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +/** + * Factory that creates a {@link KnnVectorsFormat knn vector format} based on a mapping + * configuration for the field. + * + * @opensearch.internal + */ +public class KnnVectorFormatFactory { + + private final MapperService mapperService; + + public KnnVectorFormatFactory(MapperService mapperService) { + this.mapperService = mapperService; + } + + /** + * Create KnnVectorsFormat with parameters specified in the field definition or return codec's default + * Knn Vector Format if field is not of DenseVector type + * @param field name of the field + * @return KnnVectorFormat that is specific to a mapped field + */ + public KnnVectorsFormat create(final String field) { + final MappedFieldType mappedFieldType = mapperService.fieldType(field); + if (isDenseVectorFieldType(mappedFieldType)) { + final DenseVectorFieldType knnVectorFieldType = (DenseVectorFieldType) mappedFieldType; + final KnnAlgorithmContext algorithmContext = knnVectorFieldType.getKnnContext().getKnnAlgorithmContext(); + final Map methodParams = algorithmContext.getParameters(); + int maxConnections = getIntegerParam(methodParams, HNSW_PARAMETER_MAX_CONNECTIONS); + int beamWidth = getIntegerParam(methodParams, HNSW_PARAMETER_BEAM_WIDTH); + final KnnVectorsFormat luceneHnswVectorsFormat = new Lucene92HnswVectorsFormat(maxConnections, beamWidth); + return luceneHnswVectorsFormat; + } + return Lucene92Codec.getDefault().knnVectorsFormat(); + } + + private boolean isDenseVectorFieldType(final MappedFieldType mappedFieldType) { + if (mappedFieldType != null && mappedFieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) { + return true; + } + return false; + } + + private int getIntegerParam(Map methodParams, String name) { + return (Integer) methodParams.get(name); + } +} diff --git a/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java b/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java index fd0c66983208a..f1cd21bfae186 100644 --- a/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java +++ b/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java @@ -35,6 +35,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; import org.apache.lucene.codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; @@ -57,16 +58,18 @@ public class PerFieldMappingPostingFormatCodec extends Lucene92Codec { private final Logger logger; private final MapperService mapperService; private final DocValuesFormat dvFormat = new Lucene90DocValuesFormat(); + private final KnnVectorFormatFactory knnVectorsFormatFactory; static { assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) - : "PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC; + : "PerFieldMappingPostingFormatCodec must subclass the latest lucene codec: " + Lucene.LATEST_CODEC; } public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService, Logger logger) { super(compressionMode); this.mapperService = mapperService; this.logger = logger; + this.knnVectorsFormatFactory = new KnnVectorFormatFactory(mapperService); } @Override @@ -84,4 +87,9 @@ public PostingsFormat getPostingsFormatForField(String field) { public DocValuesFormat getDocValuesFormatForField(String field) { return dvFormat; } + + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return knnVectorsFormatFactory.create(field); + } } diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java new file mode 100644 index 0000000000000..8d1ab0ceb69e7 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -0,0 +1,360 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.common.Explicit; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Field Mapper for Dense vector type. Extends ParametrizedFieldMapper in order to easily configure mapping parameters. + */ +public class DenseVectorFieldMapper extends ParametrizedFieldMapper { + + public static final String CONTENT_TYPE = "dense_vector"; + + /** + * Define the max dimension a knn_vector mapping can have. + */ + public static final int MAX_DIMENSION = 1024; + + private static DenseVectorFieldMapper toType(FieldMapper in) { + return (DenseVectorFieldMapper) in; + } + + /** + * Builder for DenseVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector + * field type + */ + public static class Builder extends ParametrizedFieldMapper.Builder { + + private final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, false); + + protected final Parameter dimension = new Parameter<>(Names.DIMENSION.getValue(), false, () -> 1, (n, c, o) -> { + int value = XContentMapValues.nodeIntegerValue(o); + if (value > MAX_DIMENSION) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[dimension] value cannot be greater than %d for vector [%s]", MAX_DIMENSION, name) + ); + } + if (value <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[dimension] value must be greater than 0 for vector [%s]", name) + ); + } + return value; + }, m -> toType(m).dimension).setSerializer((b, n, v) -> b.field(n, v.intValue()), v -> Integer.toString(v.intValue())); + + private final Parameter knnContext = new Parameter<>( + Names.KNN.getValue(), + false, + () -> null, + (n, c, o) -> KnnContext.parse(o), + m -> toType(m).knnContext + ).setSerializer(((b, n, v) -> { + if (v == null) { + return; + } + b.startObject(n); + v.toXContent(b, ToXContent.EMPTY_PARAMS); + b.endObject(); + }), m -> m.getKnnAlgorithmContext().getMethod().name()); + + public Builder(String name) { + super(name); + } + + @Override + protected List> getParameters() { + return List.of(dimension, knnContext, hasDocValues); + } + + @Override + public DenseVectorFieldMapper build(BuilderContext context) { + return new DenseVectorFieldMapper( + buildFullName(context), + new DenseVectorFieldType(buildFullName(context), dimension.get(), knnContext.get()), + multiFieldsBuilder.build(this, context), + copyTo.build() + ); + } + } + + /** + * Type parser for dense_vector field mapper + * + * @opensearch.internal + */ + public static class TypeParser implements Mapper.TypeParser { + + @Override + public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { + Builder builder = new DenseVectorFieldMapper.Builder(name); + Object dimensionField = node.get(Names.DIMENSION.getValue()); + String dimension = XContentMapValues.nodeStringValue(dimensionField, null); + if (dimension == null) { + throw new MapperParsingException(String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name)); + } + builder.parse(name, parserContext, node); + return builder; + } + } + + /** + * Field type for dense_vector field mapper + * + * @opensearch.internal + */ + public static class DenseVectorFieldType extends MappedFieldType { + + private final int dimension; + private final KnnContext knnContext; + private final boolean hasDocValues; + + public DenseVectorFieldType(String name, int dimension, KnnContext knnContext) { + this(name, Collections.emptyMap(), dimension, knnContext, false); + } + + public DenseVectorFieldType(String name, Map meta, int dimension, KnnContext knnContext, boolean hasDocValues) { + super(name, false, false, false, TextSearchInfo.NONE, meta); + this.dimension = dimension; + this.knnContext = knnContext; + this.hasDocValues = hasDocValues; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { + throw new UnsupportedOperationException("Dense_vector does not support fields search"); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public Query existsQuery(QueryShardContext context) { + return new FieldExistsQuery(name()); + } + + @Override + public Query termQuery(Object value, QueryShardContext context) { + throw new QueryShardException( + context, + "Dense_vector does not support exact searching, use KNN queries instead [" + name() + "]" + ); + } + + public int getDimension() { + return dimension; + } + + public KnnContext getKnnContext() { + return knnContext; + } + } + + protected Explicit ignoreMalformed; + protected Integer dimension; + protected boolean isKnnEnabled; + protected KnnContext knnContext; + protected boolean hasDocValues; + protected String modelId; + + public DenseVectorFieldMapper(String simpleName, DenseVectorFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo) { + super(simpleName, mappedFieldType, multiFields, copyTo); + dimension = mappedFieldType.getDimension(); + fieldType = new FieldType(DenseVectorFieldMapper.Defaults.FIELD_TYPE); + isKnnEnabled = mappedFieldType.getKnnContext() != null; + if (isKnnEnabled) { + knnContext = mappedFieldType.getKnnContext(); + fieldType.setVectorDimensionsAndSimilarityFunction( + mappedFieldType.getDimension(), + Metric.toSimilarityFunction(knnContext.getMetric()) + ); + } + fieldType.freeze(); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + parseCreateField(context, fieldType().getDimension()); + } + + protected void parseCreateField(ParseContext context, int dimension) throws IOException { + + context.path().add(simpleName()); + + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + + vector.add(value); + token = context.parser().nextToken(); + } + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + + vector.add(value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return; + } + + if (dimension != vector.size()) { + String errorMessage = String.format( + Locale.ROOT, + "Vector dimensions mismatch, expected [%d] but given [%d]", + dimension, + vector.size() + ); + throw new IllegalArgumentException(errorMessage); + } + + float[] array = new float[vector.size()]; + int i = 0; + for (Float f : vector) { + array[i++] = f; + } + + KnnVectorField point = new KnnVectorField(name(), array, fieldType); + + context.doc().add(point); + context.path().remove(); + } + + @Override + protected boolean docValuesByDefault() { + return false; + } + + @Override + public ParametrizedFieldMapper.Builder getMergeBuilder() { + return new DenseVectorFieldMapper.Builder(simpleName()).init(this); + } + + @Override + public final boolean parsesArrayValue() { + return true; + } + + @Override + public DenseVectorFieldType fieldType() { + return (DenseVectorFieldType) super.fieldType(); + } + + @Override + protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { + super.doXContentBody(builder, includeDefaults, params); + } + + /** + * Define names for dense_vector parameters + * + * @opensearch.internal + */ + enum Names { + DIMENSION("dimension"), + KNN("knn"); + + Names(String value) { + this.value = value; + } + + String value; + + String getValue() { + return this.value; + } + } + + /** + * Default parameters + * + * @opensearch.internal + */ + static class Defaults { + public static final FieldType FIELD_TYPE = new FieldType(); + + static { + FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setIndexOptions(IndexOptions.NONE); + FIELD_TYPE.setDocValuesType(DocValuesType.NONE); + FIELD_TYPE.freeze(); + } + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java new file mode 100644 index 0000000000000..ae47da1e7e3b6 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java @@ -0,0 +1,192 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.mapper; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Abstracts KNN Algorithm segment of dense_vector field type + */ +public class KnnAlgorithmContext implements ToXContentFragment, Writeable { + + private static final String PARAMETERS = "parameters"; + private static final String NAME = "name"; + + private final Method method; + private final Map parameters; + + public KnnAlgorithmContext(Method method, Map parameters) { + this.method = method; + this.parameters = parameters; + } + + public Method getMethod() { + return method; + } + + public Map getParameters() { + return parameters; + } + + public static KnnAlgorithmContext parse(Object in) { + if (!(in instanceof Map)) { + throw new MapperParsingException("Unable to parse [algorithm] component"); + } + @SuppressWarnings("unchecked") + final Map methodMap = (Map) in; + Method method = Method.HNSW; + Map parameters = Map.of(); + + for (Map.Entry methodEntry : methodMap.entrySet()) { + final String key = methodEntry.getKey(); + final Object value = methodEntry.getValue(); + if (NAME.equals(key)) { + if (!(value instanceof String)) { + throw new MapperParsingException("Component [name] should be a string"); + } + try { + Method.fromName((String) value); + } catch (IllegalArgumentException illegalArgumentException) { + throw new MapperParsingException( + String.format(Locale.ROOT, "[algorithm name] value [%s] is invalid or not supported", value) + ); + } + } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { + throw new MapperParsingException("Unable to parse [parameters] for algorithm"); + } + // Check to interpret map parameters as sub-methodComponentContexts + parameters = ((Map) value).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> { + Object v = e.getValue(); + if (v instanceof Map) { + throw new MapperParsingException( + String.format(Locale.ROOT, "Unable to parse parameter [%s] for [algorithm]", e.getValue()) + ); + } + return v; + })); + + } else { + throw new MapperParsingException(String.format(Locale.ROOT, "Invalid parameter %s for [algorithm]", key)); + } + } + return KnnAlgorithmContextFactory.createContext(method, parameters); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(NAME, method.name()); + if (parameters == null) { + builder.field(PARAMETERS, (String) null); + } else { + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + builder.field(key, value); + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); + } + + }); + builder.endObject(); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.method.name()); + if (this.parameters != null) { + out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + } + } + + private static class ParameterMapValueWriter implements Writer { + + private ParameterMapValueWriter() {} + + @Override + public void write(StreamOutput out, Object o) throws IOException { + if (o instanceof KnnAlgorithmContext) { + out.writeBoolean(true); + ((KnnAlgorithmContext) o).writeTo(out); + } else { + out.writeBoolean(false); + out.writeGenericValue(o); + } + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + KnnAlgorithmContext that = (KnnAlgorithmContext) obj; + return method == that.method && this.parameters.equals(that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(method, parameters); + } + + /** + * Abstracts supported search methods for KNN + */ + public enum Method { + HNSW; + + private static final Map STRING_TO_METHOD = Map.of("hnsw", HNSW); + + public static Method fromName(String methodName) { + return Optional.ofNullable(STRING_TO_METHOD.get(methodName.toLowerCase(Locale.ROOT))) + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Provided knn method %s is not supported", methodName)) + ); + } + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java new file mode 100644 index 0000000000000..81ecb77d0cecc --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.mapper; + +import org.opensearch.index.mapper.KnnAlgorithmContext.Method; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +/** + * Class abstracts creation of KNN Algorithm context + */ +public class KnnAlgorithmContextFactory { + + public static final String HNSW_PARAMETER_MAX_CONNECTIONS = "max_connections"; + public static final String HNSW_PARAMETER_BEAM_WIDTH = "beam_width"; + + protected static final int MAX_CONNECTIONS_DEFAULT_VALUE = 16; + protected static final int MAX_CONNECTIONS_MAX_VALUE = 16; + protected static final int BEAM_WIDTH_DEFAULT_VALUE = 100; + protected static final int BEAM_WIDTH_MAX_VALUE = 512; + + private static final Map DEFAULT_CONTEXTS = Map.of( + Method.HNSW, + createContext( + Method.HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, MAX_CONNECTIONS_DEFAULT_VALUE, HNSW_PARAMETER_BEAM_WIDTH, BEAM_WIDTH_DEFAULT_VALUE) + ) + ); + + public static KnnAlgorithmContext defaultContext(Method method) { + return Optional.ofNullable(DEFAULT_CONTEXTS.get(method)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid knn method provided [%s], only HNSW is supported", method.name()) + ) + ); + } + + public static KnnAlgorithmContext createContext(Method method, Map parameters) { + Map, KnnAlgorithmContext>> methodToContextSupplier = Map.of(Method.HNSW, hnswContext()); + + return Optional.ofNullable(methodToContextSupplier.get(method)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid knn method provided [%s], only HNSW is supported", method.name()) + ) + ) + .apply(parameters); + } + + private static Function, KnnAlgorithmContext> hnswContext() { + Function, KnnAlgorithmContext> supplierFunction = params -> { + validateForSupportedParameters(Set.of(HNSW_PARAMETER_MAX_CONNECTIONS, HNSW_PARAMETER_BEAM_WIDTH), params); + + int maxConnections = getParameter( + params, + HNSW_PARAMETER_MAX_CONNECTIONS, + MAX_CONNECTIONS_DEFAULT_VALUE, + MAX_CONNECTIONS_MAX_VALUE + ); + int beamWidth = getParameter(params, HNSW_PARAMETER_BEAM_WIDTH, BEAM_WIDTH_DEFAULT_VALUE, BEAM_WIDTH_MAX_VALUE); + + KnnAlgorithmContext hnswKnnMethodContext = new KnnAlgorithmContext( + Method.HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, maxConnections, HNSW_PARAMETER_BEAM_WIDTH, beamWidth) + ); + return hnswKnnMethodContext; + }; + return supplierFunction; + } + + private static int getParameter(Map parameters, String paramName, int defaultValue, int maxValue) { + int value = defaultValue; + if (isNullOrEmpty(parameters)) { + return value; + } + if (parameters.containsKey(paramName)) { + if (!(parameters.get(paramName) instanceof Integer)) { + throw new MapperParsingException( + String.format(Locale.ROOT, "Invalid value for field %s, it must be an integer number", paramName) + ); + } + value = (int) parameters.get(paramName); + } + if (value > maxValue) { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value cannot be greater than %d", paramName, maxValue)); + } + if (value <= 0) { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value must be greater than 0", paramName)); + } + return value; + } + + static void validateForSupportedParameters(Set supportedParams, Map actualParams) { + if (isNullOrEmpty(actualParams)) { + return; + } + Optional unsupportedParam = actualParams.keySet() + .stream() + .filter(actualParamName -> !supportedParams.contains(actualParamName)) + .findAny(); + if (unsupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Algorithm parameter [%s] is not supported", unsupportedParam.get()) + ); + } + } + + private static boolean isNullOrEmpty(Map parameters) { + return parameters == null || parameters.isEmpty(); + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java new file mode 100644 index 0000000000000..fa94972476a51 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.mapper; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +/** + * Abstracts KNN segment of dense_vector field type + */ +public class KnnContext implements ToXContentFragment, Writeable { + + private final Metric metric; + private final KnnAlgorithmContext knnAlgorithmContext; + private static final String KNN_METRIC_NAME = "metric"; + private static final String ALGORITHM = "algorithm"; + + KnnContext(final Metric metric, final KnnAlgorithmContext knnAlgorithmContext) { + this.metric = metric; + this.knnAlgorithmContext = knnAlgorithmContext; + } + + /** + * Parses an Object into a KnnContext. + * + * @param in Object containing mapping to be parsed + * @return KnnContext + */ + public static KnnContext parse(Object in) { + if (!(in instanceof Map)) { + throw new MapperParsingException("Unable to parse mapping into KnnContext. Object not of type \"Map\""); + } + + Map knnMap = (Map) in; + + Metric metric = Metric.L2; + KnnAlgorithmContext knnAlgorithmContext = KnnAlgorithmContextFactory.defaultContext(KnnAlgorithmContext.Method.HNSW); + + String key; + Object value; + for (Map.Entry methodEntry : knnMap.entrySet()) { + key = methodEntry.getKey(); + value = methodEntry.getValue(); + if (KNN_METRIC_NAME.equals(key)) { + if (value != null && !(value instanceof String)) { + throw new MapperParsingException(String.format(Locale.ROOT, "[%s] must be a string", KNN_METRIC_NAME)); + } + + if (value != null) { + try { + metric = Metric.fromName((String) value); + } catch (IllegalArgumentException illegalArgumentException) { + throw new MapperParsingException(String.format(Locale.ROOT, "[%s] value [%s] is invalid", key, value)); + } + } + } else if (ALGORITHM.equals(key)) { + if (value == null) { + continue; + } + + if (!(value instanceof Map)) { + throw new MapperParsingException("Unable to parse knn algorithm"); + } + knnAlgorithmContext = KnnAlgorithmContext.parse(value); + } else { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value is invalid", key)); + } + } + return new KnnContext(metric, knnAlgorithmContext); + } + + public Metric getMetric() { + return metric; + } + + public KnnAlgorithmContext getKnnAlgorithmContext() { + return knnAlgorithmContext; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(KNN_METRIC_NAME, metric.name()); + builder.startObject(ALGORITHM); + builder = knnAlgorithmContext.toXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(metric.name()); + this.knnAlgorithmContext.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KnnContext that = (KnnContext) o; + return metric == that.metric && Objects.equals(knnAlgorithmContext, that.knnAlgorithmContext); + } + + @Override + public int hashCode() { + return Objects.hash(metric, knnAlgorithmContext.hashCode()); + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/Metric.java b/server/src/main/java/org/opensearch/index/mapper/Metric.java new file mode 100644 index 0000000000000..5219f8e20b5fb --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/Metric.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.index.VectorSimilarityFunction; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +/** + * Abstracts supported metrics for dense_vector and KNN search + */ +public enum Metric { + L2, + COSINE, + DOT_PRODUCT; + + private static final Map STRING_TO_METRIC = Map.of("l2", L2, "cosine", COSINE, "dot_product", DOT_PRODUCT); + + private static final Map METRIC_TO_VECTOR_SIMILARITY_FUNCTION = Map.of( + L2, + VectorSimilarityFunction.EUCLIDEAN, + COSINE, + VectorSimilarityFunction.COSINE, + DOT_PRODUCT, + VectorSimilarityFunction.DOT_PRODUCT + ); + + public static Metric fromName(String metricName) { + return Optional.ofNullable(STRING_TO_METRIC.get(metricName.toLowerCase(Locale.ROOT))) + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Provided [metric] %s is not supported", metricName)) + ); + } + + /** + * Convert from our Metric type to Lucene VectorSimilarityFunction type. Only Euclidean metric is supported + */ + public static VectorSimilarityFunction toSimilarityFunction(Metric metric) { + return Optional.ofNullable(METRIC_TO_VECTOR_SIMILARITY_FUNCTION.get(metric)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Provided metric %s cannot be converted to vector similarity function", metric.name()) + ) + ); + } +} diff --git a/server/src/main/java/org/opensearch/indices/IndicesModule.java b/server/src/main/java/org/opensearch/indices/IndicesModule.java index 29ff507ad9fcf..88f4c3359bd8a 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesModule.java +++ b/server/src/main/java/org/opensearch/indices/IndicesModule.java @@ -48,6 +48,7 @@ import org.opensearch.index.mapper.CompletionFieldMapper; import org.opensearch.index.mapper.DataStreamFieldMapper; import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.DenseVectorFieldMapper; import org.opensearch.index.mapper.FieldAliasMapper; import org.opensearch.index.mapper.FieldNamesFieldMapper; import org.opensearch.index.mapper.GeoPointFieldMapper; @@ -161,6 +162,7 @@ public static Map getMappers(List mappe mappers.put(CompletionFieldMapper.CONTENT_TYPE, CompletionFieldMapper.PARSER); mappers.put(FieldAliasMapper.CONTENT_TYPE, new FieldAliasMapper.TypeParser()); mappers.put(GeoPointFieldMapper.CONTENT_TYPE, new GeoPointFieldMapper.TypeParser()); + mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, new DenseVectorFieldMapper.TypeParser()); for (MapperPlugin mapperPlugin : mapperPlugins) { for (Map.Entry entry : mapperPlugin.getMappers().entrySet()) { diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java new file mode 100644 index 0000000000000..6697691936500 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -0,0 +1,459 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.opensearch.index.mapper; + +import org.junit.Before; +import org.mockito.Mockito; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.Strings; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.compress.CompressedXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.IndexService; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.mapper.FieldTypeTestCase.MOCK_QSC_DISALLOW_EXPENSIVE; +import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +public class DenseVectorFieldTypeTests extends OpenSearchSingleNodeTestCase { + private static final String ALGORITHM_HNSW = "HNSW"; + private static final String DENSE_VECTOR_TYPE_NAME = "dense_vector"; + private static final int DIMENSION = 2; + private static final String FIELD_NAME = "field"; + private static final String METRIC_L2 = "L2"; + private static final float[] VECTOR = { 2.0f, 4.5f }; + + private IndexService indexService; + private DocumentMapperParser parser; + private MappedFieldType fieldType; + + @Before + public void setup() throws Exception { + indexService = createIndex("test"); + parser = indexService.mapperService().documentMapperParser(); + + KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( + HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 10, HNSW_PARAMETER_BEAM_WIDTH, 100) + ); + KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); + fieldType = new DenseVectorFieldMapper.DenseVectorFieldType(FIELD_NAME, 1, knnContext); + } + + public void testIndexingWithoutEnablingKnn() throws IOException { + XContentBuilder mappingAllDefaults = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .endObject() + .endObject() + .endObject() + .endObject(); + parser.parse("type", new CompressedXContent(Strings.toString(mappingAllDefaults))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); + } + + public void testIndexingWithDefaultParams() throws IOException { + XContentBuilder mappingAllDefaults = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + parser.parse("type", new CompressedXContent(Strings.toString(mappingAllDefaults))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); + } + + public void testIndexingWithAlgorithmParameters() throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field( + "knn", + Map.of( + "metric", + METRIC_L2, + "algorithm", + Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 16)) + ) + ) + .endObject() + .endObject() + .endObject() + .endObject(); + parser.parse("type", new CompressedXContent(Strings.toString(mapping))); + } + + public void testCosineMetric() throws IOException { + XContentBuilder mappingCosineMetric = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of("metric", "cosine", "algorithm", Map.of("name", ALGORITHM_HNSW))) + .endObject() + .endObject() + .endObject() + .endObject(); + parser.parse("type", new CompressedXContent(Strings.toString(mappingCosineMetric))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); + } + + public void testDotProductMetric() throws IOException { + XContentBuilder mappingDotProductMetric = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", ALGORITHM_HNSW))) + .endObject() + .endObject() + .endObject() + .endObject(); + parser.parse("type", new CompressedXContent(Strings.toString(mappingDotProductMetric))) + .parse(source(b -> b.field(FIELD_NAME, VECTOR))); + } + + public void testHNSWAlgorithmParametersInvalidInput() throws Exception { + XContentBuilder mappingInvalidMaxConnections = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field( + "knn", + Map.of( + "metric", + METRIC_L2, + "algorithm", + Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 50)) + ) + ) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMaxConnections))) + ); + org.hamcrest.MatcherAssert.assertThat( + mapperExceptionInvalidMaxConnections.getMessage(), + containsString("max_connections value cannot be greater than") + ); + + XContentBuilder mappingInvalidBeamWidth = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field( + "knn", + Map.of( + "metric", + METRIC_L2, + "algorithm", + Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 1024, "max_connections", 6)) + ) + ) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidmBeamWidth = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidBeamWidth))) + ); + org.hamcrest.MatcherAssert.assertThat( + mapperExceptionInvalidmBeamWidth.getMessage(), + containsString("beam_width value cannot be greater than") + ); + + XContentBuilder mappingUnsupportedParam = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field( + "knn", + Map.of( + "metric", + METRIC_L2, + "algorithm", + Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 6, "some_param", 23)) + ) + ) + .endObject() + .endObject() + .endObject() + .endObject(); + + final IllegalArgumentException mapperExceptionUnsupportedParam = expectThrows( + IllegalArgumentException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingUnsupportedParam))) + ); + assertEquals(mapperExceptionUnsupportedParam.getMessage(), "Algorithm parameter [some_param] is not supported"); + } + + public void testInvalidVectorDimension() throws Exception { + XContentBuilder mappingMissingDimension = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionMissingDimension = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingMissingDimension))) + ); + org.hamcrest.MatcherAssert.assertThat( + mapperExceptionMissingDimension.getMessage(), + containsString("[dimension] property must be specified for field") + ); + + XContentBuilder mappingInvalidDimension = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", 1200) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + final IllegalArgumentException exceptionInvalidDimension = expectThrows( + IllegalArgumentException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidDimension))) + ); + org.hamcrest.MatcherAssert.assertThat( + exceptionInvalidDimension.getMessage(), + containsString("[dimension] value cannot be greater than 1024 for vector") + ); + + XContentBuilder mappingDimentionsMismatch = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionIDimentionsMismatch = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingDimentionsMismatch))) + .parse(source(b -> b.field(FIELD_NAME, new float[] { 2.0f, 4.5f, 5.6f }))) + ); + org.hamcrest.MatcherAssert.assertThat( + mapperExceptionIDimentionsMismatch.getMessage(), + containsString("failed to parse field [field] of type [dense_vector]") + ); + } + + public void testInvalidMetric() throws Exception { + XContentBuilder mappingInvalidMetric = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of("metric", "LAMBDA", "algorithm", Map.of("name", ALGORITHM_HNSW))) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidMetric = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMetric))) + ); + org.hamcrest.MatcherAssert.assertThat( + mapperExceptionInvalidMetric.getMessage(), + containsString("[metric] value [LAMBDA] is invalid") + ); + } + + public void testInvalidAlgorithm() throws Exception { + XContentBuilder mappingInvalidAlgorithm = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of("metric", METRIC_L2, "algorithm", Map.of("name", "MY_ALGORITHM"))) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidAlgorithm))) + ); + assertEquals(mapperExceptionInvalidAlgorithm.getMessage(), "[algorithm name] value [MY_ALGORITHM] is invalid or not supported"); + } + + public void testInvalidParams() throws Exception { + XContentBuilder mappingInvalidMaxConnections = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("my_field", "some_value") + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMaxConnections))) + ); + assertEquals( + mapperExceptionInvalidMaxConnections.getMessage(), + "unknown parameter [my_field] on mapper [field] of type [dense_vector]" + ); + } + + public void testValueDisplay() { + Object actualFloatArray = fieldType.valueForDisplay(VECTOR); + assertTrue(actualFloatArray instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); + + KnnContext knnContextDEfaultAlgorithmContext = new KnnContext( + Metric.L2, + KnnAlgorithmContextFactory.defaultContext(KnnAlgorithmContext.Method.HNSW) + ); + MappedFieldType ftDefaultAlgorithmContext = new DenseVectorFieldMapper.DenseVectorFieldType( + FIELD_NAME, + 1, + knnContextDEfaultAlgorithmContext + ); + Object actualFloatArrayDefaultAlgorithmContext = ftDefaultAlgorithmContext.valueForDisplay(VECTOR); + assertTrue(actualFloatArrayDefaultAlgorithmContext instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArrayDefaultAlgorithmContext, 0.0f); + } + + public void testTermQueryNotSupported() { + QueryShardContext context = Mockito.mock(QueryShardContext.class); + QueryShardException exception = expectThrows(QueryShardException.class, () -> fieldType.termsQuery(Arrays.asList(VECTOR), context)); + assertEquals(exception.getMessage(), "Dense_vector does not support exact searching, use KNN queries instead [field]"); + } + + public void testPrefixQueryNotSupported() { + QueryShardException ee = expectThrows( + QueryShardException.class, + () -> fieldType.prefixQuery("foo*", null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals( + "Can only use prefix queries on keyword, text and wildcard fields - not on [field] which is of type [dense_vector]", + ee.getMessage() + ); + } + + public void testRegexpQueryNotSupported() { + QueryShardException ee = expectThrows( + QueryShardException.class, + () -> fieldType.regexpQuery("foo?", randomInt(10), 0, randomInt(10) + 1, null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals( + "Can only use regexp queries on keyword and text fields - not on [field] which is of type [dense_vector]", + ee.getMessage() + ); + } + + public void testWildcardQueryNotSupported() { + QueryShardException ee = expectThrows( + QueryShardException.class, + () -> fieldType.wildcardQuery("valu*", null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals( + "Can only use wildcard queries on keyword, text and wildcard fields - not on [field] which is of type [dense_vector]", + ee.getMessage() + ); + } + + private final SourceToParse source(CheckedConsumer build) throws IOException { + XContentBuilder builder = JsonXContent.contentBuilder().startObject(); + build.accept(builder); + builder.endObject(); + return new SourceToParse("test", "1", BytesReference.bytes(builder), XContentType.JSON); + } +} diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java new file mode 100644 index 0000000000000..d219cd9871c21 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.opensearch.index.mapper; + +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +public class DenseVectorMapperTests extends MapperServiceTestCase { + + private static final float[] VECTOR = { 2.0f, 4.5f }; + + public void testValueDisplay() { + KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( + HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 16, HNSW_PARAMETER_BEAM_WIDTH, 100) + ); + KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); + MappedFieldType ft = new DenseVectorFieldType("field", 1, knnContext); + Object actualFloatArray = ft.valueForDisplay(VECTOR); + assertTrue(actualFloatArray instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); + } + + public void testSerializationWithoutKnn() throws IOException { + DocumentMapper defaultMapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + Mapper mapper = defaultMapper.mappers().getMapper("field"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + mapper.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + assertEquals("{\"field\":{\"type\":\"dense_vector\",\"dimension\":2}}", Strings.toString(builder)); + } + + public void testSerializationWithKnn() throws IOException { + DocumentMapper defaultMapper = createDocumentMapper(fieldMapping(b -> { + minimalMapping(b); + b.field("knn", Map.of()); + })); + Mapper mapper = defaultMapper.mappers().getMapper("field"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + mapper.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + assertTrue( + Set.of( + "{\"field\":{\"type\":\"dense_vector\"," + + "\"dimension\":2," + + "\"knn\":" + + "{\"metric\":\"L2\"," + + "\"algorithm\":{" + + "\"name\":\"HNSW\"," + + "\"parameters\":{\"beam_width\":100,\"max_connections\":16}}}}}", + "{\"field\":{\"type\":\"dense_vector\"," + + "\"dimension\":2," + + "\"knn\":" + + "{\"metric\":\"L2\"," + + "\"algorithm\":{" + + "\"name\":\"HNSW\"," + + "\"parameters\":{\"max_connections\":16,\"beam_width\":100}}}}}" + ).contains(Strings.toString(builder)) + ); + } + + public void testMinimalToMaximal() throws IOException { + XContentBuilder orig = JsonXContent.contentBuilder().startObject(); + createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, INCLUDE_DEFAULTS); + orig.endObject(); + XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); + createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, INCLUDE_DEFAULTS); + parsedFromOrig.endObject(); + assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); + } + + public void testDeprecatedBoost() throws IOException { + createMapperService(fieldMapping(b -> { + minimalMapping(b); + b.field("boost", 2.0); + })); + String type = typeName(); + String[] warnings = new String[] { + "Parameter [boost] on field [field] is deprecated and will be removed in 8.0", + "Parameter [boost] has no effect on type [" + type + "] and will be removed in future" }; + allowedWarnings(warnings); + } + + public void testIfMinimalSerializesToItself() throws IOException { + XContentBuilder orig = JsonXContent.contentBuilder().startObject(); + createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, ToXContent.EMPTY_PARAMS); + orig.endObject(); + XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); + createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, ToXContent.EMPTY_PARAMS); + parsedFromOrig.endObject(); + assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); + } + + public void testForEmptyName() { + MapperParsingException e = expectThrows(MapperParsingException.class, () -> createMapperService(mapping(b -> { + b.startObject(""); + minimalMapping(b); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("name cannot be empty string")); + } + + protected void writeFieldValue(XContentBuilder b) throws IOException { + b.value(new float[] { 2.5f }); + } + + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "dense_vector"); + b.field("dimension", 2); + // b.field("knn", Map.of()); + } + + protected void registerParameters(MapperTestCase.ParameterChecker checker) throws IOException { + checker.registerConflictCheck("doc_values", b -> b.field("doc_values", false)); + checker.registerConflictCheck("index", b -> b.field("index", false)); + checker.registerConflictCheck("store", b -> b.field("store", false)); + } + + protected String typeName() throws IOException { + MapperService ms = createMapperService(fieldMapping(this::minimalMapping)); + return ms.fieldType("field").typeName(); + } +} diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 832328cb0242f..b172be9b6be64 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -91,6 +91,7 @@ import org.opensearch.index.mapper.CompletionFieldMapper; import org.opensearch.index.mapper.ContentPath; import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.DenseVectorFieldMapper; import org.opensearch.index.mapper.FieldAliasMapper; import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.GeoPointFieldMapper; @@ -184,6 +185,7 @@ public abstract class AggregatorTestCase extends OpenSearchTestCase { denylist.add(ObjectMapper.NESTED_CONTENT_TYPE); // TODO support for nested denylist.add(CompletionFieldMapper.CONTENT_TYPE); // TODO support completion denylist.add(FieldAliasMapper.CONTENT_TYPE); // TODO support alias + denylist.add(DenseVectorFieldMapper.CONTENT_TYPE); // Cannot aggregate dense_vector TYPE_TEST_DENYLIST = denylist; } From d317584a0e63b65d84c7e4eda9108ac730ba2146 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 24 Jun 2022 14:57:45 -0700 Subject: [PATCH 2/7] Adding checks on size limits for field type and vector data Signed-off-by: Martin Gaievski --- .../index/mapper/KnnAlgorithmContext.java | 13 ++- .../mapper/DenseVectorFieldTypeTests.java | 94 +++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java index ae47da1e7e3b6..3cf44eb8d63cc 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java @@ -55,6 +55,8 @@ public class KnnAlgorithmContext implements ToXContentFragment, Writeable { private final Method method; private final Map parameters; + private static final int MAX_NUMBER_OF_ALGORITHM_PARAMETERS = 50; + public KnnAlgorithmContext(Method method, Map parameters) { this.method = method; this.parameters = parameters; @@ -109,7 +111,16 @@ public static KnnAlgorithmContext parse(Object in) { } return v; })); - + if (parameters.size() > MAX_NUMBER_OF_ALGORITHM_PARAMETERS) { + throw new MapperParsingException( + String.format( + Locale.ROOT, + "Invalid number of parameters for [algorithm], max allowed is [%d] but given [%d]", + MAX_NUMBER_OF_ALGORITHM_PARAMETERS, + parameters.size() + ) + ); + } } else { throw new MapperParsingException(String.format(Locale.ROOT, "Invalid parameter %s for [algorithm]", key)); } diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java index 6697691936500..c34415b4b07b2 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -44,7 +44,9 @@ import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.containsString; import static org.opensearch.index.mapper.FieldTypeTestCase.MOCK_QSC_DISALLOW_EXPENSIVE; @@ -392,6 +394,98 @@ public void testInvalidParams() throws Exception { ); } + public void testExceedMaxNumberOfAlgorithmParams() throws Exception { + Map algorithmParams = new HashMap<>(); + IntStream.range(0, 100).forEach(number -> algorithmParams.put("param" + number, randomInt(Integer.MAX_VALUE))); + XContentBuilder mappingInvalidAlgorithm = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of("metric", METRIC_L2, "algorithm", Map.of("name", ALGORITHM_HNSW, "parameters", algorithmParams))) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidAlgorithm))) + ); + assertEquals( + mapperExceptionInvalidAlgorithm.getMessage(), + "Invalid number of parameters for [algorithm], max allowed is [50] but given [100]" + ); + } + + public void testInvalidVectorNumberFormat() throws Exception { + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", 1) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + final MapperParsingException mapperExceptionStringAsVectorValue = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) + .parse(source(b -> b.field(FIELD_NAME, "some malicious script content"))) + ); + assertEquals( + mapperExceptionStringAsVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'some malicious script content'" + ); + + final MapperParsingException mapperExceptionInfinityVectorValue = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) + .parse(source(b -> b.field(FIELD_NAME, new Float[] { Float.POSITIVE_INFINITY }))) + ); + assertEquals( + mapperExceptionInfinityVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'Infinity'" + ); + + final MapperParsingException mapperExceptionNullVectorValue = expectThrows( + MapperParsingException.class, + () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) + .parse(source(b -> b.field(FIELD_NAME, new Float[] { null }))) + ); + assertEquals( + mapperExceptionNullVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'null'" + ); + } + + public void testNullVectorValue() throws Exception { + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", DENSE_VECTOR_TYPE_NAME) + .field("dimension", DIMENSION) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject() + .endObject(); + + parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, (Float) null))); + + parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); + + parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, (Float) null))); + } + public void testValueDisplay() { Object actualFloatArray = fieldType.valueForDisplay(VECTOR); assertTrue(actualFloatArray instanceof float[]); From 65c094463c411c274a38da44c75b6e60a28c4d08 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 28 Jun 2022 12:15:03 -0700 Subject: [PATCH 3/7] Updating license header, exclude ES grant part Signed-off-by: Martin Gaievski --- .../opensearch/search/knn/DenseVectorIT.java | 31 ++----------------- .../index/codec/KnnVectorFormatFactory.java | 31 ++----------------- .../index/mapper/DenseVectorFieldMapper.java | 31 ++----------------- .../index/mapper/KnnAlgorithmContext.java | 31 ++----------------- .../mapper/KnnAlgorithmContextFactory.java | 31 ++----------------- .../opensearch/index/mapper/KnnContext.java | 31 ++----------------- .../org/opensearch/index/mapper/Metric.java | 31 ++----------------- .../mapper/DenseVectorFieldTypeTests.java | 26 ++-------------- .../index/mapper/DenseVectorMapperTests.java | 26 ++-------------- 9 files changed, 18 insertions(+), 251 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java index deefc3dcc309b..3cedcb4fd0c34 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.search.knn; diff --git a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java index 5eb1e2b9b0765..70e50fe4e3067 100644 --- a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java +++ b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.codec; diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java index 8d1ab0ceb69e7..dd2018909c323 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java index 3cf44eb8d63cc..7f4246ce5ca87 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java index 81ecb77d0cecc..b2907ea795674 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java index fa94972476a51..2355a2cbaae82 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/main/java/org/opensearch/index/mapper/Metric.java b/server/src/main/java/org/opensearch/index/mapper/Metric.java index 5219f8e20b5fb..2fd31432974f0 100644 --- a/server/src/main/java/org/opensearch/index/mapper/Metric.java +++ b/server/src/main/java/org/opensearch/index/mapper/Metric.java @@ -1,33 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java index c34415b4b07b2..b0b99a4054fbd 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -1,28 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java index d219cd9871c21..db663029f6139 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java @@ -1,28 +1,6 @@ /* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.index.mapper; From 0325d0b7be4038079b370904294403449574b4a0 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 1 Jul 2022 10:09:41 -0700 Subject: [PATCH 4/7] Review comments: inherit hasValues from base class, error messages, javadocs Signed-off-by: Martin Gaievski --- .../index/mapper/DenseVectorFieldMapper.java | 12 ++++++------ .../java/org/opensearch/index/mapper/KnnContext.java | 4 +++- .../java/org/opensearch/index/mapper/Metric.java | 4 +++- .../index/mapper/DenseVectorFieldTypeTests.java | 5 +---- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java index dd2018909c323..8c9f9ec4b50e9 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -29,8 +29,10 @@ /** * Field Mapper for Dense vector type. Extends ParametrizedFieldMapper in order to easily configure mapping parameters. + * + * @opensearch.internal */ -public class DenseVectorFieldMapper extends ParametrizedFieldMapper { +public final class DenseVectorFieldMapper extends ParametrizedFieldMapper { public static final String CONTENT_TYPE = "dense_vector"; @@ -55,12 +57,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { int value = XContentMapValues.nodeIntegerValue(o); if (value > MAX_DIMENSION) { throw new IllegalArgumentException( - String.format(Locale.ROOT, "[dimension] value cannot be greater than %d for vector [%s]", MAX_DIMENSION, name) + String.format(Locale.ROOT, "[dimension] value %d cannot be greater than %d for vector [%s]", value, MAX_DIMENSION, name) ); } if (value <= 0) { throw new IllegalArgumentException( - String.format(Locale.ROOT, "[dimension] value must be greater than 0 for vector [%s]", name) + String.format(Locale.ROOT, "[dimension] value %d must be greater than 0 for vector [%s]", value, name) ); } return value; @@ -130,17 +132,15 @@ public static class DenseVectorFieldType extends MappedFieldType { private final int dimension; private final KnnContext knnContext; - private final boolean hasDocValues; public DenseVectorFieldType(String name, int dimension, KnnContext knnContext) { this(name, Collections.emptyMap(), dimension, knnContext, false); } public DenseVectorFieldType(String name, Map meta, int dimension, KnnContext knnContext, boolean hasDocValues) { - super(name, false, false, false, TextSearchInfo.NONE, meta); + super(name, false, false, hasDocValues, TextSearchInfo.NONE, meta); this.dimension = dimension; this.knnContext = knnContext; - this.hasDocValues = hasDocValues; } @Override diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java index 2355a2cbaae82..106c7fcf2101e 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java @@ -17,8 +17,10 @@ /** * Abstracts KNN segment of dense_vector field type + * + * @opensearch.internal */ -public class KnnContext implements ToXContentFragment, Writeable { +public final class KnnContext implements ToXContentFragment, Writeable { private final Metric metric; private final KnnAlgorithmContext knnAlgorithmContext; diff --git a/server/src/main/java/org/opensearch/index/mapper/Metric.java b/server/src/main/java/org/opensearch/index/mapper/Metric.java index 2fd31432974f0..bc41128ef7479 100644 --- a/server/src/main/java/org/opensearch/index/mapper/Metric.java +++ b/server/src/main/java/org/opensearch/index/mapper/Metric.java @@ -13,8 +13,10 @@ /** * Abstracts supported metrics for dense_vector and KNN search + * + * @opensearch.internal */ -public enum Metric { +enum Metric { L2, COSINE, DOT_PRODUCT; diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java index b0b99a4054fbd..31bfce1beaa5c 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -273,10 +273,7 @@ public void testInvalidVectorDimension() throws Exception { IllegalArgumentException.class, () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidDimension))) ); - org.hamcrest.MatcherAssert.assertThat( - exceptionInvalidDimension.getMessage(), - containsString("[dimension] value cannot be greater than 1024 for vector") - ); + assertEquals(exceptionInvalidDimension.getMessage(), "[dimension] value 1200 cannot be greater than 1024 for vector [field]"); XContentBuilder mappingDimentionsMismatch = XContentFactory.jsonBuilder() .startObject() From ecb18fd947f7837338c24dc3afd787a6aba5b6f6 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 1 Jul 2022 21:22:43 -0700 Subject: [PATCH 5/7] Refactor mapper to extend FieldMapper, adjust tests accordingly Signed-off-by: Martin Gaievski --- .../index/mapper/DenseVectorFieldMapper.java | 224 +++++--- .../mapper/DenseVectorFieldMapperTests.java | 380 ++++++++++++++ .../mapper/DenseVectorFieldTypeTests.java | 489 +----------------- .../index/mapper/DenseVectorMapperTests.java | 139 ----- 4 files changed, 569 insertions(+), 663 deletions(-) create mode 100644 server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java delete mode 100644 server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java index 8c9f9ec4b50e9..7d41382f0191e 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -9,30 +9,32 @@ import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.MultiTermQuery; import org.apache.lucene.search.Query; import org.opensearch.common.Explicit; -import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.Nullable; +import org.opensearch.common.unit.Fuzziness; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.QueryShardException; import org.opensearch.search.lookup.SearchLookup; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; /** * Field Mapper for Dense vector type. Extends ParametrizedFieldMapper in order to easily configure mapping parameters. * * @opensearch.internal */ -public final class DenseVectorFieldMapper extends ParametrizedFieldMapper { +public final class DenseVectorFieldMapper extends FieldMapper { public static final String CONTENT_TYPE = "dense_vector"; @@ -49,12 +51,29 @@ private static DenseVectorFieldMapper toType(FieldMapper in) { * Builder for DenseVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector * field type */ - public static class Builder extends ParametrizedFieldMapper.Builder { + public static class Builder extends FieldMapper.Builder { + private CopyTo copyTo = CopyTo.empty(); + private Integer dimension = 1; + private KnnContext knnContext = null; - private final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, false); + public Builder(String name) { + super(name, Defaults.FIELD_TYPE); + builder = this; + } + + @Override + public DenseVectorFieldMapper build(BuilderContext context) { + final DenseVectorFieldType mappedFieldType = new DenseVectorFieldType(buildFullName(context), dimension, knnContext); + return new DenseVectorFieldMapper( + buildFullName(context), + fieldType, + mappedFieldType, + multiFieldsBuilder.build(this, context), + copyTo + ); + } - protected final Parameter dimension = new Parameter<>(Names.DIMENSION.getValue(), false, () -> 1, (n, c, o) -> { - int value = XContentMapValues.nodeIntegerValue(o); + public Builder dimension(int value) { if (value > MAX_DIMENSION) { throw new IllegalArgumentException( String.format(Locale.ROOT, "[dimension] value %d cannot be greater than %d for vector [%s]", value, MAX_DIMENSION, name) @@ -65,41 +84,13 @@ public static class Builder extends ParametrizedFieldMapper.Builder { String.format(Locale.ROOT, "[dimension] value %d must be greater than 0 for vector [%s]", value, name) ); } - return value; - }, m -> toType(m).dimension).setSerializer((b, n, v) -> b.field(n, v.intValue()), v -> Integer.toString(v.intValue())); - - private final Parameter knnContext = new Parameter<>( - Names.KNN.getValue(), - false, - () -> null, - (n, c, o) -> KnnContext.parse(o), - m -> toType(m).knnContext - ).setSerializer(((b, n, v) -> { - if (v == null) { - return; - } - b.startObject(n); - v.toXContent(b, ToXContent.EMPTY_PARAMS); - b.endObject(); - }), m -> m.getKnnAlgorithmContext().getMethod().name()); - - public Builder(String name) { - super(name); + this.dimension = value; + return this; } - @Override - protected List> getParameters() { - return List.of(dimension, knnContext, hasDocValues); - } - - @Override - public DenseVectorFieldMapper build(BuilderContext context) { - return new DenseVectorFieldMapper( - buildFullName(context), - new DenseVectorFieldType(buildFullName(context), dimension.get(), knnContext.get()), - multiFieldsBuilder.build(this, context), - copyTo.build() - ); + public Builder knn(KnnContext value) { + this.knnContext = value; + return this; } } @@ -113,12 +104,30 @@ public static class TypeParser implements Mapper.TypeParser { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { Builder builder = new DenseVectorFieldMapper.Builder(name); - Object dimensionField = node.get(Names.DIMENSION.getValue()); - String dimension = XContentMapValues.nodeStringValue(dimensionField, null); - if (dimension == null) { - throw new MapperParsingException(String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name)); + TypeParsers.parseField(builder, name, node, parserContext); + + for (Iterator> iterator = node.entrySet().iterator(); iterator.hasNext();) { + Map.Entry entry = iterator.next(); + String fieldName = entry.getKey(); + Object fieldNode = entry.getValue(); + switch (fieldName) { + case "dimension": + if (fieldNode == null) { + throw new MapperParsingException( + String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name) + ); + } + builder.dimension(XContentMapValues.nodeIntegerValue(fieldNode, 1)); + iterator.remove(); + break; + case "knn": + builder.knn(KnnContext.parse(fieldNode)); + iterator.remove(); + break; + default: + break; + } } - builder.parse(name, parserContext, node); return builder; } } @@ -145,7 +154,7 @@ public DenseVectorFieldType(String name, Map meta, int dimension @Override public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { - throw new UnsupportedOperationException("Dense_vector does not support fields search"); + throw new UnsupportedOperationException("[fields search] are not supported on [" + CONTENT_TYPE + "] fields."); } @Override @@ -154,16 +163,47 @@ public String typeName() { } @Override - public Query existsQuery(QueryShardContext context) { - return new FieldExistsQuery(name()); + public Query termQuery(Object value, QueryShardContext context) { + throw new UnsupportedOperationException("[term] queries are not supported on [" + CONTENT_TYPE + "] fields."); } @Override - public Query termQuery(Object value, QueryShardContext context) { - throw new QueryShardException( - context, - "Dense_vector does not support exact searching, use KNN queries instead [" + name() + "]" - ); + public Query fuzzyQuery( + Object value, + Fuzziness fuzziness, + int prefixLength, + int maxExpansions, + boolean transpositions, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[fuzzy] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query prefixQuery(String value, MultiTermQuery.RewriteMethod method, boolean caseInsensitive, QueryShardContext context) { + throw new UnsupportedOperationException("[prefix] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query wildcardQuery( + String value, + @Nullable MultiTermQuery.RewriteMethod method, + boolean caseInsensitive, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[wildcard] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query regexpQuery( + String value, + int syntaxFlags, + int matchFlags, + int maxDeterminizedStates, + MultiTermQuery.RewriteMethod method, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[regexp] queries are not supported on [" + CONTENT_TYPE + "] fields."); } public int getDimension() { @@ -182,8 +222,14 @@ public KnnContext getKnnContext() { protected boolean hasDocValues; protected String modelId; - public DenseVectorFieldMapper(String simpleName, DenseVectorFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo) { - super(simpleName, mappedFieldType, multiFields, copyTo); + public DenseVectorFieldMapper( + String simpleName, + FieldType fieldType, + DenseVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo + ) { + super(simpleName, fieldType, mappedFieldType, multiFields, copyTo); dimension = mappedFieldType.getDimension(); fieldType = new FieldType(DenseVectorFieldMapper.Defaults.FIELD_TYPE); isKnnEnabled = mappedFieldType.getKnnContext() != null; @@ -207,6 +253,57 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, fieldType().getDimension()); } + @Override + protected void mergeOptions(FieldMapper other, List conflicts) { + DenseVectorFieldMapper denseVectorMergeWith = (DenseVectorFieldMapper) other; + if (!Objects.equals(dimension, denseVectorMergeWith.dimension)) { + conflicts.add("mapper [" + name() + "] has different [dimension]"); + } + + if (isOnlyOneObjectNull(knnContext, denseVectorMergeWith.knnContext) + || (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext) + && !Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric()))) { + conflicts.add("mapper [" + name() + "] has different [metric]"); + } + + if (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext)) { + + if (!Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric())) { + conflicts.add("mapper [" + name() + "] has different [metric]"); + } + + if (isBothObjectsNotNull(knnContext.getKnnAlgorithmContext(), denseVectorMergeWith.knnContext.getKnnAlgorithmContext())) { + KnnAlgorithmContext knnAlgorithmContext = knnContext.getKnnAlgorithmContext(); + KnnAlgorithmContext mergeWithKnnAlgorithmContext = denseVectorMergeWith.knnContext.getKnnAlgorithmContext(); + + if (isOnlyOneObjectNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext) + || (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext) + && !Objects.equals(knnAlgorithmContext.getMethod(), mergeWithKnnAlgorithmContext.getMethod()))) { + conflicts.add("mapper [" + name() + "] has different [method]"); + } + + if (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext)) { + Map knnAlgoParams = knnAlgorithmContext.getParameters(); + Map mergeWithKnnAlgoParams = mergeWithKnnAlgorithmContext.getParameters(); + + if (isOnlyOneObjectNull(knnAlgoParams, mergeWithKnnAlgoParams) + || (isBothObjectsNotNull(knnAlgoParams, mergeWithKnnAlgoParams) + && !Objects.equals(knnAlgoParams, mergeWithKnnAlgoParams))) { + conflicts.add("mapper [" + name() + "] has different [knn algorithm parameters]"); + } + } + } + } + } + + private boolean isOnlyOneObjectNull(Object object1, Object object2) { + return object1 == null && object2 != null || object2 == null && object1 != null; + } + + private boolean isBothObjectsNotNull(Object object1, Object object2) { + return object1 != null && object2 != null; + } + protected void parseCreateField(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); @@ -276,12 +373,7 @@ protected boolean docValuesByDefault() { } @Override - public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new DenseVectorFieldMapper.Builder(simpleName()).init(this); - } - - @Override - public final boolean parsesArrayValue() { + public boolean parsesArrayValue() { return true; } @@ -293,6 +385,13 @@ public DenseVectorFieldType fieldType() { @Override protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { super.doXContentBody(builder, includeDefaults, params); + + builder.field("dimension", dimension); + if (knnContext != null) { + builder.startObject("knn"); + knnContext.toXContent(builder, params); + builder.endObject(); + } } /** @@ -325,6 +424,7 @@ static class Defaults { static { FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setOmitNorms(true); FIELD_TYPE.setIndexOptions(IndexOptions.NONE); FIELD_TYPE.setDocValuesType(DocValuesType.NONE); FIELD_TYPE.freeze(); diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java new file mode 100644 index 0000000000000..30061af3e71eb --- /dev/null +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java @@ -0,0 +1,380 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.IndexableField; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +public class DenseVectorFieldMapperTests extends FieldMapperTestCase2 { + + private static final float[] VECTOR = { 2.0f, 4.5f }; + + public void testValueDisplay() { + KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( + HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 16, HNSW_PARAMETER_BEAM_WIDTH, 100) + ); + KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); + MappedFieldType ft = new DenseVectorFieldType("field", 1, knnContext); + Object actualFloatArray = ft.valueForDisplay(VECTOR); + assertTrue(actualFloatArray instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); + } + + public void testSerializationWithoutKnn() throws IOException { + DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dimension", 2))); + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + + ParsedDocument doc = mapper.parse(source(b -> b.field("field", VECTOR))); + IndexableField[] fields = doc.rootDoc().getFields("field"); + assertEquals(1, fields.length); + assertTrue(fields[0] instanceof KnnVectorField); + float[] actualVector = ((KnnVectorField) fields[0]).vectorValue(); + assertArrayEquals(VECTOR, actualVector, 0.0f); + } + + public void testSerializationWithKnn() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + + ParsedDocument doc = mapper.parse(source(b -> b.field("field", VECTOR))); + IndexableField[] fields = doc.rootDoc().getFields("field"); + assertEquals(1, fields.length); + assertTrue(fields[0] instanceof KnnVectorField); + float[] actualVector = ((KnnVectorField) fields[0]).vectorValue(); + assertArrayEquals(VECTOR, actualVector, 0.0f); + } + + @Override + protected DenseVectorFieldMapper.Builder newBuilder() { + return new DenseVectorFieldMapper.Builder("dense_vector"); + } + + public void testDeprecatedBoost() throws IOException { + createMapperService(fieldMapping(b -> { + minimalMapping(b); + b.field("boost", 2.0); + })); + String type = typeName(); + String[] warnings = new String[] { + "Parameter [boost] on field [field] is deprecated and will be removed in 8.0", + "Parameter [boost] has no effect on type [" + type + "] and will be removed in future" }; + allowedWarnings(warnings); + } + + public void testIfMinimalSerializesToItself() throws IOException { + XContentBuilder orig = JsonXContent.contentBuilder().startObject(); + createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, ToXContent.EMPTY_PARAMS); + orig.endObject(); + XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); + createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, ToXContent.EMPTY_PARAMS); + parsedFromOrig.endObject(); + assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); + } + + public void testForEmptyName() { + MapperParsingException e = expectThrows(MapperParsingException.class, () -> createMapperService(mapping(b -> { + b.startObject(""); + minimalMapping(b); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("name cannot be empty string")); + } + + protected void writeFieldValue(XContentBuilder b) throws IOException { + b.value(new float[] { 2.5f }); + } + + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "dense_vector"); + b.field("dimension", 1); + } + + protected void registerParameters(MapperTestCase.ParameterChecker checker) throws IOException {} + + @Override + protected Set unsupportedProperties() { + return org.opensearch.common.collect.Set.of("analyzer", "similarity", "doc_values", "store", "index"); + } + + protected String typeName() throws IOException { + MapperService ms = createMapperService(fieldMapping(this::minimalMapping)); + return ms.fieldType("field").typeName(); + } + + @Override + protected boolean supportsMeta() { + return false; + } + + public void testCosineMetric() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "cosine", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + } + + public void testDotProductMetric() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + } + + public void testHNSWAlgorithmParametersInvalidInput() throws Exception { + XContentBuilder mappingInvalidMaxConnections = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 256, "beam_width", 50)) + ) + ) + ); + final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidMaxConnections) + ); + assertEquals("max_connections value cannot be greater than 16", mapperExceptionInvalidMaxConnections.getRootCause().getMessage()); + + XContentBuilder mappingInvalidBeamWidth = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 6, "beam_width", 1024)) + ) + ) + ); + final MapperParsingException mapperExceptionInvalidmBeamWidth = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidBeamWidth) + ); + assertEquals("beam_width value cannot be greater than 512", mapperExceptionInvalidmBeamWidth.getRootCause().getMessage()); + + XContentBuilder mappingUnsupportedParam = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 6, "beam_width", 256, "some_param", 23)) + ) + ) + ); + final MapperParsingException mapperExceptionUnsupportedParam = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingUnsupportedParam) + ); + assertEquals("Algorithm parameter [some_param] is not supported", mapperExceptionUnsupportedParam.getRootCause().getMessage()); + } + + public void testInvalidMetric() throws Exception { + XContentBuilder mappingInvalidMetric = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "LAMBDA", "algorithm", Map.of("name", "HNSW"))) + ); + final MapperParsingException mapperExceptionInvalidMetric = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidMetric) + ); + assertEquals("[metric] value [LAMBDA] is invalid", mapperExceptionInvalidMetric.getRootCause().getMessage()); + } + + public void testInvalidAlgorithm() throws Exception { + XContentBuilder mappingInvalidAlgorithm = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", "MY_ALGORITHM"))) + ); + final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidAlgorithm) + ); + assertEquals( + "[algorithm name] value [MY_ALGORITHM] is invalid or not supported", + mapperExceptionInvalidAlgorithm.getRootCause().getMessage() + ); + } + + public void testInvalidParams() throws Exception { + XContentBuilder mapping = fieldMapping( + b -> b.field("type", "dense_vector").field("dimension", 2).field("my_field", "some_value").field("knn", Map.of()) + ); + final MapperParsingException mapperParsingException = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mapping) + ); + assertEquals( + "Mapping definition for [field] has unsupported parameters: [my_field : some_value]", + mapperParsingException.getRootCause().getMessage() + ); + } + + public void testExceedMaxNumberOfAlgorithmParams() throws Exception { + Map algorithmParams = new HashMap<>(); + IntStream.range(0, 100).forEach(number -> algorithmParams.put("param" + number, randomInt(Integer.MAX_VALUE))); + XContentBuilder mapping = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", "HNSW", "parameters", algorithmParams))) + ); + final MapperParsingException mapperParsingException = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mapping) + ); + assertEquals( + "Invalid number of parameters for [algorithm], max allowed is [50] but given [100]", + mapperParsingException.getRootCause().getMessage() + ); + } + + public void testInvalidVectorNumberFormat() throws Exception { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + final MapperParsingException mapperExceptionStringAsVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", "some malicious script content"))) + ); + assertEquals( + mapperExceptionStringAsVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'some malicious script content'" + ); + + final MapperParsingException mapperExceptionInfinityVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", new Float[] { Float.POSITIVE_INFINITY }))) + ); + assertEquals( + mapperExceptionInfinityVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'Infinity'" + ); + + final MapperParsingException mapperExceptionNullVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", new Float[] { null }))) + ); + assertEquals( + mapperExceptionNullVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'null'" + ); + } + + public void testNullVectorValue() throws Exception { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + mapper.parse(source(b -> b.field("field", (Float) null))); + mapper.parse(source(b -> b.field("field", VECTOR))); + mapper.parse(source(b -> b.field("field", (Float) null))); + } +} diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java index 31bfce1beaa5c..7ae117881a43f 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -7,458 +7,32 @@ import org.junit.Before; import org.mockito.Mockito; -import org.opensearch.common.CheckedConsumer; -import org.opensearch.common.Strings; -import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.compress.CompressedXContent; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.index.IndexService; +import org.opensearch.common.unit.Fuzziness; +import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.QueryShardException; -import org.opensearch.test.OpenSearchSingleNodeTestCase; -import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; import java.util.Map; -import java.util.stream.IntStream; -import static org.hamcrest.Matchers.containsString; -import static org.opensearch.index.mapper.FieldTypeTestCase.MOCK_QSC_DISALLOW_EXPENSIVE; import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; -public class DenseVectorFieldTypeTests extends OpenSearchSingleNodeTestCase { - private static final String ALGORITHM_HNSW = "HNSW"; - private static final String DENSE_VECTOR_TYPE_NAME = "dense_vector"; - private static final int DIMENSION = 2; +public class DenseVectorFieldTypeTests extends FieldTypeTestCase { + private static final String FIELD_NAME = "field"; - private static final String METRIC_L2 = "L2"; private static final float[] VECTOR = { 2.0f, 4.5f }; - private IndexService indexService; - private DocumentMapperParser parser; - private MappedFieldType fieldType; + private DenseVectorFieldType fieldType; @Before public void setup() throws Exception { - indexService = createIndex("test"); - parser = indexService.mapperService().documentMapperParser(); - KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( HNSW, Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 10, HNSW_PARAMETER_BEAM_WIDTH, 100) ); KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); - fieldType = new DenseVectorFieldMapper.DenseVectorFieldType(FIELD_NAME, 1, knnContext); - } - - public void testIndexingWithoutEnablingKnn() throws IOException { - XContentBuilder mappingAllDefaults = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .endObject() - .endObject() - .endObject() - .endObject(); - parser.parse("type", new CompressedXContent(Strings.toString(mappingAllDefaults))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); - } - - public void testIndexingWithDefaultParams() throws IOException { - XContentBuilder mappingAllDefaults = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - parser.parse("type", new CompressedXContent(Strings.toString(mappingAllDefaults))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); - } - - public void testIndexingWithAlgorithmParameters() throws IOException { - XContentBuilder mapping = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field( - "knn", - Map.of( - "metric", - METRIC_L2, - "algorithm", - Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 16)) - ) - ) - .endObject() - .endObject() - .endObject() - .endObject(); - parser.parse("type", new CompressedXContent(Strings.toString(mapping))); - } - - public void testCosineMetric() throws IOException { - XContentBuilder mappingCosineMetric = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of("metric", "cosine", "algorithm", Map.of("name", ALGORITHM_HNSW))) - .endObject() - .endObject() - .endObject() - .endObject(); - parser.parse("type", new CompressedXContent(Strings.toString(mappingCosineMetric))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); - } - - public void testDotProductMetric() throws IOException { - XContentBuilder mappingDotProductMetric = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", ALGORITHM_HNSW))) - .endObject() - .endObject() - .endObject() - .endObject(); - parser.parse("type", new CompressedXContent(Strings.toString(mappingDotProductMetric))) - .parse(source(b -> b.field(FIELD_NAME, VECTOR))); - } - - public void testHNSWAlgorithmParametersInvalidInput() throws Exception { - XContentBuilder mappingInvalidMaxConnections = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field( - "knn", - Map.of( - "metric", - METRIC_L2, - "algorithm", - Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 50)) - ) - ) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMaxConnections))) - ); - org.hamcrest.MatcherAssert.assertThat( - mapperExceptionInvalidMaxConnections.getMessage(), - containsString("max_connections value cannot be greater than") - ); - - XContentBuilder mappingInvalidBeamWidth = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field( - "knn", - Map.of( - "metric", - METRIC_L2, - "algorithm", - Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 1024, "max_connections", 6)) - ) - ) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidmBeamWidth = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidBeamWidth))) - ); - org.hamcrest.MatcherAssert.assertThat( - mapperExceptionInvalidmBeamWidth.getMessage(), - containsString("beam_width value cannot be greater than") - ); - - XContentBuilder mappingUnsupportedParam = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field( - "knn", - Map.of( - "metric", - METRIC_L2, - "algorithm", - Map.of("name", ALGORITHM_HNSW, "parameters", Map.of("beam_width", 256, "max_connections", 6, "some_param", 23)) - ) - ) - .endObject() - .endObject() - .endObject() - .endObject(); - - final IllegalArgumentException mapperExceptionUnsupportedParam = expectThrows( - IllegalArgumentException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingUnsupportedParam))) - ); - assertEquals(mapperExceptionUnsupportedParam.getMessage(), "Algorithm parameter [some_param] is not supported"); - } - - public void testInvalidVectorDimension() throws Exception { - XContentBuilder mappingMissingDimension = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionMissingDimension = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingMissingDimension))) - ); - org.hamcrest.MatcherAssert.assertThat( - mapperExceptionMissingDimension.getMessage(), - containsString("[dimension] property must be specified for field") - ); - - XContentBuilder mappingInvalidDimension = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", 1200) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - final IllegalArgumentException exceptionInvalidDimension = expectThrows( - IllegalArgumentException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidDimension))) - ); - assertEquals(exceptionInvalidDimension.getMessage(), "[dimension] value 1200 cannot be greater than 1024 for vector [field]"); - - XContentBuilder mappingDimentionsMismatch = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionIDimentionsMismatch = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingDimentionsMismatch))) - .parse(source(b -> b.field(FIELD_NAME, new float[] { 2.0f, 4.5f, 5.6f }))) - ); - org.hamcrest.MatcherAssert.assertThat( - mapperExceptionIDimentionsMismatch.getMessage(), - containsString("failed to parse field [field] of type [dense_vector]") - ); - } - - public void testInvalidMetric() throws Exception { - XContentBuilder mappingInvalidMetric = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of("metric", "LAMBDA", "algorithm", Map.of("name", ALGORITHM_HNSW))) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidMetric = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMetric))) - ); - org.hamcrest.MatcherAssert.assertThat( - mapperExceptionInvalidMetric.getMessage(), - containsString("[metric] value [LAMBDA] is invalid") - ); - } - - public void testInvalidAlgorithm() throws Exception { - XContentBuilder mappingInvalidAlgorithm = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of("metric", METRIC_L2, "algorithm", Map.of("name", "MY_ALGORITHM"))) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidAlgorithm))) - ); - assertEquals(mapperExceptionInvalidAlgorithm.getMessage(), "[algorithm name] value [MY_ALGORITHM] is invalid or not supported"); - } - - public void testInvalidParams() throws Exception { - XContentBuilder mappingInvalidMaxConnections = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("my_field", "some_value") - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidMaxConnections))) - ); - assertEquals( - mapperExceptionInvalidMaxConnections.getMessage(), - "unknown parameter [my_field] on mapper [field] of type [dense_vector]" - ); - } - - public void testExceedMaxNumberOfAlgorithmParams() throws Exception { - Map algorithmParams = new HashMap<>(); - IntStream.range(0, 100).forEach(number -> algorithmParams.put("param" + number, randomInt(Integer.MAX_VALUE))); - XContentBuilder mappingInvalidAlgorithm = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of("metric", METRIC_L2, "algorithm", Map.of("name", ALGORITHM_HNSW, "parameters", algorithmParams))) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mappingInvalidAlgorithm))) - ); - assertEquals( - mapperExceptionInvalidAlgorithm.getMessage(), - "Invalid number of parameters for [algorithm], max allowed is [50] but given [100]" - ); - } - - public void testInvalidVectorNumberFormat() throws Exception { - XContentBuilder mapping = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", 1) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - final MapperParsingException mapperExceptionStringAsVectorValue = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) - .parse(source(b -> b.field(FIELD_NAME, "some malicious script content"))) - ); - assertEquals( - mapperExceptionStringAsVectorValue.getMessage(), - "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'some malicious script content'" - ); - - final MapperParsingException mapperExceptionInfinityVectorValue = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) - .parse(source(b -> b.field(FIELD_NAME, new Float[] { Float.POSITIVE_INFINITY }))) - ); - assertEquals( - mapperExceptionInfinityVectorValue.getMessage(), - "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'Infinity'" - ); - - final MapperParsingException mapperExceptionNullVectorValue = expectThrows( - MapperParsingException.class, - () -> parser.parse("type", new CompressedXContent(Strings.toString(mapping))) - .parse(source(b -> b.field(FIELD_NAME, new Float[] { null }))) - ); - assertEquals( - mapperExceptionNullVectorValue.getMessage(), - "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'null'" - ); - } - - public void testNullVectorValue() throws Exception { - XContentBuilder mapping = XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", DENSE_VECTOR_TYPE_NAME) - .field("dimension", DIMENSION) - .field("knn", Map.of()) - .endObject() - .endObject() - .endObject() - .endObject(); - - parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, (Float) null))); - - parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, VECTOR))); - - parser.parse("type", new CompressedXContent(Strings.toString(mapping))).parse(source(b -> b.field(FIELD_NAME, (Float) null))); + fieldType = new DenseVectorFieldType(FIELD_NAME, 1, knnContext); } public void testValueDisplay() { @@ -470,11 +44,7 @@ public void testValueDisplay() { Metric.L2, KnnAlgorithmContextFactory.defaultContext(KnnAlgorithmContext.Method.HNSW) ); - MappedFieldType ftDefaultAlgorithmContext = new DenseVectorFieldMapper.DenseVectorFieldType( - FIELD_NAME, - 1, - knnContextDEfaultAlgorithmContext - ); + MappedFieldType ftDefaultAlgorithmContext = new DenseVectorFieldType(FIELD_NAME, 1, knnContextDEfaultAlgorithmContext); Object actualFloatArrayDefaultAlgorithmContext = ftDefaultAlgorithmContext.valueForDisplay(VECTOR); assertTrue(actualFloatArrayDefaultAlgorithmContext instanceof float[]); assertArrayEquals(VECTOR, (float[]) actualFloatArrayDefaultAlgorithmContext, 0.0f); @@ -482,47 +52,42 @@ public void testValueDisplay() { public void testTermQueryNotSupported() { QueryShardContext context = Mockito.mock(QueryShardContext.class); - QueryShardException exception = expectThrows(QueryShardException.class, () -> fieldType.termsQuery(Arrays.asList(VECTOR), context)); - assertEquals(exception.getMessage(), "Dense_vector does not support exact searching, use KNN queries instead [field]"); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.termsQuery(Arrays.asList(VECTOR), context) + ); + assertEquals(exception.getMessage(), "[term] queries are not supported on [dense_vector] fields."); } public void testPrefixQueryNotSupported() { - QueryShardException ee = expectThrows( - QueryShardException.class, + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, () -> fieldType.prefixQuery("foo*", null, MOCK_QSC_DISALLOW_EXPENSIVE) ); - assertEquals( - "Can only use prefix queries on keyword, text and wildcard fields - not on [field] which is of type [dense_vector]", - ee.getMessage() - ); + assertEquals("[prefix] queries are not supported on [dense_vector] fields.", ee.getMessage()); } public void testRegexpQueryNotSupported() { - QueryShardException ee = expectThrows( - QueryShardException.class, + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, () -> fieldType.regexpQuery("foo?", randomInt(10), 0, randomInt(10) + 1, null, MOCK_QSC_DISALLOW_EXPENSIVE) ); - assertEquals( - "Can only use regexp queries on keyword and text fields - not on [field] which is of type [dense_vector]", - ee.getMessage() - ); + assertEquals("[regexp] queries are not supported on [dense_vector] fields.", ee.getMessage()); } public void testWildcardQueryNotSupported() { - QueryShardException ee = expectThrows( - QueryShardException.class, + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, () -> fieldType.wildcardQuery("valu*", null, MOCK_QSC_DISALLOW_EXPENSIVE) ); - assertEquals( - "Can only use wildcard queries on keyword, text and wildcard fields - not on [field] which is of type [dense_vector]", - ee.getMessage() - ); + assertEquals("[wildcard] queries are not supported on [dense_vector] fields.", ee.getMessage()); } - private final SourceToParse source(CheckedConsumer build) throws IOException { - XContentBuilder builder = JsonXContent.contentBuilder().startObject(); - build.accept(builder); - builder.endObject(); - return new SourceToParse("test", "1", BytesReference.bytes(builder), XContentType.JSON); + public void testFuzzyQuery() { + UnsupportedOperationException e = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.fuzzyQuery("foo", Fuzziness.fromEdits(2), 1, 50, true, randomMockShardContext()) + ); + assertEquals("[fuzzy] queries are not supported on [dense_vector] fields.", e.getMessage()); } } diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java deleted file mode 100644 index db663029f6139..0000000000000 --- a/server/src/test/java/org/opensearch/index/mapper/DenseVectorMapperTests.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.index.mapper; - -import org.opensearch.common.Strings; -import org.opensearch.common.xcontent.ToXContent; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; - -import java.io.IOException; -import java.util.Map; -import java.util.Set; - -import static org.hamcrest.Matchers.containsString; -import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; -import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; -import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; - -public class DenseVectorMapperTests extends MapperServiceTestCase { - - private static final float[] VECTOR = { 2.0f, 4.5f }; - - public void testValueDisplay() { - KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( - HNSW, - Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 16, HNSW_PARAMETER_BEAM_WIDTH, 100) - ); - KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); - MappedFieldType ft = new DenseVectorFieldType("field", 1, knnContext); - Object actualFloatArray = ft.valueForDisplay(VECTOR); - assertTrue(actualFloatArray instanceof float[]); - assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); - } - - public void testSerializationWithoutKnn() throws IOException { - DocumentMapper defaultMapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - Mapper mapper = defaultMapper.mappers().getMapper("field"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - mapper.toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.endObject(); - assertEquals("{\"field\":{\"type\":\"dense_vector\",\"dimension\":2}}", Strings.toString(builder)); - } - - public void testSerializationWithKnn() throws IOException { - DocumentMapper defaultMapper = createDocumentMapper(fieldMapping(b -> { - minimalMapping(b); - b.field("knn", Map.of()); - })); - Mapper mapper = defaultMapper.mappers().getMapper("field"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - mapper.toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.endObject(); - assertTrue( - Set.of( - "{\"field\":{\"type\":\"dense_vector\"," - + "\"dimension\":2," - + "\"knn\":" - + "{\"metric\":\"L2\"," - + "\"algorithm\":{" - + "\"name\":\"HNSW\"," - + "\"parameters\":{\"beam_width\":100,\"max_connections\":16}}}}}", - "{\"field\":{\"type\":\"dense_vector\"," - + "\"dimension\":2," - + "\"knn\":" - + "{\"metric\":\"L2\"," - + "\"algorithm\":{" - + "\"name\":\"HNSW\"," - + "\"parameters\":{\"max_connections\":16,\"beam_width\":100}}}}}" - ).contains(Strings.toString(builder)) - ); - } - - public void testMinimalToMaximal() throws IOException { - XContentBuilder orig = JsonXContent.contentBuilder().startObject(); - createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, INCLUDE_DEFAULTS); - orig.endObject(); - XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); - createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, INCLUDE_DEFAULTS); - parsedFromOrig.endObject(); - assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); - } - - public void testDeprecatedBoost() throws IOException { - createMapperService(fieldMapping(b -> { - minimalMapping(b); - b.field("boost", 2.0); - })); - String type = typeName(); - String[] warnings = new String[] { - "Parameter [boost] on field [field] is deprecated and will be removed in 8.0", - "Parameter [boost] has no effect on type [" + type + "] and will be removed in future" }; - allowedWarnings(warnings); - } - - public void testIfMinimalSerializesToItself() throws IOException { - XContentBuilder orig = JsonXContent.contentBuilder().startObject(); - createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, ToXContent.EMPTY_PARAMS); - orig.endObject(); - XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); - createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, ToXContent.EMPTY_PARAMS); - parsedFromOrig.endObject(); - assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); - } - - public void testForEmptyName() { - MapperParsingException e = expectThrows(MapperParsingException.class, () -> createMapperService(mapping(b -> { - b.startObject(""); - minimalMapping(b); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("name cannot be empty string")); - } - - protected void writeFieldValue(XContentBuilder b) throws IOException { - b.value(new float[] { 2.5f }); - } - - protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "dense_vector"); - b.field("dimension", 2); - // b.field("knn", Map.of()); - } - - protected void registerParameters(MapperTestCase.ParameterChecker checker) throws IOException { - checker.registerConflictCheck("doc_values", b -> b.field("doc_values", false)); - checker.registerConflictCheck("index", b -> b.field("index", false)); - checker.registerConflictCheck("store", b -> b.field("store", false)); - } - - protected String typeName() throws IOException { - MapperService ms = createMapperService(fieldMapping(this::minimalMapping)); - return ms.fieldType("field").typeName(); - } -} From 6ad1bda84fa7660b0bb3f2d3be7bb37c46974466 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 6 Jul 2022 10:16:49 -0700 Subject: [PATCH 6/7] Remove redundant null-check Signed-off-by: Martin Gaievski --- .../java/org/opensearch/index/codec/KnnVectorFormatFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java index 70e50fe4e3067..2cbc1d53392ac 100644 --- a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java +++ b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java @@ -54,7 +54,7 @@ public KnnVectorsFormat create(final String field) { } private boolean isDenseVectorFieldType(final MappedFieldType mappedFieldType) { - if (mappedFieldType != null && mappedFieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) { + if (mappedFieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) { return true; } return false; From 24dfcdb0f6c5905ac24df6fc8d98f55622c8b06a Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 6 Jul 2022 11:29:03 -0700 Subject: [PATCH 7/7] Refactoring code to address comments Signed-off-by: Martin Gaievski --- .../index/codec/KnnVectorFormatFactory.java | 6 +--- .../index/mapper/DenseVectorFieldMapper.java | 2 +- .../index/mapper/KnnAlgorithmContext.java | 34 ++++++++----------- .../opensearch/index/mapper/KnnContext.java | 7 ++-- .../translog/InternalTranslogManager.java | 12 +++---- .../index/translog/TranslogManager.java | 6 ++-- 6 files changed, 30 insertions(+), 37 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java index 2cbc1d53392ac..e353fa2ca991a 100644 --- a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java +++ b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java @@ -8,7 +8,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.lucene92.Lucene92HnswVectorsFormat; -import org.opensearch.index.mapper.DenseVectorFieldMapper; import org.opensearch.index.mapper.KnnAlgorithmContext; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; @@ -54,10 +53,7 @@ public KnnVectorsFormat create(final String field) { } private boolean isDenseVectorFieldType(final MappedFieldType mappedFieldType) { - if (mappedFieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) { - return true; - } - return false; + return mappedFieldType instanceof DenseVectorFieldType; } private int getIntegerParam(Map methodParams, String name) { diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java index 7d41382f0191e..93dd360641ac6 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -41,7 +41,7 @@ public final class DenseVectorFieldMapper extends FieldMapper { /** * Define the max dimension a knn_vector mapping can have. */ - public static final int MAX_DIMENSION = 1024; + private static final int MAX_DIMENSION = 1024; private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java index 7f4246ce5ca87..b8cdd3a717628 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java @@ -25,14 +25,14 @@ public class KnnAlgorithmContext implements ToXContentFragment, Writeable { private static final String PARAMETERS = "parameters"; private static final String NAME = "name"; + private static final int MAX_NUMBER_OF_ALGORITHM_PARAMETERS = 50; + private final Method method; private final Map parameters; - private static final int MAX_NUMBER_OF_ALGORITHM_PARAMETERS = 50; - public KnnAlgorithmContext(Method method, Map parameters) { - this.method = method; - this.parameters = parameters; + this.method = Objects.requireNonNull(method, "[method] for knn algorithm context cannot be null"); + this.parameters = Objects.requireNonNull(parameters, "[parameters] for knn algorithm context cannot be null"); } public Method getMethod() { @@ -78,9 +78,7 @@ public static KnnAlgorithmContext parse(Object in) { parameters = ((Map) value).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> { Object v = e.getValue(); if (v instanceof Map) { - throw new MapperParsingException( - String.format(Locale.ROOT, "Unable to parse parameter [%s] for [algorithm]", e.getValue()) - ); + throw new MapperParsingException(String.format(Locale.ROOT, "Unable to parse parameter [%s] for [algorithm]", v)); } return v; })); @@ -105,19 +103,17 @@ public static KnnAlgorithmContext parse(Object in) { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(NAME, method.name()); if (parameters == null) { - builder.field(PARAMETERS, (String) null); - } else { - builder.startObject(PARAMETERS); - parameters.forEach((key, value) -> { - try { - builder.field(key, value); - } catch (IOException ioe) { - throw new RuntimeException("Unable to generate xcontent for method component"); - } - - }); - builder.endObject(); + return builder.field(PARAMETERS, (String) null); } + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + builder.field(key, value); + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); + } + }); + builder.endObject(); return builder; } diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java index 106c7fcf2101e..3d67bd4c44604 100644 --- a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java +++ b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java @@ -22,13 +22,14 @@ */ public final class KnnContext implements ToXContentFragment, Writeable { - private final Metric metric; - private final KnnAlgorithmContext knnAlgorithmContext; private static final String KNN_METRIC_NAME = "metric"; private static final String ALGORITHM = "algorithm"; + private final Metric metric; + private final KnnAlgorithmContext knnAlgorithmContext; + KnnContext(final Metric metric, final KnnAlgorithmContext knnAlgorithmContext) { - this.metric = metric; + this.metric = Objects.requireNonNull(metric, "[metric] for knn context cannot be null"); this.knnAlgorithmContext = knnAlgorithmContext; } diff --git a/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java b/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java index c7fdf1e30e6a1..e5ffe799eb90b 100644 --- a/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java +++ b/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java @@ -285,9 +285,9 @@ public void ensureCanFlush() { /** * Reads operations from the translog - * @param location + * @param location the location in the translog * @return the translog operation - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ @Override public Translog.Operation readOperation(Translog.Location location) throws IOException { @@ -296,9 +296,9 @@ public Translog.Operation readOperation(Translog.Location location) throws IOExc /** * Adds an operation to the translog - * @param operation + * @param operation the operation in the translog * @return the location in the translog - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ @Override public Translog.Location add(Translog.Operation operation) throws IOException { @@ -396,8 +396,8 @@ public String getTranslogUUID() { /** * - * @param localCheckpointOfLastCommit - * @param flushThreshold + * @param localCheckpointOfLastCommit the localCheckpoint of last commit in the translog + * @param flushThreshold the flush threshold in the translog * @return if the translog should be flushed */ public boolean shouldPeriodicallyFlush(long localCheckpointOfLastCommit, long flushThreshold) { diff --git a/server/src/main/java/org/opensearch/index/translog/TranslogManager.java b/server/src/main/java/org/opensearch/index/translog/TranslogManager.java index f82434f40b06c..fec51d9aa9463 100644 --- a/server/src/main/java/org/opensearch/index/translog/TranslogManager.java +++ b/server/src/main/java/org/opensearch/index/translog/TranslogManager.java @@ -98,15 +98,15 @@ public interface TranslogManager { * Reads operations for the translog * @param location the location in the translog * @return the translog operation - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ Translog.Operation readOperation(Translog.Location location) throws IOException; /** * Adds an operation to the translog - * @param operation + * @param operation translog operation * @return the location in the translog - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ Translog.Location add(Translog.Operation operation) throws IOException;