Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #714 from Particular/encrypt-column
Browse files Browse the repository at this point in the history
Change read code to support Always Encrypted
WilliamBZA authored Nov 18, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
2 parents 6b6fe72 + d85153c commit c1147b3
Showing 11 changed files with 66 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ public async Task SetUp()

await ResetQueue(addressParser, sqlConnectionFactory);

queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName);
queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, false);
}

[Test]
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ public void Prepare()
async Task PrepareAsync()
{
var addressParser = new QueueAddressTranslator("nservicebus", "dbo", null, null);
var tableCache = new TableBasedQueueCache(addressParser);
var tableCache = new TableBasedQueueCache(addressParser, true);

await CreateOutputQueueIfNecessary(addressParser, sqlConnectionFactory);

@@ -119,8 +119,8 @@ Task PurgeOutputQueue(QueueAddressTranslator addressTranslator)
{
purger = new QueuePurger(sqlConnectionFactory);
var queueAddress = addressTranslator.Parse(validAddress).QualifiedTableName;
queue = new TableBasedQueue(queueAddress, validAddress);

queue = new TableBasedQueue(queueAddress, validAddress, true);
return purger.Purge(queue);
}

Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ public async Task SetUp()

await CreateQueueIfNotExists(addressParser, sqlConnectionFactory);

queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName);
queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, true);
}

