diff --git a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/LearningSynchronizer.java b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/LearningSynchronizer.java index f7e9e812d895..ae61b9bef0d1 100644 --- a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/LearningSynchronizer.java +++ b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/LearningSynchronizer.java @@ -79,6 +79,7 @@ public class LearningSynchronizer implements ReconnectNodeCount { private volatile AsyncOutputStream out; private final Queue rootsToReceive; + private boolean processingNullStartingRoot; // All root/custom tree views, by view ID private final Map> views; private final Deque> viewsToInitialize; @@ -137,7 +138,7 @@ public LearningSynchronizer( @NonNull final ThreadManager threadManager, @NonNull final MerkleDataInputStream in, @NonNull final MerkleDataOutputStream out, - @NonNull final MerkleNode root, + final MerkleNode root, @NonNull final Runnable breakConnection, @NonNull final ReconnectConfig reconnectConfig) { @@ -152,7 +153,12 @@ public LearningSynchronizer( views.put(viewId, nodeTreeView(root)); viewsToInitialize = new ConcurrentLinkedDeque<>(); rootsToReceive = new ConcurrentLinkedQueue<>(); - rootsToReceive.add(root); + if (root == null) { + processingNullStartingRoot = true; + } else { + processingNullStartingRoot = false; + rootsToReceive.add(root); + } this.breakConnection = breakConnection; } @@ -294,11 +300,18 @@ private synchronized boolean receiveNextSubtree( return false; } - if (rootsToReceive.isEmpty()) { - viewsInProgress.decrementAndGet(); - return false; + final MerkleNode root; + if (processingNullStartingRoot) { + assert rootsToReceive.isEmpty(); + root = null; + processingNullStartingRoot = false; + } else { + if (rootsToReceive.isEmpty()) { + viewsInProgress.decrementAndGet(); + return false; + } + root = rootsToReceive.poll(); } - final MerkleNode root = rootsToReceive.poll(); final String route = root == null ? "[]" : root.getRoute().toString(); final int viewId = nextViewId++; diff --git a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/streams/AsyncOutputStream.java b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/streams/AsyncOutputStream.java index 539f067628c9..154689cd6479 100644 --- a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/streams/AsyncOutputStream.java +++ b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/streams/AsyncOutputStream.java @@ -28,13 +28,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.concurrent.BlockingQueue; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -66,7 +63,7 @@ public class AsyncOutputStream implements AutoCloseable { /** * A queue that need to be written to the output stream. */ - private final BlockingQueue streamQueue; + private final Queue streamQueue; /** * The time that has elapsed since the last flush was attempted. @@ -113,7 +110,7 @@ public AsyncOutputStream( this.outputStream = Objects.requireNonNull(outputStream, "outputStream must not be null"); this.workGroup = Objects.requireNonNull(workGroup, "workGroup must not be null"); - this.streamQueue = new LinkedBlockingQueue<>(config.asyncStreamBufferSize() * 32); + this.streamQueue = new ConcurrentLinkedQueue<>(); this.timeSinceLastFlush = new StopWatch(); this.timeSinceLastFlush.start(); this.flushInterval = config.asyncOutputStreamFlush(); @@ -197,8 +194,8 @@ public void whenCurrentMessagesProcessed(final Runnable run) throws InterruptedE sendAsync(new QueueItem(run)); } - private void sendAsync(final QueueItem item) throws InterruptedException { - final boolean success = streamQueue.offer(item, timeout.toMillis(), TimeUnit.MILLISECONDS); + private void sendAsync(final QueueItem item) { + final boolean success = streamQueue.offer(item); if (!success) { try { outputStream.close(); @@ -228,30 +225,29 @@ public void waitForCompletion() throws InterruptedException { * @return true if a message was sent. */ private boolean handleQueuedMessages() { - if (!streamQueue.isEmpty()) { - final int size = streamQueue.size(); - final List localQueue = new ArrayList<>(size); - streamQueue.drainTo(localQueue, size); - for (final QueueItem item : localQueue) { - if (item.toNotify() != null) { - assert item.messageBytes() == null; - item.toNotify().run(); - } else { - final int viewId = item.viewId(); - final byte[] messageBytes = item.messageBytes(); - try { - outputStream.writeInt(viewId); - outputStream.writeInt(messageBytes.length); - outputStream.write(messageBytes); - } catch (final IOException e) { - throw new MerkleSynchronizationException(e); - } - bufferedMessageCount += 1; + QueueItem item = streamQueue.poll(); + if (item == null) { + return false; + } + while (item != null) { + if (item.toNotify() != null) { + assert item.messageBytes() == null; + item.toNotify().run(); + } else { + final int viewId = item.viewId(); + final byte[] messageBytes = item.messageBytes(); + try { + outputStream.writeInt(viewId); + outputStream.writeInt(messageBytes.length); + outputStream.write(messageBytes); + } catch (final IOException e) { + throw new MerkleSynchronizationException(e); } + bufferedMessageCount += 1; } - return true; + item = streamQueue.poll(); } - return false; + return true; } protected void serializeMessage(final SelfSerializable message) throws IOException {