From b8664ca609c563f57625a891aecd0ceb2a9910e7 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-hx Date: Thu, 24 Oct 2024 16:30:35 -0700 Subject: [PATCH] implementation --- CMakeLists.txt | 4 + cpp/FileTransferAgent.cpp | 34 +- cpp/lib/BindUploader.cpp | 897 ++++++++++------------------- cpp/lib/ClientBindUploader.cpp | 184 ++++++ cpp/lib/ClientBindUploader.hpp | 72 +++ cpp/lib/Exceptions.cpp | 124 +++- cpp/util/SnowflakeCommon.cpp | 26 + cpp/util/SnowflakeCommon.hpp | 24 +- include/snowflake/BindUploader.hpp | 333 ++++++----- include/snowflake/Exceptions.hpp | 136 ++++- include/snowflake/client.h | 39 +- lib/client.c | 244 ++++++-- lib/client_int.h | 48 ++ tests/test_bind_params.c | 183 ++++++ 14 files changed, 1496 insertions(+), 852 deletions(-) create mode 100644 cpp/lib/ClientBindUploader.cpp create mode 100644 cpp/lib/ClientBindUploader.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 82ccd84aa5..5184c733d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -215,6 +215,7 @@ set(SOURCE_FILES_CPP_WRAPPER include/snowflake/SFURL.hpp include/snowflake/CurlDesc.hpp include/snowflake/CurlDescPool.hpp + include/snowflake/BindUploader.hpp cpp/lib/Exceptions.cpp cpp/lib/Connection.cpp cpp/lib/Statement.cpp @@ -235,6 +236,9 @@ set(SOURCE_FILES_CPP_WRAPPER cpp/lib/ResultSetJson.hpp cpp/lib/Authenticator.cpp cpp/lib/Authenticator.hpp + cpp/lib/BindUploader.cpp + cpp/lib/ClientBindUploader.hpp + cpp/lib/ClientBindUploader.cpp cpp/jwt/jwtWrapper.cpp cpp/util/SnowflakeCommon.cpp cpp/util/SFURL.cpp diff --git a/cpp/FileTransferAgent.cpp b/cpp/FileTransferAgent.cpp index cc0fbaee90..fc2ce86613 100755 --- a/cpp/FileTransferAgent.cpp +++ b/cpp/FileTransferAgent.cpp @@ -17,6 +17,7 @@ #include "crypto/Cryptor.hpp" #include "util/CompressionUtil.hpp" #include "util/ThreadPool.hpp" +#include "util/SnowflakeCommon.hpp" #include "EncryptionProvider.hpp" #include "logger/SFLogger.hpp" #include "error.h" @@ -31,35 +32,11 @@ using ::std::string; using ::std::vector; using ::Snowflake::Client::RemoteStorageRequestOutcome; +using namespace Snowflake::Client::Util; namespace { const std::string FILE_PROTOCOL = "file://"; - - void replaceStrAll(std::string& stringToReplace, - std::string const& oldValue, - std::string const& newValue) - { - size_t oldValueLen = oldValue.length(); - size_t newValueLen = newValue.length(); - if (0 == oldValueLen) - { - return; - } - - size_t index = 0; - while (true) { - /* Locate the substring to replace. */ - index = stringToReplace.find(oldValue, index); - if (index == std::string::npos) break; - - /* Make the replacement. */ - stringToReplace.replace(index, oldValueLen, newValue); - - /* Advance index forward so the next iteration doesn't pick it up as well. */ - index += newValueLen; - } - } } Snowflake::Client::FileTransferAgent::FileTransferAgent( @@ -968,6 +945,8 @@ using namespace Snowflake::Client; extern "C" { SF_STATUS STDCALL _snowflake_execute_put_get_native( SF_STMT* sfstmt, + void* upload_stream, + size_t stream_size, struct SF_QUERY_RESULT_CAPTURE* result_capture) { if (!sfstmt) @@ -996,6 +975,11 @@ extern "C" { agent.setGetMaxRetries(sfconn->get_maxretries); agent.setRandomDeviceAsUrand(sfconn->put_use_urand_dev); + if (upload_stream) + { + agent.setUploadStream((std::basic_iostream*)upload_stream, stream_size); + } + ITransferResult* result; try { diff --git a/cpp/lib/BindUploader.cpp b/cpp/lib/BindUploader.cpp index 4a8c81d56b..e990c0e749 100644 --- a/cpp/lib/BindUploader.cpp +++ b/cpp/lib/BindUploader.cpp @@ -1,24 +1,18 @@ /* - * File: BindUploader.cpp - * Author: harryx - * - * Copyright (c) 2020 Snowflake Computing - * - * Created on March 5, 2020, 3:14 PM + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. */ + #include #include +#include "zlib.h" + +#include "snowflake/BindUploader.hpp" +#include "../logger/SFLogger.hpp" +#include "snowflake/basic_types.h" +#include "snowflake/SF_CRTFunctionSafe.h" +#include "../util/SnowflakeCommon.hpp" -#include "BindUploader.hpp" -#include "picojson.h" -#include "Logger.hpp" -#include "NumberConverter.h" -#include "Mutex.hpp" -#include "TDWTime.h" -#include "TDWDate.h" -#include "TDWTimestamp.h" -#include "Platform/DataConversion.hpp" #ifdef _WIN32 # include # include @@ -31,688 +25,395 @@ #define WINDOW_BIT 15 #define GZIP_ENCODING 16 +using namespace Snowflake::Client::Util; + namespace { - static const simba_wstring STAGE_NAME(L"SYSTEM$BIND"); + static const std::string STAGE_NAME("SYSTEM$BIND"); - static const simba_wstring CREATE_STAGE_STMT( - L"CREATE TEMPORARY STAGE " + static const std::string CREATE_STAGE_STMT( + "CREATE TEMPORARY STAGE " + STAGE_NAME - + L" file_format=(" - + L" type=csv" - + L" field_optionally_enclosed_by='\"'" - + L")"); - - static const simba_wstring PUT_STMT( - L"PUT" - L" file://%s" // argument 1: fake file name - L" '%s'" // argument 2: stage path - L" overwrite=true" // skip file existence check - L" auto_compress=false" // we compress already - L" source_compression=gzip" // (with gzip) + + " file_format=(" + + " type=csv" + + " field_optionally_enclosed_by='\"'" + + ")"); + + static const std::string PUT_STMT( + "PUT" + " file://%s" // argument 1: fake file name + " '%s'" // argument 2: stage path + " overwrite=true" // skip file existence check + " auto_compress=false" // we compress already + " source_compression=gzip" // (with gzip) ); static const unsigned int PUT_RETRY_COUNT = 3; } -namespace sf +namespace Snowflake { - using namespace picojson; - using namespace Simba::Support; - - BindUploader::BindUploader(Connection &connection, const simba_wstring& stageDir, - unsigned int numParams, unsigned int numParamSets, - int compressLevel, bool injectError) : - m_connection(connection), - m_stagePath(L"@" + STAGE_NAME + L"/" + stageDir + L"/"), - m_fileNo(0), - m_retryCount(PUT_RETRY_COUNT), - m_maxFileSize(connection.getStageBindMaxFileSize()), - m_numParams(numParams), - m_numParamSets(numParamSets), - m_curParamIndex(0), - m_curParamSetIndex(0), - m_dataSize(0), - m_startTime(std::chrono::steady_clock::now()), - m_serializeStartTime(std::chrono::steady_clock::now()), - m_compressTime(0), - m_serializeTime(0), - m_putTime(0), - m_hasBindingUploaded(false), - m_compressLevel(compressLevel), - m_injectError(injectError) - { - SF_TRACE_LOG("sf", "BindUploader", "BindUploader", - "Constructing BindUploader%s", ""); - } +namespace Client +{ +BindUploader::BindUploader(const std::string& stageDir, + unsigned int numParams, unsigned int numParamSets, + unsigned int maxFileSize, + int compressLevel) : + m_stagePath("@" + STAGE_NAME + "/" + stageDir + "/"), + m_fileNo(0), + m_retryCount(PUT_RETRY_COUNT), + m_maxFileSize(maxFileSize), + m_numParams(numParams), + m_numParamSets(numParamSets), + m_curParamIndex(0), + m_curParamSetIndex(0), + m_dataSize(0), + m_startTime(std::chrono::steady_clock::now()), + m_serializeStartTime(std::chrono::steady_clock::now()), + m_compressTime(0), + m_serializeTime(0), + m_putTime(0), + m_hasBindingUploaded(false), + m_compressLevel(compressLevel) +{ + CXX_LOG_TRACE("Constructing BindUploader: stageDir:%s, numParams: %d, numParamSets: %d, " + "maxFileSize: %d, compressLevel: %d", + stageDir.c_str(), numParams, numParamSets, + maxFileSize, compressLevel); +} - void BindUploader::putBinds() +void BindUploader::putBinds() +{ + // count serialize time since this function is called when serialization for + // one chunk is done + m_serializeTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - m_serializeStartTime).count(); + m_serializeStartTime = std::chrono::steady_clock::now(); + + createStageIfNeeded(); + auto compressStartTime = std::chrono::steady_clock::now(); + size_t compressedSize = compressWithGzip(); + m_compressTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - compressStartTime).count(); + + auto putStartTime = std::chrono::steady_clock::now(); + std::string filename = std::to_string(m_fileNo++); + while (m_retryCount > 0) { - // count serialize time since this function is called when serialization for - // one chunk is done - m_serializeTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - m_serializeStartTime).count(); - m_serializeStartTime = std::chrono::steady_clock::now(); - - createStageIfNeeded(); - auto compressStartTime = std::chrono::steady_clock::now(); - size_t compressedSize = compressWithGzip(); - m_compressTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - compressStartTime).count(); - - auto putStartTime = std::chrono::steady_clock::now(); - std::string filename = NumberConverter::ConvertUInt32ToString(m_fileNo++); - while (m_retryCount > 0) + std::string putStmt = getPutStmt(filename); + try { - std::string putStmt = getPutStmt(filename); - try - { - sf::Statement statement(m_connection); - statement.setUploadStream(m_compressStream, compressedSize); - statement.executeTransfer(putStmt); - m_hasBindingUploaded = true; - if (m_injectError && (m_fileNo == 1)) - { - // throw error on second chunk uploading to test the logic of fallback - // to regular binding - SF_THROWGEN1_NO_INCIDENT(L"SFFileTransferError", "Error injection."); - } - break; - } - catch (...) - { - SF_TRACE_LOG("sf", "BindUploader", "putBinds", - "Failed to upload array binds, retry%s", ""); - m_retryCount--; - if (0 == m_retryCount) - { - SF_TRACE_LOG("sf", "BindUploader", "putBinds", - "Failed to upload array binds with all retry%s", ""); - throw; - } - } + executeUploading(putStmt, m_compressStream, compressedSize); + m_hasBindingUploaded = true; + break; } - m_putTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - putStartTime).count(); - - m_csvStream = std::stringstream(); - m_dataSize = 0; - if (m_curParamSetIndex >= m_numParamSets) + catch (...) { - auto totalTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_startTime).count(); - SF_INFO_LOG("sf", "BindUploader", "addStringValue", - "total time: %ld, serialize time: %d, compress time: %ld, put time %ld", totalTime, m_serializeTime, m_compressTime, m_putTime); + CXX_LOG_WARN("BindUploader::putBinds: Failed to upload array binds, retry"); + m_retryCount--; + if (0 == m_retryCount) + { + CXX_LOG_ERROR("BindUploader::putBinds: Failed to upload array binds with all retry"); + throw; + } } } + m_putTime += std::chrono::duration_cast(std::chrono::steady_clock::now() - putStartTime).count(); - size_t BindUploader::compressWithGzip() + m_csvStream = std::stringstream(); + m_dataSize = 0; + if (m_curParamSetIndex >= m_numParamSets) { - int ret, flush; - unsigned have; - z_stream strm; - unsigned char in[CHUNK]; - unsigned char out[CHUNK]; - - m_compressStream = std::stringstream(); - m_csvStream.seekg(0); - - /* allocate deflate state */ - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; - ret = deflateInit2(&strm, m_compressLevel, Z_DEFLATED, - WINDOW_BIT | GZIP_ENCODING, 8, Z_DEFAULT_STRATEGY); - if (ret != Z_OK) - { - SF_TRACE_LOG("sf", "BindUploader", "compressWithGzip", - "Compression initial failed with error code %d", ret); - throw; - } - - /* compress until end of source data */ - do - { - m_csvStream.read((char*)in, CHUNK); - strm.next_in = in; - strm.avail_in = m_csvStream.gcount(); - flush = m_csvStream.eof() ? Z_FINISH : Z_NO_FLUSH; - - /* run deflate() on input until output buffer not full, finish - compression if all of source has been read in */ - do - { - strm.avail_out = CHUNK; - strm.next_out = out; - ret = deflate(&strm, flush); /* no bad return value */ - assert(ret != Z_STREAM_ERROR); /* state not clobbered */ - have = CHUNK - strm.avail_out; - m_compressStream.write((char*)out, have); - } while (strm.avail_out == 0); - assert(strm.avail_in == 0); /* all input will be used */ - - /* done when last data in file processed */ - } while (flush != Z_FINISH); - assert(ret == Z_STREAM_END); /* stream will be complete */ - - size_t destSize = strm.total_out; - - /* clean up and return */ - (void)deflateEnd(&strm); - return destSize; + auto totalTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_startTime).count(); + CXX_LOG_INFO("BindUploader::putBinds: total time: %ld, serialize time: %d, compress time: %ld, put time %ld", + totalTime, m_serializeTime, m_compressTime, m_putTime); } +} - void BindUploader::createStageIfNeeded() +size_t BindUploader::compressWithGzip() +{ + int ret, flush; + unsigned have; + z_stream strm; + unsigned char in[CHUNK]; + unsigned char out[CHUNK]; + + m_compressStream = std::stringstream(); + m_csvStream.seekg(0); + + /* allocate deflate state */ + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + ret = deflateInit2(&strm, m_compressLevel, Z_DEFLATED, + WINDOW_BIT | GZIP_ENCODING, 8, Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { - // Check the flag without locking to get better performance. - if (m_connection.isArrayBindStageCreated()) - { - return; - } - - MutexGuard guard(m_connection.getArrayBindingMutex()); + CXX_LOG_TRACE("BindUploader: Compression initial failed with error code %d", ret); + throw; + } - // another thread may have created the session by the time we enter this block - // so check the flag again. - if (m_connection.isArrayBindStageCreated()) - { - return; - } + /* compress until end of source data */ + do + { + m_csvStream.read((char*)in, CHUNK); + strm.next_in = in; + strm.avail_in = m_csvStream.gcount(); + flush = m_csvStream.eof() ? Z_FINISH : Z_NO_FLUSH; - try - { - sf::Statement statement(m_connection); - statement.executeQuery(CREATE_STAGE_STMT.GetAsUTF8(), false, true); - m_connection.setArrayBindStageCreated(); - } - catch (...) + /* run deflate() on input until output buffer not full, finish + compression if all of source has been read in */ + do { - SF_TRACE_LOG("sf", "BindUploader", "createStageIfNeeded", - "Failed to create temporary stage for array binds.%s", ""); - throw; - } - } + strm.avail_out = CHUNK; + strm.next_out = out; + ret = deflate(&strm, flush); /* no bad return value */ + have = CHUNK - strm.avail_out; + m_compressStream.write((char*)out, have); + } while (strm.avail_out == 0); - std::string BindUploader::getPutStmt(const std::string& srcFilePath) - { - char strBuf[MAX_PATH * 2]; // *2 to make sure there is enough space - simba_sprintf(strBuf, sizeof(strBuf), PUT_STMT.GetAsUTF8().c_str(), - srcFilePath.c_str(), getStagePath().c_str()); + /* done when last data in file processed */ + } while (flush != Z_FINISH); - return std::string(strBuf); - } + size_t destSize = strm.total_out; - std::string BindUploader::convertTimeFormat(const std::string& timeInNano) - { - simba_uint32 seconds; - simba_uint32 fraction; - int len = timeInNano.length(); - if (len < 10) - { - seconds = 0; - fraction = NumberConverter::ConvertStringToUInt32(timeInNano); - } - else - { - seconds = NumberConverter::ConvertStringToUInt32(timeInNano.substr(0, len - 9)); - fraction = NumberConverter::ConvertStringToUInt32(timeInNano.substr(len - 9)); - } + /* clean up and return */ + (void)deflateEnd(&strm); + return destSize; +} - simba_uint16 hour, min, sec; - hour = seconds / 3600; - seconds = seconds % 3600; - min = seconds / 60; - sec = seconds % 60; - TDWTime time(hour, min, sec, fraction); +std::string BindUploader::getPutStmt(const std::string& srcFilePath) +{ + char strBuf[MAX_PATH * 2]; // *2 to make sure there is enough space + sf_sprintf(strBuf, sizeof(strBuf), PUT_STMT.c_str(), + srcFilePath.c_str(), getStagePath().c_str()); - return time.ToString(9); - } + return std::string(strBuf); +} - std::string BindUploader::revertTimeFormat(const std::string& formatedTime) - { - TDWTime time(formatedTime); - std::string seconds = std::to_string(time.Hour * 3600 + time.Minute * 60 + time.Second); - std::string fraction = std::to_string(time.Fraction); - if (fraction.length() < 9) - { - fraction = std::string(9 - fraction.length(), '0') + fraction; - } - return seconds + fraction; - } +std::string BindUploader::getCreateStageStmt() +{ + return CREATE_STAGE_STMT; +} - std::string BindUploader::convertDateFormat(const std::string& millisecondSinceEpoch) +void BindUploader::addStringValue(const std::string& val, SF_DB_TYPE type) +{ + if (m_curParamIndex != 0) { - simba_int64 SecondsSinceEpoch = - NumberConverter::ConvertStringToInt64(millisecondSinceEpoch) / 1000; - TDWDate date = sf::DataConversions::parseSnowflakeDate(SecondsSinceEpoch); - return date.ToString(); + m_csvStream << ","; + m_dataSize++; } - - std::string BindUploader::revertDateFormat(const std::string& formatedDate) + else if (m_dataSize == 0) { - TDWDate date(formatedDate); - struct tm datetm; - datetm.tm_year = date.Year -1900; - datetm.tm_mon = date.Month - 1; - datetm.tm_mday = date.Day; - datetm.tm_hour = 0; - datetm.tm_min = 0; - datetm.tm_sec = 0; - - simba_int64 secondsSinceEpoch = (simba_int64)sf::DataConversions::sfchrono_timegm(&datetm); - return std::to_string(secondsSinceEpoch * 1000); + m_serializeStartTime = std::chrono::steady_clock::now(); } - std::string BindUploader::convertTimestampFormat(const std::string& timestampInNano, - simba_int16 type) + if (val.empty()) { - TDWExactNumericType totalFracSeconds(timestampInNano.c_str(), - timestampInNano.length(), - true); - totalFracSeconds.MultiplyByTenToThePowerOf(-9); - - sb8 seconds = totalFracSeconds.GetInt64(); - bool isTruncated; - bool isOutofRange; - simba_uint32 fraction = totalFracSeconds.GetFraction(isTruncated, isOutofRange, 9); - if (!totalFracSeconds.IsPositive() && (fraction != 0)) - { - seconds--; - fraction = 1000000000 - fraction; - } - - TDWTimestamp timestamp; - LogicalType_t ltype; - if (type == SQL_SF_TIMESTAMP_NTZ) - { - ltype = LTY_TIMESTAMP_NTZ; - } - else - { - ltype = LTY_TIMESTAMP_LTZ; - } - - timestamp = sf::DataConversions::parseSnowflakeTimestamp( - seconds, - fraction, - ltype, - 9, - true, - true); - - // Get the local time offset - tm tmV; - time_t timeV = (time_t)seconds; - int offset = 0; - sf::DataConversions::sfchrono_localtime(&timeV, &tmV); -#if defined(WIN32) || defined(_WIN64) - sb8 localEpoch = (sf::sb8)(sf::DataConversions::sfchrono_timegm(&tmV)); - offset = (int)(localEpoch - (sf::sb8)seconds); -#else - offset = tmV.tm_gmtoff; -#endif - int tzh = offset / 3600; - int tzm = (offset - (tzh * 3600)) / 60; - std::ostringstream stz; - stz << ((offset < 0) ? "-" : "+") - << std::setfill('0') << std::setw(2) << abs(tzh) - << ":" << std::setfill('0') << std::setw(2) << abs(tzm); - - return timestamp.ToString() + " " + stz.str(); + m_csvStream << "\"\""; // an empty string => an empty string with quotes + m_dataSize += sizeof("\"\""); } - - std::string BindUploader::revertTimestampFormat(const std::string& formatedtTimestamp, simba_int16 type) + else { - // separate timestamp and timezone information - // this is reverting the output from convertTimestampFormat so we should - // always have timezone part lead with a space - size_t timezonePos = formatedtTimestamp.rfind(' '); - if (timezonePos == std::string::npos) + switch (type) { - // not possible but just in case - return ""; - } - - simba_wstring timestampStr = formatedtTimestamp.substr(0, timezonePos); - timestampStr.Trim(); - TDWTimestamp timestamp(timestampStr); - struct tm gmttm; - gmttm.tm_year = timestamp.Year - 1900; - gmttm.tm_mon = timestamp.Month - 1; - gmttm.tm_mday = timestamp.Day; - gmttm.tm_hour = timestamp.Hour; - gmttm.tm_min = timestamp.Minute; - gmttm.tm_sec = timestamp.Second; - simba_int64 secondsSinceEpoch = (simba_int64)sf::DataConversions::sfchrono_timegm(&gmttm); - - // For local timezone add timezone information to get gmt time. - if (type != SQL_SF_TIMESTAMP_NTZ) - { - simba_wstring timezoneStr = formatedtTimestamp.substr(timezonePos); - timezoneStr.Trim(); - bool isTimezoneSigned = false; - if ((timezoneStr.GetLength() > 0) && - ((timezoneStr.CharAt(0) == '+') || (timezoneStr.CharAt(0) == '-'))) + case SF_DB_TYPE_TIME: { - if (timezoneStr.CharAt(0) == '-') - { - isTimezoneSigned = true; - } - timezoneStr = timezoneStr.Substr(1) + ":00"; + std::string timeStr = convertTimeFormat(val); + m_csvStream << timeStr; + m_dataSize += timeStr.length(); + break; } - TDWTime timezone(timezoneStr); - int timezoneSeconds = timezone.Hour * 3600 + timezone.Minute * 60; - if (isTimezoneSigned) + + case SF_DB_TYPE_DATE: { - timezoneSeconds *= -1; + std::string dateStr = convertDateFormat(val); + m_csvStream << dateStr; + m_dataSize += dateStr.length(); + break; } - secondsSinceEpoch -= timezoneSeconds; - } - - // If the seconds is negative, convert fraction to negative value as well. - simba_uint32 fraction = timestamp.Fraction; - if ((secondsSinceEpoch < 0) && (fraction > 0)) - { - fraction = 1000000000 - fraction; - secondsSinceEpoch++; - } - - // return string in nano seconds combining second and fraction parts - std::string fractionStr = std::to_string(fraction); - if (fractionStr.length() < 9) - { - fractionStr = std::string(9 - fractionStr.length(), '0') + fractionStr; - } - - return std::to_string(secondsSinceEpoch) + fractionStr; - } - void BindUploader::addStringValue(const std::string& val, simba_int16 type) - { - if (m_curParamIndex != 0) - { - m_csvStream << ","; - m_dataSize++; - } - else if (m_dataSize == 0) - { - m_serializeStartTime = std::chrono::steady_clock::now(); - } - - if (val.empty()) - { - m_csvStream << "\"\""; // an empty string => an empty string with quotes - m_dataSize += sizeof("\"\""); - } - else - { - switch (type) + case SF_DB_TYPE_TIMESTAMP_LTZ: + case SF_DB_TYPE_TIMESTAMP_NTZ: + case SF_DB_TYPE_TIMESTAMP_TZ: { - case SQL_TYPE_TIME: - case SQL_TIME: - { - std::string timeStr = convertTimeFormat(val); - m_csvStream << timeStr; - m_dataSize += timeStr.length(); - break; - } - - case SQL_TYPE_DATE: - case SQL_DATE: - { - std::string dateStr = convertDateFormat(val); - m_csvStream << dateStr; - m_dataSize += dateStr.length(); - break; - } + std::string timestampStr = convertTimestampFormat(val, type); + m_csvStream << timestampStr; + m_dataSize += timestampStr.length(); + break; + } - case SQL_TYPE_TIMESTAMP: - case SQL_TIMESTAMP: - case SQL_SF_TIMESTAMP_LTZ: - case SQL_SF_TIMESTAMP_NTZ: + default: + { + if (std::string::npos == val.find_first_of("\"\n,\\")) { - std::string timestampStr = convertTimestampFormat(val, type); - m_csvStream << timestampStr; - m_dataSize += timestampStr.length(); - break; + m_csvStream << val; + m_dataSize += val.length(); } - - default: + else { - if (std::string::npos == val.find_first_of("\"\n,\\")) - { - m_csvStream << val; - m_dataSize += val.length(); - } - else - { - simba_wstring escapeSimbaStr(val); - escapeSimbaStr.Replace("\"", "\"\""); - escapeSimbaStr = "\"" + escapeSimbaStr + "\""; - std::string escapeStr = escapeSimbaStr.GetAsUTF8(); + std::string escapeStr(val); + replaceStrAll(escapeStr, "\"", "\"\""); + escapeStr = "\"" + escapeStr + "\""; - m_csvStream << escapeStr; - m_dataSize += escapeStr.length(); - } - break; + m_csvStream << escapeStr; + m_dataSize += escapeStr.length(); } + break; } } + } - // The last column in the current row, add new line - // Also upload the data as needed. - if (++m_curParamIndex >= m_numParams) + // The last column in the current row, add new line + // Also upload the data as needed. + if (++m_curParamIndex >= m_numParams) + { + m_csvStream << "\n"; + m_dataSize++; + m_curParamIndex = 0; + m_curParamSetIndex++; + + // Upload data when exceed file size limit or all rows are added + if ((m_dataSize >= m_maxFileSize) || + (m_curParamSetIndex >= m_numParamSets)) { - m_csvStream << "\n"; - m_dataSize++; - m_curParamIndex = 0; - m_curParamSetIndex++; - - // Upload data when exceed file size limit or all rows are added - if ((m_dataSize >= m_maxFileSize) || - (m_curParamSetIndex >= m_numParamSets)) - { - putBinds(); - } + putBinds(); } } +} - void BindUploader::addNullValue() +void BindUploader::addNullValue() +{ + if (m_curParamIndex != 0) { - if (m_curParamIndex != 0) - { - m_csvStream << ","; - m_dataSize++; - } + m_csvStream << ","; + m_dataSize++; + } - // The last column in the current row, add new line - // Also upload the data as needed. - if (++m_curParamIndex >= m_numParams) + // The last column in the current row, add new line + // Also upload the data as needed. + if (++m_curParamIndex >= m_numParams) + { + m_csvStream << "\n"; + m_dataSize++; + m_curParamIndex = 0; + m_curParamSetIndex++; + + // Upload data when exceed file size limit or all rows are added + if ((m_dataSize >= m_maxFileSize) || + (m_curParamSetIndex >= m_numParamSets)) { - m_csvStream << "\n"; - m_dataSize++; - m_curParamIndex = 0; - m_curParamSetIndex++; - - // Upload data when exceed file size limit or all rows are added - if ((m_dataSize >= m_maxFileSize) || - (m_curParamSetIndex >= m_numParamSets)) - { - putBinds(); - } + putBinds(); } } +} - bool BindUploader::csvGetNextField(std::string& fieldValue, - bool& isNull, bool& isEndOfRow) - { - char c; +bool BindUploader::csvGetNextField(std::string& fieldValue, + bool& isNull, bool& isEndOfRow) +{ + char c; - // the flag indecate if currently in a quoted value - bool inQuote = false; - // the flag indecate if the value has been quoted, quoted empty string is - // empty value (like ,"",) while unquoted empty string is null (like ,,) - bool quoted = false; - // the flag indecate a value is found to end the loop - bool done = false; - // the flag indicate the next char already fetched by checking double quote escape ("") - bool nextCharFetched = false; + // the flag indecate if currently in a quoted value + bool inQuote = false; + // the flag indecate if the value has been quoted, quoted empty string is + // empty value (like ,"",) while unquoted empty string is null (like ,,) + bool quoted = false; + // the flag indecate a value is found to end the loop + bool done = false; + // the flag indicate the next char already fetched by checking double quote escape ("") + bool nextCharFetched = false; - fieldValue.clear(); + fieldValue.clear(); - if (!m_csvStream.get(c)) - { - return false; - } + if (!m_csvStream.get(c)) + { + return false; + } - while (!done) + while (!done) + { + switch (c) { - switch (c) + case ',': { - case ',': + if (!inQuote) { - if (!inQuote) - { - done = true; - } - else - { - fieldValue.push_back(c); - } - break; + done = true; + } + else + { + fieldValue.push_back(c); } + break; + } - case '\n': + case '\n': + { + if (!inQuote) { - if (!inQuote) - { - done = true; - isEndOfRow = true; - } - else - { - fieldValue.push_back(c); - } - break; + done = true; + isEndOfRow = true; } + else + { + fieldValue.push_back(c); + } + break; + } - case '\"': + case '\"': + { + if (!inQuote) + { + quoted = true; + inQuote = true; + } + else { - if (!inQuote) + if (!m_csvStream.get(c)) { - quoted = true; - inQuote = true; + isEndOfRow = true; + done = true; } else { - if (!m_csvStream.get(c)) + if (c == '\"') { - isEndOfRow = true; - done = true; + // escape double qoute in quoted string + fieldValue.push_back(c); } else { - if (c == '\"') - { - // escape double qoute in quoted string - fieldValue.push_back(c); - } - else - { - inQuote = false; - nextCharFetched = true; - } + inQuote = false; + nextCharFetched = true; } } - - break; } - default: - { - fieldValue.push_back(c); - } + break; } - if ((!done) && (!nextCharFetched)) + default: { - if (!m_csvStream.get(c)) - { - isEndOfRow = true; - break; - } - } - else - { - nextCharFetched = false; + fieldValue.push_back(c); } } - isNull = (fieldValue.empty() && !quoted); - return true; - } - - void BindUploader::convertBindingFromCsvToJson( - std::vector& paramBindOrder, jsonObject_t& parameterBinds) - { - bool endOfData = false; - m_csvStream.seekg(0); - - while (!endOfData) + if ((!done) && (!nextCharFetched)) { - bool endOfRow = false; - for (size_t i = 0; i < paramBindOrder.size(); i++) + if (!m_csvStream.get(c)) { - std::string fieldValue; - bool isNull = false; - std::string bindName = paramBindOrder[i]; - jsonObject_t& bind = parameterBinds[bindName].get(); - jsonArray_t& valueList = bind["value"].get(); - const std::string& colType = bind["type"].get(); - - // Should not happen as we are parsing the result generated by - // addStringValue()/addNullValue(), fill missing fields with null - // just in case - if (endOfData || endOfRow) - { - valueList.push_back(value()); - continue; - } - - if (!csvGetNextField(fieldValue, isNull, endOfRow)) - { - endOfData = true; - if (i == 0) - { - // Normal case of reaching end of data - break; - } - // Should not happen but just fill the missing data with null - valueList.push_back(value()); - continue; - } - - if (isNull) - { - valueList.push_back(value()); - continue; - } - - if (strcasecmp(colType.c_str(), "TIME") == 0) - { - fieldValue = revertTimeFormat(fieldValue); - } - else if (strcasecmp(colType.c_str(), "DATE") == 0) - { - fieldValue = revertDateFormat(fieldValue); - } - else if (strcasecmp(colType.c_str(), "TIMESTAMP") == 0) - { - fieldValue = revertTimestampFormat(fieldValue, SQL_SF_TIMESTAMP_LTZ); - } - else if (strcasecmp(colType.c_str(), "TIMESTAMP_NTZ") == 0) - { - fieldValue = revertTimestampFormat(fieldValue, SQL_SF_TIMESTAMP_NTZ); - } - - valueList.push_back(value(fieldValue)); + isEndOfRow = true; + break; } } + else + { + nextCharFetched = false; + } } -} // namespace sf + + isNull = (fieldValue.empty() && !quoted); + return true; +} + +} // namespace Client +} // namespace Snowflake diff --git a/cpp/lib/ClientBindUploader.cpp b/cpp/lib/ClientBindUploader.cpp new file mode 100644 index 0000000000..ee26cc63ab --- /dev/null +++ b/cpp/lib/ClientBindUploader.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. + */ + + +#include +#include + +#include "ClientBindUploader.hpp" +#include "../logger/SFLogger.hpp" +#include "snowflake/basic_types.h" +#include "snowflake/SF_CRTFunctionSafe.h" +#include "../util/SnowflakeCommon.hpp" +#include "snowflake/Exceptions.hpp" +#include "client_int.h" +#include "results.h" +#include "error.h" + +namespace Snowflake +{ +namespace Client +{ +ClientBindUploader::ClientBindUploader(SF_STMT *sfstmt, + const std::string& stageDir, + unsigned int numParams, unsigned int numParamSets, + unsigned int maxFileSize, + int compressLevel) : + BindUploader(stageDir, numParams, numParamSets, maxFileSize, compressLevel) +{ + if (!sfstmt || !sfstmt->connection) + { + SNOWFLAKE_THROW("BindUploader:: Invalid statement"); + } + SF_STATUS ret; + m_stmt = snowflake_stmt(sfstmt->connection); + if (sfstmt == NULL) { + SET_SNOWFLAKE_ERROR( + &sfstmt->error, + SF_STATUS_ERROR_OUT_OF_MEMORY, + "Out of memory in creating SF_STMT. ", + SF_SQLSTATE_UNABLE_TO_CONNECT); + + SNOWFLAKE_THROW_S(&sfstmt->error); + } +} + +ClientBindUploader::~ClientBindUploader() +{ + if (m_stmt) + { + snowflake_stmt_term(m_stmt); + } +} + +void ClientBindUploader::createStageIfNeeded() +{ + SF_CONNECT* conn = m_stmt->connection; + // Check the flag without locking to get better performance. + if (conn->binding_stage_created) + { + return; + } + + _mutex_lock(&conn->mutex_parameters); + if (conn->binding_stage_created) + { + _mutex_unlock(&conn->mutex_parameters); + return; + } + + std::string command = getCreateStageStmt(); + SF_STATUS ret = snowflake_query(m_stmt, command.c_str(), 0); + if (ret != SF_STATUS_SUCCESS) + { + _mutex_unlock(&conn->mutex_parameters); + SNOWFLAKE_THROW_S(&m_stmt->error); + } + + conn->binding_stage_created = SF_BOOLEAN_TRUE; + _mutex_unlock(&conn->mutex_parameters); +} + +void ClientBindUploader::executeUploading(const std::string &sql, + std::basic_iostream& uploadStream, + size_t dataSize) +{ + snowflake_prepare(m_stmt, sql.c_str(), 0); + SF_STATUS ret = _snowflake_execute_put_get_native(m_stmt, &uploadStream, dataSize, NULL); + if (ret != SF_STATUS_SUCCESS) + { + SNOWFLAKE_THROW_S(&m_stmt->error); + } +} + +} // namespace Client +} // namespace Snowflake + +extern "C" { + +using namespace Snowflake::Client; + +char* STDCALL +_snowflake_stage_bind_upload(SF_STMT* sfstmt) +{ + std::string bindStage; + try + { + ClientBindUploader uploader(sfstmt, sfstmt->request_id, + sfstmt->params_len, sfstmt->paramset_size, + SF_DEFAULT_STAGE_BINDING_MAX_FILESIZE, 0); + + const char* type; + char name_buf[SF_PARAM_NAME_BUF_LEN]; + char* name = NULL; + char* value = NULL; + struct bind_info { + SF_BIND_INPUT* input; + void* val_ptr; + int step; + }; + std::vector bindInfo; + for (unsigned int i = 0; i < sfstmt->params_len; i++) + { + SF_BIND_INPUT* input = _snowflake_get_binding_by_index(sfstmt, i, &name, + name_buf, SF_PARAM_NAME_BUF_LEN); + if (input == NULL) + { + log_error("_snowflake_execute_ex: No parameter by this name %s", name); + return NULL; + } + bindInfo.emplace_back(); + bindInfo.back().input = input; + bindInfo.back().val_ptr = input->value; + bindInfo.back().step = _snowflake_get_binding_value_size(input); + } + for (int64 i = 0; i < sfstmt->paramset_size; i++) + { + for (unsigned int j = 0; j < sfstmt->params_len; j++) + { + SF_BIND_INPUT* input = bindInfo[j].input; + void* val_ptr = bindInfo[j].val_ptr; + int val_len = input->len; + if (input->len_ind) + { + val_len = input->len_ind[i]; + } + + if (SF_BIND_LEN_NULL == val_len) + { + uploader.addNullValue(); + } + + if ((SF_C_TYPE_STRING == input->c_type) && + (SF_BIND_LEN_NTS == val_len)) + { + val_len = strlen((char*)val_ptr); + } + + value = value_to_string(val_ptr, val_len, input->c_type); + if (value) { + uploader.addStringValue(value, input->type); + SF_FREE(value); + } + bindInfo[j].val_ptr = (char*)bindInfo[j].val_ptr + bindInfo[j].step; + } + } + bindStage = uploader.getStagePath(); + } + catch (SnowflakeGeneralException& e) + { + return NULL; + } + + if (!bindStage.empty()) + { + char* bind_stage = (char*) SF_CALLOC(1, bindStage.size() + 1); + sf_strncpy(bind_stage, bindStage.size() + 1, bindStage.c_str(), bindStage.size()); + return bind_stage; + } + + return NULL; +} + +} // extern "C" diff --git a/cpp/lib/ClientBindUploader.hpp b/cpp/lib/ClientBindUploader.hpp new file mode 100644 index 0000000000..c81134fdfe --- /dev/null +++ b/cpp/lib/ClientBindUploader.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. + */ + +#pragma once +#ifndef SNOWFLAKECLIENT_CLIENTBINDUPLOADER_HPP +#define SNOWFLAKECLIENT_CLIENTBINDUPLOADER_HPP + +#include +#include "snowflake/client.h" +#include "snowflake/BindUploader.hpp" + +namespace Snowflake +{ +namespace Client +{ + +class ClientBindUploader : public BindUploader +{ +public: + /** + * constructor + * + * @param sfstmt The SNOWFLAKE_STMT context. + * @param stageDir The unique stage path for bindings uploading, could be a GUID. + * @param numParams Number of parameters. + * @param numParamSets Number of parameter sets. + * @param maxFileSize The max size of single file for bindings uploading. + * Separate into multiple files when exceed. + * @param compressLevel The compress level, between -1(default) to 9. + */ + explicit ClientBindUploader(SF_STMT *sfstmt, + const std::string& stageDir, + unsigned int numParams, + unsigned int numParamSets, + unsigned int maxFileSize, + int compressLevel); + + ~ClientBindUploader(); + +protected: + /** + * Check whether the session's temporary stage has been created, and create it + * if not. + * + * @throws Exception if creating the stage fails + */ + virtual void createStageIfNeeded() override; + + /** + * Execute uploading for single data file. + * + * @param sql PUT command for single data file uploading + * @param uploadStream stream for data file to be uploaded + * @param dataSize Size of the data to be uploaded. + * + * @throws Exception if uploading fails + */ + virtual void executeUploading(const std::string &sql, + std::basic_iostream& uploadStream, + size_t dataSize) override; + +private: + // SNOWFLAKE_STMT context + SF_STMT * m_stmt; + +}; + +} // namespace Client +} // namespace Snowflake + +#endif // SNOWFLAKECLIENT_CLIENTBINDUPLOADER_HPP diff --git a/cpp/lib/Exceptions.cpp b/cpp/lib/Exceptions.cpp index 28f834e885..dbe8fef856 100644 --- a/cpp/lib/Exceptions.cpp +++ b/cpp/lib/Exceptions.cpp @@ -1,4 +1,126 @@ /* - * Copyright (c) 2018-2019 Snowflake Computing, Inc. All rights reserved. + * Copyright (c) 2018-2024 Snowflake Computing, Inc. All rights reserved. */ +#include "snowflake/Exceptions.hpp" +#include "../logger/SFLogger.hpp" + +// helper functions +namespace +{ + std::string setupErrorMessage(const std::string& message, + const std::string& file, + int line, + const std::string& queryId, + const std::string& sqlState, + int code) + { + std::string errmsg = "Snowflake exception: "; + if (!file.empty()) + { + errmsg += file + ":" + std::to_string(line) + ", "; + } + if (!queryId.empty()) + { + errmsg += std::string("query ID: ") + queryId + ", "; + } + if (!sqlState.empty()) + { + errmsg += std::string("SQLState: ") + sqlState + ", "; + } + errmsg += std::string("error code :") + std::to_string(code) + ", "; + + errmsg += std::string("error message: ") + message; + + return errmsg; + } +} + +namespace Snowflake +{ +namespace Client +{ + +void SnowflakeException::setErrorMessage(const std::string& errmsg) +{ + m_errmsg = SFLogger::getMaskedMsg("%s", errmsg.c_str()); +} + +void SnowflakeException::setErrorMessage(const char* fmt, va_list args) +{ + m_errmsg = SFLogger::getMaskedMsgVA(fmt, args); +} + +SnowflakeGeneralException::SnowflakeGeneralException(SF_ERROR_STRUCT *error) : + m_message(error->msg ? error->msg : ""), + m_file(error->file ? error->file : ""), + m_line(error->line), + m_queryId(error->sfqid ? error->sfqid : ""), + m_sqlState(error->sqlstate ? error->sqlstate : ""), + m_code((int)error->error_code) +{ + std::string errmsg = setupErrorMessage(m_message, m_file, m_line, m_queryId, m_sqlState, m_code); + setErrorMessage(errmsg); +} + +SnowflakeGeneralException::SnowflakeGeneralException(const std::string& message, + const char* file, int line, + int code, + const std::string queryId, + const std::string sqlState) : + m_message(message), + m_file(file ? file : ""), + m_line(line), + m_queryId(queryId), + m_sqlState(sqlState), + m_code(code) +{ + std::string errmsg = setupErrorMessage(m_message, m_file, m_line, m_queryId, m_sqlState, m_code); + setErrorMessage(errmsg); +} + +SnowflakeGeneralException::SnowflakeGeneralException(const char* file, int line, + int code, + const std::string queryId, + const std::string sqlState, + const char* fmt, ...) +{ + va_list args; + va_start(args, fmt); + m_message = SFLogger::getMaskedMsgVA(fmt, args); + va_end(args); + m_errmsg = setupErrorMessage(m_message, m_file, m_line, m_queryId, m_sqlState, m_code); +} + +int SnowflakeGeneralException::code() +{ + return m_code; +} + +const char* SnowflakeGeneralException::sqlstate() +{ + return m_sqlState.c_str(); +} + +const char* SnowflakeGeneralException::msg() +{ + return m_message.c_str(); +} + +const char* SnowflakeGeneralException::sfqid() +{ + return m_queryId.c_str(); +} + +const char* SnowflakeGeneralException::file() +{ + return m_file.c_str(); +} + +int SnowflakeGeneralException::line() +{ + return m_line; +} + +} // namespace Client +} // namespace Snowflake diff --git a/cpp/util/SnowflakeCommon.cpp b/cpp/util/SnowflakeCommon.cpp index b12642b6c4..b66f3c129c 100644 --- a/cpp/util/SnowflakeCommon.cpp +++ b/cpp/util/SnowflakeCommon.cpp @@ -11,6 +11,7 @@ #include "snowflake/Proxy.hpp" #include "../logger/SFLogger.hpp" #include +#include "SnowflakeCommon.hpp" using namespace Snowflake; using namespace Snowflake::Client; @@ -144,3 +145,28 @@ uint64 sf_get_current_time_millis() } +void Snowflake::Client::Util::replaceStrAll(std::string& stringToReplace, + std::string const& oldValue, + std::string const& newValue) +{ + size_t oldValueLen = oldValue.length(); + size_t newValueLen = newValue.length(); + if (0 == oldValueLen) + { + return; + } + + size_t index = 0; + while (true) { + /* Locate the substring to replace. */ + index = stringToReplace.find(oldValue, index); + if (index == std::string::npos) break; + + /* Make the replacement. */ + stringToReplace.replace(index, oldValueLen, newValue); + + /* Advance index forward so the next iteration doesn't pick it up as well. */ + index += newValueLen; + } +} + diff --git a/cpp/util/SnowflakeCommon.hpp b/cpp/util/SnowflakeCommon.hpp index f151d1c691..4557ebb430 100644 --- a/cpp/util/SnowflakeCommon.hpp +++ b/cpp/util/SnowflakeCommon.hpp @@ -1,20 +1,28 @@ /* - * Copyright (c) 2018-2019 Snowflake Computing, Inc. All rights reserved. + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. */ #ifndef SNOWFLAKECLIENT_SNOWFLAKECOMMON_HPP #define SNOWFLAKECLIENT_SNOWFLAKECOMMON_HPP -#include +#include #include #include -// unsigned integer types -typedef uint8_t ub1; -typedef uint16_t ub2; -typedef uint32_t ub4; -typedef uint64_t ub8; -typedef uint128_t ub16; +/* CPP only utilities */ +namespace Snowflake +{ +namespace Client +{ +namespace Util +{ +void replaceStrAll(std::string& stringToReplace, + std::string const& oldValue, + std::string const& newValue); + +} // namespace Util +} // namespace Client +} // namespace Snowflake #endif //SNOWFLAKECLIENT_SNOWFLAKECOMMON_HPP diff --git a/include/snowflake/BindUploader.hpp b/include/snowflake/BindUploader.hpp index a81b3674b9..6754eea5a4 100644 --- a/include/snowflake/BindUploader.hpp +++ b/include/snowflake/BindUploader.hpp @@ -1,182 +1,217 @@ /* -* File: BindUploader.hpp -* Author: harryx -* -* Copyright (c) 2020 Snowflake Computing -* -* Created on March 5, 2020, 3:14 PM -*/ + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. + */ #pragma once -#ifndef BINDUPLOADER_HPP -#define BINDUPLOADER_HPP +#ifndef SNOWFLAKECLIENT_BINDUPLOADER_HPP +#define SNOWFLAKECLIENT_BINDUPLOADER_HPP -#include "picojson.h" -#include "Statement.hpp" -#include "Logger.hpp" +#include +#include +#include "client.h" -namespace sf +namespace Snowflake { - using namespace picojson; +namespace Client +{ + +class BindUploader +{ +public: + /** + * constructor + * + * @param stageDir The unique stage path for bindings uploading, could be a GUID. + * @param numParams Number of parameters. + * @param numParamSets Number of parameter sets. + * @param maxFileSize The max size of single file for bindings uploading. + * Separate into multiple files when exceed. + * @param compressLevel The compress level, between -1(default) to 9. + */ + explicit BindUploader(const std::string& stageDir, + unsigned int numParams, + unsigned int numParamSets, + unsigned int maxFileSize, + int compressLevel); + + void addStringValue(const std::string& value, SF_DB_TYPE type); + + void addNullValue(); + + inline std::string getStagePath() + { + return m_stagePath; + } + + inline bool hasBindingUploaded() + { + return m_hasBindingUploaded; + } + +protected: + /** + * @return The statement for creating temporary stage for bind uploading. + */ + std::string getCreateStageStmt(); + + /** + * Check whether the session's temporary stage has been created, and create it + * if not. + * + * @throws Exception if creating the stage fails + */ + virtual void createStageIfNeeded() = 0; + + /** + * Execute uploading for single data file. + * + * @param sql PUT command for single data file uploading + * @param uploadStream stream for data file to be uploaded + * @param dataSize Size of the data to be uploaded. + * + * @throws Exception if uploading fails + */ + virtual void executeUploading(const std::string &sql, + std::basic_iostream& uploadStream, + size_t dataSize) = 0; + + /* date/time format conversions to be overridden by drivers (such as ODBC) + * that need native date/time type support. + * Will be called to converting binding format between regular binding and + * bulk binding. + * No conversion by default, in such case application/driver should bind + * data/time data as string. + */ + + /** + * Convert time data format from nanoseconds to HH:MM:SS.F9 + * @param timeInNano The time data string in nanoseconds. + */ + virtual std::string convertTimeFormat(const std::string& timeInNano) + { + return timeInNano; + } + + /** + * Convert date data format from days to YYYY-MM-DD + * @param milliseconds since Epoch + */ + virtual std::string convertDateFormat(const std::string& millisecondSinceEpoch) + { + return millisecondSinceEpoch; + } /** - * Class BindUploader + * Convert timestamp data format from nanoseconds to YYYY_MM_DD HH:MM:SS.F9 + * @param timestampInNano The timestamp data string in nanoseconds. + * @param type Either TIMESTAMP_LTZ or NTZ depends on CLIENT_TIMESTAMP_TYPE_MAPPING */ - class BindUploader + virtual std::string convertTimestampFormat(const std::string& timestampInNano, + SF_DB_TYPE type) { - public: - explicit BindUploader(Connection &connection, - const simba_wstring& stageDir, - unsigned int numParams, - unsigned int numParamSets, - int compressLevel, - bool injectError); - - void addStringValue(const std::string& value, simba_int16 type); - - void addNullValue(); - - inline std::string getStagePath() - { - return m_stagePath.GetAsUTF8(); - } - - /** - * Convert binding data in csv format for stage binding into json format - * for regular binding. This is for fallback to regular binding when stage - * binding fails. - * @param paramBindOrder The bind order for parameters with parameter names. - * @param parameterBinds The output of parameter bindings in json - */ - void convertBindingFromCsvToJson(std::vector& paramBindOrder, - jsonObject_t& parameterBinds); - - inline bool hasBindingUploaded() - { - return m_hasBindingUploaded; - } - - private: - /** - * Upload serialized binds in CSV stream to stage - * - * @throws BindException if uploading the binds fails - */ - void putBinds(); - - /** - * Compress data from csv stream to compress stream with gzip - * @return The data size of compress stream if compress succeeded. - * @throw when compress failed. - */ - size_t compressWithGzip(); - - /** - * Check whether the session's temporary stage has been created, and create it - * if not. - * - * @throws Exception if creating the stage fails - */ - void createStageIfNeeded(); + return timestampInNano; + } - /** - * Build PUT statement string. Handle filesystem differences and escaping backslashes. - * @param srcFilePath The faked source file path to upload. - */ - std::string getPutStmt(const std::string& srcFilePath); - - /** - * Convert time data format from nanoseconds to HH:MM:SS.F9 - * @param timeInNano The time data string in nanoseconds. - */ - std::string convertTimeFormat(const std::string& timeInNano); - - /** - * Convert date data format from days to YYYY-MM-DD - * @param milliseconds since Epoch - */ - std::string convertDateFormat(const std::string& millisecondSinceEpoch); - - /** - * Convert timestamp data format from nanoseconds to YYYY_MM_DD HH:MM:SS.F9 - * @param timestampInNano The timestamp data string in nanoseconds. - * @param type Either SQL_SF_TIMESTAMP_LTZ or NTZ depends on CLIENT_TIMESTAMP_TYPE_MAPPING - */ - std::string convertTimestampFormat(const std::string& timestampInNano, - simba_int16 type); - - /** - * Revert time data format from HH:MM:SS.F9 to nanoseconds - * @param formatedTime The time data string in HH:MM:SS.F9. - */ - std::string revertTimeFormat(const std::string& formatedTime); - - /** - * Convert date data format from YYYY-MM-DD to milliseconds since Epoch - * @param formatedDate the date string in YYYY-MM-DD - */ - std::string revertDateFormat(const std::string& formatedDate); - - /** - * Convert timestamp data format from YYYY_MM_DD HH:MM:SS.F9 to nanoseconds - * @param Formatedtimestamp The timestamp data string in YYYY_MM_DD HH:MM:SS.F9. - * @param type Either SQL_SF_TIMESTAMP_LTZ or NTZ depends on CLIENT_TIMESTAMP_TYPE_MAPPING - */ - std::string revertTimestampFormat(const std::string& Formatedtimestamp, - simba_int16 type); - - /** - * csv parsing function called by convertBindingFromCsvToJson(), get value of - * next field. - * @param fieldValue The output of the field value. - * @param isNull The output of the flag whether the filed is null. - * @param isEndofRow The output of the flag wether the end of row is reached. - * @return true if a field value is retrieved successfully, false if end of data - * is reached and no field value available. - */ - bool csvGetNextField(std::string& fieldValue, bool& isNull, bool& isEndofRow); + /** + * Revert time data format from HH:MM:SS.F9 to nanoseconds + * @param formatedTime The time data string in HH:MM:SS.F9. + */ + virtual std::string revertTimeFormat(const std::string& formatedTime) + { + return formatedTime; + } - Connection &m_connection; + /** + * Convert date data format from YYYY-MM-DD to milliseconds since Epoch + * @param formatedDate the date string in YYYY-MM-DD + */ + virtual std::string revertDateFormat(const std::string& formatedDate) + { + return formatedDate; + } + + /** + * Convert timestamp data format from YYYY_MM_DD HH:MM:SS.F9 to nanoseconds + * @param Formatedtimestamp The timestamp data string in YYYY_MM_DD HH:MM:SS.F9. + * @param type Either TIMESTAMP_LTZ or NTZ depends on CLIENT_TIMESTAMP_TYPE_MAPPING + */ + virtual std::string revertTimestampFormat(const std::string& Formatedtimestamp, + SF_DB_TYPE type) + { + return Formatedtimestamp; + } + +private: + /** + * Upload serialized binds in CSV stream to stage + * + * @throws BindException if uploading the binds fails + */ + void putBinds(); + + /** + * Compress data from csv stream to compress stream with gzip + * @return The data size of compress stream if compress succeeded. + * @throw when compress failed. + */ + size_t compressWithGzip(); + + /** + * Build PUT statement string. Handle filesystem differences and escaping backslashes. + * @param srcFilePath The faked source file path to upload. + */ + std::string getPutStmt(const std::string& srcFilePath); + + /** + * csv parsing function called by convertBindingFromCsvToJson(), get value of + * next field. + * @param fieldValue The output of the field value. + * @param isNull The output of the flag whether the filed is null. + * @param isEndofRow The output of the flag wether the end of row is reached. + * @return true if a field value is retrieved successfully, false if end of data + * is reached and no field value available. + */ + bool csvGetNextField(std::string& fieldValue, bool& isNull, bool& isEndofRow); - std::stringstream m_csvStream; + std::stringstream m_csvStream; - std::stringstream m_compressStream; + std::stringstream m_compressStream; - simba_wstring m_stagePath; + std::string m_stagePath; - unsigned int m_fileNo; + unsigned int m_fileNo; - unsigned int m_retryCount; + unsigned int m_retryCount; - unsigned int m_maxFileSize; + unsigned int m_maxFileSize; - unsigned int m_numParams; + unsigned int m_numParams; - unsigned int m_numParamSets; + unsigned int m_numParamSets; - unsigned int m_curParamIndex; + unsigned int m_curParamIndex; - unsigned int m_curParamSetIndex; + unsigned int m_curParamSetIndex; - size_t m_dataSize; + size_t m_dataSize; - std::chrono::steady_clock::time_point m_startTime; + std::chrono::steady_clock::time_point m_startTime; - std::chrono::steady_clock::time_point m_serializeStartTime; + std::chrono::steady_clock::time_point m_serializeStartTime; - long long m_compressTime; + long long m_compressTime; - long long m_serializeTime; + long long m_serializeTime; - long long m_putTime; + long long m_putTime; - bool m_hasBindingUploaded; + bool m_hasBindingUploaded; - int m_compressLevel; + int m_compressLevel; - bool m_injectError; - }; +}; -} // namespace sf +} // namespace Client +} // namespace Snowflake -#endif // BINDUPLOADER_HPP +#endif // SNOWFLAKECLIENT_BINDUPLOADER_HPP diff --git a/include/snowflake/Exceptions.hpp b/include/snowflake/Exceptions.hpp index ea32f0a6b0..e8c0282623 100644 --- a/include/snowflake/Exceptions.hpp +++ b/include/snowflake/Exceptions.hpp @@ -1,38 +1,146 @@ /* - * Copyright (c) 2018-2019 Snowflake Computing, Inc. All rights reserved. + * Copyright (c) 2018-2024 Snowflake Computing, Inc. All rights reserved. */ #ifndef SNOWFLAKECLIENT_EXCEPTIONS_HPP #define SNOWFLAKECLIENT_EXCEPTIONS_HPP #include +#include #include "client.h" -class SnowflakeException: public std::exception { -public: - SnowflakeException(SF_ERROR_STRUCT *error); +namespace Snowflake +{ +namespace Client +{ - const char * what() const throw(); +class SnowflakeException: public std::exception +{ +public: + // Return error message combine all information + // sub-classes need to setup m_errmsg. + virtual const char* what() const noexcept override + { + return m_errmsg.c_str(); + } - SF_STATUS code(); + // optional properties sub-classes could choose what to override + // with information available + virtual int code() + { + return 0; + } - const char *sqlstate(); + virtual const char* sqlstate() + { + return ""; + } - const char *msg(); + // Return the original error message without other information (sqlstate etc.). + virtual const char* msg() + { + return ""; + } - const char *sfqid(); + virtual const char* sfqid() + { + return ""; + } - const char *file(); + virtual const char* file() + { + return ""; + } - int line(); + virtual int line() + { + return 0; + } protected: - SF_ERROR_STRUCT *error; + // update error message + void setErrorMessage(const std::string& errmsg); + + // update error message with formatted arguments + void setErrorMessage(const char* fmt, va_list args); + + std::string m_errmsg; }; -class GeneralException: public SnowflakeException { +class SnowflakeGeneralException: public SnowflakeException +{ public: - GeneralException(SF_ERROR_STRUCT *error) : SnowflakeException(error) {}; + SnowflakeGeneralException(SF_ERROR_STRUCT *error); + SnowflakeGeneralException(const std::string& message, + const char* file, int line, + int code = 0, + const std::string queryId = "", + const std::string sqlState = ""); + SnowflakeGeneralException(const char* file, int line, + int code, + const std::string queryId, + const std::string sqlState, + const char* fmt, ...); + + virtual int code() override; + + virtual const char* sqlstate() override; + + virtual const char* msg() override; + + virtual const char* sfqid() override; + + virtual const char* file() override; + + virtual int line() override; + +protected: + std::string m_message; + std::string m_file; + int m_line; + std::string m_queryId; + std::string m_sqlState; + int m_code; }; +// macro for throw general exception with SF_ERROR_STRUCT +#define SNOWFLAKE_THROW_S(error) \ +{ \ + throw SnowflakeGeneralException(error); \ +} + +// macro for throw general exception with error message +#define SNOWFLAKE_THROW(errmsg) \ +{ \ + throw SnowflakeGeneralException(errmsg, \ + __FILE__, __LINE__); \ +} + +// macro for throw general exception with more detail information +#define SNOWFLAKE_THROW_DETAIL(errmsg, code, qid, state) \ +{ \ + throw SnowflakeGeneralException(errmsg, \ + __FILE__, __LINE__, \ + code, qid, state); \ +} + +// macro for throw general exception with formatted arguments +#define SNOWFLAKE_THROW_FORMATTED(fmt, ...) \ +{ \ + throw SnowflakeGeneralException(__FILE__, __LINE__, \ + 0, "", "", \ + fmt, __VA_ARGS__); \ +} + +// macro for throw general exception with formatted arguments and detail information. +#define SNOWFLAKE_THROW_FORMATTED_DETAIL(code, qid, state, fmt, ...) \ +{ \ + throw SnowflakeGeneralException(__FILE__, __LINE__, \ + code, qid, state, \ + fmt, __VA_ARGS__); \ +} + +} // namespace Client +} // namespace Snowflake + #endif //SNOWFLAKECLIENT_EXCEPTIONS_HPP diff --git a/include/snowflake/client.h b/include/snowflake/client.h index c045795f6a..4009c3f7e3 100644 --- a/include/snowflake/client.h +++ b/include/snowflake/client.h @@ -282,6 +282,7 @@ typedef enum SF_ATTRIBUTE { SF_CON_GET_FASTFAIL, SF_CON_GET_MAXRETRIES, SF_CON_GET_THRESHOLD, + SF_CON_STAGE_BIND_THRESHOLD, SF_DIR_QUERY_URL, SF_DIR_QUERY_URL_PARAM, SF_DIR_QUERY_TOKEN, @@ -306,7 +307,8 @@ typedef enum SF_GLOBAL_ATTRIBUTE { */ typedef enum SF_STMT_ATTRIBUTE { SF_STMT_USER_REALLOC_FUNC, - SF_STMT_MULTI_STMT_COUNT + SF_STMT_MULTI_STMT_COUNT, + SF_STMT_PARAMSET_SIZE } SF_STMT_ATTRIBUTE; #define SF_MULTI_STMT_COUNT_UNSET (-1) #define SF_MULTI_STMT_COUNT_UNLIMITED 0 @@ -432,6 +434,14 @@ typedef struct SF_CONNECT { sf_bool get_fastfail; int8 get_maxretries; int64 get_threshold; + + // stage binding + sf_bool binding_stage_created; + uint64 stage_binding_threshold; + // the flag indecates the threshold from session parameter is overridden + // by the setting from connection attribute + sf_bool binding_threshold_overridden; + sf_bool stage_binding_disabled; } SF_CONNECT; /** @@ -517,6 +527,7 @@ typedef struct SF_STMT { sf_bool is_multi_stmt; void* multi_stmt_result_ids; int64 multi_stmt_count; + int64 paramset_size; /** * User realloc function used in snowflake_fetch @@ -529,14 +540,32 @@ typedef struct SF_STMT { /** * Bind input parameter context - */ + * Array binding (usually for insert/update multiple rows with one query) supported. + * To do that, value should be set to the array having multiple values, + * statement attribute SF_STMT_PARAMSET_SIZE set to the number of elements of the array + * in each binding. + * for SF_C_TYPE_STRING len should be set to the buffer length of each string value, + * NOT the entire length of the array. It would be used to find the start of each value. + * len_ind should be set to an array of length,indicating the actual length of each value. + * each length could be set to + * SF_BIND_LEN_NULL to indicate NULL data + * SF_BIND_LEN_NTS to indicate NULL terminated string (for SF_C_TYPE_STRING only). + * >= 0 for actual data length (for SF_C_TYPE_STRING only). + * len_ind could be omitted (set to NULL) as well if no NULL data, + * and for for SF_C_TYPE_STRING, all string values are null terminated. + */ + +#define SF_BIND_LEN_NULL -1 +#define SF_BIND_LEN_NTS -3 + typedef struct { size_t idx; /* One based index of the columns, 0 if Named */ char * name; /* Named Parameter name, NULL if positional */ SF_C_TYPE c_type; /* input data type in C */ - void *value; /* input value */ - size_t len; /* input value length. valid only for SF_C_TYPE_STRING */ + void *value; /* input value, could be array of multiple values */ + size_t len; /* The length of each input value. valid only for SF_C_TYPE_STRING */ SF_DB_TYPE type; /* (optional) target Snowflake data type */ + int* len_ind; /* (optional) The array of length indicator to support array binding*/ } SF_BIND_INPUT; /** @@ -875,6 +904,8 @@ uint64 STDCALL snowflake_num_params(SF_STMT *sfstmt); * * For Positional parameters: * SF_BIND_INPUT name = NULL; + * + * * * @param input preallocated SF_BIND_INPUT instance * @return void diff --git a/lib/client.c b/lib/client.c index f6bdc93cdf..3afe1e9b2d 100644 --- a/lib/client.c +++ b/lib/client.c @@ -186,6 +186,9 @@ static SF_STATUS STDCALL _reset_connection_parameters( else if (strcmp(name->valuestring, "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1") == 0) { sf->use_s3_regional_url = snowflake_cJSON_IsTrue(value) ? SF_BOOLEAN_TRUE : SF_BOOLEAN_FALSE; } + else if (strcmp(name->valuestring, "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD") == 0) { + sf->stage_binding_threshold = snowflake_cJSON_GetUint64Value(value); + } } } SF_STATUS ret = SF_STATUS_ERROR_GENERAL; @@ -764,6 +767,9 @@ SF_CONNECT *STDCALL snowflake_init() { sf->get_fastfail = SF_BOOLEAN_FALSE; sf->get_maxretries = SF_DEFAULT_GET_MAX_RETRIES; sf->get_threshold = SF_DEFAULT_GET_THRESHOLD; + + sf->binding_stage_created = SF_BOOLEAN_FALSE; + sf->stage_binding_threshold = SF_DEFAULT_STAGE_BINDING_THRESHOLD; } return sf; @@ -1852,6 +1858,170 @@ static void STDCALL _snowflake_deallocate_named_param_list(void * name_list) SF_FREE(name_list); } +#define SF_PARAM_NAME_BUF_LEN 20 +/** + * Get parameter binding by index for both POSITIONAL and NAMED cases. + * @param sfstmt SNOWFLAKE_STMT context. + * @param index The 0 based index of parameter binding to get. + * @param name Output the name of binding. + * @param name_buf The buffer to store name. + Used for POSITIONAL and name will point to this buffer in such case. + * @param name_buf_len The size of name_buf. + * @return parameter binding with specified index. + */ +SF_BIND_INPUT* STDCALL _snowflake_get_binding_by_index(SF_STMT* sfstmt, + size_t index, + char** name, + char* name_buf, + size_t name_buf_len) +{ + SF_BIND_INPUT* input = NULL; + if (_snowflake_get_current_param_style(sfstmt) == POSITIONAL) + { + input = (SF_BIND_INPUT*)sf_param_store_get(sfstmt->params, + index + 1, NULL); + sf_sprintf(name_buf, name_buf_len, "%lu", (unsigned long)(index + 1)); + *name = name_buf; + } + else if (_snowflake_get_current_param_style(sfstmt) == NAMED) + { + *name = (char*)(((NamedParams*)sfstmt->name_list)->name_list[index]); + input = (SF_BIND_INPUT*)sf_param_store_get(sfstmt->params, 0, *name); + } + + return input; +} + +/* + * @return size of single binding value per data type. + */ +size_t STDCALL _snowflake_get_binding_value_size(SF_BIND_INPUT* bind) +{ + switch (bind->c_type) + { + case SF_C_TYPE_INT8: + return sizeof (int8); + case SF_C_TYPE_UINT8: + return sizeof(uint8); + case SF_C_TYPE_INT64: + return sizeof(int64); + case SF_C_TYPE_UINT64: + return sizeof(uint64); + case SF_C_TYPE_FLOAT64: + return sizeof(float64); + case SF_C_TYPE_BOOLEAN: + return sizeof(sf_bool); + case SF_C_TYPE_BINARY: + case SF_C_TYPE_STRING: + return bind->len; + case SF_C_TYPE_TIMESTAMP: + // TODO Add timestamp case + case SF_C_TYPE_NULL: + default: + return 0; + } +} + +/** + * @param sfstmt SNOWFLAKE_STMT context. + * @return parameter bindings in cJSON. + */ +cJSON* STDCALL _snowflake_get_binding_json(SF_STMT* sfstmt) +{ + size_t i; + SF_BIND_INPUT* input; + const char* type; + char name_buf[SF_PARAM_NAME_BUF_LEN]; + char* name = NULL; + char* value = NULL; + cJSON* bindings = NULL; + + if (_snowflake_get_current_param_style(sfstmt) == INVALID_PARAM_TYPE) + { + return NULL; + } + bindings = snowflake_cJSON_CreateObject(); + for (i = 0; i < sfstmt->params_len; i++) + { + cJSON* binding; + input = _snowflake_get_binding_by_index(sfstmt, i, &name, + name_buf, SF_PARAM_NAME_BUF_LEN); + if (input == NULL) + { + log_error("_snowflake_execute_ex: No parameter by this name %s", name); + continue; + } + binding = snowflake_cJSON_CreateObject(); + type = snowflake_type_to_string( + c_type_to_snowflake(input->c_type, SF_DB_TYPE_TIMESTAMP_NTZ)); + if (sfstmt->paramset_size > 1) + { + cJSON* val_array = snowflake_cJSON_CreateArray(); + size_t step = _snowflake_get_binding_value_size(input); + void* val_ptr = input->value; + int64 val_len; + cJSON* single_val = NULL; + for (int64 j = 0; j < sfstmt->paramset_size; j++, val_ptr = (char*)val_ptr + step) + { + val_len = input->len; + if (input->len_ind) + { + val_len = input->len_ind[j]; + } + + if (SF_BIND_LEN_NULL == val_len) + { + single_val = snowflake_cJSON_CreateNull(); + snowflake_cJSON_AddItemToArray(val_array, single_val); + continue; + } + + if ((SF_C_TYPE_STRING == input->c_type) && + (SF_BIND_LEN_NTS == val_len)) + { + val_len = strlen((char*)val_ptr); + } + + value = value_to_string(val_ptr, val_len, input->c_type); + single_val = snowflake_cJSON_CreateString(value); + snowflake_cJSON_AddItemToArray(val_array, single_val); + if (value) { + SF_FREE(value); + } + } + snowflake_cJSON_AddItemToObject(binding, "value", val_array); + } + else // paramset_size = 1, single value binding + { + value = value_to_string(input->value, input->len, input->c_type); + snowflake_cJSON_AddStringToObject(binding, "value", value); + if (value) { + SF_FREE(value); + } + } + snowflake_cJSON_AddStringToObject(binding, "type", type); + snowflake_cJSON_AddItemToObject(bindings, name, binding); + } + + return bindings; +} + +sf_bool STDCALL _snowflake_needs_stage_binding(SF_STMT* sfstmt) +{ + if (!sfstmt || !sfstmt->connection || + (_snowflake_get_current_param_style(sfstmt) == INVALID_PARAM_TYPE) || + sfstmt->connection->stage_binding_disabled || + sfstmt->paramset_size <= 1) + { + return SF_BOOLEAN_FALSE; + } + + if (sfstmt->paramset_size * sfstmt->params_len >= sfstmt->connection->stage_binding_threshold) + { + return SF_BOOLEAN_TRUE; + } + return SF_BOOLEAN_FALSE; +} /** * Resets SNOWFLAKE_STMT parameters. * @@ -1973,6 +2143,8 @@ SF_STMT *STDCALL snowflake_stmt(SF_CONNECT *sf) { _snowflake_stmt_reset(sfstmt); sfstmt->connection = sf; sfstmt->multi_stmt_count = SF_MULTI_STMT_COUNT_UNSET; + // single value binding by default + sfstmt->paramset_size = 1; } return sfstmt; } @@ -2010,6 +2182,7 @@ void STDCALL snowflake_bind_input_init(SF_BIND_INPUT * input) input->idx = 0; input->name = NULL; input->value = NULL; + input->len_ind = NULL; } /** @@ -2400,7 +2573,7 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt, if (is_put_get_command && is_native_put_get && !is_describe_only) { _snowflake_stmt_desc_reset(sfstmt); - return _snowflake_execute_put_get_native(sfstmt, result_capture); + return _snowflake_execute_put_get_native(sfstmt, NULL, 0, result_capture); } clear_snowflake_error(&sfstmt->error); @@ -2419,6 +2592,7 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt, }; size_t i; cJSON *bindings = NULL; + char* bind_stage = NULL; SF_BIND_INPUT *input; const char *type; char *value; @@ -2427,60 +2601,13 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt, sfstmt->sequence_counter = ++sfstmt->connection->sequence_counter; _mutex_unlock(&sfstmt->connection->mutex_sequence_counter); - if (_snowflake_get_current_param_style(sfstmt) == POSITIONAL) + if (_snowflake_needs_stage_binding(sfstmt)) { - bindings = snowflake_cJSON_CreateObject(); - for (i = 0; i < sfstmt->params_len; i++) - { - cJSON *binding; - input = (SF_BIND_INPUT *) sf_param_store_get(sfstmt->params, - i+1,NULL); - if (input == NULL) { - continue; - } - // TODO check if input is null and either set error or write msg to log - type = snowflake_type_to_string( - c_type_to_snowflake(input->c_type, SF_DB_TYPE_TIMESTAMP_NTZ)); - value = value_to_string(input->value, input->len, input->c_type); - binding = snowflake_cJSON_CreateObject(); - char idxbuf[20]; - sf_sprintf(idxbuf, sizeof(idxbuf), "%lu", (unsigned long) (i + 1)); - snowflake_cJSON_AddStringToObject(binding, "type", type); - snowflake_cJSON_AddStringToObject(binding, "value", value); - snowflake_cJSON_AddItemToObject(bindings, idxbuf, binding); - if (value) { - SF_FREE(value); - } - } + bind_stage = _snowflake_stage_bind_upload(sfstmt); } - else if (_snowflake_get_current_param_style(sfstmt) == NAMED) + if (!bind_stage) { - bindings = snowflake_cJSON_CreateObject(); - char *named_param = NULL; - for(i = 0; i < sfstmt->params_len; i++) - { - cJSON *binding; - named_param = (char *)(((NamedParams *)sfstmt->name_list)->name_list[i]); - input = (SF_BIND_INPUT *) sf_param_store_get(sfstmt->params, - 0,named_param); - if (input == NULL) - { - log_error("_snowflake_execute_ex: No parameter by this name %s",named_param); - continue; - } - type = snowflake_type_to_string( - c_type_to_snowflake(input->c_type, SF_DB_TYPE_TIMESTAMP_NTZ)); - value = value_to_string(input->value, input->len, input->c_type); - binding = snowflake_cJSON_CreateObject(); - - snowflake_cJSON_AddStringToObject(binding, "type", type); - snowflake_cJSON_AddStringToObject(binding, "value", value); - snowflake_cJSON_AddItemToObject(bindings, named_param, binding); - if (value) - { - SF_FREE(value); - } - } + bindings = _snowflake_get_binding_json(sfstmt); } if (is_string_empty(sfstmt->connection->directURL) && @@ -2500,7 +2627,12 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt, is_string_empty(sfstmt->connection->directURL) ? NULL : sfstmt->request_id, is_describe_only, sfstmt->multi_stmt_count); - if (bindings != NULL) { + if (bind_stage) + { + snowflake_cJSON_AddStringToObject(body, "bindStage", bind_stage); + SF_FREE(bind_stage); + } + else if (bindings != NULL) { /* binding parameters if exists */ snowflake_cJSON_AddItemToObject(body, "bindings", bindings); } @@ -2787,6 +2919,9 @@ SF_STATUS STDCALL snowflake_stmt_get_attr( case SF_STMT_MULTI_STMT_COUNT: *value = &sfstmt->multi_stmt_count; break; + case SF_STMT_PARAMSET_SIZE: + *value = &sfstmt->paramset_size; + break; default: SET_SNOWFLAKE_ERROR( &sfstmt->error, SF_STATUS_ERROR_BAD_ATTRIBUTE_TYPE, @@ -2810,6 +2945,9 @@ SF_STATUS STDCALL snowflake_stmt_set_attr( case SF_STMT_MULTI_STMT_COUNT: sfstmt->multi_stmt_count = value ? *((int64*)value) : SF_MULTI_STMT_COUNT_UNSET; break; + case SF_STMT_PARAMSET_SIZE: + sfstmt->paramset_size = value ? *((int64*)value) : 1; + break; default: SET_SNOWFLAKE_ERROR( &sfstmt->error, SF_STATUS_ERROR_BAD_ATTRIBUTE_TYPE, diff --git a/lib/client_int.h b/lib/client_int.h index 47db68c595..c6e97ca986 100644 --- a/lib/client_int.h +++ b/lib/client_int.h @@ -25,6 +25,9 @@ #define SF_DEFAULT_MAX_OBJECT_SIZE 16777216 +#define SF_DEFAULT_STAGE_BINDING_THRESHOLD 65280 +#define SF_DEFAULT_STAGE_BINDING_MAX_FILESIZE 100 * 1024 * 1024 + // defaults for put get configurations #define SF_DEFAULT_PUT_COMPRESS_LEVEL (-1) #define SF_MAX_PUT_COMPRESS_LEVEL 9 @@ -173,6 +176,12 @@ sf_bool STDCALL _is_put_get_command(char* sql_text); */ PARAM_TYPE STDCALL _snowflake_get_param_style(const SF_BIND_INPUT *input); +/** + * @param sfstmt SNOWFLAKE_STMT context. + * @return parameter bindings in cJSON. + */ +cJSON* STDCALL _snowflake_get_binding_json(SF_STMT *sfstmt); + #ifdef __cplusplus extern "C" { #endif @@ -190,13 +199,52 @@ _snowflake_query_put_get_legacy(SF_STMT* sfstmt, const char* command, size_t com /** * Executes put get command natively. * @param sfstmt SNOWFLAKE_STMT context. + * @param upload_stream Internal support for bind uploading, pointer to std::basic_iostream. + * @param stream_size The data size of upload_stream. * @param raw_response_buffer optional pointer to an SF_QUERY_RESULT_CAPTURE, * * @return 0 if success, otherwise an errno is returned. */ SF_STATUS STDCALL _snowflake_execute_put_get_native( SF_STMT *sfstmt, + void* upload_stream, + size_t stream_size, struct SF_QUERY_RESULT_CAPTURE* result_capture); + +/* + * @return size of single binding value per data type. + */ +size_t STDCALL _snowflake_get_binding_value_size(SF_BIND_INPUT* bind); + +#define SF_PARAM_NAME_BUF_LEN 20 +/** + * Get parameter binding by index for both POSITIONAL and NAMED cases. + * @param sfstmt SNOWFLAKE_STMT context. + * @param index The 0 based index of parameter binding to get. + * @param name Output the name of binding. + * @param name_buf The buffer to store name. + Used for POSITIONAL and name will point to this buffer in such case. + * @param name_buf_len The size of name_buf. + * @return parameter binding with specified index. + */ +SF_BIND_INPUT* STDCALL _snowflake_get_binding_by_index(SF_STMT* sfstmt, + size_t index, + char** name, + char* name_buf, + size_t name_buf_len); + +sf_bool STDCALL _snowflake_needs_stage_binding(SF_STMT* sfstmt); + +/** + * Upload parameter bindings through internal stage. + * @param sfstmt SNOWFLAKE_STMT context. + * + * @return Stage path for uploaded bindings if success, + * otherwise NULL is returned and error is put in sfstmt->error. + */ +char* STDCALL +_snowflake_stage_bind_upload(SF_STMT* sfstmt); + #ifdef __cplusplus } // extern "C" #endif diff --git a/tests/test_bind_params.c b/tests/test_bind_params.c index 255cc0b857..4ffa97da8b 100644 --- a/tests/test_bind_params.c +++ b/tests/test_bind_params.c @@ -4,6 +4,7 @@ #include #include "utils/test_setup.h" +#include "memory.h" #define INPUT_ARRAY_SIZE 3 @@ -111,10 +112,192 @@ void test_bind_parameters(void **unused) { snowflake_term(sf); } +void test_array_binding_core(unsigned int array_size) { + /* init */ + SF_STATUS status; + int8* int8_array = NULL; + int8 int8_value = -12; + char int8_expected_result[] = "-12"; + uint8* uint8_array = NULL; + uint8 uint8_value = 12; + char uint8_expected_result[] = "12"; + int64* int64_array = NULL; + int64 int64_value = -12345; + char int64_expected_result[] = "-12345"; + uint64* uint64_array = NULL; + uint64 uint64_value = 12345; + char uint64_expected_result[] = "12345"; + float64* float_array = NULL; + float64 float_value = 1.23; + char float_expected_result[] = "1.23"; + char* string_array = NULL; + char string_value[] = "str"; + char string_expected_result[] = "str"; + byte* binary_array = NULL; + byte binary_value[] = {0x12, 0x34, 0x56, 0x78}; + char binary_expected_result[] = "12345678"; + sf_bool* bool_array = NULL; + sf_bool bool_value = SF_BOOLEAN_TRUE; + char bool_expected_result[] = "1"; + SF_BIND_INPUT int8_input; + SF_BIND_INPUT uint8_input; + SF_BIND_INPUT int64_input; + SF_BIND_INPUT uint64_input; + SF_BIND_INPUT float_input; + SF_BIND_INPUT string_input; + SF_BIND_INPUT binary_input; + SF_BIND_INPUT bool_input; + + SF_BIND_INPUT input_array[8]; + char* expected_results[8]; + unsigned int i = 0, j = 0; + + // initialize bindings with argument + int8_array = SF_CALLOC(array_size, sizeof(int8_value)); + uint8_array = SF_CALLOC(array_size, sizeof(uint8_value)); + int64_array = SF_CALLOC(array_size, sizeof(int64_value)); + uint64_array = SF_CALLOC(array_size, sizeof(uint64_value)); + float_array = SF_CALLOC(array_size, sizeof(float_value)); + string_array = SF_CALLOC(array_size, sizeof(string_value)); + binary_array = SF_CALLOC(array_size, sizeof(binary_value)); + bool_array = SF_CALLOC(array_size, sizeof(bool_value)); + + for (i = 0; i < array_size; i++) + { + int8_array[i] = int8_value; + uint8_array[i] = uint8_value; + int64_array[i] = int64_value; + uint64_array[i] = uint64_value; + float_array[i] = float_value; + memcpy(string_array + sizeof(string_value) * i, string_value, sizeof(string_value)); + memcpy(binary_array + sizeof(binary_value) * i, binary_value, sizeof(binary_value)); + bool_array[i] = bool_value; + } + + snowflake_bind_input_init(&int8_input); + snowflake_bind_input_init(&uint8_input); + snowflake_bind_input_init(&int64_input); + snowflake_bind_input_init(&uint64_input); + snowflake_bind_input_init(&float_input); + snowflake_bind_input_init(&string_input); + snowflake_bind_input_init(&binary_input); + snowflake_bind_input_init(&bool_input); + + int8_input.idx = 1; + int8_input.c_type = SF_C_TYPE_INT8; + int8_input.value = int8_array; + + uint8_input.idx = 2; + uint8_input.c_type = SF_C_TYPE_UINT8; + uint8_input.value = uint8_array; + + int64_input.idx = 3; + int64_input.c_type = SF_C_TYPE_INT64; + int64_input.value = int64_array; + + uint64_input.idx = 4; + uint64_input.c_type = SF_C_TYPE_UINT64; + uint64_input.value = uint64_array; + + float_input.idx = 5; + float_input.c_type = SF_C_TYPE_FLOAT64; + float_input.value = float_array; + + string_input.idx = 6; + string_input.c_type = SF_C_TYPE_STRING; + string_input.value = string_array; + string_input.len = sizeof(string_value); + + binary_input.idx = 7; + binary_input.c_type = SF_C_TYPE_BINARY; + binary_input.value = binary_array; + binary_input.len = sizeof(binary_value); + + bool_input.idx = 8; + bool_input.c_type = SF_C_TYPE_BOOLEAN; + bool_input.value = bool_array; + + input_array[0] = int8_input; + input_array[1] = uint8_input; + input_array[2] = int64_input; + input_array[3] = uint64_input; + input_array[4] = float_input; + input_array[5] = string_input; + input_array[6] = binary_input; + input_array[7] = bool_input; + + expected_results[0] = int8_expected_result; + expected_results[1] = uint8_expected_result; + expected_results[2] = int64_expected_result; + expected_results[3] = uint64_expected_result; + expected_results[4] = float_expected_result; + expected_results[5] = string_expected_result; + expected_results[6] = binary_expected_result; + expected_results[7] = bool_expected_result; + + /* Connect with all parameters set */ + SF_CONNECT* sf = setup_snowflake_connection(); + status = snowflake_connect(sf); + assert_int_equal(status, SF_STATUS_SUCCESS); + + /* Create a statement once and reused */ + SF_STMT* stmt = snowflake_stmt(sf); + status = snowflake_query( + stmt, + "create or replace temporary table t (c1 number, c2 number, c3 number, c4 number, c5 float, c6 string, c7 binary, c8 boolean)", + 0 + ); + assert_int_equal(status, SF_STATUS_SUCCESS); + + int64 paramset_size = array_size; + status = snowflake_stmt_set_attr(stmt, SF_STMT_PARAMSET_SIZE, ¶mset_size); + status = snowflake_prepare( + stmt, + "insert into t values(?, ?, ?, ?, ?, ?, ?, ?)", + 0 + ); + assert_int_equal(status, SF_STATUS_SUCCESS); + + status = snowflake_bind_param_array(stmt, input_array, sizeof(input_array) / sizeof(SF_BIND_INPUT)); + assert_int_equal(status, SF_STATUS_SUCCESS); + + status = snowflake_execute(stmt); + assert_int_equal(status, SF_STATUS_SUCCESS); + assert_int_equal(snowflake_affected_rows(stmt), array_size); + + status = snowflake_query(stmt, "select * from t", 0); + assert_int_equal(status, SF_STATUS_SUCCESS); + assert_int_equal(snowflake_num_rows(stmt), array_size); + + for (i = 0; i < array_size; i++) + { + status = snowflake_fetch(stmt); + assert_int_equal(status, SF_STATUS_SUCCESS); + char* result = NULL; + for (j = 0; j < 8; j++) + { + snowflake_column_as_const_str(stmt, j + 1, &result); + assert_string_equal(result, expected_results[j]); + } + } + snowflake_stmt_term(stmt); + snowflake_term(sf); +} + +void test_array_binding_normal(void** unused) { + test_array_binding_core(1000); +} + +void test_array_binding_stage(void** unused) { + test_array_binding_core(100000); +} + int main(void) { initialize_test(SF_BOOLEAN_FALSE); const struct CMUnitTest tests[] = { cmocka_unit_test(test_bind_parameters), + cmocka_unit_test(test_array_binding_normal), + cmocka_unit_test(test_array_binding_stage), }; int ret = cmocka_run_group_tests(tests, NULL, NULL); snowflake_global_term();