Skip to content

Commit

Permalink
SNOW-1524269: support put/get for GCP (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-hx authored Oct 22, 2024
1 parent 6bc1b92 commit e3e1b88
Show file tree
Hide file tree
Showing 12 changed files with 322 additions and 44 deletions.
153 changes: 153 additions & 0 deletions cpp/StatementPutGet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,35 @@
*/

#include <client_int.h>
#include "connection.h"
#include "snowflake/PutGetParseResponse.hpp"
#include "StatementPutGet.hpp"
#include "curl_desc_pool.h"

using namespace Snowflake::Client;

static size_t file_get_write_callback(char* ptr, size_t size, size_t nmemb, void* userdata)
{
size_t data_size = size * nmemb;
std::basic_iostream<char>* recvStream = (std::basic_iostream<char>*)(userdata);
if (recvStream)
{
recvStream->write(static_cast<const char*>(ptr), data_size);
}

return data_size;
}

static size_t file_put_read_callback(void* ptr, size_t size, size_t nmemb, void* userdata)
{
std::basic_iostream<char>* payload = (std::basic_iostream<char>*)(userdata);
size_t data_size = size * nmemb;

payload->read(static_cast<char*>(ptr), data_size);
size_t ret = payload->gcount();
return payload->gcount();
}

StatementPutGet::StatementPutGet(SF_STMT *stmt) :
m_stmt(stmt), m_useProxy(false)
{
Expand Down Expand Up @@ -104,6 +128,14 @@ bool StatementPutGet::parsePutGetCommand(std::string *sql,
};
putGetParseResponse->stageInfo.endPoint = response->stage_info->endPoint;

}
else if (sf_strncasecmp(response->stage_info->location_type, "gcs", 3) == 0)
{
putGetParseResponse->stageInfo.stageType = StageType::GCS;
putGetParseResponse->stageInfo.credentials = {
{"GCS_ACCESS_TOKEN", response->stage_info->stage_cred->gcs_access_token}
};

} else if (sf_strncasecmp(response->stage_info->location_type,
"local_fs", 8) == 0)
{
Expand All @@ -123,3 +155,124 @@ Util::Proxy* StatementPutGet::get_proxy()
return &m_proxy;
}
}

bool StatementPutGet::http_put(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>& payload,
size_t payloadLen,
std::string& responseHeaders)
{
if (!m_stmt || !m_stmt->connection)
{
return false;
}
SF_CONNECT* sf = m_stmt->connection;
void* curl_desc = get_curl_desc_from_pool(url.c_str(), sf->proxy, sf->no_proxy);
CURL* curl = get_curl_from_desc(curl_desc);
if (!curl)
{
return false;
}

char* urlbuf = (char*)SF_CALLOC(1, url.length() + 1);
sf_strcpy(urlbuf, url.length() + 1, url.c_str());

SF_HEADER reqHeaders;
reqHeaders.header = NULL;
for (auto itr = headers.begin(); itr != headers.end(); itr++)
{
reqHeaders.header = curl_slist_append(reqHeaders.header, itr->c_str());
}

PUT_PAYLOAD putPayload;
putPayload.buffer = &payload;
putPayload.length = payloadLen;
putPayload.read_callback = file_put_read_callback;

char* respHeaders = NULL;
sf_bool success = SF_BOOLEAN_FALSE;

success = http_perform(curl, PUT_REQUEST_TYPE, urlbuf, &reqHeaders, NULL, &putPayload, NULL,
NULL, &respHeaders, get_retry_timeout(sf),
SF_BOOLEAN_FALSE, &m_stmt->error, sf->insecure_mode,sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
0, sf->retry_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
sf->proxy, sf->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE);

free_curl_desc(curl_desc);
SF_FREE(urlbuf);
curl_slist_free_all(reqHeaders.header);
if (respHeaders)
{
responseHeaders = std::string(respHeaders);
SF_FREE(respHeaders);
}

return success;
}

