Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding 'dense_vector' field type #3659

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.search.knn;

import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.VersionUtils;

import java.util.Map;

import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;

public class DenseVectorIT extends OpenSearchIntegTestCase {

private static final float[] VECTOR_ONE = { 2.0f, 4.5f, 5.6f, 4.2f };
private static final float[] VECTOR_TWO = { 4.0f, 2.5f, 1.6f, 2.2f };

@Override
protected boolean forbidPrivateIndexSettings() {
return false;
}

public void testIndexingSingleDocumentWithoutKnn() throws Exception {
Version version = VersionUtils.randomIndexCompatibleVersion(random());
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build();
XContentBuilder defaultMapping = jsonBuilder().startObject()
.startObject("properties")
.startObject("vector_field")
.field("type", "dense_vector")
.field("dimension", 4)
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping));
ensureGreen();

indexRandom(
true,
client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject())
);
ensureSearchable("test");
}

public void testIndexingSingleDocumentWithDefaultKnnParams() throws Exception {
Version version = VersionUtils.randomIndexCompatibleVersion(random());
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build();
XContentBuilder defaultMapping = jsonBuilder().startObject()
.startObject("properties")
.startObject("vector_field")
.field("type", "dense_vector")
.field("dimension", 4)
.field("knn", Map.of())
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping));
ensureGreen();

indexRandom(
true,
client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject())
);
ensureSearchable("test");
}

public void testIndexingMultipleDocumentsWithHnswDefinition() throws Exception {
Version version = VersionUtils.randomIndexCompatibleVersion(random());
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version).build();
XContentBuilder defaultMapping = jsonBuilder().startObject()
.startObject("properties")
.startObject("field")
.field("type", "dense_vector")
.field("dimension", 4)
.field(
"knn",
Map.of("metric", "l2", "algorithm", Map.of("name", "hnsw", "parameters", Map.of("max_connections", 12, "beam_width", 256)))
)
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate("test").setSettings(settings).setMapping(defaultMapping));
ensureGreen();

indexRandom(
true,
client().prepareIndex("test").setId("1").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_ONE).endObject()),
client().prepareIndex("test").setId("2").setSource(jsonBuilder().startObject().field("vector_field", VECTOR_TWO).endObject())
);
ensureSearchable("test");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.index.codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene92.Lucene92Codec;
import org.apache.lucene.codecs.lucene92.Lucene92HnswVectorsFormat;
import org.opensearch.index.mapper.KnnAlgorithmContext;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;

import java.util.Map;

import static org.opensearch.index.mapper.DenseVectorFieldMapper.DenseVectorFieldType;
import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_BEAM_WIDTH;
import static org.opensearch.index.mapper.KnnAlgorithmContextFactory.HNSW_PARAMETER_MAX_CONNECTIONS;

/**
* Factory that creates a {@link KnnVectorsFormat knn vector format} based on a mapping
* configuration for the field.
*
* @opensearch.internal
*/
public class KnnVectorFormatFactory {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT?

Suggested change
public class KnnVectorFormatFactory {
public class KNNVectorFormatFactory {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to be honest - we're wrapping Lucene classes and they have notation as "Knn" (e.g. https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java), so I've follow them


private final MapperService mapperService;

public KnnVectorFormatFactory(MapperService mapperService) {
this.mapperService = mapperService;
}

/**
* Create KnnVectorsFormat with parameters specified in the field definition or return codec's default
* Knn Vector Format if field is not of DenseVector type
* @param field name of the field
* @return KnnVectorFormat that is specific to a mapped field
*/
public KnnVectorsFormat create(final String field) {
final MappedFieldType mappedFieldType = mapperService.fieldType(field);
if (isDenseVectorFieldType(mappedFieldType)) {
final DenseVectorFieldType knnVectorFieldType = (DenseVectorFieldType) mappedFieldType;
final KnnAlgorithmContext algorithmContext = knnVectorFieldType.getKnnContext().getKnnAlgorithmContext();
final Map<String, Object> methodParams = algorithmContext.getParameters();
int maxConnections = getIntegerParam(methodParams, HNSW_PARAMETER_MAX_CONNECTIONS);
int beamWidth = getIntegerParam(methodParams, HNSW_PARAMETER_BEAM_WIDTH);
final KnnVectorsFormat luceneHnswVectorsFormat = new Lucene92HnswVectorsFormat(maxConnections, beamWidth);
return luceneHnswVectorsFormat;
}
return Lucene92Codec.getDefault().knnVectorsFormat();
}

private boolean isDenseVectorFieldType(final MappedFieldType mappedFieldType) {
return mappedFieldType instanceof DenseVectorFieldType;
}

private int getIntegerParam(Map<String, Object> methodParams, String name) {
return (Integer) methodParams.get(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene92.Lucene92Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
Expand All @@ -57,16 +58,18 @@ public class PerFieldMappingPostingFormatCodec extends Lucene92Codec {
private final Logger logger;
private final MapperService mapperService;
private final DocValuesFormat dvFormat = new Lucene90DocValuesFormat();
private final KnnVectorFormatFactory knnVectorsFormatFactory;

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class)
: "PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
: "PerFieldMappingPostingFormatCodec must subclass the latest lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService, Logger logger) {
super(compressionMode);
this.mapperService = mapperService;
this.logger = logger;
this.knnVectorsFormatFactory = new KnnVectorFormatFactory(mapperService);
}

@Override
Expand All @@ -84,4 +87,9 @@ public PostingsFormat getPostingsFormatForField(String field) {
public DocValuesFormat getDocValuesFormatForField(String field) {
return dvFormat;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return knnVectorsFormatFactory.create(field);
}
}
Loading