Skip to content

Commit

Permalink
SNOW-1569293 Read encryption headers in a case insensitive way (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-knozderko authored Dec 4, 2024
1 parent 6d0ba1f commit f7201b2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 39 deletions.
35 changes: 34 additions & 1 deletion Snowflake.Data.Tests/UnitTests/SFAzureClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Snowflake.Data.Tests.UnitTests
using Azure;
using Azure.Storage.Blobs.Models;

[TestFixture]
[TestFixture, NonParallelizable]
class SFAzureClientTest : SFBaseTest
{
// Mock data for file metadata
Expand Down Expand Up @@ -377,5 +377,38 @@ public async Task TestDownloadFileAsync(HttpStatusCode httpStatusCode, ResultSta
// Assert
Assert.AreEqual(expectedResultStatus.ToString(), _fileMetadata.resultStatus);
}

[Test]
public void TestEncryptionMetadataReadingIsCaseInsensitive()
{
// arrange
var metadata = new Dictionary<string, string>
{
{
"ENCRYPTIONDATA",
@"{
""ContentEncryptionIV"": ""initVector"",
""WrappedContentKey"": {
""EncryptedKey"": ""key""
}
}"
},
{ "MATDESC", "description" },
{ "SFCDIGEST", "something"}
};
var blobProperties = BlobsModelFactory.BlobProperties(metadata: metadata, contentLength: 10);
var mockBlobServiceClient = new Mock<BlobServiceClient>();
_client = new SFSnowflakeAzureClient(_fileMetadata.stageInfo, mockBlobServiceClient.Object);

// act
var fileHeader = _client.HandleFileHeaderResponse(ref _fileMetadata, blobProperties);

// assert
Assert.AreEqual(ResultStatus.UPLOADED.ToString(), _fileMetadata.resultStatus);
Assert.AreEqual("something", fileHeader.digest);
Assert.AreEqual("initVector", fileHeader.encryptionMetadata.iv);
Assert.AreEqual("key", fileHeader.encryptionMetadata.key);
Assert.AreEqual("description", fileHeader.encryptionMetadata.matDesc);
}
}
}
63 changes: 51 additions & 12 deletions Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
*/

using System;
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.FileTransfer.StorageClient;
using Snowflake.Data.Core.FileTransfer;
using System.Collections.Generic;
using System.Net;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using System.Threading;
using Snowflake.Data.Tests.Mock;
using Moq;