bool StatementPutGet::http_get(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>* payload,
std::string& responseHeaders,
bool headerOnly)
{
SF_REQUEST_TYPE reqType = GET_REQUEST_TYPE;
if (headerOnly)
{
reqType = HEAD_REQUEST_TYPE;
}

if (!m_stmt || !m_stmt->connection)
{
return false;
}
SF_CONNECT* sf = m_stmt->connection;

void* curl_desc = get_curl_desc_from_pool(url.c_str(), sf->proxy, sf->no_proxy);
CURL* curl = get_curl_from_desc(curl_desc);
if (!curl)
{
return false;
}

char* urlbuf = (char*)SF_CALLOC(1, url.length() + 1);
sf_strcpy(urlbuf, url.length() + 1, url.c_str());

SF_HEADER reqHeaders;
reqHeaders.header = NULL;
for (auto itr = headers.begin(); itr != headers.end(); itr++)
{
reqHeaders.header = curl_slist_append(reqHeaders.header, itr->c_str());
}

NON_JSON_RESP resp;
resp.buffer = payload;
resp.write_callback = file_get_write_callback;

char* respHeaders = NULL;
sf_bool success = SF_BOOLEAN_FALSE;

success = http_perform(curl, reqType, urlbuf, &reqHeaders, NULL, NULL, NULL,
&resp, &respHeaders, get_retry_timeout(sf),
SF_BOOLEAN_FALSE, &m_stmt->error, sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
0, sf->retry_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
sf->proxy, sf->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE);

free_curl_desc(curl_desc);
SF_FREE(urlbuf);
curl_slist_free_all(reqHeaders.header);
if (respHeaders)
{
responseHeaders = respHeaders;
SF_FREE(respHeaders);
}

if (payload)
{
payload->flush();
}

return success;
}
33 changes: 33 additions & 0 deletions cpp/StatementPutGet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,39 @@ class StatementPutGet : public Snowflake::Client::IStatementPutGet

virtual Util::Proxy* get_proxy();

/**
* PUT/GET on GCS use this interface to perform put request.
* Not implemented by default.
* @param url The url of the request.
* @param headers The headers of the request.
* @param payload The upload data.
* @param responseHeaders The headers of the response.
*
* return true if succeed otherwise false
*/
virtual bool http_put(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>& payload,
size_t payloadLen,
std::string& responseHeaders);

/**
* PUT/GET on GCS use this interface to perform put request.
* Not implemented by default.
* @param url The url of the request.
* @param headers The headers of the request.
* @param payload The upload data.
* @param responseHeaders The headers of the response.
* @param headerOnly True if get response header only without payload body.
*
* return true if succeed otherwise false
*/
virtual bool http_get(std::string const& url,
std::vector<std::string> const& headers,
std::basic_iostream<char>* payload,
std::string& responseHeaders,
bool headerOnly);

