From f9b5b804baaf5da2630affd640e84ee678e0bd89 Mon Sep 17 00:00:00 2001 From: Purujit Saha Date: Tue, 16 Jul 2024 11:16:57 -0700 Subject: [PATCH] Reject new stage metadata if the deployment id does not match what the client was created with (#794) * Reject new stage metadata if the deployment id does not match what the client was created with * Review comments --- .../internal/StreamingIngestStage.java | 22 ++++++- .../net/snowflake/ingest/utils/ErrorCode.java | 3 +- .../ingest/ingest_error_messages.properties | 5 +- .../internal/StreamingIngestStageTest.java | 62 +++++++++++++++++++ 4 files changed, 86 insertions(+), 6 deletions(-) diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java index ed73e3774..5556b7205 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java @@ -25,6 +25,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import javax.annotation.Nullable; import net.snowflake.client.core.OCSPMode; import net.snowflake.client.jdbc.SnowflakeFileTransferAgent; import net.snowflake.client.jdbc.SnowflakeFileTransferConfig; @@ -91,6 +92,7 @@ state to record unknown age. private final String role; private final String clientName; private String clientPrefix; + private Long deploymentId; private final int maxUploadRetries; @@ -258,9 +260,11 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole payload.put("role", this.role); Map response = this.makeClientConfigureCall(payload); - JsonNode responseNode = this.parseClientConfigureResponse(response); + JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId); // Do not change the prefix everytime we have to refresh credentials if (Utils.isNullOrEmpty(this.clientPrefix)) { + this.deploymentId = + responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null; this.clientPrefix = createClientPrefix(responseNode); } Utils.assertStringNotNullOrEmpty("client prefix", this.clientPrefix); @@ -326,7 +330,7 @@ SnowflakeFileTransferMetadataV1 fetchSignedURL(String fileName) payload.put("file_name", fileName); Map response = this.makeClientConfigureCall(payload); - JsonNode responseNode = this.parseClientConfigureResponse(response); + JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId); SnowflakeFileTransferMetadataV1 metadata = (SnowflakeFileTransferMetadataV1) @@ -350,7 +354,8 @@ public Long apply(T input) { private static final MapStatusGetter statusGetter = new MapStatusGetter(); - private JsonNode parseClientConfigureResponse(Map response) { + private JsonNode parseClientConfigureResponse( + Map response, @Nullable Long expectedDeploymentId) { JsonNode responseNode = mapper.valueToTree(response); // Currently there are a few mismatches between the client/configure response and what @@ -362,6 +367,17 @@ private JsonNode parseClientConfigureResponse(Map response) { // JDBC expects this field which maps to presignedFileUrlName. We will set this later dataNode.putArray("src_locations").add("placeholder"); + if (expectedDeploymentId != null) { + Long actualDeploymentId = + responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null; + if (actualDeploymentId != null && !actualDeploymentId.equals(expectedDeploymentId)) { + throw new SFException( + ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH, + expectedDeploymentId, + actualDeploymentId, + clientName); + } + } return responseNode; } diff --git a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java index b863717e9..fdd8a713b 100644 --- a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java +++ b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java @@ -41,7 +41,8 @@ public enum ErrorCode { OAUTH_REFRESH_TOKEN_ERROR("0033"), INVALID_CONFIG_PARAMETER("0034"), CRYPTO_PROVIDER_ERROR("0035"), - DROP_CHANNEL_FAILURE("0036"); + DROP_CHANNEL_FAILURE("0036"), + CLIENT_DEPLOYMENT_ID_MISMATCH("0037"); public static final String errorMessageResource = "net.snowflake.ingest.ingest_error_messages"; diff --git a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties index d2fea0b0d..03e50d9b6 100644 --- a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties +++ b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties @@ -37,5 +37,6 @@ 0032=URI builder fail to build url: {0} 0033=OAuth token refresh failure: {0} 0034=Invalid config parameter: {0} -0035=Too large batch of rows passed to insertRows, the batch size cannot exceed {0} bytes, recommended batch size for optimal performance and memory utilization is {1} bytes. We recommend splitting large batches into multiple smaller ones and call insertRows for each smaller batch separately. -0036=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1} \ No newline at end of file +0035=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1} +0036=Failed to drop channel: {0} +0037=Deployment ID mismatch, Client was created on: {0}, Got upload location for: {1}. Please restart client: {2}. \ No newline at end of file diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java index 1ba9f98df..2458c5bd5 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java @@ -43,6 +43,7 @@ import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.ParameterProvider; import net.snowflake.ingest.utils.SFException; import org.junit.Assert; @@ -99,6 +100,21 @@ public class StreamingIngestStageTest { + " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":" + " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null," + " \"endPoint\": null}}"; + String remoteMetaResponseDifferentDeployment = + "{\"src_locations\": [\"foo/\"]," + + " \"deployment_id\": " + + (deploymentId + 1) + + "," + + " \"status_code\": 0, \"message\": \"Success\", \"prefix\":" + + " \"" + + prefix + + "\", \"stage_location\": {\"locationType\": \"S3\", \"location\":" + + " \"foo/streaming_ingest/\", \"path\": \"streaming_ingest/\", \"region\":" + + " \"us-east-1\", \"storageAccount\": null, \"isClientSideEncrypted\": true," + + " \"creds\": {\"AWS_KEY_ID\": \"EXAMPLE_AWS_KEY_ID\", \"AWS_SECRET_KEY\":" + + " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":" + + " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null," + + " \"endPoint\": null}}"; private void setupMocksForRefresh() throws Exception { PowerMockito.mockStatic(HttpUtil.class); @@ -302,6 +318,52 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix()); } + @Test + public void testRefreshSnowflakeMetadataDeploymentIdMismatch() throws Exception { + RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); + CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); + CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + StatusLine mockStatusLine = Mockito.mock(StatusLine.class); + Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); + Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); + + BasicHttpEntity entity = new BasicHttpEntity(); + entity.setContent( + new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); + + BasicHttpEntity entityFromDifferentDeployment = new BasicHttpEntity(); + entityFromDifferentDeployment.setContent( + new ByteArrayInputStream( + remoteMetaResponseDifferentDeployment.getBytes(StandardCharsets.UTF_8))); + Mockito.when(mockResponse.getEntity()) + .thenReturn(entity) + .thenReturn(entityFromDifferentDeployment); + Mockito.when(mockClient.execute(Mockito.any())) + .thenReturn(mockResponse) + .thenReturn(mockResponse); + + StreamingIngestStage stage = + new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + + StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge = + stage.refreshSnowflakeMetadata(true); + + Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix()); + + SFException exception = + Assert.assertThrows(SFException.class, () -> stage.refreshSnowflakeMetadata(true)); + Assert.assertEquals( + ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH.getMessageCode(), exception.getVendorCode()); + Assert.assertEquals( + "Deployment ID mismatch, Client was created on: " + + deploymentId + + ", Got upload location for: " + + (deploymentId + 1) + + ". Please" + + " restart client: clientName.", + exception.getMessage()); + } + @Test public void testFetchSignedURL() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class);