Skip to content

Commit

Permalink
Removing redundant conversion for hamming space for binary vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
kasundra07 committed Dec 17, 2024
1 parent d57fdea commit e7d655d
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
Expand All @@ -20,6 +22,7 @@
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private static final Logger logger = LogManager.getLogger(KNNVectorScriptDocValues.class);
private final DocIdSetIterator vectorValues;
private final String fieldName;
@Getter
Expand Down Expand Up @@ -60,6 +63,28 @@ public float[] getValue() {
}
}

public byte[] getByteValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
+ "This can be avoided by checking if a document has a value for the field or not "
+ "by doc['%s'].size() == 0 ? 0 : {your script}",
fieldName,
fieldName
);
throw new IllegalStateException(errorMessage);
}
try {
return doGetByteValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected byte[] doGetByteValue() throws IOException {
throw new UnsupportedOperationException();
}

protected abstract float[] doGetValue() throws IOException;

@Override
Expand Down Expand Up @@ -111,6 +136,16 @@ protected float[] doGetValue() throws IOException {
}
return value;
}

@Override
public byte[] doGetByteValue() {
try {
logger.info("KNNByteVectorScriptDocValues getByteValue called");
return values.vectorValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
Expand Down Expand Up @@ -139,6 +174,16 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}

@Override
public byte[] doGetByteValue() {
try {
logger.info("KNNNativeVectorScriptDocValues getByteValue called");
return values.binaryValue().bytes;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.plugin.script;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -14,6 +16,7 @@

import java.io.IOException;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Map;
import java.util.function.BiFunction;

Expand All @@ -23,6 +26,7 @@
* only concerned with the types of the query and docs being processed.
*/
public abstract class KNNScoreScript<T> extends ScoreScript {
private static final Logger logger = LogManager.getLogger(KNNScoreScript.class);
protected final T queryValue;
protected final String field;
protected final BiFunction<T, T, Float> scoringMethod;
Expand Down Expand Up @@ -144,4 +148,35 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}

public static class KNNByteVectorType extends KNNScoreScript<byte[]> {

public KNNByteVectorType(
Map<String, Object> params,
byte[] queryValue,
String field,
BiFunction<byte[], byte[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
* This function called for each doc in the segment. We evaluate the score of the vector in the doc
*
* @param explanationHolder A helper to take in an explanation from a script and turn
* it into an {@link org.apache.lucene.search.Explanation}
* @return score of the vector to the query vector
*/
@Override
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue());
}
}
}
52 changes: 38 additions & 14 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.plugin.script;

import lombok.Getter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -26,13 +28,17 @@

import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryVectorDataType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {
public static final Logger logger = LogManager.getLogger(KNNScoringSpace.class);

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -181,25 +187,43 @@ protected BiFunction<float[], float[], Float> getScoringMethod(final float[] pro
}
}

class Hamming extends KNNFieldSpace {
private static final Set<VectorDataType> DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY);
class Hamming implements KNNScoringSpace {
private byte[] processedQuery;
BiFunction<byte[], byte[], Float> scoringMethod;

public Hamming(Object query, MappedFieldType fieldType) {
super(query, fieldType, "hamming", DATA_TYPES_HAMMING);
}
if (!isKNNVectorFieldType(fieldType)) {
throw new IllegalArgumentException(
"Incompatible field_type for hamming space. The field type must be knn_vector."
);
}
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType;
if (!isBinaryVectorDataType(knnVectorFieldType)){
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Incompatible field_type for hamming space. The data type should be binary but got %s",
knnVectorFieldType.getVectorDataType()
)
);
}

@Override
protected BiFunction<float[], float[], Float> getScoringMethod(final float[] processedQuery) {
// TODO we want to avoid converting back and forth between byte and float
return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(toByte(q), toByte(v)));
this.processedQuery = parseToByteArray(query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType());
this.scoringMethod = (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v));
}

private byte[] toByte(final float[] vector) {
byte[] bytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
bytes[i] = (byte) vector[i];
}
return bytes;
@Override
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNByteVectorType(
params,
this.processedQuery,
field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ public static float[] parseToFloatArray(Object object, int expectedVectorLength,
return floatArray;
}

public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) {
byte[] byteArray = convertVectorToByteArray(object, vectorDataType);
if (expectedVectorLength != byteArray.length) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalStateException(
"Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "."
);
}
return byteArray;
}

/**
* Converts Object vector to primitive float[]
*
Expand All @@ -134,6 +145,23 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec
return primitiveVector;
}

@SuppressWarnings("unchecked")
public static byte[] convertVectorToByteArray(Object vector, VectorDataType vectorDataType) {
byte[] primitiveVector = null;
if (vector != null) {
final List<Number> tmp = (List<Number>) vector;
primitiveVector = new byte[tmp.size()];
for (int i = 0; i < primitiveVector.length; i++) {
float value = tmp.get(i).floatValue();
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
validateByteVectorValue(value, vectorDataType);
}
primitiveVector[i] = tmp.get(i).byteValue();
}
}
return primitiveVector;
}

/**
* Calculates the magnitude of given vector
*
Expand Down

0 comments on commit e7d655d

Please sign in to comment.