Skip to content

Commit

Permalink
SNOW-935778: Add support for unicode characters in file path in PUT/G…
Browse files Browse the repository at this point in the history
…ET command
  • Loading branch information
Harry Xi authored Oct 26, 2023
2 parents 5f5cf96 + 64c8069 commit 2c9f77c
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 28 deletions.
30 changes: 17 additions & 13 deletions cpp/FileMetadataInitializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

Snowflake::Client::FileMetadataInitializer::FileMetadataInitializer(
std::vector<FileMetadata> &smallFileMetadata,
std::vector<FileMetadata> &largeFileMetadata) :
std::vector<FileMetadata> &largeFileMetadata,
IStatementPutGet *stmtPutGet) :
m_smallFileMetadata(smallFileMetadata),
m_largeFileMetadata(largeFileMetadata),
m_autoCompress(true)
m_autoCompress(true),
m_stmtPutGet(stmtPutGet)
{
}

Expand All @@ -39,9 +41,9 @@ Snowflake::Client::FileMetadataInitializer::initUploadFileMetadata(const std::st
fileNameFull += fileName;

FileMetadata fileMetadata;
fileMetadata.srcFileName = fileNameFull;
fileMetadata.srcFileName = m_stmtPutGet->platformStringToUTF8(fileNameFull);
fileMetadata.srcFileSize = fileSize;
fileMetadata.destFileName = std::string(fileName);
fileMetadata.destFileName = m_stmtPutGet->platformStringToUTF8(std::string(fileName));
// process compression type
initCompressionMetadata(fileMetadata);

Expand All @@ -56,9 +58,11 @@ void Snowflake::Client::FileMetadataInitializer::populateSrcLocUploadMetadata(st
size_t putThreshold)
{
// looking for files on disk.
std::string srcLocationPlatform = m_stmtPutGet->UTF8ToPlatformString(sourceLocation);

#ifdef _WIN32
WIN32_FIND_DATA fdd;
HANDLE hFind = FindFirstFile(sourceLocation.c_str(), &fdd);
HANDLE hFind = FindFirstFile(srcLocationPlatform.c_str(), &fdd);
if (hFind == INVALID_HANDLE_VALUE)
{
DWORD dwError = GetLastError();
Expand All @@ -73,22 +77,22 @@ void Snowflake::Client::FileMetadataInitializer::populateSrcLocUploadMetadata(st
{
CXX_LOG_ERROR("Failed on FindFirstFile. Error: %d", dwError);
throw SnowflakeTransferException(TransferError::DIR_OPEN_ERROR,
sourceLocation.c_str(), dwError);
srcLocationPlatform.c_str(), dwError);
}
}

do {
if (!(fdd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) )
{
std::string fileFullPath = std::string(fdd.cFileName);
size_t dirSep = sourceLocation.find_last_of(PATH_SEP);
size_t dirSep = srcLocationPlatform.find_last_of(PATH_SEP);
if (dirSep == std::string::npos)
{
dirSep = sourceLocation.find_last_of(ALTER_PATH_SEP);
}
if (dirSep != std::string::npos)
{
std::string dirPath = sourceLocation.substr(0, dirSep + 1);
std::string dirPath = srcLocationPlatform.substr(0, dirSep + 1);
LARGE_INTEGER fileSize;
fileSize.LowPart = fdd.nFileSizeLow;
fileSize.HighPart = fdd.nFileSizeHigh;
Expand All @@ -102,14 +106,14 @@ void Snowflake::Client::FileMetadataInitializer::populateSrcLocUploadMetadata(st
{
CXX_LOG_ERROR("Failed on FindNextFile. Error: %d", dwError);
throw SnowflakeTransferException(TransferError::DIR_OPEN_ERROR,
sourceLocation.c_str(), dwError);
srcLocationPlatform.c_str(), dwError);
}
FindClose(hFind);

#else
unsigned long dirSep = sourceLocation.find_last_of(PATH_SEP);
std::string dirPath = sourceLocation.substr(0, dirSep + 1);
std::string filePattern = sourceLocation.substr(dirSep + 1);
unsigned long dirSep = srcLocationPlatform.find_last_of(PATH_SEP);
std::string dirPath = srcLocationPlatform.substr(0, dirSep + 1);
std::string filePattern = srcLocationPlatform.substr(dirSep + 1);

DIR * dir = nullptr;
struct dirent * dir_entry;
Expand All @@ -133,7 +137,7 @@ void Snowflake::Client::FileMetadataInitializer::populateSrcLocUploadMetadata(st
{
CXX_LOG_ERROR("Cannot read path struct");
throw SnowflakeTransferException(TransferError::DIR_OPEN_ERROR,
sourceLocation.c_str(), ret);
srcLocationPlatform.c_str(), ret);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion cpp/FileMetadataInitializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>
#include "FileMetadata.hpp"
#include "IStorageClient.hpp"
#include "snowflake/IStatementPutGet.hpp"

// used to decide whether to upload in sequence or in parallel
#define DEFAULT_UPLOAD_DATA_SIZE_THRESHOLD 209715200 //200Mb
Expand All @@ -25,7 +26,8 @@ class FileMetadataInitializer
{
public:
FileMetadataInitializer(std::vector<FileMetadata> &smallFileMetadata,
std::vector<FileMetadata> &largeFileMetadata);
std::vector<FileMetadata> &largeFileMetadata,
IStatementPutGet *stmtPutGet);

/**
* Given a source locations, find all files that match the location pattern,
Expand Down Expand Up @@ -101,6 +103,9 @@ class FileMetadataInitializer

/// Random device for crytpo random num generator.
Crypto::CryptoRandomDevice m_randDevice;

// statement which provides encoding conversion funcationality
IStatementPutGet *m_stmtPutGet;
};
}
}
Expand Down
16 changes: 9 additions & 7 deletions cpp/FileTransferAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Snowflake::Client::FileTransferAgent::FileTransferAgent(
IStatementPutGet *statement,
TransferConfig *transferConfig) :
m_stmtPutGet(statement),
m_FileMetadataInitializer(m_smallFilesMeta, m_largeFilesMeta),
m_FileMetadataInitializer(m_smallFilesMeta, m_largeFilesMeta, statement),
m_executionResults(nullptr),
m_storageClient(nullptr),
m_lastRefreshTokenSec(0),
Expand Down Expand Up @@ -488,7 +488,7 @@ RemoteStorageRequestOutcome Snowflake::Client::FileTransferAgent::uploadSingleFi
srcFileStream = m_uploadStream;
} else {
try {
fs = ::std::fstream(fileMetadata->srcFileToUpload.c_str(),
fs = ::std::fstream(m_stmtPutGet->UTF8ToPlatformString(fileMetadata->srcFileToUpload).c_str(),
::std::ios_base::in |
::std::ios_base::binary);
}
Expand Down Expand Up @@ -613,20 +613,21 @@ void Snowflake::Client::FileTransferAgent::compressSourceFile(
}

std::string stagingFile(tempDir);
stagingFile += fileMetadata->destFileName;
stagingFile += m_stmtPutGet->UTF8ToPlatformString(fileMetadata->destFileName);
std::string srcFileNamePlatform = m_stmtPutGet->UTF8ToPlatformString(fileMetadata->srcFileName);

FILE *sourceFile = fopen(fileMetadata->srcFileName.c_str(), "r");
FILE *sourceFile = fopen(srcFileNamePlatform.c_str(), "r");
if( !sourceFile ){
CXX_LOG_ERROR("Failed to open srcFileName %s. Errno: %d", fileMetadata->srcFileName.c_str(), errno);
throw SnowflakeTransferException(TransferError::FILE_OPEN_ERROR, fileMetadata->srcFileName.c_str(), -1);
throw SnowflakeTransferException(TransferError::FILE_OPEN_ERROR, srcFileNamePlatform.c_str(), -1);
}
FILE *destFile = fopen(stagingFile.c_str(), "w");
if ( !destFile) {
CXX_LOG_ERROR("Failed to open srcFileToUpload file %s. Errno: %d", stagingFile.c_str(), errno);
throw SnowflakeTransferException(TransferError::FILE_OPEN_ERROR, stagingFile.c_str(), -1);
}
// set srcFileToUpload after open file successfully to prevent command injection.
fileMetadata->srcFileToUpload = stagingFile;
fileMetadata->srcFileToUpload = m_stmtPutGet->platformStringToUTF8(stagingFile);

int ret = Util::CompressionUtil::compressWithGzip(sourceFile, destFile,
fileMetadata->srcFileToUploadSize, level);
Expand Down Expand Up @@ -829,6 +830,7 @@ RemoteStorageRequestOutcome Snowflake::Client::FileTransferAgent::downloadSingle
{
fileMetadata->destPath = std::string(response.localLocation) + PATH_SEP +
fileMetadata->destFileName;
std::string destPathPlatform = m_stmtPutGet->UTF8ToPlatformString(fileMetadata->destPath);

RemoteStorageRequestOutcome outcome = RemoteStorageRequestOutcome::FAILED;
RetryContext getRetryCtx(fileMetadata->srcFileName, m_maxGetRetries);
Expand All @@ -839,7 +841,7 @@ RemoteStorageRequestOutcome Snowflake::Client::FileTransferAgent::downloadSingle

std::basic_fstream<char> dstFile;
try {
dstFile = std::basic_fstream<char>(fileMetadata->destPath.c_str(),
dstFile = std::basic_fstream<char>(destPathPlatform.c_str(),
std::ios_base::out | std::ios_base::binary);
}
catch (...) {
Expand Down
12 changes: 12 additions & 0 deletions include/snowflake/IStatementPutGet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ class IStatementPutGet
return NULL;
}

// Utility functions to convert enconding between UTF-8 to the encoding
// from system locale. No coversion by default.
virtual std::string UTF8ToPlatformString(const std::string& utf8_str)
{
return utf8_str;
}

virtual std::string platformStringToUTF8(const std::string& platform_str)
{
return platform_str;
}

virtual ~IStatementPutGet()
{

Expand Down
119 changes: 113 additions & 6 deletions tests/test_simple_put.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "snowflake/IStatementPutGet.hpp"
#include "StatementPutGet.hpp"
#include "FileTransferAgent.hpp"
#include "boost/filesystem.hpp"

#define COLUMN_STATUS "STATUS"
#define COLUMN_SOURCE "SOURCE"
Expand All @@ -27,6 +28,54 @@
#define MAX_BUF_SIZE 4096

using namespace ::Snowflake::Client;
using namespace boost::filesystem;

#ifdef _WIN32
static std::string PLATFORM_STR = "\xe9";
#else
static std::string PLATFORM_STR = "é";
#endif
static std::string UTF8_STR = "\xc3\xa9";

bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ) {
bool replaced = false;
std::size_t i = str.find( replaceThis );
while( i != std::string::npos ) {
replaced = true;
str = str.substr( 0, i ) + withThis + str.substr( i+replaceThis.size() );
if( i < str.size()-withThis.size() )
i = str.find( replaceThis, i+withThis.size() );
else
i = std::string::npos;
}
return replaced;
}

namespace Snowflake
{
namespace Client
{
class StatementPutGetUnicode : public Snowflake::Client::StatementPutGet
{
public:
StatementPutGetUnicode(SF_STMT *stmt) : StatementPutGet(stmt) {}
virtual std::string UTF8ToPlatformString(const std::string& utf8_str)
{
std::string result = utf8_str;
replaceInPlace(result, UTF8_STR, PLATFORM_STR);
return result;
}

virtual std::string platformStringToUTF8(const std::string& platform_str)
{
std::string result = platform_str;
replaceInPlace(result, PLATFORM_STR, UTF8_STR);
return result;
}
};

}
}

//File list to be made available to re-upload.
static std::vector<std::string> fileList;
Expand Down Expand Up @@ -64,11 +113,13 @@ void test_simple_put_core(const char * fileName,
bool useS3regionalUrl = false,
int compressLevel = -1,
bool overwrite = false,
SF_CONNECT * connection = nullptr)
SF_CONNECT * connection = nullptr,
bool testUnicode = false)
{
/* init */
SF_STATUS status;
SF_CONNECT *sf;

if (!connection) {
sf = setup_snowflake_connection();
status = snowflake_connect(sf);
Expand Down Expand Up @@ -104,6 +155,12 @@ void test_simple_put_core(const char * fileName,
std::string dataDir = TestSetup::getDataDir();
std::string file = dataDir + fileName;
std::string putCommand = "put file://" + file + " @%test_small_put";
if (testUnicode)
{
replaceInPlace(file, "\\", "\\\\");
putCommand = "put 'file://" + file + "' @%test_small_put";
}

if(createDupTable)
{
putCommand = "put file://" + std::string(fileName) + " @%test_small_put_dup";
Expand Down Expand Up @@ -132,8 +189,17 @@ void test_simple_put_core(const char * fileName,
{
putCommand += " overwrite=true";
}
std::unique_ptr<IStatementPutGet> stmtPutGet = std::unique_ptr
<StatementPutGet>(new Snowflake::Client::StatementPutGet(sfstmt));
std::unique_ptr<IStatementPutGet> stmtPutGet;
if (testUnicode)
{
stmtPutGet = std::unique_ptr
<StatementPutGetUnicode>(new Snowflake::Client::StatementPutGetUnicode(sfstmt));
}
else
{
stmtPutGet = std::unique_ptr
<StatementPutGet>(new Snowflake::Client::StatementPutGet(sfstmt));
}

TransferConfig transConfig;
TransferConfig * transConfigPtr = nullptr;
Expand Down Expand Up @@ -282,7 +348,7 @@ static int teardown(void **unused)
}

void test_simple_get_data(const char *getCommand, const char *size,
long getThreshold = 0)
long getThreshold = 0, bool testUnicode = false)
{
/* init */
SF_STATUS status;
Expand All @@ -296,8 +362,17 @@ void test_simple_get_data(const char *getCommand, const char *size,
/* query */
sfstmt = snowflake_stmt(sf);

std::unique_ptr<IStatementPutGet> stmtPutGet = std::unique_ptr
<StatementPutGet>(new Snowflake::Client::StatementPutGet(sfstmt));
std::unique_ptr<IStatementPutGet> stmtPutGet;
if (testUnicode)
{
stmtPutGet = std::unique_ptr
<StatementPutGetUnicode>(new Snowflake::Client::StatementPutGetUnicode(sfstmt));
}
else
{
stmtPutGet = std::unique_ptr
<StatementPutGet>(new Snowflake::Client::StatementPutGet(sfstmt));
}

TransferConfig transConfig;
TransferConfig * transConfigPtr = nullptr;
Expand Down Expand Up @@ -1502,6 +1577,37 @@ void test_upload_file_to_stage_using_stream(void **unused)
snowflake_term(sf);
}

void test_put_get_with_unicode(void **unused)
{
std::string dataDir = TestSetup::getDataDir();
std::string filename=PLATFORM_STR + ".csv";
copy_file(dataDir + "small_file.csv", dataDir + filename, copy_option::overwrite_if_exists);
filename = UTF8_STR + ".csv";
test_simple_put_core(
filename.c_str(), // filename
"auto", //source compression
true, // auto compress
true, // copyUploadFile
true, // verifyCopyUploadFile
false, // copyTableToStaging
false, // createDupTable
false, // setCustomThreshold
64 * 1024 * 1024, // customThreshold
false, // useDevUrand
false, // createSubfolder
nullptr, // tmpDir
false, // useS3regionalUrl
-1, // compressLevel
false, // overwrite
nullptr, // connection
true // testUnicode
);

std::string getcmd = std::string("get '@%test_small_put/") + UTF8_STR +".csv.gz'"
" file://" + TestSetup::getDataDir();
test_simple_get_data(getcmd.c_str(), "48", 0, true);
}

int main(void) {

#ifdef __APPLE__
Expand Down Expand Up @@ -1563,6 +1669,7 @@ int main(void) {
cmocka_unit_test_teardown(test_simple_put_with_proxy_fromenv, teardown),
cmocka_unit_test_teardown(test_simple_put_with_noproxy_fromenv, teardown),
cmocka_unit_test_teardown(test_upload_file_to_stage_using_stream, donothing),
cmocka_unit_test_teardown(test_put_get_with_unicode, teardown),
};
int ret = cmocka_run_group_tests(tests, gr_setup, gr_teardown);
return ret;
Expand Down
Loading

0 comments on commit 2c9f77c

Please sign in to comment.