From 05e9779a1c59a40e62f74974bdfbaab5cfc83aa7 Mon Sep 17 00:00:00 2001
From: Ryan Bogan <rbogan@amazon.com>
Date: Fri, 9 Aug 2024 12:03:34 -0700
Subject: [PATCH] Fix graph merge stats size calculation (#1844) (#1941)

* Fix graph merge stats size calculation

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add changelog entry

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add javadocs

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Make calculations easier to read

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Remove java overhead from calculations

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Change from serialization mode to vector data type for calculations

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Minor change to if statements

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
(cherry picked from commit e3158f990d058b02568da617688fd4857d0d521b)
---
 CHANGELOG.md                                  |  1 +
 .../KNN80Codec/KNN80DocValuesConsumer.java    |  7 +--
 .../knn/index/codec/util/KNNCodecUtil.java    | 54 ++++++-------------
 .../index/codec/util/KNNCodecUtilTests.java   | 19 +++++++
 4 files changed, 40 insertions(+), 41 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6c761d27f..f9d715823 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 ### Enhancements
 ### Bug Fixes
 * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
+* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
 ### Infrastructure
 ### Documentation
 ### Maintenance
diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
index 31862c25d..7dad3a8fd 100644
--- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
+++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
@@ -129,6 +129,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
         NativeIndexCreator indexCreator;
         KNNCodecUtil.Pair pair;
         Map<String, String> fieldAttributes = field.attributes();
+        VectorDataType vectorDataType;
 
         if (fieldAttributes.containsKey(MODEL_ID)) {
             String modelId = fieldAttributes.get(MODEL_ID);
@@ -136,12 +137,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
             if (model.getModelBlob() == null) {
                 throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
             }
-            VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType();
+            vectorDataType = model.getModelMetadata().getVectorDataType();
             pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
             indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath);
         } else {
             // get vector data type from field attributes or provide default value
-            VectorDataType vectorDataType = VectorDataType.get(
+            vectorDataType = VectorDataType.get(
                 fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())
             );
             pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
@@ -154,7 +155,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
             return;
         }
 
-        long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
+        long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType);
 
         if (isMerge) {
             KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
index d208d8179..ea14fe883 100644
--- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
+++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
@@ -11,6 +11,7 @@
 import org.apache.lucene.index.BinaryDocValues;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.util.BytesRef;
+import org.opensearch.knn.index.VectorDataType;
 import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
 import org.opensearch.knn.index.codec.transfer.VectorTransfer;
 
@@ -21,12 +22,6 @@
 public class KNNCodecUtil {
     // Floats are 4 bytes in size
     public static final int FLOAT_BYTE_SIZE = 4;
-    // References to objects are 4 bytes in size
-    public static final int JAVA_REFERENCE_SIZE = 4;
-    // Each array in Java has a header that is 12 bytes
-    public static final int JAVA_ARRAY_HEADER_SIZE = 12;
-    // Java rounds each array size up to multiples of 8 bytes
-    public static final int JAVA_ROUNDING_NUMBER = 8;
 
     @AllArgsConstructor
     public static final class Pair {
@@ -67,39 +62,22 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect
         );
     }
 
-    public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) {
-        if (serializationMode == SerializationMode.ARRAY) {
-            int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
-            if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
-            }
-            int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
-            if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
-            }
-            return vectorsSize;
-        } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) {
-            int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
-            if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
-            }
-            int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
-            if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
-            }
-            return vectorsSize;
-        } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
-            int vectorSize = vectorLength;
-            if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
-            }
-            int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
-            if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
-                vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
-            }
-            return vectorsSize;
+    /**
+     * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters.
+     * @param numVectors number of vectors in the array
+     * @param vectorLength the length of each vector
+     * @param vectorDataType type of data stored in each vector
+     * @return rough estimate of number of bytes used to store an array with the given parameters
+     */
+    public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) {
+        if (vectorDataType == VectorDataType.FLOAT) {
+            return numVectors * vectorLength * FLOAT_BYTE_SIZE;
+        } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) {
+            return numVectors * vectorLength;
         } else {
-            throw new IllegalStateException("Unreachable code");
+            throw new IllegalArgumentException(
+                "Float, binary, and byte are the only supported vector data types for array size calculation."
+            );
         }
     }
 
diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java
index 2ff0f08e5..47dd1dda9 100644
--- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java
@@ -9,6 +9,7 @@
 import lombok.SneakyThrows;
 import org.apache.lucene.index.BinaryDocValues;
 import org.apache.lucene.util.BytesRef;
+import org.opensearch.knn.index.VectorDataType;
 import org.opensearch.knn.index.codec.transfer.VectorTransfer;
 
 import java.util.Arrays;
@@ -18,6 +19,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;
 
 public class KNNCodecUtilTests extends TestCase {
     @SneakyThrows
@@ -52,4 +54,21 @@ public void testGetPair_whenCalled_thenReturn() {
         assertEquals(dimension, pair.getDimension());
         assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode);
     }
+
+    public void testCalculateArraySize() {
+        int numVectors = 4;
+        int vectorLength = 10;
+
+        // Float data type
+        VectorDataType vectorDataType = VectorDataType.FLOAT;
+        assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType));
+
+        // Byte data type
+        vectorDataType = VectorDataType.BYTE;
+        assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));
+
+        // Binary data type
+        vectorDataType = VectorDataType.BINARY;
+        assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));
+    }
 }