Skip to content

Commit

Permalink
Use dot_product as default similarity for dense_vector fields
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Sep 6, 2023
1 parent 595e69f commit 0760f9c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ setup:
type: dense_vector
dims: 5
index: true
similarity: cosine
similarity: dot_product

---
"Indexed by default with specified similarity and index options":
Expand All @@ -42,7 +42,7 @@ setup:
vector:
type: dense_vector
dims: 5
similarity: dot_product
similarity: l2_norm
index_options:
type: hnsw
m: 32
Expand All @@ -62,7 +62,7 @@ setup:
type: dense_vector
dims: 5
index: true
similarity: dot_product
similarity: l2_norm
index_options:
type: hnsw
m: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ private static IndexVersion registerIndexVersion(int id, Version luceneVersion,
*/
public static final IndexVersion V_8_500_000 = registerIndexVersion(8_500_000, Version.LUCENE_9_7_0, "bf656f5e-5808-4eee-bf8a-e2bf6736ff55");
public static final IndexVersion V_8_500_001 = registerIndexVersion(8_500_001, Version.LUCENE_9_7_0, "45045a5a-fc57-4462-89f6-6bc04cda6015");
public static final IndexVersion V_8_500_002 = registerIndexVersion(8_500_002, Version.LUCENE_9_7_0, "5c49ca52-dd9e-4ca8-a201-94cb42c922d4");
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down Expand Up @@ -150,7 +151,7 @@ private static IndexVersion registerIndexVersion(int id, Version luceneVersion,
*/

private static class CurrentHolder {
private static final IndexVersion CURRENT = findCurrent(V_8_500_001);
private static final IndexVersion CURRENT = findCurrent(V_8_500_002);

// finds the pluggable current version, or uses the given fallback
private static IndexVersion findCurrent(IndexVersion fallback) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersion.V_8_11_0;
public static final IndexVersion DOT_PRODUCT_AUTO_NORMALIZED = IndexVersion.V_8_11_0;
public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersion.V_8_9_0;
public static final IndexVersion DOT_PRODUCT_DEFAULT_SIMILARITY = IndexVersion.V_8_500_002;

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; // maximum allowed number of dimensions
Expand Down Expand Up @@ -145,7 +146,7 @@ public Builder(String name, IndexVersion indexVersionCreated) {
"similarity",
false,
m -> toType(m).similarity,
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null,
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? defaultSimilarity(indexVersionCreated) : null,
VectorSimilarity.class
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null);
this.indexed.addValidator(v -> {
Expand Down Expand Up @@ -198,6 +199,10 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
}
}

private static VectorSimilarity defaultSimilarity(IndexVersion indexVersionCreated) {
return indexVersionCreated.onOrAfter(DOT_PRODUCT_DEFAULT_SIMILARITY) ? VectorSimilarity.DOT_PRODUCT : VectorSimilarity.COSINE;
}

private static FieldType getDenseVectorFieldType(
int dimension,
VectorEncoding vectorEncoding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.search.lookup.Source;
import org.elasticsearch.search.lookup.SourceProvider;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.index.IndexVersionUtils;
import org.elasticsearch.xcontent.XContentBuilder;
import org.junit.AssumptionViolatedException;

Expand All @@ -53,6 +54,7 @@

import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.INDEXED_BY_DEFAULT_INDEX_VERSION;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand All @@ -79,7 +81,7 @@ protected void minimalMapping(XContentBuilder b) throws IOException {

@Override
protected void minimalMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
indexMapping(b, indexVersion.onOrAfter(DenseVectorFieldMapper.INDEXED_BY_DEFAULT_INDEX_VERSION));
indexMapping(b, indexVersion.onOrAfter(INDEXED_BY_DEFAULT_INDEX_VERSION));
}

private void indexMapping(XContentBuilder b, boolean indexedByDefault) throws IOException {
Expand All @@ -92,7 +94,10 @@ private void indexMapping(XContentBuilder b, boolean indexedByDefault) throws IO
b.field("index", indexed);
}
if (indexed) {
b.field("similarity", "dot_product");
if (indexedByDefault == false) {
// Add similarity when it's required
b.field("similarity", "dot_product");
}
if (indexOptionsSet) {
b.startObject("index_options");
b.field("type", "hnsw");
Expand Down Expand Up @@ -242,7 +247,7 @@ public void testDims() {
public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3)));

testIndexedVector(VectorSimilarity.COSINE, mapper);
testIndexedVector(VectorSimilarity.DOT_PRODUCT, mapper);
}

public void testIndexedVector() throws Exception {
Expand Down Expand Up @@ -517,11 +522,20 @@ public void testParamsBeforeIndexByDefault() throws Exception {
assertEquals(VectorSimilarity.DOT_PRODUCT, denseVectorFieldType.getSimilarity());
}

public void testDefaultParamsIndexByDefault() throws Exception {
public void testDefaultParamsIndexedByDefault() throws Exception {
DocumentMapper documentMapper = createDocumentMapper(fieldMapping(b -> { b.field("type", "dense_vector").field("dims", 3); }));
DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) documentMapper.mappers().getMapper("field");
DenseVectorFieldType denseVectorFieldType = denseVectorFieldMapper.fieldType();

assertTrue(denseVectorFieldType.isIndexed());
assertEquals(VectorSimilarity.DOT_PRODUCT, denseVectorFieldType.getSimilarity());
}

public void testDefaultParamsBeforeDotProductNormalization() throws Exception {
DocumentMapper documentMapper = createDocumentMapper(INDEXED_BY_DEFAULT_INDEX_VERSION, fieldMapping(b -> { b.field("type", "dense_vector").field("dims", 3); }));
DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) documentMapper.mappers().getMapper("field");
DenseVectorFieldType denseVectorFieldType = denseVectorFieldMapper.fieldType();

assertTrue(denseVectorFieldType.isIndexed());
assertEquals(VectorSimilarity.COSINE, denseVectorFieldType.getSimilarity());
}
Expand Down

0 comments on commit 0760f9c

Please sign in to comment.