diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 7053e6151..33a4b1fe5 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -39,7 +39,7 @@ public long ramBytesUsed() { } @Override - public ScriptDocValues getScriptValues() { + public ScriptDocValues getScriptValues() { try { FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName); if (fieldInfo == null) { diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 146177ba9..03109fb1a 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -18,7 +18,7 @@ import org.opensearch.index.fielddata.ScriptDocValues; @RequiredArgsConstructor(access = AccessLevel.PRIVATE) -public abstract class KNNVectorScriptDocValues extends ScriptDocValues { +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { private final DocIdSetIterator vectorValues; private final String fieldName; @@ -42,7 +42,7 @@ public void setNextDocId(int docId) throws IOException { docExists = lastDocID == curDocID; } - public float[] getValue() { + public T getValue() { if (!docExists) { String errorMessage = String.format( "One of the document doesn't have a value for field '%s'. " @@ -60,29 +60,7 @@ public float[] getValue() { } } - public byte[] getByteValue() { - if (!docExists) { - String errorMessage = String.format( - "One of the document doesn't have a value for field '%s'. " - + "This can be avoided by checking if a document has a value for the field or not " - + "by doc['%s'].size() == 0 ? 0 : {your script}", - fieldName, - fieldName - ); - throw new IllegalStateException(errorMessage); - } - try { - return doGetByteValue(); - } catch (IOException e) { - throw ExceptionsHelper.convertToOpenSearchException(e); - } - } - - protected byte[] doGetByteValue() throws IOException { - throw new UnsupportedOperationException(); - } - - protected abstract float[] doGetValue() throws IOException; + protected abstract T doGetValue() throws IOException; @Override public int size() { @@ -90,7 +68,7 @@ public int size() { } @Override - public float[] get(int i) { + public T get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } @@ -103,20 +81,20 @@ public float[] get(int i) { * @return A KNNVectorScriptDocValues object based on the type of the values. * @throws IllegalArgumentException If the type of values is unsupported. */ - public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { Objects.requireNonNull(values, "values must not be null"); if (values instanceof ByteVectorValues) { return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); } else if (values instanceof FloatVectorValues) { return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); } else if (values instanceof BinaryDocValues) { - return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + return new KNNNativeVectorScriptDocValues<>((BinaryDocValues) values, fieldName, vectorDataType); } else { throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); } } - private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { private final ByteVectorValues values; KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { @@ -125,17 +103,7 @@ private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptD } @Override - protected float[] doGetValue() throws IOException { - byte[] bytes = values.vectorValue(); - float[] value = new float[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - value[i] = (float) bytes[i]; - } - return value; - } - - @Override - public byte[] doGetByteValue() { + protected byte[] doGetValue() throws IOException { try { return values.vectorValue(); } catch (IOException e) { @@ -144,7 +112,7 @@ public byte[] doGetByteValue() { } } - private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { private final FloatVectorValues values; KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { @@ -158,7 +126,7 @@ protected float[] doGetValue() throws IOException { } } - private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { private final BinaryDocValues values; KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { @@ -167,18 +135,9 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip } @Override - protected float[] doGetValue() throws IOException { + protected T doGetValue() throws IOException { return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); } - - @Override - public byte[] doGetByteValue() { - try { - return values.binaryValue().bytes; - } catch (IOException e) { - throw ExceptionsHelper.convertToOpenSearchException(e); - } - } } /** @@ -188,10 +147,18 @@ public byte[] doGetByteValue() { * @param type The data type of the vector. * @return An empty KNNVectorScriptDocValues object. */ - public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { - return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + if (type == VectorDataType.FLOAT) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { @Override - protected float[] doGetValue() throws IOException { + protected byte[] doGetValue() throws IOException { throw new UnsupportedOperationException("empty values"); } }; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index e97bd2dbf..001acaae8 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -46,15 +46,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc } @Override - public float[] getVectorFromBytesRef(BytesRef binaryValue) { - float[] vector = new float[binaryValue.length]; - int i = 0; - int j = binaryValue.offset; - - while (i < binaryValue.length) { - vector[i++] = binaryValue.bytes[j++]; - } - return vector; + public byte[] getVectorFromBytesRef(BytesRef binaryValue) { + return binaryValue.bytes; } @Override @@ -75,15 +68,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc } @Override - public float[] getVectorFromBytesRef(BytesRef binaryValue) { - float[] vector = new float[binaryValue.length]; - int i = 0; - int j = binaryValue.offset; - - while (i < binaryValue.length) { - vector[i++] = binaryValue.bytes[j++]; - } - return vector; + public byte[] getVectorFromBytesRef(BytesRef binaryValue) { + return binaryValue.bytes; } @Override @@ -143,7 +129,7 @@ public void freeNativeMemory(long memoryAddress) { * @param binaryValue Binary Value * @return float vector deserialized from binary value */ - public abstract float[] getVectorFromBytesRef(BytesRef binaryValue); + public abstract T getVectorFromBytesRef(BytesRef binaryValue); /** * @param trainingDataAllocation training data that has been allocated in native memory diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index 4b2a2b598..1817ca73a 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -114,9 +114,9 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) { * KNNVectors with float[] type. The query value passed in is expected to be float[]. The fieldType of the docs * being searched over are expected to be KNNVector type. */ - public static class KNNVectorType extends KNNScoreScript { + public static class KNNFloatVectorType extends KNNScoreScript { - public KNNVectorType( + public KNNFloatVectorType( Map params, float[] queryValue, String field, @@ -136,8 +136,9 @@ public KNNVectorType( * @return score of the vector to the query vector */ @Override + @SuppressWarnings("unchecked") public double execute(ScoreScript.ExplanationHolder explanationHolder) { - KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); + KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); if (scriptDocValues.isEmpty()) { return 0.0; } @@ -171,12 +172,13 @@ public KNNByteVectorType( * @return score of the vector to the query vector */ @Override + @SuppressWarnings("unchecked") public double execute(ScoreScript.ExplanationHolder explanationHolder) { - KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); + KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); if (scriptDocValues.isEmpty()) { return 0.0; } - return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue()); + return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue()); } } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index f59748c39..33907dd3a 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -27,16 +27,14 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType; -import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryVectorDataType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger; -import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray; +import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; public interface KNNScoringSpace { - /** * Return the correct scoring script for a given query. The scoring script * @@ -57,9 +55,9 @@ ScoreScript getScoreScript(Map params, String field, SearchLooku abstract class KNNFieldSpace implements KNNScoringSpace { public static final Set DATA_TYPES_DEFAULT = Set.of(VectorDataType.FLOAT, VectorDataType.BYTE); - private float[] processedQuery; + private Object processedQuery; @Getter - private BiFunction scoringMethod; + private BiFunction scoringMethod; public KNNFieldSpace(final Object query, final MappedFieldType fieldType, final String spaceName) { this(query, fieldType, spaceName, DATA_TYPES_DEFAULT); @@ -76,6 +74,7 @@ public KNNFieldSpace( this.scoringMethod = getScoringMethod(this.processedQuery, knnVectorFieldType.getKnnMappingConfig().getIndexCreatedVersion()); } + @SuppressWarnings("unchecked") public ScoreScript getScoreScript( Map params, String field, @@ -83,7 +82,27 @@ public ScoreScript getScoreScript( LeafReaderContext ctx, IndexSearcher searcher ) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); + if (processedQuery instanceof float[]) { + return new KNNScoreScript.KNNFloatVectorType( + params, + (float[]) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx, + searcher + ); + } else { + return new KNNScoreScript.KNNByteVectorType( + params, + (byte[]) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx, + searcher + ); + } } private KNNVectorFieldType toKNNVectorFieldType( @@ -116,17 +135,27 @@ private KNNVectorFieldType toKNNVectorFieldType( return knnVectorFieldType; } - protected float[] getProcessedQuery(final Object query, final KNNVectorFieldType knnVectorFieldType) { - return parseToFloatArray( + protected Object getProcessedQuery(final Object query, final KNNVectorFieldType knnVectorFieldType) { + VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType() == null + ? VectorDataType.FLOAT + : knnVectorFieldType.getVectorDataType(); + if (vectorDataType.equals(VectorDataType.FLOAT)) { + return parseToFloatArray( + query, + KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), + knnVectorFieldType.getVectorDataType() + ); + } + return parseToByteArray( query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType() ); } - protected abstract BiFunction getScoringMethod(final float[] processedQuery); + public abstract BiFunction getScoringMethod(final Object processedQuery); - protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { + protected BiFunction getScoringMethod(final Object processedQuery, Version indexCreatedVersion) { return getScoringMethod(processedQuery); } @@ -138,8 +167,12 @@ public L2(final Object query, final MappedFieldType fieldType) { } @Override - public BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + } } } @@ -149,30 +182,35 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(float[] processedQuery) { + public BiFunction getScoringMethod(Object processedQuery) { return getScoringMethod(processedQuery, Version.CURRENT); } @Override - protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { - SpaceType.COSINESIMIL.validateVector(processedQuery); - float qVectorSquaredMagnitude = getVectorMagnitudeSquared(processedQuery); - if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) { - // To be consistent, we will be using same formula used by lucene as mentioned below - // https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 - // for indices that are created on or after 2.19.0 - // - // OS Score = ( 2 − cosineSimil) / 2 - // However cosineSimil = 1 - cos θ, after applying this to above formula, - // OS Score = ( 2 − ( 1 − cos θ ) ) / 2 - // which simplifies to - // OS Score = ( 1 + cos θ ) / 2 - return (float[] q, float[] v) -> Math.max( - ((1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0F), - 0 - ); + protected BiFunction getScoringMethod(final Object processedQuery, Version indexCreatedVersion) { + if (processedQuery instanceof float[]) { + SpaceType.COSINESIMIL.validateVector((float[]) processedQuery); + float qVectorSquaredMagnitude = getVectorMagnitudeSquared((float[]) processedQuery); + if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) { + // To be consistent, we will be using same formula used by lucene as mentioned below + // https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 + // for indices that are created on or after 2.19.0 + // + // OS Score = ( 2 − cosineSimil) / 2 + // However cosineSimil = 1 - cos θ, after applying this to above formula, + // OS Score = ( 2 − ( 1 − cos θ ) ) / 2 + // which simplifies to + // OS Score = ( 1 + cos θ ) / 2 + return (float[] q, float[] v) -> Math.max( + ((1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0F), + 0 + ); + } + return (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); + } else { + SpaceType.COSINESIMIL.validateVector((byte[]) processedQuery); + return (byte[] q, byte[] v) -> 1 + KNNScoringUtil.cosinesimil(q, v); } - return (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } } @@ -182,8 +220,12 @@ public L1(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + } } } @@ -193,8 +235,12 @@ public LInf(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + } } } @@ -204,51 +250,25 @@ public InnerProd(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + } else { + return (byte[] q, byte[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + } } } - class Hamming implements KNNScoringSpace { - private byte[] processedQuery; - BiFunction scoringMethod; + class Hamming extends KNNFieldSpace { + private static final Set DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY); public Hamming(Object query, MappedFieldType fieldType) { - if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for hamming space. The field type must be knn_vector."); - } - KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType; - if (!isBinaryVectorDataType(knnVectorFieldType)) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Incompatible field_type for hamming space. The data type should be [BINARY] but got %s", - knnVectorFieldType.getVectorDataType() - ) - ); - } - - this.processedQuery = parseToByteArray( - query, - KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), - knnVectorFieldType.getVectorDataType() - ); - this.scoringMethod = getHammingScoringMethod(); - } - - public BiFunction getHammingScoringMethod() { - return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v)); + super(query, fieldType, "hamming", DATA_TYPES_HAMMING); } @Override - public ScoreScript getScoreScript( - Map params, - String field, - SearchLookup lookup, - LeafReaderContext ctx, - IndexSearcher searcher - ) throws IOException { - return new KNNScoreScript.KNNByteVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); + public BiFunction getScoringMethod(final Object processedQuery) { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v)); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 17e47ba8e..b4789e145 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -111,24 +111,6 @@ public static float[] parseToFloatArray(Object object, int expectedVectorLength, return floatArray; } - /** - * Convert an Object to a byte array. - * - * @param object Object to be converted to a byte array - * @param expectedVectorLength int representing the expected vector length of this array. - * @return byte[] of the object - */ - public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) { - byte[] byteArray = convertVectorToByteArray(object, vectorDataType); - if (expectedVectorLength != byteArray.length) { - KNNCounter.SCRIPT_QUERY_ERRORS.increment(); - throw new IllegalStateException( - "Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "." - ); - } - return byteArray; - } - /** * Converts Object vector to primitive float[] * @@ -152,6 +134,24 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec return primitiveVector; } + /** + * Convert an Object to a byte array. + * + * @param object Object to be converted to a byte array + * @param expectedVectorLength int representing the expected vector length of this array. + * @return byte[] of the object + */ + public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) { + byte[] byteArray = convertVectorToByteArray(object, vectorDataType); + if (expectedVectorLength != byteArray.length) { + KNNCounter.SCRIPT_QUERY_ERRORS.increment(); + throw new IllegalStateException( + "Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "." + ); + } + return byteArray; + } + /** * Converts Object vector to byte[] * diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index f61ae4349..e14058e00 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -99,6 +99,18 @@ public static float l2Squared(float[] queryVector, float[] inputVector) { return VectorUtil.squareDistance(queryVector, inputVector); } + /** + * This method calculates L2 squared distance between byte query vector + * and byte input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L2 score + */ + public static float l2Squared(byte[] queryVector, byte[] inputVector) { + return VectorUtil.squareDistance(queryVector, inputVector); + } + private static float[] toFloat(final List inputVector, final VectorDataType vectorDataType) { Objects.requireNonNull(inputVector); float[] value = new float[inputVector.size()]; @@ -144,6 +156,23 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { } } + /** + * This method calculates cosine similarity + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return cosine score + */ + public static float cosinesimil(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { + logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); + return 0.0f; + } + } + /** * This method can be used script to avoid repeated calculation of normalization * for query vector for each filtered documents @@ -222,6 +251,24 @@ public static float l1Norm(float[] queryVector, float[] inputVector) { return distance; } + /** + * This method calculates L1 distance between byte query vector + * and byte input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L1 score + */ + public static float l1Norm(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + float diff = queryVector[i] - inputVector[i]; + distance += Math.abs(diff); + } + return distance; + } + /** * This method calculates L-inf distance between query vector * and input vector @@ -240,6 +287,24 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { return distance; } + /** + * This method calculates L-inf distance between byte query vector + * and input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L-inf score + */ + public static float lInfNorm(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + float diff = queryVector[i] - inputVector[i]; + distance = Math.max(Math.abs(diff), distance); + } + return distance; + } + /** * This method calculates dot product distance between query vector * and input vector @@ -253,6 +318,19 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { return VectorUtil.dotProduct(queryVector, inputVector); } + /** + * This method calculates dot product distance between byte query vector + * and byte input vector + * + * @param queryVector query vector + * @param inputVector input vector + * @return dot product score + */ + public static float innerProduct(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + return VectorUtil.dotProduct(queryVector, inputVector); + } + /** ********************************************************************************************* * Functions to be used in painless script which is defined in knn_allowlist.txt @@ -275,9 +353,13 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { * @param docValues script doc values * @return L2 score */ - public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("l2Squared", docValues.getVectorDataType()); - return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("l2Squared", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return l2Squared(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -296,9 +378,13 @@ public static float l2Squared(List queryVector, KNNVectorScriptDocValues * @param docValues script doc values * @return L-inf score */ - public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("lInfNorm", docValues.getVectorDataType()); - return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("lInfNorm", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return lInfNorm(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -317,9 +403,13 @@ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues * @param docValues script doc values * @return L1 score */ - public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("l1Norm", docValues.getVectorDataType()); - return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("l1Norm", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return l1Norm(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -338,9 +428,13 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @param docValues script doc values * @return inner product score */ - public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("innerProduct", docValues.getVectorDataType()); - return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("innerProduct", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return innerProduct(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -359,11 +453,18 @@ public static float innerProduct(List queryVector, KNNVectorScriptDocVal * @param docValues script doc values * @return cosine score */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); - float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimil(inputVector, docValues.getValue()); + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("cosineSimilarity", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, (float[]) docValues.getValue()); + } else { + byte[] inputVector = toByte(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, (byte[]) docValues.getValue()); + } } /** @@ -383,11 +484,21 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo * @param queryVectorMagnitude the magnitude of the query vector. * @return cosine score */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("cosineSimilarity", vectorDataType); float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); + if (VectorDataType.FLOAT == vectorDataType) { + return cosinesimilOptimized(inputVector, (float[]) docValues.getValue(), queryVectorMagnitude.floatValue()); + } else { + byte[] docVectorInByte = (byte[]) docValues.getValue(); + float[] docVectorInFloat = new float[docVectorInByte.length]; + for (int i = 0; i < docVectorInByte.length; i++) { + docVectorInFloat[i] = docVectorInByte[i]; + } + return cosinesimilOptimized(inputVector, docVectorInFloat, queryVectorMagnitude.floatValue()); + } } /** @@ -406,17 +517,9 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo * @param docValues script doc values * @return hamming score */ - public static float hamming(List queryVector, KNNVectorScriptDocValues docValues) { + public static float hamming(List queryVector, KNNVectorScriptDocValues docValues) { requireBinaryType("hamming", docValues.getVectorDataType()); byte[] queryVectorInByte = toByte(queryVector, docValues.getVectorDataType()); - - // TODO Optimization need be done for doc value to return byte[] instead of float[] - float[] docVectorInFloat = docValues.getValue(); - byte[] docVectorInByte = new byte[docVectorInFloat.length]; - for (int i = 0; i < docVectorInByte.length; i++) { - docVectorInByte[i] = (byte) docVectorInFloat[i]; - } - - return calculateHammingBit(queryVectorInByte, docVectorInByte); + return calculateHammingBit(queryVectorInByte, (byte[]) docValues.getValue()); } } diff --git a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt index 388cdda8a..0462cab03 100644 --- a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt +++ b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt @@ -4,7 +4,7 @@ # Painless definition of classes used by knn plugin class org.opensearch.knn.index.KNNVectorScriptDocValues { - float[] getValue() + Object getValue() } static_import { float l2Squared(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index cbe11dd6b..2649888a5 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -61,20 +61,22 @@ public void tearDown() throws Exception { directory.close(); } + @SuppressWarnings("unchecked") public void testGetScriptValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( leafReaderContext.reader(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT ); - ScriptDocValues scriptValues = leafFieldData.getScriptValues(); + ScriptDocValues scriptValues = (ScriptDocValues) leafFieldData.getScriptValues(); assertNotNull(scriptValues); assertTrue(scriptValues instanceof KNNVectorScriptDocValues); } + @SuppressWarnings("unchecked") public void testGetScriptValuesWrongFieldName() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid", VectorDataType.FLOAT); - ScriptDocValues scriptValues = leafFieldData.getScriptValues(); + ScriptDocValues scriptValues = (ScriptDocValues) leafFieldData.getScriptValues(); assertNotNull(scriptValues); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 66e2893c0..70e3a5caa 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -33,15 +33,16 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; - private KNNVectorScriptDocValues scriptDocValues; + private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; + private Class valuesClass; @Before public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); LeafReader leafReader = reader.getContext().leaves().get(0).reader(); @@ -86,9 +87,14 @@ public void tearDown() throws Exception { directory.close(); } + @SuppressWarnings("unchecked") public void testGetValue() throws IOException { scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + if (ByteVectorValues.class.equals(valuesClass)) { + Assert.assertArrayEquals(SAMPLE_BYTE_VECTOR_DATA, ((KNNVectorScriptDocValues) scriptDocValues).getValue()); + } else { + Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, ((KNNVectorScriptDocValues) scriptDocValues).getValue(), 0.1f); + } } // Test getValue without calling setNextDocId diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 73af608c1..b056c5e6e 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -32,7 +32,7 @@ public class VectorDataTypeTests extends KNNTestCase { @SneakyThrows public void testGetDocValuesWithFloatVectorDataType() { - KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); + KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); scriptDocValues.setNextDocId(0); Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); @@ -43,35 +43,37 @@ public void testGetDocValuesWithFloatVectorDataType() { @SneakyThrows public void testGetDocValuesWithByteVectorDataType() { - KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); + KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + Assert.assertArrayEquals(SAMPLE_BYTE_VECTOR_DATA, scriptDocValues.getValue()); reader.close(); directory.close(); } + @SuppressWarnings("unchecked") @SneakyThrows - private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { + private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { directory = newDirectory(); createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return (KNNVectorScriptDocValues) KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT ); } + @SuppressWarnings("unchecked") @SneakyThrows - private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { + private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { directory = newDirectory(); createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return (KNNVectorScriptDocValues) KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE @@ -110,8 +112,7 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException public void testGetVectorFromBytesRef_whenBinary_thenException() { byte[] vector = { 1, 2, 3 }; - float[] expected = { 1, 2, 3 }; BytesRef bytesRef = new BytesRef(vector); - assertArrayEquals(expected, VectorDataType.BINARY.getVectorFromBytesRef(bytesRef), 0.01f); + assertArrayEquals(vector, VectorDataType.BINARY.getVectorFromBytesRef(bytesRef)); } } diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 8151107e0..ef7afcd99 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -69,33 +69,54 @@ public void testKNNL2ScriptScore() throws Exception { testKNNScriptScore(SpaceType.L2); } + public void testKNNL2ByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.L2); + } + public void testKNNL1ScriptScore() throws Exception { testKNNScriptScore(SpaceType.L1); } + public void testKNNL1ByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.L1); + } + public void testKNNLInfScriptScore() throws Exception { testKNNScriptScore(SpaceType.LINF); } + public void testKNNLInfByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.LINF); + } + public void testKNNCosineScriptScore() throws Exception { testKNNScriptScore(SpaceType.COSINESIMIL); } + public void testKNNByteCosineScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.COSINESIMIL); + } + @SneakyThrows public void testKNNHammingScriptScore() { testKNNScriptScoreOnBinaryIndex(SpaceType.HAMMING); } + @SuppressWarnings("unchecked") @SneakyThrows public void testKNNHammingScriptScore_whenNonBinary_thenException() { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BYTE); - final BiFunction scoreFunction = getHammingScoreFunction(); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + SpaceType.HAMMING, + queryVector, + VectorDataType.BINARY + ); List nonBinary = List.of(VectorDataType.FLOAT, VectorDataType.BYTE); for (VectorDataType vectorDataType : nonBinary) { Exception e = expectThrows( Exception.class, - () -> createIndexAndAssertHammingScriptScore( + () -> createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, vectorDataType), SpaceType.HAMMING, scoreFunction, @@ -110,18 +131,23 @@ public void testKNNHammingScriptScore_whenNonBinary_thenException() { } } + @SuppressWarnings("unchecked") public void testKNNNonHammingScriptScore_whenBinary_thenException() { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BINARY); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + SpaceType.HAMMING, + queryVector, + VectorDataType.BINARY + ); Set spaceTypeToExclude = Set.of(SpaceType.UNDEFINED, SpaceType.HAMMING); Arrays.stream(SpaceType.values()).filter(s -> spaceTypeToExclude.contains(s) == false).forEach(s -> { - final BiFunction scoreFunction = getScoreFunction(s, queryVector); Exception e = expectThrows( Exception.class, - () -> createIndexAndAssertScriptScore( + () -> createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), s, - v -> scoreFunction.apply(queryVector, v), + scoreFunction, dims, queryVector, true, @@ -625,6 +651,7 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { assertEquals(1, secondQueryCacheMap.get("hit_count")); } + @SuppressWarnings("unchecked") public void testKNNScriptScoreOnModelBasedIndex() throws Exception { int dimensions = randomIntBetween(2, 10); String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); @@ -661,7 +688,11 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { continue; } final float[] queryVector = randomVector(dimensions); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.FLOAT + ); createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } } @@ -688,6 +719,30 @@ private List createMappers(int dimensions) throws Exception { ); } + private List createByteMappers(int dimensions) throws Exception { + return List.of( + createKnnIndexMapping(FIELD_NAME, dimensions, VectorDataType.BYTE), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + true, + VectorDataType.BYTE + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false, + VectorDataType.BYTE + ) + ); + } + private List createBinaryIndexMappers(int dimensions) throws Exception { return List.of( createKnnIndexMapping( @@ -724,6 +779,15 @@ private float[] randomVector(final int dimensions, final VectorDataType vectorDa return vector; } + private byte[] randomByteVector(final int dimensions, final VectorDataType vectorDataType) { + int size = VectorDataType.BINARY == vectorDataType ? dimensions / 8 : dimensions; + final byte[] vector = new byte[size]; + for (int i = 0; i < size; i++) { + vector[i] = randomByte(); + } + return vector; + } + private Map createDataset( Function scoreFunction, int dimensions, @@ -745,29 +809,7 @@ private Map createDataset( return dataset; } - private BiFunction getHammingScoreFunction() { - final int dims = randomIntBetween(2, 10); - final float[] queryVector = randomVector(dims, VectorDataType.BINARY); - final SpaceType spaceType = SpaceType.HAMMING; - List target = new ArrayList<>(queryVector.length); - for (float f : queryVector) { - target.add(f); - } - KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create( - spaceType.getValue(), - target, - new KNNVectorFieldType( - FIELD_NAME, - Collections.emptyMap(), - VectorDataType.BINARY, - getMappingConfigForFlatMapping(queryVector.length * 8) - ) - ); - - return ((KNNScoringSpace.Hamming) knnScoringSpace).getHammingScoringMethod(); - } - - private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { + private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector, VectorDataType vectorDataType) { List target = new ArrayList<>(queryVector.length); for (float f : queryVector) { target.add(f); @@ -778,8 +820,8 @@ private BiFunction getScoreFunction(SpaceType spaceType new KNNVectorFieldType( FIELD_NAME, Collections.emptyMap(), - SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, - getMappingConfigForFlatMapping(SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length) + vectorDataType, + getMappingConfigForFlatMapping(vectorDataType.equals(VectorDataType.BINARY) ? queryVector.length * 8 : queryVector.length) ) ); switch (spaceType) { @@ -788,35 +830,64 @@ private BiFunction getScoreFunction(SpaceType spaceType case LINF: case COSINESIMIL: case INNER_PRODUCT: - return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(); + case HAMMING: + if (vectorDataType.equals(VectorDataType.FLOAT)) { + return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(queryVector); + } + return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(toByte(queryVector)); default: throw new IllegalArgumentException(); } } + @SuppressWarnings("unchecked") private void testKNNScriptScore(SpaceType spaceType) throws Exception { final int dims = randomIntBetween(2, 10); final float[] queryVector = randomVector(dims); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.FLOAT + ); for (String mapper : createMappers(dims)) { createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true); createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false); } } + @SuppressWarnings("unchecked") + private void testKNNByteScriptScore(SpaceType spaceType) throws Exception { + final int dims = randomIntBetween(2, 10); + final float[] queryVector = randomVector(dims, VectorDataType.BYTE); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.BYTE + ); + for (String mapper : createByteMappers(dims)) { + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BYTE); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BYTE); + } + } + + @SuppressWarnings("unchecked") private void testKNNScriptScoreOnBinaryIndex(SpaceType spaceType) throws Exception { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BINARY); - final BiFunction scoreFunction = getHammingScoreFunction(); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.BINARY + ); // Test when knn is enabled and engine is Faiss for (String mapper : createBinaryIndexMappers(dims)) { - createIndexAndAssertHammingScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BINARY); - createIndexAndAssertHammingScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BINARY); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BINARY); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BINARY); } // Test when knn is disabled and engine is default(Nmslib) - createIndexAndAssertHammingScriptScore( + createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), spaceType, scoreFunction, @@ -826,7 +897,7 @@ private void testKNNScriptScoreOnBinaryIndex(SpaceType spaceType) throws Excepti false, VectorDataType.BINARY ); - createIndexAndAssertHammingScriptScore( + createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), spaceType, scoreFunction, @@ -911,7 +982,7 @@ private void createIndexAndAssertScriptScore( } } - private void createIndexAndAssertHammingScriptScore( + private void createIndexAndAssertByteScriptScore( String mapper, SpaceType spaceType, BiFunction scoreFunction, diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 7141553cd..c571053e0 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -55,6 +55,7 @@ private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) th } @SneakyThrows + @SuppressWarnings("unchecked") public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); @@ -66,7 +67,12 @@ public void testL2_whenValid_thenSucceed() { getMappingConfigForMethodMapping(knnMethodContext, 3) ); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); - assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); + float[] processedFloatQuery = (float[]) l2.getProcessedQuery(arrayListQueryObject, fieldType); + assertEquals( + 1F, + ((BiFunction) l2.getScoringMethod(processedFloatQuery)).apply(arrayFloat, arrayFloat), + 0.1F + ); } @SneakyThrows @@ -75,6 +81,7 @@ public void testL2_whenInvalidType_thenException() { expectThrowsExceptionWithKNNFieldWithBinaryDataType(KNNScoringSpace.L2.class); } + @SuppressWarnings("unchecked") public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); @@ -87,9 +94,10 @@ public void testCosineSimilarity_whenValid_thenSucceed() { getMappingConfigForMethodMapping(knnMethodContext, 3) ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); + float[] processedFloatQuery = (float[]) cosineSimilarity.getProcessedQuery(arrayListQueryObject, fieldType); assertEquals( VectorSimilarityFunction.COSINE.compare(arrayFloat2, arrayFloat), - cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), + ((BiFunction) cosineSimilarity.getScoringMethod(processedFloatQuery)).apply(arrayFloat2, arrayFloat), 0.1F ); @@ -131,6 +139,7 @@ public void testCosineSimilarity_whenInvalidType_thenException() { expectThrowsExceptionWithKNNFieldWithBinaryDataType(KNNScoringSpace.CosineSimilarity.class); } + @SuppressWarnings("unchecked") public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); @@ -145,23 +154,45 @@ public void testInnerProd_whenValid_thenSucceed() { ); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); - assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); + float[] processedFloatQuery_case1 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case1, fieldType); + assertEquals( + 7.0F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case1)).apply( + arrayFloat_case1, + arrayFloat2_case1 + ), + 0.001F + ); float[] arrayFloat_case2 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case2 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); float[] arrayFloat2_case2 = new float[] { -100_000.0f, -200_000.0f, -300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case2, fieldType); - - assertEquals(7.142857143E-12F, innerProd.getScoringMethod().apply(arrayFloat_case2, arrayFloat2_case2), 1.0E-11F); + float[] processedFloatQuery_case2 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case2, fieldType); + assertEquals( + 7.142857143E-12F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case2)).apply( + arrayFloat_case2, + arrayFloat2_case2 + ), + 1.0E-11F + ); float[] arrayFloat_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case3 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); float[] arrayFloat2_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case3, fieldType); - - assertEquals(140_000_000_001F, innerProd.getScoringMethod().apply(arrayFloat_case3, arrayFloat2_case3), 0.01F); + float[] processedFloatQuery_case3 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case3, fieldType); + assertEquals( + 140_000_000_001F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case3)).apply( + arrayFloat_case3, + arrayFloat2_case3 + ), + 0.01F + ); } @SneakyThrows @@ -205,6 +236,7 @@ public void testHammingBit_Base64() { ); } + @SuppressWarnings("unchecked") public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); @@ -216,9 +248,13 @@ public void testHamming_whenKNNFieldType_thenSucceed() { ); KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); - + byte[] processedByteQuery = (byte[]) hamming.getProcessedQuery(arrayListQueryObject, fieldType); byte[] arrayByte = new byte[] { 1, 2, 3 }; - assertEquals(1F, ((BiFunction) hamming.scoringMethod).apply(arrayByte, arrayByte), 0.1F); + assertEquals( + 1F, + ((BiFunction) hamming.getScoringMethod(processedByteQuery)).apply(arrayByte, arrayByte), + 0.1F + ); } public void testHamming_whenNonBinaryVectorDataType_thenException() { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 2cc20c8f9..22e54b111 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -314,11 +314,10 @@ public void testHamming_whenKNNVectorScriptDocValuesOfBinary_thenSuccess() { byte[] b1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 byte[] b2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 float[] f1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 - float[] f2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 List queryVector = Arrays.asList(f1[0], f1[1], f1[2]); - KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); + KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); when(docValues.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(docValues.getValue()).thenReturn(f2); + when(docValues.getValue()).thenReturn(b2); assertEquals(KNNScoringUtil.calculateHammingBit(b1, b2), KNNScoringUtil.hamming(queryVector, docValues), 0.01f); }