From 334767d75aa919b91ffadee3fa85ff01377663a6 Mon Sep 17 00:00:00 2001 From: Alec Huang Date: Mon, 22 Jul 2024 19:22:36 -0700 Subject: [PATCH] Code refactor for Iceberg support (#792) Co-authored-by: Hitesh Madan --- .../ingest/connection/OAuthClient.java | 1 + .../ingest/streaming/OpenChannelRequest.java | 4 +- .../internal/ChannelFlushContext.java | 10 +- .../internal/ChannelsStatusRequest.java | 22 +- .../internal/ClientConfigureRequest.java | 49 ++++ .../internal/ClientConfigureResponse.java | 75 ++++++ .../internal/DropChannelRequestInternal.java | 92 +++++++ .../streaming/internal/FileLocationInfo.java | 143 ++++++++++ .../streaming/internal/FlushService.java | 143 +++------- .../streaming/internal/IStorageManager.java | 67 +++++ .../internal/IStreamingIngestRequest.java | 13 + .../internal/InternalStageManager.java | 208 +++++++++++++++ .../internal/OpenChannelRequestInternal.java | 97 +++++++ .../internal/RegisterBlobRequest.java | 48 ++++ .../internal/SnowflakeServiceClient.java | 193 ++++++++++++++ ...nowflakeStreamingIngestClientInternal.java | 178 ++++--------- .../internal/StreamingIngestResponse.java | 9 +- ...Stage.java => StreamingIngestStorage.java} | 251 +++++++----------- .../internal/StreamingIngestUtils.java | 7 +- .../net/snowflake/ingest/utils/ErrorCode.java | 2 +- .../net/snowflake/ingest/utils/Utils.java | 29 +- .../streaming/internal/FlushServiceTest.java | 43 +-- .../SnowflakeStreamingIngestChannelTest.java | 7 +- ...t.java => StreamingIngestStorageTest.java} | 176 ++++++------ .../internal/StreamingIngestUtilsIT.java | 23 +- 25 files changed, 1379 insertions(+), 511 deletions(-) create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java create mode 100644 src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java rename src/main/java/net/snowflake/ingest/streaming/internal/{StreamingIngestStage.java => StreamingIngestStorage.java} (63%) rename src/test/java/net/snowflake/ingest/streaming/internal/{StreamingIngestStageTest.java => StreamingIngestStorageTest.java} (83%) diff --git a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java index 61c736a42..a592899ea 100644 --- a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java +++ b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java @@ -93,6 +93,7 @@ public void refreshToken() throws IOException { /** Helper method for making refresh request */ private HttpUriRequest makeRefreshTokenRequest() { + // TODO SNOW-1538108 Use SnowflakeServiceClient to make the request HttpPost post = new HttpPost(oAuthCredential.get().getOAuthTokenEndpoint()); post.addHeader(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER); post.addHeader(HttpHeaders.AUTHORIZATION, oAuthCredential.get().getAuthHeader()); diff --git a/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java b/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java index cc8782dbd..4d3ea19aa 100644 --- a/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java +++ b/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming; @@ -150,7 +150,7 @@ public ZoneId getDefaultTimezone() { } public String getFullyQualifiedTableName() { - return String.format("%s.%s.%s", this.dbName, this.schemaName, this.tableName); + return Utils.getFullyQualifiedTableName(this.dbName, this.schemaName, this.tableName); } public OnErrorOption getOnErrorOption() { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java index 3e5265719..fe9542267 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; +import net.snowflake.ingest.utils.Utils; + /** * Channel immutable identification and encryption attributes. * @@ -36,12 +38,12 @@ class ChannelFlushContext { String encryptionKey, Long encryptionKeyId) { this.name = name; - this.fullyQualifiedName = String.format("%s.%s.%s.%s", dbName, schemaName, tableName, name); + this.fullyQualifiedName = + Utils.getFullyQualifiedChannelName(dbName, schemaName, tableName, name); this.dbName = dbName; this.schemaName = schemaName; this.tableName = tableName; - this.fullyQualifiedTableName = - String.format("%s.%s.%s", this.getDbName(), this.getSchemaName(), this.getTableName()); + this.fullyQualifiedTableName = Utils.getFullyQualifiedTableName(dbName, schemaName, tableName); this.channelSequencer = channelSequencer; this.encryptionKey = encryptionKey; this.encryptionKeyId = encryptionKeyId; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java index c72c62a4b..025647f14 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java @@ -1,14 +1,16 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +import java.util.stream.Collectors; +import net.snowflake.ingest.utils.Utils; /** Class to deserialize a request from a channel status request */ -class ChannelsStatusRequest { +class ChannelsStatusRequest implements IStreamingIngestRequest { // Used to deserialize a channel request static class ChannelStatusRequestDTO { @@ -86,4 +88,20 @@ void setChannels(List channels) { List getChannels() { return channels; } + + @Override + public String getStringForLogging() { + return String.format( + "ChannelsStatusRequest(role=%s, channels=[%s])", + role, + channels.stream() + .map( + r -> + Utils.getFullyQualifiedChannelName( + r.getDatabaseName(), + r.getSchemaName(), + r.getTableName(), + r.getChannelName())) + .collect(Collectors.joining(", "))); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java new file mode 100644 index 000000000..79b282079 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Class used to serialize client configure request */ +class ClientConfigureRequest implements IStreamingIngestRequest { + /** + * Constructor for client configure request + * + * @param role Role to be used for the request. + */ + ClientConfigureRequest(String role) { + this.role = role; + } + + @JsonProperty("role") + private String role; + + // File name for the GCS signed url request + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("file_name") + private String fileName; + + String getRole() { + return role; + } + + void setRole(String role) { + this.role = role; + } + + String getFileName() { + return fileName; + } + + void setFileName(String fileName) { + this.fileName = fileName; + } + + @Override + public String getStringForLogging() { + return String.format("ClientConfigureRequest(role=%s, file_name=%s)", getRole(), getFileName()); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java new file mode 100644 index 000000000..03a1d3576 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Class used to deserialize responses from configure endpoint */ +@JsonIgnoreProperties(ignoreUnknown = true) +class ClientConfigureResponse extends StreamingIngestResponse { + @JsonProperty("prefix") + private String prefix; + + @JsonProperty("status_code") + private Long statusCode; + + @JsonProperty("message") + private String message; + + @JsonProperty("stage_location") + private FileLocationInfo stageLocation; + + @JsonProperty("deployment_id") + private Long deploymentId; + + String getPrefix() { + return prefix; + } + + void setPrefix(String prefix) { + this.prefix = prefix; + } + + @Override + Long getStatusCode() { + return statusCode; + } + + void setStatusCode(Long statusCode) { + this.statusCode = statusCode; + } + + String getMessage() { + return message; + } + + void setMessage(String message) { + this.message = message; + } + + FileLocationInfo getStageLocation() { + return stageLocation; + } + + void setStageLocation(FileLocationInfo stageLocation) { + this.stageLocation = stageLocation; + } + + Long getDeploymentId() { + return deploymentId; + } + + void setDeploymentId(Long deploymentId) { + this.deploymentId = deploymentId; + } + + String getClientPrefix() { + if (this.deploymentId == null) { + return this.prefix; + } + return this.prefix + "_" + this.deploymentId; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java new file mode 100644 index 000000000..322b53acf --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import net.snowflake.ingest.streaming.DropChannelRequest; +import net.snowflake.ingest.utils.Utils; + +/** Class used to serialize the {@link DropChannelRequest} */ +class DropChannelRequestInternal implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("channel") + private String channel; + + @JsonProperty("table") + private String table; + + @JsonProperty("database") + private String database; + + @JsonProperty("schema") + private String schema; + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("client_sequencer") + Long clientSequencer; + + DropChannelRequestInternal( + String requestId, + String role, + String database, + String schema, + String table, + String channel, + Long clientSequencer) { + this.requestId = requestId; + this.role = role; + this.database = database; + this.schema = schema; + this.table = table; + this.channel = channel; + this.clientSequencer = clientSequencer; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + String getChannel() { + return channel; + } + + String getTable() { + return table; + } + + String getDatabase() { + return database; + } + + String getSchema() { + return schema; + } + + Long getClientSequencer() { + return clientSequencer; + } + + String getFullyQualifiedTableName() { + return Utils.getFullyQualifiedTableName(database, schema, table); + } + + @Override + public String getStringForLogging() { + return String.format( + "DropChannelRequest(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," + + " clientSequencer=%s)", + requestId, role, database, schema, table, channel, clientSequencer); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java new file mode 100644 index 000000000..add98a6fb --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; + +/** Class used to deserialized volume information response by server */ +class FileLocationInfo { + /** The stage type */ + @JsonProperty("locationType") + private String locationType; + + /** The container or bucket */ + @JsonProperty("location") + private String location; + + /** The path of the target file */ + @JsonProperty("path") + private String path; + + /** The credentials required for the stage */ + @JsonProperty("creds") + private Map credentials; + + /** AWS/S3/GCS region (S3/GCS only) */ + @JsonProperty("region") + private String region; + + /** The Azure Storage endpoint (Azure only) */ + @JsonProperty("endPoint") + private String endPoint; + + /** The Azure Storage account (Azure only) */ + @JsonProperty("storageAccount") + private String storageAccount; + + /** GCS gives us back a presigned URL instead of a cred */ + @JsonProperty("presignedUrl") + private String presignedUrl; + + /** Whether to encrypt/decrypt files on the stage */ + @JsonProperty("isClientSideEncrypted") + private boolean isClientSideEncrypted; + + /** Whether to use s3 regional URL (AWS Only) */ + @JsonProperty("useS3RegionalUrl") + private boolean useS3RegionalUrl; + + /** A unique id for volume assigned by server */ + @JsonProperty("volumeHash") + private String volumeHash; + + String getLocationType() { + return locationType; + } + + void setLocationType(String locationType) { + this.locationType = locationType; + } + + String getLocation() { + return location; + } + + void setLocation(String location) { + this.location = location; + } + + String getPath() { + return path; + } + + void setPath(String path) { + this.path = path; + } + + Map getCredentials() { + return credentials; + } + + void setCredentials(Map credentials) { + this.credentials = credentials; + } + + String getRegion() { + return region; + } + + void setRegion(String region) { + this.region = region; + } + + String getEndPoint() { + return endPoint; + } + + void setEndPoint(String endPoint) { + this.endPoint = endPoint; + } + + String getStorageAccount() { + return storageAccount; + } + + void setStorageAccount(String storageAccount) { + this.storageAccount = storageAccount; + } + + String getPresignedUrl() { + return presignedUrl; + } + + void setPresignedUrl(String presignedUrl) { + this.presignedUrl = presignedUrl; + } + + boolean getIsClientSideEncrypted() { + return this.isClientSideEncrypted; + } + + void setIsClientSideEncrypted(boolean isClientSideEncrypted) { + this.isClientSideEncrypted = isClientSideEncrypted; + } + + boolean getUseS3RegionalUrl() { + return this.useS3RegionalUrl; + } + + void setUseS3RegionalUrl(boolean useS3RegionalUrl) { + this.useS3RegionalUrl = useS3RegionalUrl; + } + + String getVolumeHash() { + return this.volumeHash; + } + + void setVolumeHash(String volumeHash) { + this.volumeHash = volumeHash; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index 76e43ff4d..954abfc4a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -1,10 +1,9 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.utils.Constants.BLOB_EXTENSION_TYPE; import static net.snowflake.ingest.utils.Constants.DISABLE_BACKGROUND_FLUSH; import static net.snowflake.ingest.utils.Constants.MAX_BLOB_SIZE_IN_BYTES; import static net.snowflake.ingest.utils.Constants.MAX_THREAD_COUNT; @@ -19,13 +18,11 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; -import java.util.Calendar; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.TimeZone; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -34,11 +31,9 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; -import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; @@ -84,9 +79,6 @@ List>> getData() { private static final Logging logger = new Logging(FlushService.class); - // Increasing counter to generate a unique blob name per client - private final AtomicLong counter; - // The client that owns this flush service private final SnowflakeStreamingIngestClientInternal owningClient; @@ -102,8 +94,8 @@ List>> getData() { // Reference to the channel cache private final ChannelCache channelCache; - // Reference to the Streaming Ingest stage - private final StreamingIngestStage targetStage; + // Reference to the Streaming Ingest storage manager + private final IStorageManager storageManager; // Reference to register service private final RegisterService registerService; @@ -125,56 +117,22 @@ List>> getData() { private volatile int numProcessors = Runtime.getRuntime().availableProcessors(); /** - * Constructor for TESTING that takes (usually mocked) StreamingIngestStage + * Default constructor * - * @param client - * @param cache - * @param isTestMode + * @param client the owning client + * @param cache the channel cache + * @param storageManager the storage manager + * @param isTestMode whether the service is running in test mode */ FlushService( SnowflakeStreamingIngestClientInternal client, ChannelCache cache, - StreamingIngestStage targetStage, // For testing + IStorageManager storageManager, boolean isTestMode) { this.owningClient = client; this.channelCache = cache; - this.targetStage = targetStage; - this.counter = new AtomicLong(0); - this.registerService = new RegisterService<>(client, isTestMode); - this.isNeedFlush = false; - this.lastFlushTime = System.currentTimeMillis(); - this.isTestMode = isTestMode; - this.latencyTimerContextMap = new ConcurrentHashMap<>(); - this.bdecVersion = this.owningClient.getParameterProvider().getBlobFormatVersion(); - createWorkers(); - } - - /** - * Default constructor - * - * @param client - * @param cache - * @param isTestMode - */ - FlushService( - SnowflakeStreamingIngestClientInternal client, ChannelCache cache, boolean isTestMode) { - this.owningClient = client; - this.channelCache = cache; - try { - this.targetStage = - new StreamingIngestStage( - isTestMode, - client.getRole(), - client.getHttpClient(), - client.getRequestBuilder(), - client.getName(), - DEFAULT_MAX_UPLOAD_RETRIES); - } catch (SnowflakeSQLException | IOException err) { - throw new SFException(err, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE); - } - + this.storageManager = storageManager; this.registerService = new RegisterService<>(client, isTestMode); - this.counter = new AtomicLong(0); this.isNeedFlush = false; this.lastFlushTime = System.currentTimeMillis(); this.isTestMode = isTestMode; @@ -367,7 +325,7 @@ void distributeFlushTasks() { while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { List>> blobData = new ArrayList<>(); float totalBufferSizeInBytes = 0F; - final String blobPath = getBlobPath(this.targetStage.getClientPrefix()); + final String blobPath = this.storageManager.generateBlobPath(); // Distribute work at table level, split the blob if reaching the blob size limit or the // channel has different encryption key ids @@ -449,9 +407,9 @@ && shouldStopProcessing( // Kick off a build job if (blobData.isEmpty()) { - // we decrement the counter so that we do not have gaps in the blob names created by this - // client. See method getBlobPath() below. - this.counter.decrementAndGet(); + // we decrement the blob sequencer so that we do not have gaps in the blob names created by + // this client. + this.storageManager.decrementBlobSequencer(); } else { long flushStartMs = System.currentTimeMillis(); if (this.owningClient.flushLatency != null) { @@ -463,7 +421,13 @@ && shouldStopProcessing( CompletableFuture.supplyAsync( () -> { try { - BlobMetadata blobMetadata = buildAndUpload(blobPath, blobData); + // Get the fully qualified table name from the first channel in the blob. + // This only matters when the client is in Iceberg mode. In Iceberg mode, + // all channels in the blob belong to the same table. + String fullyQualifiedTableName = + blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName(); + BlobMetadata blobMetadata = + buildAndUpload(blobPath, blobData, fullyQualifiedTableName); blobMetadata.getBlobStats().setFlushStartMs(flushStartMs); return blobMetadata; } catch (Throwable e) { @@ -546,9 +510,12 @@ private boolean shouldStopProcessing( * @param blobPath Path of the destination blob in cloud storage * @param blobData All the data for one blob. Assumes that all ChannelData in the inner List * belongs to the same table. Will error if this is not the case + * @param fullyQualifiedTableName the table name of the first channel in the blob, only matters in + * Iceberg mode * @return BlobMetadata for FlushService.upload */ - BlobMetadata buildAndUpload(String blobPath, List>> blobData) + BlobMetadata buildAndUpload( + String blobPath, List>> blobData, String fullyQualifiedTableName) throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, InvalidKeyException { @@ -559,12 +526,18 @@ BlobMetadata buildAndUpload(String blobPath, List>> blobData blob.blobStats.setBuildDurationMs(buildContext); - return upload(blobPath, blob.blobBytes, blob.chunksMetadataList, blob.blobStats); + return upload( + this.storageManager.getStorage(fullyQualifiedTableName), + blobPath, + blob.blobBytes, + blob.chunksMetadataList, + blob.blobStats); } /** * Upload a blob to Streaming Ingest dedicated stage * + * @param storage the storage to upload the blob * @param blobPath full path of the blob * @param blob blob data * @param metadata a list of chunk metadata @@ -572,13 +545,17 @@ BlobMetadata buildAndUpload(String blobPath, List>> blobData * @return BlobMetadata object used to create the register blob request */ BlobMetadata upload( - String blobPath, byte[] blob, List metadata, BlobStats blobStats) + StreamingIngestStorage storage, + String blobPath, + byte[] blob, + List metadata, + BlobStats blobStats) throws NoSuchAlgorithmException { logger.logInfo("Start uploading blob={}, size={}", blobPath, blob.length); long startTime = System.currentTimeMillis(); Timer.Context uploadContext = Utils.createTimerContext(this.owningClient.uploadLatency); - this.targetStage.put(blobPath, blob); + storage.put(blobPath, blob); if (uploadContext != null) { blobStats.setUploadDurationMs(uploadContext); @@ -635,45 +612,6 @@ void setNeedFlush() { this.isNeedFlush = true; } - /** - * Generate a blob path, which is: "YEAR/MONTH/DAY_OF_MONTH/HOUR_OF_DAY/MINUTE/.BDEC" - * - * @return the generated blob file path - */ - private String getBlobPath(String clientPrefix) { - Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); - return getBlobPath(calendar, clientPrefix); - } - - /** For TESTING */ - String getBlobPath(Calendar calendar, String clientPrefix) { - if (isTestMode && clientPrefix == null) { - clientPrefix = "testPrefix"; - } - - Utils.assertStringNotNullOrEmpty("client prefix", clientPrefix); - int year = calendar.get(Calendar.YEAR); - int month = calendar.get(Calendar.MONTH) + 1; // Gregorian calendar starts from 0 - int day = calendar.get(Calendar.DAY_OF_MONTH); - int hour = calendar.get(Calendar.HOUR_OF_DAY); - int minute = calendar.get(Calendar.MINUTE); - long time = TimeUnit.MILLISECONDS.toSeconds(calendar.getTimeInMillis()); - long threadId = Thread.currentThread().getId(); - // Create the blob short name, the clientPrefix contains the deployment id - String blobShortName = - Long.toString(time, 36) - + "_" - + clientPrefix - + "_" - + threadId - + "_" - + this.counter.getAndIncrement() - + "." - + BLOB_EXTENSION_TYPE; - return year + "/" + month + "/" + day + "/" + hour + "/" + minute + "/" + blobShortName; - } - /** * Invalidate all the channels in the blob data * @@ -697,11 +635,6 @@ void invalidateAllChannelsInBlob( })); } - /** Get the server generated unique prefix for this client */ - String getClientPrefix() { - return this.targetStage.getClientPrefix(); - } - /** * Throttle if the number of queued buildAndUpload tasks is bigger than the total number of * available processors diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java new file mode 100644 index 000000000..51f4a82de --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import java.util.Optional; + +/** + * Interface to manage {@link StreamingIngestStorage} for {@link FlushService} + * + * @param The type of chunk data + * @param the type of location that's being managed (internal stage / external volume) + */ +interface IStorageManager { + /** Default max upload retries for streaming ingest storage */ + int DEFAULT_MAX_UPLOAD_RETRIES = 5; + + /** + * Given a fully qualified table name, return the target storage + * + * @param fullyQualifiedTableName the target fully qualified table name + * @return target stage + */ + StreamingIngestStorage getStorage(String fullyQualifiedTableName); + + /** + * Add a storage to the manager + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @param fileLocationInfo file location info from configure response + */ + void addStorage( + String dbName, String schemaName, String tableName, FileLocationInfo fileLocationInfo); + + /** + * Gets the latest file location info (with a renewed short-lived access token) for the specified + * location + * + * @param location A reference to the target location + * @param fileName optional filename for single-file signed URL fetch from server + * @return the new location information + */ + FileLocationInfo getRefreshedLocation(TLocation location, Optional fileName); + + /** + * Generate a unique blob path and increment the blob sequencer + * + * @return the blob path + */ + String generateBlobPath(); + + /** + * Decrement the blob sequencer, this method is needed to prevent gap between file name sequencer. + * See {@link IStorageManager#generateBlobPath()} for more details. + */ + void decrementBlobSequencer(); + + /** + * Get the unique client prefix generated by the Snowflake server + * + * @return the client prefix + */ + String getClientPrefix(); +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java new file mode 100644 index 000000000..a4b5e29d1 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +/** + * The StreamingIngestRequest interface is a marker interface used for type safety in the {@link + * SnowflakeServiceClient} for streaming ingest API request. + */ +interface IStreamingIngestRequest { + String getStringForLogging(); +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java new file mode 100644 index 000000000..d33a80738 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import static net.snowflake.ingest.utils.Constants.BLOB_EXTENSION_TYPE; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.util.Calendar; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.ingest.connection.IngestResponseException; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.SFException; +import net.snowflake.ingest.utils.Utils; + +class InternalStageLocation { + public InternalStageLocation() {} +} + +/** Class to manage single Snowflake internal stage */ +class InternalStageManager implements IStorageManager { + /** Target stage for the client */ + private final StreamingIngestStorage targetStage; + + /** Increasing counter to generate a unique blob name per client */ + private final AtomicLong counter; + + /** Whether the manager in test mode */ + private final boolean isTestMode; + + /** Snowflake service client used for configure calls */ + private final SnowflakeServiceClient snowflakeServiceClient; + + /** The name of the client */ + private final String clientName; + + /** The role of the client */ + private final String role; + + /** Client prefix generated by the Snowflake server */ + private String clientPrefix; + + /** Deployment ID generated by the Snowflake server */ + private Long deploymentId; + + /** + * Constructor for InternalStageManager + * + * @param isTestMode whether the manager in test mode + * @param role the role of the client + * @param clientName the name of the client + * @param snowflakeServiceClient the Snowflake service client to use for configure calls + */ + InternalStageManager( + boolean isTestMode, + String role, + String clientName, + SnowflakeServiceClient snowflakeServiceClient) { + this.snowflakeServiceClient = snowflakeServiceClient; + this.isTestMode = isTestMode; + this.clientName = clientName; + this.role = role; + this.counter = new AtomicLong(0); + try { + if (!isTestMode) { + ClientConfigureResponse response = + this.snowflakeServiceClient.clientConfigure(new ClientConfigureRequest(role)); + this.clientPrefix = response.getClientPrefix(); + this.deploymentId = response.getDeploymentId(); + this.targetStage = + new StreamingIngestStorage( + this, + clientName, + response.getStageLocation(), + new InternalStageLocation(), + DEFAULT_MAX_UPLOAD_RETRIES); + } else { + this.clientPrefix = null; + this.deploymentId = null; + this.targetStage = + new StreamingIngestStorage( + this, + "testClient", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + new InternalStageLocation(), + DEFAULT_MAX_UPLOAD_RETRIES); + } + } catch (IngestResponseException | IOException e) { + throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); + } catch (SnowflakeSQLException e) { + throw new SFException(e, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE, e.getMessage()); + } + } + + /** + * Get the storage. In this case, the storage is always the target stage as there's only one stage + * in non-iceberg mode. + * + * @param fullyQualifiedTableName the target fully qualified table name + * @return the target storage + */ + @Override + @SuppressWarnings("unused") + public StreamingIngestStorage getStorage( + String fullyQualifiedTableName) { + // There's always only one stage for the client in non-iceberg mode + return targetStage; + } + + /** Add storage to the manager. Do nothing as there's only one stage in non-Iceberg mode. */ + @Override + public void addStorage( + String dbName, String schemaName, String tableName, FileLocationInfo fileLocationInfo) {} + + /** + * Gets the latest file location info (with a renewed short-lived access token) for the specified + * location + * + * @param location A reference to the target location + * @param fileName optional filename for single-file signed URL fetch from server + * @return the new location information + */ + @Override + public FileLocationInfo getRefreshedLocation( + InternalStageLocation location, Optional fileName) { + try { + ClientConfigureRequest request = new ClientConfigureRequest(this.role); + fileName.ifPresent(request::setFileName); + ClientConfigureResponse response = snowflakeServiceClient.clientConfigure(request); + if (this.clientPrefix == null) { + this.clientPrefix = response.getClientPrefix(); + this.deploymentId = response.getDeploymentId(); + } + if (this.deploymentId != null && !this.deploymentId.equals(response.getDeploymentId())) { + throw new SFException( + ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH, + this.deploymentId, + response.getDeploymentId(), + this.clientName); + } + return response.getStageLocation(); + } catch (IngestResponseException | IOException e) { + throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); + } + } + + /** + * Generate a blob path, which is: "YEAR/MONTH/DAY_OF_MONTH/HOUR_OF_DAY/MINUTE/.BDEC" + * + * @return the generated blob file path + */ + @Override + public String generateBlobPath() { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + return getBlobPath(calendar, this.clientPrefix); + } + + @Override + public void decrementBlobSequencer() { + this.counter.decrementAndGet(); + } + + /** For TESTING */ + @VisibleForTesting + public String getBlobPath(Calendar calendar, String clientPrefix) { + if (this.isTestMode && clientPrefix == null) { + clientPrefix = "testPrefix"; + } + + Utils.assertStringNotNullOrEmpty("client prefix", clientPrefix); + int year = calendar.get(Calendar.YEAR); + int month = calendar.get(Calendar.MONTH) + 1; // Gregorian calendar starts from 0 + int day = calendar.get(Calendar.DAY_OF_MONTH); + int hour = calendar.get(Calendar.HOUR_OF_DAY); + int minute = calendar.get(Calendar.MINUTE); + long time = TimeUnit.MILLISECONDS.toSeconds(calendar.getTimeInMillis()); + long threadId = Thread.currentThread().getId(); + // Create the blob short name, the clientPrefix contains the deployment id + String blobShortName = + Long.toString(time, 36) + + "_" + + clientPrefix + + "_" + + threadId + + "_" + + this.counter.getAndIncrement() + + "." + + BLOB_EXTENSION_TYPE; + return year + "/" + month + "/" + day + "/" + hour + "/" + minute + "/" + blobShortName; + } + + /** + * Get the unique client prefix generated by the Snowflake server + * + * @return the client prefix + */ + @Override + public String getClientPrefix() { + return this.clientPrefix; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java new file mode 100644 index 000000000..ff53f6729 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.utils.Constants; + +/** Class used to serialize the {@link OpenChannelRequest} */ +class OpenChannelRequestInternal implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("channel") + private String channel; + + @JsonProperty("table") + private String table; + + @JsonProperty("database") + private String database; + + @JsonProperty("schema") + private String schema; + + @JsonProperty("write_mode") + private String writeMode; + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("offset_token") + private String offsetToken; + + OpenChannelRequestInternal( + String requestId, + String role, + String database, + String schema, + String table, + String channel, + Constants.WriteMode writeMode, + String offsetToken) { + this.requestId = requestId; + this.role = role; + this.database = database; + this.schema = schema; + this.table = table; + this.channel = channel; + this.writeMode = writeMode.name(); + this.offsetToken = offsetToken; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + String getChannel() { + return channel; + } + + String getTable() { + return table; + } + + String getDatabase() { + return database; + } + + String getSchema() { + return schema; + } + + String getWriteMode() { + return writeMode; + } + + String getOffsetToken() { + return offsetToken; + } + + @Override + public String getStringForLogging() { + return String.format( + "OpenChannelRequestInternal(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," + + " writeMode=%s)", + requestId, role, database, schema, table, channel, writeMode); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java new file mode 100644 index 000000000..fcb7edf4f --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.stream.Collectors; + +/** Class used to serialize the blob register request */ +class RegisterBlobRequest implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("blobs") + private List blobs; + + RegisterBlobRequest(String requestId, String role, List blobs) { + this.requestId = requestId; + this.role = role; + this.blobs = blobs; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + List getBlobs() { + return blobs; + } + + @Override + public String getStringForLogging() { + return String.format( + "RegisterBlobRequest(requestId=%s, role=%s, blobs=[%s])", + requestId, + role, + blobs.stream().map(BlobMetadata::getPath).collect(Collectors.joining(", "))); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java new file mode 100644 index 000000000..67958618b --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_STATUS; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_DROP_CHANNEL; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_OPEN_CHANNEL; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_REGISTER_BLOB; +import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; +import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; + +import java.io.IOException; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.ingest.connection.IngestResponseException; +import net.snowflake.ingest.connection.RequestBuilder; +import net.snowflake.ingest.connection.ServiceResponseHandler; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Logging; +import net.snowflake.ingest.utils.SFException; + +/** + * The SnowflakeServiceClient class is responsible for making API requests to the Snowflake service. + */ +class SnowflakeServiceClient { + private static final Logging logger = new Logging(SnowflakeServiceClient.class); + + /** HTTP client used for making requests */ + private final CloseableHttpClient httpClient; + + /** Request builder for building streaming API request */ + private final RequestBuilder requestBuilder; + + /** + * Default constructor + * + * @param httpClient the HTTP client used for making requests + * @param requestBuilder the request builder for building streaming API requests + */ + SnowflakeServiceClient(CloseableHttpClient httpClient, RequestBuilder requestBuilder) { + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + } + + /** + * Configures the client given a {@link ClientConfigureRequest}. + * + * @param request the client configuration request + * @return the response from the configuration request + */ + ClientConfigureResponse clientConfigure(ClientConfigureRequest request) + throws IngestResponseException, IOException { + ClientConfigureResponse response = + executeApiRequestWithRetries( + ClientConfigureResponse.class, + request, + CLIENT_CONFIGURE_ENDPOINT, + "client configure", + STREAMING_CLIENT_CONFIGURE); + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Client configure request failed, request={}, message={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.CLIENT_CONFIGURE_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Opens a channel given a {@link OpenChannelRequestInternal}. + * + * @param request the open channel request + * @return the response from the open channel request + */ + OpenChannelResponse openChannel(OpenChannelRequestInternal request) + throws IngestResponseException, IOException { + OpenChannelResponse response = + executeApiRequestWithRetries( + OpenChannelResponse.class, + request, + OPEN_CHANNEL_ENDPOINT, + "open channel", + STREAMING_OPEN_CHANNEL); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Open channel request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.OPEN_CHANNEL_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Drops a channel given a {@link DropChannelRequestInternal}. + * + * @param request the drop channel request + * @return the response from the drop channel request + */ + DropChannelResponse dropChannel(DropChannelRequestInternal request) + throws IngestResponseException, IOException { + DropChannelResponse response = + executeApiRequestWithRetries( + DropChannelResponse.class, + request, + DROP_CHANNEL_ENDPOINT, + "drop channel", + STREAMING_DROP_CHANNEL); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Drop channel request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.DROP_CHANNEL_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Gets the status of a channel given a {@link ChannelsStatusRequest}. + * + * @param request the channel status request + * @return the response from the channel status request + */ + ChannelsStatusResponse getChannelStatus(ChannelsStatusRequest request) + throws IngestResponseException, IOException { + ChannelsStatusResponse response = + executeApiRequestWithRetries( + ChannelsStatusResponse.class, + request, + CHANNEL_STATUS_ENDPOINT, + "channel status", + STREAMING_CHANNEL_STATUS); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Channel status request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.CHANNEL_STATUS_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Registers a blob given a {@link RegisterBlobRequest}. + * + * @param request the register blob request + * @param executionCount the number of times the request has been executed, used for logging + * @return the response from the register blob request + */ + RegisterBlobResponse registerBlob(RegisterBlobRequest request, final int executionCount) + throws IngestResponseException, IOException { + RegisterBlobResponse response = + executeApiRequestWithRetries( + RegisterBlobResponse.class, + request, + REGISTER_BLOB_ENDPOINT, + "register blob", + STREAMING_REGISTER_BLOB); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Register blob request failed, request={}, response={}, executionCount={}", + request.getStringForLogging(), + response.getMessage(), + executionCount); + throw new SFException(ErrorCode.REGISTER_BLOB_FAILURE, response.getMessage()); + } + return response; + } + + private T executeApiRequestWithRetries( + Class responseClass, + IStreamingIngestRequest request, + String endpoint, + String operation, + ServiceResponseHandler.ApiName apiName) + throws IngestResponseException, IOException { + return executeWithRetries( + responseClass, endpoint, request, operation, apiName, this.httpClient, this.requestBuilder); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 75eb4f717..6331a4045 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -1,23 +1,14 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_STATUS; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_DROP_CHANNEL; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_OPEN_CHANNEL; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_REGISTER_BLOB; -import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.sleepForRetry; -import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.COMMIT_MAX_RETRY_COUNT; import static net.snowflake.ingest.utils.Constants.COMMIT_RETRY_INTERVAL_IN_MS; -import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; import static net.snowflake.ingest.utils.Constants.ENABLE_TELEMETRY_TO_SF; import static net.snowflake.ingest.utils.Constants.MAX_STREAMING_INGEST_API_CHANNEL_RETRY; -import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; import static net.snowflake.ingest.utils.Constants.RESPONSE_ERR_ENQUEUE_TABLE_CHUNK_QUEUE_FULL; import static net.snowflake.ingest.utils.Constants.RESPONSE_ERR_GENERAL_EXCEPTION_RETRY_REQUEST; import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; @@ -37,7 +28,6 @@ import com.codahale.metrics.jmx.JmxReporter; import com.codahale.metrics.jvm.MemoryUsageGaugeSet; import com.codahale.metrics.jvm.ThreadStatesGaugeSet; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.net.URI; @@ -94,9 +84,6 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea private static final Logging logger = new Logging(SnowflakeStreamingIngestClientInternal.class); - // Object mapper for all marshalling and unmarshalling - private static final ObjectMapper objectMapper = new ObjectMapper(); - // Counter to generate unique request ids per client private final AtomicLong counter = new AtomicLong(0); @@ -118,6 +105,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea // Reference to the flush service private final FlushService flushService; + // Reference to storage manager + private final IStorageManager storageManager; + // Indicates whether the client has closed private volatile boolean isClosed; @@ -145,6 +135,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea // Background thread that uploads telemetry data periodically private ScheduledExecutorService telemetryWorker; + // Snowflake service client to make API calls + private SnowflakeServiceClient snowflakeServiceClient; + /** * Constructor * @@ -228,8 +221,14 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea this.setupMetricsForClient(); } + this.snowflakeServiceClient = new SnowflakeServiceClient(this.httpClient, this.requestBuilder); + + this.storageManager = + new InternalStageManager(isTestMode, this.role, this.name, this.snowflakeServiceClient); + try { - this.flushService = new FlushService<>(this, this.channelCache, this.isTestMode); + this.flushService = + new FlushService<>(this, this.channelCache, this.storageManager, this.isTestMode); } catch (Exception e) { // Need to clean up the resources before throwing any exceptions cleanUpResources(); @@ -274,6 +273,7 @@ public SnowflakeStreamingIngestClientInternal( @VisibleForTesting public void injectRequestBuilder(RequestBuilder requestBuilder) { this.requestBuilder = requestBuilder; + this.snowflakeServiceClient = new SnowflakeServiceClient(this.httpClient, this.requestBuilder); } /** @@ -320,39 +320,17 @@ public SnowflakeStreamingIngestChannelInternal openChannel(OpenChannelRequest getName()); try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("channel", request.getChannelName()); - payload.put("table", request.getTableName()); - payload.put("database", request.getDBName()); - payload.put("schema", request.getSchemaName()); - payload.put("write_mode", Constants.WriteMode.CLOUD_STORAGE.name()); - payload.put("role", this.role); - if (request.isOffsetTokenProvided()) { - payload.put("offset_token", request.getOffsetToken()); - } - - OpenChannelResponse response = - executeWithRetries( - OpenChannelResponse.class, - OPEN_CHANNEL_ENDPOINT, - payload, - "open channel", - STREAMING_OPEN_CHANNEL, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Open channel request failed, channel={}, table={}, client={}, message={}", - request.getChannelName(), - request.getFullyQualifiedTableName(), - getName(), - response.getMessage()); - throw new SFException(ErrorCode.OPEN_CHANNEL_FAILURE, response.getMessage()); - } + OpenChannelRequestInternal openChannelRequest = + new OpenChannelRequestInternal( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + request.getDBName(), + request.getSchemaName(), + request.getTableName(), + request.getChannelName(), + Constants.WriteMode.CLOUD_STORAGE, + request.getOffsetToken()); + OpenChannelResponse response = snowflakeServiceClient.openChannel(openChannelRequest); logger.logInfo( "Open channel request succeeded, channel={}, table={}, clientSequencer={}," @@ -405,51 +383,28 @@ public void dropChannel(DropChannelRequest request) { getName()); try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("channel", request.getChannelName()); - payload.put("table", request.getTableName()); - payload.put("database", request.getDBName()); - payload.put("schema", request.getSchemaName()); - payload.put("role", this.role); - Long clientSequencer = null; - if (request instanceof DropChannelVersionRequest) { - clientSequencer = ((DropChannelVersionRequest) request).getClientSequencer(); - if (clientSequencer != null) { - payload.put("client_sequencer", clientSequencer); - } - } - - DropChannelResponse response = - executeWithRetries( - DropChannelResponse.class, - DROP_CHANNEL_ENDPOINT, - payload, - "drop channel", - STREAMING_DROP_CHANNEL, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Drop channel request failed, channel={}, table={}, client={}, message={}", - request.getChannelName(), - request.getFullyQualifiedTableName(), - getName(), - response.getMessage()); - throw new SFException(ErrorCode.DROP_CHANNEL_FAILURE, response.getMessage()); - } + DropChannelRequestInternal dropChannelRequest = + new DropChannelRequestInternal( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + request.getDBName(), + request.getSchemaName(), + request.getTableName(), + request.getChannelName(), + request instanceof DropChannelVersionRequest + ? ((DropChannelVersionRequest) request).getClientSequencer() + : null); + snowflakeServiceClient.dropChannel(dropChannelRequest); logger.logInfo( "Drop channel request succeeded, channel={}, table={}, clientSequencer={} client={}", request.getChannelName(), request.getFullyQualifiedTableName(), - clientSequencer, + request instanceof DropChannelVersionRequest + ? ((DropChannelVersionRequest) request).getClientSequencer() + : null, getName()); - - } catch (IOException | IngestResponseException e) { + } catch (IngestResponseException | IOException e) { throw new SFException(e, ErrorCode.DROP_CHANNEL_FAILURE, e.getMessage()); } } @@ -494,22 +449,7 @@ ChannelsStatusResponse getChannelsStatus( request.setChannels(requestDTOs); request.setRole(this.role); - String payload = objectMapper.writeValueAsString(request); - - ChannelsStatusResponse response = - executeWithRetries( - ChannelsStatusResponse.class, - CHANNEL_STATUS_ENDPOINT, - payload, - "channel status", - STREAMING_CHANNEL_STATUS, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - throw new SFException(ErrorCode.CHANNEL_STATUS_FAILURE, response.getMessage()); - } + ChannelsStatusResponse response = snowflakeServiceClient.getChannelStatus(request); for (int idx = 0; idx < channels.size(); idx++) { SnowflakeStreamingIngestChannelInternal channel = channels.get(idx); @@ -606,32 +546,12 @@ void registerBlobs(List blobs, final int executionCount) { RegisterBlobResponse response = null; try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("blobs", blobs); - payload.put("role", this.role); - - response = - executeWithRetries( - RegisterBlobResponse.class, - REGISTER_BLOB_ENDPOINT, - payload, - "register blob", - STREAMING_REGISTER_BLOB, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Register blob request failed for blob={}, client={}, message={}, executionCount={}", - blobs.stream().map(BlobMetadata::getPath).collect(Collectors.toList()), - this.name, - response.getMessage(), - executionCount); - throw new SFException(ErrorCode.REGISTER_BLOB_FAILURE, response.getMessage()); - } + RegisterBlobRequest request = + new RegisterBlobRequest( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + blobs); + response = snowflakeServiceClient.registerBlob(request, executionCount); } catch (IOException | IngestResponseException e) { throw new SFException(e, ErrorCode.REGISTER_BLOB_FAILURE, e.getMessage()); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java index 1ec01fceb..6c4df8c6d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java @@ -1,9 +1,16 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; +/** + * The StreamingIngestResponse class is an abstract class that represents a response from the + * Snowflake streaming ingest API. This class provides a common structure for all types of responses + * that can be received from the {@link SnowflakeServiceClient}. + */ abstract class StreamingIngestResponse { abstract Long getStatusCode(); + + abstract String getMessage(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java similarity index 63% rename from src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java rename to src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java index 5556b7205..242b5cc43 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java @@ -1,16 +1,16 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; -import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; -import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; import static net.snowflake.ingest.utils.HttpUtil.generateProxyPropertiesForJDBC; import static net.snowflake.ingest.utils.Utils.getStackTrace; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; import java.io.File; @@ -19,13 +19,9 @@ import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; -import java.util.HashMap; -import java.util.Map; import java.util.Optional; import java.util.Properties; import java.util.concurrent.TimeUnit; -import java.util.function.Function; -import javax.annotation.Nullable; import net.snowflake.client.core.OCSPMode; import net.snowflake.client.jdbc.SnowflakeFileTransferAgent; import net.snowflake.client.jdbc.SnowflakeFileTransferConfig; @@ -33,20 +29,26 @@ import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.cloud.storage.StageInfo; import net.snowflake.client.jdbc.internal.apache.commons.io.FileUtils; -import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ObjectNode; -import net.snowflake.ingest.connection.IngestResponseException; -import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; -/** Handles uploading files to the Snowflake Streaming Ingest Stage */ -class StreamingIngestStage { +/** Handles uploading files to the Snowflake Streaming Ingest Storage */ +class StreamingIngestStorage { private static final ObjectMapper mapper = new ObjectMapper(); + + /** + * Object mapper for parsing the client/configure response to Jackson version the same as + * jdbc.internal.fasterxml.jackson. We need two different versions of ObjectMapper because {@link + * SnowflakeFileTransferAgent#getFileTransferMetadatas(net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode)} + * expects a different version of json object than {@link StreamingIngestResponse}. TODO: + * SNOW-1493470 Align Jackson version + */ + private static final net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper + parseConfigureResponseMapper = + new net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper(); + private static final long REFRESH_THRESHOLD_IN_MS = TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES); @@ -55,7 +57,7 @@ class StreamingIngestStage { private static final Duration refreshDuration = Duration.ofMinutes(58); private static Instant prevRefresh = Instant.EPOCH; - private static final Logging logger = new Logging(StreamingIngestStage.class); + private static final Logging logger = new Logging(StreamingIngestStorage.class); /** * Wrapper class containing SnowflakeFileTransferMetadata and the timestamp at which the metadata @@ -87,61 +89,61 @@ state to record unknown age. } private SnowflakeFileTransferMetadataWithAge fileTransferMetadataWithAge; - private final CloseableHttpClient httpClient; - private final RequestBuilder requestBuilder; - private final String role; + private final IStorageManager owningManager; + private final TLocation location; private final String clientName; - private String clientPrefix; - private Long deploymentId; private final int maxUploadRetries; // Proxy parameters that we set while calling the Snowflake JDBC to upload the streams private final Properties proxyProperties; - StreamingIngestStage( - boolean isTestMode, - String role, - CloseableHttpClient httpClient, - RequestBuilder requestBuilder, + /** + * Default constructor + * + * @param owningManager the storage manager owning this storage + * @param clientName The client name + * @param fileLocationInfo The file location information from open channel response + * @param location A reference to the target location + * @param maxUploadRetries The maximum number of retries to attempt + */ + StreamingIngestStorage( + IStorageManager owningManager, String clientName, + FileLocationInfo fileLocationInfo, + TLocation location, int maxUploadRetries) throws SnowflakeSQLException, IOException { - this.httpClient = httpClient; - this.role = role; - this.requestBuilder = requestBuilder; - this.clientName = clientName; - this.proxyProperties = generateProxyPropertiesForJDBC(); - this.maxUploadRetries = maxUploadRetries; - - if (!isTestMode) { - refreshSnowflakeMetadata(); - } + this( + owningManager, + clientName, + (SnowflakeFileTransferMetadataWithAge) null, + location, + maxUploadRetries); + createFileTransferMetadataWithAge(fileLocationInfo); } /** * Constructor for TESTING that takes SnowflakeFileTransferMetadataWithAge as input * - * @param isTestMode must be true - * @param role Snowflake role used by the Client - * @param httpClient http client reference - * @param requestBuilder request builder to build the HTTP request + * @param owningManager the storage manager owning this storage * @param clientName the client name * @param testMetadata SnowflakeFileTransferMetadataWithAge to test with + * @param location A reference to the target location + * @param maxUploadRetries the maximum number of retries to attempt */ - StreamingIngestStage( - boolean isTestMode, - String role, - CloseableHttpClient httpClient, - RequestBuilder requestBuilder, + StreamingIngestStorage( + IStorageManager owningManager, String clientName, SnowflakeFileTransferMetadataWithAge testMetadata, - int maxRetryCount) + TLocation location, + int maxUploadRetries) throws SnowflakeSQLException, IOException { - this(isTestMode, role, httpClient, requestBuilder, clientName, maxRetryCount); - if (!isTestMode) { - throw new SFException(ErrorCode.INTERNAL_ERROR); - } + this.owningManager = owningManager; + this.clientName = clientName; + this.maxUploadRetries = maxUploadRetries; + this.proxyProperties = generateProxyPropertiesForJDBC(); + this.location = location; this.fileTransferMetadataWithAge = testMetadata; } @@ -201,7 +203,7 @@ private void putRemote(String fullFilePath, byte[] data, int retryCount) .setUploadStream(inStream) .setRequireCompress(false) .setOcspMode(OCSPMode.FAIL_OPEN) - .setStreamingIngestClientKey(this.clientPrefix) + .setStreamingIngestClientKey(this.owningManager.getClientPrefix()) .setStreamingIngestClientName(this.clientName) .setProxyProperties(this.proxyProperties) .setDestFileName(fullFilePath) @@ -256,34 +258,27 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole return fileTransferMetadataWithAge; } - Map payload = new HashMap<>(); - payload.put("role", this.role); - Map response = this.makeClientConfigureCall(payload); + FileLocationInfo location = + this.owningManager.getRefreshedLocation(this.location, Optional.empty()); + return createFileTransferMetadataWithAge(location); + } - JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId); - // Do not change the prefix everytime we have to refresh credentials - if (Utils.isNullOrEmpty(this.clientPrefix)) { - this.deploymentId = - responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null; - this.clientPrefix = createClientPrefix(responseNode); - } - Utils.assertStringNotNullOrEmpty("client prefix", this.clientPrefix); + private SnowflakeFileTransferMetadataWithAge createFileTransferMetadataWithAge( + FileLocationInfo fileLocationInfo) + throws JsonProcessingException, + net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException, + SnowflakeSQLException { + Utils.assertStringNotNullOrEmpty("client prefix", this.owningManager.getClientPrefix()); - if (responseNode - .get("data") - .get("stageInfo") - .get("locationType") - .toString() + if (fileLocationInfo + .getLocationType() .replaceAll( "^[\"]|[\"]$", "") // Replace the first and last character if they're double quotes .equals(StageInfo.StageType.LOCAL_FS.name())) { this.fileTransferMetadataWithAge = new SnowflakeFileTransferMetadataWithAge( - responseNode - .get("data") - .get("stageInfo") - .get("location") - .toString() + fileLocationInfo + .getLocation() .replaceAll( "^[\"]|[\"]$", ""), // Replace the first and last character if they're double quotes @@ -292,7 +287,9 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole this.fileTransferMetadataWithAge = new SnowflakeFileTransferMetadataWithAge( (SnowflakeFileTransferMetadataV1) - SnowflakeFileTransferAgent.getFileTransferMetadatas(responseNode).get(0), + SnowflakeFileTransferAgent.getFileTransferMetadatas( + parseFileLocationInfo(fileLocationInfo)) + .get(0), Optional.of(System.currentTimeMillis())); } @@ -300,22 +297,6 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole return this.fileTransferMetadataWithAge; } - /** - * Creates a client-specific prefix that will be also part of the files registered by this client. - * The prefix will include a server-side generated string and the GlobalID of the deployment the - * client is registering blobs to. The latter (deploymentId) is needed in order to guarantee that - * blob filenames are unique across deployments even with replication enabled. - * - * @param response the client/configure response from the server - * @return the client prefix. - */ - private String createClientPrefix(final JsonNode response) { - final String prefix = response.get("prefix").textValue(); - final String deploymentId = - response.has("deployment_id") ? "_" + response.get("deployment_id").longValue() : ""; - return prefix + deploymentId; - } - /** * GCS requires a signed url per file. We need to fetch this from the server for each put * @@ -325,86 +306,39 @@ private String createClientPrefix(final JsonNode response) { SnowflakeFileTransferMetadataV1 fetchSignedURL(String fileName) throws SnowflakeSQLException, IOException { - Map payload = new HashMap<>(); - payload.put("role", this.role); - payload.put("file_name", fileName); - Map response = this.makeClientConfigureCall(payload); - - JsonNode responseNode = this.parseClientConfigureResponse(response, this.deploymentId); + FileLocationInfo location = + this.owningManager.getRefreshedLocation(this.location, Optional.of(fileName)); SnowflakeFileTransferMetadataV1 metadata = (SnowflakeFileTransferMetadataV1) - SnowflakeFileTransferAgent.getFileTransferMetadatas(responseNode).get(0); + SnowflakeFileTransferAgent.getFileTransferMetadatas(parseFileLocationInfo(location)) + .get(0); // Transfer agent trims path for fileName metadata.setPresignedUrlFileName(fileName); return metadata; } - private static class MapStatusGetter implements Function { - public MapStatusGetter() {} - - public Long apply(T input) { - try { - return ((Integer) ((Map) input).get("status_code")).longValue(); - } catch (Exception e) { - throw new SFException(ErrorCode.INTERNAL_ERROR, "failed to get status_code from response"); - } - } - } - - private static final MapStatusGetter statusGetter = new MapStatusGetter(); - - private JsonNode parseClientConfigureResponse( - Map response, @Nullable Long expectedDeploymentId) { - JsonNode responseNode = mapper.valueToTree(response); + private net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode + parseFileLocationInfo(FileLocationInfo fileLocationInfo) + throws JsonProcessingException, + net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException { + JsonNode fileLocationInfoNode = mapper.valueToTree(fileLocationInfo); // Currently there are a few mismatches between the client/configure response and what // SnowflakeFileTransferAgent expects - ObjectNode mutable = (ObjectNode) responseNode; - mutable.putObject("data"); - ObjectNode dataNode = (ObjectNode) mutable.get("data"); - dataNode.set("stageInfo", responseNode.get("stage_location")); + + ObjectNode node = mapper.createObjectNode(); + node.putObject("data"); + ObjectNode dataNode = (ObjectNode) node.get("data"); + dataNode.set("stageInfo", fileLocationInfoNode); // JDBC expects this field which maps to presignedFileUrlName. We will set this later dataNode.putArray("src_locations").add("placeholder"); - if (expectedDeploymentId != null) { - Long actualDeploymentId = - responseNode.has("deployment_id") ? responseNode.get("deployment_id").longValue() : null; - if (actualDeploymentId != null && !actualDeploymentId.equals(expectedDeploymentId)) { - throw new SFException( - ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH, - expectedDeploymentId, - actualDeploymentId, - clientName); - } - } - return responseNode; - } - - private Map makeClientConfigureCall(Map payload) - throws IOException { - try { - Map response = - executeWithRetries( - Map.class, - CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), - "client configure", - STREAMING_CLIENT_CONFIGURE, - httpClient, - requestBuilder, - statusGetter); - - // Check for Snowflake specific response code - if (!response.get("status_code").equals((int) RESPONSE_SUCCESS)) { - throw new SFException( - ErrorCode.CLIENT_CONFIGURE_FAILURE, response.get("message").toString()); - } - return response; - } catch (IngestResponseException e) { - throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); - } + // use String as intermediate object to avoid Jackson version mismatch + // TODO: SNOW-1493470 Align Jackson version + String responseString = mapper.writeValueAsString(node); + return parseConfigureResponseMapper.readTree(responseString); } /** @@ -450,9 +384,4 @@ void putLocal(String fullFilePath, byte[] data) { throw new SFException(ex, ErrorCode.BLOB_UPLOAD_FAILURE); } } - - /** Get the server generated unique prefix for this client */ - String getClientPrefix() { - return this.clientPrefix; - } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java index 56e960064..538283b4e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.utils.Constants.MAX_STREAMING_INGEST_API_CHANNEL_RETRY; @@ -6,7 +10,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; -import java.util.Map; import java.util.function.Function; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest; @@ -77,7 +80,7 @@ public static void sleepForRetry(int executionCount) { static T executeWithRetries( Class targetClass, String endpoint, - Map payload, + IStreamingIngestRequest payload, String message, ServiceResponseHandler.ApiName apiName, CloseableHttpClient httpClient, diff --git a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java index fdd8a713b..a9aab9c3b 100644 --- a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java +++ b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.utils; diff --git a/src/main/java/net/snowflake/ingest/utils/Utils.java b/src/main/java/net/snowflake/ingest/utils/Utils.java index a06df4027..5220625da 100644 --- a/src/main/java/net/snowflake/ingest/utils/Utils.java +++ b/src/main/java/net/snowflake/ingest/utils/Utils.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.utils; @@ -384,4 +384,31 @@ public static String getStackTrace(Throwable e) { } return stackTrace.toString(); } + + /** + * Get the fully qualified table name + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @return the fully qualified table name + */ + public static String getFullyQualifiedTableName( + String dbName, String schemaName, String tableName) { + return String.format("%s.%s.%s", dbName, schemaName, tableName); + } + + /** + * Get the fully qualified channel name + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @param channelName the channel name + * @return the fully qualified channel name + */ + public static String getFullyQualifiedChannelName( + String dbName, String schemaName, String tableName, String channelName) { + return String.format("%s.%s.%s.%s", dbName, schemaName, tableName, channelName); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index f200c7177..8ac1d2b85 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -77,22 +77,25 @@ private abstract static class TestContext implements AutoCloseable { ChannelCache channelCache; final Map> channels = new HashMap<>(); FlushService flushService; - StreamingIngestStage stage; + IStorageManager storageManager; + StreamingIngestStorage storage; ParameterProvider parameterProvider; RegisterService registerService; final List> channelData = new ArrayList<>(); TestContext() { - stage = Mockito.mock(StreamingIngestStage.class); - Mockito.when(stage.getClientPrefix()).thenReturn("client_prefix"); + storage = Mockito.mock(StreamingIngestStorage.class); parameterProvider = new ParameterProvider(); client = Mockito.mock(SnowflakeStreamingIngestClientInternal.class); Mockito.when(client.getParameterProvider()).thenReturn(parameterProvider); + storageManager = Mockito.spy(new InternalStageManager<>(true, "role", "client", null)); + Mockito.doReturn(storage).when(storageManager).getStorage(ArgumentMatchers.any()); + Mockito.when(storageManager.getClientPrefix()).thenReturn("client_prefix"); channelCache = new ChannelCache<>(); Mockito.when(client.getChannelCache()).thenReturn(channelCache); registerService = Mockito.spy(new RegisterService(client, client.isTestMode())); - flushService = Mockito.spy(new FlushService<>(client, channelCache, stage, true)); + flushService = Mockito.spy(new FlushService<>(client, channelCache, storageManager, true)); } ChannelData flushChannel(String name) { @@ -105,7 +108,10 @@ ChannelData flushChannel(String name) { BlobMetadata buildAndUpload() throws Exception { List>> blobData = Collections.singletonList(channelData); - return flushService.buildAndUpload("file_name", blobData); + return flushService.buildAndUpload( + "file_name", + blobData, + blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName()); } abstract SnowflakeStreamingIngestChannelInternal createChannel( @@ -389,10 +395,11 @@ private static ColumnMetadata createLargeTestTextColumn(String name) { @Test public void testGetFilePath() { TestContext testContext = testContextFactory.create(); - FlushService flushService = testContext.flushService; + IStorageManager storageManager = testContext.storageManager; Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); String clientPrefix = "honk"; - String outputString = flushService.getBlobPath(calendar, clientPrefix); + String outputString = + ((InternalStageManager) storageManager).getBlobPath(calendar, clientPrefix); Path outputPath = Paths.get(outputString); Assert.assertTrue(outputPath.getFileName().toString().contains(clientPrefix)); Assert.assertTrue( @@ -480,7 +487,8 @@ public void testBlobCreation() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.atLeast(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.atLeast(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -529,7 +537,8 @@ public void testBlobSplitDueToDifferentSchema() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.atLeast(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.atLeast(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -563,7 +572,8 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.times(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.times(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -603,7 +613,7 @@ public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Excepti ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.times(expectedBlobs)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = @@ -646,7 +656,7 @@ public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Except ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.atLeast(2)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = @@ -764,12 +774,15 @@ public void testBuildAndUpload() throws Exception { .build(); // Check FlushService.upload called with correct arguments + final ArgumentCaptor storageCaptor = + ArgumentCaptor.forClass(StreamingIngestStorage.class); final ArgumentCaptor nameCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor blobCaptor = ArgumentCaptor.forClass(byte[].class); final ArgumentCaptor> metadataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(testContext.flushService) .upload( + storageCaptor.capture(), nameCaptor.capture(), blobCaptor.capture(), metadataCaptor.capture(), @@ -912,10 +925,10 @@ public void testInvalidateChannels() { innerData.add(channel1Data); innerData.add(channel2Data); - StreamingIngestStage stage = Mockito.mock(StreamingIngestStage.class); - Mockito.when(stage.getClientPrefix()).thenReturn("client_prefix"); + IStorageManager storageManager = + Mockito.spy(new InternalStageManager<>(true, "role", "client", null)); FlushService flushService = - new FlushService<>(client, channelCache, stage, false); + new FlushService<>(client, channelCache, storageManager, false); flushService.invalidateAllChannelsInBlob(blobData, "Invalidated by test"); Assert.assertFalse(channel1.isValid()); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index b4fa769a1..5d8d8d36a 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static java.time.ZoneOffset.UTC; @@ -259,7 +263,7 @@ public void testOpenChannelRequestCreationSuccess() { Assert.assertEquals( "STREAMINGINGEST_TEST.PUBLIC.T_STREAMINGINGEST", request.getFullyQualifiedTableName()); - Assert.assertFalse(request.isOffsetTokenProvided()); + Assert.assertNull(request.getOffsetToken()); } @Test @@ -276,7 +280,6 @@ public void testOpenChannelRequesCreationtWithOffsetToken() { Assert.assertEquals( "STREAMINGINGEST_TEST.PUBLIC.T_STREAMINGINGEST", request.getFullyQualifiedTableName()); Assert.assertEquals("TEST_TOKEN", request.getOffsetToken()); - Assert.assertTrue(request.isOffsetTokenProvided()); } @Test diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java similarity index 83% rename from src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java rename to src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java index 2458c5bd5..d4c3f0374 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java @@ -1,6 +1,7 @@ package net.snowflake.ingest.streaming.internal; import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED; +import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_PASSWORD; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_USER; import static net.snowflake.ingest.utils.HttpUtil.NON_PROXY_HOSTS; @@ -33,6 +34,7 @@ import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.cloud.storage.StageInfo; import net.snowflake.client.jdbc.internal.amazonaws.util.IOUtils; +import net.snowflake.client.jdbc.internal.apache.http.HttpEntity; import net.snowflake.client.jdbc.internal.apache.http.StatusLine; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.entity.BasicHttpEntity; @@ -42,9 +44,7 @@ import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; -import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; -import net.snowflake.ingest.utils.ParameterProvider; import net.snowflake.ingest.utils.SFException; import org.junit.Assert; import org.junit.Test; @@ -57,7 +57,7 @@ @RunWith(PowerMockRunner.class) @PrepareForTest({TestUtils.class, HttpUtil.class, SnowflakeFileTransferAgent.class}) -public class StreamingIngestStageTest { +public class StreamingIngestStorageTest { private final String prefix = "EXAMPLE_PREFIX"; @@ -130,15 +130,16 @@ public void testPutRemote() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, 1); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); @@ -172,16 +173,14 @@ public void testPutLocal() throws Exception { String fullFilePath = "testOutput"; String fileName = "putLocalOutput"; - StreamingIngestStage stage = + StreamingIngestStorage stage = Mockito.spy( - new StreamingIngestStage( - true, - "role", - null, + new StreamingIngestStorage( null, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( fullFilePath, Optional.of(System.currentTimeMillis())), + null, 1)); Mockito.doReturn(true).when(stage).isLocalFS(); @@ -202,15 +201,16 @@ public void doTestPutRemoteRefreshes() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, maxUploadRetryCount); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeSQLException e = @@ -256,16 +256,17 @@ public void testPutRemoteGCS() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = Mockito.spy( - new StreamingIngestStage( - true, - "role", - null, - null, + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, 1)); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeFileTransferMetadataV1 metaMock = Mockito.mock(SnowflakeFileTransferMetadataV1.class); @@ -281,22 +282,30 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); - BasicHttpEntity entity = new BasicHttpEntity(); - entity.setContent( - new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); - Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - ParameterProvider parameterProvider = new ParameterProvider(); - StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, + "clientName", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, + 1); - StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge = + StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge metadataWithAge = stage.refreshSnowflakeMetadata(true); final ArgumentCaptor endpointCaptor = ArgumentCaptor.forClass(String.class); @@ -304,7 +313,7 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { Mockito.verify(mockBuilder) .generateStreamingIngestPostRequest( stringCaptor.capture(), endpointCaptor.capture(), Mockito.any()); - Assert.assertEquals(Constants.CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); + Assert.assertEquals(CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); Assert.assertTrue(metadataWithAge.timestamp.isPresent()); Assert.assertEquals( StageInfo.StageType.S3, metadataWithAge.fileTransferMetadata.getStageInfo().getStageType()); @@ -315,7 +324,7 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { Assert.assertEquals( Paths.get("placeholder").toAbsolutePath(), Paths.get(metadataWithAge.fileTransferMetadata.getPresignedUrlFileName()).toAbsolutePath()); - Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix()); + Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix()); } @Test @@ -342,16 +351,18 @@ public void testRefreshSnowflakeMetadataDeploymentIdMismatch() throws Exception .thenReturn(mockResponse) .thenReturn(mockResponse); - StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "clientName", snowflakeServiceClient); - StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge = - stage.refreshSnowflakeMetadata(true); + StreamingIngestStorage storage = storageManager.getStorage(""); + storage.refreshSnowflakeMetadata(true); - Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix()); + Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix()); SFException exception = - Assert.assertThrows(SFException.class, () -> stage.refreshSnowflakeMetadata(true)); + Assert.assertThrows(SFException.class, () -> storage.refreshSnowflakeMetadata(true)); Assert.assertEquals( ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH.getMessageCode(), exception.getVendorCode()); Assert.assertEquals( @@ -369,19 +380,27 @@ public void testFetchSignedURL() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); - BasicHttpEntity entity = new BasicHttpEntity(); - entity.setContent( - new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); - Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + StreamingIngestStorage stage = + new StreamingIngestStorage( + storageManager, + "clientName", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, + 1); SnowflakeFileTransferMetadataV1 metadata = stage.fetchSignedURL("path/fileName"); @@ -390,7 +409,7 @@ public void testFetchSignedURL() throws Exception { Mockito.verify(mockBuilder) .generateStreamingIngestPostRequest( stringCaptor.capture(), endpointCaptor.capture(), Mockito.any()); - Assert.assertEquals(Constants.CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); + Assert.assertEquals(CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); Assert.assertEquals(StageInfo.StageType.S3, metadata.getStageInfo().getStageType()); Assert.assertEquals("foo/streaming_ingest/", metadata.getStageInfo().getLocation()); Assert.assertEquals("path/fileName", metadata.getPresignedUrlFileName()); @@ -407,26 +426,26 @@ public void testRefreshSnowflakeMetadataSynchronized() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); - BasicHttpEntity entity = new BasicHttpEntity(); - entity.setContent( - new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); - Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - mockClient, - mockBuilder, + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(0L)), + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, 1); ThreadFactory buildUploadThreadFactory = @@ -555,15 +574,16 @@ public void testRefreshMetadataOnFirstPutException() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, maxUploadRetryCount); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeSQLException e = @@ -608,4 +628,10 @@ public Object answer(org.mockito.invocation.InvocationOnMock invocation) InputStream capturedInput = capturedConfig.getUploadStream(); Assert.assertEquals("Hello Upload", IOUtils.toString(capturedInput)); } + + private HttpEntity createHttpEntity(String content) { + BasicHttpEntity entity = new BasicHttpEntity(); + entity.setContent(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8))); + return entity; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java index 4e054c209..b40cdad82 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; @@ -5,9 +9,6 @@ import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.HashMap; -import java.util.Map; import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.IngestResponseException; @@ -53,11 +54,11 @@ public void testJWTRetries() throws Exception { "testJWTRetries")); // build payload - Map payload = new HashMap<>(); - if (!TestUtils.getRole().isEmpty() && !TestUtils.getRole().equals("DEFAULT_ROLE")) { - payload.put("role", TestUtils.getRole()); - } - ObjectMapper mapper = new ObjectMapper(); + ClientConfigureRequest request = + new ClientConfigureRequest( + !TestUtils.getRole().isEmpty() && !TestUtils.getRole().equals("DEFAULT_ROLE") + ? TestUtils.getRole() + : null); // request wih invalid token, should get 401 3 times PowerMockito.doReturn("invalid_token").when(spyManager).getToken(); @@ -66,7 +67,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient, @@ -84,7 +85,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient, @@ -101,7 +102,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient,