diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c761d27f..f9d715823 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 31862c25d..7dad3a8fd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -129,6 +129,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, NativeIndexCreator indexCreator; KNNCodecUtil.Pair pair; Map fieldAttributes = field.attributes(); + VectorDataType vectorDataType; if (fieldAttributes.containsKey(MODEL_ID)) { String modelId = fieldAttributes.get(MODEL_ID); @@ -136,12 +137,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + vectorDataType = model.getModelMetadata().getVectorDataType(); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { // get vector data type from field attributes or provide default value - VectorDataType vectorDataType = VectorDataType.get( + vectorDataType = VectorDataType.get( fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) ); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); @@ -154,7 +155,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index d208d8179..ea14fe883 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.index.codec.transfer.VectorTransfer; @@ -21,12 +22,6 @@ public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - // References to objects are 4 bytes in size - public static final int JAVA_REFERENCE_SIZE = 4; - // Each array in Java has a header that is 12 bytes - public static final int JAVA_ARRAY_HEADER_SIZE = 12; - // Java rounds each array size up to multiples of 8 bytes - public static final int JAVA_ROUNDING_NUMBER = 8; @AllArgsConstructor public static final class Pair { @@ -67,39 +62,22 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect ); } - public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { - if (serializationMode == SerializationMode.ARRAY) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; - } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; - } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { - int vectorSize = vectorLength; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; + /** + * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. + * @param numVectors number of vectors in the array + * @param vectorLength the length of each vector + * @param vectorDataType type of data stored in each vector + * @return rough estimate of number of bytes used to store an array with the given parameters + */ + public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) { + if (vectorDataType == VectorDataType.FLOAT) { + return numVectors * vectorLength * FLOAT_BYTE_SIZE; + } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) { + return numVectors * vectorLength; } else { - throw new IllegalStateException("Unreachable code"); + throw new IllegalArgumentException( + "Float, binary, and byte are the only supported vector data types for array size calculation." + ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 2ff0f08e5..47dd1dda9 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.transfer.VectorTransfer; import java.util.Arrays; @@ -18,6 +19,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { @SneakyThrows @@ -52,4 +54,21 @@ public void testGetPair_whenCalled_thenReturn() { assertEquals(dimension, pair.getDimension()); assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); } + + public void testCalculateArraySize() { + int numVectors = 4; + int vectorLength = 10; + + // Float data type + VectorDataType vectorDataType = VectorDataType.FLOAT; + assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType)); + + // Byte data type + vectorDataType = VectorDataType.BYTE; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); + + // Binary data type + vectorDataType = VectorDataType.BINARY; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); + } }