Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1524269: support put/get for GCP #738

Merged
merged 9 commits into from
Oct 22, 2024
153 changes: 153 additions & 0 deletions cpp/StatementPutGet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,35 @@
*/

#include <client_int.h>
#include "connection.h"
#include "snowflake/PutGetParseResponse.hpp"
#include "StatementPutGet.hpp"
#include "curl_desc_pool.h"

using namespace Snowflake::Client;

static size_t file_get_write_callback(char* ptr, size_t size, size_t nmemb, void* userdata)
{
size_t data_size = size * nmemb;
std::basic_iostream<char>* recvStream = (std::basic_iostream<char>*)(userdata);
if (recvStream)
{
recvStream->write(static_cast<const char*>(ptr), data_size);
}

return data_size;
}

static size_t file_put_read_callback(void* ptr, size_t size, size_t nmemb, void* userdata)
{
std::basic_iostream<char>* payload = (std::basic_iostream<char>*)(userdata);
size_t data_size = size * nmemb;

payload->read(static_cast<char*>(ptr), data_size);
size_t ret = payload->gcount();
return payload->gcount();
}

StatementPutGet::StatementPutGet(SF_STMT *stmt) :
m_stmt(stmt), m_useProxy(false)
{
Expand Down Expand Up @@ -104,6 +128,14 @@ bool StatementPutGet::parsePutGetCommand(std::string *sql,
};
putGetParseResponse->stageInfo.endPoint = response->stage_info->endPoint;

}
else if (sf_strncasecmp(response->stage_info->location_type, "gcs", 3) == 0)
{
putGetParseResponse->stageInfo.stageType = StageType::GCS;
putGetParseResponse->stageInfo.credentials = {
{"GCS_ACCESS_TOKEN", response->stage_info->stage_cred->gcs_access_token}
};

} else if (sf_strncasecmp(response->stage_info->location_type,
"local_fs", 8) == 0)
{
Expand All @@ -123,3 +155,124 @@ Util::Proxy* StatementPutGet::get_proxy()
return &m_proxy;
}
}

bool StatementPutGet::http_put(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>& payload,
size_t payloadLen,
std::string& responseHeaders)
{
if (!m_stmt || !m_stmt->connection)
{
return false;
}
SF_CONNECT* sf = m_stmt->connection;
void* curl_desc = get_curl_desc_from_pool(url.c_str(), sf->proxy, sf->no_proxy);
CURL* curl = get_curl_from_desc(curl_desc);
if (!curl)
{
return false;
}

char* urlbuf = (char*)SF_CALLOC(1, url.length() + 1);
sf_strcpy(urlbuf, url.length() + 1, url.c_str());

SF_HEADER reqHeaders;
reqHeaders.header = NULL;
for (auto itr = headers.begin(); itr != headers.end(); itr++)
{
reqHeaders.header = curl_slist_append(reqHeaders.header, itr->c_str());
}

PUT_PAYLOAD putPayload;
putPayload.buffer = &payload;
putPayload.length = payloadLen;
putPayload.read_callback = file_put_read_callback;

char* respHeaders = NULL;
sf_bool success = SF_BOOLEAN_FALSE;

success = http_perform(curl, PUT_REQUEST_TYPE, urlbuf, &reqHeaders, NULL, &putPayload, NULL,
NULL, &respHeaders, get_retry_timeout(sf),
SF_BOOLEAN_FALSE, &m_stmt->error, sf->insecure_mode,sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
0, sf->retry_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
sf->proxy, sf->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE);

free_curl_desc(curl_desc);
SF_FREE(urlbuf);
curl_slist_free_all(reqHeaders.header);
if (respHeaders)
{
responseHeaders = std::string(respHeaders);
SF_FREE(respHeaders);
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved
}

return success;
}

