Skip to content

Commit

Permalink
Add new VectorScorer interface to vector value iterators (apache#13181)
Browse files Browse the repository at this point in the history
With quantized vectors, and with current vectors, we separate out the "scoring" vs. "iteration", requiring the user to always iterate the raw vectors and provide their own similarity function.

While this is flexible, it creates frustration in:

 - Just iterating and scoring, especially since the field already has a similarity function stored...Why can't we just know which one to use and use it!
 - Iterating and scoring quantized vectors. By default it would be good to be able to iterate and score quantized vectors (e.g. without going through the HNSW graph).

This significantly hampers support for true exact kNN search.

This commit extends the vector value iterators to be able to return a scorer given some vector value (what this PR demonstrates). The scorer contains a copy of the originating iterator and allows for iteration and scoring the most optimized way the provided codec can give. 

Users can still iterate vector values directly, read them on heap, and score any way they please.
  • Loading branch information
benwtrent authored May 9, 2024
1 parent 8d7e417 commit b60e86c
Show file tree
Hide file tree
Showing 47 changed files with 1,169 additions and 343 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ New Features
* GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add
new Hnsw format for binary quantized vectors. (Ben Trent)

* GITHUB#13181: Add new VectorScorer interface to vector value iterators. This allows for vector codecs to supply
simpler and more optimized vector scoring when iterating vector values directly. (Ben Trent)

Improvements
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand Down Expand Up @@ -272,7 +274,8 @@ private OffHeapFloatVectorValues getOffHeapVectorValues(FieldEntry fieldEntry)
throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new OffHeapFloatVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice);
return new OffHeapFloatVectorValues(
fieldEntry.dimension, fieldEntry.ordToDoc, fieldEntry.similarityFunction, bytesSlice);
}

private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
Expand Down Expand Up @@ -359,14 +362,20 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
final int byteSize;
int lastOrd = -1;
final float[] value;
final VectorSimilarityFunction similarityFunction;

int ord = -1;
int doc = -1;

OffHeapFloatVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) {
OffHeapFloatVectorValues(
int dimension,
int[] ordToDoc,
VectorSimilarityFunction similarityFunction,
IndexInput dataIn) {
this.dimension = dimension;
this.ordToDoc = ordToDoc;
this.dataIn = dataIn;
this.similarityFunction = similarityFunction;

byteSize = Float.BYTES * dimension;
value = new float[dimension];
Expand Down Expand Up @@ -420,7 +429,7 @@ public int advance(int target) {

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone());
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
}

@Override
Expand All @@ -433,6 +442,22 @@ public float[] vectorValue(int targetOrd) throws IOException {
lastOrd = targetOrd;
return value;
}

@Override
public VectorScorer scorer(float[] target) {
OffHeapFloatVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}

/** Read the nearest-neighbors graph from the index input */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand Down Expand Up @@ -255,7 +257,11 @@ private OffHeapFloatVectorValues getOffHeapVectorValues(FieldEntry fieldEntry)
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new OffHeapFloatVectorValues(
fieldEntry.dimension, fieldEntry.size(), fieldEntry.ordToDoc, bytesSlice);
fieldEntry.dimension,
fieldEntry.size(),
fieldEntry.ordToDoc,
fieldEntry.similarityFunction,
bytesSlice);
}

private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
Expand Down Expand Up @@ -399,16 +405,23 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
private final IndexInput dataIn;
private final int byteSize;
private final float[] value;
private final VectorSimilarityFunction similarityFunction;

private int ord = -1;
private int doc = -1;

OffHeapFloatVectorValues(int dimension, int size, int[] ordToDoc, IndexInput dataIn) {
OffHeapFloatVectorValues(
int dimension,
int size,
int[] ordToDoc,
VectorSimilarityFunction similarityFunction,
IndexInput dataIn) {
this.dimension = dimension;
this.size = size;
this.ordToDoc = ordToDoc;
ordToDocOperator = ordToDoc == null ? IntUnaryOperator.identity() : (ord) -> ordToDoc[ord];
this.dataIn = dataIn;
this.similarityFunction = similarityFunction;
byteSize = Float.BYTES * dimension;
value = new float[dimension];
}
Expand Down Expand Up @@ -468,7 +481,8 @@ public int advance(int target) {

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone());
return new OffHeapFloatVectorValues(
dimension, size, ordToDoc, similarityFunction, dataIn.clone());
}

@Override
Expand All @@ -477,6 +491,22 @@ public float[] vectorValue(int targetOrd) throws IOException {
dataIn.readFloats(value, 0, value.length);
return value;
}

@Override
public VectorScorer scorer(float[] target) {
OffHeapFloatVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}

/** Read the nearest-neighbors graph from the index input */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import java.io.IOException;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
Expand All @@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;

OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) {
protected final VectorSimilarityFunction vectorSimilarityFunction;
;

OffHeapFloatVectorValues(
int dimension,
int size,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
byteSize = Float.BYTES * dimension;
value = new float[dimension];
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
Expand Down Expand Up @@ -75,18 +85,24 @@ static OffHeapFloatVectorValues load(
vectorData.slice(
"vector-data", fieldEntry.vectorDataOffset(), fieldEntry.vectorDataLength());
if (fieldEntry.docsWithFieldOffset() == -1) {
return new DenseOffHeapVectorValues(fieldEntry.dimension(), fieldEntry.size(), bytesSlice);
return new DenseOffHeapVectorValues(
fieldEntry.dimension(), fieldEntry.size(), fieldEntry.similarityFunction(), bytesSlice);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
return new SparseOffHeapVectorValues(
fieldEntry, vectorData, fieldEntry.similarityFunction(), bytesSlice);
}
}

static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {

private int doc = -1;

public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
super(dimension, size, slice);
public DenseOffHeapVectorValues(
int dimension,
int size,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice) {
super(dimension, size, vectorSimilarityFunction, slice);
}

@Override
Expand Down Expand Up @@ -115,13 +131,29 @@ public int advance(int target) throws IOException {

@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone());
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}

@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}

@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}

private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
Expand All @@ -132,10 +164,13 @@ private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues
private final Lucene92HnswVectorsReader.FieldEntry fieldEntry;

public SparseOffHeapVectorValues(
Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
Lucene92HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice)
throws IOException {

super(fieldEntry.dimension(), fieldEntry.size(), slice);
super(fieldEntry.dimension(), fieldEntry.size(), vectorSimilarityFunction, slice);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
Expand Down Expand Up @@ -173,8 +208,9 @@ public int advance(int target) throws IOException {
}

@Override
public OffHeapFloatVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, vectorSimilarityFunction, slice.clone());
}

@Override
Expand All @@ -199,12 +235,28 @@ public int length() {
}
};
}

@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}

@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}

private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {

public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null);
super(dimension, 0, VectorSimilarityFunction.COSINE, null);
}

private int doc = -1;
Expand Down Expand Up @@ -258,5 +310,10 @@ public int ordToDoc(int ord) {
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}

@Override
public VectorScorer scorer(float[] query) {
return null;
}
}
}
Loading

0 comments on commit b60e86c

Please sign in to comment.