diff --git a/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java b/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java
index 88d7837a..fe9ca413 100644
--- a/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java
+++ b/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java
@@ -19,7 +19,6 @@
import org.ice4j.*;
import org.ice4j.attribute.*;
-import org.ice4j.ice.*;
import org.ice4j.message.*;
import org.ice4j.socket.*;
import org.ice4j.util.*;
@@ -37,7 +36,7 @@
import static org.ice4j.ice.harvest.HarvestConfig.config;
/**
- * A class which holds a {@link DatagramSocket} and runs a thread
+ * A class which holds a {@link SocketPool} and runs a thread
* ({@link #thread}) which perpetually reads from it.
*
* When a datagram from an unknown source is received, it is parsed as a STUN
@@ -196,13 +195,18 @@ static String getUfrag(byte[] buf, int off, int len)
*/
protected final TransportAddress localAddress;
+ /**
+ * The pool of sockets available for writing.
+ */
+ private final SocketPool socketPool;
+
/**
* The "main" socket that this harvester reads from.
*/
- private final DatagramSocket socket;
+ private final DatagramSocket receiveSocket;
/**
- * The thread reading from {@link #socket}.
+ * The thread reading from {@link #receiveSocket}.
*/
private final Thread thread;
@@ -236,12 +240,14 @@ protected AbstractUdpListener(TransportAddress localAddress)
);
}
- socket = new DatagramSocket( tempAddress );
+ socketPool = new SocketPool( tempAddress, config.udpSocketPoolSize() );
+
+ receiveSocket = socketPool.getReceiveSocket();
Integer receiveBufferSize = config.udpReceiveBufferSize();
if (receiveBufferSize != null)
{
- socket.setReceiveBufferSize(receiveBufferSize);
+ receiveSocket.setReceiveBufferSize(receiveBufferSize);
}
/* Update the port number if needed. */
@@ -249,7 +255,7 @@ protected AbstractUdpListener(TransportAddress localAddress)
{
tempAddress = new TransportAddress(
tempAddress.getAddress(),
- socket.getLocalPort(),
+ receiveSocket.getLocalPort(),
tempAddress.getTransport()
);
}
@@ -257,11 +263,12 @@ protected AbstractUdpListener(TransportAddress localAddress)
String logMessage
= "Initialized AbstractUdpListener with address " + this.localAddress;
- logMessage += ". Receive buffer size " + socket.getReceiveBufferSize();
+ logMessage += ". Receive buffer size " + receiveSocket.getReceiveBufferSize();
if (receiveBufferSize != null)
{
logMessage += " (asked for " + receiveBufferSize + ")";
}
+ logMessage += "; socket pool size " + socketPool.getNumSockets();
logger.info(logMessage);
thread = new Thread(() ->
@@ -292,11 +299,11 @@ public TransportAddress getLocalAddress()
public void close()
{
close = true;
- socket.close(); // causes socket#receive to stop blocking.
+ socketPool.close(); // causes socket#receive to stop blocking.
}
/**
- * Perpetually reads datagrams from {@link #socket} and handles them
+ * Perpetually reads datagrams from {@link #receiveSocket} and handles them
* accordingly.
*
* It is important that this blocks are little as possible (except on
@@ -326,7 +333,7 @@ private void runInHarvesterThread()
try
{
- socket.receive(pkt);
+ receiveSocket.receive(pkt);
}
catch (IOException ioe)
{
@@ -376,13 +383,13 @@ private void runInHarvesterThread()
{
candidateSocket.close();
}
- socket.close();
+ socketPool.close();
}
/**
* Read packets from the socket and forward them via the push API. Note that the memory model here is different
* than the other case. Specifically, we:
- * 1. Receive from {@link #socket} into a fixed buffer
+ * 1. Receive from {@link #receiveSocket} into a fixed buffer
* 2. Obtain a buffer of the required size using {@link BufferPool#getBuffer}
* 3. Copy the data into the buffer and either
* 3.1 Call the associated {@link BufferHandler} if the packet is payload
@@ -410,7 +417,7 @@ private void runInHarvesterThreadPush()
try
{
- socket.receive(pkt);
+ receiveSocket.receive(pkt);
receivedTime = clock.instant();
}
catch (IOException ioe)
@@ -467,7 +474,7 @@ private void runInHarvesterThreadPush()
{
candidateSocket.close();
}
- socket.close();
+ socketPool.close();
}
private Buffer bufferFromPacket(DatagramPacket p, Instant receivedTime)
@@ -478,7 +485,7 @@ private Buffer bufferFromPacket(DatagramPacket p, Instant receivedTime)
System.arraycopy(p.getData(), p.getOffset(), buffer.getBuffer(), off, p.getLength());
buffer.setOffset(off);
buffer.setLength(p.getLength());
- buffer.setLocalAddress(socket.getLocalSocketAddress());
+ buffer.setLocalAddress(receiveSocket.getLocalSocketAddress());
buffer.setRemoteAddress(p.getSocketAddress());
buffer.setReceivedTime(receivedTime);
@@ -808,14 +815,14 @@ public void receive(DatagramPacket p)
/**
* {@inheritDoc}
*
- * Delegates to the actual socket of the harvester.
+ * Delegates to the socket pool.
*/
@Override
public void send(DatagramPacket p)
throws IOException
{
p.setSocketAddress(remoteAddress);
- socket.send(p);
+ socketPool.send(p);
}
}
}
diff --git a/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt b/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt
index 5c554f18..6ac60132 100644
--- a/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt
+++ b/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt
@@ -41,6 +41,12 @@ class HarvestConfig {
}
fun udpReceiveBufferSize() = udpReceiveBufferSize
+ val udpSocketPoolSize: Int by config {
+ "ice4j.harvest.udp.socket-pool-size".from(configSource)
+ }
+
+ fun udpSocketPoolSize() = udpSocketPoolSize
+
val useIpv6: Boolean by config {
"org.ice4j.ipv6.DISABLED".from(configSource)
.transformedBy { !it }
diff --git a/src/main/kotlin/org/ice4j/socket/SocketPool.kt b/src/main/kotlin/org/ice4j/socket/SocketPool.kt
new file mode 100644
index 00000000..624e5b23
--- /dev/null
+++ b/src/main/kotlin/org/ice4j/socket/SocketPool.kt
@@ -0,0 +1,114 @@
+/*
+ * Copyright @ 2020 - Present, 8x8 Inc
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.ice4j.socket
+
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.DatagramSocketImpl
+import java.net.SocketAddress
+import java.nio.channels.DatagramChannel
+
+/** A pool of datagram sockets all bound on the same port.
+ *
+ * This is necessary to allow multiple threads to send packets simultaneously from the same source address,
+ * in JDK 15 and later, because the [DatagramChannel]-based implementation of [DatagramSocketImpl] introduced
+ * in that version locks the socket during a call to [DatagramSocket.send].
+ *
+ * (The old [DatagramSocketImpl] implementation can be used by setting the system property
+ * `jdk.net.usePlainDatagramSocketImpl` in JDK versions 15 through 17, but was removed in versions 18 and later.)
+ *
+ * This feature may also be useful on older JDK versions on non-Linux operating systems, such as macOS,
+ * which block simultaneous writes through the same UDP socket at the operating system level.
+ *
+ * The sockets are opened such that packets will be _received_ on exactly one socket.
+ */
+class SocketPool(
+ /** The address to which to bind the pool of sockets. */
+ address: SocketAddress,
+ /** The number of sockets to create for the pool. If this is set to zero (the default), the number
+ * will be set automatically to an appropriate value.
+ */
+ requestedNumSockets: Int = 0
+) {
+ init {
+ require(requestedNumSockets >= 0) { "RequestedNumSockets must be >= 0" }
+ }
+
+ internal class SocketAndIndex(
+ val socket: DatagramSocket,
+ var count: Int = 0
+ )
+
+ val numSockets: Int =
+ if (requestedNumSockets != 0) {
+ requestedNumSockets
+ } else {
+ // TODO: set this to 1 in situations where pools aren't needed?
+ Runtime.getRuntime().availableProcessors()
+ }
+
+ private val sockets = buildList {
+ val multipleSockets = numSockets > 1
+ var bindAddr = address
+ for (i in 0 until numSockets) {
+ val sock = DatagramSocket(null)
+ if (multipleSockets) {
+ sock.reuseAddress = true
+ }
+ sock.bind(bindAddr)
+ if (i == 0 && multipleSockets) {
+ bindAddr = sock.localSocketAddress
+ }
+ add(SocketAndIndex(sock, 0))
+ }
+ }
+
+ /** The socket on which packets will be received. */
+ val receiveSocket: DatagramSocket
+ // On all platforms I've tested, the last-bound socket is the one which receives packets.
+ // TODO: should we support Linux's flavor of SO_REUSEPORT, in which packets can be received on *all* the
+ // sockets, spreading load?
+ get() = sockets.last().socket
+
+ fun send(packet: DatagramPacket) {
+ val sendSocket = getSendSocket()
+ sendSocket.socket.send(packet)
+ returnSocket(sendSocket)
+ }
+
+ /** Gets a socket on which packets can be sent, chosen from among all the available send sockets. */
+ internal fun getSendSocket(): SocketAndIndex {
+ if (numSockets == 1) {
+ return sockets.first()
+ }
+ synchronized(sockets) {
+ val min = sockets.minBy { it.count }
+ min.count++
+
+ return min
+ }
+ }
+
+ internal fun returnSocket(socket: SocketAndIndex) {
+ synchronized(sockets) {
+ socket.count--
+ }
+ }
+
+ fun close() {
+ sockets.forEach { it.socket.close() }
+ }
+}
diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf
index 09955622..9906d4cb 100644
--- a/src/main/resources/reference.conf
+++ b/src/main/resources/reference.conf
@@ -57,6 +57,10 @@ ice4j {
// Whether to allocate ephemeral ports for local candidates. This is the default value, and can be overridden
// for Agent instances.
use-dynamic-ports = true
+
+ // The size of the socket pool to use to send packets on the "single port" harvester. 0 means the
+ // default (Java's reported number of available processors). 1 is equivalent to not using a socket pool.
+ socket-pool-size = 0
}
// The list of IP addresses that are allowed to be used for host candidate allocations. When empty, any address is
diff --git a/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java b/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java
index 3535e227..d6f8aeb5 100644
--- a/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java
+++ b/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java
@@ -31,51 +31,6 @@
*/
public class SinglePortUdpHarvesterTest
{
- /**
- * Verifies that, without closing, the address used by a harvester cannot be re-used.
- *
- * @see https://github.com/jitsi/ice4j/issues/139
- */
- @Test
- public void testRebindWithoutCloseThrows() throws Exception
- {
- // Setup test fixture.
- final TransportAddress address = new TransportAddress( "127.0.0.1", 10000, Transport.UDP );
- SinglePortUdpHarvester firstHarvester;
- try
- {
- firstHarvester = new SinglePortUdpHarvester( address );
- }
- catch (BindException ex)
- {
- // This is not expected at this stage (the port is likely already in use by another process, voiding this
- // test). Rethrow as a different exception than the BindException, that is expected to be thrown later in
- // this test.
- throw new Exception( "Test fixture is invalid.", ex );
- }
-
- // Execute system under test.
- SinglePortUdpHarvester secondHarvester = null;
- try
- {
- secondHarvester = new SinglePortUdpHarvester( address );
- fail("expected BindException to be thrown at this point");
- }
- catch (BindException ex)
- {
- //expected, do nothing
- }
- finally
- {
- // Tear down
- firstHarvester.close();
- if (secondHarvester != null)
- {
- secondHarvester.close();
- }
- }
- }
-
/**
* Verifies that, after closing, the address used by a harvester can be re-used.
*
diff --git a/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt b/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt
index 0fbf396a..8ed65d26 100644
--- a/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt
+++ b/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt
@@ -30,6 +30,7 @@ class HarvestConfigTest : ConfigTest() {
config.useIpv6 shouldBe true
config.useLinkLocalAddresses shouldBe true
config.udpReceiveBufferSize shouldBe null
+ config.udpSocketPoolSize shouldBe 0
config.stunMappingCandidateHarvesterAddresses shouldBe emptyList()
}
context("Setting via legacy config (system properties)") {
@@ -39,6 +40,7 @@ class HarvestConfigTest : ConfigTest() {
config.useIpv6 shouldBe false
config.useLinkLocalAddresses shouldBe false
config.udpReceiveBufferSize shouldBe 555
+ config.udpSocketPoolSize shouldBe 0
config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.legacy:555", "stun2.legacy")
}
}
@@ -49,6 +51,7 @@ class HarvestConfigTest : ConfigTest() {
config.useIpv6 shouldBe false
config.useLinkLocalAddresses shouldBe false
config.udpReceiveBufferSize shouldBe 666
+ config.udpSocketPoolSize shouldBe 3
config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.new:666", "stun2.new")
}
}
@@ -60,6 +63,7 @@ class HarvestConfigTest : ConfigTest() {
config.useIpv6 shouldBe false
config.useLinkLocalAddresses shouldBe false
config.udpReceiveBufferSize shouldBe 555
+ config.udpSocketPoolSize shouldBe 0
config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.legacy:555", "stun2.legacy")
}
}
@@ -153,6 +157,7 @@ private val newConfigNonDefault = """
udp {
receive-buffer-size = 666
use-dynamic-ports = false
+ socket-pool-size = 3
}
mapping {
stun {
diff --git a/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt
new file mode 100644
index 00000000..b82bcb4f
--- /dev/null
+++ b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt
@@ -0,0 +1,266 @@
+package org.ice4j.socket
+
+import io.kotest.core.spec.style.ShouldSpec
+import io.kotest.core.test.Enabled
+import io.kotest.core.test.TestCase
+import io.kotest.matchers.comparables.shouldBeLessThan
+import io.kotest.matchers.should
+import io.kotest.matchers.shouldBe
+import io.kotest.matchers.shouldNotBe
+import io.kotest.matchers.types.beInstanceOf
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.time.Clock
+import java.time.Duration
+import java.time.Instant
+import java.util.concurrent.CyclicBarrier
+
+private val loopbackAny = InetSocketAddress("127.0.0.1", 0)
+private val loopbackDiscard = InetSocketAddress("127.0.0.1", 9)
+
+@OptIn(io.kotest.common.ExperimentalKotest::class)
+class SocketPoolTest : ShouldSpec() {
+ init {
+ context("Creating a new socket pool") {
+ val pool = SocketPool(loopbackAny)
+ should("Bind to a random port") {
+ val local = pool.receiveSocket.localSocketAddress
+ local should beInstanceOf()
+ (local as InetSocketAddress).port shouldNotBe 0
+ }
+ pool.close()
+ }
+
+ context("Getting multiple send sockets from a pool") {
+ val numSockets = 4
+ val pool = SocketPool(loopbackAny, numSockets)
+ val sockets = mutableListOf()
+ should("be possible") {
+ repeat(numSockets) {
+ sockets.add(pool.getSendSocket().socket)
+ }
+ }
+ // All sockets should be distinct
+ sockets.toSet().size shouldBe sockets.size
+ pool.close()
+ }
+
+ context("Packets sent from each of the send sockets in the pool") {
+ val numSockets = 4
+ val pool = SocketPool(loopbackAny, numSockets)
+ val local = pool.receiveSocket.localSocketAddress
+ val sockets = mutableListOf()
+ repeat(numSockets) {
+ sockets.add(pool.getSendSocket().socket)
+ }
+ sockets.forEachIndexed { i, it ->
+ val buf = i.toString().toByteArray()
+ val packet = DatagramPacket(buf, buf.size, local)
+ it.send(packet)
+ }
+
+ should("be received") {
+ for (i in 0 until numSockets) {
+ val buf = ByteArray(1500)
+ val packet = DatagramPacket(buf, buf.size)
+ pool.receiveSocket.soTimeout = 1 // Don't block if something's wrong
+ pool.receiveSocket.receive(packet)
+ packet.data.decodeToString(0, packet.length).toInt() shouldBe i
+ packet.socketAddress shouldBe local
+ }
+ }
+ pool.close()
+ }
+
+ context("The number of send sockets") {
+ val numSockets = 4
+ val pool = SocketPool(loopbackAny, numSockets)
+
+ val sockets = mutableSetOf()
+
+ repeat(2 * numSockets) {
+ // This should cycle through all the available send sockets
+ sockets.add(pool.getSendSocket().socket)
+ }
+
+ should("be correct") {
+ sockets.size shouldBe numSockets
+ }
+
+ pool.close()
+ }
+
+ val disableIfOnlyOneCore: (TestCase) -> Enabled = {
+ if (Runtime.getRuntime().availableProcessors() > 1) {
+ Enabled.enabled
+ } else {
+ Enabled.disabled("Need multiple processors to run test")
+ }
+ }
+
+ context("Sending packets from multiple threads").config(enabledOrReasonIf = disableIfOnlyOneCore) {
+ val poolWarmup = SocketPool(loopbackAny, 1)
+ sendTimeOnAllSockets(poolWarmup)
+
+ val pool1 = SocketPool(loopbackAny, 1)
+ val elapsed1 = sendTimeOnAllSockets(pool1)
+
+ // 0 means pick the default value, currently Runtime.getRuntime().availableProcessors().
+ val poolN = SocketPool(loopbackAny, 0)
+ val elapsedN = sendTimeOnAllSockets(poolN)
+
+ elapsedN shouldBeLessThan elapsed1 // Very weak test
+ }
+
+ val enableOnlyIfPropertySet: (TestCase) -> Enabled = {
+ if (System.getProperty("doPerfTests") != null) {
+ Enabled.enabled
+ } else {
+ Enabled.disabled("Set \"doPerfTests\" property to enable SocketPool performance tests")
+ }
+ }
+
+ context("Test sending packets from multiple threads").config(enabledOrReasonIf = enableOnlyIfPropertySet) {
+ testSending()
+ }
+ }
+ private class Sender(
+ private val count: Int,
+ private val pool: SocketPool,
+ private val destAddr: SocketAddress
+ ) : Runnable {
+ private val buf = ByteArray(BUFFER_SIZE)
+
+ private fun sendToSocket(count: Int) {
+ for (i in 0 until count) {
+ pool.send(DatagramPacket(buf, BUFFER_SIZE, destAddr))
+ }
+ }
+
+ override fun run() {
+ barrier.await()
+
+ start()
+ sendToSocket(count)
+ end()
+ }
+
+ companion object {
+ private const val BUFFER_SIZE = 1500
+ const val NUM_PACKETS = 600000
+ private val clock = Clock.systemUTC()
+
+ private var start = Instant.MAX
+ private var end = Instant.MIN
+
+ val elapsed: Duration
+ get() = Duration.between(start, end)
+
+ fun start() {
+ val now = clock.instant()
+ synchronized(this) {
+ if (start.isAfter(now)) {
+ start = now
+ }
+ }
+ }
+
+ fun end() {
+ val now = clock.instant()
+ synchronized(this) {
+ if (end.isBefore(now)) {
+ end = now
+ }
+ }
+ }
+
+ private var barrier: CyclicBarrier = CyclicBarrier(1)
+
+ fun reset(numThreads: Int) {
+ barrier = CyclicBarrier(numThreads)
+ start = Instant.MAX
+ end = Instant.MIN
+ }
+ }
+ }
+
+ companion object {
+ private fun sendTimeOnAllSockets(
+ pool: SocketPool,
+ numThreads: Int = pool.numSockets,
+ numPackets: Int = Sender.NUM_PACKETS
+ ): Duration {
+ val threads = mutableListOf()
+ Sender.reset(numThreads)
+ repeat(numThreads) {
+ val thread = Thread(Sender(numPackets / numThreads, pool, loopbackDiscard))
+ threads.add(thread)
+ thread.start()
+ }
+ threads.forEach { it.join() }
+ return Sender.elapsed
+ }
+
+ private fun testSendingOnce(
+ numSockets: Int,
+ numThreads: Int,
+ numPackets: Int = Sender.NUM_PACKETS,
+ warmup: Boolean = false
+ ) {
+ val pool = SocketPool(loopbackAny, numSockets)
+ val elapsed = sendTimeOnAllSockets(pool, numThreads, numPackets)
+ if (!warmup) {
+ println(
+ "Send $numPackets packets on $numSockets sockets on $numThreads threads " +
+ "took $elapsed"
+ )
+ }
+ }
+
+ fun testSending() {
+ val numProcessors = Runtime.getRuntime().availableProcessors()
+
+ testSendingOnce(1, 1, warmup = true)
+ testSendingOnce(2 * numProcessors, 2 * numProcessors, warmup = true)
+
+ testSendingOnce(1, 1)
+ testSendingOnce(1, numProcessors)
+ testSendingOnce(1, 2 * numProcessors)
+ testSendingOnce(1, 4 * numProcessors)
+ testSendingOnce(1, 8 * numProcessors)
+
+ testSendingOnce(numProcessors, numProcessors)
+ testSendingOnce(numProcessors, 2 * numProcessors)
+ testSendingOnce(numProcessors, 4 * numProcessors)
+ testSendingOnce(numProcessors, 8 * numProcessors)
+
+ testSendingOnce(2 * numProcessors, 2 * numProcessors)
+ testSendingOnce(2 * numProcessors, 4 * numProcessors)
+ testSendingOnce(2 * numProcessors, 8 * numProcessors)
+
+ testSendingOnce(4 * numProcessors, 4 * numProcessors)
+ testSendingOnce(4 * numProcessors, 8 * numProcessors)
+
+ testSendingOnce(8 * numProcessors, 8 * numProcessors)
+ }
+
+ @JvmStatic
+ fun main(args: Array) {
+ if (args.size >= 2) {
+ val numSockets = args[0].toInt()
+ val numThreads = args[1].toInt()
+ val numPackets = if (args.size > 2) {
+ args[2].toInt()
+ } else {
+ Sender.NUM_PACKETS
+ }
+ testSendingOnce(numThreads = numThreads, numSockets = numSockets, numPackets = 10000, warmup = true)
+ testSendingOnce(numThreads = numThreads, numSockets = numSockets, numPackets = numPackets)
+ } else {
+ testSending()
+ }
+ }
+ }
+}