[Test]
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ public async Task Should_stop_pumping_messages_after_first_unsuccessful_receive(

var pump = new MessagePump(
m => new ProcessWithNoTransaction(sqlConnectionFactory, null),
qa => qa == "input" ? (TableBasedQueue)inputQueue : new TableBasedQueue(parser.Parse(qa).QualifiedTableName, qa),
qa => qa == "input" ? (TableBasedQueue)inputQueue : new TableBasedQueue(parser.Parse(qa).QualifiedTableName, qa, true),
new QueuePurger(sqlConnectionFactory),
new NoOpExpiredMessagesPurger(),
new QueuePeeker(sqlConnectionFactory, new QueuePeekerOptions()),
@@ -82,7 +82,7 @@ class FakeTableBasedQueue : TableBasedQueue
int queueSize;
int successfulReceives;

public FakeTableBasedQueue(string address, int queueSize, int successfulReceives) : base(address, "")
public FakeTableBasedQueue(string address, int queueSize, int successfulReceives) : base(address, "", true)
{
this.queueSize = queueSize;
this.successfulReceives = successfulReceives;
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ public void Prepare()
async Task PrepareAsync()
{
var addressParser = new QueueAddressTranslator("nservicebus", "dbo", null, new QueueSchemaAndCatalogSettings());
var tableCache = new TableBasedQueueCache(addressParser);
var tableCache = new TableBasedQueueCache(addressParser, true);

var connectionString = Environment.GetEnvironmentVariable("SqlServerTransportConnectionString");
if (string.IsNullOrEmpty(connectionString))
@@ -142,7 +142,7 @@ Task PurgeOutputQueue(QueueAddressTranslator addressParser)
{
purger = new QueuePurger(sqlConnectionFactory);
var queueAddress = addressParser.Parse(validAddress);
queue = new TableBasedQueue(queueAddress.QualifiedTableName, queueAddress.Address);
queue = new TableBasedQueue(queueAddress.QualifiedTableName, queueAddress.Address, true);

return purger.Purge(queue);
}
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ public TransportConfigurationResult Configure(SettingsHolder settings, Transport
var localAddress = settings.EndpointName();
return new TransportConfigurationResult
{
TransportInfrastructure = new SqlServerTransportInfrastructure("nservicebus", settings, connectionString, () => localAddress, () => logicalAddress)
TransportInfrastructure = new SqlServerTransportInfrastructure("nservicebus", settings, connectionString, () => localAddress, () => logicalAddress, false)
};
}

14 changes: 9 additions & 5 deletions src/NServiceBus.Transport.SqlServer/Queuing/MessageRow.cs
Original file line number Diff line number Diff line change
@@ -16,9 +16,9 @@ class MessageRow
{
MessageRow() { }

public static async Task<MessageReadResult> Read(SqlDataReader dataReader)
public static async Task<MessageReadResult> Read(SqlDataReader dataReader, bool isStreamSupported)
{
var row = await ReadRow(dataReader).ConfigureAwait(false);
var row = await ReadRow(dataReader, isStreamSupported).ConfigureAwait(false);
return row.TryParse();
}

@@ -46,17 +46,16 @@ public void PrepareSendCommand(SqlCommand command)
AddParameter(command, "Body", SqlDbType.VarBinary, bodyBytes, -1);
}

static async Task<MessageRow> ReadRow(SqlDataReader dataReader)
static async Task<MessageRow> ReadRow(SqlDataReader dataReader, bool isStreamSupported)
{
//HINT: we are assuming that dataReader is sequential. Order or reads is important !
return new MessageRow
{
id = await dataReader.GetFieldValueAsync<Guid>(0).ConfigureAwait(false),
correlationId = await GetNullableAsync<string>(dataReader, 1).ConfigureAwait(false),
replyToAddress = await GetNullableAsync<string>(dataReader, 2).ConfigureAwait(false),
expired = await dataReader.GetFieldValueAsync<int>(3).ConfigureAwait(false) == 1,
headers = await GetHeaders(dataReader, 4).ConfigureAwait(false),
bodyBytes = await GetBody(dataReader, 5).ConfigureAwait(false)
bodyBytes = isStreamSupported ? await GetBody(dataReader, 5).ConfigureAwait(false) : await GetNonStreamBody(dataReader, 5).ConfigureAwait(false)
};
}

@@ -106,6 +105,11 @@ static async Task<byte[]> GetBody(SqlDataReader dataReader, int bodyIndex)
}
}

static Task<byte[]> GetNonStreamBody(SqlDataReader dataReader, int bodyIndex)
{
return Task.FromResult((byte[])dataReader[bodyIndex]);
}

static async Task<T> GetNullableAsync<T>(SqlDataReader dataReader, int index) where T : class
{
if (await dataReader.IsDBNullAsync(index).ConfigureAwait(false))
17 changes: 12 additions & 5 deletions src/NServiceBus.Transport.SqlServer/Queuing/TableBasedQueue.cs
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ class TableBasedQueue
{
public string Name { get; }

public TableBasedQueue(string qualifiedTableName, string queueName)
public TableBasedQueue(string qualifiedTableName, string queueName, bool isStreamSupported)
{
#pragma warning disable 618
this.qualifiedTableName = qualifiedTableName;
@@ -28,6 +28,7 @@ public TableBasedQueue(string qualifiedTableName, string queueName)
checkExpiresIndexCommand = Format(SqlConstants.CheckIfExpiresIndexIsPresent, this.qualifiedTableName);
checkNonClusteredRowVersionIndexCommand = Format(SqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, this.qualifiedTableName);
checkHeadersColumnTypeCommand = Format(SqlConstants.CheckHeadersColumnType, this.qualifiedTableName);
this.isStreamSupported = isStreamSupported;
#pragma warning restore 618
}

@@ -70,17 +71,22 @@ public Task Send(OutgoingMessage message, TimeSpan timeToBeReceived, SqlConnecti
return SendRawMessage(messageRow, connection, transaction);
}

static async Task<MessageReadResult> ReadMessage(SqlCommand command)
async Task<MessageReadResult> ReadMessage(SqlCommand command)
{
// We need sequential access to not buffer everything into memory
using (var dataReader = await command.ExecuteReaderAsync(CommandBehavior.SingleRow | CommandBehavior.SequentialAccess).ConfigureAwait(false))
var behavior = CommandBehavior.SingleRow;
if (isStreamSupported)
{
behavior |= CommandBehavior.SequentialAccess;
}

using (var dataReader = await command.ExecuteReaderAsync(behavior).ConfigureAwait(false))
{
if (!await dataReader.ReadAsync().ConfigureAwait(false))
{
return MessageReadResult.NoMessage;
}

return await MessageRow.Read(dataReader).ConfigureAwait(false);
return await MessageRow.Read(dataReader, isStreamSupported).ConfigureAwait(false);
}
}

@@ -177,5 +183,6 @@ public override string ToString()
string checkExpiresIndexCommand;
string checkNonClusteredRowVersionIndexCommand;
string checkHeadersColumnTypeCommand;
bool isStreamSupported;
}
}
Original file line number Diff line number Diff line change
@@ -5,21 +5,23 @@ namespace NServiceBus.Transport.SqlServer

class TableBasedQueueCache
{
public TableBasedQueueCache(QueueAddressTranslator addressTranslator)
public TableBasedQueueCache(QueueAddressTranslator addressTranslator, bool isStreamSupported)
{
this.addressTranslator = addressTranslator;
this.isStreamSupported = isStreamSupported;
}

public TableBasedQueue Get(string destination)
{
var address = addressTranslator.Parse(destination);
var key = Tuple.Create(address.QualifiedTableName, address.Address);
var queue = cache.GetOrAdd(key, x => new TableBasedQueue(x.Item1, x.Item2));
var queue = cache.GetOrAdd(key, x => new TableBasedQueue(x.Item1, x.Item2, isStreamSupported));

return queue;
}

QueueAddressTranslator addressTranslator;
ConcurrentDictionary<Tuple<string, string>, TableBasedQueue> cache = new ConcurrentDictionary<Tuple<string, string>, TableBasedQueue>();
bool isStreamSupported;
}
}
26 changes: 25 additions & 1 deletion src/NServiceBus.Transport.SqlServer/SqlServerTransport.cs
Original file line number Diff line number Diff line change
@@ -43,8 +43,9 @@ static bool LegacyMultiInstanceModeTurnedOn(SettingsHolder settings)
public override TransportInfrastructure Initialize(SettingsHolder settings, string connectionString)
{
var catalog = GetDefaultCatalog(settings, connectionString);
var isEncrypted = IsEncrypted(settings, connectionString);

return new SqlServerTransportInfrastructure(catalog, settings, connectionString, settings.LocalAddress, settings.LogicalAddress);
return new SqlServerTransportInfrastructure(catalog, settings, connectionString, settings.LocalAddress, settings.LogicalAddress, isEncrypted);
}

static string GetDefaultCatalog(SettingsHolder settings, string connectionString)
@@ -71,5 +72,28 @@ static string GetDefaultCatalog(SettingsHolder settings, string connectionString
}
throw new Exception("Initial Catalog property is mandatory in the connection string.");
}