bool StatementPutGet::http_get(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>* payload,
std::string& responseHeaders,
bool headerOnly)
{
SF_REQUEST_TYPE reqType = GET_REQUEST_TYPE;
if (headerOnly)
{
reqType = HEAD_REQUEST_TYPE;
}

if (!m_stmt || !m_stmt->connection)
{
return false;
}
SF_CONNECT* sf = m_stmt->connection;

void* curl_desc = get_curl_desc_from_pool(url.c_str(), sf->proxy, sf->no_proxy);
CURL* curl = get_curl_from_desc(curl_desc);
if (!curl)
{
return false;
}

char* urlbuf = (char*)SF_CALLOC(1, url.length() + 1);
sf_strcpy(urlbuf, url.length() + 1, url.c_str());

SF_HEADER reqHeaders;
reqHeaders.header = NULL;
for (auto itr = headers.begin(); itr != headers.end(); itr++)
{
reqHeaders.header = curl_slist_append(reqHeaders.header, itr->c_str());
}

NON_JSON_RESP resp;
resp.buffer = payload;
resp.write_callback = file_get_write_callback;

char* respHeaders = NULL;
sf_bool success = SF_BOOLEAN_FALSE;

success = http_perform(curl, reqType, urlbuf, &reqHeaders, NULL, NULL, NULL,
&resp, &respHeaders, get_retry_timeout(sf),
SF_BOOLEAN_FALSE, &m_stmt->error, sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
0, sf->retry_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
sf->proxy, sf->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE);

free_curl_desc(curl_desc);
SF_FREE(urlbuf);
curl_slist_free_all(reqHeaders.header);
if (respHeaders)
{
responseHeaders = respHeaders;
SF_FREE(respHeaders);
}

if (payload)
{
payload->flush();
}

return success;
}
33 changes: 33 additions & 0 deletions cpp/StatementPutGet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,39 @@ class StatementPutGet : public Snowflake::Client::IStatementPutGet

virtual Util::Proxy* get_proxy();

/**
* PUT/GET on GCS use this interface to perform put request.
* Not implemented by default.
* @param url The url of the request.
* @param headers The headers of the request.
* @param payload The upload data.
* @param responseHeaders The headers of the response.
*
* return true if succeed otherwise false
*/
virtual bool http_put(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>& payload,
size_t payloadLen,
std::string& responseHeaders);

/**
* PUT/GET on GCS use this interface to perform put request.
* Not implemented by default.
* @param url The url of the request.
* @param headers The headers of the request.
* @param payload The upload data.
* @param responseHeaders The headers of the response.
* @param headerOnly True if get response header only without payload body.
*
* return true if succeed otherwise false
*/
virtual bool http_get(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>* payload,
std::string& responseHeaders,
bool headerOnly);

