From 6212da4cc4a97b32aed8591bd58324a8b1328bca Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Thu, 29 Aug 2024 09:03:50 +0000 Subject: [PATCH 1/6] SNOW-1640968 dispose http message --- Snowflake.Data/Core/RestRequester.cs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Snowflake.Data/Core/RestRequester.cs b/Snowflake.Data/Core/RestRequester.cs index a38a11f70..4c2c7ee39 100644 --- a/Snowflake.Data/Core/RestRequester.cs +++ b/Snowflake.Data/Core/RestRequester.cs @@ -93,8 +93,10 @@ private async Task SendAsync(HttpMethod method, IRestRequest request, CancellationToken externalCancellationToken) { - HttpRequestMessage message = request.ToRequestMessage(method); - return await SendAsync(message, request.GetRestTimeout(), externalCancellationToken, request.getSid()).ConfigureAwait(false); + using (HttpRequestMessage message = request.ToRequestMessage(method)) + { + return await SendAsync(message, request.GetRestTimeout(), externalCancellationToken, request.getSid()).ConfigureAwait(false); + } } protected virtual async Task SendAsync(HttpRequestMessage message, @@ -130,7 +132,7 @@ protected virtual async Task SendAsync(HttpRequestMessage m } catch (Exception e) { - // Disposing of the response if not null now that we don't need it anymore + // Disposing of the response if not null now that we don't need it anymore response?.Dispose(); if (restRequestTimeout.IsCancellationRequested) { From c1eb279cb65053e74fd215ac8551e0409e5b8da5 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Thu, 29 Aug 2024 10:15:30 +0000 Subject: [PATCH 2/6] dispose streams --- .../UnitTests/ConcatenatedStreamTest.cs | 11 +- .../Util}/ConcatenatedStream.cs | 4 +- .../Authenticator/KeyPairAuthenticator.cs | 64 +++++---- .../Core/FileTransfer/EncryptionProvider.cs | 124 ++++++++++-------- .../StorageClient/SFLocalStorageUtil.cs | 2 +- .../FileTransfer/StorageClient/SFS3Client.cs | 2 +- Snowflake.Data/Core/SFBindUploader.cs | 90 +++++++------ .../Core/SFBlockingChunkDownloaderV3.cs | 16 ++- 8 files changed, 176 insertions(+), 137 deletions(-) rename {Snowflake.Data/Core => Snowflake.Data.Tests/Util}/ConcatenatedStream.cs (91%) diff --git a/Snowflake.Data.Tests/UnitTests/ConcatenatedStreamTest.cs b/Snowflake.Data.Tests/UnitTests/ConcatenatedStreamTest.cs index db06693c7..4ba3ae83d 100644 --- a/Snowflake.Data.Tests/UnitTests/ConcatenatedStreamTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConcatenatedStreamTest.cs @@ -2,13 +2,14 @@ * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. */ +using Snowflake.Data.Tests.Util; +using NUnit.Framework; +using System; +using System.IO; +using System.Text; + namespace Snowflake.Data.Tests.UnitTests { - using NUnit.Framework; - using Snowflake.Data.Core; - using System; - using System.IO; - using System.Text; [TestFixture] class ConcatenatedStreamTest diff --git a/Snowflake.Data/Core/ConcatenatedStream.cs b/Snowflake.Data.Tests/Util/ConcatenatedStream.cs similarity index 91% rename from Snowflake.Data/Core/ConcatenatedStream.cs rename to Snowflake.Data.Tests/Util/ConcatenatedStream.cs index 41d05dbb7..724d890ec 100755 --- a/Snowflake.Data/Core/ConcatenatedStream.cs +++ b/Snowflake.Data.Tests/Util/ConcatenatedStream.cs @@ -1,12 +1,12 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; using System.Collections.Generic; using System.IO; -namespace Snowflake.Data.Core +namespace Snowflake.Data.Tests.Util { /// /// Used to concat multiple streams without copying. Since we need to preappend '[' and append ']' diff --git a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs index e0c28d4ef..7d86d02c9 100644 --- a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs @@ -100,36 +100,29 @@ private string GenerateJwtToken() { try { - PemReader pr = null; - if (null != pkPwd) + using (PemReader pr = CreatePemReader(tr, pkPwd)) { - IPasswordFinder ipwdf = new PasswordFinder(pkPwd); - pr = new PemReader(tr, ipwdf); - } - else - { - pr = new PemReader(tr); - } - - object key = pr.ReadObject(); - // Infer what the pem reader is sending back based on the object properties - if (key.GetType().GetProperty("Private") != null) - { - // PKCS1 key - keypair = (AsymmetricCipherKeyPair)key; - rsaParams = DotNetUtilities.ToRSAParameters( - keypair.Private as RsaPrivateCrtKeyParameters); - } - else - { - // PKCS8 key - RsaPrivateCrtKeyParameters pk = (RsaPrivateCrtKeyParameters)key; - rsaParams = DotNetUtilities.ToRSAParameters(pk); - keypair = DotNetUtilities.GetRsaKeyPair(rsaParams); - } - if (keypair == null) - { - throw new Exception("Unknown error."); + object key = pr.ReadObject(); + // Infer what the pem reader is sending back based on the object properties + if (key.GetType().GetProperty("Private") != null) + { + // PKCS1 key + keypair = (AsymmetricCipherKeyPair)key; + rsaParams = DotNetUtilities.ToRSAParameters( + keypair.Private as RsaPrivateCrtKeyParameters); + } + else + { + // PKCS8 key + RsaPrivateCrtKeyParameters pk = (RsaPrivateCrtKeyParameters)key; + rsaParams = DotNetUtilities.ToRSAParameters(pk); + keypair = DotNetUtilities.GetRsaKeyPair(rsaParams); + } + + if (keypair == null) + { + throw new Exception("Unknown error."); + } } } catch (Exception e) @@ -207,6 +200,19 @@ private string GenerateJwtToken() return jwtToken; } + private PemReader CreatePemReader(TextReader textReader, string privateKeyPassword) + { + if (null != privateKeyPassword) + { + IPasswordFinder ipwdf = new PasswordFinder(privateKeyPassword); + return new PemReader(textReader, ipwdf); + } + else + { + return new PemReader(textReader); + } + } + /// /// Helper class to handle the password for the certificate if there is one. /// diff --git a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs index 463363c6c..edbee426e 100644 --- a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs +++ b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs @@ -10,7 +10,7 @@ namespace Snowflake.Data.Core.FileTransfer { /// - /// The encryption materials. + /// The encryption materials. /// internal class MaterialDescriptor { @@ -22,7 +22,7 @@ internal class MaterialDescriptor } /// - /// The encryptor/decryptor for PUT/GET files. + /// The encryptor/decryptor for PUT/GET files. /// class EncryptionProvider { @@ -89,7 +89,7 @@ public static Stream EncryptStream( // Encrypt file key byte[] encryptedFileKey = encryptFileKey(decodedMasterKey, keyData); - + // Store encryption metadata information MaterialDescriptor matDesc = new MaterialDescriptor { @@ -116,17 +116,21 @@ public static Stream EncryptStream( /// The encrypted key. private static byte[] encryptFileKey(byte[] masterKey, byte[] unencryptedFileKey) { - Aes aes = Aes.Create(); - aes.Key = masterKey; - aes.Mode = CipherMode.ECB; - aes.Padding = PaddingMode.PKCS7; - - MemoryStream cipherStream = new MemoryStream(); - CryptoStream cryptoStream = new CryptoStream(cipherStream, aes.CreateEncryptor(), CryptoStreamMode.Write); - cryptoStream.Write(unencryptedFileKey, 0, unencryptedFileKey.Length); - cryptoStream.FlushFinalBlock(); + using (Aes aes = Aes.Create()) + { + aes.Key = masterKey; + aes.Mode = CipherMode.ECB; + aes.Padding = PaddingMode.PKCS7; - return cipherStream.ToArray(); + using (MemoryStream cipherStream = new MemoryStream()) + using (var encryptor = aes.CreateEncryptor()) + using (CryptoStream cryptoStream = new CryptoStream(cipherStream, encryptor, CryptoStreamMode.Write)) + { + cryptoStream.Write(unencryptedFileKey, 0, unencryptedFileKey.Length); + cryptoStream.FlushFinalBlock(); + return cipherStream.ToArray(); + } + } } /// @@ -143,26 +147,31 @@ private static Stream CreateEncryptedBytesStream( byte[] iv, FileTransferConfiguration transferConfiguration) { - Aes aes = Aes.Create(); - aes.Key = key; - aes.Mode = CipherMode.CBC; - aes.Padding = PaddingMode.PKCS7; - aes.IV = iv; - inputStream.Position = 0; - - var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); - CryptoStream cryptoStream = new CryptoStream(targetStream, aes.CreateEncryptor(), CryptoStreamMode.Write); - byte[] buffer = new byte[transferConfiguration.MaxBytesInMemory]; - int bytesRead; - while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0) + using (Aes aes = Aes.Create()) { - cryptoStream.Write(buffer, 0, bytesRead); + aes.Key = key; + aes.Mode = CipherMode.CBC; + aes.Padding = PaddingMode.PKCS7; + aes.IV = iv; + inputStream.Position = 0; + + var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); + using (var encryptor = aes.CreateEncryptor()) + using (CryptoStream cryptoStream = new CryptoStream(targetStream, encryptor, CryptoStreamMode.Write)) + { + byte[] buffer = new byte[transferConfiguration.MaxBytesInMemory]; + int bytesRead; + while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0) + { + cryptoStream.Write(buffer, 0, bytesRead); + } + cryptoStream.FlushFinalBlock(); + + return targetStream; + } } - cryptoStream.FlushFinalBlock(); - - return targetStream; } - + /// /// Decrypt data and write to the outStream. /// @@ -218,17 +227,22 @@ public static string DecryptFile( /// The encrypted key. private static byte[] decryptFileKey(byte[] masterKey, byte[] unencryptedFileKey) { - Aes aes = Aes.Create(); - aes.Key = masterKey; - aes.Mode = CipherMode.ECB; - aes.Padding = PaddingMode.PKCS7; + using (Aes aes = Aes.Create()) + { + aes.Key = masterKey; + aes.Mode = CipherMode.ECB; + aes.Padding = PaddingMode.PKCS7; - MemoryStream cipherStream = new MemoryStream(); - CryptoStream cryptoStream = new CryptoStream(cipherStream, aes.CreateDecryptor(), CryptoStreamMode.Write); - cryptoStream.Write(unencryptedFileKey, 0, unencryptedFileKey.Length); - cryptoStream.FlushFinalBlock(); + using (MemoryStream cipherStream = new MemoryStream()) + using (var encryptor = aes.CreateDecryptor()) + using (CryptoStream cryptoStream = new CryptoStream(cipherStream, encryptor, CryptoStreamMode.Write)) + { + cryptoStream.Write(unencryptedFileKey, 0, unencryptedFileKey.Length); + cryptoStream.FlushFinalBlock(); - return cipherStream.ToArray(); + return cipherStream.ToArray(); + } + } } /// @@ -244,22 +258,26 @@ private static Stream CreateDecryptedBytesStream( byte[] iv, FileTransferConfiguration transferConfiguration) { - Aes aes = Aes.Create(); - aes.Key = key; - aes.Mode = CipherMode.CBC; - aes.Padding = PaddingMode.PKCS7; - aes.IV = iv; - - var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); - CryptoStream cryptoStream = new CryptoStream(targetStream, aes.CreateDecryptor(), CryptoStreamMode.Write); - - using(Stream inStream = File.OpenRead(inFile)) + using (Aes aes = Aes.Create()) { - inStream.CopyTo(cryptoStream); - } - cryptoStream.FlushFinalBlock(); + aes.Key = key; + aes.Mode = CipherMode.CBC; + aes.Padding = PaddingMode.PKCS7; + aes.IV = iv; + + var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); + using (var decryptor = aes.CreateDecryptor()) + using (CryptoStream cryptoStream = new CryptoStream(targetStream, decryptor, CryptoStreamMode.Write)) + { + using (Stream inStream = File.OpenRead(inFile)) + { + inStream.CopyTo(cryptoStream); + } + cryptoStream.FlushFinalBlock(); - return targetStream; + return targetStream; + } + } } } } diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs index 6b98f9fb1..b81581fca 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs @@ -62,7 +62,7 @@ internal static void DownloadOneFile(SFFileMetadata fileMetadata) } // Create stream object for reader and writer - Stream stream = new MemoryStream(File.ReadAllBytes(realSrcFilePath)); + using (Stream stream = new MemoryStream(File.ReadAllBytes(realSrcFilePath))) using (var fileStream = File.Create(output)) { // Write file diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs index 88b20c1d5..60d67b5d7 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs @@ -484,7 +484,7 @@ public async Task DownloadFileAsync(SFFileMetadata fileMetadata, string fullDstP try { // Issue the GET request - GetObjectResponse response = await client.GetObjectAsync(getObjectRequest, cancellationToken).ConfigureAwait(false); + using (GetObjectResponse response = await client.GetObjectAsync(getObjectRequest, cancellationToken).ConfigureAwait(false)) // Write to file using (var fileStream = File.Create(fullDstPath)) diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index 71dec60fb..d62a80e75 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -75,26 +75,36 @@ public void Upload(Dictionary bindings) StringBuilder sBuffer = new StringBuilder(); MemoryStream ms = new MemoryStream(); - StreamWriter tw = new StreamWriter(ms); - - for (int i = startIndex; i < rowNum; i++) - { - sBuffer.Append(dataRows[i]); - } - tw.Write(sBuffer.ToString()); - tw.Flush(); - try { - string fileName = (++fileCount).ToString(); - UploadStream(ref ms, fileName); - startIndex = rowNum; - curBytes = 0; + using (StreamWriter tw = new StreamWriter(ms)) + { + + for (int i = startIndex; i < rowNum; i++) + { + sBuffer.Append(dataRows[i]); + } + + tw.Write(sBuffer.ToString()); + tw.Flush(); + + try + { + string fileName = (++fileCount).ToString(); + UploadStream(ref ms, fileName); + startIndex = rowNum; + curBytes = 0; + } + catch (IOException e) + { + // failure using stream put + throw new Exception("file stream upload error." + e.ToString()); + } + } } - catch (IOException e) + finally { - // failure using stream put - throw new Exception("file stream upload error." + e.ToString()); + ms.Dispose(); } } } @@ -122,27 +132,29 @@ internal async Task UploadAsync(Dictionary bindings, Cancell } StringBuilder sBuffer = new StringBuilder(); - MemoryStream ms = new MemoryStream(); - StreamWriter tw = new StreamWriter(ms); - - for (int i = startIndex; i < rowNum; i++) + using (MemoryStream ms = new MemoryStream()) + using (StreamWriter tw = new StreamWriter(ms)) { - sBuffer.Append(dataRows[i]); - } - tw.Write(sBuffer.ToString()); - tw.Flush(); + for (int i = startIndex; i < rowNum; i++) + { + sBuffer.Append(dataRows[i]); + } - try - { - string fileName = (++fileCount).ToString(); - await UploadStreamAsync(ms, fileName, cancellationToken).ConfigureAwait(false); - startIndex = rowNum; - curBytes = 0; - } - catch (IOException e) - { - // failure using stream put - throw new Exception("file stream upload error." + e.ToString()); + tw.Write(sBuffer.ToString()); + tw.Flush(); + + try + { + string fileName = (++fileCount).ToString(); + await UploadStreamAsync(ms, fileName, cancellationToken).ConfigureAwait(false); + startIndex = rowNum; + curBytes = 0; + } + catch (IOException e) + { + // failure using stream put + throw new Exception("file stream upload error." + e.ToString()); + } } } } @@ -152,7 +164,7 @@ private void CreateDataRows(ref List dataRows, Dictionary arrbinds = bindings.Values.ToList(); List> bindList = new List>(); List types = new List(); // for the binding types - dataRows = new List(); // for the converted data string + dataRows = new List(); // for the converted data string int rowSize = ((List)arrbinds[0].value).Count; int paramSize = arrbinds.Count; @@ -202,7 +214,7 @@ private void UploadStream(ref MemoryStream stream, string destFileName) SFStatement statement = new SFStatement(session); statement.SetUploadStream(stream, destFileName, stagePath); statement.ExecuteTransfer(putStmt); - + } internal async Task UploadStreamAsync(MemoryStream stream, string destFileName, CancellationToken cancellationToken) @@ -224,7 +236,7 @@ internal async Task UploadStreamAsync(MemoryStream stream, string destFileName, statement.SetUploadStream(stream, destFileName, stagePath); await statement.ExecuteTransferAsync(putStmt, cancellationToken).ConfigureAwait(false); } - + internal string GetCSVData(string sType, string sValue) { if (sValue == null) @@ -264,7 +276,7 @@ internal string GetCSVData(string sType, string sValue) case "TIMESTAMP_TZ": string[] tstzString = sValue.Split(' '); long nsFromEpochTz = long.Parse(tstzString[0]); // SFDateConverter provides in [ns] from Epoch - int timeZoneOffset = int.Parse(tstzString[1]) - 1440; // SFDateConverter provides in minutes increased by 1440m + int timeZoneOffset = int.Parse(tstzString[1]) - 1440; // SFDateConverter provides in minutes increased by 1440m DateTime timestamp = epoch.AddTicks(nsFromEpochTz/100).AddMinutes(timeZoneOffset); TimeSpan offset = TimeSpan.FromMinutes(timeZoneOffset); DateTimeOffset tzDateTimeOffset = new DateTimeOffset(timestamp.Ticks, offset); diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs index 33451a0da..52fc754b5 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs @@ -70,11 +70,11 @@ public SFBlockingChunkDownloaderV3(int colCount, for (int i=0; i DownloadChunkAsync(DownloadContextV3 downloa { if (String.Compare(encoding.First(), "gzip", true) == 0) { - Stream stream_gzip = new GZipStream(stream, CompressionMode.Decompress); - await ParseStreamIntoChunk(stream_gzip, chunk); + using (Stream streamGzip = new GZipStream(stream, CompressionMode.Decompress)) + { + await ParseStreamIntoChunk(streamGzip, chunk).ConfigureAwait(false); + } } else { - await ParseStreamIntoChunk(stream, chunk); + await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); } } else { - await ParseStreamIntoChunk(stream, chunk); + await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); } } catch (Exception e) @@ -214,7 +216,7 @@ private async Task DownloadChunkAsync(DownloadContextV3 downloa logger.Info($"Succeed downloading chunk #{chunk.ChunkIndex}"); return chunk; } - + private async Task ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) { IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.ResultFormat, content); From d108ac222853c25154ddd54fe47a61ad2c642e07 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 30 Aug 2024 13:04:54 +0000 Subject: [PATCH 3/6] fix the problem with disposing encrypted streams --- .../UnitTests/SFRemoteStorageClientTest.cs | 4 +-- .../Core/FileTransfer/EncryptionProvider.cs | 31 ++++++++++++------- .../FileTransfer/FileBackedOutputStream.cs | 2 +- .../StorageClient/SFRemoteStorageUtil.cs | 26 ++++++++-------- .../Core/FileTransfer/StreamPair.cs | 17 ++++++++++ 5 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 Snowflake.Data/Core/FileTransfer/StreamPair.cs diff --git a/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs index fa7df6250..c05a7f0f5 100644 --- a/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs @@ -527,14 +527,14 @@ private void SetUpMockEncryptedFileForDownload() // Get encrypted stream from file SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); - Stream stream = EncryptionProvider.EncryptFile( + StreamPair streamPair = EncryptionProvider.EncryptFile( t_downloadFileName, _fileMetadata.encryptionMaterial, encryptionMetadata, FileTransferConfiguration.FromFileMetadata(_fileMetadata)); // Set up the stream and metadata for decryption - MockRemoteStorageClient.SetEncryptionData(stream, encryptionMetadata.iv, encryptionMetadata.key); + MockRemoteStorageClient.SetEncryptionData(streamPair.MainStream, encryptionMetadata.iv, encryptionMetadata.key); } [Test] diff --git a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs index edbee426e..5135f42f7 100644 --- a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs +++ b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs @@ -43,7 +43,7 @@ class EncryptionProvider /// Store the encryption metadata into /// Contains parameters used during encryption process /// The encrypted bytes of the file to upload. - public static Stream EncryptFile( + public static StreamPair EncryptFile( string inFile, PutGetEncryptionMaterial encryptionMaterial, SFEncryptionMetadata encryptionMetadata, @@ -63,7 +63,7 @@ public static Stream EncryptFile( /// Store the encryption metadata into /// Contains parameters used during encryption process /// The encrypted bytes of the file to upload. - public static Stream EncryptStream( + public static StreamPair EncryptStream( Stream inputStream, PutGetEncryptionMaterial encryptionMaterial, SFEncryptionMetadata encryptionMetadata, @@ -141,7 +141,7 @@ private static byte[] encryptFileKey(byte[] masterKey, byte[] unencryptedFileKey /// The encryption IV or null if it needs to be generated. /// Contains parameters used during encryption process /// The encrypted bytes. - private static Stream CreateEncryptedBytesStream( + private static StreamPair CreateEncryptedBytesStream( Stream inputStream, byte[] key, byte[] iv, @@ -157,8 +157,8 @@ private static Stream CreateEncryptedBytesStream( var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); using (var encryptor = aes.CreateEncryptor()) - using (CryptoStream cryptoStream = new CryptoStream(targetStream, encryptor, CryptoStreamMode.Write)) { + CryptoStream cryptoStream = new CryptoStream(targetStream, encryptor, CryptoStreamMode.Write); byte[] buffer = new byte[transferConfiguration.MaxBytesInMemory]; int bytesRead; while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0) @@ -167,7 +167,11 @@ private static Stream CreateEncryptedBytesStream( } cryptoStream.FlushFinalBlock(); - return targetStream; + return new StreamPair + { + MainStream = targetStream, + HelperStream = cryptoStream // cryptoStream cannot be closed here because it would close target stream as well + }; } } } @@ -204,7 +208,7 @@ public static string DecryptFile( byte[] decryptedFileKey = decryptFileKey(decodedMasterKey, keyBytes); // Create key decipher with decoded key and AES ECB - using (var decryptedBytesStream = CreateDecryptedBytesStream( + using (var decryptedBytesStreamPair = CreateDecryptedBytesStream( inFile, decryptedFileKey, ivBytes, @@ -212,6 +216,7 @@ public static string DecryptFile( { using (var decryptedFileStream = File.Create(tempFileName)) { + var decryptedBytesStream = decryptedBytesStreamPair.MainStream; decryptedBytesStream.Position = 0; decryptedBytesStream.CopyTo(decryptedFileStream); } @@ -234,8 +239,8 @@ private static byte[] decryptFileKey(byte[] masterKey, byte[] unencryptedFileKey aes.Padding = PaddingMode.PKCS7; using (MemoryStream cipherStream = new MemoryStream()) - using (var encryptor = aes.CreateDecryptor()) - using (CryptoStream cryptoStream = new CryptoStream(cipherStream, encryptor, CryptoStreamMode.Write)) + using (var decryptor = aes.CreateDecryptor()) + using (CryptoStream cryptoStream = new CryptoStream(cipherStream, decryptor, CryptoStreamMode.Write)) { cryptoStream.Write(unencryptedFileKey, 0, unencryptedFileKey.Length); cryptoStream.FlushFinalBlock(); @@ -252,7 +257,7 @@ private static byte[] decryptFileKey(byte[] masterKey, byte[] unencryptedFileKey /// The encryption key. /// The encryption IV or null if it needs to be generated. /// The decrypted bytes stream - private static Stream CreateDecryptedBytesStream( + private static StreamPair CreateDecryptedBytesStream( string inFile, byte[] key, byte[] iv, @@ -267,15 +272,19 @@ private static Stream CreateDecryptedBytesStream( var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); using (var decryptor = aes.CreateDecryptor()) - using (CryptoStream cryptoStream = new CryptoStream(targetStream, decryptor, CryptoStreamMode.Write)) { + CryptoStream cryptoStream = new CryptoStream(targetStream, decryptor, CryptoStreamMode.Write); using (Stream inStream = File.OpenRead(inFile)) { inStream.CopyTo(cryptoStream); } cryptoStream.FlushFinalBlock(); - return targetStream; + return new StreamPair + { + MainStream = targetStream, + HelperStream = cryptoStream // cryptoStream cannot be closed here because it would close target stream as well + }; } } } diff --git a/Snowflake.Data/Core/FileTransfer/FileBackedOutputStream.cs b/Snowflake.Data/Core/FileTransfer/FileBackedOutputStream.cs index 3fea3f6ad..76707a060 100644 --- a/Snowflake.Data/Core/FileTransfer/FileBackedOutputStream.cs +++ b/Snowflake.Data/Core/FileTransfer/FileBackedOutputStream.cs @@ -89,7 +89,7 @@ internal bool IsUsingFileOutputStream() { return _fileOutputStream != null; } - + private void SwitchFromMemoryToTempFile() { _fileName = GenerateTempFilePath(); diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFRemoteStorageUtil.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFRemoteStorageUtil.cs index 92e8dd467..7ca2ee27d 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFRemoteStorageUtil.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFRemoteStorageUtil.cs @@ -11,7 +11,7 @@ namespace Snowflake.Data.Core.FileTransfer { /// - /// The class containing file header information. + /// The class containing file header information. /// internal class FileHeader { @@ -21,12 +21,12 @@ internal class FileHeader } /// - /// The interface for the storage clients. + /// The interface for the storage clients. /// class SFRemoteStorageUtil { /// - /// Strings to indicate specific storage type. + /// Strings to indicate specific storage type. /// public const string S3_FS = "S3"; public const string AZURE_FS = "AZURE"; @@ -34,12 +34,12 @@ class SFRemoteStorageUtil public const string LOCAL_FS = "LOCAL_FS"; /// - /// Amount of concurrency to use by default. + /// Amount of concurrency to use by default. /// const int DEFAULT_CONCURRENCY = 1; /// - /// Maximum amount of times to retry. + /// Maximum amount of times to retry. /// const int DEFAULT_MAX_RETRY = 5; @@ -87,7 +87,7 @@ internal static ISFRemoteStorageClient GetRemoteStorage(PutGetResponseData respo internal static void UploadOneFile(SFFileMetadata fileMetadata) { SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); - using (var fileBytesStream = GetFileBytesStream(fileMetadata, encryptionMetadata)) + using (var fileBytesStreamPair = GetFileBytesStream(fileMetadata, encryptionMetadata)) { int maxConcurrency = fileMetadata.parallel; @@ -116,7 +116,7 @@ internal static void UploadOneFile(SFFileMetadata fileMetadata) if (fileMetadata.overwrite || fileMetadata.resultStatus == ResultStatus.NOT_FOUND_FILE.ToString()) { // Upload the file - client.UploadFile(fileMetadata, fileBytesStream, encryptionMetadata); + client.UploadFile(fileMetadata, fileBytesStreamPair.MainStream, encryptionMetadata); } if (fileMetadata.resultStatus == ResultStatus.UPLOADED.ToString() || @@ -149,7 +149,7 @@ internal static void UploadOneFile(SFFileMetadata fileMetadata) internal static async Task UploadOneFileAsync(SFFileMetadata fileMetadata, CancellationToken cancellationToken) { SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); - using (var fileBytesStream = GetFileBytesStream(fileMetadata, encryptionMetadata)) + using (var fileBytesStreamPair = GetFileBytesStream(fileMetadata, encryptionMetadata)) { int maxConcurrency = fileMetadata.parallel; @@ -180,7 +180,7 @@ internal static async Task UploadOneFileAsync(SFFileMetadata fileMetadata, Cance { // Upload the file await client - .UploadFileAsync(fileMetadata, fileBytesStream, encryptionMetadata, cancellationToken) + .UploadFileAsync(fileMetadata, fileBytesStreamPair.MainStream, encryptionMetadata, cancellationToken) .ConfigureAwait(false); } @@ -516,8 +516,8 @@ private static void HandleDownloadFileErr(ref SFFileMetadata fileMetadata, ref i System.Threading.Thread.Sleep(sleepingTime); } } - - private static Stream GetFileBytesStream(SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata) + + private static StreamPair GetFileBytesStream(SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata) { // If encryption enabled, encrypt the file to be uploaded if (fileMetadata.encryptionMaterial != null) @@ -543,11 +543,11 @@ private static Stream GetFileBytesStream(SFFileMetadata fileMetadata, SFEncrypti { if (fileMetadata.memoryStream != null) { - return fileMetadata.memoryStream; + return new StreamPair { MainStream = fileMetadata.memoryStream }; } else { - return File.OpenRead(fileMetadata.realSrcFilePath); + return new StreamPair { MainStream = File.OpenRead(fileMetadata.realSrcFilePath) }; } } } diff --git a/Snowflake.Data/Core/FileTransfer/StreamPair.cs b/Snowflake.Data/Core/FileTransfer/StreamPair.cs new file mode 100644 index 000000000..fe628a3a7 --- /dev/null +++ b/Snowflake.Data/Core/FileTransfer/StreamPair.cs @@ -0,0 +1,17 @@ +using System; +using System.IO; + +namespace Snowflake.Data.Core.FileTransfer +{ + internal class StreamPair: IDisposable + { + public Stream MainStream { get; set; } + public Stream HelperStream { get; set; } + + public void Dispose() + { + MainStream?.Dispose(); + HelperStream?.Dispose(); + } + } +} From a8f9d7e2a4b51ffa48b833bed89323630f8c5c92 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 2 Sep 2024 08:25:42 +0000 Subject: [PATCH 4/6] test from StreamPair class --- .../UnitTests/StreamPairTest.cs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 Snowflake.Data.Tests/UnitTests/StreamPairTest.cs diff --git a/Snowflake.Data.Tests/UnitTests/StreamPairTest.cs b/Snowflake.Data.Tests/UnitTests/StreamPairTest.cs new file mode 100644 index 000000000..2476a5a90 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/StreamPairTest.cs @@ -0,0 +1,39 @@ +using System.IO; +using NUnit.Framework; +using Moq; +using Snowflake.Data.Core.FileTransfer; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + public class StreamPairTest + { + [Test] + public void TestCloseBothStreams() + { + // arrange + var mockedMainStream = new Mock(); + var mockedHelperStream = new Mock(); + + // act + using (new StreamPair { MainStream = mockedMainStream.Object, HelperStream = mockedHelperStream.Object }) {} + + // assert + mockedMainStream.Verify(stream => stream.Close()); + mockedHelperStream.Verify(stream => stream.Close()); + } + + [Test] + public void TestCloseMainStreamOnlyWhenHelperStreamNotGiven() + { + // arrange + var mockedMainStream = new Mock(); + + // act + using (new StreamPair { MainStream = mockedMainStream.Object }) {} + + // assert + mockedMainStream.Verify(stream => stream.Close()); + } + } +} From 27d63df9f082c0292d7ca1ce62ec4ab61c73baab Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 2 Sep 2024 12:48:46 +0000 Subject: [PATCH 5/6] fix review comments --- .../Core/FileTransfer/EncryptionProvider.cs | 4 +- .../StorageClient/SFLocalStorageUtil.cs | 2 +- Snowflake.Data/Core/SFBindUploader.cs | 47 ++++++++----------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs index 5135f42f7..411a6eeab 100644 --- a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs +++ b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs @@ -122,7 +122,7 @@ private static byte[] encryptFileKey(byte[] masterKey, byte[] unencryptedFileKey aes.Mode = CipherMode.ECB; aes.Padding = PaddingMode.PKCS7; - using (MemoryStream cipherStream = new MemoryStream()) + MemoryStream cipherStream = new MemoryStream(); using (var encryptor = aes.CreateEncryptor()) using (CryptoStream cryptoStream = new CryptoStream(cipherStream, encryptor, CryptoStreamMode.Write)) { @@ -238,7 +238,7 @@ private static byte[] decryptFileKey(byte[] masterKey, byte[] unencryptedFileKey aes.Mode = CipherMode.ECB; aes.Padding = PaddingMode.PKCS7; - using (MemoryStream cipherStream = new MemoryStream()) + MemoryStream cipherStream = new MemoryStream(); using (var decryptor = aes.CreateDecryptor()) using (CryptoStream cryptoStream = new CryptoStream(cipherStream, decryptor, CryptoStreamMode.Write)) { diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs index b81581fca..6b98f9fb1 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFLocalStorageUtil.cs @@ -62,7 +62,7 @@ internal static void DownloadOneFile(SFFileMetadata fileMetadata) } // Create stream object for reader and writer - using (Stream stream = new MemoryStream(File.ReadAllBytes(realSrcFilePath))) + Stream stream = new MemoryStream(File.ReadAllBytes(realSrcFilePath)); using (var fileStream = File.Create(output)) { // Write file diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index d62a80e75..6268c724c 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -75,36 +75,29 @@ public void Upload(Dictionary bindings) StringBuilder sBuffer = new StringBuilder(); MemoryStream ms = new MemoryStream(); - try + using (StreamWriter tw = new StreamWriter(ms)) { - using (StreamWriter tw = new StreamWriter(ms)) + + for (int i = startIndex; i < rowNum; i++) { + sBuffer.Append(dataRows[i]); + } - for (int i = startIndex; i < rowNum; i++) - { - sBuffer.Append(dataRows[i]); - } - - tw.Write(sBuffer.ToString()); - tw.Flush(); - - try - { - string fileName = (++fileCount).ToString(); - UploadStream(ref ms, fileName); - startIndex = rowNum; - curBytes = 0; - } - catch (IOException e) - { - // failure using stream put - throw new Exception("file stream upload error." + e.ToString()); - } + tw.Write(sBuffer.ToString()); + tw.Flush(); + + try + { + string fileName = (++fileCount).ToString(); + UploadStream(ref ms, fileName); + startIndex = rowNum; + curBytes = 0; + } + catch (IOException e) + { + // failure using stream put + throw new Exception("file stream upload error." + e.ToString()); } - } - finally - { - ms.Dispose(); } } } @@ -132,7 +125,7 @@ internal async Task UploadAsync(Dictionary bindings, Cancell } StringBuilder sBuffer = new StringBuilder(); - using (MemoryStream ms = new MemoryStream()) + MemoryStream ms = new MemoryStream(); using (StreamWriter tw = new StreamWriter(ms)) { for (int i = startIndex; i < rowNum; i++) From 4e35ac97ad7a49be431828c0c8c31eb9e72588f6 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 2 Sep 2024 14:05:44 +0000 Subject: [PATCH 6/6] Add comment to StreamPair class --- Snowflake.Data/Core/FileTransfer/StreamPair.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Snowflake.Data/Core/FileTransfer/StreamPair.cs b/Snowflake.Data/Core/FileTransfer/StreamPair.cs index fe628a3a7..73e44fe12 100644 --- a/Snowflake.Data/Core/FileTransfer/StreamPair.cs +++ b/Snowflake.Data/Core/FileTransfer/StreamPair.cs @@ -3,6 +3,13 @@ namespace Snowflake.Data.Core.FileTransfer { + /* + * StreamPair class has been introduced to solve the issue for a stream which is meant to be returned from a method, + * but another helper stream is created in this method and is tightly coupled with the main stream, + * so the helper stream cannot be closed in this method because it would close the main stream as well + * (if CryptoStream in EncryptionProvider class would be disposed it would close the base stream as well). + * The solution is to return both streams and dispose both of them together when processing of the main stream is over. + */ internal class StreamPair: IDisposable { public Stream MainStream { get; set; }