From 98562a155dd056b1e43b09baf163999bb6d9e4df Mon Sep 17 00:00:00 2001 From: Lukas Sembera Date: Sat, 13 Jan 2024 16:32:00 +0000 Subject: [PATCH] SNOW-995369 GCS token refresh --- .../java/net/snowflake/IngestTestUtils.java | 33 +++++++++- .../java/net/snowflake/FipsIngestE2ETest.java | 13 +++- e2e-jar-test/pom.xml | 3 +- .../net/snowflake/StandardIngestE2ETest.java | 14 ++++- .../streaming/internal/FlushService.java | 6 +- .../internal/StreamingIngestStage.java | 62 +++++++++++++++---- .../internal/StreamingIngestStageTest.java | 52 ++++++++++++---- 7 files changed, 150 insertions(+), 33 deletions(-) diff --git a/e2e-jar-test/core/src/main/java/net/snowflake/IngestTestUtils.java b/e2e-jar-test/core/src/main/java/net/snowflake/IngestTestUtils.java index 100972ea0..e7db8d1c2 100644 --- a/e2e-jar-test/core/src/main/java/net/snowflake/IngestTestUtils.java +++ b/e2e-jar-test/core/src/main/java/net/snowflake/IngestTestUtils.java @@ -13,6 +13,9 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Base64; import java.util.HashMap; @@ -26,6 +29,8 @@ import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class IngestTestUtils { private static final String PROFILE_PATH = "profile.json"; @@ -38,6 +43,8 @@ public class IngestTestUtils { private final String testId; + private static final Logger logger = LoggerFactory.getLogger(IngestTestUtils.class); + private final SnowflakeStreamingIngestClient client; private final SnowflakeStreamingIngestChannel channel; @@ -146,7 +153,7 @@ private void waitForOffset(SnowflakeStreamingIngestChannel channel, String expec expectedOffset, lastCommittedOffset)); } - public void test() throws InterruptedException { + public void runBasicTest() throws InterruptedException { // Insert few rows one by one for (int offset = 2; offset < 1000; offset++) { offset++; @@ -161,6 +168,30 @@ public void test() throws InterruptedException { waitForOffset(channel, offset); } + public void runLongRunningTest(Duration testDuration) throws InterruptedException { + final Instant testStart = Instant.now(); + int counter = 0; + while(true) { + counter++; + + channel.insertRow(createRow(), String.valueOf(counter)); + + if (!channel.isValid()) { + throw new IllegalStateException("Channel has been invalidated"); + } + Thread.sleep(60000); + + final Duration elapsed = Duration.between(testStart, Instant.now()); + + logger.info("Test loop_nr={} duration={}s/{}s committed_offset={}", counter, elapsed.get(ChronoUnit.SECONDS), testDuration.get(ChronoUnit.SECONDS), channel.getLatestCommittedOffsetToken()); + + if (elapsed.compareTo(testDuration) > 0) { + break; + } + } + waitForOffset(channel, String.valueOf(counter)); + } + public void close() throws Exception { connection.close(); channel.close().get(); diff --git a/e2e-jar-test/fips/src/test/java/net/snowflake/FipsIngestE2ETest.java b/e2e-jar-test/fips/src/test/java/net/snowflake/FipsIngestE2ETest.java index 7279f23ff..c6f9bfe33 100644 --- a/e2e-jar-test/fips/src/test/java/net/snowflake/FipsIngestE2ETest.java +++ b/e2e-jar-test/fips/src/test/java/net/snowflake/FipsIngestE2ETest.java @@ -3,9 +3,12 @@ import org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import java.security.Security; +import java.time.Duration; +import java.time.temporal.ChronoUnit; public class FipsIngestE2ETest { @@ -25,7 +28,13 @@ public void tearDown() throws Exception { } @Test - public void name() throws InterruptedException { - ingestTestUtils.test(); + public void basicTest() throws InterruptedException { + ingestTestUtils.runBasicTest(); + } + + @Test + @Ignore("Takes too long to run") + public void longRunningTest() throws InterruptedException { + ingestTestUtils.runLongRunningTest(Duration.of(80, ChronoUnit.MINUTES)); } } diff --git a/e2e-jar-test/pom.xml b/e2e-jar-test/pom.xml index 4c236b483..e4ba3635b 100644 --- a/e2e-jar-test/pom.xml +++ b/e2e-jar-test/pom.xml @@ -27,7 +27,8 @@ net.snowflake snowflake-ingest-sdk - 2.0.4 + + 2.0.5-SNAPSHOT diff --git a/e2e-jar-test/standard/src/test/java/net/snowflake/StandardIngestE2ETest.java b/e2e-jar-test/standard/src/test/java/net/snowflake/StandardIngestE2ETest.java index 255577655..211c421fc 100644 --- a/e2e-jar-test/standard/src/test/java/net/snowflake/StandardIngestE2ETest.java +++ b/e2e-jar-test/standard/src/test/java/net/snowflake/StandardIngestE2ETest.java @@ -2,8 +2,12 @@ import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; +import java.time.Duration; +import java.time.temporal.ChronoUnit; + public class StandardIngestE2ETest { private IngestTestUtils ingestTestUtils; @@ -19,7 +23,13 @@ public void tearDown() throws Exception { } @Test - public void name() throws InterruptedException { - ingestTestUtils.test(); + public void basicTest() throws InterruptedException { + ingestTestUtils.runBasicTest(); + } + + @Test + @Ignore("Takes too long to run") + public void longRunningTest() throws InterruptedException { + ingestTestUtils.runLongRunningTest(Duration.of(80, ChronoUnit.MINUTES)); } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index 0e6998bdc..955b05d6f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -59,6 +59,9 @@ */ class FlushService { + // The max number of upload retry attempts to the stage + private static final int DEFAULT_MAX_UPLOAD_RETRIES = 5; + // Static class to save the list of channels that are used to build a blob, which is mainly used // to invalidate all the channels when there is a failure static class BlobData { @@ -163,7 +166,8 @@ List>> getData() { client.getRole(), client.getHttpClient(), client.getRequestBuilder(), - client.getName()); + client.getName(), + DEFAULT_MAX_UPLOAD_RETRIES); } catch (SnowflakeSQLException | IOException err) { throw new SFException(err, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java index 0d7e3f211..15d90f33f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java @@ -34,6 +34,7 @@ import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ObjectNode; +import net.snowflake.client.jdbc.internal.google.cloud.storage.StorageException; import net.snowflake.ingest.connection.IngestResponseException; import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.ErrorCode; @@ -46,7 +47,6 @@ class StreamingIngestStage { private static final ObjectMapper mapper = new ObjectMapper(); private static final long REFRESH_THRESHOLD_IN_MS = TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES); - static final int MAX_RETRY_COUNT = 1; private static final Logging logger = new Logging(StreamingIngestStage.class); @@ -86,6 +86,8 @@ state to record unknown age. private final String clientName; private String clientPrefix; + private final int maxUploadRetries; + // Proxy parameters that we set while calling the Snowflake JDBC to upload the streams private final Properties proxyProperties; @@ -94,13 +96,15 @@ state to record unknown age. String role, CloseableHttpClient httpClient, RequestBuilder requestBuilder, - String clientName) + String clientName, + int maxUploadRetries) throws SnowflakeSQLException, IOException { this.httpClient = httpClient; this.role = role; this.requestBuilder = requestBuilder; this.clientName = clientName; this.proxyProperties = generateProxyPropertiesForJDBC(); + this.maxUploadRetries = maxUploadRetries; if (!isTestMode) { refreshSnowflakeMetadata(); @@ -123,9 +127,10 @@ state to record unknown age. CloseableHttpClient httpClient, RequestBuilder requestBuilder, String clientName, - SnowflakeFileTransferMetadataWithAge testMetadata) + SnowflakeFileTransferMetadataWithAge testMetadata, + int maxRetryCount) throws SnowflakeSQLException, IOException { - this(isTestMode, role, httpClient, requestBuilder, clientName); + this(isTestMode, role, httpClient, requestBuilder, clientName, maxRetryCount); if (!isTestMode) { throw new SFException(ErrorCode.INTERNAL_ERROR); } @@ -187,17 +192,49 @@ private void putRemote(String fullFilePath, byte[] data, int retryCount) .setProxyProperties(this.proxyProperties) .setDestFileName(fullFilePath) .build()); - } catch (SnowflakeSQLException e) { - if (e.getErrorCode() != CLOUD_STORAGE_CREDENTIALS_EXPIRED || retryCount >= MAX_RETRY_COUNT) { + } catch (Exception e) { + if (retryCount >= maxUploadRetries) { logger.logError( - "Failed to upload to stage, client={}, message={}", clientName, e.getMessage()); - throw e; + "Failed to upload to stage, retry attempts exhausted ({}), client={}, message={}", + maxUploadRetries, + clientName, + e.getMessage()); + throw new SFException(e, ErrorCode.IO_ERROR); } - this.refreshSnowflakeMetadata(); - this.putRemote(fullFilePath, data, ++retryCount); - } catch (Exception e) { - throw new SFException(e, ErrorCode.IO_ERROR); + + if (isCredentialsExpiredException(e)) { + logger.logInfo( + "Stage metadata need to be refreshed due to upload error: {}", e.getMessage()); + this.refreshSnowflakeMetadata(); + } + retryCount++; + StreamingIngestUtils.sleepForRetry(retryCount); + logger.logInfo( + "Retrying upload, attempt {}/{} {}", + retryCount, + maxUploadRetries, + e.getMessage()); + this.putRemote(fullFilePath, data, retryCount); + } + } + + /** + * @return Whether the passed exception means that credentials expired and the stage metadata + * should be refreshed from Snowflake. The reasons for refresh is SnowflakeSQLException with + * error code 240001 (thrown by the JDBC driver) or GCP StorageException with HTTP status 401. + */ + static boolean isCredentialsExpiredException(Exception e) { + if (e == null || e.getClass() == null) { + return false; } + + if (e instanceof SnowflakeSQLException) { + return ((SnowflakeSQLException) e).getErrorCode() == CLOUD_STORAGE_CREDENTIALS_EXPIRED; + } else if (e instanceof StorageException) { + return ((StorageException) e).getCode() == 401; + } + + return false; } SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata() @@ -399,7 +436,6 @@ void putLocal(String fullFilePath, byte[] data) { String stageLocation = this.fileTransferMetadataWithAge.localLocation; File destFile = Paths.get(stageLocation, fullFilePath).toFile(); FileUtils.copyInputStreamToFile(input, destFile); - System.out.println("Filename: " + destFile); // TODO @rcheng - remove this before merge } catch (Exception ex) { throw new SFException(ex, ErrorCode.BLOB_UPLOAD_FAILURE); } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java index a137ab9ed..a14f1c46b 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java @@ -1,6 +1,7 @@ package net.snowflake.ingest.streaming.internal; import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED; +import static net.snowflake.ingest.streaming.internal.StreamingIngestStage.isCredentialsExpiredException; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_PASSWORD; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_USER; import static net.snowflake.ingest.utils.HttpUtil.NON_PROXY_HOSTS; @@ -39,11 +40,13 @@ import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; +import net.snowflake.client.jdbc.internal.google.cloud.storage.StorageException; import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ParameterProvider; +import net.snowflake.ingest.utils.SFException; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -121,7 +124,8 @@ public void testPutRemote() throws Exception { null, "clientName", new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(System.currentTimeMillis()))); + originalMetadata, Optional.of(System.currentTimeMillis())), + 1); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); final ArgumentCaptor captor = @@ -163,7 +167,8 @@ public void testPutLocal() throws Exception { null, "clientName", new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - fullFilePath, Optional.of(System.currentTimeMillis())))); + fullFilePath, Optional.of(System.currentTimeMillis())), + 1)); Mockito.doReturn(true).when(stage).isLocalFS(); stage.put(fileName, dataBytes); @@ -174,7 +179,8 @@ public void testPutLocal() throws Exception { } @Test - public void testPutRemoteRefreshes() throws Exception { + public void doTestPutRemoteRefreshes() throws Exception { + int maxUploadRetryCount = 2; JsonNode exampleJson = mapper.readTree(exampleRemoteMeta); SnowflakeFileTransferMetadataV1 originalMetadata = (SnowflakeFileTransferMetadataV1) @@ -190,7 +196,8 @@ public void testPutRemoteRefreshes() throws Exception { null, "clientName", new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(System.currentTimeMillis()))); + originalMetadata, Optional.of(System.currentTimeMillis())), + maxUploadRetryCount); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeSQLException e = new SnowflakeSQLException( @@ -202,16 +209,15 @@ public void testPutRemoteRefreshes() throws Exception { try { stage.putRemote("test/path", dataBytes); - Assert.assertTrue(false); - } catch (SnowflakeSQLException ex) { + Assert.fail("Should not succeed"); + } catch (SFException ex) { // Expected behavior given mocked response } - PowerMockito.verifyStatic( - SnowflakeFileTransferAgent.class, times(StreamingIngestStage.MAX_RETRY_COUNT + 1)); + PowerMockito.verifyStatic(SnowflakeFileTransferAgent.class, times(maxUploadRetryCount + 1)); SnowflakeFileTransferAgent.uploadWithoutConnection(captor.capture()); SnowflakeFileTransferConfig capturedConfig = captor.getValue(); - Assert.assertEquals(false, capturedConfig.getRequireCompress()); + Assert.assertFalse(capturedConfig.getRequireCompress()); Assert.assertEquals(OCSPMode.FAIL_OPEN, capturedConfig.getOcspMode()); SnowflakeFileTransferMetadataV1 capturedMetadata = @@ -245,7 +251,8 @@ public void testPutRemoteGCS() throws Exception { null, "clientName", new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(System.currentTimeMillis())))); + originalMetadata, Optional.of(System.currentTimeMillis())), + 1)); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeFileTransferMetadataV1 metaMock = Mockito.mock(SnowflakeFileTransferMetadataV1.class); @@ -273,7 +280,7 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { ParameterProvider parameterProvider = new ParameterProvider(); StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName"); + new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge = stage.refreshSnowflakeMetadata(true); @@ -314,7 +321,7 @@ public void testFetchSignedURL() throws Exception { Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName"); + new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); SnowflakeFileTransferMetadataV1 metadata = stage.fetchSignedURL("path/fileName"); @@ -359,7 +366,8 @@ public void testRefreshSnowflakeMetadataSynchronized() throws Exception { mockBuilder, "clientName", new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(0L))); + originalMetadata, Optional.of(0L)), + 1); ThreadFactory buildUploadThreadFactory = new ThreadFactoryBuilder().setNameFormat("ingest-build-upload-thread-%d").build(); @@ -476,4 +484,22 @@ public void testShouldBypassProxy() { System.setProperty(NON_PROXY_HOSTS, oldNonProxyHosts); } } + + @Test + public void testIsCredentialExpiredException() { + Assert.assertTrue( + isCredentialsExpiredException( + new SnowflakeSQLException("Error", CLOUD_STORAGE_CREDENTIALS_EXPIRED))); + Assert.assertTrue(isCredentialsExpiredException(new StorageException(401, "unauthorized"))); + + Assert.assertFalse(isCredentialsExpiredException(new StorageException(400, "bad request"))); + Assert.assertFalse(isCredentialsExpiredException(null)); + Assert.assertFalse(isCredentialsExpiredException(new RuntimeException())); + Assert.assertFalse( + isCredentialsExpiredException( + new RuntimeException(String.valueOf(CLOUD_STORAGE_CREDENTIALS_EXPIRED)))); + Assert.assertFalse( + isCredentialsExpiredException( + new SnowflakeSQLException("Error", CLOUD_STORAGE_CREDENTIALS_EXPIRED + 1))); + } }