Skip to content

Commit

Permalink
Use SQL system catalog views to check for the presence of a Recoverab…
Browse files Browse the repository at this point in the history
…le column. This removes the need for SELECT permissions to send a message to a queue table. (#1451)

Co-authored-by: Marc Wils <[email protected]>
  • Loading branch information
tmasternak and MarcWils authored Oct 17, 2024
1 parent 95b66cc commit a12efe0
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async Task SetupInputQueue()
(address, isStreamSupported) =>
{
var canonicalAddress = addressTranslator.Parse(address);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
},
s => addressTranslator.Parse(s).Address,
true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public async Task SetUp()

await ResetQueue(addressParser, dbConnectionFactory);

queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, false);
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName), QueueTableName, false);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
(address, isStreamSupported) =>
{
var canonicalAddress = addressTranslator.Parse(address);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
},
s => addressTranslator.Parse(s).Address,
true);
Expand All @@ -122,7 +122,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
Task PurgeOutputQueue(QueueAddressTranslator addressTranslator, CancellationToken cancellationToken = default)
{
purger = new QueuePurger(dbConnectionFactory);
var queueAddress = addressTranslator.Parse(ValidAddress).QualifiedTableName;
var queueAddress = addressTranslator.Parse(ValidAddress);
queue = new SqlTableBasedQueue(sqlConstants, queueAddress, ValidAddress, true);

return purger.Purge(queue, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public async Task SetUp()

await CreateQueueIfNotExists(addressParser, dbConnectionFactory);

queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, true);
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName), QueueTableName, true);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task Should_stop_receiving_messages_after_first_unsuccessful_receiv
transport.Testing.QueueFactoryOverride = qa =>
qa == inputQueueAddress
? inputQueue
: new SqlTableBasedQueue(sqlConstants, parser.Parse(qa).QualifiedTableName, qa, true);
: new SqlTableBasedQueue(sqlConstants, parser.Parse(qa), qa, true);

