Skip to content

Commit

Permalink
Refactored KNNVectorScriptDocValues to support both float and byte ve…
Browse files Browse the repository at this point in the history
…ctors

Signed-off-by: Bansi Kasundra <[email protected]>
  • Loading branch information
kasundra07 committed Jan 9, 2025
1 parent 4ab4cb4 commit bdfc4f8
Show file tree
Hide file tree
Showing 14 changed files with 459 additions and 266 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public long ramBytesUsed() {
}

@Override
public ScriptDocValues<float[]> getScriptValues() {
public ScriptDocValues<?> getScriptValues() {
try {
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.index.fielddata.ScriptDocValues;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
public abstract class KNNVectorScriptDocValues<T> extends ScriptDocValues<T> {

private final DocIdSetIterator vectorValues;
private final String fieldName;
Expand All @@ -42,7 +42,7 @@ public void setNextDocId(int docId) throws IOException {
docExists = lastDocID == curDocID;
}

public float[] getValue() {
public T getValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
Expand All @@ -60,37 +60,15 @@ public float[] getValue() {
}
}

public byte[] getByteValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
+ "This can be avoided by checking if a document has a value for the field or not "
+ "by doc['%s'].size() == 0 ? 0 : {your script}",
fieldName,
fieldName
);
throw new IllegalStateException(errorMessage);
}
try {
return doGetByteValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected byte[] doGetByteValue() throws IOException {
throw new UnsupportedOperationException();
}

protected abstract float[] doGetValue() throws IOException;
protected abstract T doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
}

@Override
public float[] get(int i) {
public T get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

Expand All @@ -103,20 +81,20 @@ public float[] get(int i) {
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
public static KNNVectorScriptDocValues<?> create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
return new KNNNativeVectorScriptDocValues<>((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues<byte[]> {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
Expand All @@ -125,17 +103,7 @@ private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptD
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}

@Override
public byte[] doGetByteValue() {
protected byte[] doGetValue() throws IOException {
try {
return values.vectorValue();
} catch (IOException e) {
Expand All @@ -144,7 +112,7 @@ public byte[] doGetByteValue() {
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues<float[]> {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
Expand All @@ -158,7 +126,7 @@ protected float[] doGetValue() throws IOException {
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNNativeVectorScriptDocValues<T> extends KNNVectorScriptDocValues<T> {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
Expand All @@ -167,18 +135,9 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
}

@Override
protected float[] doGetValue() throws IOException {
protected T doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}

@Override
public byte[] doGetByteValue() {
try {
return values.binaryValue().bytes;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
Expand All @@ -188,10 +147,18 @@ public byte[] doGetByteValue() {
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
public static KNNVectorScriptDocValues<?> emptyValues(String fieldName, VectorDataType type) {
if (type == VectorDataType.FLOAT) {
return new KNNVectorScriptDocValues<float[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
return new KNNVectorScriptDocValues<byte[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
protected byte[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
Expand Down
24 changes: 5 additions & 19 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand All @@ -75,15 +68,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand Down Expand Up @@ -143,7 +129,7 @@ public void freeNativeMemory(long memoryAddress) {
* @param binaryValue Binary Value
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromBytesRef(BytesRef binaryValue);
public abstract <T> T getVectorFromBytesRef(BytesRef binaryValue);

/**
* @param trainingDataAllocation training data that has been allocated in native memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
* KNNVectors with float[] type. The query value passed in is expected to be float[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNVectorType extends KNNScoreScript<float[]> {
public static class KNNFloatVectorType extends KNNScoreScript<float[]> {

public KNNVectorType(
public KNNFloatVectorType(
Map<String, Object> params,
float[] queryValue,
String field,
Expand All @@ -136,8 +136,9 @@ public KNNVectorType(
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
KNNVectorScriptDocValues<float[]> scriptDocValues = (KNNVectorScriptDocValues<float[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
Expand Down Expand Up @@ -171,12 +172,13 @@ public KNNByteVectorType(
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
KNNVectorScriptDocValues<byte[]> scriptDocValues = (KNNVectorScriptDocValues<byte[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue());
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}
}
Loading

0 comments on commit bdfc4f8

Please sign in to comment.