diff --git a/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java new file mode 100644 index 0000000000000..3cedcb4fd0c34 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/search/knn/DenseVectorIT.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.search.knn; + +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.VersionUtils; + +import java.util.Map; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; + +public class DenseVectorIT extends OpenSearchIntegTestCase { + + private static final float[] VECTOR_ONE = { 2.0f, 4.5f, 5.6f, 4.2f }; + private static final float[] VECTOR_TWO = { 4.0f, 2.5f, 1.6f, 2.2f }; + + @Override + protected boolean forbidPrivateIndexSettings() { + return false; + } + + public void testIndexingSingleDocumentWithoutKnn() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("vector_field") + .field("type", "dense_vector") + .field("dimension", 4) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()) + ); + ensureSearchable("test"); + } + + public void testIndexingSingleDocumentWithDefaultKnnParams() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("vector_field") + .field("type", "dense_vector") + .field("dimension", 4) + .field("knn", Map.of()) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()) + ); + ensureSearchable("test"); + } + + public void testIndexingMultipleDocumentsWithHnswDefinition() throws Exception { + Version version = VersionUtils.randomIndexCompatibleVersion(random()); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build(); + XContentBuilder defaultMapping = jsonBuilder().startObject() + .startObject("properties") + .startObject("field") + .field("type", "dense_vector") + .field("dimension", 4) + .field( + "knn", + Map.of("metric", "l2", "algorithm", Map.of("name", "hnsw", "parameters", Map.of("max_connections", 12, "beam_width", 256))) + ) + .endObject() + .endObject() + .endObject(); + assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping)); + ensureGreen(); + + indexRandom( + true, + client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()), + client().prepareIndex("test").setId("2").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_TWO).endObject()) + ); + ensureSearchable("test"); + } +} diff --git a/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java new file mode 100644 index 0000000000000..e353fa2ca991a --- /dev/null +++ b/server/src/main/java/org/opensearch/index/codec/KnnVectorFormatFactory.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene92.Lucene92Codec; +import org.apache.lucene.codecs.lucene92.Lucene92HnswVectorsFormat; +import org.opensearch.index.mapper.KnnAlgorithmContext; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; + +import java.util.Map; + +import static org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +/** + * Factory that creates a {@link KnnVectorsFormat knn vector format} based on a mapping + * configuration for the field. + * + * @opensearch.internal + */ +public class KnnVectorFormatFactory { + + private final MapperService mapperService; + + public KnnVectorFormatFactory(MapperService mapperService) { + this.mapperService = mapperService; + } + + /** + * Create KnnVectorsFormat with parameters specified in the field definition or return codec's default + * Knn Vector Format if field is not of DenseVector type + * @param field name of the field + * @return KnnVectorFormat that is specific to a mapped field + */ + public KnnVectorsFormat create(final String field) { + final MappedFieldType mappedFieldType = mapperService.fieldType(field); + if (isDenseVectorFieldType(mappedFieldType)) { + final DenseVectorFieldType knnVectorFieldType = (DenseVectorFieldType) mappedFieldType; + final KnnAlgorithmContext algorithmContext = knnVectorFieldType.getKnnContext().getKnnAlgorithmContext(); + final Map methodParams = algorithmContext.getParameters(); + int maxConnections = getIntegerParam(methodParams, HNSW_PARAMETER_MAX_CONNECTIONS); + int beamWidth = getIntegerParam(methodParams, HNSW_PARAMETER_BEAM_WIDTH); + final KnnVectorsFormat luceneHnswVectorsFormat = new Lucene92HnswVectorsFormat(maxConnections, beamWidth); + return luceneHnswVectorsFormat; + } + return Lucene92Codec.getDefault().knnVectorsFormat(); + } + + private boolean isDenseVectorFieldType(final MappedFieldType mappedFieldType) { + return mappedFieldType instanceof DenseVectorFieldType; + } + + private int getIntegerParam(Map methodParams, String name) { + return (Integer) methodParams.get(name); + } +} diff --git a/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java b/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java index fd0c66983208a..f1cd21bfae186 100644 --- a/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java +++ b/server/src/main/java/org/opensearch/index/codec/PerFieldMappingPostingFormatCodec.java @@ -35,6 +35,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; import org.apache.lucene.codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; @@ -57,16 +58,18 @@ public class PerFieldMappingPostingFormatCodec extends Lucene92Codec { private final Logger logger; private final MapperService mapperService; private final DocValuesFormat dvFormat = new Lucene90DocValuesFormat(); + private final KnnVectorFormatFactory knnVectorsFormatFactory; static { assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) - : "PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC; + : "PerFieldMappingPostingFormatCodec must subclass the latest lucene codec: " + Lucene.LATEST_CODEC; } public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService, Logger logger) { super(compressionMode); this.mapperService = mapperService; this.logger = logger; + this.knnVectorsFormatFactory = new KnnVectorFormatFactory(mapperService); } @Override @@ -84,4 +87,9 @@ public PostingsFormat getPostingsFormatForField(String field) { public DocValuesFormat getDocValuesFormatForField(String field) { return dvFormat; } + + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return knnVectorsFormatFactory.create(field); + } } diff --git a/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java new file mode 100644 index 0000000000000..93dd360641ac6 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/DenseVectorFieldMapper.java @@ -0,0 +1,433 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.search.MultiTermQuery; +import org.apache.lucene.search.Query; +import org.opensearch.common.Explicit; +import org.opensearch.common.Nullable; +import org.opensearch.common.unit.Fuzziness; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +/** + * Field Mapper for Dense vector type. Extends ParametrizedFieldMapper in order to easily configure mapping parameters. + * + * @opensearch.internal + */ +public final class DenseVectorFieldMapper extends FieldMapper { + + public static final String CONTENT_TYPE = "dense_vector"; + + /** + * Define the max dimension a knn_vector mapping can have. + */ + private static final int MAX_DIMENSION = 1024; + + private static DenseVectorFieldMapper toType(FieldMapper in) { + return (DenseVectorFieldMapper) in; + } + + /** + * Builder for DenseVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector + * field type + */ + public static class Builder extends FieldMapper.Builder { + private CopyTo copyTo = CopyTo.empty(); + private Integer dimension = 1; + private KnnContext knnContext = null; + + public Builder(String name) { + super(name, Defaults.FIELD_TYPE); + builder = this; + } + + @Override + public DenseVectorFieldMapper build(BuilderContext context) { + final DenseVectorFieldType mappedFieldType = new DenseVectorFieldType(buildFullName(context), dimension, knnContext); + return new DenseVectorFieldMapper( + buildFullName(context), + fieldType, + mappedFieldType, + multiFieldsBuilder.build(this, context), + copyTo + ); + } + + public Builder dimension(int value) { + if (value > MAX_DIMENSION) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[dimension] value %d cannot be greater than %d for vector [%s]", value, MAX_DIMENSION, name) + ); + } + if (value <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[dimension] value %d must be greater than 0 for vector [%s]", value, name) + ); + } + this.dimension = value; + return this; + } + + public Builder knn(KnnContext value) { + this.knnContext = value; + return this; + } + } + + /** + * Type parser for dense_vector field mapper + * + * @opensearch.internal + */ + public static class TypeParser implements Mapper.TypeParser { + + @Override + public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { + Builder builder = new DenseVectorFieldMapper.Builder(name); + TypeParsers.parseField(builder, name, node, parserContext); + + for (Iterator> iterator = node.entrySet().iterator(); iterator.hasNext();) { + Map.Entry entry = iterator.next(); + String fieldName = entry.getKey(); + Object fieldNode = entry.getValue(); + switch (fieldName) { + case "dimension": + if (fieldNode == null) { + throw new MapperParsingException( + String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name) + ); + } + builder.dimension(XContentMapValues.nodeIntegerValue(fieldNode, 1)); + iterator.remove(); + break; + case "knn": + builder.knn(KnnContext.parse(fieldNode)); + iterator.remove(); + break; + default: + break; + } + } + return builder; + } + } + + /** + * Field type for dense_vector field mapper + * + * @opensearch.internal + */ + public static class DenseVectorFieldType extends MappedFieldType { + + private final int dimension; + private final KnnContext knnContext; + + public DenseVectorFieldType(String name, int dimension, KnnContext knnContext) { + this(name, Collections.emptyMap(), dimension, knnContext, false); + } + + public DenseVectorFieldType(String name, Map meta, int dimension, KnnContext knnContext, boolean hasDocValues) { + super(name, false, false, hasDocValues, TextSearchInfo.NONE, meta); + this.dimension = dimension; + this.knnContext = knnContext; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { + throw new UnsupportedOperationException("[fields search] are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public Query termQuery(Object value, QueryShardContext context) { + throw new UnsupportedOperationException("[term] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query fuzzyQuery( + Object value, + Fuzziness fuzziness, + int prefixLength, + int maxExpansions, + boolean transpositions, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[fuzzy] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query prefixQuery(String value, MultiTermQuery.RewriteMethod method, boolean caseInsensitive, QueryShardContext context) { + throw new UnsupportedOperationException("[prefix] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query wildcardQuery( + String value, + @Nullable MultiTermQuery.RewriteMethod method, + boolean caseInsensitive, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[wildcard] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + @Override + public Query regexpQuery( + String value, + int syntaxFlags, + int matchFlags, + int maxDeterminizedStates, + MultiTermQuery.RewriteMethod method, + QueryShardContext context + ) { + throw new UnsupportedOperationException("[regexp] queries are not supported on [" + CONTENT_TYPE + "] fields."); + } + + public int getDimension() { + return dimension; + } + + public KnnContext getKnnContext() { + return knnContext; + } + } + + protected Explicit ignoreMalformed; + protected Integer dimension; + protected boolean isKnnEnabled; + protected KnnContext knnContext; + protected boolean hasDocValues; + protected String modelId; + + public DenseVectorFieldMapper( + String simpleName, + FieldType fieldType, + DenseVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo + ) { + super(simpleName, fieldType, mappedFieldType, multiFields, copyTo); + dimension = mappedFieldType.getDimension(); + fieldType = new FieldType(DenseVectorFieldMapper.Defaults.FIELD_TYPE); + isKnnEnabled = mappedFieldType.getKnnContext() != null; + if (isKnnEnabled) { + knnContext = mappedFieldType.getKnnContext(); + fieldType.setVectorDimensionsAndSimilarityFunction( + mappedFieldType.getDimension(), + Metric.toSimilarityFunction(knnContext.getMetric()) + ); + } + fieldType.freeze(); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + parseCreateField(context, fieldType().getDimension()); + } + + @Override + protected void mergeOptions(FieldMapper other, List conflicts) { + DenseVectorFieldMapper denseVectorMergeWith = (DenseVectorFieldMapper) other; + if (!Objects.equals(dimension, denseVectorMergeWith.dimension)) { + conflicts.add("mapper [" + name() + "] has different [dimension]"); + } + + if (isOnlyOneObjectNull(knnContext, denseVectorMergeWith.knnContext) + || (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext) + && !Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric()))) { + conflicts.add("mapper [" + name() + "] has different [metric]"); + } + + if (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext)) { + + if (!Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric())) { + conflicts.add("mapper [" + name() + "] has different [metric]"); + } + + if (isBothObjectsNotNull(knnContext.getKnnAlgorithmContext(), denseVectorMergeWith.knnContext.getKnnAlgorithmContext())) { + KnnAlgorithmContext knnAlgorithmContext = knnContext.getKnnAlgorithmContext(); + KnnAlgorithmContext mergeWithKnnAlgorithmContext = denseVectorMergeWith.knnContext.getKnnAlgorithmContext(); + + if (isOnlyOneObjectNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext) + || (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext) + && !Objects.equals(knnAlgorithmContext.getMethod(), mergeWithKnnAlgorithmContext.getMethod()))) { + conflicts.add("mapper [" + name() + "] has different [method]"); + } + + if (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext)) { + Map knnAlgoParams = knnAlgorithmContext.getParameters(); + Map mergeWithKnnAlgoParams = mergeWithKnnAlgorithmContext.getParameters(); + + if (isOnlyOneObjectNull(knnAlgoParams, mergeWithKnnAlgoParams) + || (isBothObjectsNotNull(knnAlgoParams, mergeWithKnnAlgoParams) + && !Objects.equals(knnAlgoParams, mergeWithKnnAlgoParams))) { + conflicts.add("mapper [" + name() + "] has different [knn algorithm parameters]"); + } + } + } + } + } + + private boolean isOnlyOneObjectNull(Object object1, Object object2) { + return object1 == null && object2 != null || object2 == null && object1 != null; + } + + private boolean isBothObjectsNotNull(Object object1, Object object2) { + return object1 != null && object2 != null; + } + + protected void parseCreateField(ParseContext context, int dimension) throws IOException { + + context.path().add(simpleName()); + + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + + vector.add(value); + token = context.parser().nextToken(); + } + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + + vector.add(value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return; + } + + if (dimension != vector.size()) { + String errorMessage = String.format( + Locale.ROOT, + "Vector dimensions mismatch, expected [%d] but given [%d]", + dimension, + vector.size() + ); + throw new IllegalArgumentException(errorMessage); + } + + float[] array = new float[vector.size()]; + int i = 0; + for (Float f : vector) { + array[i++] = f; + } + + KnnVectorField point = new KnnVectorField(name(), array, fieldType); + + context.doc().add(point); + context.path().remove(); + } + + @Override + protected boolean docValuesByDefault() { + return false; + } + + @Override + public boolean parsesArrayValue() { + return true; + } + + @Override + public DenseVectorFieldType fieldType() { + return (DenseVectorFieldType) super.fieldType(); + } + + @Override + protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { + super.doXContentBody(builder, includeDefaults, params); + + builder.field("dimension", dimension); + if (knnContext != null) { + builder.startObject("knn"); + knnContext.toXContent(builder, params); + builder.endObject(); + } + } + + /** + * Define names for dense_vector parameters + * + * @opensearch.internal + */ + enum Names { + DIMENSION("dimension"), + KNN("knn"); + + Names(String value) { + this.value = value; + } + + String value; + + String getValue() { + return this.value; + } + } + + /** + * Default parameters + * + * @opensearch.internal + */ + static class Defaults { + public static final FieldType FIELD_TYPE = new FieldType(); + + static { + FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setOmitNorms(true); + FIELD_TYPE.setIndexOptions(IndexOptions.NONE); + FIELD_TYPE.setDocValuesType(DocValuesType.NONE); + FIELD_TYPE.freeze(); + } + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java new file mode 100644 index 0000000000000..b8cdd3a717628 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContext.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Abstracts KNN Algorithm segment of dense_vector field type + */ +public class KnnAlgorithmContext implements ToXContentFragment, Writeable { + + private static final String PARAMETERS = "parameters"; + private static final String NAME = "name"; + + private static final int MAX_NUMBER_OF_ALGORITHM_PARAMETERS = 50; + + private final Method method; + private final Map parameters; + + public KnnAlgorithmContext(Method method, Map parameters) { + this.method = Objects.requireNonNull(method, "[method] for knn algorithm context cannot be null"); + this.parameters = Objects.requireNonNull(parameters, "[parameters] for knn algorithm context cannot be null"); + } + + public Method getMethod() { + return method; + } + + public Map getParameters() { + return parameters; + } + + public static KnnAlgorithmContext parse(Object in) { + if (!(in instanceof Map)) { + throw new MapperParsingException("Unable to parse [algorithm] component"); + } + @SuppressWarnings("unchecked") + final Map methodMap = (Map) in; + Method method = Method.HNSW; + Map parameters = Map.of(); + + for (Map.Entry methodEntry : methodMap.entrySet()) { + final String key = methodEntry.getKey(); + final Object value = methodEntry.getValue(); + if (NAME.equals(key)) { + if (!(value instanceof String)) { + throw new MapperParsingException("Component [name] should be a string"); + } + try { + Method.fromName((String) value); + } catch (IllegalArgumentException illegalArgumentException) { + throw new MapperParsingException( + String.format(Locale.ROOT, "[algorithm name] value [%s] is invalid or not supported", value) + ); + } + } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { + throw new MapperParsingException("Unable to parse [parameters] for algorithm"); + } + // Check to interpret map parameters as sub-methodComponentContexts + parameters = ((Map) value).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> { + Object v = e.getValue(); + if (v instanceof Map) { + throw new MapperParsingException(String.format(Locale.ROOT, "Unable to parse parameter [%s] for [algorithm]", v)); + } + return v; + })); + if (parameters.size() > MAX_NUMBER_OF_ALGORITHM_PARAMETERS) { + throw new MapperParsingException( + String.format( + Locale.ROOT, + "Invalid number of parameters for [algorithm], max allowed is [%d] but given [%d]", + MAX_NUMBER_OF_ALGORITHM_PARAMETERS, + parameters.size() + ) + ); + } + } else { + throw new MapperParsingException(String.format(Locale.ROOT, "Invalid parameter %s for [algorithm]", key)); + } + } + return KnnAlgorithmContextFactory.createContext(method, parameters); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(NAME, method.name()); + if (parameters == null) { + return builder.field(PARAMETERS, (String) null); + } + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + builder.field(key, value); + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); + } + }); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.method.name()); + if (this.parameters != null) { + out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + } + } + + private static class ParameterMapValueWriter implements Writer { + + private ParameterMapValueWriter() {} + + @Override + public void write(StreamOutput out, Object o) throws IOException { + if (o instanceof KnnAlgorithmContext) { + out.writeBoolean(true); + ((KnnAlgorithmContext) o).writeTo(out); + } else { + out.writeBoolean(false); + out.writeGenericValue(o); + } + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + KnnAlgorithmContext that = (KnnAlgorithmContext) obj; + return method == that.method && this.parameters.equals(that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(method, parameters); + } + + /** + * Abstracts supported search methods for KNN + */ + public enum Method { + HNSW; + + private static final Map STRING_TO_METHOD = Map.of("hnsw", HNSW); + + public static Method fromName(String methodName) { + return Optional.ofNullable(STRING_TO_METHOD.get(methodName.toLowerCase(Locale.ROOT))) + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Provided knn method %s is not supported", methodName)) + ); + } + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java new file mode 100644 index 0000000000000..b2907ea795674 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnAlgorithmContextFactory.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.opensearch.index.mapper.KnnAlgorithmContext.Method; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +/** + * Class abstracts creation of KNN Algorithm context + */ +public class KnnAlgorithmContextFactory { + + public static final String HNSW_PARAMETER_MAX_CONNECTIONS = "max_connections"; + public static final String HNSW_PARAMETER_BEAM_WIDTH = "beam_width"; + + protected static final int MAX_CONNECTIONS_DEFAULT_VALUE = 16; + protected static final int MAX_CONNECTIONS_MAX_VALUE = 16; + protected static final int BEAM_WIDTH_DEFAULT_VALUE = 100; + protected static final int BEAM_WIDTH_MAX_VALUE = 512; + + private static final Map DEFAULT_CONTEXTS = Map.of( + Method.HNSW, + createContext( + Method.HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, MAX_CONNECTIONS_DEFAULT_VALUE, HNSW_PARAMETER_BEAM_WIDTH, BEAM_WIDTH_DEFAULT_VALUE) + ) + ); + + public static KnnAlgorithmContext defaultContext(Method method) { + return Optional.ofNullable(DEFAULT_CONTEXTS.get(method)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid knn method provided [%s], only HNSW is supported", method.name()) + ) + ); + } + + public static KnnAlgorithmContext createContext(Method method, Map parameters) { + Map, KnnAlgorithmContext>> methodToContextSupplier = Map.of(Method.HNSW, hnswContext()); + + return Optional.ofNullable(methodToContextSupplier.get(method)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid knn method provided [%s], only HNSW is supported", method.name()) + ) + ) + .apply(parameters); + } + + private static Function, KnnAlgorithmContext> hnswContext() { + Function, KnnAlgorithmContext> supplierFunction = params -> { + validateForSupportedParameters(Set.of(HNSW_PARAMETER_MAX_CONNECTIONS, HNSW_PARAMETER_BEAM_WIDTH), params); + + int maxConnections = getParameter( + params, + HNSW_PARAMETER_MAX_CONNECTIONS, + MAX_CONNECTIONS_DEFAULT_VALUE, + MAX_CONNECTIONS_MAX_VALUE + ); + int beamWidth = getParameter(params, HNSW_PARAMETER_BEAM_WIDTH, BEAM_WIDTH_DEFAULT_VALUE, BEAM_WIDTH_MAX_VALUE); + + KnnAlgorithmContext hnswKnnMethodContext = new KnnAlgorithmContext( + Method.HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, maxConnections, HNSW_PARAMETER_BEAM_WIDTH, beamWidth) + ); + return hnswKnnMethodContext; + }; + return supplierFunction; + } + + private static int getParameter(Map parameters, String paramName, int defaultValue, int maxValue) { + int value = defaultValue; + if (isNullOrEmpty(parameters)) { + return value; + } + if (parameters.containsKey(paramName)) { + if (!(parameters.get(paramName) instanceof Integer)) { + throw new MapperParsingException( + String.format(Locale.ROOT, "Invalid value for field %s, it must be an integer number", paramName) + ); + } + value = (int) parameters.get(paramName); + } + if (value > maxValue) { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value cannot be greater than %d", paramName, maxValue)); + } + if (value <= 0) { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value must be greater than 0", paramName)); + } + return value; + } + + static void validateForSupportedParameters(Set supportedParams, Map actualParams) { + if (isNullOrEmpty(actualParams)) { + return; + } + Optional unsupportedParam = actualParams.keySet() + .stream() + .filter(actualParamName -> !supportedParams.contains(actualParamName)) + .findAny(); + if (unsupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Algorithm parameter [%s] is not supported", unsupportedParam.get()) + ); + } + } + + private static boolean isNullOrEmpty(Map parameters) { + return parameters == null || parameters.isEmpty(); + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/KnnContext.java b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java new file mode 100644 index 0000000000000..3d67bd4c44604 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/KnnContext.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +/** + * Abstracts KNN segment of dense_vector field type + * + * @opensearch.internal + */ +public final class KnnContext implements ToXContentFragment, Writeable { + + private static final String KNN_METRIC_NAME = "metric"; + private static final String ALGORITHM = "algorithm"; + + private final Metric metric; + private final KnnAlgorithmContext knnAlgorithmContext; + + KnnContext(final Metric metric, final KnnAlgorithmContext knnAlgorithmContext) { + this.metric = Objects.requireNonNull(metric, "[metric] for knn context cannot be null"); + this.knnAlgorithmContext = knnAlgorithmContext; + } + + /** + * Parses an Object into a KnnContext. + * + * @param in Object containing mapping to be parsed + * @return KnnContext + */ + public static KnnContext parse(Object in) { + if (!(in instanceof Map)) { + throw new MapperParsingException("Unable to parse mapping into KnnContext. Object not of type \"Map\""); + } + + Map knnMap = (Map) in; + + Metric metric = Metric.L2; + KnnAlgorithmContext knnAlgorithmContext = KnnAlgorithmContextFactory.defaultContext(KnnAlgorithmContext.Method.HNSW); + + String key; + Object value; + for (Map.Entry methodEntry : knnMap.entrySet()) { + key = methodEntry.getKey(); + value = methodEntry.getValue(); + if (KNN_METRIC_NAME.equals(key)) { + if (value != null && !(value instanceof String)) { + throw new MapperParsingException(String.format(Locale.ROOT, "[%s] must be a string", KNN_METRIC_NAME)); + } + + if (value != null) { + try { + metric = Metric.fromName((String) value); + } catch (IllegalArgumentException illegalArgumentException) { + throw new MapperParsingException(String.format(Locale.ROOT, "[%s] value [%s] is invalid", key, value)); + } + } + } else if (ALGORITHM.equals(key)) { + if (value == null) { + continue; + } + + if (!(value instanceof Map)) { + throw new MapperParsingException("Unable to parse knn algorithm"); + } + knnAlgorithmContext = KnnAlgorithmContext.parse(value); + } else { + throw new MapperParsingException(String.format(Locale.ROOT, "%s value is invalid", key)); + } + } + return new KnnContext(metric, knnAlgorithmContext); + } + + public Metric getMetric() { + return metric; + } + + public KnnAlgorithmContext getKnnAlgorithmContext() { + return knnAlgorithmContext; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(KNN_METRIC_NAME, metric.name()); + builder.startObject(ALGORITHM); + builder = knnAlgorithmContext.toXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(metric.name()); + this.knnAlgorithmContext.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KnnContext that = (KnnContext) o; + return metric == that.metric && Objects.equals(knnAlgorithmContext, that.knnAlgorithmContext); + } + + @Override + public int hashCode() { + return Objects.hash(metric, knnAlgorithmContext.hashCode()); + } +} diff --git a/server/src/main/java/org/opensearch/index/mapper/Metric.java b/server/src/main/java/org/opensearch/index/mapper/Metric.java new file mode 100644 index 0000000000000..bc41128ef7479 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/Metric.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.index.VectorSimilarityFunction; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +/** + * Abstracts supported metrics for dense_vector and KNN search + * + * @opensearch.internal + */ +enum Metric { + L2, + COSINE, + DOT_PRODUCT; + + private static final Map STRING_TO_METRIC = Map.of("l2", L2, "cosine", COSINE, "dot_product", DOT_PRODUCT); + + private static final Map METRIC_TO_VECTOR_SIMILARITY_FUNCTION = Map.of( + L2, + VectorSimilarityFunction.EUCLIDEAN, + COSINE, + VectorSimilarityFunction.COSINE, + DOT_PRODUCT, + VectorSimilarityFunction.DOT_PRODUCT + ); + + public static Metric fromName(String metricName) { + return Optional.ofNullable(STRING_TO_METRIC.get(metricName.toLowerCase(Locale.ROOT))) + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Provided [metric] %s is not supported", metricName)) + ); + } + + /** + * Convert from our Metric type to Lucene VectorSimilarityFunction type. Only Euclidean metric is supported + */ + public static VectorSimilarityFunction toSimilarityFunction(Metric metric) { + return Optional.ofNullable(METRIC_TO_VECTOR_SIMILARITY_FUNCTION.get(metric)) + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Provided metric %s cannot be converted to vector similarity function", metric.name()) + ) + ); + } +} diff --git a/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java b/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java index c7fdf1e30e6a1..e5ffe799eb90b 100644 --- a/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java +++ b/server/src/main/java/org/opensearch/index/translog/InternalTranslogManager.java @@ -285,9 +285,9 @@ public void ensureCanFlush() { /** * Reads operations from the translog - * @param location + * @param location the location in the translog * @return the translog operation - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ @Override public Translog.Operation readOperation(Translog.Location location) throws IOException { @@ -296,9 +296,9 @@ public Translog.Operation readOperation(Translog.Location location) throws IOExc /** * Adds an operation to the translog - * @param operation + * @param operation the operation in the translog * @return the location in the translog - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ @Override public Translog.Location add(Translog.Operation operation) throws IOException { @@ -396,8 +396,8 @@ public String getTranslogUUID() { /** * - * @param localCheckpointOfLastCommit - * @param flushThreshold + * @param localCheckpointOfLastCommit the localCheckpoint of last commit in the translog + * @param flushThreshold the flush threshold in the translog * @return if the translog should be flushed */ public boolean shouldPeriodicallyFlush(long localCheckpointOfLastCommit, long flushThreshold) { diff --git a/server/src/main/java/org/opensearch/index/translog/TranslogManager.java b/server/src/main/java/org/opensearch/index/translog/TranslogManager.java index f82434f40b06c..fec51d9aa9463 100644 --- a/server/src/main/java/org/opensearch/index/translog/TranslogManager.java +++ b/server/src/main/java/org/opensearch/index/translog/TranslogManager.java @@ -98,15 +98,15 @@ public interface TranslogManager { * Reads operations for the translog * @param location the location in the translog * @return the translog operation - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ Translog.Operation readOperation(Translog.Location location) throws IOException; /** * Adds an operation to the translog - * @param operation + * @param operation translog operation * @return the location in the translog - * @throws IOException + * @throws IOException if an {@link IOException} occurs while executing method */ Translog.Location add(Translog.Operation operation) throws IOException; diff --git a/server/src/main/java/org/opensearch/indices/IndicesModule.java b/server/src/main/java/org/opensearch/indices/IndicesModule.java index 29ff507ad9fcf..88f4c3359bd8a 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesModule.java +++ b/server/src/main/java/org/opensearch/indices/IndicesModule.java @@ -48,6 +48,7 @@ import org.opensearch.index.mapper.CompletionFieldMapper; import org.opensearch.index.mapper.DataStreamFieldMapper; import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.DenseVectorFieldMapper; import org.opensearch.index.mapper.FieldAliasMapper; import org.opensearch.index.mapper.FieldNamesFieldMapper; import org.opensearch.index.mapper.GeoPointFieldMapper; @@ -161,6 +162,7 @@ public static Map getMappers(List mappe mappers.put(CompletionFieldMapper.CONTENT_TYPE, CompletionFieldMapper.PARSER); mappers.put(FieldAliasMapper.CONTENT_TYPE, new FieldAliasMapper.TypeParser()); mappers.put(GeoPointFieldMapper.CONTENT_TYPE, new GeoPointFieldMapper.TypeParser()); + mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, new DenseVectorFieldMapper.TypeParser()); for (MapperPlugin mapperPlugin : mapperPlugins) { for (Map.Entry entry : mapperPlugin.getMappers().entrySet()) { diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java new file mode 100644 index 0000000000000..30061af3e71eb --- /dev/null +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldMapperTests.java @@ -0,0 +1,380 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.IndexableField; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +public class DenseVectorFieldMapperTests extends FieldMapperTestCase2 { + + private static final float[] VECTOR = { 2.0f, 4.5f }; + + public void testValueDisplay() { + KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( + HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 16, HNSW_PARAMETER_BEAM_WIDTH, 100) + ); + KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); + MappedFieldType ft = new DenseVectorFieldType("field", 1, knnContext); + Object actualFloatArray = ft.valueForDisplay(VECTOR); + assertTrue(actualFloatArray instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); + } + + public void testSerializationWithoutKnn() throws IOException { + DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dimension", 2))); + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + + ParsedDocument doc = mapper.parse(source(b -> b.field("field", VECTOR))); + IndexableField[] fields = doc.rootDoc().getFields("field"); + assertEquals(1, fields.length); + assertTrue(fields[0] instanceof KnnVectorField); + float[] actualVector = ((KnnVectorField) fields[0]).vectorValue(); + assertArrayEquals(VECTOR, actualVector, 0.0f); + } + + public void testSerializationWithKnn() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + + ParsedDocument doc = mapper.parse(source(b -> b.field("field", VECTOR))); + IndexableField[] fields = doc.rootDoc().getFields("field"); + assertEquals(1, fields.length); + assertTrue(fields[0] instanceof KnnVectorField); + float[] actualVector = ((KnnVectorField) fields[0]).vectorValue(); + assertArrayEquals(VECTOR, actualVector, 0.0f); + } + + @Override + protected DenseVectorFieldMapper.Builder newBuilder() { + return new DenseVectorFieldMapper.Builder("dense_vector"); + } + + public void testDeprecatedBoost() throws IOException { + createMapperService(fieldMapping(b -> { + minimalMapping(b); + b.field("boost", 2.0); + })); + String type = typeName(); + String[] warnings = new String[] { + "Parameter [boost] on field [field] is deprecated and will be removed in 8.0", + "Parameter [boost] has no effect on type [" + type + "] and will be removed in future" }; + allowedWarnings(warnings); + } + + public void testIfMinimalSerializesToItself() throws IOException { + XContentBuilder orig = JsonXContent.contentBuilder().startObject(); + createMapperService(fieldMapping(this::minimalMapping)).documentMapper().mapping().toXContent(orig, ToXContent.EMPTY_PARAMS); + orig.endObject(); + XContentBuilder parsedFromOrig = JsonXContent.contentBuilder().startObject(); + createMapperService(orig).documentMapper().mapping().toXContent(parsedFromOrig, ToXContent.EMPTY_PARAMS); + parsedFromOrig.endObject(); + assertEquals(Strings.toString(orig), Strings.toString(parsedFromOrig)); + } + + public void testForEmptyName() { + MapperParsingException e = expectThrows(MapperParsingException.class, () -> createMapperService(mapping(b -> { + b.startObject(""); + minimalMapping(b); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("name cannot be empty string")); + } + + protected void writeFieldValue(XContentBuilder b) throws IOException { + b.value(new float[] { 2.5f }); + } + + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "dense_vector"); + b.field("dimension", 1); + } + + protected void registerParameters(MapperTestCase.ParameterChecker checker) throws IOException {} + + @Override + protected Set unsupportedProperties() { + return org.opensearch.common.collect.Set.of("analyzer", "similarity", "doc_values", "store", "index"); + } + + protected String typeName() throws IOException { + MapperService ms = createMapperService(fieldMapping(this::minimalMapping)); + return ms.fieldType("field").typeName(); + } + + @Override + protected boolean supportsMeta() { + return false; + } + + public void testCosineMetric() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "cosine", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + } + + public void testDotProductMetric() throws IOException { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + + Mapper fieldMapper = mapper.mappers().getMapper("field"); + assertTrue(fieldMapper instanceof DenseVectorFieldMapper); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; + assertEquals(2, denseVectorFieldMapper.fieldType().getDimension()); + } + + public void testHNSWAlgorithmParametersInvalidInput() throws Exception { + XContentBuilder mappingInvalidMaxConnections = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 256, "beam_width", 50)) + ) + ) + ); + final MapperParsingException mapperExceptionInvalidMaxConnections = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidMaxConnections) + ); + assertEquals("max_connections value cannot be greater than 16", mapperExceptionInvalidMaxConnections.getRootCause().getMessage()); + + XContentBuilder mappingInvalidBeamWidth = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 6, "beam_width", 1024)) + ) + ) + ); + final MapperParsingException mapperExceptionInvalidmBeamWidth = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidBeamWidth) + ); + assertEquals("beam_width value cannot be greater than 512", mapperExceptionInvalidmBeamWidth.getRootCause().getMessage()); + + XContentBuilder mappingUnsupportedParam = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "dot_product", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 6, "beam_width", 256, "some_param", 23)) + ) + ) + ); + final MapperParsingException mapperExceptionUnsupportedParam = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingUnsupportedParam) + ); + assertEquals("Algorithm parameter [some_param] is not supported", mapperExceptionUnsupportedParam.getRootCause().getMessage()); + } + + public void testInvalidMetric() throws Exception { + XContentBuilder mappingInvalidMetric = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "LAMBDA", "algorithm", Map.of("name", "HNSW"))) + ); + final MapperParsingException mapperExceptionInvalidMetric = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidMetric) + ); + assertEquals("[metric] value [LAMBDA] is invalid", mapperExceptionInvalidMetric.getRootCause().getMessage()); + } + + public void testInvalidAlgorithm() throws Exception { + XContentBuilder mappingInvalidAlgorithm = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", "MY_ALGORITHM"))) + ); + final MapperParsingException mapperExceptionInvalidAlgorithm = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mappingInvalidAlgorithm) + ); + assertEquals( + "[algorithm name] value [MY_ALGORITHM] is invalid or not supported", + mapperExceptionInvalidAlgorithm.getRootCause().getMessage() + ); + } + + public void testInvalidParams() throws Exception { + XContentBuilder mapping = fieldMapping( + b -> b.field("type", "dense_vector").field("dimension", 2).field("my_field", "some_value").field("knn", Map.of()) + ); + final MapperParsingException mapperParsingException = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mapping) + ); + assertEquals( + "Mapping definition for [field] has unsupported parameters: [my_field : some_value]", + mapperParsingException.getRootCause().getMessage() + ); + } + + public void testExceedMaxNumberOfAlgorithmParams() throws Exception { + Map algorithmParams = new HashMap<>(); + IntStream.range(0, 100).forEach(number -> algorithmParams.put("param" + number, randomInt(Integer.MAX_VALUE))); + XContentBuilder mapping = fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field("knn", Map.of("metric", "dot_product", "algorithm", Map.of("name", "HNSW", "parameters", algorithmParams))) + ); + final MapperParsingException mapperParsingException = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper(mapping) + ); + assertEquals( + "Invalid number of parameters for [algorithm], max allowed is [50] but given [100]", + mapperParsingException.getRootCause().getMessage() + ); + } + + public void testInvalidVectorNumberFormat() throws Exception { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + final MapperParsingException mapperExceptionStringAsVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", "some malicious script content"))) + ); + assertEquals( + mapperExceptionStringAsVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'some malicious script content'" + ); + + final MapperParsingException mapperExceptionInfinityVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", new Float[] { Float.POSITIVE_INFINITY }))) + ); + assertEquals( + mapperExceptionInfinityVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'Infinity'" + ); + + final MapperParsingException mapperExceptionNullVectorValue = expectThrows( + MapperParsingException.class, + () -> mapper.parse(source(b -> b.field("field", new Float[] { null }))) + ); + assertEquals( + mapperExceptionNullVectorValue.getMessage(), + "failed to parse field [field] of type [dense_vector] in document with id '1'. Preview of field's value: 'null'" + ); + } + + public void testNullVectorValue() throws Exception { + DocumentMapper mapper = createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dimension", 2) + .field( + "knn", + Map.of( + "metric", + "L2", + "algorithm", + Map.of("name", "HNSW", "parameters", Map.of("max_connections", 16, "beam_width", 100)) + ) + ) + ) + ); + mapper.parse(source(b -> b.field("field", (Float) null))); + mapper.parse(source(b -> b.field("field", VECTOR))); + mapper.parse(source(b -> b.field("field", (Float) null))); + } +} diff --git a/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java new file mode 100644 index 0000000000000..7ae117881a43f --- /dev/null +++ b/server/src/test/java/org/opensearch/index/mapper/DenseVectorFieldTypeTests.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.index.mapper; + +import org.junit.Before; +import org.mockito.Mockito; +import org.opensearch.common.unit.Fuzziness; +import org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType; +import org.opensearch.index.query.QueryShardContext; + +import java.util.Arrays; +import java.util.Map; + +import static org.opensearch.index.mapper.KnnAlgorithmContext.Method.HNSW; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH; +import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS; + +public class DenseVectorFieldTypeTests extends FieldTypeTestCase { + + private static final String FIELD_NAME = "field"; + private static final float[] VECTOR = { 2.0f, 4.5f }; + + private DenseVectorFieldType fieldType; + + @Before + public void setup() throws Exception { + KnnAlgorithmContext knnMethodContext = new KnnAlgorithmContext( + HNSW, + Map.of(HNSW_PARAMETER_MAX_CONNECTIONS, 10, HNSW_PARAMETER_BEAM_WIDTH, 100) + ); + KnnContext knnContext = new KnnContext(Metric.L2, knnMethodContext); + fieldType = new DenseVectorFieldType(FIELD_NAME, 1, knnContext); + } + + public void testValueDisplay() { + Object actualFloatArray = fieldType.valueForDisplay(VECTOR); + assertTrue(actualFloatArray instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArray, 0.0f); + + KnnContext knnContextDEfaultAlgorithmContext = new KnnContext( + Metric.L2, + KnnAlgorithmContextFactory.defaultContext(KnnAlgorithmContext.Method.HNSW) + ); + MappedFieldType ftDefaultAlgorithmContext = new DenseVectorFieldType(FIELD_NAME, 1, knnContextDEfaultAlgorithmContext); + Object actualFloatArrayDefaultAlgorithmContext = ftDefaultAlgorithmContext.valueForDisplay(VECTOR); + assertTrue(actualFloatArrayDefaultAlgorithmContext instanceof float[]); + assertArrayEquals(VECTOR, (float[]) actualFloatArrayDefaultAlgorithmContext, 0.0f); + } + + public void testTermQueryNotSupported() { + QueryShardContext context = Mockito.mock(QueryShardContext.class); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.termsQuery(Arrays.asList(VECTOR), context) + ); + assertEquals(exception.getMessage(), "[term] queries are not supported on [dense_vector] fields."); + } + + public void testPrefixQueryNotSupported() { + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.prefixQuery("foo*", null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals("[prefix] queries are not supported on [dense_vector] fields.", ee.getMessage()); + } + + public void testRegexpQueryNotSupported() { + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.regexpQuery("foo?", randomInt(10), 0, randomInt(10) + 1, null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals("[regexp] queries are not supported on [dense_vector] fields.", ee.getMessage()); + } + + public void testWildcardQueryNotSupported() { + UnsupportedOperationException ee = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.wildcardQuery("valu*", null, MOCK_QSC_DISALLOW_EXPENSIVE) + ); + assertEquals("[wildcard] queries are not supported on [dense_vector] fields.", ee.getMessage()); + } + + public void testFuzzyQuery() { + UnsupportedOperationException e = expectThrows( + UnsupportedOperationException.class, + () -> fieldType.fuzzyQuery("foo", Fuzziness.fromEdits(2), 1, 50, true, randomMockShardContext()) + ); + assertEquals("[fuzzy] queries are not supported on [dense_vector] fields.", e.getMessage()); + } +} diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 832328cb0242f..b172be9b6be64 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -91,6 +91,7 @@ import org.opensearch.index.mapper.CompletionFieldMapper; import org.opensearch.index.mapper.ContentPath; import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.DenseVectorFieldMapper; import org.opensearch.index.mapper.FieldAliasMapper; import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.GeoPointFieldMapper; @@ -184,6 +185,7 @@ public abstract class AggregatorTestCase extends OpenSearchTestCase { denylist.add(ObjectMapper.NESTED_CONTENT_TYPE); // TODO support for nested denylist.add(CompletionFieldMapper.CONTENT_TYPE); // TODO support completion denylist.add(FieldAliasMapper.CONTENT_TYPE); // TODO support alias + denylist.add(DenseVectorFieldMapper.CONTENT_TYPE); // Cannot aggregate dense_vector TYPE_TEST_DENYLIST = denylist; }