private:
SF_STMT *m_stmt;
Util::Proxy m_proxy;
Expand Down
5 changes: 5 additions & 0 deletions include/snowflake/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ extern "C" {
/**
* API Name
*/
/* TODO: Temporarily change to ODBC for now to pass the test before
* features (PUT for GCP, multiple statements etc.) unblocked
* on server side.
* Need to revert to C_API when merging to master.
*/
#define SF_API_NAME "ODBC"
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved

/**
Expand Down
6 changes: 5 additions & 1 deletion include/snowflake/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#ifndef SNOWFLAKE_CLIENT_VERSION_H
#define SNOWFLAKE_CLIENT_VERSION_H

// TODO: temporary change for testing, will remove
/* TODO: Temporarily change to ODBC version for now to pass the test before
* features (PUT for GCP, multiple statements etc.) unblocked
* on server side.
* Need to revert to libsfclient version when merging to master.
*/
#define SF_API_VERSION "3.0.1"

#endif /* SNOWFLAKE_CLIENT_VERSION_H */
4 changes: 2 additions & 2 deletions lib/chunk_downloader.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ sf_bool STDCALL download_chunk(char *url, SF_HEADER *headers,
CURL *curl = get_curl_from_desc(curl_desc);

if (!curl ||
!http_perform(curl, GET_REQUEST_TYPE, url, headers, NULL, chunk,
non_json_resp, network_timeout,
!http_perform(curl, GET_REQUEST_TYPE, url, headers, NULL, NULL, chunk,
non_json_resp, NULL, network_timeout,
SF_BOOLEAN_TRUE, error, insecure_mode, fail_open, 0,
0, retry_max_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
proxy, no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE)) {
Expand Down
3 changes: 3 additions & 0 deletions lib/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -2488,6 +2488,9 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt,
json_copy_string(
&sfstmt->put_get_response->stage_info->stage_cred->azure_sas_token,
stage_cred, "AZURE_SAS_TOKEN");
json_copy_string(
&sfstmt->put_get_response->stage_info->stage_cred->gcs_access_token,
stage_cred, "GCS_ACCESS_TOKEN");
json_copy_string(
&sfstmt->put_get_response->localLocation, data,
"localLocation");
Expand Down
1 change: 1 addition & 0 deletions lib/client_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef struct SF_STAGE_CRED {
char *aws_secret_key;
char *aws_token;
char *azure_sas_token;
char* gcs_access_token;
} SF_STAGE_CRED;

typedef struct SF_STAGE_INFO {
Expand Down
19 changes: 9 additions & 10 deletions lib/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ cJSON *STDCALL create_query_json_body(const char *sql_text,
parameters = snowflake_cJSON_CreateObject();
}
snowflake_cJSON_AddStringToObject(parameters, "C_API_QUERY_RESULT_FORMAT", "JSON");

// temporary code to fake as ODBC to have multiple statements enabled
snowflake_cJSON_AddStringToObject(parameters, "ODBC_QUERY_RESULT_FORMAT", "JSON");
#endif
Expand Down Expand Up @@ -376,7 +375,7 @@ sf_bool STDCALL curl_post_call(SF_CONNECT *sf,
}

do {
if (!http_perform(curl, POST_REQUEST_TYPE, url, header, body, json, NULL,
if (!http_perform(curl, POST_REQUEST_TYPE, url, header, body, NULL, json, NULL, NULL,
retry_timeout, SF_BOOLEAN_FALSE, error,
sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
Expand Down Expand Up @@ -503,7 +502,7 @@ sf_bool STDCALL curl_get_call(SF_CONNECT *sf,
memset(query_code, 0, QUERYCODE_LEN);

do {
if (!http_perform(curl, GET_REQUEST_TYPE, url, header, NULL, json, NULL,
if (!http_perform(curl, GET_REQUEST_TYPE, url, header, NULL, NULL, json, NULL, NULL,
get_retry_timeout(sf), SF_BOOLEAN_FALSE, error,
sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
Expand Down Expand Up @@ -906,16 +905,16 @@ ARRAY_LIST *json_get_object_keys(const cJSON *item) {
}

size_t
json_resp_cb(char *data, size_t size, size_t nmemb, RAW_JSON_BUFFER *raw_json) {
char_resp_cb(char *data, size_t size, size_t nmemb, RAW_CHAR_BUFFER *raw_buf) {
size_t data_size = size * nmemb;
log_debug("Curl response size: %zu", data_size);
raw_json->buffer = (char *) SF_REALLOC(raw_json->buffer,
raw_json->size + data_size + 1);
raw_buf->buffer = (char *) SF_REALLOC(raw_buf->buffer,
raw_buf->size + data_size + 1);
// Start copying where last null terminator existed
sf_memcpy(&raw_json->buffer[raw_json->size], data_size, data, data_size);
raw_json->size += data_size;
// Set null terminator
raw_json->buffer[raw_json->size] = '\0';
sf_memcpy(&raw_buf->buffer[raw_buf->size], data_size, data, data_size);
raw_buf->size += data_size;
// Set null raw_buf
raw_buf->buffer[raw_buf->size] = '\0';
return data_size;
}

Expand Down
Loading
Loading