Skip to content

Commit

Permalink
Fix graph merge stats size calculation (#1844) (#1941)
Browse files Browse the repository at this point in the history
* Fix graph merge stats size calculation

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Add javadocs

Signed-off-by: Ryan Bogan <[email protected]>

* Make calculations easier to read

Signed-off-by: Ryan Bogan <[email protected]>

* Remove java overhead from calculations

Signed-off-by: Ryan Bogan <[email protected]>

* Change from serialization mode to vector data type for calculations

Signed-off-by: Ryan Bogan <[email protected]>

* Minor change to if statements

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
(cherry picked from commit e3158f9)
  • Loading branch information
ryanbogan authored Aug 9, 2024
1 parent 4b5e210 commit 05e9779
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,20 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
NativeIndexCreator indexCreator;
KNNCodecUtil.Pair pair;
Map<String, String> fieldAttributes = field.attributes();
VectorDataType vectorDataType;

if (fieldAttributes.containsKey(MODEL_ID)) {
String modelId = fieldAttributes.get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType();
vectorDataType = model.getModelMetadata().getVectorDataType();
pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath);
} else {
// get vector data type from field attributes or provide default value
VectorDataType vectorDataType = VectorDataType.get(
vectorDataType = VectorDataType.get(
fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())
);
pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
Expand All @@ -154,7 +155,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
return;
}

long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType);

if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
Expand Down
54 changes: 16 additions & 38 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;

Expand All @@ -21,12 +22,6 @@
public class KNNCodecUtil {
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

@AllArgsConstructor
public static final class Pair {
Expand Down Expand Up @@ -67,39 +62,22 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect
);
}

public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) {
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
int vectorSize = vectorLength;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
/**
* This method provides a rough estimate of the number of bytes used for storing an array with the given parameters.
* @param numVectors number of vectors in the array
* @param vectorLength the length of each vector
* @param vectorDataType type of data stored in each vector
* @return rough estimate of number of bytes used to store an array with the given parameters
*/
public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) {
if (vectorDataType == VectorDataType.FLOAT) {
return numVectors * vectorLength * FLOAT_BYTE_SIZE;
} else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) {
return numVectors * vectorLength;
} else {
throw new IllegalStateException("Unreachable code");
throw new IllegalArgumentException(
"Float, binary, and byte are the only supported vector data types for array size calculation."
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.SneakyThrows;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;

import java.util.Arrays;
Expand All @@ -18,6 +19,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;

public class KNNCodecUtilTests extends TestCase {
@SneakyThrows
Expand Down Expand Up @@ -52,4 +54,21 @@ public void testGetPair_whenCalled_thenReturn() {
assertEquals(dimension, pair.getDimension());
assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode);
}

public void testCalculateArraySize() {
int numVectors = 4;
int vectorLength = 10;

// Float data type
VectorDataType vectorDataType = VectorDataType.FLOAT;
assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType));

// Byte data type
vectorDataType = VectorDataType.BYTE;
assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));

// Binary data type
vectorDataType = VectorDataType.BINARY;
assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));
}
}

0 comments on commit 05e9779

Please sign in to comment.