Skip to content

Commit

Permalink
Add float|byte vector support to memory index (#13633)
Browse files Browse the repository at this point in the history
* Add float|byte vector support to memory index

* adding changes
  • Loading branch information
benwtrent authored Aug 7, 2024
1 parent 8221018 commit 9e831ee
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 2 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ Improvements

* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)

* GITHUB#13633: Add ability to read/write knn vector values to a MemoryIndex. (Ben Trent)

Optimizations
---------------------
Expand Down
238 changes: 236 additions & 2 deletions lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
Expand Down Expand Up @@ -636,6 +640,10 @@ public void addField(IndexableField field, Analyzer analyzer) {
if (field.fieldType().stored()) {
storeValues(info, field);
}

if (field.fieldType().vectorDimension() > 0) {
storeVectorValues(info, field);
}
}

/**
Expand Down Expand Up @@ -749,6 +757,56 @@ private void storePointValues(Info info, BytesRef pointValue) {
info.pointValues[info.pointValuesCount++] = BytesRef.deepCopyOf(pointValue);
}

private void storeVectorValues(Info info, IndexableField vectorField) {
assert vectorField instanceof KnnFloatVectorField || vectorField instanceof KnnByteVectorField;
switch (info.fieldInfo.getVectorEncoding()) {
case BYTE -> {
if (vectorField instanceof KnnByteVectorField byteVectorField) {
if (info.byteVectorCount == 1) {
throw new IllegalArgumentException(
"Only one value per field allowed for byte vector field ["
+ vectorField.name()
+ "]");
}
info.byteVectorCount++;
if (info.byteVectorValues == null) {
info.byteVectorValues = new byte[1][];
}
info.byteVectorValues[0] =
ArrayUtil.copyOfSubArray(
byteVectorField.vectorValue(), 0, info.fieldInfo.getVectorDimension());
return;
}
throw new IllegalArgumentException(
"Field ["
+ vectorField.name()
+ "] is not a byte vector field, but the field info is configured for byte vectors");
}
case FLOAT32 -> {
if (vectorField instanceof KnnFloatVectorField floatVectorField) {
if (info.floatVectorCount == 1) {
throw new IllegalArgumentException(
"Only one value per field allowed for float vector field ["
+ vectorField.name()
+ "]");
}
info.floatVectorCount++;
if (info.floatVectorValues == null) {
info.floatVectorValues = new float[1][];
}
info.floatVectorValues[0] =
ArrayUtil.copyOfSubArray(
floatVectorField.vectorValue(), 0, info.fieldInfo.getVectorDimension());
return;
}
throw new IllegalArgumentException(
"Field ["
+ vectorField.name()
+ "] is not a float vector field, but the field info is configured for float vectors");
}
}
}

private void storeValues(Info info, IndexableField field) {
if (info.storedValues == null) {
info.storedValues = new ArrayList<>();
Expand Down Expand Up @@ -1148,6 +1206,18 @@ private final class Info {

private BytesRef[] pointValues;

/** Number of float vectors added for this field */
private int floatVectorCount;

/** the float vectors added for this field */
private float[][] floatVectorValues;

/** Number of byte vectors added for this field */
private int byteVectorCount;

/** the byte vectors added for this field */
private byte[][] byteVectorValues;

private byte[] minPackedValue;

private byte[] maxPackedValue;
Expand Down Expand Up @@ -1641,12 +1711,20 @@ public PointValues getPointValues(String fieldName) {

@Override
public FloatVectorValues getFloatVectorValues(String fieldName) {
return null;
Info info = fields.get(fieldName);
if (info == null || info.floatVectorValues == null) {
return null;
}
return new MemoryFloatVectorValues(info);
}

@Override
public ByteVectorValues getByteVectorValues(String fieldName) {
return null;
Info info = fields.get(fieldName);
if (info == null || info.byteVectorValues == null) {
return null;
}
return new MemoryByteVectorValues(info);
}

@Override
Expand Down Expand Up @@ -2204,4 +2282,160 @@ public int[] clear() {
return super.clear();
}
}

private static final class MemoryFloatVectorValues extends FloatVectorValues {
private final Info info;
private int currentDoc = -1;

MemoryFloatVectorValues(Info info) {
this.info = info;
}

@Override
public int dimension() {
return info.fieldInfo.getVectorDimension();
}

@Override
public int size() {
return info.floatVectorCount;
}

@Override
public float[] vectorValue() {
if (currentDoc == 0) {
return info.floatVectorValues[0];
} else {
return null;
}
}

@Override
public VectorScorer scorer(float[] query) {
if (query.length != info.fieldInfo.getVectorDimension()) {
throw new IllegalArgumentException(
"query vector dimension "
+ query.length
+ " does not match field dimension "
+ info.fieldInfo.getVectorDimension());
}
MemoryFloatVectorValues vectorValues = new MemoryFloatVectorValues(info);
return new VectorScorer() {
@Override
public float score() throws IOException {
return info.fieldInfo
.getVectorSimilarityFunction()
.compare(vectorValues.vectorValue(), query);
}

@Override
public DocIdSetIterator iterator() {
return vectorValues;
}
};
}

@Override
public int docID() {
return currentDoc;
}

@Override
public int nextDoc() {
int doc = ++currentDoc;
if (doc == 0) {
return doc;
} else {
return NO_MORE_DOCS;
}
}

@Override
public int advance(int target) {
if (target == 0) {
currentDoc = target;
return target;
} else {
return NO_MORE_DOCS;
}
}
}

private static final class MemoryByteVectorValues extends ByteVectorValues {
private final Info info;
private int currentDoc = -1;

MemoryByteVectorValues(Info info) {
this.info = info;
}

@Override
public int dimension() {
return info.fieldInfo.getVectorDimension();
}

@Override
public int size() {
return info.byteVectorCount;
}

@Override
public byte[] vectorValue() {
if (currentDoc == 0) {
return info.byteVectorValues[0];
} else {
return null;
}
}

@Override
public VectorScorer scorer(byte[] query) {
if (query.length != info.fieldInfo.getVectorDimension()) {
throw new IllegalArgumentException(
"query vector dimension "
+ query.length
+ " does not match field dimension "
+ info.fieldInfo.getVectorDimension());
}
MemoryByteVectorValues vectorValues = new MemoryByteVectorValues(info);
return new VectorScorer() {
@Override
public float score() {
return info.fieldInfo
.getVectorSimilarityFunction()
.compare(vectorValues.vectorValue(), query);
}

@Override
public DocIdSetIterator iterator() {
return vectorValues;
}
};
}

@Override
public int docID() {
return currentDoc;
}

@Override
public int nextDoc() {
int doc = ++currentDoc;
if (doc == 0) {
return doc;
} else {
return NO_MORE_DOCS;
}
}

@Override
public int advance(int target) {
if (target == 0) {
currentDoc = target;
return target;
} else {
return NO_MORE_DOCS;
}
}
}
}
Loading

0 comments on commit 9e831ee

Please sign in to comment.