private:
SF_STMT *m_stmt;
Util::Proxy m_proxy;
Expand Down
5 changes: 5 additions & 0 deletions include/snowflake/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ extern "C" {
/**
* API Name
*/
/* TODO: Temporarily change to ODBC for now to pass the test before
* features (PUT for GCP, multiple statements etc.) unblocked
* on server side.
* Need to revert to C_API when merging to master.
*/
#define SF_API_NAME "ODBC"

/**
Expand Down
6 changes: 5 additions & 1 deletion include/snowflake/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#ifndef SNOWFLAKE_CLIENT_VERSION_H
#define SNOWFLAKE_CLIENT_VERSION_H

// TODO: temporary change for testing, will remove
/* TODO: Temporarily change to ODBC version for now to pass the test before
* features (PUT for GCP, multiple statements etc.) unblocked
* on server side.
* Need to revert to libsfclient version when merging to master.
*/
#define SF_API_VERSION "3.0.1"

#endif /* SNOWFLAKE_CLIENT_VERSION_H */
4 changes: 2 additions & 2 deletions lib/chunk_downloader.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ sf_bool STDCALL download_chunk(char *url, SF_HEADER *headers,
CURL *curl = get_curl_from_desc(curl_desc);

if (!curl ||
!http_perform(curl, GET_REQUEST_TYPE, url, headers, NULL, chunk,
non_json_resp, network_timeout,
!http_perform(curl, GET_REQUEST_TYPE, url, headers, NULL, NULL, chunk,
non_json_resp, NULL, network_timeout,
SF_BOOLEAN_TRUE, error, insecure_mode, fail_open, 0,
0, retry_max_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE,
proxy, no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE)) {
Expand Down
3 changes: 3 additions & 0 deletions lib/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -2488,6 +2488,9 @@ SF_STATUS STDCALL _snowflake_execute_ex(SF_STMT *sfstmt,
json_copy_string(
&sfstmt->put_get_response->stage_info->stage_cred->azure_sas_token,
stage_cred, "AZURE_SAS_TOKEN");
json_copy_string(
&sfstmt->put_get_response->stage_info->stage_cred->gcs_access_token,
stage_cred, "GCS_ACCESS_TOKEN");
json_copy_string(
&sfstmt->put_get_response->localLocation, data,
"localLocation");
Expand Down
1 change: 1 addition & 0 deletions lib/client_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef struct SF_STAGE_CRED {
char *aws_secret_key;
char *aws_token;
char *azure_sas_token;
char* gcs_access_token;
} SF_STAGE_CRED;

typedef struct SF_STAGE_INFO {
Expand Down
19 changes: 9 additions & 10 deletions lib/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ cJSON *STDCALL create_query_json_body(const char *sql_text,
parameters = snowflake_cJSON_CreateObject();
}
snowflake_cJSON_AddStringToObject(parameters, "C_API_QUERY_RESULT_FORMAT", "JSON");

// temporary code to fake as ODBC to have multiple statements enabled
snowflake_cJSON_AddStringToObject(parameters, "ODBC_QUERY_RESULT_FORMAT", "JSON");
#endif
Expand Down Expand Up @@ -376,7 +375,7 @@ sf_bool STDCALL curl_post_call(SF_CONNECT *sf,
}

do {
if (!http_perform(curl, POST_REQUEST_TYPE, url, header, body, json, NULL,
if (!http_perform(curl, POST_REQUEST_TYPE, url, header, body, NULL, json, NULL, NULL,
retry_timeout, SF_BOOLEAN_FALSE, error,
sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
Expand Down Expand Up @@ -503,7 +502,7 @@ sf_bool STDCALL curl_get_call(SF_CONNECT *sf,
memset(query_code, 0, QUERYCODE_LEN);

do {
if (!http_perform(curl, GET_REQUEST_TYPE, url, header, NULL, json, NULL,
if (!http_perform(curl, GET_REQUEST_TYPE, url, header, NULL, NULL, json, NULL, NULL,
get_retry_timeout(sf), SF_BOOLEAN_FALSE, error,
sf->insecure_mode, sf->ocsp_fail_open,
sf->retry_on_curle_couldnt_connect_count,
Expand Down Expand Up @@ -906,16 +905,16 @@ ARRAY_LIST *json_get_object_keys(const cJSON *item) {
}

size_t
json_resp_cb(char *data, size_t size, size_t nmemb, RAW_JSON_BUFFER *raw_json) {
char_resp_cb(char *data, size_t size, size_t nmemb, RAW_CHAR_BUFFER *raw_buf) {
size_t data_size = size * nmemb;
log_debug("Curl response size: %zu", data_size);
raw_json->buffer = (char *) SF_REALLOC(raw_json->buffer,
raw_json->size + data_size + 1);
raw_buf->buffer = (char *) SF_REALLOC(raw_buf->buffer,
raw_buf->size + data_size + 1);
// Start copying where last null terminator existed
sf_memcpy(&raw_json->buffer[raw_json->size], data_size, data, data_size);
raw_json->size += data_size;
// Set null terminator
raw_json->buffer[raw_json->size] = '\0';
sf_memcpy(&raw_buf->buffer[raw_buf->size], data_size, data, data_size);
raw_buf->size += data_size;
// Set null raw_buf
raw_buf->buffer[raw_buf->size] = '\0';
return data_size;
}

Expand Down
Loading

0 comments on commit e3e1b88

Please sign in to comment.