namespace Snowflake.Data.Tests.UnitTests
{
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.FileTransfer.StorageClient;
using Snowflake.Data.Core.FileTransfer;
using System.Collections.Generic;
using System.Net;
using System.IO;
using System.Threading.Tasks;
using System.Threading;
using Snowflake.Data.Tests.Mock;
using Moq;

[TestFixture, NonParallelizable]
class SFGCSClientTest : SFBaseTest
{
Expand Down Expand Up @@ -371,6 +372,44 @@ public void TestUseUriWithRegionsWhenNeeded(string region, string endPoint, bool
Assert.AreEqual(expectedRequestUri, uri);
}

[Test]
[TestCase("some-header-name", "SOME-HEADER-NAME")]
[TestCase("SOME-HEADER-NAME", "some-header-name")]
public void TestGcsHeadersAreCaseInsensitiveForHttpResponseMessage(string headerNameToAdd, string headerNameToGet)
{
// arrange
const string HeaderValue = "someValue";
var responseMessage = new HttpResponseMessage( HttpStatusCode.OK ) {Content = new StringContent( "Response content" ) };
responseMessage.Headers.Add(headerNameToAdd, HeaderValue);

// act
var header = responseMessage.Headers.GetValues(headerNameToGet);

// assert
Assert.NotNull(header);
Assert.AreEqual(1, header.Count());
Assert.AreEqual(HeaderValue, header.First());
}

[Test]
[TestCase("some-header-name", "SOME-HEADER-NAME")]
[TestCase("SOME-HEADER-NAME", "some-header-name")]
public void TestGcsHeadersAreCaseInsensitiveForWebHeaderCollection(string headerNameToAdd, string headerNameToGet)
{
// arrange
const string HeaderValue = "someValue";
var headers = new WebHeaderCollection();
headers.Add(headerNameToAdd, HeaderValue);

// act
var header = headers.GetValues(headerNameToGet);

// assert
Assert.NotNull(header);
Assert.AreEqual(1, header.Count());
Assert.AreEqual(HeaderValue, header.First());
}

private void AssertForDownloadFileTests(ResultStatus expectedResultStatus)
{
if (expectedResultStatus == ResultStatus.DOWNLOADED)
Expand Down
55 changes: 38 additions & 17 deletions Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
/*
* Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
* Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
*/

using System;
using Amazon.S3.Encryption;
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.FileTransfer.StorageClient;
using Snowflake.Data.Core.FileTransfer;
using System.Collections.Generic;
using Amazon.S3;
using Snowflake.Data.Tests.Mock;
using System.Threading.Tasks;
using Amazon;
using System.Threading;
using System.IO;
using Moq;
using Amazon.S3.Model;

namespace Snowflake.Data.Tests.UnitTests
{
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.FileTransfer.StorageClient;
using Snowflake.Data.Core.FileTransfer;
using System.Collections.Generic;
using Amazon.S3;
using Snowflake.Data.Tests.Mock;
using System.Threading.Tasks;
using Amazon;
using System.Threading;
using System.IO;
using Moq;
using Amazon.S3.Model;

[TestFixture]
[TestFixture, NonParallelizable]
class SFS3ClientTest : SFBaseTest
{
// Mock data for file metadata
Expand Down Expand Up @@ -320,6 +318,29 @@ public async Task TestDownloadFileAsync(string awsStatusCode, ResultStatus expec
AssertForDownloadFileTests(expectedResultStatus);
}

[Test]
public void TestEncryptionMetadataReadingIsCaseInsensitive()
{
// arrange
var mockAmazonS3Client = new Mock<AmazonS3Client>(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig);
_client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object);
var response = new GetObjectResponse();
response.Metadata.Add(SFS3Client.AMZ_IV.ToUpper(), "initVector");
response.Metadata.Add(SFS3Client.AMZ_KEY.ToUpper(), "key");
response.Metadata.Add(SFS3Client.AMZ_MATDESC.ToUpper(), "description");
response.Metadata.Add(SFS3Client.SFC_DIGEST.ToUpper(), "something");

// act
var fileHeader = _client.HandleFileHeaderResponse(ref _fileMetadata, response);

// assert
Assert.AreEqual(ResultStatus.UPLOADED.ToString(), _fileMetadata.resultStatus);
Assert.AreEqual("something", fileHeader.digest);
Assert.AreEqual("initVector", fileHeader.encryptionMetadata.iv);
Assert.AreEqual("key", fileHeader.encryptionMetadata.key);
Assert.AreEqual("description", fileHeader.encryptionMetadata.matDesc);
}

private void AssertForDownloadFileTests(ResultStatus expectedResultStatus)
{
if (expectedResultStatus == ResultStatus.DOWNLOADED)
Expand Down
23 changes: 18 additions & 5 deletions Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Snowflake.Data.Log;
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -266,26 +267,38 @@ private GetObjectRequest GetFileHeaderRequest(ref AmazonS3Client client, SFFileM
/// <param name="fileMetadata">The S3 file metadata.</param>
/// <param name="response">The Amazon S3 response.</param>
/// <returns>The file header of the S3 file.</returns>
private FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, GetObjectResponse response)
internal FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, GetObjectResponse response)
{
// Update the result status of the file metadata
fileMetadata.resultStatus = ResultStatus.UPLOADED.ToString();

SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata
{
iv = response.Metadata[AMZ_IV],
key = response.Metadata[AMZ_KEY],
matDesc = response.Metadata[AMZ_MATDESC]
iv = GetMetadataCaseInsensitive(response.Metadata, AMZ_IV),
key = GetMetadataCaseInsensitive(response.Metadata, AMZ_KEY),
matDesc = GetMetadataCaseInsensitive(response.Metadata, AMZ_MATDESC)
};

return new FileHeader
{
digest = response.Metadata[SFC_DIGEST],
digest = GetMetadataCaseInsensitive(response.Metadata, SFC_DIGEST),
contentLength = response.ContentLength,
encryptionMetadata = encryptionMetadata
};
}

private string GetMetadataCaseInsensitive(MetadataCollection metadataCollection, string metadataKey)
{
var value = metadataCollection[metadataKey];
if (value != null)
return value;
if (string.IsNullOrEmpty(metadataKey))
return null;
var keysCaseInsensitive = metadataCollection.Keys
.Where(key => $"x-amz-meta-{metadataKey}".Equals(key, StringComparison.OrdinalIgnoreCase));
return keysCaseInsensitive.Any() ? metadataCollection[keysCaseInsensitive.First()] : null;
}

/// <summary>
/// Set the client configuration common to both client with and without client-side
/// encryption.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Azure;
using Azure.Storage.Blobs.Models;
using Newtonsoft.Json;
Expand Down Expand Up @@ -154,30 +155,48 @@ public async Task<FileHeader> GetFileHeaderAsync(SFFileMetadata fileMetadata, Ca
/// <param name="fileMetadata">The S3 file metadata.</param>
/// <param name="response">The Amazon S3 response.</param>
/// <returns>The file header of the S3 file.</returns>
private FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, BlobProperties response)
internal FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, BlobProperties response)
{
fileMetadata.resultStatus = ResultStatus.UPLOADED.ToString();

SFEncryptionMetadata encryptionMetadata = null;
if (response.Metadata.TryGetValue("encryptiondata", out var encryptionDataStr))
if (TryGetMetadataValueCaseInsensitive(response, "encryptiondata", out var encryptionDataStr))
{
dynamic encryptionData = JsonConvert.DeserializeObject(encryptionDataStr);
encryptionMetadata = new SFEncryptionMetadata
{
iv = encryptionData["ContentEncryptionIV"],
key = encryptionData.WrappedContentKey["EncryptedKey"],
matDesc = response.Metadata["matdesc"]
matDesc = GetMetadataValueCaseInsensitive(response, "matdesc")
};
}

return new FileHeader
{
digest = response.Metadata["sfcdigest"],
digest = GetMetadataValueCaseInsensitive(response, "sfcdigest"),
contentLength = response.ContentLength,
encryptionMetadata = encryptionMetadata
};
}

private bool TryGetMetadataValueCaseInsensitive(BlobProperties properties, string metadataKey, out string metadataValue)
{
if (properties.Metadata.TryGetValue(metadataKey, out metadataValue))
return true;
if (string.IsNullOrEmpty(metadataKey))
return false;
var keysCaseInsensitive = properties.Metadata.Keys
.Where(key => metadataKey.Equals(key, StringComparison.OrdinalIgnoreCase));
return keysCaseInsensitive.Any() ? properties.Metadata.TryGetValue(keysCaseInsensitive.First(), out metadataValue) : false;
}

private string GetMetadataValueCaseInsensitive(BlobProperties properties, string metadataKey)
{
if (TryGetMetadataValueCaseInsensitive(properties, metadataKey, out var metadataValue))
return metadataValue;
throw new KeyNotFoundException($"The given key '{metadataKey}' was not present in the dictionary.");
}

/// <summary>
/// Upload the file to the Azure location.
/// </summary>
Expand Down

0 comments on commit f7201b2

Please sign in to comment.