static bool IsEncrypted(SettingsHolder settings, string connectionString)
{
if (settings.TryGet(SettingsKeys.ConnectionFactoryOverride, out Func<Task<SqlConnection>> factoryOverride))
{
using (var connection = factoryOverride().GetAwaiter().GetResult())
{
connectionString = connection.ConnectionString;
}
}

var parser = new DbConnectionStringBuilder
{
ConnectionString = connectionString
};

if (parser.TryGetValue("Column Encryption Setting", out var enabled))
{
return ((string)enabled).Equals("enabled", StringComparison.InvariantCultureIgnoreCase);
}

return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -23,12 +23,13 @@ namespace NServiceBus.Transport.SqlServer
/// </summary>
class SqlServerTransportInfrastructure : TransportInfrastructure
{
internal SqlServerTransportInfrastructure(string catalog, SettingsHolder settings, string connectionString, Func<string> localAddress, Func<LogicalAddress> logicalAddress)
internal SqlServerTransportInfrastructure(string catalog, SettingsHolder settings, string connectionString, Func<string> localAddress, Func<LogicalAddress> logicalAddress, bool isEncrypted)
{
this.settings = settings;
this.connectionString = connectionString;
this.localAddress = localAddress;
this.logicalAddress = logicalAddress;
this.isEncrypted = isEncrypted;

if (settings.HasSetting(SettingsKeys.DisableNativePubSub))
{
@@ -43,7 +44,7 @@ internal SqlServerTransportInfrastructure(string catalog, SettingsHolder setting

var queueSchemaSettings = settings.GetOrDefault<QueueSchemaAndCatalogSettings>();
addressTranslator = new QueueAddressTranslator(catalog, "dbo", defaultSchemaOverride, queueSchemaSettings);
tableBasedQueueCache = new TableBasedQueueCache(addressTranslator);
tableBasedQueueCache = new TableBasedQueueCache(addressTranslator, !isEncrypted);
connectionFactory = CreateConnectionFactory();

//Configure the schema and catalog for logical endpoint-based routing
@@ -156,7 +157,7 @@ public override TransportReceiveInfrastructure ConfigureReceiveInfrastructure()

var schemaVerification = new SchemaInspector(queue => connectionFactory.OpenNewConnection(), validateExpiredIndex);

Func<string, TableBasedQueue> queueFactory = queueName => new TableBasedQueue(addressTranslator.Parse(queueName).QualifiedTableName, queueName);
Func<string, TableBasedQueue> queueFactory = queueName => new TableBasedQueue(addressTranslator.Parse(queueName).QualifiedTableName, queueName, !isEncrypted);

//Create delayed delivery infrastructure
CanonicalQueueAddress delayedQueueCanonicalAddress = null;
@@ -398,6 +399,8 @@ public override string MakeCanonicalForm(string transportAddress)
ISubscriptionStore subscriptionStore;
IDelayedMessageStore delayedMessageStore = new SendOnlyDelayedMessageStore();
TableBasedQueueCache tableBasedQueueCache;
bool isEncrypted;

static ILog Logger = LogManager.GetLogger<SqlServerTransportInfrastructure>();
}
}

0 comments on commit c1147b3

Please sign in to comment.