var receiveSettings = new ReceiveSettings("receiver", new Transport.QueueAddress(inputQueueName), true, false, "error");
var hostSettings = new HostSettings("IntegrationTests", string.Empty, new StartupDiagnosticEntries(),
Expand Down Expand Up @@ -95,7 +95,7 @@ class FakeTableBasedQueue : SqlTableBasedQueue
int queueSize;
int successfulReceives;

public FakeTableBasedQueue(SqlServerConstants sqlConstants, string address, int queueSize, int successfulReceives) : base(sqlConstants, address, "", true)
public FakeTableBasedQueue(SqlServerConstants sqlConstants, string address, int queueSize, int successfulReceives) : base(sqlConstants, new QueueAddressTranslator("nservicebus", "dbo", null, new QueueSchemaAndCatalogOptions()).Parse(address), "", true)
{
this.queueSize = queueSize;
this.successfulReceives = successfulReceives;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task Should_recover(Type contextProviderType, DispatchConsistency d
(address, isStreamSupported) =>
{
var canonicalAddress = addressTranslator.Parse(address);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
},
s => addressTranslator.Parse(s).Address,
true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
(address, isStreamSupported) =>
{
var canonicalAddress = addressTranslator.Parse(address);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
},
s => addressTranslator.Parse(s).Address,
true);
Expand All @@ -147,7 +147,7 @@ Task PurgeOutputQueue(QueueAddressTranslator addressParser, CancellationToken ca
{
purger = new QueuePurger(dbConnectionFactory);
var queueAddress = addressParser.Parse(ValidAddress);
queue = new SqlTableBasedQueue(sqlConstants, queueAddress.QualifiedTableName, queueAddress.Address, true);
queue = new SqlTableBasedQueue(sqlConstants, queueAddress, queueAddress.Address, true);

return purger.Purge(queue, cancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace NServiceBus.TransportTests;
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using NUnit.Framework;
using Transport;
using Transport.SqlServer;
Expand Down Expand Up @@ -44,15 +45,15 @@ public async Task Peeker_should_provide_accurate_queue_length_estimate(Transport
Assert.That(peekCount, Is.EqualTo(1), "A long running receive transaction should not skew the estimation for number of messages in the queue.");
}

static async Task<SqlTableBasedQueue> CreateATestQueue(SqlServerDbConnectionFactory connectionFactory)
async Task<SqlTableBasedQueue> CreateATestQueue(SqlServerDbConnectionFactory connectionFactory)
{
var queueName = "queue_length_estimation_test";

var sqlConstants = new SqlServerConstants();

var queue = new SqlTableBasedQueue(sqlConstants, queueName, queueName, false);
var queue = new SqlTableBasedQueue(sqlConstants, new CanonicalQueueAddress(queueName, "dbo", catalogName), queueName, false);

var addressTranslator = new QueueAddressTranslator("nservicebus", "dbo", null, null);
var addressTranslator = new QueueAddressTranslator(catalogName, "dbo", null, null);
var queueCreator = new QueueCreator(sqlConstants, connectionFactory, addressTranslator.Parse, false);

await queueCreator.CreateQueueIfNecessary(new[] { queueName }, null);
Expand Down Expand Up @@ -98,6 +99,11 @@ await queue.Send(
[SetUp]
public async Task Setup()
{
var connectionString = ConfigureSqlServerTransportInfrastructure.ConnectionString;
var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString);

catalogName = connectionStringBuilder.InitialCatalog;

connectionFactory = new SqlServerDbConnectionFactory(ConfigureSqlServerTransportInfrastructure.ConnectionString);

queue = await CreateATestQueue(connectionFactory);
Expand All @@ -119,6 +125,7 @@ public async Task TearDown()
await comm.ExecuteNonQueryAsync(CancellationToken.None);
}

string catalogName;
SqlTableBasedQueue queue;
SqlServerDbConnectionFactory connectionFactory;
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ THEN DATEADD(ms, @TimeToBeReceivedMs, GETUTCDATE()) END,
IF (@NOCOUNT = 'ON') SET NOCOUNT ON;
IF (@NOCOUNT = 'OFF') SET NOCOUNT OFF;";

public string CheckIfTableHasRecoverableText { get; set; } = "SELECT TOP (0) * FROM {0} WITH (NOLOCK);";
public string CheckIfTableHasRecoverableText { get; set; } = @"
SELECT COUNT(*)
FROM {0}.sys.columns c
WHERE c.object_id = OBJECT_ID(N'{1}')
AND c.name = 'Recoverable'";

public string StoreDelayedMessageText { get; set; } =
@"
Expand Down
47 changes: 23 additions & 24 deletions src/NServiceBus.Transport.SqlServer/Queuing/SqlTableBasedQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@

class SqlTableBasedQueue : TableBasedQueue
{
public SqlTableBasedQueue(SqlServerConstants sqlConstants, string qualifiedTableName, string queueName, bool isStreamSupported) :
base(sqlConstants, qualifiedTableName, queueName, isStreamSupported)
public SqlTableBasedQueue(SqlServerConstants sqlConstants, CanonicalQueueAddress queueAddress, string queueName, bool isStreamSupported) :
base(sqlConstants, queueAddress.QualifiedTableName, queueName, isStreamSupported)
{
sqlServerConstants = sqlConstants;

purgeExpiredCommand = Format(sqlConstants.PurgeBatchOfExpiredMessagesText, this.qualifiedTableName);
checkExpiresIndexCommand = Format(sqlConstants.CheckIfExpiresIndexIsPresent, this.qualifiedTableName);
checkNonClusteredRowVersionIndexCommand = Format(sqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, this.qualifiedTableName);
checkHeadersColumnTypeCommand = Format(sqlConstants.CheckHeadersColumnType, this.qualifiedTableName);
purgeExpiredCommand = Format(sqlConstants.PurgeBatchOfExpiredMessagesText, qualifiedTableName);
checkExpiresIndexCommand = Format(sqlConstants.CheckIfExpiresIndexIsPresent, qualifiedTableName);
checkNonClusteredRowVersionIndexCommand = Format(sqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, qualifiedTableName);
checkHeadersColumnTypeCommand = Format(sqlConstants.CheckHeadersColumnType, qualifiedTableName);
checkRecoverableColumnCommand = Format(sqlConstants.CheckIfTableHasRecoverableText, queueAddress.Catalog, qualifiedTableName);
}

public async Task<int> PurgeBatchOfExpiredMessages(DbConnection connection, int purgeBatchSize, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -85,7 +86,7 @@ protected override async Task SendRawMessage(MessageRow message, DbConnection co

message.PrepareSendCommand(command);

await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
_ = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
// 207 = Invalid column name
Expand Down Expand Up @@ -123,27 +124,24 @@ async Task<string> GetSendCommandText(DbConnection connection, DbTransaction tra
return sendCommand;
}

var commandText = Format(sqlServerConstants.CheckIfTableHasRecoverableText, qualifiedTableName);
using (var command = connection.CreateCommand())
{
command.CommandText = checkRecoverableColumnCommand;
command.CommandType = CommandType.Text;
command.CommandText = commandText;
command.Transaction = transaction;

using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
var rowsCount = await command.ExecuteScalarAsync<int>(nameof(checkRecoverableColumnCommand), cancellationToken).ConfigureAwait(false);
if (rowsCount > 0)
{
for (int fieldIndex = 0; fieldIndex < reader.FieldCount; fieldIndex++)
{
if (string.Equals("Recoverable", reader.GetName(fieldIndex), StringComparison.OrdinalIgnoreCase))
{
cachedSendCommand = Format(sqlServerConstants.SendTextWithRecoverable, qualifiedTableName);
return cachedSendCommand;
}
}
cachedSendCommand = Format(sqlServerConstants.SendTextWithRecoverable, qualifiedTableName);
return cachedSendCommand;
}
else
{

cachedSendCommand = Format(sqlServerConstants.SendTextWithoutRecoverable, qualifiedTableName);
return cachedSendCommand;
cachedSendCommand = Format(sqlServerConstants.SendTextWithoutRecoverable, qualifiedTableName);
return cachedSendCommand;
}
}
}
finally
Expand All @@ -153,10 +151,11 @@ async Task<string> GetSendCommandText(DbConnection connection, DbTransaction tra
}

string cachedSendCommand;
string purgeExpiredCommand;
string checkExpiresIndexCommand;
string checkNonClusteredRowVersionIndexCommand;
string checkHeadersColumnTypeCommand;
readonly string purgeExpiredCommand;
readonly string checkExpiresIndexCommand;
readonly string checkNonClusteredRowVersionIndexCommand;
readonly string checkHeadersColumnTypeCommand;
readonly string checkRecoverableColumnCommand;
readonly SemaphoreSlim sendCommandLock = new SemaphoreSlim(1, 1);
readonly SqlServerConstants sqlServerConstants;
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task Initialize(CancellationToken cancellationToken = default)
(address, isStreamSupported) =>
{
var canonicalAddress = addressTranslator.Parse(address);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
},
s => addressTranslator.Parse(s).Address,
!connectionAttributes.IsEncrypted);
Expand Down Expand Up @@ -161,7 +161,7 @@ async Task ConfigureReceiveInfrastructure(CancellationToken cancellationToken)

var schemaVerification = new SchemaInspector((queue, token) => connectionFactory.OpenNewConnection(token), validateExpiredIndex);

var queueFactory = transport.Testing.QueueFactoryOverride ?? (queueName => new SqlTableBasedQueue(sqlConstants, addressTranslator.Parse(queueName).QualifiedTableName, queueName, !connectionAttributes.IsEncrypted));
var queueFactory = transport.Testing.QueueFactoryOverride ?? (queueName => new SqlTableBasedQueue(sqlConstants, addressTranslator.Parse(queueName), queueName, !connectionAttributes.IsEncrypted));

//Create delayed delivery infrastructure
CanonicalQueueAddress delayedQueueCanonicalAddress = null;
Expand Down

0 comments on commit a12efe0

Please sign in to comment.