Skip to content

Commit

Permalink
[dvc][server][samza] Heap size estimation improvement (#1281)
Browse files Browse the repository at this point in the history
Introduced two new utilities to make our on-heap memory usage assessment more 
accurate, and easier to maintain as class hierarchies evolve:

- ClassSizeEstimator: Predicts (based on assumptions about how the JVM's
  memory layout works) the shallow size of a class. This includes the object
  header, all primitive fields, and all references to other objects (but it
  does not count these other objects, hence the shallowness). Reflection is
  used in this class.
- InstanceSizeEstimator: Predicts the size of instances of a limited number
  of classes. This is not a general-purpose utility, and it requires some
  manual effort to onboard a new class. Reflection is not used in this class.

The general design goals are the following:

- Reflection should only be used once per class per runtime, and the result
  of this logic should be stored in static constants.
- On the hot path, there should be no reflection, and we should leverage our
  knowledge of the Venice code base to determine which objects are meant to
  be counted or not. For example, singleton or otherwise shared instances
  should not be counted, since their amortized cost is negligible (besides
  the size of the pointer to refer to them).

The above utilities have been integrated in all classes that implement the
Measurable interface, and several new classes have been given this interface
as well. The Measurable::getSize function has been renamed getHeapSize, to
minimize the chance that it could clash with other function names, and to
make it extra clear what kind of size is meant.

Miscellaneous:

- Minor efficiency improvements to PubSubMessageHeaders and ApacheKafkaUtils
  so that empty headers (a common case) carry less overhead. Also made the
  PubSubMessageHeaders implement Iterable.
- Created a DefaultLeaderMetadata static class in VeniceWriter, so that a
  shared instance can be leveraged in cases where that object is always the
  same (e.g. when producing to the RT topic).
- BlobSnapshotManagerTest improvements:
  - Added timeouts to all tests.
  - Fixed a race condition in testMultipleThreads.
  • Loading branch information
FelixGV authored Nov 14, 2024
1 parent 06bf58c commit 020f099
Show file tree
Hide file tree
Showing 40 changed files with 1,810 additions and 185 deletions.
3 changes: 2 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ subprojects {
}
classes = classDirs.asFileTree.matching { exclude generatedClasses }
auxClassPaths += classDirs.asFileTree.matching { include generatedClasses }.each {
println "Excluding generated class ${project.relativePath(it)}"
// Muted to reduce noise, but can be uncommented to debug exclusions
// println "Excluding generated class ${project.relativePath(it)}"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,9 @@ public String toString() {
return "PubSubMessage{" + topicPartition + ", offset=" + offset + ", timestamp=" + timestamp + ", isEndOfBootstrap="
+ isEndOfBootstrap + '}';
}

@Override
public int getHeapSize() {
throw new UnsupportedOperationException("getHeapSize is not supported on " + this.getClass().getSimpleName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,7 @@ protected void getAndUpdateLeaderCompletedState(
&& Arrays.equals(kafkaKey.getKey(), KafkaKey.HEART_BEAT.getKey())) {
LeaderCompleteState oldState = partitionConsumptionState.getLeaderCompleteState();
LeaderCompleteState newState = oldState;
for (PubSubMessageHeader header: pubSubMessageHeaders.toList()) {
for (PubSubMessageHeader header: pubSubMessageHeaders) {
if (header.key().equals(VENICE_LEADER_COMPLETION_STATE_HEADER)) {
newState = LeaderCompleteState.valueOf(header.value()[0]);
partitionConsumptionState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import static com.linkedin.venice.kafka.protocol.enums.MessageType.CONTROL_MESSAGE;
import static com.linkedin.venice.kafka.protocol.enums.MessageType.DELETE;
import static com.linkedin.venice.kafka.protocol.enums.MessageType.PUT;
import static com.linkedin.venice.memory.ClassSizeEstimator.getClassOverhead;
import static com.linkedin.venice.memory.InstanceSizeEstimator.getSize;

import com.linkedin.venice.kafka.protocol.ControlMessage;
import com.linkedin.venice.kafka.protocol.Delete;
import com.linkedin.venice.kafka.protocol.Put;
import com.linkedin.venice.kafka.protocol.enums.MessageType;
import com.linkedin.venice.memory.Measurable;
import java.util.concurrent.CompletableFuture;


Expand All @@ -25,7 +28,9 @@
* drainer thread completes the persistedToDBFuture.
*/

public class LeaderProducedRecordContext {
public class LeaderProducedRecordContext implements Measurable {
private static final int PARTIAL_CLASS_OVERHEAD =
getClassOverhead(LeaderProducedRecordContext.class) + getClassOverhead(CompletableFuture.class);
private static final int NO_UPSTREAM = -1;
/**
* Kafka cluster ID where the source kafka consumer record was consumed from.
Expand Down Expand Up @@ -235,4 +240,26 @@ private static void checkConsumedOffsetParam(long consumedOffset) {
throw new IllegalArgumentException("consumedOffset cannot be negative");
}
}

@Override
public int getHeapSize() {
int size = PARTIAL_CLASS_OVERHEAD + getSize(this.keyBytes);
switch (this.messageType) {
case PUT:
size += getSize((Put) this.valueUnion);
break;
case CONTROL_MESSAGE:
size += getSize((ControlMessage) this.valueUnion);
break;
default:
/**
* Only the above two cases contribute any size.
*
* {@link DELETE} contributes nothing, and {@link com.linkedin.venice.kafka.protocol.enums.MessageType.UPDATE}
* should never happen.
*/
break;
}
return size;
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package com.linkedin.davinci.kafka.consumer;

import static com.linkedin.venice.utils.ProtocolUtils.getEstimateOfMessageEnvelopeSizeOnHeap;
import static com.linkedin.venice.memory.ClassSizeEstimator.getClassOverhead;
import static java.util.Collections.reverseOrder;
import static java.util.Comparator.comparing;
import static java.util.stream.Collectors.toList;

import com.linkedin.davinci.stats.StoreBufferServiceStats;
import com.linkedin.davinci.utils.LockAssistedCompletableFuture;
import com.linkedin.venice.common.Measurable;
import com.linkedin.venice.exceptions.VeniceChecksumException;
import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.kafka.protocol.KafkaMessageEnvelope;
import com.linkedin.venice.memory.ClassSizeEstimator;
import com.linkedin.venice.memory.Measurable;
import com.linkedin.venice.message.KafkaKey;
import com.linkedin.venice.pubsub.api.PubSubMessage;
import com.linkedin.venice.pubsub.api.PubSubTopicPartition;
Expand Down Expand Up @@ -71,42 +72,53 @@ public StoreBufferService(
boolean queueLeaderWrites,
MetricsRepository metricsRepository,
boolean sorted) {
this.drainerNum = drainerNum;
this.blockingQueueArr = new ArrayList<>();
this.bufferCapacityPerDrainer = bufferCapacityPerDrainer;
for (int cur = 0; cur < drainerNum; ++cur) {
this.blockingQueueArr.add(new MemoryBoundBlockingQueue<>(bufferCapacityPerDrainer, bufferNotifyDelta));
}
this.isSorted = sorted;
this.leaderRecordHandler = queueLeaderWrites ? this::queueLeaderRecord : StoreBufferService::processRecord;
String metricNamePrefix = sorted ? "StoreBufferServiceSorted" : "StoreBufferServiceUnsorted";
this.storeBufferServiceStats = new StoreBufferServiceStats(
metricsRepository,
metricNamePrefix,
this::getTotalMemoryUsage,
this::getTotalRemainingMemory,
this::getMaxMemoryUsagePerDrainer,
this::getMinMemoryUsagePerDrainer);
this(drainerNum, bufferCapacityPerDrainer, bufferNotifyDelta, queueLeaderWrites, null, metricsRepository, sorted);
}

/**
* Constructor for testing
* Package-private constructor for testing
*/
public StoreBufferService(
StoreBufferService(
int drainerNum,
long bufferCapacityPerDrainer,
long bufferNotifyDelta,
boolean queueLeaderWrites,
StoreBufferServiceStats stats) {
this(drainerNum, bufferCapacityPerDrainer, bufferNotifyDelta, queueLeaderWrites, stats, null, true);
}

/**
* Shared code for the main and test constructors.
*
* N.B.: Either {@param stats} or {@param metricsRepository} should be null, but not both. If neither are null, then
* we default to the main code's expected path, meaning that the metric repo will be used to construct a
* {@link StoreBufferServiceStats} instance, and the passed in stats object will be ignored.
*/
private StoreBufferService(
int drainerNum,
long bufferCapacityPerDrainer,
long bufferNotifyDelta,
boolean queueLeaderWrites,
StoreBufferServiceStats stats,
MetricsRepository metricsRepository,
boolean sorted) {
this.drainerNum = drainerNum;
this.blockingQueueArr = new ArrayList<>();
this.bufferCapacityPerDrainer = bufferCapacityPerDrainer;
for (int cur = 0; cur < drainerNum; ++cur) {
this.blockingQueueArr.add(new MemoryBoundBlockingQueue<>(bufferCapacityPerDrainer, bufferNotifyDelta));
}
this.isSorted = sorted;
this.leaderRecordHandler = queueLeaderWrites ? this::queueLeaderRecord : StoreBufferService::processRecord;
this.storeBufferServiceStats = stats;
this.isSorted = true;
this.storeBufferServiceStats = metricsRepository == null
? Objects.requireNonNull(stats)
: new StoreBufferServiceStats(
Objects.requireNonNull(metricsRepository),
sorted ? "StoreBufferServiceSorted" : "StoreBufferServiceUnsorted",
this::getTotalMemoryUsage,
this::getTotalRemainingMemory,
this::getMaxMemoryUsagePerDrainer,
this::getMinMemoryUsagePerDrainer);
}

protected MemoryBoundBlockingQueue<QueueNode> getDrainerForConsumerRecord(
Expand Down Expand Up @@ -378,11 +390,8 @@ public long getMinMemoryUsagePerDrainer() {
/**
* Queue node type in {@link BlockingQueue} of each drainer thread.
*/
private static class QueueNode implements Measurable {
/**
* Considering the overhead of {@link PubSubMessage} and its internal structures.
*/
private static final int QUEUE_NODE_OVERHEAD_IN_BYTE = 256;
static class QueueNode implements Measurable {
private static final int SHALLOW_CLASS_OVERHEAD = ClassSizeEstimator.getClassOverhead(QueueNode.class);
private final PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long> consumerRecord;
private final StoreIngestionTask ingestionTask;
private final String kafkaUrl;
Expand Down Expand Up @@ -447,15 +456,14 @@ public int hashCode() {
return consumerRecord.hashCode();
}

protected int getBaseClassOverhead() {
return SHALLOW_CLASS_OVERHEAD;
}

@Override
public int getSize() {
// For FakePubSubMessage, the key and the value are null.
if (consumerRecord instanceof FakePubSubMessage) {
return QUEUE_NODE_OVERHEAD_IN_BYTE;
}
// N.B.: This is just an estimate. TODO: Consider if it is really useful, and whether to get rid of it.
return this.consumerRecord.getKey().getEstimatedObjectSizeOnHeap()
+ getEstimateOfMessageEnvelopeSizeOnHeap(this.consumerRecord.getValue()) + QUEUE_NODE_OVERHEAD_IN_BYTE;
public int getHeapSize() {
/** The other non-primitive fields point to shared instances and are therefore ignored. */
return getBaseClassOverhead() + consumerRecord.getHeapSize();
}

@Override
Expand All @@ -465,6 +473,13 @@ public String toString() {
}

private static class FollowerQueueNode extends QueueNode {
/**
* N.B.: We don't want to recurse fully into the {@link CompletableFuture}, but we do want to take into account an
* "empty" one.
*/
private static final int PARTIAL_CLASS_OVERHEAD =
getClassOverhead(FollowerQueueNode.class) + getClassOverhead(CompletableFuture.class);

private final CompletableFuture<Void> queuedRecordPersistedFuture;

public FollowerQueueNode(
Expand All @@ -491,9 +506,16 @@ public int hashCode() {
public boolean equals(Object o) {
return super.equals(o);
}

@Override
protected int getBaseClassOverhead() {
return PARTIAL_CLASS_OVERHEAD;
}
}

private static class LeaderQueueNode extends QueueNode {
static class LeaderQueueNode extends QueueNode {
private static final int SHALLOW_CLASS_OVERHEAD = ClassSizeEstimator.getClassOverhead(LeaderQueueNode.class);

private final LeaderProducedRecordContext leaderProducedRecordContext;

public LeaderQueueNode(
Expand All @@ -520,9 +542,21 @@ public int hashCode() {
public boolean equals(Object o) {
return super.equals(o);
}

@Override
protected int getBaseClassOverhead() {
return SHALLOW_CLASS_OVERHEAD + leaderProducedRecordContext.getHeapSize();
}
}

private static class CommandQueueNode extends QueueNode {
/**
* N.B.: We don't want to recurse fully into the {@link CompletableFuture}, but we do want to take into account an
* "empty" one.
*/
private static final int PARTIAL_CLASS_OVERHEAD =
getClassOverhead(CommandQueueNode.class) + getClassOverhead(LockAssistedCompletableFuture.class);

enum CommandType {
// only supports SYNC_OFFSET command today.
SYNC_OFFSET
Expand Down Expand Up @@ -589,6 +623,10 @@ public int hashCode() {
public boolean equals(Object o) {
return super.equals(o);
}

protected int getBaseClassOverhead() {
return PARTIAL_CLASS_OVERHEAD;
}
}

/**
Expand Down Expand Up @@ -714,6 +752,7 @@ public void run() {
}

private static class FakePubSubMessage implements PubSubMessage {
private static final int SHALLOW_CLASS_OVERHEAD = ClassSizeEstimator.getClassOverhead(FakePubSubMessage.class);
private final PubSubTopicPartition topicPartition;

FakePubSubMessage(PubSubTopicPartition topicPartition) {
Expand Down Expand Up @@ -754,5 +793,11 @@ public int getPayloadSize() {
public boolean isEndOfBootstrap() {
return false;
}

@Override
public int getHeapSize() {
/** We assume that {@link #topicPartition} is a singleton instance, and therefore we're not counting it. */
return SHALLOW_CLASS_OVERHEAD;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertTrue;

import com.linkedin.davinci.storage.StorageEngineRepository;
import com.linkedin.davinci.storage.StorageMetadataService;
Expand All @@ -19,12 +20,14 @@
import com.linkedin.venice.meta.ReadOnlyStoreRepository;
import com.linkedin.venice.meta.Store;
import com.linkedin.venice.store.rocksdb.RocksDBUtils;
import com.linkedin.venice.utils.Time;
import com.linkedin.venice.utils.Utils;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
Expand All @@ -36,6 +39,7 @@


public class BlobSnapshotManagerTest {
private static final int TIMEOUT = 30 * Time.MS_PER_SECOND;
private static final String STORE_NAME = "test-store";
private static final int VERSION_ID = 1;
private static final String TOPIC_NAME = STORE_NAME + "_v" + VERSION_ID;
Expand All @@ -51,7 +55,7 @@ public class BlobSnapshotManagerTest {
private static final BlobTransferPayload blobTransferPayload =
new BlobTransferPayload(BASE_PATH, STORE_NAME, VERSION_ID, PARTITION_ID);

@Test
@Test(timeOut = TIMEOUT)
public void testHybridSnapshot() {
AbstractStorageEngine storageEngine = Mockito.mock(AbstractStorageEngine.class);
Mockito.doReturn(storageEngine).when(storageEngineRepository).getLocalStorageEngine(TOPIC_NAME);
Expand All @@ -76,7 +80,7 @@ public void testHybridSnapshot() {
Assert.assertEquals(actualBlobTransferPartitionMetadata, blobTransferPartitionMetadata);
}

@Test
@Test(timeOut = TIMEOUT)
public void testSameSnapshotWhenConcurrentUsersNotExceedMaxAllowedUsers() {
Store mockStore = mock(Store.class);

Expand Down Expand Up @@ -105,7 +109,7 @@ public void testSameSnapshotWhenConcurrentUsersNotExceedMaxAllowedUsers() {
Assert.assertEquals(actualBlobTransferPartitionMetadata, blobTransferPartitionMetadata);
}

@Test
@Test(timeOut = TIMEOUT)
public void testSameSnapshotWhenConcurrentUsersExceedsMaxAllowedUsers() {
Store mockStore = mock(Store.class);

Expand Down Expand Up @@ -145,7 +149,7 @@ public void testSameSnapshotWhenConcurrentUsersExceedsMaxAllowedUsers() {
BlobSnapshotManager.DEFAULT_MAX_CONCURRENT_USERS);
}

@Test
@Test(timeOut = TIMEOUT)
public void testTwoRequestUsingSameOffset() {
// Prepare
Store mockStore = mock(Store.class);
Expand Down Expand Up @@ -180,8 +184,8 @@ public void testTwoRequestUsingSameOffset() {
blobTransferPartitionMetadata);
}

@Test
public void testMultipleThreads() {
@Test(timeOut = TIMEOUT)
public void testMultipleThreads() throws InterruptedException {
final int numberOfThreads = 2;
final ExecutorService asyncExecutor = Executors.newFixedThreadPool(numberOfThreads);
final CountDownLatch latch = new CountDownLatch(numberOfThreads);
Expand Down Expand Up @@ -218,10 +222,12 @@ public void testMultipleThreads() {
Assert.assertEquals(e.getMessage(), errorMessage);
}

assertTrue(latch.await(TIMEOUT / 2, TimeUnit.MILLISECONDS));

Assert.assertEquals(blobSnapshotManager.getConcurrentSnapshotUsers(TOPIC_NAME, PARTITION_ID), 0);
}

@Test
@Test(timeOut = TIMEOUT)
public void testCreateSnapshotForBatch() throws RocksDBException {
try (MockedStatic<Checkpoint> checkpointMockedStatic = Mockito.mockStatic(Checkpoint.class)) {
try (MockedStatic<FileUtils> fileUtilsMockedStatic = Mockito.mockStatic(FileUtils.class)) {
Expand Down
Loading

0 comments on commit 020f099

Please sign in to comment.