Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph merge stats size calculation #1844

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
public class KNNCodecUtil {
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

@AllArgsConstructor
public static final class Pair {
Expand Down Expand Up @@ -67,39 +61,18 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect
);
}

/**
* This method provides a rough estimate of the number of bytes used for storing an array with the given parameters.
* @param numVectors number of vectors in the array
* @param vectorLength the length of each vector
* @param serializationMode serialization mode
* @return rough estimate of number of bytes used to store an array with the given parameters
*/
public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) {
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
int vectorSize = vectorLength;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get rid of this serializationMode attribute completely?

Copy link
Member Author

@ryanbogan ryanbogan Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is used to calculate array size from a Pair class typically. The issue is that the KNNCodecUtil.Pair class only has doc id's, a vector address, dimension, and serialization mode as instance variables. Therefore, without reading memory from the vector address I don't think it's possible to differentiate whether the data is floats or bytes without the serialization mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use vector datatype to know if the vector is byte or float?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that would work, binary type would be the same calculation as byte right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes same thing

return numVectors * vectorLength;
} else {
throw new IllegalStateException("Unreachable code");
return numVectors * vectorLength * FLOAT_BYTE_SIZE;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;

public class KNNCodecUtilTests extends TestCase {
@SneakyThrows
Expand Down Expand Up @@ -52,4 +53,21 @@ public void testGetPair_whenCalled_thenReturn() {
assertEquals(dimension, pair.getDimension());
assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode);
}

public void testCalculateArraySize() {
int numVectors = 4;
int vectorLength = 10;

// Array
SerializationMode serializationMode = SerializationMode.ARRAY;
assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode));

// Collection of floats
serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode));

// Collection of bytes
serializationMode = SerializationMode.COLLECTIONS_OF_BYTES;
assertEquals(40, calculateArraySize(numVectors, vectorLength, serializationMode));
}
}
Loading