Skip to content

Commit

Permalink
mrege in master
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang committed Jul 16, 2024
2 parents 73edfb4 + 7310921 commit 4507637
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ class InternalStageManager<T> implements StorageManager<T, InternalStageLocation
/** Snowflake service client used for configure calls */
private final SnowflakeServiceClient snowflakeServiceClient;

/** The name of the client */
private final String clientName;

/** The role of the client */
private final String role;

/** Client prefix generated by the Snowflake server */
private final String clientPrefix;
private String clientPrefix;

/** Deployment ID generated by the Snowflake server */
private Long deploymentId;

/**
* Constructor for InternalStageManager
Expand All @@ -58,13 +64,15 @@ class InternalStageManager<T> implements StorageManager<T, InternalStageLocation
SnowflakeServiceClient snowflakeServiceClient) {
this.snowflakeServiceClient = snowflakeServiceClient;
this.isTestMode = isTestMode;
this.clientName = clientName;
this.role = role;
this.counter = new AtomicLong(0);
try {
if (!isTestMode) {
ClientConfigureResponse response =
this.snowflakeServiceClient.clientConfigure(new ClientConfigureRequest(role));
this.clientPrefix = response.getClientPrefix();
this.deploymentId = response.getDeploymentId();
this.targetStage =
new StreamingIngestStorage<T, InternalStageLocation>(
this,
Expand All @@ -73,7 +81,8 @@ class InternalStageManager<T> implements StorageManager<T, InternalStageLocation
new InternalStageLocation(),
DEFAULT_MAX_UPLOAD_RETRIES);
} else {
this.clientPrefix = "testPrefix";
this.clientPrefix = null;
this.deploymentId = null;
this.targetStage =
new StreamingIngestStorage<T, InternalStageLocation>(
this,
Expand Down Expand Up @@ -124,6 +133,17 @@ public FileLocationInfo getRefreshedLocation(
ClientConfigureRequest request = new ClientConfigureRequest(this.role);
fileName.ifPresent(request::setFileName);
ClientConfigureResponse response = snowflakeServiceClient.clientConfigure(request);
if (this.clientPrefix == null) {
this.clientPrefix = response.getClientPrefix();
this.deploymentId = response.getDeploymentId();
}
if (this.deploymentId != null && !this.deploymentId.equals(response.getDeploymentId())) {
throw new SFException(
ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH,
this.deploymentId,
response.getDeploymentId(),
this.clientName);
}
return response.getStageLocation();
} catch (IngestResponseException | IOException e) {
throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import javax.annotation.Nullable;
import net.snowflake.client.core.OCSPMode;
import net.snowflake.client.jdbc.SnowflakeFileTransferAgent;
import net.snowflake.client.jdbc.SnowflakeFileTransferConfig;
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/ingest/utils/ErrorCode.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public enum ErrorCode {
OAUTH_REFRESH_TOKEN_ERROR("0033"),
INVALID_CONFIG_PARAMETER("0034"),
CRYPTO_PROVIDER_ERROR("0035"),
DROP_CHANNEL_FAILURE("0036");
DROP_CHANNEL_FAILURE("0036"),
CLIENT_DEPLOYMENT_ID_MISMATCH("0037");

public static final String errorMessageResource = "net.snowflake.ingest.ingest_error_messages";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
0032=URI builder fail to build url: {0}
0033=OAuth token refresh failure: {0}
0034=Invalid config parameter: {0}
0035=Too large batch of rows passed to insertRows, the batch size cannot exceed {0} bytes, recommended batch size for optimal performance and memory utilization is {1} bytes. We recommend splitting large batches into multiple smaller ones and call insertRows for each smaller batch separately.
0036=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1}
0035=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1}
0036=Failed to drop channel: {0}
0037=Deployment ID mismatch, Client was created on: {0}, Got upload location for: {1}. Please restart client: {2}.
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming.internal;

import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED;
Expand Down Expand Up @@ -48,6 +44,7 @@
import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder;
import net.snowflake.ingest.TestUtils;
import net.snowflake.ingest.connection.RequestBuilder;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.SFException;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -103,6 +100,21 @@ public class StreamingIngestStorageTest {
+ " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":"
+ " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null,"
+ " \"endPoint\": null}}";
String remoteMetaResponseDifferentDeployment =
"{\"src_locations\": [\"foo/\"],"
+ " \"deployment_id\": "
+ (deploymentId + 1)
+ ","
+ " \"status_code\": 0, \"message\": \"Success\", \"prefix\":"
+ " \""
+ prefix
+ "\", \"stage_location\": {\"locationType\": \"S3\", \"location\":"
+ " \"foo/streaming_ingest/\", \"path\": \"streaming_ingest/\", \"region\":"
+ " \"us-east-1\", \"storageAccount\": null, \"isClientSideEncrypted\": true,"
+ " \"creds\": {\"AWS_KEY_ID\": \"EXAMPLE_AWS_KEY_ID\", \"AWS_SECRET_KEY\":"
+ " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":"
+ " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null,"
+ " \"endPoint\": null}}";

private void setupMocksForRefresh() throws Exception {
PowerMockito.mockStatic(HttpUtil.class);
Expand Down Expand Up @@ -193,7 +205,7 @@ public void doTestPutRemoteRefreshes() throws Exception {
Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix");

StreamingIngestStorage<?, ?> stage =
new StreamingIngestStorage(
new StreamingIngestStorage<>(
storageManager,
"clientName",
new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge(
Expand Down Expand Up @@ -249,7 +261,7 @@ public void testPutRemoteGCS() throws Exception {

StreamingIngestStorage<?, ?> stage =
Mockito.spy(
new StreamingIngestStorage(
new StreamingIngestStorage<>(
storageManager,
"clientName",
new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge(
Expand Down Expand Up @@ -283,10 +295,10 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception {
SnowflakeServiceClient snowflakeServiceClient =
new SnowflakeServiceClient(mockClient, mockBuilder);
StorageManager<?, ?> storageManager =
new InternalStageManager(true, "role", "client", snowflakeServiceClient);
new InternalStageManager<>(true, "role", "client", snowflakeServiceClient);

StreamingIngestStorage<?, ?> stage =
new StreamingIngestStorage(
new StreamingIngestStorage<>(
storageManager,
"clientName",
(StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null,
Expand All @@ -312,7 +324,55 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception {
Assert.assertEquals(
Paths.get("placeholder").toAbsolutePath(),
Paths.get(metadataWithAge.fileTransferMetadata.getPresignedUrlFileName()).toAbsolutePath());
Assert.assertEquals("testPrefix", storageManager.getClientPrefix());
Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix());
}

@Test
public void testRefreshSnowflakeMetadataDeploymentIdMismatch() throws Exception {
RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class);
CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);
CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class);
StatusLine mockStatusLine = Mockito.mock(StatusLine.class);
Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200);
Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);

BasicHttpEntity entity = new BasicHttpEntity();
entity.setContent(
new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8)));

BasicHttpEntity entityFromDifferentDeployment = new BasicHttpEntity();
entityFromDifferentDeployment.setContent(
new ByteArrayInputStream(
remoteMetaResponseDifferentDeployment.getBytes(StandardCharsets.UTF_8)));
Mockito.when(mockResponse.getEntity())
.thenReturn(entity)
.thenReturn(entityFromDifferentDeployment);
Mockito.when(mockClient.execute(Mockito.any()))
.thenReturn(mockResponse)
.thenReturn(mockResponse);

SnowflakeServiceClient snowflakeServiceClient =
new SnowflakeServiceClient(mockClient, mockBuilder);
StorageManager<?, ?> storageManager =
new InternalStageManager<>(true, "role", "clientName", snowflakeServiceClient);

StreamingIngestStorage<?, ?> storage = storageManager.getStorage("");
storage.refreshSnowflakeMetadata(true);

Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix());

SFException exception =
Assert.assertThrows(SFException.class, () -> storage.refreshSnowflakeMetadata(true));
Assert.assertEquals(
ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH.getMessageCode(), exception.getVendorCode());
Assert.assertEquals(
"Deployment ID mismatch, Client was created on: "
+ deploymentId
+ ", Got upload location for: "
+ (deploymentId + 1)
+ ". Please"
+ " restart client: clientName.",
exception.getMessage());
}

@Test
Expand All @@ -325,8 +385,8 @@ public void testFetchSignedURL() throws Exception {
Mockito.when(mockClientInternal.getRole()).thenReturn("role");
SnowflakeServiceClient snowflakeServiceClient =
new SnowflakeServiceClient(mockClient, mockBuilder);
StorageManager storageManager =
new InternalStageManager(true, "role", "client", snowflakeServiceClient);
StorageManager<?, ?> storageManager =
new InternalStageManager<>(true, "role", "client", snowflakeServiceClient);
StatusLine mockStatusLine = Mockito.mock(StatusLine.class);
Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200);

Expand Down Expand Up @@ -371,17 +431,17 @@ public void testRefreshSnowflakeMetadataSynchronized() throws Exception {
Mockito.when(mockClientInternal.getRole()).thenReturn("role");
SnowflakeServiceClient snowflakeServiceClient =
new SnowflakeServiceClient(mockClient, mockBuilder);
StorageManager storageManager =
new InternalStageManager(true, "role", "client", snowflakeServiceClient);
StorageManager<?, ?> storageManager =
new InternalStageManager<>(true, "role", "client", snowflakeServiceClient);
StatusLine mockStatusLine = Mockito.mock(StatusLine.class);
Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200);

Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);
Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse));
Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse);

StreamingIngestStorage stage =
new StreamingIngestStorage(
StreamingIngestStorage<?, ?> stage =
new StreamingIngestStorage<>(
storageManager,
"clientName",
(StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null,
Expand Down Expand Up @@ -517,8 +577,8 @@ public void testRefreshMetadataOnFirstPutException() throws Exception {
StorageManager<?, ?> storageManager = Mockito.mock(StorageManager.class);
Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix");

StreamingIngestStorage stage =
new StreamingIngestStorage(
StreamingIngestStorage<?, ?> stage =
new StreamingIngestStorage<>(
storageManager,
"clientName",
new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge(
Expand Down

0 comments on commit 4507637

Please sign in to comment.