Skip to content

Commit

Permalink
Reject new stage metadata if the deployment id does not match what th…
Browse files Browse the repository at this point in the history
…e client was created with (snowflakedb#794)

* Reject new stage metadata if the deployment id does not match what the client was created with

* Review comments
  • Loading branch information
sfc-gh-psaha authored and sfc-gh-kgaputis committed Sep 12, 2024
1 parent 252d4c5 commit f9b5b80
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import javax.annotation.Nullable;
import net.snowflake.client.core.OCSPMode;
import net.snowflake.client.jdbc.SnowflakeFileTransferAgent;
import net.snowflake.client.jdbc.SnowflakeFileTransferConfig;
Expand Down Expand Up @@ -91,6 +92,7 @@ state to record unknown age.
private final String role;
private final String clientName;
private String clientPrefix;
private Long deploymentId;

private final int maxUploadRetries;

Expand Down Expand Up @@ -258,9 +260,11 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole
payload.put("role", this.role);
Map<String, Object> response = this.makeClientConfigureCall(payload);

JsonNode responseNode = this.parseClientConfigureResponse(response);
JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId);
// Do not change the prefix everytime we have to refresh credentials
if (Utils.isNullOrEmpty(this.clientPrefix)) {
this.deploymentId =
responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null;
this.clientPrefix = createClientPrefix(responseNode);
}
Utils.assertStringNotNullOrEmpty("client prefix", this.clientPrefix);
Expand Down Expand Up @@ -326,7 +330,7 @@ SnowflakeFileTransferMetadataV1 fetchSignedURL(String fileName)
payload.put("file_name", fileName);
Map<String, Object> response = this.makeClientConfigureCall(payload);

JsonNode responseNode = this.parseClientConfigureResponse(response);
JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId);

SnowflakeFileTransferMetadataV1 metadata =
(SnowflakeFileTransferMetadataV1)
Expand All @@ -350,7 +354,8 @@ public Long apply(T input) {

private static final MapStatusGetter statusGetter = new MapStatusGetter();

private JsonNode parseClientConfigureResponse(Map<String, Object> response) {
private JsonNode parseClientConfigureResponse(
Map<String, Object> response, @Nullable Long expectedDeploymentId) {
JsonNode responseNode = mapper.valueToTree(response);

// Currently there are a few mismatches between the client/configure response and what
Expand All @@ -362,6 +367,17 @@ private JsonNode parseClientConfigureResponse(Map<String, Object> response) {

// JDBC expects this field which maps to presignedFileUrlName. We will set this later
dataNode.putArray("src_locations").add("placeholder");
if (expectedDeploymentId != null) {
Long actualDeploymentId =
responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null;
if (actualDeploymentId != null && !actualDeploymentId.equals(expectedDeploymentId)) {
throw new SFException(
ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH,
expectedDeploymentId,
actualDeploymentId,
clientName);
}
}
return responseNode;
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/ingest/utils/ErrorCode.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public enum ErrorCode {
OAUTH_REFRESH_TOKEN_ERROR("0033"),
INVALID_CONFIG_PARAMETER("0034"),
CRYPTO_PROVIDER_ERROR("0035"),
DROP_CHANNEL_FAILURE("0036");
DROP_CHANNEL_FAILURE("0036"),
CLIENT_DEPLOYMENT_ID_MISMATCH("0037");

public static final String errorMessageResource = "net.snowflake.ingest.ingest_error_messages";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
0032=URI builder fail to build url: {0}
0033=OAuth token refresh failure: {0}
0034=Invalid config parameter: {0}
0035=Too large batch of rows passed to insertRows, the batch size cannot exceed {0} bytes, recommended batch size for optimal performance and memory utilization is {1} bytes. We recommend splitting large batches into multiple smaller ones and call insertRows for each smaller batch separately.
0036=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1}
0035=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1}
0036=Failed to drop channel: {0}
0037=Deployment ID mismatch, Client was created on: {0}, Got upload location for: {1}. Please restart client: {2}.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import net.snowflake.ingest.TestUtils;
import net.snowflake.ingest.connection.RequestBuilder;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.ParameterProvider;
import net.snowflake.ingest.utils.SFException;
import org.junit.Assert;
Expand Down Expand Up @@ -99,6 +100,21 @@ public class StreamingIngestStageTest {
+ " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":"
+ " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null,"
+ " \"endPoint\": null}}";
String remoteMetaResponseDifferentDeployment =
"{\"src_locations\": [\"foo/\"],"
+ " \"deployment_id\": "
+ (deploymentId + 1)
+ ","
+ " \"status_code\": 0, \"message\": \"Success\", \"prefix\":"
+ " \""
+ prefix
+ "\", \"stage_location\": {\"locationType\": \"S3\", \"location\":"
+ " \"foo/streaming_ingest/\", \"path\": \"streaming_ingest/\", \"region\":"
+ " \"us-east-1\", \"storageAccount\": null, \"isClientSideEncrypted\": true,"
+ " \"creds\": {\"AWS_KEY_ID\": \"EXAMPLE_AWS_KEY_ID\", \"AWS_SECRET_KEY\":"
+ " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":"
+ " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null,"
+ " \"endPoint\": null}}";

private void setupMocksForRefresh() throws Exception {
PowerMockito.mockStatic(HttpUtil.class);
Expand Down Expand Up @@ -302,6 +318,52 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception {
Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix());
}

@Test
public void testRefreshSnowflakeMetadataDeploymentIdMismatch() throws Exception {
RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class);
CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);
CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class);
StatusLine mockStatusLine = Mockito.mock(StatusLine.class);
Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200);
Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);

BasicHttpEntity entity = new BasicHttpEntity();
entity.setContent(
new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8)));

BasicHttpEntity entityFromDifferentDeployment = new BasicHttpEntity();
entityFromDifferentDeployment.setContent(
new ByteArrayInputStream(
remoteMetaResponseDifferentDeployment.getBytes(StandardCharsets.UTF_8)));
Mockito.when(mockResponse.getEntity())
.thenReturn(entity)
.thenReturn(entityFromDifferentDeployment);
Mockito.when(mockClient.execute(Mockito.any()))
.thenReturn(mockResponse)
.thenReturn(mockResponse);

StreamingIngestStage stage =
new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1);

StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge =
stage.refreshSnowflakeMetadata(true);

Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix());

SFException exception =
Assert.assertThrows(SFException.class, () -> stage.refreshSnowflakeMetadata(true));
Assert.assertEquals(
ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH.getMessageCode(), exception.getVendorCode());
Assert.assertEquals(
"Deployment ID mismatch, Client was created on: "
+ deploymentId
+ ", Got upload location for: "
+ (deploymentId + 1)
+ ". Please"
+ " restart client: clientName.",
exception.getMessage());
}

@Test
public void testFetchSignedURL() throws Exception {
RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class);
Expand Down

0 comments on commit f9b5b80

Please sign in to comment.