Skip to content

Commit

Permalink
SNOW-995369 GCS token refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lsembera committed Jan 17, 2024
1 parent db1fd02 commit 98562a1
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
Expand All @@ -26,6 +29,8 @@
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IngestTestUtils {
private static final String PROFILE_PATH = "profile.json";
Expand All @@ -38,6 +43,8 @@ public class IngestTestUtils {

private final String testId;

private static final Logger logger = LoggerFactory.getLogger(IngestTestUtils.class);

private final SnowflakeStreamingIngestClient client;

private final SnowflakeStreamingIngestChannel channel;
Expand Down Expand Up @@ -146,7 +153,7 @@ private void waitForOffset(SnowflakeStreamingIngestChannel channel, String expec
expectedOffset, lastCommittedOffset));
}

public void test() throws InterruptedException {
public void runBasicTest() throws InterruptedException {
// Insert few rows one by one
for (int offset = 2; offset < 1000; offset++) {
offset++;
Expand All @@ -161,6 +168,30 @@ public void test() throws InterruptedException {
waitForOffset(channel, offset);
}

public void runLongRunningTest(Duration testDuration) throws InterruptedException {
final Instant testStart = Instant.now();
int counter = 0;
while(true) {
counter++;

channel.insertRow(createRow(), String.valueOf(counter));

if (!channel.isValid()) {
throw new IllegalStateException("Channel has been invalidated");
}
Thread.sleep(60000);

final Duration elapsed = Duration.between(testStart, Instant.now());

logger.info("Test loop_nr={} duration={}s/{}s committed_offset={}", counter, elapsed.get(ChronoUnit.SECONDS), testDuration.get(ChronoUnit.SECONDS), channel.getLatestCommittedOffsetToken());

if (elapsed.compareTo(testDuration) > 0) {
break;
}
}
waitForOffset(channel, String.valueOf(counter));
}

public void close() throws Exception {
connection.close();
channel.close().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

import java.security.Security;
import java.time.Duration;
import java.time.temporal.ChronoUnit;

public class FipsIngestE2ETest {

Expand All @@ -25,7 +28,13 @@ public void tearDown() throws Exception {
}

@Test
public void name() throws InterruptedException {
ingestTestUtils.test();
public void basicTest() throws InterruptedException {
ingestTestUtils.runBasicTest();
}

@Test
@Ignore("Takes too long to run")
public void longRunningTest() throws InterruptedException {
ingestTestUtils.runLongRunningTest(Duration.of(80, ChronoUnit.MINUTES));
}
}
3 changes: 2 additions & 1 deletion e2e-jar-test/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>snowflake-ingest-sdk</artifactId>
<version>2.0.4</version>
<!-- This value should be the same as the version in the pom.xml of the SDK -->
<version>2.0.5-SNAPSHOT</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

import java.time.Duration;
import java.time.temporal.ChronoUnit;

public class StandardIngestE2ETest {

private IngestTestUtils ingestTestUtils;
Expand All @@ -19,7 +23,13 @@ public void tearDown() throws Exception {
}

@Test
public void name() throws InterruptedException {
ingestTestUtils.test();
public void basicTest() throws InterruptedException {
ingestTestUtils.runBasicTest();
}

@Test
@Ignore("Takes too long to run")
public void longRunningTest() throws InterruptedException {
ingestTestUtils.runLongRunningTest(Duration.of(80, ChronoUnit.MINUTES));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
*/
class FlushService<T> {

// The max number of upload retry attempts to the stage
private static final int DEFAULT_MAX_UPLOAD_RETRIES = 5;

// Static class to save the list of channels that are used to build a blob, which is mainly used
// to invalidate all the channels when there is a failure
static class BlobData<T> {
Expand Down Expand Up @@ -163,7 +166,8 @@ List<List<ChannelData<T>>> getData() {
client.getRole(),
client.getHttpClient(),
client.getRequestBuilder(),
client.getName());
client.getName(),
DEFAULT_MAX_UPLOAD_RETRIES);
} catch (SnowflakeSQLException | IOException err) {
throw new SFException(err, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode;
import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper;
import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ObjectNode;
import net.snowflake.client.jdbc.internal.google.cloud.storage.StorageException;
import net.snowflake.ingest.connection.IngestResponseException;
import net.snowflake.ingest.connection.RequestBuilder;
import net.snowflake.ingest.utils.ErrorCode;
Expand All @@ -46,7 +47,6 @@ class StreamingIngestStage {
private static final ObjectMapper mapper = new ObjectMapper();
private static final long REFRESH_THRESHOLD_IN_MS =
TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES);
static final int MAX_RETRY_COUNT = 1;

private static final Logging logger = new Logging(StreamingIngestStage.class);

Expand Down Expand Up @@ -86,6 +86,8 @@ state to record unknown age.
private final String clientName;
private String clientPrefix;

private final int maxUploadRetries;

// Proxy parameters that we set while calling the Snowflake JDBC to upload the streams
private final Properties proxyProperties;

Expand All @@ -94,13 +96,15 @@ state to record unknown age.
String role,
CloseableHttpClient httpClient,
RequestBuilder requestBuilder,
String clientName)
String clientName,
int maxUploadRetries)
throws SnowflakeSQLException, IOException {
this.httpClient = httpClient;
this.role = role;
this.requestBuilder = requestBuilder;
this.clientName = clientName;
this.proxyProperties = generateProxyPropertiesForJDBC();
this.maxUploadRetries = maxUploadRetries;

if (!isTestMode) {
refreshSnowflakeMetadata();
Expand All @@ -123,9 +127,10 @@ state to record unknown age.
CloseableHttpClient httpClient,
RequestBuilder requestBuilder,
String clientName,
SnowflakeFileTransferMetadataWithAge testMetadata)
SnowflakeFileTransferMetadataWithAge testMetadata,
int maxRetryCount)
throws SnowflakeSQLException, IOException {
this(isTestMode, role, httpClient, requestBuilder, clientName);
this(isTestMode, role, httpClient, requestBuilder, clientName, maxRetryCount);
if (!isTestMode) {
throw new SFException(ErrorCode.INTERNAL_ERROR);
}
Expand Down Expand Up @@ -187,17 +192,49 @@ private void putRemote(String fullFilePath, byte[] data, int retryCount)
.setProxyProperties(this.proxyProperties)
.setDestFileName(fullFilePath)
.build());
} catch (SnowflakeSQLException e) {
if (e.getErrorCode() != CLOUD_STORAGE_CREDENTIALS_EXPIRED || retryCount >= MAX_RETRY_COUNT) {
} catch (Exception e) {
if (retryCount >= maxUploadRetries) {
logger.logError(
"Failed to upload to stage, client={}, message={}", clientName, e.getMessage());
throw e;
"Failed to upload to stage, retry attempts exhausted ({}), client={}, message={}",
maxUploadRetries,
clientName,
e.getMessage());
throw new SFException(e, ErrorCode.IO_ERROR);
}
this.refreshSnowflakeMetadata();
this.putRemote(fullFilePath, data, ++retryCount);
} catch (Exception e) {
throw new SFException(e, ErrorCode.IO_ERROR);

if (isCredentialsExpiredException(e)) {
logger.logInfo(
"Stage metadata need to be refreshed due to upload error: {}", e.getMessage());
this.refreshSnowflakeMetadata();
}
retryCount++;
StreamingIngestUtils.sleepForRetry(retryCount);
logger.logInfo(
"Retrying upload, attempt {}/{} {}",
retryCount,
maxUploadRetries,
e.getMessage());
this.putRemote(fullFilePath, data, retryCount);
}
}

/**
* @return Whether the passed exception means that credentials expired and the stage metadata
* should be refreshed from Snowflake. The reasons for refresh is SnowflakeSQLException with
* error code 240001 (thrown by the JDBC driver) or GCP StorageException with HTTP status 401.
*/
static boolean isCredentialsExpiredException(Exception e) {
if (e == null || e.getClass() == null) {
return false;
}

if (e instanceof SnowflakeSQLException) {
return ((SnowflakeSQLException) e).getErrorCode() == CLOUD_STORAGE_CREDENTIALS_EXPIRED;
} else if (e instanceof StorageException) {
return ((StorageException) e).getCode() == 401;
}

return false;
}

SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata()
Expand Down Expand Up @@ -399,7 +436,6 @@ void putLocal(String fullFilePath, byte[] data) {
String stageLocation = this.fileTransferMetadataWithAge.localLocation;
File destFile = Paths.get(stageLocation, fullFilePath).toFile();
FileUtils.copyInputStreamToFile(input, destFile);
System.out.println("Filename: " + destFile); // TODO @rcheng - remove this before merge
} catch (Exception ex) {
throw new SFException(ex, ErrorCode.BLOB_UPLOAD_FAILURE);
}
Expand Down
Loading

0 comments on commit 98562a1

Please sign in to comment.