From e8e554e1060ec44ae758609c1f066835206bf770 Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Fri, 10 May 2024 14:47:23 +0100 Subject: [PATCH] add manual test --- .../vec/VectorScorerFactoryTests.java | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java b/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java index 115cf8e8cf9f8..246ddaeb2ebcf 100644 --- a/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java +++ b/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java @@ -8,6 +8,8 @@ package org.elasticsearch.vec; +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -17,6 +19,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Objects; +import java.util.Random; import java.util.function.Function; import static org.elasticsearch.vec.VectorSimilarityType.COSINE; @@ -226,6 +230,67 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Functi } } + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir(getTestName()))) { + final int dims = 8192; + final int size = 262144; + final float correction = randomFloat(); + + String fileName = getTestName() + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var off = (float) i; + out.writeBytes(vec, 0, vec.length); + out.writeInt(Float.floatToIntBits(off)); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + float off0 = (float) idx0; + float off1 = (float) idx1; + // dot product + float expected = luceneScore(DOT_PRODUCT, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); + var scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, DOT_PRODUCT, in).get(); + assertThat(scorer.score(idx0, idx1), equalTo(expected)); + assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); + // max inner product + expected = luceneScore(MAXIMUM_INNER_PRODUCT, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); + scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, MAXIMUM_INNER_PRODUCT, in).get(); + assertThat(scorer.score(idx0, idx1), equalTo(expected)); + assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); + // cosine + expected = luceneScore(COSINE, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); + scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, COSINE, in).get(); + assertThat(scorer.score(idx0, idx1), equalTo(expected)); + assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); + // euclidean + expected = luceneScore(EUCLIDEAN, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); + scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, EUCLIDEAN, in).get(); + assertThat(scorer.score(idx0, idx1), equalTo(expected)); + assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); + } + } + } + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static byte[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + byte[] ba = new byte[dims]; + for (int i = 0; i < dims; i++) { + ba[i] = (byte) RandomNumbers.randomIntBetween(random, MIN_INT7_VALUE, MAX_INT7_VALUE); + } + return ba; + } + static Function BYTE_ARRAY_RANDOM_INT7_FUNC = size -> { byte[] ba = new byte[size]; randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE);