Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph merge stats size calculation #1844

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,20 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
NativeIndexCreator indexCreator;
KNNCodecUtil.Pair pair;
Map<String, String> fieldAttributes = field.attributes();
VectorDataType vectorDataType;

if (fieldAttributes.containsKey(MODEL_ID)) {
String modelId = fieldAttributes.get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType();
vectorDataType = model.getModelMetadata().getVectorDataType();
pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath);
} else {
// get vector data type from field attributes or provide default value
VectorDataType vectorDataType = VectorDataType.get(
vectorDataType = VectorDataType.get(
fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())
);
pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType));
Expand All @@ -156,7 +157,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
return;
}

long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType);

if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;

Expand All @@ -21,12 +22,6 @@
public class KNNCodecUtil {
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

@AllArgsConstructor
public static final class Pair {
Expand Down Expand Up @@ -67,39 +62,18 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect
);
}

public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) {
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
int vectorSize = vectorLength;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
/**
* This method provides a rough estimate of the number of bytes used for storing an array with the given parameters.
* @param numVectors number of vectors in the array
* @param vectorLength the length of each vector
* @param vectorDataType type of data stored in each vector
* @return rough estimate of number of bytes used to store an array with the given parameters
*/
public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) {
if (vectorDataType == VectorDataType.FLOAT) {
return numVectors * vectorLength * FLOAT_BYTE_SIZE;
} else {
throw new IllegalStateException("Unreachable code");
return numVectors * vectorLength;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.SneakyThrows;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;

import java.util.Arrays;
Expand All @@ -18,6 +19,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;

public class KNNCodecUtilTests extends TestCase {
@SneakyThrows
Expand Down Expand Up @@ -52,4 +54,21 @@ public void testGetPair_whenCalled_thenReturn() {
assertEquals(dimension, pair.getDimension());
assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode);
}

public void testCalculateArraySize() {
int numVectors = 4;
int vectorLength = 10;

// Float data type
VectorDataType vectorDataType = VectorDataType.FLOAT;
assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType));

// Byte data type
vectorDataType = VectorDataType.BYTE;
assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));

// Binary data type
vectorDataType = VectorDataType.BINARY;
assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType));
}
}
Loading