Skip to content

Commit

Permalink
Allow max connection pool setting to be used (#1431)
Browse files Browse the repository at this point in the history
* Split ConnectionPoolValidator for sqlserver and postgresql

* Move validation of connection pooling to the initialize phase

---------

Co-authored-by: Jo Palac <[email protected]>
  • Loading branch information
andreasohlund and jpalac authored Sep 26, 2024
1 parent 3f053a3 commit 9d7ff39
Show file tree
Hide file tree
Showing 15 changed files with 134 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;
using NServiceBus;
using NServiceBus.Transport;
using NServiceBus.Transport.PostgreSql;
using NServiceBus.TransportTests;
using NUnit.Framework;
using QueueAddress = NServiceBus.Transport.QueueAddress;

public class ConfigurePostgreSqlTransportInfrastructure : IConfigureTransportInfrastructure
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
namespace NServiceBus.Transport.PostgreSql.UnitTests
{
using NUnit.Framework;

[TestFixture]
public class ConnectionPoolValidatorTests
{
[Test]
public void Is_not_validated_when_connection_pooling_not_specified()
{
var result = ConnectionPoolValidator.Validate("Database = xxx");

Assert.That(result.IsValid, Is.False);
}

[Test]
public void Is_validated_when_both_min_and_max_pool_size_is_specified()
{
var result = ConnectionPoolValidator.Validate("Database = xxx; minimum pool size = 20; maximum pool size=120");

Assert.That(result.IsValid, Is.True);
}

[Test]
public void Is_not_validated_when_only_min_pool_size_is_specified()
{
var result = ConnectionPoolValidator.Validate("Database = xxx; Minimum Pool Size = 20;");

Assert.That(result.IsValid, Is.False);
}

[Test]
public void Is_not_validated_when_pooling_is_enabled_and_no_min_and_max_is_set()
{
var result = ConnectionPoolValidator.Validate("Database = xxx; Pooling = true");

Assert.That(result.IsValid, Is.False);
}

[Test]
public void Is_validated_when_pooling_is_disabled()
{
var result = ConnectionPoolValidator.Validate("Database = xxx; Pooling = false");

Assert.That(result.IsValid, Is.True);
}

[Test]
public void Parses_pool_disable_values_with_yes_or_no()
{
var result = ConnectionPoolValidator.Validate("Database = xxx; Pooling = no");

Assert.That(result.IsValid, Is.True);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace NServiceBus.Transport.PostgreSql;

using System;
using System.Data.Common;
using Sql.Shared;

static class ConnectionPoolValidator
{
public static ValidationCheckResult Validate(string connectionString)
{
var keys = new DbConnectionStringBuilder { ConnectionString = connectionString };
var hasPoolingValue = keys.TryGetValue("Pooling", out object poolingValue);
if (hasPoolingValue && !string.Equals(poolingValue.ToString(), "true", StringComparison.InvariantCultureIgnoreCase))
{
return ValidationCheckResult.Valid();
}
if (keys.ContainsKey("Maximum Pool Size"))
{
return ValidationCheckResult.Valid();
}
return ValidationCheckResult.Invalid(ConnectionPoolSizeNotSet);
}

const string ConnectionPoolSizeNotSet =
"Maximum connection pooling value (Maximum Pool Size=N) is not " +
"configured on the provided connection string. The default value (100) will be used.";
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ public PostgreSqlDbConnectionFactory(string connectionString)
{
openNewConnection = async cancellationToken =>
{
ValidateConnectionPool(connectionString);

var connection = new NpgsqlConnection(connectionString);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PostgreSqlTransportInfrastructure : TransportInfrastructure
IDelayedMessageStore delayedMessageStore = new SendOnlyDelayedMessageStore();
PostgreSqlDbConnectionFactory connectionFactory;

static ILog _logger = LogManager.GetLogger<PostgreSqlTransportInfrastructure>();
static ILog Logger = LogManager.GetLogger<PostgreSqlTransportInfrastructure>();
readonly PostgreSqlExceptionClassifier exceptionClassifier;

public PostgreSqlTransportInfrastructure(PostgreSqlTransport transport, HostSettings hostSettings,
Expand Down Expand Up @@ -72,6 +72,8 @@ public override string ToTransportAddress(Transport.QueueAddress address)
{
connectionFactory = CreateConnectionFactory();

await ValidateDatabaseAccess(cancellationToken).ConfigureAwait(false);

addressTranslator = new QueueAddressTranslator("public", transport.DefaultSchema, transport.Schema);

tableBasedQueueCache = new TableBasedQueueCache(
Expand Down Expand Up @@ -192,7 +194,7 @@ async Task ConfigureReceiveInfrastructure(CancellationToken cancellationToken)

if (receiveSetting.PurgeOnStartup)
{
_logger.Warn($"The {receiveSetting.PurgeOnStartup} should only be used in the development environment.");
Logger.Warn($"The {receiveSetting.PurgeOnStartup} should only be used in the development environment.");
}

return new MessageReceiver(transport, receiveSetting.Id, receiveAddress, receiveSetting.ErrorQueue,
Expand All @@ -201,8 +203,6 @@ async Task ConfigureReceiveInfrastructure(CancellationToken cancellationToken)
subscriptionManager, receiveSetting.PurgeOnStartup, exceptionClassifier);
}).ToDictionary<MessageReceiver, string, IMessageReceiver>(receiver => receiver.Id, receiver => receiver);

await ValidateDatabaseAccess(cancellationToken).ConfigureAwait(false);

var receiveAddresses = Receivers.Values.Select(r => r.ReceiveAddress).ToList();

if (hostSettings.SetupInfrastructure)
Expand Down Expand Up @@ -246,8 +246,14 @@ async Task TryOpenDatabaseConnection(CancellationToken cancellationToken)
{
try
{
await using (await connectionFactory.OpenNewConnection(cancellationToken).ConfigureAwait(false))
using (var connection = await connectionFactory.OpenNewConnection(cancellationToken).ConfigureAwait(false))
{
var result = ConnectionPoolValidator.Validate(connection.ConnectionString);

if (!result.IsValid)
{
Logger.Warn(result.Message);
}
}
}
catch (Exception ex) when (!ex.IsCausedBy(cancellationToken))
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,16 @@
using System.Threading.Tasks;
using System.Threading;
using System;
using NServiceBus.Logging;

public abstract class DbConnectionFactory
{
public DbConnectionFactory(Func<CancellationToken, Task<DbConnection>> factory)
{
openNewConnection = factory;
}
protected DbConnectionFactory(Func<CancellationToken, Task<DbConnection>> factory) => openNewConnection = factory;

protected DbConnectionFactory()
{
}

public async Task<DbConnection> OpenNewConnection(CancellationToken cancellationToken = default)
{
var connection = await openNewConnection(cancellationToken).ConfigureAwait(false);

ValidateConnectionPool(connection.ConnectionString);

return connection;
}

protected void ValidateConnectionPool(string connectionString)
{
if (hasValidated)
{
return;
}

var validationResult = ConnectionPoolValidator.Validate(connectionString);
if (!validationResult.IsValid)
{
Logger.Warn(validationResult.Message);
}

hasValidated = true;
}

static bool hasValidated;
public Task<DbConnection> OpenNewConnection(CancellationToken cancellationToken = default) => openNewConnection(cancellationToken);

protected Func<CancellationToken, Task<DbConnection>> openNewConnection;

static ILog Logger = LogManager.GetLogger<DbConnectionFactory>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
<InternalsVisibleTo Include="NServiceBus.Transport.SqlServer" Key="$(NServiceBusKey)" />
<InternalsVisibleTo Include="NServiceBus.Transport.SqlServer.UnitTests" Key="$(NServiceBusTestsKey)" />
<InternalsVisibleTo Include="NServiceBus.Transport.PostgreSql" Key="$(NServiceBusKey)" />
<InternalsVisibleTo Include="NServiceBus.Transport.PostgreSql.UnitTests" Key="$(NServiceBusTestsKey)" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Threading;
using System.Threading.Tasks;
using NUnit.Framework;
using Sql.Shared.Queuing;
using SqlServer;

public class When_checking_schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Transactions;
using NUnit.Framework;
using Sql.Shared.Queuing;
using Sql.Shared.Receiving;
using SqlServer;
using Transport;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ namespace NServiceBus.Transport.SqlServer.IntegrationTests
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using Microsoft.Data.SqlClient;
using System.Threading.Tasks;
using Extensibility;
using NUnit.Framework;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
namespace NServiceBus.Transport.SqlServer.UnitTests
{
using NUnit.Framework;
using Sql.Shared;

[TestFixture]
public class ConnectionPoolValidatorTests
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace NServiceBus.Transport.SqlServer;

using System;
using System.Data.Common;
using Sql.Shared;

static class ConnectionPoolValidator
{
public static ValidationCheckResult Validate(string connectionString)
{
var keys = new DbConnectionStringBuilder { ConnectionString = connectionString };
var hasPoolingValue = keys.TryGetValue("Pooling", out object poolingValue);
if (hasPoolingValue && !string.Equals(poolingValue.ToString(), "true", StringComparison.InvariantCultureIgnoreCase))
{
return ValidationCheckResult.Valid();
}
if (keys.ContainsKey("Max Pool Size"))
{
return ValidationCheckResult.Valid();
}
return ValidationCheckResult.Invalid(ConnectionPoolSizeNotSet);
}

const string ConnectionPoolSizeNotSet =
"Maximum connection pooling value (Max Pool Size=N) is not " +
"configured on the provided connection string. The default value (100) will be used.";
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ public SqlServerDbConnectionFactory(Func<CancellationToken, Task<DbConnection>>
{
}


public SqlServerDbConnectionFactory(string connectionString)
{
openNewConnection = async cancellationToken =>
{
ValidateConnectionPool(connectionString);

var connection = new SqlConnection(connectionString);
try
{
Expand All @@ -46,6 +43,6 @@ public SqlServerDbConnectionFactory(string connectionString)
};
}

static ILog Logger = LogManager.GetLogger<SqlServerDbConnectionFactory>();
static readonly ILog Logger = LogManager.GetLogger<SqlServerDbConnectionFactory>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ public async Task Initialize(CancellationToken cancellationToken = default)
}
}

var result = ConnectionPoolValidator.Validate(connectionString);

if (!result.IsValid)
{
Logger.Warn(result.Message);
}

connectionAttributes = ConnectionAttributesParser.Parse(connectionString, transport.DefaultCatalog);

addressTranslator = new QueueAddressTranslator(connectionAttributes.Catalog, "dbo", transport.DefaultSchema, transport.SchemaAndCatalog);
Expand Down Expand Up @@ -320,7 +327,7 @@ async Task TryEscalateToDistributedTransactions(TransactionOptions transactionOp

if (!string.IsNullOrWhiteSpace(message))
{
_logger.Warn(message);
Logger.Warn(message);
}
}
}
Expand Down Expand Up @@ -361,6 +368,6 @@ public override Task Shutdown(CancellationToken cancellationToken = default)
Distributed transactions are not available on Linux. The other transaction modes can be used by setting the `SqlServerTransport.TransportTransactionMode` property when configuring the endpoint.
Be aware that different transaction modes affect consistency guarantees since distributed transactions won't be atomically updating the resources together with consuming the incoming message.";

static ILog _logger = LogManager.GetLogger<SqlServerTransportInfrastructure>();
static ILog Logger = LogManager.GetLogger<SqlServerTransportInfrastructure>();
}
}

0 comments on commit 9d7ff39

Please sign in to comment.