diff --git a/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java b/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java index 02902e5d32..ec8a8e3a61 100644 --- a/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java +++ b/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java @@ -200,6 +200,9 @@ int doWrite(final ByteBuffer src) throws IOException { @Override public void truncateWrites() throws IOException { + if (anyAreSet(state, FLAG_FINISHED)) { + return; + } try { if (lastChunkBuffer != null) { lastChunkBuffer.close(); @@ -259,6 +262,9 @@ public long transferFrom(final StreamSourceChannel source, final long count, fin @Override public boolean flush() throws IOException { + if (anyAreSet(state, FLAG_FINISHED)) { + return true; + } this.state |= FLAG_FIRST_DATA_WRITTEN; if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) { if (anyAreSet(state, FLAG_NEXT_SHUTDOWN)) { diff --git a/core/src/test/java/io/undertow/conduits/ChunkedStreamSinkConduitTest.java b/core/src/test/java/io/undertow/conduits/ChunkedStreamSinkConduitTest.java new file mode 100644 index 0000000000..5b820d69e2 --- /dev/null +++ b/core/src/test/java/io/undertow/conduits/ChunkedStreamSinkConduitTest.java @@ -0,0 +1,177 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2024 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.conduits; + +import io.undertow.connector.ByteBufferPool; +import io.undertow.server.DefaultByteBufferPool; +import io.undertow.testutils.category.UnitTest; +import io.undertow.util.AbstractAttachable; +import io.undertow.util.HeaderMap; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.xnio.XnioIoThread; +import org.xnio.XnioWorker; +import org.xnio.channels.StreamSourceChannel; +import org.xnio.conduits.StreamSinkConduit; +import org.xnio.conduits.WriteReadyHandler; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.junit.Assert.*; + +/** + * Test case for UNDERTOW-2424. + */ +@Category(UnitTest.class) +public class ChunkedStreamSinkConduitTest { + + @Test + public void testChunkedStreamSinkConduit() throws IOException { + ByteBufferPool pool = new DefaultByteBufferPool(false, 1024, -1, -1); + AtomicLong written = new AtomicLong(); + AtomicInteger flushes = new AtomicInteger(); + AtomicInteger listenerInvocations = new AtomicInteger(); + StreamSinkConduit next = new StreamSinkConduit() { + + @Override + public long transferFrom(FileChannel src, long position, long count) throws IOException { + written.addAndGet(count); + return count; + } + + @Override + public long transferFrom(StreamSourceChannel source, long count, ByteBuffer throughBuffer) throws IOException { + written.addAndGet(count); + return count; + } + + @Override + public int write(ByteBuffer src) { + int remaining = src.remaining(); + src.position(src.position() + remaining); + written.addAndGet(remaining); + return remaining; + } + + @Override + public long write(ByteBuffer[] srcs, int offs, int len) throws IOException { + long total = 0; + for (int i = offs; i < len; i++) { + int written = write(srcs[i]); + if (written == 0) { + break; + } + total += written; + } + return total; + } + + @Override + public int writeFinal(ByteBuffer src) throws IOException { + return write(src); + } + + @Override + public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException { + return write(srcs, offset, length); + } + + @Override + public void terminateWrites() { + } + + @Override + public boolean isWriteShutdown() { + return false; + } + + @Override + public void resumeWrites() { + } + + @Override + public void suspendWrites() { + } + + @Override + public void wakeupWrites() { + } + + @Override + public boolean isWriteResumed() { + return false; + } + + @Override + public void awaitWritable() { + } + + @Override + public void awaitWritable(long time, TimeUnit timeUnit) { + + } + + @Override + public XnioIoThread getWriteThread() { + return null; + } + + @Override + public void setWriteReadyHandler(WriteReadyHandler handler) { + + } + + @Override + public void truncateWrites() { + + } + + @Override + public boolean flush() { + flushes.incrementAndGet(); + return true; + } + + @Override + public XnioWorker getWorker() { + return null; + } + }; + ConduitListener listener = channel -> listenerInvocations.incrementAndGet(); + ChunkedStreamSinkConduit conduit = new ChunkedStreamSinkConduit(next, pool, false, false, new HeaderMap(), listener, new AbstractAttachable() {}); + + assertEquals(5, conduit.write(ByteBuffer.wrap("Hello".getBytes(StandardCharsets.UTF_8)))); + assertEquals("Expected 11 bytes to be flushed including chunk headers", 11, written.get()); + assertEquals(0, flushes.get()); + conduit.terminateWrites(); + assertTrue(conduit.flush()); + int flushesAfterTerminate = flushes.get(); + assertTrue(conduit.flush()); + // UNDERTOW-2424: If this isn't the case, invocations from response wrappers may invoke flush on persistent + // connections that are already being used to process other requests on other threads. + assertEquals("Expected flushing after termination not to have any impact", flushesAfterTerminate, flushes.get()); + assertEquals(1, listenerInvocations.get()); + } +}