From d9b0bebceb7c69942e0be434f1fce816e33a88ae Mon Sep 17 00:00:00 2001 From: Jay Patel Date: Wed, 1 Mar 2023 01:06:35 -0800 Subject: [PATCH] NO-SNOW Revert one client change, JDBC upgrade and config for arrow and parquet (#555) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Revert "no-snow Add connector name in KC client name for better debugging the client … (#547)" This reverts commit f7df64df6a3288664d38f4f2488e4abbb994e669. * Revert "Refactor sink task integration tests to be more readable (#543)" This reverts commit 1b24187c864322549e60963db58bc75bb8ca3924. * Ignore flaky test * Revert "Fix client off by one issue (#540)" This reverts commit a5e9b456b940b759cf8f6d30f1323107571fa7bf. * Revert "[SNOW-692657] Create one client per connector instead of one client per task (#528)" This reverts commit 0b4b1cb234ee94dc3627555f63a140c723f871ee. * Revert "SNOW-740327 Create config for using Arrow BDEC file format + Bump Client SDK 1.0.3-beta (#541)" This reverts commit c5f8aed1a3474fe2cc46ba9fb41faec29341b3d9. * Revert "[SNOW-726924] Add New jvm.nonProxy.hosts Parameter and update JDBC to 3.13.23 (#533)" This reverts commit d2c35614f940f2b20fb345ecadcb3918c610f637. * Revert "Modify pom.xml to use jdbc defined in connector and not from ingest sdk (#546)" This reverts commit 79051a16813a34825a8454ae7d6291837b118c50. --- pom.xml | 20 +- pom_confluent.xml | 20 +- .../connector/SnowflakeSinkConnector.java | 35 +- .../SnowflakeSinkConnectorConfig.java | 31 +- .../com/snowflake/kafka/connector/Utils.java | 26 +- .../connector/internal/InternalUtils.java | 7 - .../internal/SnowflakeConnectionService.java | 7 - .../SnowflakeConnectionServiceV1.java | 5 - .../connector/internal/SnowflakeErrors.java | 7 - .../internal/ingestsdk/IngestSdkProvider.java | 42 -- .../ingestsdk/KcStreamingIngestClient.java | 186 ----- .../ingestsdk/StreamingClientManager.java | 202 ------ .../streaming/SnowflakeSinkServiceV2.java | 60 ++ .../streaming/TopicPartitionChannel.java | 73 +- .../kafka/connector/ConnectorConfigTest.java | 55 -- .../SnowflakeSinkTaskForStreamingIT.java | 306 ++++++++ .../SnowflakeSinkTaskTestForStreamingIT.java | 256 ------- .../kafka/connector/internal/TestUtils.java | 73 +- .../KcStreamingIngestClientTest.java | 201 ----- .../ingestsdk/StreamingClientManagerTest.java | 182 ----- .../streaming/SnowflakeSinkServiceV2IT.java | 96 ++- .../streaming/SnowflakeSinkTaskTest.java | 146 ---- .../streaming/TopicPartitionChannelIT.java | 109 ++- .../streaming/TopicPartitionChannelTest.java | 686 +++++++----------- 24 files changed, 837 insertions(+), 1994 deletions(-) delete mode 100644 src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/IngestSdkProvider.java delete mode 100644 src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/KcStreamingIngestClient.java delete mode 100644 src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManager.java create mode 100644 src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskForStreamingIT.java delete mode 100644 src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskTestForStreamingIT.java delete mode 100644 src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/KcStreamingIngestClientTest.java delete mode 100644 src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManagerTest.java delete mode 100644 src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkTaskTest.java diff --git a/pom.xml b/pom.xml index 3f03651e4..4177f9321 100644 --- a/pom.xml +++ b/pom.xml @@ -313,24 +313,11 @@ - - - net.snowflake - snowflake-jdbc - 3.13.23 - - net.snowflake snowflake-ingest-sdk 1.1.0 - - - net.snowflake - snowflake-jdbc - - @@ -408,6 +395,13 @@ 7.2.1 + + + net.snowflake + snowflake-jdbc + 3.13.14 + + io.dropwizard.metrics diff --git a/pom_confluent.xml b/pom_confluent.xml index 5bd4e0fad..b62af2921 100644 --- a/pom_confluent.xml +++ b/pom_confluent.xml @@ -364,24 +364,11 @@ - - - net.snowflake - snowflake-jdbc - 3.13.23 - - net.snowflake snowflake-ingest-sdk 1.1.0 - - - net.snowflake - snowflake-jdbc - - @@ -459,6 +446,13 @@ 7.2.1 + + + net.snowflake + snowflake-jdbc + 3.13.14 + + io.dropwizard.metrics diff --git a/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnector.java b/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnector.java index 0a6641d4f..f0337211f 100644 --- a/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnector.java +++ b/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnector.java @@ -21,8 +21,6 @@ import com.snowflake.kafka.connector.internal.SnowflakeConnectionServiceFactory; import com.snowflake.kafka.connector.internal.SnowflakeErrors; import com.snowflake.kafka.connector.internal.SnowflakeKafkaConnectorException; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig; import com.snowflake.kafka.connector.internal.telemetry.SnowflakeTelemetryService; import java.util.ArrayList; import java.util.HashMap; @@ -44,15 +42,12 @@ * running on Kafka Connect Workers. */ public class SnowflakeSinkConnector extends SinkConnector { - // TEMPORARY config of num tasks assigned per client, round up if number is not divisible - // currently set to 1 for a 1:1 task to client ratio, so we can maintain the current functionality - private static final int NUM_TASK_TO_CLIENT = 1; - // create logger without correlationId for now private static LoggerHandler LOGGER = new LoggerHandler(SnowflakeSinkConnector.class.getName()); private Map config; // connector configuration, provided by // user through kafka connect framework + private String connectorName; // unique name of this connector instance // SnowflakeJDBCWrapper provides methods to interact with user's snowflake // account and executes queries @@ -69,12 +64,6 @@ public class SnowflakeSinkConnector extends SinkConnector { // Using setupComplete to synchronize private boolean setupComplete; - // The id of this connector instance. Should only be reset on start - private String kcInstanceId; - - // If this connector is configured to use streaming snowpipe ingestion - private boolean usesStreamingIngestion; - /** No-Arg constructor. Required by Kafka Connect framework */ public SnowflakeSinkConnector() { setupComplete = false; @@ -100,8 +89,7 @@ public void start(final Map parsedConfig) { connectorStartTime = System.currentTimeMillis(); // initialize logging with global instance Id - this.kcInstanceId = this.getKcInstanceId(this.connectorStartTime); - LoggerHandler.setConnectGlobalInstanceId(kcInstanceId); + LoggerHandler.setConnectGlobalInstanceId(this.getKcInstanceId(this.connectorStartTime)); config = new HashMap<>(parsedConfig); @@ -119,14 +107,6 @@ public void start(final Map parsedConfig) { // config as a side effect conn = SnowflakeConnectionServiceFactory.builder().setProperties(config).build(); - // check if we are using snowpipe streaming ingestion - this.usesStreamingIngestion = - config != null - && config.get(SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT) != null - && config - .get(SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT) - .equalsIgnoreCase(IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); - telemetryClient = conn.getTelemetryClient(); telemetryClient.reportKafkaConnectStart(connectorStartTime, this.config); @@ -147,9 +127,6 @@ public void stop() { // set task logging to default SnowflakeSinkTask.setTotalTaskCreationCount(-1); setupComplete = false; - - IngestSdkProvider.getStreamingClientManager().closeAllStreamingClients(); - LOGGER.info("SnowflakeSinkConnector:stop"); telemetryClient.reportKafkaConnectStop(connectorStartTime); } @@ -179,13 +156,6 @@ public Class taskClass() { */ @Override public List> taskConfigs(final int maxTasks) { - // create all necessary clients, evenly mapping tasks to clients - // must be done here instead of start() because we need the maxTasks value - if (this.usesStreamingIngestion) { - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(config, kcInstanceId, maxTasks, NUM_TASK_TO_CLIENT); - } - // wait for setup to complete int counter = 0; while (counter < 120) // poll for 120*5 seconds (10 mins) maximum @@ -206,7 +176,6 @@ public List> taskConfigs(final int maxTasks) { throw SnowflakeErrors.ERROR_5007.getException(telemetryClient); } - // taskIds must be consecutive, the StreamingClientManager relies on this List> taskConfigs = new ArrayList<>(maxTasks); for (int i = 0; i < maxTasks; i++) { Map conf = new HashMap<>(config); diff --git a/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnectorConfig.java b/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnectorConfig.java index 19b091637..4498c7d8a 100644 --- a/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnectorConfig.java +++ b/src/main/java/com/snowflake/kafka/connector/SnowflakeSinkConnectorConfig.java @@ -78,7 +78,6 @@ public class SnowflakeSinkConnectorConfig { private static final String PROXY_INFO = "Proxy Info"; public static final String JVM_PROXY_HOST = "jvm.proxy.host"; public static final String JVM_PROXY_PORT = "jvm.proxy.port"; - public static final String JVM_NON_PROXY_HOSTS = "jvm.nonProxy.hosts"; public static final String JVM_PROXY_USERNAME = "jvm.proxy.username"; public static final String JVM_PROXY_PASSWORD = "jvm.proxy.password"; @@ -112,10 +111,6 @@ public class SnowflakeSinkConnectorConfig { public static final String INGESTION_METHOD_DEFAULT_SNOWPIPE = IngestionMethodConfig.SNOWPIPE.toString(); - // This is the streaming bdec file version which can be defined in config - // NOTE: Please do not override this value unless recommended from snowflake - public static final String SNOWPIPE_STREAMING_FILE_VERSION = "snowflake.streaming.file.version"; - // TESTING public static final String REBALANCING = "snowflake.test.rebalancing"; public static final boolean REBALANCING_DEFAULT = false; @@ -307,16 +302,6 @@ static ConfigDef newConfigDef() { 1, ConfigDef.Width.NONE, JVM_PROXY_PORT) - .define( - JVM_NON_PROXY_HOSTS, - Type.STRING, - "", - Importance.LOW, - "JVM option: http.nonProxyHosts", - PROXY_INFO, - 2, - ConfigDef.Width.NONE, - JVM_NON_PROXY_HOSTS) .define( JVM_PROXY_USERNAME, Type.STRING, @@ -324,7 +309,7 @@ static ConfigDef newConfigDef() { Importance.LOW, "JVM proxy username", PROXY_INFO, - 3, + 2, ConfigDef.Width.NONE, JVM_PROXY_USERNAME) .define( @@ -334,7 +319,7 @@ static ConfigDef newConfigDef() { Importance.LOW, "JVM proxy password", PROXY_INFO, - 4, + 3, ConfigDef.Width.NONE, JVM_PROXY_PASSWORD) // Connector Config @@ -478,18 +463,6 @@ static ConfigDef newConfigDef() { 5, ConfigDef.Width.NONE, INGESTION_METHOD_OPT) - .define( - SNOWPIPE_STREAMING_FILE_VERSION, - Type.STRING, - "", // default is handled in Ingest SDK - null, // no validator - Importance.LOW, - "Acceptable values for Snowpipe Streaming BDEC Versions: 1 and 3. Check Ingest" - + " SDK for default behavior. Please do not set this unless Absolutely needed. ", - CONNECTOR_CONFIG, - 6, - ConfigDef.Width.NONE, - SNOWPIPE_STREAMING_FILE_VERSION) .define( ERRORS_TOLERANCE_CONFIG, Type.STRING, diff --git a/src/main/java/com/snowflake/kafka/connector/Utils.java b/src/main/java/com/snowflake/kafka/connector/Utils.java index d9b4bf570..3f0c73ea6 100644 --- a/src/main/java/com/snowflake/kafka/connector/Utils.java +++ b/src/main/java/com/snowflake/kafka/connector/Utils.java @@ -81,7 +81,6 @@ public class Utils { public static final String HTTPS_PROXY_PORT = "https.proxyPort"; public static final String HTTP_PROXY_HOST = "http.proxyHost"; public static final String HTTP_PROXY_PORT = "http.proxyPort"; - public static final String HTTP_NON_PROXY_HOSTS = "http.nonProxyHosts"; public static final String JDK_HTTP_AUTH_TUNNELING = "jdk.http.auth.tunneling.disabledSchemes"; public static final String HTTPS_PROXY_USER = "https.proxyUser"; @@ -225,7 +224,6 @@ static void validateProxySetting(Map config) { String port = SnowflakeSinkConnectorConfig.getProperty( config, SnowflakeSinkConnectorConfig.JVM_PROXY_PORT); - // either both host and port are provided or none of them are provided if (host != null ^ port != null) { throw SnowflakeErrors.ERROR_0022.getException( @@ -264,12 +262,8 @@ static boolean enableJVMProxy(Map config) { String port = SnowflakeSinkConnectorConfig.getProperty( config, SnowflakeSinkConnectorConfig.JVM_PROXY_PORT); - String nonProxyHosts = - SnowflakeSinkConnectorConfig.getProperty( - config, SnowflakeSinkConnectorConfig.JVM_NON_PROXY_HOSTS); if (host != null && port != null) { - LOGGER.info( - "enable jvm proxy: {}:{} and bypass proxy for hosts: {}", host, port, nonProxyHosts); + LOGGER.info("enable jvm proxy: {}:{}", host, port); // enable https proxy System.setProperty(HTTP_USE_PROXY, "true"); @@ -278,17 +272,6 @@ static boolean enableJVMProxy(Map config) { System.setProperty(HTTPS_PROXY_HOST, host); System.setProperty(HTTPS_PROXY_PORT, port); - // If the user provided the jvm.nonProxy.hosts configuration then we - // will append that to the list provided by the JVM argument - // -Dhttp.nonProxyHosts and not override it altogether, if it exists. - if (nonProxyHosts != null) { - nonProxyHosts = - (System.getProperty(HTTP_NON_PROXY_HOSTS) != null) - ? System.getProperty(HTTP_NON_PROXY_HOSTS) + "|" + nonProxyHosts - : nonProxyHosts; - System.setProperty(HTTP_NON_PROXY_HOSTS, nonProxyHosts); - } - // set username and password String username = SnowflakeSinkConnectorConfig.getProperty( @@ -382,13 +365,6 @@ static String validateConfig(Map config) { "Schematization is only available with {}.", IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); } - if (config.containsKey(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION)) { - configIsValid = false; - LOGGER.error( - "{} is only available with ingestion type: {}.", - SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, - IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); - } } if (config.containsKey(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MAP) diff --git a/src/main/java/com/snowflake/kafka/connector/internal/InternalUtils.java b/src/main/java/com/snowflake/kafka/connector/internal/InternalUtils.java index 9b936ec87..fcd40b1d6 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/InternalUtils.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/InternalUtils.java @@ -202,13 +202,6 @@ protected static Properties generateProxyParametersIfRequired(MapAny exceptions will be passed up, a sfexception will be converted to a connectexception - * - * @param streamingClientProps the properties for the client - * @param parameterOverrides Helps to override any default parameters in streaming client - * @param clientName the client name to uniquely identify the client - */ - protected KcStreamingIngestClient( - Properties streamingClientProps, Map parameterOverrides, String clientName) { - try { - LOGGER.info("Creating Streaming Client: {}", clientName); - this.client = - SnowflakeStreamingIngestClientFactory.builder(clientName) - .setProperties(streamingClientProps) - .setParameterOverrides(parameterOverrides) - .build(); - - assert this.client != null; // client is final, so never need to do another null check - assert this.client.getName().equals(clientName); - } catch (SFException ex) { - throw new ConnectException(ex); - } - } - - /** - * Creates an ingest sdk OpenChannelRequest and opens the client's channel - * - *

No exception handling done, all exceptions will be passed through - * - * @param channelName the name of the channel to open - * @param config config to get the database and schema names for the channel - * @param tableName table name of the channel - * @return the opened channel - */ - public SnowflakeStreamingIngestChannel openChannel( - String channelName, Map config, String tableName) { - OpenChannelRequest channelRequest = - OpenChannelRequest.builder(channelName) - .setDBName(config.get(Utils.SF_DATABASE)) - .setSchemaName(config.get(Utils.SF_SCHEMA)) - .setTableName(tableName) - .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) - .build(); - LOGGER.info("Opening a channel with name:{} for table name:{}", channelName, tableName); - - return this.client.openChannel(channelRequest); - } - - /** - * Calls the ingest sdk to close the client sdk - * - *

Swallows all exceptions and returns t/f if the client was closed because closing is best - * effort - * - * @return if the client was successfully closed - */ - public boolean close() { - if (this.client.isClosed()) { - LOGGER.debug("Streaming client is already closed"); - return true; - } - - LOGGER.info("Closing Streaming Client:{}", this.client.getName()); - - try { - this.client.close(); - return true; - } catch (Exception e) { - String message = - e.getMessage() != null && !e.getMessage().isEmpty() - ? e.getMessage() - : "no error message provided"; - - String cause = - e.getCause() != null - && e.getCause().getStackTrace() != null - && !Arrays.toString(e.getCause().getStackTrace()).isEmpty() - ? Arrays.toString(e.getCause().getStackTrace()) - : "no cause provided"; - - // don't throw an exception because closing the client is best effort - // the actual close, not in the catch or finally - LOGGER.error("Failure closing Streaming client msg:{}, cause:{}", message, cause); - return false; - } - } - - /** - * Checks if the current client is closed - * - * @return if the client is closed - */ - public boolean isClosed() { - return this.client.isClosed(); - } - - /** - * Returns the clients name. We treat this as the id - * - * @return the clients name - */ - public String getName() { - return this.client.getName(); - } - - /** - * Equality between clients is verified by the client name and the state (if it is closed or not) - * - * @param o Other object to check equality - * @return If the given object is the same - */ - @Override - public boolean equals(Object o) { - if (!(o instanceof KcStreamingIngestClient)) { - return false; - } - - KcStreamingIngestClient otherClient = (KcStreamingIngestClient) o; - return otherClient.getName().equals(this.getName()) - && otherClient.isClosed() == this.isClosed(); - } -} diff --git a/src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManager.java b/src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManager.java deleted file mode 100644 index cf0ebdb1a..000000000 --- a/src/main/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManager.java +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Copyright (c) 2023 Snowflake Inc. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package com.snowflake.kafka.connector.internal.ingestsdk; - -import static net.snowflake.ingest.utils.ParameterProvider.BLOB_FORMAT_VERSION; - -import com.google.common.annotations.VisibleForTesting; -import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; -import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.LoggerHandler; -import com.snowflake.kafka.connector.internal.SnowflakeErrors; -import com.snowflake.kafka.connector.internal.streaming.StreamingUtils; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.Properties; - -/** - * Provides access to the streaming ingest clients. This should be the only place to manage clients. - */ -public class StreamingClientManager { - private LoggerHandler LOGGER; - - private Map taskToClientMap; - private int maxTasks; - private final int minTasks = 0; - private int clientId; // this should only ever increase - - // TESTING ONLY - inject the client map - @VisibleForTesting - public StreamingClientManager(Map taskToClientMap) { - this(); - this.taskToClientMap = taskToClientMap; - this.clientId = (int) taskToClientMap.values().stream().distinct().count() - 1; - } - - /** Creates a new client manager */ - protected StreamingClientManager() { - LOGGER = new LoggerHandler(this.getClass().getName()); - this.taskToClientMap = new HashMap<>(); - this.maxTasks = 0; - this.clientId = -1; // will be incremented when a client is created - } - - /** Gets the task to client map associated with StreamingClientManager */ - @VisibleForTesting - public Map getTaskToClientMap() { - return taskToClientMap; - } - - /** - * Creates as many clients as needed with the connector config and kc instance id. This assumes - * that all taskIds are consecutive ranging from 0 to maxTasks. - * - * @param connectorConfig the config for the clients, cannot be null - * @param kcInstanceId the kafka connector id requesting the clients, cannot be null - * @param maxTasks the max number of tasks assigned to this connector, must be greater than 0 - * @param numTasksPerClient the max number of tasks to be assigned to each client, must be greater - * than 0 - */ - public void createAllStreamingClients( - Map connectorConfig, - String kcInstanceId, - int maxTasks, - int numTasksPerClient) { - assert connectorConfig != null && kcInstanceId != null && maxTasks > 0 && numTasksPerClient > 0; - - this.maxTasks = maxTasks; - - int clientCount = (int) Math.ceil((double) maxTasks / (double) numTasksPerClient); - - Properties clientProperties = new Properties(); - clientProperties.putAll( - StreamingUtils.convertConfigForStreamingClient(new HashMap<>(connectorConfig))); - - // Override only if bdec version is explicitly set in config, default to the version set inside - // Ingest SDK - Map parameterOverrides = new HashMap<>(); - Optional snowpipeStreamingBdecVersion = - Optional.ofNullable( - connectorConfig.get(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION)); - - snowpipeStreamingBdecVersion.ifPresent( - overriddenValue -> parameterOverrides.put(BLOB_FORMAT_VERSION, overriddenValue)); - - LOGGER.info( - "Creating {} clients for {} tasks with max {} tasks per client using {} file format", - clientCount, - maxTasks, - numTasksPerClient, - snowpipeStreamingBdecVersion); - - // put a new client for every tasksToCurrClient taskIds - int tasksToCurrClient = 0; - KcStreamingIngestClient createdClient = - this.getClientHelper( - clientProperties, - parameterOverrides, - connectorConfig.get(Utils.NAME), - kcInstanceId, - 0); // asserted that we have at least 1 task - - for (int taskId = 0; taskId < this.maxTasks; taskId++) { - if (tasksToCurrClient == numTasksPerClient) { - createdClient = - this.getClientHelper( - clientProperties, - parameterOverrides, - connectorConfig.get(Utils.NAME), - kcInstanceId, - taskId); - tasksToCurrClient = 1; - } else { - tasksToCurrClient++; - } - - this.taskToClientMap.put(taskId, createdClient); - } - } - - // builds the client name and returns the created client. note taskId is used just for logging - private KcStreamingIngestClient getClientHelper( - Properties props, - Map parameterOverrides, - String connectorName, - String kcInstanceId, - int taskId) { - this.clientId++; - String clientName = - KcStreamingIngestClient.buildStreamingIngestClientName( - connectorName, kcInstanceId, this.clientId); - LOGGER.debug("Creating client {} for taskid {}", clientName, taskId); - - return new KcStreamingIngestClient(props, parameterOverrides, clientName); - } - - /** - * Gets the client corresponding to the task id and validates it (not null and is closed) - * - * @param taskId the task id to get the corresponding client - * @return The streaming client, throws an exception if no client was initialized - */ - public KcStreamingIngestClient getValidClient(int taskId) { - if (taskId > this.maxTasks || taskId < this.minTasks) { - throw SnowflakeErrors.ERROR_3010.getException( - Utils.formatString( - "taskId must be between 0 and {} but was given {}", this.maxTasks, taskId)); - } - - if (this.clientId < 0) { - throw SnowflakeErrors.ERROR_3009.getException("call the manager to create the clients"); - } - - KcStreamingIngestClient client = this.taskToClientMap.get(taskId); - if (client == null || client.isClosed()) { - throw SnowflakeErrors.ERROR_3009.getException(); - } - - return client; - } - - /** - * Gets the number of clients created - * - * @return the number of clients created - */ - public int getClientCount() { - return this.clientId + 1; // clientid starts at 0, so off by one - } - - /** - * Closes all the streaming clients in the map. Client closure exceptions will be swallowed and - * logged - * - * @return if all the clients were closed - */ - public boolean closeAllStreamingClients() { - boolean isAllClosed = true; - LOGGER.info("Closing all clients"); - - for (Integer taskId : this.taskToClientMap.keySet()) { - KcStreamingIngestClient client = this.taskToClientMap.get(taskId); - isAllClosed &= client.close(); - } - - return isAllClosed; - } -} diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java index df9def7d3..52fbbb2e8 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java @@ -16,11 +16,17 @@ import com.snowflake.kafka.connector.internal.telemetry.SnowflakeTelemetryService; import com.snowflake.kafka.connector.records.RecordService; import com.snowflake.kafka.connector.records.SnowflakeMetadataConfig; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.Properties; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; +import net.snowflake.ingest.utils.SFException; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.errors.ConnectException; import org.apache.kafka.connect.sink.SinkRecord; import org.apache.kafka.connect.sink.SinkTaskContext; @@ -44,6 +50,8 @@ public class SnowflakeSinkServiceV2 implements SnowflakeSinkService { private static final LoggerHandler LOGGER = new LoggerHandler(SnowflakeSinkServiceV2.class.getName()); + private static String STREAMING_CLIENT_PREFIX_NAME = "KC_CLIENT_"; + // Assume next three values are a threshold after which we will call insertRows API // Set in config (Time based flush) in seconds private long flushTimeSeconds; @@ -82,11 +90,16 @@ public class SnowflakeSinkServiceV2 implements SnowflakeSinkService { private SinkTaskContext sinkTaskContext; // ------ Streaming Ingest ------ // + // needs url, username. p8 key, role name + private SnowflakeStreamingIngestClient streamingIngestClient; + // Config set in JSON private final Map connectorConfig; private final String taskId; + private final String streamingIngestClientName; + private boolean enableSchematization; /** @@ -120,6 +133,9 @@ public SnowflakeSinkServiceV2( this.recordService.setAndGetEnableSchematizationFromConfig(this.connectorConfig); this.taskId = connectorConfig.getOrDefault(Utils.TASK_ID, "-1"); + this.streamingIngestClientName = + STREAMING_CLIENT_PREFIX_NAME + conn.getConnectorName() + "_" + taskId; + initStreamingClient(); this.partitionsToChannel = new HashMap<>(); } @@ -155,6 +171,7 @@ private void createStreamingChannelForTopicPartition( partitionsToChannel.put( partitionChannelKey, new TopicPartitionChannel( + this.streamingIngestClient, topicPartition, partitionChannelKey, // Streaming channel name tableName, @@ -253,6 +270,7 @@ public void closeAll() { topicPartitionChannel.closeChannel(); }); partitionsToChannel.clear(); + closeStreamingClient(); } /** @@ -426,6 +444,12 @@ protected static String partitionChannelKey(String topic, int partition) { return topic + "_" + partition; } + /* Used for testing */ + @VisibleForTesting + SnowflakeStreamingIngestClient getStreamingIngestClient() { + return this.streamingIngestClient; + } + /** * Used for testing Only * @@ -440,6 +464,42 @@ protected Optional getTopicPartitionChannelFromCacheKey( } // ------ Streaming Ingest Related Functions ------ // + + /* Init Streaming client. If is also used to re-init the client if client was closed before. */ + private void initStreamingClient() { + Map streamingPropertiesMap = + StreamingUtils.convertConfigForStreamingClient(new HashMap<>(this.connectorConfig)); + Properties streamingClientProps = new Properties(); + streamingClientProps.putAll(streamingPropertiesMap); + if (this.streamingIngestClient == null || this.streamingIngestClient.isClosed()) { + try { + LOGGER.info("Initializing Streaming Client. ClientName:{}", this.streamingIngestClientName); + this.streamingIngestClient = + SnowflakeStreamingIngestClientFactory.builder(this.streamingIngestClientName) + .setProperties(streamingClientProps) + .build(); + } catch (SFException ex) { + LOGGER.error( + "Exception creating streamingIngestClient with name:{}", + this.streamingIngestClientName); + throw new ConnectException(ex); + } + } + } + + /** Closes the streaming client. */ + private void closeStreamingClient() { + LOGGER.info("Closing Streaming Client:{}", this.streamingIngestClientName); + try { + streamingIngestClient.close(); + } catch (Exception e) { + LOGGER.error( + "Failure closing Streaming client msg:{}, cause:{}", + e.getMessage(), + Arrays.toString(e.getCause().getStackTrace())); + } + } + private void createTableIfNotExists(final String tableName) { if (this.conn.tableExist(tableName)) { if (!this.enableSchematization) { diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannel.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannel.java index cc81f6aba..626962f08 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannel.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannel.java @@ -12,13 +12,12 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.dlq.KafkaRecordErrorReporter; import com.snowflake.kafka.connector.internal.BufferThreshold; import com.snowflake.kafka.connector.internal.LoggerHandler; import com.snowflake.kafka.connector.internal.PartitionBuffer; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import com.snowflake.kafka.connector.internal.ingestsdk.KcStreamingIngestClient; import com.snowflake.kafka.connector.records.RecordService; import com.snowflake.kafka.connector.records.SnowflakeJsonSchema; import com.snowflake.kafka.connector.records.SnowflakeRecordContent; @@ -38,7 +37,9 @@ import java.util.concurrent.locks.ReentrantLock; import net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException; import net.snowflake.ingest.streaming.InsertValidationResponse; +import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import net.snowflake.ingest.utils.Pair; import net.snowflake.ingest.utils.SFException; import org.apache.kafka.common.TopicPartition; @@ -122,7 +123,7 @@ public class TopicPartitionChannel { // -------- private final fields -------- // - private final KcStreamingIngestClient streamingIngestClient; + private final SnowflakeStreamingIngestClient streamingIngestClient; // Topic partition Object from connect consisting of topic and partition private final TopicPartition topicPartition; @@ -181,7 +182,30 @@ public class TopicPartitionChannel { // Reference to the Snowflake connection service private final SnowflakeConnectionService conn; + /** Testing only, initialize TopicPartitionChannel without the connection service */ + public TopicPartitionChannel( + SnowflakeStreamingIngestClient streamingIngestClient, + TopicPartition topicPartition, + final String channelName, + final String tableName, + final BufferThreshold streamingBufferThreshold, + final Map sfConnectorConfig, + KafkaRecordErrorReporter kafkaRecordErrorReporter, + SinkTaskContext sinkTaskContext) { + this( + streamingIngestClient, + topicPartition, + channelName, + tableName, + streamingBufferThreshold, + sfConnectorConfig, + kafkaRecordErrorReporter, + sinkTaskContext, + null); + } + /** + * @param streamingIngestClient client created specifically for this task * @param topicPartition topic partition corresponding to this Streaming Channel * (TopicPartitionChannel) * @param channelName channel Name which is deterministic for topic and partition @@ -193,6 +217,7 @@ public class TopicPartitionChannel { * @param conn the snowflake connection service */ public TopicPartitionChannel( + SnowflakeStreamingIngestClient streamingIngestClient, TopicPartition topicPartition, final String channelName, final String tableName, @@ -201,21 +226,18 @@ public TopicPartitionChannel( KafkaRecordErrorReporter kafkaRecordErrorReporter, SinkTaskContext sinkTaskContext, SnowflakeConnectionService conn) { - this.streamingIngestClient = - Preconditions.checkNotNull( - IngestSdkProvider.getStreamingClientManager().getValidClient(conn.getTaskId())); + this.streamingIngestClient = Preconditions.checkNotNull(streamingIngestClient); + Preconditions.checkState(!streamingIngestClient.isClosed()); this.topicPartition = Preconditions.checkNotNull(topicPartition); this.channelName = Preconditions.checkNotNull(channelName); this.tableName = Preconditions.checkNotNull(tableName); this.streamingBufferThreshold = Preconditions.checkNotNull(streamingBufferThreshold); this.sfConnectorConfig = Preconditions.checkNotNull(sfConnectorConfig); - this.channel = - Preconditions.checkNotNull( - this.streamingIngestClient.openChannel( - this.channelName, this.sfConnectorConfig, this.tableName)); + this.channel = Preconditions.checkNotNull(openChannelForTable()); this.kafkaRecordErrorReporter = Preconditions.checkNotNull(kafkaRecordErrorReporter); this.sinkTaskContext = Preconditions.checkNotNull(sinkTaskContext); this.conn = conn; + this.recordService = new RecordService(); this.previousFlushTimeStampMs = System.currentTimeMillis(); @@ -910,10 +932,7 @@ private void resetChannelMetadataAfterRecovery( private long getRecoveredOffsetFromSnowflake( final StreamingApiFallbackInvoker streamingApiFallbackInvoker) { LOGGER.warn("{} Re-opening channel:{}", streamingApiFallbackInvoker, this.getChannelName()); - this.channel = - Preconditions.checkNotNull( - this.streamingIngestClient.openChannel( - this.channelName, this.sfConnectorConfig, this.tableName)); + this.channel = Preconditions.checkNotNull(openChannelForTable()); LOGGER.warn( "{} Fetching offsetToken after re-opening the channel:{}", streamingApiFallbackInvoker, @@ -958,6 +977,32 @@ private long fetchLatestCommittedOffsetFromSnowflake() { } } + /** + * Open a channel for Table with given channel name and tableName. + * + *

Open channels happens at: + * + *

Constructor of TopicPartitionChannel -> which means we will wipe of all states and it will + * call precomputeOffsetTokenForChannel + * + *

Failure handling which will call reopen, replace instance variable with new channel and call + * offsetToken/insertRows. + * + * @return new channel which was fetched after open/reopen + */ + private SnowflakeStreamingIngestChannel openChannelForTable() { + OpenChannelRequest channelRequest = + OpenChannelRequest.builder(this.channelName) + .setDBName(this.sfConnectorConfig.get(Utils.SF_DATABASE)) + .setSchemaName(this.sfConnectorConfig.get(Utils.SF_SCHEMA)) + .setTableName(this.tableName) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .build(); + LOGGER.info( + "Opening a channel with name:{} for table name:{}", this.channelName, this.tableName); + return streamingIngestClient.openChannel(channelRequest); + } + /** * Close channel associated to this partition Not rethrowing connect exception because the * connector will stop. Channel will eventually be reopened. diff --git a/src/test/java/com/snowflake/kafka/connector/ConnectorConfigTest.java b/src/test/java/com/snowflake/kafka/connector/ConnectorConfigTest.java index a8697aae9..7236a20b7 100644 --- a/src/test/java/com/snowflake/kafka/connector/ConnectorConfigTest.java +++ b/src/test/java/com/snowflake/kafka/connector/ConnectorConfigTest.java @@ -3,7 +3,6 @@ import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_LOG_ENABLE_CONFIG; import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_TOLERANCE_CONFIG; import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.NAME; -import static com.snowflake.kafka.connector.Utils.HTTP_NON_PROXY_HOSTS; import static com.snowflake.kafka.connector.internal.TestUtils.getConfig; import com.snowflake.kafka.connector.internal.SnowflakeKafkaConnectorException; @@ -11,7 +10,6 @@ import com.snowflake.kafka.connector.internal.streaming.StreamingUtils; import java.util.Locale; import java.util.Map; -import org.junit.Assert; import org.junit.Test; public class ConnectorConfigTest { @@ -119,33 +117,6 @@ public void testEmptyHost() { Utils.validateConfig(config); } - @Test - public void testNonProxyHosts() { - String oldNonProxyHosts = - (System.getProperty(HTTP_NON_PROXY_HOSTS) != null) - ? System.getProperty(HTTP_NON_PROXY_HOSTS) - : null; - - System.setProperty(HTTP_NON_PROXY_HOSTS, "host1.com|host2.com|localhost"); - Map config = getConfig(); - config.put(SnowflakeSinkConnectorConfig.JVM_PROXY_HOST, "127.0.0.1"); - config.put(SnowflakeSinkConnectorConfig.JVM_PROXY_PORT, "3128"); - config.put( - SnowflakeSinkConnectorConfig.JVM_NON_PROXY_HOSTS, - "*.snowflakecomputing.com|*.amazonaws.com"); - Utils.enableJVMProxy(config); - String mergedNonProxyHosts = System.getProperty(HTTP_NON_PROXY_HOSTS); - Assert.assertTrue( - mergedNonProxyHosts.equals( - "host1.com|host2.com|localhost|*.snowflakecomputing.com|*.amazonaws.com")); - - if (oldNonProxyHosts != null) { - System.setProperty(HTTP_NON_PROXY_HOSTS, oldNonProxyHosts); - } else { - System.clearProperty(HTTP_NON_PROXY_HOSTS); - } - } - @Test(expected = SnowflakeKafkaConnectorException.class) public void testIllegalTopicMap() { Map config = getConfig(); @@ -622,30 +593,4 @@ public void testValidSchematizationForStreamingSnowpipe() { config.put(Utils.SF_ROLE, "ACCOUNTADMIN"); Utils.validateConfig(config); } - - @Test(expected = SnowflakeKafkaConnectorException.class) - public void testInValidConfigFileTypeForSnowpipe() { - Map config = getConfig(); - config.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "3"); - Utils.validateConfig(config); - } - - @Test - public void testValidFileTypesForSnowpipeStreaming() { - Map config = getConfig(); - config.put( - SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT, - IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); - config.put(Utils.SF_ROLE, "ACCOUNTADMIN"); - - config.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "3"); - Utils.validateConfig(config); - - config.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "1"); - Utils.validateConfig(config); - - // lower case - config.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "abcd"); - Utils.validateConfig(config); - } } diff --git a/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskForStreamingIT.java b/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskForStreamingIT.java new file mode 100644 index 000000000..31dbdf6c1 --- /dev/null +++ b/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskForStreamingIT.java @@ -0,0 +1,306 @@ +package com.snowflake.kafka.connector; + +import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS; +import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT; + +import com.snowflake.kafka.connector.internal.LoggerHandler; +import com.snowflake.kafka.connector.internal.TestUtils; +import com.snowflake.kafka.connector.internal.streaming.InMemorySinkTaskContext; +import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig; +import java.sql.ResultSet; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException; +import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; +import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.mockito.AdditionalMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.Spy; +import org.slf4j.Logger; + +/** + * Sink Task IT test which uses {@link + * com.snowflake.kafka.connector.internal.streaming.SnowflakeSinkServiceV2} + */ +public class SnowflakeSinkTaskForStreamingIT { + + private String topicName; + private static int partition = 0; + private TopicPartition topicPartition; + + @Mock Logger logger = Mockito.mock(Logger.class); + + @InjectMocks @Spy + private LoggerHandler loggerHandler = Mockito.spy(new LoggerHandler(this.getClass().getName())); + + @InjectMocks private SnowflakeSinkTask sinkTask1 = new SnowflakeSinkTask(); + + @Before + public void setup() { + topicName = TestUtils.randomTableName(); + topicPartition = new TopicPartition(topicName, partition); + } + + @After + public void after() { + TestUtils.dropTable(topicName); + } + + @Test + public void testSinkTask() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + config.put(BUFFER_COUNT_RECORDS, "1"); // override + + config.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); + + SnowflakeSinkTask sinkTask = new SnowflakeSinkTask(); + + // Inits the sinktaskcontext + sinkTask.initialize(new InMemorySinkTaskContext(Collections.singleton(topicPartition))); + sinkTask.start(config); + ArrayList topicPartitions = new ArrayList<>(); + topicPartitions.add(new TopicPartition(topicName, partition)); + sinkTask.open(topicPartitions); + + // send regular data + List records = TestUtils.createJsonStringSinkRecords(0, 1, topicName, partition); + sinkTask.put(records); + + // commit offset + final Map offsetMap = new HashMap<>(); + offsetMap.put(topicPartitions.get(0), new OffsetAndMetadata(10000)); + + TestUtils.assertWithRetry(() -> sinkTask.preCommit(offsetMap).size() == 1, 20, 5); + + TestUtils.assertWithRetry( + () -> sinkTask.preCommit(offsetMap).get(topicPartitions.get(0)).offset() == 1, 20, 5); + + sinkTask.close(topicPartitions); + sinkTask.stop(); + } + + @Test + @Ignore + public void testMultipleSinkTaskWithLogs() throws Exception { + // setup log mocking for task1 + MockitoAnnotations.initMocks(this); + Mockito.when(logger.isInfoEnabled()).thenReturn(true); + Mockito.when(logger.isDebugEnabled()).thenReturn(true); + Mockito.when(logger.isWarnEnabled()).thenReturn(true); + + // set up configs + String task0Id = "0"; + Map config0 = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config0); + config0.put(BUFFER_COUNT_RECORDS, "1"); // override + config0.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); + config0.put(Utils.TASK_ID, task0Id); + + String task1Id = "1"; + int taskOpen1Count = 0; + Map config1 = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config1); + config1.put(BUFFER_COUNT_RECORDS, "1"); // override + config1.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); + config1.put(Utils.TASK_ID, task1Id); + + SnowflakeSinkTask sinkTask0 = new SnowflakeSinkTask(); + + sinkTask0.initialize(new InMemorySinkTaskContext(Collections.singleton(topicPartition))); + sinkTask1.initialize(new InMemorySinkTaskContext(Collections.singleton(topicPartition))); + + // set up task1 logging tag + String expectedTask1Tag = + TestUtils.getExpectedLogTagWithoutCreationCount(task1Id, taskOpen1Count); + Mockito.doCallRealMethod().when(loggerHandler).setLoggerInstanceTag(expectedTask1Tag); + + // start tasks + sinkTask0.start(config0); + sinkTask1.start(config1); + + // verify task1 start logs + Mockito.verify(loggerHandler, Mockito.times(1)) + .setLoggerInstanceTag(Mockito.contains(expectedTask1Tag)); + Mockito.verify(logger, Mockito.times(2)) + .debug( + AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("start"))); + + // open tasks + ArrayList topicPartitions0 = new ArrayList<>(); + topicPartitions0.add(new TopicPartition(topicName, partition)); + ArrayList topicPartitions1 = new ArrayList<>(); + topicPartitions1.add(new TopicPartition(topicName, partition)); + + sinkTask0.open(topicPartitions0); + sinkTask1.open(topicPartitions1); + + taskOpen1Count++; + expectedTask1Tag = TestUtils.getExpectedLogTagWithoutCreationCount(task1Id, taskOpen1Count); + + // verify task1 open logs + Mockito.verify(logger, Mockito.times(1)) + .debug( + AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("open"))); + + // send data to tasks + List records0 = TestUtils.createJsonStringSinkRecords(0, 1, topicName, partition); + List records1 = TestUtils.createJsonStringSinkRecords(0, 1, topicName, partition); + + sinkTask0.put(records0); + sinkTask1.put(records1); + + // verify task1 put logs + Mockito.verify(logger, Mockito.times(1)) + .debug(AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("put"))); + + // commit offsets + final Map offsetMap0 = new HashMap<>(); + final Map offsetMap1 = new HashMap<>(); + offsetMap0.put(topicPartitions0.get(0), new OffsetAndMetadata(10000)); + offsetMap1.put(topicPartitions1.get(0), new OffsetAndMetadata(10000)); + + TestUtils.assertWithRetry(() -> sinkTask0.preCommit(offsetMap0).size() == 1, 20, 5); + TestUtils.assertWithRetry(() -> sinkTask1.preCommit(offsetMap1).size() == 1, 20, 5); + + // verify task1 precommit logs + Mockito.verify(logger, Mockito.times(1)) + .debug( + AdditionalMatchers.and( + Mockito.contains(expectedTask1Tag), Mockito.contains("precommit"))); + + TestUtils.assertWithRetry( + () -> sinkTask0.preCommit(offsetMap0).get(topicPartitions0.get(0)).offset() == 1, 20, 5); + TestUtils.assertWithRetry( + () -> sinkTask1.preCommit(offsetMap1).get(topicPartitions1.get(0)).offset() == 1, 20, 5); + + // close tasks + sinkTask0.close(topicPartitions0); + sinkTask1.close(topicPartitions1); + + // verify task1 close logs + Mockito.verify(logger, Mockito.times(1)) + .debug( + AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("closed"))); + + // stop tasks + sinkTask0.stop(); + sinkTask1.stop(); + + // verify task1 stop logs + Mockito.verify(logger, Mockito.times(1)) + .debug( + AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("stop"))); + } + + @Test + public void testSinkTaskWithMultipleOpenClose() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + config.put(BUFFER_COUNT_RECORDS, "1"); // override + + config.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); + + SnowflakeSinkTask sinkTask = new SnowflakeSinkTask(); + // Inits the sinktaskcontext + sinkTask.initialize(new InMemorySinkTaskContext(Collections.singleton(topicPartition))); + + sinkTask.start(config); + ArrayList topicPartitions = new ArrayList<>(); + topicPartitions.add(new TopicPartition(topicName, partition)); + sinkTask.open(topicPartitions); + + final long noOfRecords = 1l; + final long lastOffsetNo = noOfRecords - 1; + + // send regular data + List records = + TestUtils.createJsonStringSinkRecords(0, noOfRecords, topicName, partition); + sinkTask.put(records); + + // commit offset + final Map offsetMap = new HashMap<>(); + offsetMap.put(topicPartitions.get(0), new OffsetAndMetadata(lastOffsetNo)); + + TestUtils.assertWithRetry(() -> sinkTask.preCommit(offsetMap).size() == 1, 20, 5); + + // precommit is one more than offset last inserted + TestUtils.assertWithRetry( + () -> sinkTask.preCommit(offsetMap).get(topicPartitions.get(0)).offset() == noOfRecords, + 20, + 5); + + sinkTask.close(topicPartitions); + + // Add one more partition + topicPartitions.add(new TopicPartition(topicName, partition + 1)); + + sinkTask.open(topicPartitions); + + // trying to put same records + sinkTask.put(records); + + List recordsWithAnotherPartition = + TestUtils.createJsonStringSinkRecords(0, noOfRecords, topicName, partition + 1); + sinkTask.put(recordsWithAnotherPartition); + + // Adding to offsetMap so that this gets into precommit + offsetMap.put(topicPartitions.get(1), new OffsetAndMetadata(lastOffsetNo)); + + TestUtils.assertWithRetry(() -> sinkTask.preCommit(offsetMap).size() == 2, 20, 5); + + TestUtils.assertWithRetry( + () -> sinkTask.preCommit(offsetMap).get(topicPartitions.get(0)).offset() == 1, 20, 5); + + TestUtils.assertWithRetry( + () -> sinkTask.preCommit(offsetMap).get(topicPartitions.get(1)).offset() == 1, 20, 5); + + sinkTask.close(topicPartitions); + + sinkTask.stop(); + + ResultSet resultSet = TestUtils.showTable(topicName); + LinkedList contentResult = new LinkedList<>(); + LinkedList metadataResult = new LinkedList<>(); + + while (resultSet.next()) { + contentResult.add(resultSet.getString("RECORD_CONTENT")); + metadataResult.add(resultSet.getString("RECORD_METADATA")); + } + resultSet.close(); + assert metadataResult.size() == 2; + assert contentResult.size() == 2; + ObjectMapper mapper = new ObjectMapper(); + + Set partitionsInTable = new HashSet<>(); + metadataResult.forEach( + s -> { + try { + JsonNode metadata = mapper.readTree(s); + metadata.get("offset").asText().equals("0"); + partitionsInTable.add(metadata.get("partition").asLong()); + } catch (JsonProcessingException e) { + Assert.fail(); + } + }); + + assert partitionsInTable.size() == 2; + } +} diff --git a/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskTestForStreamingIT.java b/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskTestForStreamingIT.java deleted file mode 100644 index 5d2a1c428..000000000 --- a/src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskTestForStreamingIT.java +++ /dev/null @@ -1,256 +0,0 @@ -package com.snowflake.kafka.connector; - -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT; - -import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import com.snowflake.kafka.connector.internal.streaming.InMemorySinkTaskContext; -import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig; -import java.sql.ResultSet; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; -import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.connect.sink.SinkRecord; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -/** - * Sink Task IT test which uses {@link - * com.snowflake.kafka.connector.internal.streaming.SnowflakeSinkServiceV2} - */ -public class SnowflakeSinkTaskTestForStreamingIT { - private int taskId; - private int partitionCount; - private String topicName; - private Map config; - private List topicPartitions; - private SnowflakeSinkTask sinkTask; - private InMemorySinkTaskContext sinkTaskContext; - private List records; - private Map offsetMap; - - // sets up default objects for normal testing - // NOTE: everything defaults to having one of each (ex: # topic partitions, # tasks, # clients, - // etc) - @Before - public void setup() throws Exception { - this.taskId = 0; - this.partitionCount = 1; - this.topicName = "topicName"; - this.config = this.getConfig(this.taskId); - this.topicPartitions = getTopicPartitions(this.topicName, this.partitionCount); - this.sinkTask = new SnowflakeSinkTask(); - this.sinkTaskContext = - new InMemorySinkTaskContext(Collections.singleton(this.topicPartitions.get(0))); - this.records = TestUtils.createJsonStringSinkRecords(0, 1, this.topicName, 0); - this.offsetMap = new HashMap<>(); - this.offsetMap.put(this.topicPartitions.get(0), new OffsetAndMetadata(10000)); - - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(this.config, "testkcid", 1, 1); - assert IngestSdkProvider.getStreamingClientManager().getClientCount() == 1; - } - - @After - public void after() throws Exception { - this.sinkTask.close(this.topicPartitions); - this.sinkTask.stop(); - TestUtils.dropTable(topicName); - IngestSdkProvider.setStreamingClientManager( - TestUtils.resetAndGetEmptyStreamingClientManager()); // reset to clean initial manager - } - - @Test - public void testSinkTask() throws Exception { - // Inits the sinktaskcontext - this.sinkTask.initialize(this.sinkTaskContext); - this.sinkTask.start(this.config); - this.sinkTask.open(this.topicPartitions); - - // send regular data - this.sinkTask.put(this.records); - - // commit offset - TestUtils.assertWithRetry(() -> this.sinkTask.preCommit(this.offsetMap).size() == 1, 20, 5); - - // verify offset - TestUtils.assertWithRetry( - () -> - this.sinkTask.preCommit(this.offsetMap).get(this.topicPartitions.get(0)).offset() == 1, - 20, - 5); - - // cleanup - this.sinkTask.close(this.topicPartitions); - this.sinkTask.stop(); - } - - // test two tasks map to one client behaves as expected - @Test - public void testTaskToClientMapping() throws Exception { - // setup two tasks pointing to one client - IngestSdkProvider.setStreamingClientManager(TestUtils.resetAndGetEmptyStreamingClientManager()); - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(this.config, "kcid", 2, 2); - assert IngestSdkProvider.getStreamingClientManager().getClientCount() == 1; - - // setup task0, not strictly necessary but makes test more readable - Map config0 = this.config; - List topicPartitions0 = this.topicPartitions; - SnowflakeSinkTask sinkTask0 = this.sinkTask; - InMemorySinkTaskContext sinkTaskContext0 = this.sinkTaskContext; - List records0 = this.records; - Map offsetMap0 = this.offsetMap; - - // setup task1 - int taskId1 = 1; - String topicName1 = "topicName1"; - Map config1 = this.getConfig(taskId1); - List topicPartitions1 = getTopicPartitions(topicName1, 1); - SnowflakeSinkTask sinkTask1 = new SnowflakeSinkTask(); - InMemorySinkTaskContext sinkTaskContext1 = - new InMemorySinkTaskContext(Collections.singleton(topicPartitions1.get(0))); - List records1 = TestUtils.createJsonStringSinkRecords(0, 1, topicName1, 0); - Map offsetMap1 = new HashMap<>(); - offsetMap1.put(topicPartitions1.get(0), new OffsetAndMetadata(10000)); - - // start init and open tasks - sinkTask0.initialize(sinkTaskContext0); - sinkTask1.initialize(sinkTaskContext1); - sinkTask0.start(config0); - sinkTask1.start(config1); - sinkTask0.open(topicPartitions0); - sinkTask1.open(topicPartitions1); - - // send data to both tasks - sinkTask0.put(records0); - sinkTask1.put(records1); - - // verify that data was ingested - TestUtils.assertWithRetry(() -> sinkTask0.preCommit(offsetMap0).size() == 1, 20, 5); - TestUtils.assertWithRetry(() -> sinkTask1.preCommit(offsetMap1).size() == 1, 20, 5); - - TestUtils.assertWithRetry( - () -> sinkTask0.preCommit(offsetMap0).get(topicPartitions0.get(0)).offset() == 1, 20, 5); - TestUtils.assertWithRetry( - () -> sinkTask1.preCommit(offsetMap1).get(topicPartitions1.get(0)).offset() == 1, 20, 5); - - // clean up tasks - sinkTask0.close(topicPartitions0); - sinkTask1.close(topicPartitions1); - sinkTask0.stop(); - sinkTask1.stop(); - } - - @Test - public void testSinkTaskWithMultipleOpenClose() throws Exception { - final long noOfRecords = 1l; - final long lastOffsetNo = noOfRecords - 1; - - this.sinkTask.initialize(this.sinkTaskContext); - this.sinkTask.start(this.config); - this.sinkTask.open(this.topicPartitions); - - List recordsPart0 = this.records; - List recordsPart1 = TestUtils.createJsonStringSinkRecords(0, 1, this.topicName, 1); - - // send regular data to partition 0, verify data was committed - this.sinkTask.put(recordsPart0); - TestUtils.assertWithRetry( - () -> this.sinkTask.preCommit(this.offsetMap).size() == noOfRecords, 20, 5); - TestUtils.assertWithRetry( - () -> - this.sinkTask.preCommit(this.offsetMap).get(this.topicPartitions.get(0)).offset() - == noOfRecords, - 20, - 5); - - this.sinkTask.close(this.topicPartitions); - - // Add one more partition and open last partition - this.partitionCount++; - this.topicPartitions = this.getTopicPartitions(this.topicName, this.partitionCount); - this.sinkTask.open(this.topicPartitions); - - // trying to put records to partition 0 and 1 - this.sinkTask.put(recordsPart0); - this.sinkTask.put(recordsPart1); - - // Adding to offsetMap so that this gets into precommit - this.offsetMap.put(this.topicPartitions.get(1), new OffsetAndMetadata(lastOffsetNo)); - - // verify precommit for task and each partition - TestUtils.assertWithRetry(() -> this.sinkTask.preCommit(this.offsetMap).size() == 2, 20, 5); - TestUtils.assertWithRetry( - () -> - this.sinkTask.preCommit(this.offsetMap).get(this.topicPartitions.get(0)).offset() == 1, - 20, - 5); - TestUtils.assertWithRetry( - () -> sinkTask.preCommit(offsetMap).get(topicPartitions.get(1)).offset() == 1, 20, 5); - - // clean up - sinkTask.close(topicPartitions); - sinkTask.stop(); - - // verify content and metadata - ResultSet resultSet = TestUtils.showTable(topicName); - LinkedList contentResult = new LinkedList<>(); - LinkedList metadataResult = new LinkedList<>(); - - while (resultSet.next()) { - contentResult.add(resultSet.getString("RECORD_CONTENT")); - metadataResult.add(resultSet.getString("RECORD_METADATA")); - } - resultSet.close(); - assert metadataResult.size() == 2; - assert contentResult.size() == 2; - ObjectMapper mapper = new ObjectMapper(); - - Set partitionsInTable = new HashSet<>(); - metadataResult.forEach( - s -> { - try { - JsonNode metadata = mapper.readTree(s); - metadata.get("offset").asText().equals("0"); - partitionsInTable.add(metadata.get("partition").asLong()); - } catch (JsonProcessingException e) { - Assert.fail(); - } - }); - - assert partitionsInTable.size() == 2; - } - - public static Map getConfig(int taskId) { - Map config = TestUtils.getConfForStreaming(); - config.put(BUFFER_COUNT_RECORDS, "1"); // override - config.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); - SnowflakeSinkConnectorConfig.setDefaultValues(config); - config.put(Utils.TASK_ID, taskId + ""); - - return config; - } - - public static ArrayList getTopicPartitions(String topicName, int numPartitions) { - ArrayList topicPartitions = new ArrayList<>(); - for (int i = 0; i < numPartitions; i++) { - topicPartitions.add(new TopicPartition(topicName, i)); - } - - return topicPartitions; - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/TestUtils.java b/src/test/java/com/snowflake/kafka/connector/internal/TestUtils.java index 87a0199ca..1ae188b96 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/TestUtils.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/TestUtils.java @@ -35,9 +35,7 @@ import com.snowflake.client.jdbc.SnowflakeDriver; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import com.snowflake.kafka.connector.internal.ingestsdk.KcStreamingIngestClient; -import com.snowflake.kafka.connector.internal.ingestsdk.StreamingClientManager; +import com.snowflake.kafka.connector.internal.streaming.StreamingUtils; import com.snowflake.kafka.connector.records.SnowflakeJsonSchema; import com.snowflake.kafka.connector.records.SnowflakeRecordContent; import io.confluent.connect.avro.AvroConverter; @@ -62,6 +60,8 @@ import java.util.regex.Pattern; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaAndValue; @@ -165,7 +165,6 @@ private static Map getPropertiesMapFromProfile(final String prof configuration.put(Utils.SF_URL, getProfile(profileFileName).get(HOST).asText()); configuration.put(Utils.SF_WAREHOUSE, getProfile(profileFileName).get(WAREHOUSE).asText()); configuration.put(Utils.SF_PRIVATE_KEY, getProfile(profileFileName).get(PRIVATE_KEY).asText()); - // configuration.put(Utils.SF_ROLE, getProfile(profileFileName).get(ROLE).asText()); configuration.put(Utils.NAME, TEST_CONNECTOR_NAME); @@ -373,28 +372,9 @@ public static boolean assertError(SnowflakeErrors error, Runnable func) { return false; } - /** - * Check Snowflake Error Code in test - * - * @param error Snowflake error - * @param func function throwing exception - * @return true is error code is correct, otherwise, false - */ - public static boolean assertExceptionType(Class exceptionClass, Runnable func) { - try { - func.run(); - } catch (Exception ex) { - return ex.getClass().equals(exceptionClass); - } - return false; - } - /** @return snowflake connection for test */ public static SnowflakeConnectionService getConnectionService() { - return SnowflakeConnectionServiceFactory.builder() - .setProperties(getConf()) - .setTaskID("0") - .build(); + return SnowflakeConnectionServiceFactory.builder().setProperties(getConf()).build(); } /** @@ -544,15 +524,6 @@ public static List createNativeJsonSinkRecords( final int partitionNo) { ArrayList records = new ArrayList<>(); - for (long i = startOffset; i < startOffset + noOfRecords; ++i) { - records.add(createNativeJsonSinkRecord(i, topicName, partitionNo)); - } - - return records; - } - - public static SinkRecord createNativeJsonSinkRecord( - final long offset, final String topicName, final int partitionNo) { JsonConverter converter = new JsonConverter(); HashMap converterConfig = new HashMap<>(); converterConfig.put("schemas.enable", "true"); @@ -561,14 +532,18 @@ public static SinkRecord createNativeJsonSinkRecord( converter.toConnectData( "test", TestUtils.JSON_WITH_SCHEMA.getBytes(StandardCharsets.UTF_8)); - return new SinkRecord( - topicName, - partitionNo, - Schema.STRING_SCHEMA, - "test", - schemaInputValue.schema(), - schemaInputValue.value(), - offset); + for (long i = startOffset; i < startOffset + noOfRecords; ++i) { + records.add( + new SinkRecord( + topicName, + partitionNo, + Schema.STRING_SCHEMA, + "test", + schemaInputValue.schema(), + schemaInputValue.value(), + i)); + } + return records; } /* Generate (noOfRecords - startOffset) for a given topic and partition which were essentially avro records */ @@ -760,14 +735,12 @@ public static String getExpectedLogTagWithoutCreationCount(String taskId, int ta return Utils.formatString(TASK_INSTANCE_TAG_FORMAT, taskId, taskOpenCount, "").split("#")[0]; } - /** Resets existing streaming clients and get a new client manager used in testing. */ - public static StreamingClientManager resetAndGetEmptyStreamingClientManager() { - Map taskToClientMap = - IngestSdkProvider.getStreamingClientManager().getTaskToClientMap(); - if (taskToClientMap != null && !taskToClientMap.isEmpty()) { - taskToClientMap.forEach( - (integer, kcStreamingIngestClient) -> kcStreamingIngestClient.close()); - } - return new StreamingClientManager(new HashMap<>()); + public static SnowflakeStreamingIngestClient createStreamingClient( + Map config, String clientName) { + Properties clientProperties = new Properties(); + clientProperties.putAll(StreamingUtils.convertConfigForStreamingClient(new HashMap<>(config))); + return SnowflakeStreamingIngestClientFactory.builder(clientName) + .setProperties(clientProperties) + .build(); } } diff --git a/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/KcStreamingIngestClientTest.java b/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/KcStreamingIngestClientTest.java deleted file mode 100644 index 1f6ea1ccc..000000000 --- a/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/KcStreamingIngestClientTest.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2023 Snowflake Inc. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package com.snowflake.kafka.connector.internal.ingestsdk; - -import static net.snowflake.ingest.utils.ParameterProvider.BLOB_FORMAT_VERSION; - -import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; -import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.streaming.StreamingUtils; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; -import net.snowflake.ingest.streaming.OpenChannelRequest; -import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; -import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; -import org.apache.kafka.connect.errors.ConnectException; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.Mockito; - -public class KcStreamingIngestClientTest { - private String clientName; - private Map config; - private Properties properties; - - private SnowflakeStreamingIngestClient mockClient; - - @Before - public void setup() { - this.config = TestUtils.getConfForStreaming(); - this.clientName = - KcStreamingIngestClient.buildStreamingIngestClientName( - config.getOrDefault(Utils.NAME, "TEST_CONNECTOR_NAME"), "testKcId", 0); - SnowflakeSinkConnectorConfig.setDefaultValues(config); - this.properties = new Properties(); - this.properties.putAll(StreamingUtils.convertConfigForStreamingClient(new HashMap<>(config))); - - this.mockClient = Mockito.mock(SnowflakeStreamingIngestClient.class); - } - - @Test - public void testCreateClient() { - // setup - Mockito.when(this.mockClient.getName()).thenReturn(this.clientName); - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - - // test - KcStreamingIngestClient kcActualClient = - new KcStreamingIngestClient(this.properties, null, this.clientName); - - // verify - assert kcActualClient.getName().equals(kcMockClient.getName()); - Assert.assertTrue(kcActualClient.getName().contains(this.config.get(Utils.NAME))); - Mockito.verify(this.mockClient, Mockito.times(1)).getName(); - } - - @Test - public void testCreateClientFailure() { - TestUtils.assertExceptionType( - ConnectException.class, () -> new KcStreamingIngestClient(null, null, null)); - TestUtils.assertExceptionType( - ConnectException.class, () -> new KcStreamingIngestClient(null, null, this.clientName)); - TestUtils.assertExceptionType( - ConnectException.class, () -> new KcStreamingIngestClient(this.properties, null, null)); - } - - @Test - public void testCreateClientWithArrowBDECFileFormat() { - // setup - Mockito.when(this.mockClient.getName()).thenReturn(this.clientName); - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - - Map parameterOverrides = Collections.singletonMap(BLOB_FORMAT_VERSION, "1"); - - // test - KcStreamingIngestClient kcActualClient = - new KcStreamingIngestClient(this.properties, parameterOverrides, this.clientName); - - // verify - assert kcActualClient.getName().equals(kcMockClient.getName()); - Mockito.verify(this.mockClient, Mockito.times(1)).getName(); - } - - @Test - public void testOpenChannel() { - String channelName = "testchannel"; - String tableName = "testtable"; - this.config.put(Utils.SF_DATABASE, "testdb"); - this.config.put(Utils.SF_SCHEMA, "testschema"); - OpenChannelRequest request = - OpenChannelRequest.builder(channelName) - .setDBName(this.config.get(Utils.SF_DATABASE)) - .setSchemaName(this.config.get(Utils.SF_SCHEMA)) - .setTableName(tableName) - .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) - .build(); - - // setup mocks - SnowflakeStreamingIngestChannel goalChannel = - Mockito.mock(SnowflakeStreamingIngestChannel.class); - Mockito.when(this.mockClient.openChannel(ArgumentMatchers.refEq(request))) - .thenReturn(goalChannel); - - // test - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - SnowflakeStreamingIngestChannel res = - kcMockClient.openChannel(channelName, this.config, tableName); - - // verify - assert res.equals(goalChannel); - Mockito.verify(this.mockClient, Mockito.times(1)).openChannel(ArgumentMatchers.refEq(request)); - } - - @Test - public void testCloseClient() throws Exception { - Mockito.when(this.mockClient.isClosed()).thenReturn(false); - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - assert kcMockClient.close(); - Mockito.verify(this.mockClient, Mockito.times(1)).close(); - Mockito.verify(this.mockClient, Mockito.times(1)).isClosed(); - } - - @Test - public void testCloseAlreadyClosedClient() throws Exception { - Mockito.when(this.mockClient.isClosed()).thenReturn(true); - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - assert kcMockClient.close(); - Mockito.verify(this.mockClient, Mockito.times(1)).isClosed(); - } - - @Test - public void testCloseClientFailure() throws Exception { - Exception exceptionToThrow = new Exception(); - this.testCloseClientFailureRunner(exceptionToThrow); - exceptionToThrow = new Exception("did you pet a cat today though"); - this.testCloseClientFailureRunner(exceptionToThrow); - exceptionToThrow.initCause(new Exception("because you should")); - this.testCloseClientFailureRunner(exceptionToThrow); - } - - private void testCloseClientFailureRunner(Exception exceptionToThrow) throws Exception { - this.mockClient = Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.doThrow(exceptionToThrow).when(this.mockClient).close(); - Mockito.when(this.mockClient.isClosed()).thenReturn(false); - - // test - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - assert !kcMockClient.close(); - - // verify - Mockito.verify(this.mockClient, Mockito.times(1)).close(); - Mockito.verify(this.mockClient, Mockito.times(1)).isClosed(); - } - - @Test - public void testClientIsClosed() { - boolean isClosed = false; - Mockito.when(this.mockClient.isClosed()).thenReturn(isClosed); - KcStreamingIngestClient kcMockClient = new KcStreamingIngestClient(this.mockClient); - assert kcMockClient.isClosed() == isClosed; - Mockito.verify(this.mockClient, Mockito.times(1)).isClosed(); - } - - @Test - public void testInvalidInsertRowsWithInvalidBDECFormat() throws Exception { - // Wipe off existing clients. - IngestSdkProvider.setStreamingClientManager( - TestUtils.resetAndGetEmptyStreamingClientManager()); // reset to clean initial manager - - // add config which overrides the bdec file format - Map overriddenConfig = new HashMap<>(this.config); - overriddenConfig.put( - SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, - "TWOO_HUNDRED"); // some random string not supported in enum - try { - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(overriddenConfig, "testkcid", 1, 1); - } catch (IllegalArgumentException ex) { - Assert.assertEquals(NumberFormatException.class, ex.getCause().getClass()); - } - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManagerTest.java b/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManagerTest.java deleted file mode 100644 index 29710968a..000000000 --- a/src/test/java/com/snowflake/kafka/connector/internal/ingestsdk/StreamingClientManagerTest.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright (c) 2023 Snowflake Inc. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package com.snowflake.kafka.connector.internal.ingestsdk; - -import com.snowflake.kafka.connector.internal.SnowflakeErrors; -import com.snowflake.kafka.connector.internal.TestUtils; -import java.util.HashMap; -import java.util.Map; -import org.junit.Test; -import org.mockito.Mockito; - -public class StreamingClientManagerTest { - - @Test - public void testCreateAndGetAllStreamingClientsWithUnevenRatio() { - // test to create the following mapping - // [0, 1] -> clientA, [2, 3] -> clientB, [4] -> clientC - // test - int maxTasks = 5; - int numTasksPerClient = 2; - - StreamingClientManager manager = new StreamingClientManager(); - manager.createAllStreamingClients( - TestUtils.getConfForStreaming(), "testkcid", maxTasks, numTasksPerClient); - - // verify - KcStreamingIngestClient task0Client = manager.getValidClient(0); - KcStreamingIngestClient task1Client = manager.getValidClient(1); - assert task0Client.equals(task1Client); - - KcStreamingIngestClient task2Client = manager.getValidClient(2); - KcStreamingIngestClient task3Client = manager.getValidClient(3); - assert task2Client.equals(task3Client); - assert !task2Client.equals(task0Client); - - KcStreamingIngestClient task4Client = manager.getValidClient(4); - assert !task4Client.equals(task0Client); - assert !task4Client.equals(task2Client); - - assert Math.ceil((double) maxTasks / (double) numTasksPerClient) == manager.getClientCount(); - - // close clients - task0Client.close(); - task1Client.close(); - task2Client.close(); - task3Client.close(); - task4Client.close(); - } - - @Test - public void testCreateAndGetAllStreamingClientsWithEvenRatio() { - // test to create the following mapping - // [0, 1, 2] -> clientA, [3, 4, 5] -> clientB - // test - int maxTasks = 6; - int numTasksPerClient = 3; - - StreamingClientManager manager = new StreamingClientManager(); - manager.createAllStreamingClients( - TestUtils.getConfForStreaming(), "testkcid", maxTasks, numTasksPerClient); - - // verify - KcStreamingIngestClient task0Client = manager.getValidClient(0); - KcStreamingIngestClient task1Client = manager.getValidClient(1); - KcStreamingIngestClient task2Client = manager.getValidClient(2); - assert task0Client.equals(task1Client); - assert task1Client.equals(task2Client); - - KcStreamingIngestClient task3Client = manager.getValidClient(3); - KcStreamingIngestClient task4Client = manager.getValidClient(4); - KcStreamingIngestClient task5Client = manager.getValidClient(5); - assert task3Client.equals(task4Client); - assert task4Client.equals(task5Client); - - assert !task0Client.equals(task5Client); - - assert Math.ceil((double) maxTasks / (double) numTasksPerClient) == manager.getClientCount(); - - // close clients - task0Client.close(); - task1Client.close(); - task2Client.close(); - task3Client.close(); - task4Client.close(); - task5Client.close(); - } - - @Test - public void testCloseAllStreamingClients() { - // test to close the following mapping - // [0, 1] -> clientA, [2, 3] -> clientB, [4] -> clientC - KcStreamingIngestClient task01Client = Mockito.mock(KcStreamingIngestClient.class); - KcStreamingIngestClient task23Client = Mockito.mock(KcStreamingIngestClient.class); - KcStreamingIngestClient task4Client = Mockito.mock(KcStreamingIngestClient.class); - - Mockito.when(task01Client.close()).thenReturn(true); - Mockito.when(task23Client.close()).thenReturn(true); - Mockito.when(task4Client.close()).thenReturn(true); - - Map taskToClientMap = new HashMap<>(); - taskToClientMap.put(0, task01Client); - taskToClientMap.put(1, task01Client); - taskToClientMap.put(2, task23Client); - taskToClientMap.put(3, task23Client); - taskToClientMap.put(4, task4Client); - - StreamingClientManager manager = new StreamingClientManager(taskToClientMap); - - // test - assert manager.closeAllStreamingClients(); - - // verify - Mockito.verify(task01Client, Mockito.times(2)).close(); - Mockito.verify(task23Client, Mockito.times(2)).close(); - Mockito.verify(task4Client, Mockito.times(1)).close(); - } - - @Test - public void testGetClosedClient() { - int taskId = 0; - Map taskToClientMap = new HashMap<>(); - - KcStreamingIngestClient mockClient = Mockito.mock(KcStreamingIngestClient.class); - Mockito.when(mockClient.isClosed()).thenReturn(true); - taskToClientMap.put(taskId, mockClient); - - StreamingClientManager manager = new StreamingClientManager(taskToClientMap); - TestUtils.assertError(SnowflakeErrors.ERROR_3009, () -> manager.getValidClient(taskId)); - - Mockito.verify(mockClient, Mockito.times(1)).isClosed(); - } - - @Test - public void testGetNullClient() { - int taskId = 0; - Map taskToClientMap = new HashMap<>(); - taskToClientMap.put(taskId, null); - - StreamingClientManager manager = new StreamingClientManager(taskToClientMap); - TestUtils.assertError(SnowflakeErrors.ERROR_3009, () -> manager.getValidClient(taskId)); - } - - @Test - public void testGetInvalidTaskId() { - int maxTasks = 5; - - // create with max task id 4 (starts from 0) - StreamingClientManager manager = new StreamingClientManager(); - manager.createAllStreamingClients(TestUtils.getConfForStreaming(), "testkcid", maxTasks, 2); - - // test throws error - TestUtils.assertError(SnowflakeErrors.ERROR_3010, () -> manager.getValidClient(-1)); - TestUtils.assertError(SnowflakeErrors.ERROR_3010, () -> manager.getValidClient(maxTasks)); - - // verify can get a client - assert manager.getValidClient(0) != null; - } - - @Test - public void testGetUnInitClient() { - int taskId = 0; - Map taskToClientMap = new HashMap<>(); - - StreamingClientManager manager = new StreamingClientManager(taskToClientMap); - TestUtils.assertError(SnowflakeErrors.ERROR_3009, () -> manager.getValidClient(taskId)); - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java index cb9150d68..a218b51b4 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java @@ -1,6 +1,7 @@ package com.snowflake.kafka.connector.internal.streaming; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.dlq.InMemoryKafkaRecordErrorReporter; import com.snowflake.kafka.connector.internal.SchematizationTestUtils; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; @@ -8,7 +9,6 @@ import com.snowflake.kafka.connector.internal.SnowflakeSinkService; import com.snowflake.kafka.connector.internal.SnowflakeSinkServiceFactory; import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; import com.snowflake.kafka.connector.records.SnowflakeConverter; import com.snowflake.kafka.connector.records.SnowflakeJsonConverter; import io.confluent.connect.avro.AvroConverter; @@ -25,19 +25,19 @@ import java.util.List; import java.util.Map; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; +import net.snowflake.ingest.utils.SFException; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaAndValue; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.ConnectException; import org.apache.kafka.connect.json.JsonConverter; import org.apache.kafka.connect.sink.SinkRecord; import org.junit.After; -import org.junit.Before; import org.junit.Test; public class SnowflakeSinkServiceV2IT { - private final String clientName = "testclient"; private SnowflakeConnectionService conn = TestUtils.getConnectionService(); private String table = TestUtils.randomTableName(); @@ -47,30 +47,18 @@ public class SnowflakeSinkServiceV2IT { private TopicPartition topicPartition = new TopicPartition(topic, partition); private static ObjectMapper MAPPER = new ObjectMapper(); - private Map config; - - @Before - public void setup() { - // config - this.config = TestUtils.getConfForStreaming(); - SnowflakeSinkConnectorConfig.setDefaultValues(this.config); - - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(config, "testkcid", 1, 1); - } - @After - public void afterEach() throws Exception { + public void afterEach() { TestUtils.dropTable(table); - IngestSdkProvider.setStreamingClientManager( - TestUtils.resetAndGetEmptyStreamingClientManager()); // reset to clean initial manager } @Test public void testSinkServiceV2Builder() { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .build(); assert service instanceof SnowflakeSinkServiceV2; @@ -80,7 +68,7 @@ public void testSinkServiceV2Builder() { SnowflakeErrors.ERROR_5010, () -> SnowflakeSinkServiceFactory.builder( - null, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + null, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .build()); assert TestUtils.assertError( SnowflakeErrors.ERROR_5010, @@ -88,7 +76,7 @@ public void testSinkServiceV2Builder() { SnowflakeConnectionService conn = TestUtils.getConnectionService(); conn.close(); SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .build(); }); } @@ -101,8 +89,7 @@ public void testChannelCloseIngestion() throws Exception { // opens a channel for partition 0, table and topic SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -146,8 +133,7 @@ public void testRebalanceOpenCloseIngestion() throws Exception { // opens a channel for partition 0, table and topic SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -192,8 +178,7 @@ public void testStreamingIngestion() throws Exception { // opens a channel for partition 0, table and topic SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -257,8 +242,7 @@ public void testStreamingIngest_multipleChannelPartitions() throws Exception { // opens a channel for partition 0, table and topic SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(5) .setFlushTime(5) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) @@ -309,8 +293,7 @@ public void testStreamingIngestion_timeBased() throws Exception { // opens a channel for partition 0, table and topic SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(100) .setFlushTime(11) // 11 seconds .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) @@ -419,8 +402,7 @@ public void testNativeJsonInputIngestion() throws Exception { startOffset + 3); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -583,8 +565,7 @@ public void testNativeAvroInputIngestion() throws Exception { conn.createTable(table); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -644,8 +625,7 @@ public void testBrokenIngestion() throws Exception { InMemoryKafkaRecordErrorReporter errorReporter = new InMemoryKafkaRecordErrorReporter(); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(errorReporter) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -692,8 +672,7 @@ public void testBrokenRecordIngestionFollowedUpByValidRecord() throws Exception InMemoryKafkaRecordErrorReporter errorReporter = new InMemoryKafkaRecordErrorReporter(); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setErrorReporter(errorReporter) .setRecordNumber(recordCount) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -751,8 +730,7 @@ public void testBrokenRecordIngestionAfterValidRecord() throws Exception { InMemoryKafkaRecordErrorReporter errorReporter = new InMemoryKafkaRecordErrorReporter(); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setErrorReporter(errorReporter) .setRecordNumber(recordCount) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -775,8 +753,23 @@ public void testBrokenRecordIngestionAfterValidRecord() throws Exception { service.closeAll(); } - /* Service start -> Insert -> Close. service start -> fetch the offsetToken, compare and ingest - check data */ + @Test(expected = ConnectException.class) + public void testMissingPropertiesForStreamingClient() { + Map config = TestUtils.getConfForStreaming(); + config.remove(Utils.SF_ROLE); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + + try { + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) + .build(); + } catch (ConnectException ex) { + assert ex.getCause() instanceof SFException; + assert ex.getCause().getMessage().contains("Missing role"); + throw ex; + } + } + + /* Service start -> Insert -> Close. service start -> fetch the offsetToken, compare and ingest check data */ @Test public void testStreamingIngestionWithExactlyOnceSemanticsNoOverlappingOffsets() @@ -785,8 +778,7 @@ public void testStreamingIngestionWithExactlyOnceSemanticsNoOverlappingOffsets() Map config = TestUtils.getConfForStreaming(); SnowflakeSinkConnectorConfig.setDefaultValues(config); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -814,8 +806,7 @@ public void testStreamingIngestionWithExactlyOnceSemanticsNoOverlappingOffsets() // initialize a new sink service SnowflakeSinkService service2 = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -839,8 +830,7 @@ public void testStreamingIngestionWithExactlyOnceSemanticsNoOverlappingOffsets() service2.closeAll(); } - /* Service start -> Insert -> Close. service start -> fetch the offsetToken, compare and ingest - check data */ + /* Service start -> Insert -> Close. service start -> fetch the offsetToken, compare and ingest check data */ @Test public void testStreamingIngestionWithExactlyOnceSemanticsOverlappingOffsets() throws Exception { @@ -848,8 +838,7 @@ public void testStreamingIngestionWithExactlyOnceSemanticsOverlappingOffsets() t Map config = TestUtils.getConfForStreaming(); SnowflakeSinkConnectorConfig.setDefaultValues(config); SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) @@ -878,8 +867,7 @@ public void testStreamingIngestionWithExactlyOnceSemanticsOverlappingOffsets() t // initialize a new sink service SnowflakeSinkService service2 = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(1) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition))) diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkTaskTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkTaskTest.java deleted file mode 100644 index 429130a67..000000000 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkTaskTest.java +++ /dev/null @@ -1,146 +0,0 @@ -package com.snowflake.kafka.connector.internal.streaming; - -import com.snowflake.kafka.connector.SnowflakeSinkTask; -import com.snowflake.kafka.connector.SnowflakeSinkTaskTestForStreamingIT; -import com.snowflake.kafka.connector.internal.LoggerHandler; -import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.connect.sink.SinkRecord; -import org.junit.Ignore; -import org.junit.Test; -import org.mockito.AdditionalMatchers; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.Spy; -import org.slf4j.Logger; - -public class SnowflakeSinkTaskTest { - - // JUST FOR LOGGING TESTING, these should not be used anywhere else - @Mock private Logger logger = Mockito.mock(Logger.class); - - @InjectMocks @Spy - private LoggerHandler loggerHandler = Mockito.spy(new LoggerHandler(this.getClass().getName())); - - @InjectMocks private SnowflakeSinkTask sinkTask1 = new SnowflakeSinkTask(); - - @Ignore - @Test - public void testMultipleSinkTaskWithLogs() throws Exception { - // setup task0, the real one - int taskId0 = 0; - String topicName0 = "topicName0"; - SnowflakeSinkTask sinkTask0 = new SnowflakeSinkTask(); - Map config0 = SnowflakeSinkTaskTestForStreamingIT.getConfig(taskId0); - List topicPartitions0 = - SnowflakeSinkTaskTestForStreamingIT.getTopicPartitions(topicName0, 1); - InMemorySinkTaskContext sinkTaskContext0 = - new InMemorySinkTaskContext(Collections.singleton(topicPartitions0.get(0))); - List records0 = TestUtils.createJsonStringSinkRecords(0, 1, topicName0, 0); - Map offsetMap0 = new HashMap<>(); - offsetMap0.put(topicPartitions0.get(0), new OffsetAndMetadata(10000)); - - // set up task1, the !real one (logging verification one) - // basically the same as sinktask0 except its logger is mocked - int taskId1 = 1; - String topicName1 = "topicName1"; - Map config1 = SnowflakeSinkTaskTestForStreamingIT.getConfig(taskId1); - List topicPartitions1 = - SnowflakeSinkTaskTestForStreamingIT.getTopicPartitions(topicName1, 1); - InMemorySinkTaskContext sinkTaskContext1 = - new InMemorySinkTaskContext(Collections.singleton(topicPartitions1.get(0))); - List records1 = TestUtils.createJsonStringSinkRecords(0, 1, topicName1, 0); - Map offsetMap1 = new HashMap<>(); - offsetMap1.put(topicPartitions1.get(0), new OffsetAndMetadata(10000)); - // task1 logging - int task1OpenCount = 0; - MockitoAnnotations.initMocks(this); - Mockito.when(logger.isInfoEnabled()).thenReturn(true); - Mockito.when(logger.isDebugEnabled()).thenReturn(true); - Mockito.when(logger.isWarnEnabled()).thenReturn(true); - String expectedTask1Tag = - TestUtils.getExpectedLogTagWithoutCreationCount(taskId1 + "", task1OpenCount); - Mockito.doCallRealMethod().when(loggerHandler).setLoggerInstanceTag(expectedTask1Tag); - - // set up two clients - IngestSdkProvider.setStreamingClientManager(TestUtils.resetAndGetEmptyStreamingClientManager()); - IngestSdkProvider.getStreamingClientManager().createAllStreamingClients(config0, "kcid", 2, 1); - assert IngestSdkProvider.getStreamingClientManager().getClientCount() == 2; - - // init tasks - sinkTask0.initialize(sinkTaskContext0); - sinkTask1.initialize(sinkTaskContext1); - - // start tasks - sinkTask0.start(config0); - sinkTask1.start(config1); - - // verify task1 start logs - Mockito.verify(loggerHandler, Mockito.times(1)) - .setLoggerInstanceTag(Mockito.contains(expectedTask1Tag)); - Mockito.verify(logger, Mockito.times(2)) - .debug( - AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("start"))); - - // open tasks - sinkTask0.open(topicPartitions0); - sinkTask1.open(topicPartitions1); - - // verify task1 open logs - task1OpenCount++; - expectedTask1Tag = - TestUtils.getExpectedLogTagWithoutCreationCount(taskId1 + "", task1OpenCount); - Mockito.verify(logger, Mockito.times(1)) - .debug( - AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("open"))); - - // send data to tasks - sinkTask0.put(records0); - sinkTask1.put(records1); - - // verify task1 put logs - Mockito.verify(logger, Mockito.times(1)) - .debug(AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("put"))); - - // commit offsets - TestUtils.assertWithRetry(() -> sinkTask0.preCommit(offsetMap0).size() == 1, 20, 5); - TestUtils.assertWithRetry(() -> sinkTask1.preCommit(offsetMap1).size() == 1, 20, 5); - - // verify task1 precommit logs - Mockito.verify(logger, Mockito.times(1)) - .debug( - AdditionalMatchers.and( - Mockito.contains(expectedTask1Tag), Mockito.contains("precommit"))); - - TestUtils.assertWithRetry( - () -> sinkTask0.preCommit(offsetMap0).get(topicPartitions0.get(0)).offset() == 1, 20, 5); - TestUtils.assertWithRetry( - () -> sinkTask1.preCommit(offsetMap1).get(topicPartitions1.get(0)).offset() == 1, 20, 5); - - // close tasks - sinkTask0.close(topicPartitions0); - sinkTask1.close(topicPartitions1); - - // verify task1 close logs - Mockito.verify(logger, Mockito.times(1)) - .debug( - AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("closed"))); - - // stop tasks - sinkTask0.stop(); - sinkTask1.stop(); - - // verify task1 stop logs - Mockito.verify(logger, Mockito.times(1)) - .debug( - AdditionalMatchers.and(Mockito.contains(expectedTask1Tag), Mockito.contains("stop"))); - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelIT.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelIT.java index 9b75219a7..d0994c5a7 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelIT.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelIT.java @@ -1,17 +1,18 @@ package com.snowflake.kafka.connector.internal.streaming; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.dlq.InMemoryKafkaRecordErrorReporter; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; import com.snowflake.kafka.connector.internal.SnowflakeSinkService; import com.snowflake.kafka.connector.internal.SnowflakeSinkServiceFactory; import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.connect.sink.SinkRecord; import org.junit.After; @@ -20,7 +21,6 @@ import org.junit.Test; public class TopicPartitionChannelIT { - private final String clientName = "testclient"; private SnowflakeConnectionService conn = TestUtils.getConnectionService(); private String testTableName; @@ -30,35 +30,29 @@ public class TopicPartitionChannelIT { private TopicPartition topicPartition, topicPartition2; private String testChannelName, testChannelName2; - private Map config; - @Before public void beforeEach() { testTableName = TestUtils.randomTableName(); topic = testTableName; - topicPartition = new TopicPartition(topic, PARTITION); + topicPartition2 = new TopicPartition(topic, PARTITION_2); testChannelName = SnowflakeSinkServiceV2.partitionChannelKey(topic, PARTITION); - testChannelName2 = SnowflakeSinkServiceV2.partitionChannelKey(topic, PARTITION_2); - - this.config = TestUtils.getConfForStreaming(); - SnowflakeSinkConnectorConfig.setDefaultValues(this.config); - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(this.config, "testkcid", 2, 1); + testChannelName2 = SnowflakeSinkServiceV2.partitionChannelKey(topic, PARTITION_2); } @After - public void afterEach() throws Exception { + public void afterEach() { TestUtils.dropTable(testTableName); - IngestSdkProvider.setStreamingClientManager( - TestUtils.resetAndGetEmptyStreamingClientManager()); // reset to clean initial manager } @Test public void testAutoChannelReopenOn_OffsetTokenSFException() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + InMemorySinkTaskContext inMemorySinkTaskContext = new InMemorySinkTaskContext(Collections.singleton(topicPartition)); @@ -87,14 +81,14 @@ public void testAutoChannelReopenOn_OffsetTokenSFException() throws Exception { // Ctor of TopicPartitionChannel tries to open the channel. TopicPartitionChannel channel = new TopicPartitionChannel( + snowflakeSinkServiceV2.getStreamingIngestClient(), topicPartition, testChannelName, testTableName, new StreamingBufferThreshold(10, 10_000, 1), config, new InMemoryKafkaRecordErrorReporter(), - new InMemorySinkTaskContext(Collections.singleton(topicPartition)), - conn); + new InMemorySinkTaskContext(Collections.singleton(topicPartition))); // since channel is updated, try to insert data again or may be call getOffsetToken // We will reopen the channel in since the older channel in service is stale because we @@ -109,6 +103,9 @@ public void testAutoChannelReopenOn_OffsetTokenSFException() throws Exception { /* This will automatically open the channel. */ @Test public void testInsertRowsOnChannelClosed() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + InMemorySinkTaskContext inMemorySinkTaskContext = new InMemorySinkTaskContext(Collections.singleton(topicPartition)); @@ -164,6 +161,9 @@ public void testInsertRowsOnChannelClosed() throws Exception { */ @Test public void testAutoChannelReopen_InsertRowsSFException() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + InMemorySinkTaskContext inMemorySinkTaskContext = new InMemorySinkTaskContext(Collections.singleton(topicPartition)); @@ -224,7 +224,7 @@ public void testAutoChannelReopen_InsertRowsSFException() throws Exception { } /** - * Two partitions for a topic Partition 1 -> 10(0-9) records -> Success Partition 2 -> 10(0-9) + * Two partions for a topic Partition 1 -> 10(0-9) records -> Success Partition 2 -> 10(0-9) * records -> Success * *

Partition 1 -> Channel 1 -> open with same client Client sequencer for channel 1 - 1 @@ -238,13 +238,15 @@ public void testAutoChannelReopen_InsertRowsSFException() throws Exception { */ @Test public void testAutoChannelReopen_MultiplePartitionsInsertRowsSFException() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + InMemorySinkTaskContext inMemorySinkTaskContext = new InMemorySinkTaskContext(Collections.singleton(topicPartition)); // This will automatically create a channel for topicPartition. SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder( - conn, IngestionMethodConfig.SNOWPIPE_STREAMING, this.config) + SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) .setRecordNumber(5) .setFlushTime(5) .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) @@ -276,10 +278,18 @@ public void testAutoChannelReopen_MultiplePartitionsInsertRowsSFException() thro 20, 5); + SnowflakeStreamingIngestClient client = + ((SnowflakeSinkServiceV2) service).getStreamingIngestClient(); + OpenChannelRequest channelRequest = + OpenChannelRequest.builder(testChannelName) + .setDBName(config.get(Utils.SF_DATABASE)) + .setSchemaName(config.get(Utils.SF_SCHEMA)) + .setTableName(this.testTableName) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .build(); + // Open a channel with same name will bump up the client sequencer number for this channel - IngestSdkProvider.getStreamingClientManager() - .getValidClient(0) - .openChannel(testChannelName, config, testTableName); + client.openChannel(channelRequest); assert TestUtils.getClientSequencerForChannelAndTable(testTableName, testChannelName) == 1; @@ -334,6 +344,9 @@ public void testAutoChannelReopen_MultiplePartitionsInsertRowsSFException() thro @Test public void testAutoChannelReopen_SinglePartitionsInsertRowsSFException() throws Exception { + Map config = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(config); + InMemorySinkTaskContext inMemorySinkTaskContext = new InMemorySinkTaskContext(Collections.singleton(topicPartition)); @@ -360,9 +373,16 @@ public void testAutoChannelReopen_SinglePartitionsInsertRowsSFException() throws 20, 5); - IngestSdkProvider.getStreamingClientManager() - .getValidClient(0) - .openChannel(testChannelName, config, testTableName); + SnowflakeStreamingIngestClient client = + ((SnowflakeSinkServiceV2) service).getStreamingIngestClient(); + OpenChannelRequest channelRequest = + OpenChannelRequest.builder(testChannelName) + .setDBName(config.get(Utils.SF_DATABASE)) + .setSchemaName(config.get(Utils.SF_SCHEMA)) + .setTableName(this.testTableName) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .build(); + client.openChannel(channelRequest); Thread.sleep(5_000); // send offset 10 - 19 @@ -388,41 +408,4 @@ public void testAutoChannelReopen_SinglePartitionsInsertRowsSFException() throws assert TestUtils.getOffsetTokenForChannelAndTable(testTableName, testChannelName) == (recordsInPartition1 + anotherSetOfRecords - 1); } - - /* This will automatically open the channel. */ - @Test - public void testSimpleInsertRowsWithArrowBDECFormat() throws Exception { - // Wipe off existing clients. - IngestSdkProvider.setStreamingClientManager( - TestUtils.resetAndGetEmptyStreamingClientManager()); // reset to clean initial manager - - // add config which overrides the bdec file format - Map overriddenConfig = new HashMap<>(this.config); - overriddenConfig.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "1"); - IngestSdkProvider.getStreamingClientManager() - .createAllStreamingClients(overriddenConfig, "testkcid", 1, 1); - - InMemorySinkTaskContext inMemorySinkTaskContext = - new InMemorySinkTaskContext(Collections.singleton(topicPartition)); - - // This will automatically create a channel for topicPartition. - SnowflakeSinkService service = - SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config) - .setRecordNumber(1) - .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) - .setSinkTaskContext(inMemorySinkTaskContext) - .addTask(testTableName, topicPartition) - .build(); - - final long noOfRecords = 1; - - // send regular data - List records = - TestUtils.createJsonStringSinkRecords(0, noOfRecords, topic, PARTITION); - - service.insert(records); - - TestUtils.assertWithRetry( - () -> service.getOffset(new TopicPartition(topic, PARTITION)) == noOfRecords, 20, 5); - } } diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java index 34139f6e9..26655feff 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java @@ -8,13 +8,11 @@ import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.MAX_GET_OFFSET_TOKEN_RETRIES; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.dlq.InMemoryKafkaRecordErrorReporter; import com.snowflake.kafka.connector.dlq.KafkaRecordErrorReporter; +import com.snowflake.kafka.connector.internal.BufferThreshold; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; -import com.snowflake.kafka.connector.internal.SnowflakeErrors; import com.snowflake.kafka.connector.internal.TestUtils; -import com.snowflake.kafka.connector.internal.ingestsdk.IngestSdkProvider; -import com.snowflake.kafka.connector.internal.ingestsdk.KcStreamingIngestClient; -import com.snowflake.kafka.connector.internal.ingestsdk.StreamingClientManager; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; @@ -24,7 +22,9 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import net.snowflake.ingest.streaming.InsertValidationResponse; +import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.apache.kafka.common.TopicPartition; @@ -35,7 +35,6 @@ import org.apache.kafka.connect.json.JsonConverter; import org.apache.kafka.connect.sink.SinkRecord; import org.apache.kafka.connect.sink.SinkTaskContext; -import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -47,34 +46,32 @@ @RunWith(Parameterized.class) public class TopicPartitionChannelTest { + @Mock private KafkaRecordErrorReporter mockKafkaRecordErrorReporter; + + @Mock private SnowflakeStreamingIngestClient mockStreamingClient; + @Mock private SnowflakeStreamingIngestChannel mockStreamingChannel; + @Mock private SinkTaskContext mockSinkTaskContext; - @Mock private KcStreamingIngestClient mockKcStreamingIngestClient; - @Mock private StreamingClientManager mockStreamingClientManager; - @Mock private SnowflakeConnectionService mockConn; - @Mock private StreamingBufferThreshold mockStreamingBufferThreshold; - // constants private static final String TOPIC = "TEST"; + private static final int PARTITION = 0; + private static final String TEST_CHANNEL_NAME = SnowflakeSinkServiceV2.partitionChannelKey(TOPIC, PARTITION); private static final String TEST_TABLE_NAME = "TEST_TABLE"; - private static final long INVALID_OFFSET_VALUE = -1L; - - private final boolean enableSchematization; - private final int taskId = 0; - // models private TopicPartition topicPartition; + private Map sfConnectorConfig; + + private BufferThreshold streamingBufferThreshold; + private SFException SF_EXCEPTION = new SFException(ErrorCode.INVALID_CHANNEL, "INVALID_CHANNEL"); - // expected mock verification count - private int expectedCallGetValidClientCount = 1; - private int expectedCallOpenChannelCount = 1; - private int expectedCallGetTaskIdCount = 1; + private final boolean enableSchematization; public TopicPartitionChannelTest(boolean enableSchematization) { this.enableSchematization = enableSchematization; @@ -87,73 +84,34 @@ public static Collection input() { @Before public void setupEachTest() { - // recreate mocks - this.mockStreamingChannel = Mockito.mock(SnowflakeStreamingIngestChannel.class); - this.mockKafkaRecordErrorReporter = Mockito.mock(KafkaRecordErrorReporter.class); - this.mockSinkTaskContext = Mockito.mock(SinkTaskContext.class); - this.mockKcStreamingIngestClient = Mockito.mock(KcStreamingIngestClient.class); - this.mockStreamingClientManager = Mockito.mock(StreamingClientManager.class); - this.mockConn = Mockito.mock(SnowflakeConnectionService.class); - this.mockStreamingBufferThreshold = Mockito.mock(StreamingBufferThreshold.class); - - // sunny mock interactions and verifications - Mockito.when(this.mockStreamingClientManager.getValidClient(this.taskId)) - .thenReturn(this.mockKcStreamingIngestClient); - Mockito.when( - this.mockKcStreamingIngestClient.openChannel( - Mockito.refEq(TEST_CHANNEL_NAME), - ArgumentMatchers.any(Map.class), - Mockito.refEq(TEST_TABLE_NAME))) - .thenReturn(this.mockStreamingChannel); - Mockito.when(this.mockConn.getTaskId()).thenReturn(this.taskId); - expectedCallGetValidClientCount = 1; - expectedCallOpenChannelCount = 1; - expectedCallGetTaskIdCount = 1; - + mockStreamingClient = Mockito.mock(SnowflakeStreamingIngestClient.class); + mockStreamingChannel = Mockito.mock(SnowflakeStreamingIngestChannel.class); + mockKafkaRecordErrorReporter = Mockito.mock(KafkaRecordErrorReporter.class); + mockSinkTaskContext = Mockito.mock(SinkTaskContext.class); + Mockito.when(mockStreamingClient.isClosed()).thenReturn(false); + Mockito.when(mockStreamingClient.openChannel(ArgumentMatchers.any(OpenChannelRequest.class))) + .thenReturn(mockStreamingChannel); + Mockito.when(mockStreamingChannel.getFullyQualifiedName()).thenReturn(TEST_CHANNEL_NAME); this.topicPartition = new TopicPartition(TOPIC, PARTITION); this.sfConnectorConfig = TestUtils.getConfig(); + this.streamingBufferThreshold = new StreamingBufferThreshold(10, 10_000, 1); this.sfConnectorConfig.put( SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG, Boolean.toString(this.enableSchematization)); - - IngestSdkProvider.setStreamingClientManager(this.mockStreamingClientManager); - } - - @After - public void afterEachTest() { - // need to reset client manager since it is global static variable - IngestSdkProvider.setStreamingClientManager(TestUtils.resetAndGetEmptyStreamingClientManager()); - - // verify the mocks setup above - Mockito.verify(this.mockStreamingClientManager, Mockito.times(expectedCallGetValidClientCount)) - .getValidClient(this.taskId); - Mockito.verify(this.mockKcStreamingIngestClient, Mockito.times(expectedCallOpenChannelCount)) - .openChannel( - Mockito.refEq(TEST_CHANNEL_NAME), - ArgumentMatchers.any(Map.class), - Mockito.refEq(TEST_TABLE_NAME)); - Mockito.verify(this.mockConn, Mockito.times(expectedCallGetTaskIdCount)).getTaskId(); } - @Test + @Test(expected = IllegalStateException.class) public void testTopicPartitionChannelInit_streamingClientClosed() { - Mockito.when( - this.mockStreamingClientManager.getValidClient(ArgumentMatchers.any(Integer.class))) - .thenThrow(SnowflakeErrors.ERROR_3009.getException()); - this.expectedCallOpenChannelCount = 0; // constructor fails before open - - TestUtils.assertError( - SnowflakeErrors.ERROR_3009, - () -> - new TopicPartitionChannel( - topicPartition, - TEST_CHANNEL_NAME, - TEST_TABLE_NAME, - mockStreamingBufferThreshold, - sfConnectorConfig, - mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn)); + Mockito.when(mockStreamingClient.isClosed()).thenReturn(true); + new TopicPartitionChannel( + mockStreamingClient, + topicPartition, + TEST_CHANNEL_NAME, + TEST_TABLE_NAME, + streamingBufferThreshold, + sfConnectorConfig, + mockKafkaRecordErrorReporter, + mockSinkTaskContext); } @Test @@ -162,62 +120,57 @@ public void testFetchOffsetTokenWithRetry_null() { TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); Assert.assertEquals(-1L, topicPartitionChannel.fetchOffsetTokenWithRetry()); - - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); } @Test public void testFetchOffsetTokenWithRetry_validLong() { + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn("100"); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); Assert.assertEquals(100L, topicPartitionChannel.fetchOffsetTokenWithRetry()); - - Mockito.verify(mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); } + // TODO:: Fix this test @Test public void testFirstRecordForChannel() { Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn(null); + Mockito.when( mockStreamingChannel.insertRows( ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) .thenReturn(new InsertValidationResponse()); - Mockito.when( - this.mockStreamingBufferThreshold.isFlushBufferedBytesBased( - ArgumentMatchers.any(Long.class))) - .thenReturn(true); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); JsonConverter converter = new JsonConverter(); HashMap converterConfig = new HashMap(); @@ -240,13 +193,8 @@ public void testFirstRecordForChannel() { topicPartitionChannel.insertRecordToBuffer(record1); Assert.assertEquals(-1l, topicPartitionChannel.getOffsetPersistedInSnowflake()); - Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); - Mockito.verify(mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); - Mockito.verify(this.mockStreamingBufferThreshold, Mockito.times(1)) - .isFlushBufferedBytesBased(ArgumentMatchers.any(Long.class)); + Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); } @Test @@ -255,62 +203,58 @@ public void testCloseChannelException() throws Exception { Mockito.when(mockStreamingChannel.close()).thenReturn(mockFuture); Mockito.when(mockStreamingChannel.getFullyQualifiedName()).thenReturn(TEST_CHANNEL_NAME); - Mockito.when(mockFuture.get()).thenThrow(new InterruptedException("Interrupted Exception")); + Mockito.when(mockFuture.get()).thenThrow(new InterruptedException("Interrupted Exception")); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); topicPartitionChannel.closeChannel(); - - Mockito.verify(mockStreamingChannel, Mockito.times(1)).close(); - Mockito.verify(mockStreamingChannel, Mockito.times(1)).getFullyQualifiedName(); - Mockito.verify(mockFuture, Mockito.times(1)).get(); } - @Test + /* Only SFExceptions are retried and goes into fallback. */ + @Test(expected = SFException.class) public void testFetchOffsetTokenWithRetry_SFException() { - Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenThrow(SF_EXCEPTION); + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) + .thenThrow(SF_EXCEPTION) + .thenThrow(SF_EXCEPTION) + .thenThrow(SF_EXCEPTION) + .thenThrow(SF_EXCEPTION); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); - long fetchedOffset; try { - this.expectedCallOpenChannelCount++; // retry getting offset reopens channel - fetchedOffset = topicPartitionChannel.fetchOffsetTokenWithRetry(); + Assert.assertEquals(-1L, topicPartitionChannel.fetchOffsetTokenWithRetry()); } catch (SFException ex) { - fetchedOffset = INVALID_OFFSET_VALUE; + Mockito.verify(mockStreamingClient, Mockito.times(2)).openChannel(ArgumentMatchers.any()); + Mockito.verify( + topicPartitionChannel.getChannel(), Mockito.times(MAX_GET_OFFSET_TOKEN_RETRIES + 1)) + .getLatestCommittedOffsetToken(); + throw ex; } - - assert fetchedOffset == INVALID_OFFSET_VALUE; - - Mockito.verify(this.mockStreamingChannel, Mockito.times(MAX_GET_OFFSET_TOKEN_RETRIES + 1)) - .getLatestCommittedOffsetToken(); } - /* SFExceptions are retried and goes into fallback where it will reopen the channel and return a - 0 offsetToken */ + /* SFExceptions are retried and goes into fallback where it will reopen the channel and return a 0 offsetToken */ @Test - public void testFetchOffsetTokenWithRetry_validOffsetTokenAfterMaxRetrySFExceptions() { + public void testFetchOffsetTokenWithRetry_validOffsetTokenAfterThreeSFExceptions() { final String offsetTokenAfterMaxAttempts = "0"; - // max retry is currently 3, so throw on first 3 and return correct on last retry Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) .thenThrow(SF_EXCEPTION) .thenThrow(SF_EXCEPTION) @@ -319,199 +263,168 @@ public void testFetchOffsetTokenWithRetry_validOffsetTokenAfterMaxRetrySFExcepti TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); Assert.assertEquals( Long.parseLong(offsetTokenAfterMaxAttempts), topicPartitionChannel.fetchOffsetTokenWithRetry()); - this.expectedCallOpenChannelCount++; // retry getting offset reopens channel - - Mockito.verify(this.mockStreamingChannel, Mockito.times(MAX_GET_OFFSET_TOKEN_RETRIES + 1)) + Mockito.verify(mockStreamingClient, Mockito.times(2)).openChannel(ArgumentMatchers.any()); + Mockito.verify( + topicPartitionChannel.getChannel(), Mockito.times(MAX_GET_OFFSET_TOKEN_RETRIES + 1)) .getLatestCommittedOffsetToken(); } /* No retries are since it throws NumberFormatException */ - @Test + @Test(expected = ConnectException.class) public void testFetchOffsetTokenWithRetry_InvalidNumber() { + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn("invalidNo"); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); - long fetchedOffset; try { - fetchedOffset = topicPartitionChannel.fetchOffsetTokenWithRetry(); + topicPartitionChannel.fetchOffsetTokenWithRetry(); + Assert.fail("Should throw exception"); } catch (ConnectException exception) { - fetchedOffset = INVALID_OFFSET_VALUE; + // Open channel is not called again. + Mockito.verify(mockStreamingClient, Mockito.times(1)).openChannel(ArgumentMatchers.any()); + + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) + .getLatestCommittedOffsetToken(); Assert.assertTrue(exception.getMessage().contains("invalidNo")); + throw exception; } - - assert fetchedOffset == INVALID_OFFSET_VALUE; - - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); } /* No retries and fallback here too since it throws an unknown NPE. */ - @Test + @Test(expected = NullPointerException.class) public void testFetchOffsetTokenWithRetry_NullPointerException() { - NullPointerException npe = new NullPointerException("NPE"); - Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenThrow(npe); + NullPointerException exception = new NullPointerException("NPE"); + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenThrow(exception); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); - long fetchedOffset; try { - fetchedOffset = topicPartitionChannel.fetchOffsetTokenWithRetry(); + Assert.assertEquals(-1L, topicPartitionChannel.fetchOffsetTokenWithRetry()); } catch (NullPointerException ex) { - fetchedOffset = INVALID_OFFSET_VALUE; - assert ex.getMessage().equals(npe.getMessage()); + Mockito.verify(mockStreamingClient, Mockito.times(1)).openChannel(ArgumentMatchers.any()); + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) + .getLatestCommittedOffsetToken(); + throw ex; } - - assert fetchedOffset == INVALID_OFFSET_VALUE; - - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); } - /* No retries and fallback here too since it throws an unknown runtime exception. */ + /* No retries and fallback here too since it throws an unknown NPE. */ @Test(expected = RuntimeException.class) public void testFetchOffsetTokenWithRetry_RuntimeException() { - RuntimeException runtimeException = new RuntimeException("runtime exception"); - Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenThrow(runtimeException); + RuntimeException exception = new RuntimeException("runtime exception"); + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenThrow(exception); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); try { Assert.assertEquals(-1L, topicPartitionChannel.fetchOffsetTokenWithRetry()); } catch (RuntimeException ex) { + Mockito.verify(mockStreamingClient, Mockito.times(1)).openChannel(ArgumentMatchers.any()); + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) + .getLatestCommittedOffsetToken(); + throw ex; } - - long fetchedOffset; - try { - fetchedOffset = topicPartitionChannel.fetchOffsetTokenWithRetry(); - } catch (NullPointerException ex) { - fetchedOffset = INVALID_OFFSET_VALUE; - assert ex.getMessage().equals(runtimeException.getMessage()); - } - - assert fetchedOffset == INVALID_OFFSET_VALUE; - - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); } - /* - try insert rows twice, first will fail, second reopens channel and succeeds - first insertrows: - 1. precompute offset because new channel - got null offset - 2. try insert, fail - throw sfexception - 3. reopen channel - 4. get offset token - null offset - second insert rows: - 1. try insert, succeed - */ + /* Only SFExceptions goes into fallback -> reopens channel, fetch offsetToken and throws Appropriate exception */ @Test public void testInsertRows_SuccessAfterReopenChannel() throws Exception { - final int noOfRecords = 5; - int expectedCallInsertRowCount = 0; - int expectedCallGetOffsetCount = 0; - - // first insert fails, so first offset response is null - // second insert succeeds, and offset is bumped accordingly Mockito.when( mockStreamingChannel.insertRows( ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) - .thenThrow(SF_EXCEPTION) - .thenReturn(new InsertValidationResponse()); - // get offset token is called twice - after channel re-open and before a new partition is just - // created (In Precomputation). So first two returns are the failure, second two are the success - Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) - .thenReturn(null) - .thenReturn(null) - .thenReturn(Long.toString(noOfRecords - 1)) - .thenReturn(Long.toString(noOfRecords - 1)); - Mockito.when( - mockStreamingBufferThreshold.isFlushBufferedBytesBased( - ArgumentMatchers.any(Long.class))) - .thenReturn(true); + .thenThrow(SF_EXCEPTION); + + // get null from snowflake first time it is called and null for second time too since insert + // rows was failure + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn(null); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); - - // verify channel did nothing - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallInsertRowCount)) - .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallGetOffsetCount)) - .getLatestCommittedOffsetToken(); - - // TEST inserting - should fail + mockSinkTaskContext); + final int noOfRecords = 5; + // Since record 0 was not able to ingest, all records in this batch will not be added into the + // buffer. List records = TestUtils.createJsonStringSinkRecords(0, noOfRecords, TOPIC, PARTITION); + records.forEach(topicPartitionChannel::insertRecordToBuffer); - this.expectedCallOpenChannelCount++; // should reopen channel here - expectedCallInsertRowCount += 1; - expectedCallGetOffsetCount += 2; - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallInsertRowCount)) + Mockito.verify(mockStreamingClient, Mockito.times(2)).openChannel(ArgumentMatchers.any()); + // insert rows is only called once. + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallGetOffsetCount)) + + // get offset token is called once after channel re-open + once before a new partition is just + // created (In Precomputation) + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(2)) .getLatestCommittedOffsetToken(); - // TEST inserting - should succeed + // Now, it should be successful + Mockito.when( + mockStreamingChannel.insertRows( + ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) + .thenReturn(new InsertValidationResponse()); + + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) + .thenReturn(Long.toString(noOfRecords - 1)); + + // We will mimick the retry strategy now + // This time since record 0 is again trying to insert, we will call insertFiles noOfRecords + // times records.forEach(topicPartitionChannel::insertRecordToBuffer); - expectedCallInsertRowCount += noOfRecords; - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallInsertRowCount)) + Mockito.verify( + topicPartitionChannel.getChannel(), + Mockito.times(noOfRecords + 1)) // noOfRecords + 1 (before retry) .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallGetOffsetCount)) - .getLatestCommittedOffsetToken(); - // expected number of records were ingested Assert.assertEquals(noOfRecords - 1, topicPartitionChannel.fetchOffsetTokenWithRetry()); - expectedCallGetOffsetCount += 1; // one more get offset call - - Mockito.verify(this.mockStreamingChannel, Mockito.times(expectedCallGetOffsetCount)) - .getLatestCommittedOffsetToken(); - Mockito.verify(this.mockStreamingBufferThreshold, Mockito.times(6)) - .isFlushBufferedBytesBased(ArgumentMatchers.any(Long.class)); } @Test @@ -519,97 +432,78 @@ public void testInsertRowsWithSchemaEvolution() throws Exception { if (this.sfConnectorConfig .get(SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG) .equals("true")) { - int noOfRecords = 0; - - // this response should insert the row - InsertValidationResponse validResponse = new InsertValidationResponse(); - SinkRecord validRecord = TestUtils.createNativeJsonSinkRecord(noOfRecords, TOPIC, PARTITION); - noOfRecords++; - - // this response should get row sent to dlq - InsertValidationResponse failureResponse = new InsertValidationResponse(); + InsertValidationResponse validationResponse1 = new InsertValidationResponse(); InsertValidationResponse.InsertError insertError1 = - new InsertValidationResponse.InsertError("CONTENT", noOfRecords); + new InsertValidationResponse.InsertError("CONTENT", 0); insertError1.setException(SF_EXCEPTION); - failureResponse.addError(insertError1); - SinkRecord failureRecord = - TestUtils.createNativeJsonSinkRecord(noOfRecords, TOPIC, PARTITION); - noOfRecords++; + validationResponse1.addError(insertError1); - // this response should make schema evolve - InsertValidationResponse evolveSchemaResponse = new InsertValidationResponse(); + InsertValidationResponse validationResponse2 = new InsertValidationResponse(); InsertValidationResponse.InsertError insertError2 = - new InsertValidationResponse.InsertError("CONTENT", noOfRecords); + new InsertValidationResponse.InsertError("CONTENT", 0); insertError2.setException(SF_EXCEPTION); - insertError2.setExtraColNames(Collections.singletonList("gender")); // will evolve schema - evolveSchemaResponse.addError(insertError2); - SinkRecord evolveSchemaRecord = - TestUtils.createNativeJsonSinkRecord(noOfRecords, TOPIC, PARTITION); - noOfRecords++; - - this.sfConnectorConfig.put( - ERRORS_TOLERANCE_CONFIG, SnowflakeSinkConnectorConfig.ErrorTolerance.ALL.toString()); - this.sfConnectorConfig.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + insertError2.setExtraColNames(Collections.singletonList("gender")); + validationResponse2.addError(insertError2); - // mocks Mockito.when( - this.mockStreamingChannel.insertRow( + mockStreamingChannel.insertRow( ArgumentMatchers.any(), ArgumentMatchers.any(String.class))) - .thenReturn(validResponse) - .thenReturn(failureResponse) - .thenReturn(evolveSchemaResponse); + .thenReturn(new InsertValidationResponse()) + .thenReturn(validationResponse1) + .thenReturn(validationResponse2); + + SnowflakeConnectionService conn = Mockito.mock(SnowflakeConnectionService.class); Mockito.when( - this.mockConn.hasSchemaEvolutionPermission( - ArgumentMatchers.any(), ArgumentMatchers.any())) + conn.hasSchemaEvolutionPermission(ArgumentMatchers.any(), ArgumentMatchers.any())) .thenReturn(true); Mockito.doNothing() - .when(this.mockConn) + .when(conn) .appendColumnsToTable(ArgumentMatchers.any(), ArgumentMatchers.any()); - Mockito.when(mockStreamingBufferThreshold.isFlushTimeBased(ArgumentMatchers.any(Long.class))) - .thenReturn(true); - // test + long bufferFlushTimeSeconds = 5L; + StreamingBufferThreshold bufferThreshold = + new StreamingBufferThreshold(bufferFlushTimeSeconds, 1_000 /* < 1KB */, 10000000L); + + Map sfConnectorConfigWithErrors = new HashMap<>(sfConnectorConfig); + sfConnectorConfigWithErrors.put( + ERRORS_TOLERANCE_CONFIG, SnowflakeSinkConnectorConfig.ErrorTolerance.ALL.toString()); + sfConnectorConfigWithErrors.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + InMemoryKafkaRecordErrorReporter kafkaRecordErrorReporter = + new InMemoryKafkaRecordErrorReporter(); + TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, - this.sfConnectorConfig, - this.mockKafkaRecordErrorReporter, - this.mockSinkTaskContext, - this.mockConn); + bufferThreshold, + sfConnectorConfigWithErrors, + kafkaRecordErrorReporter, + mockSinkTaskContext, + conn); + + final int noOfRecords = 3; + List records = + TestUtils.createNativeJsonSinkRecords(0, noOfRecords, TOPIC, PARTITION); - topicPartitionChannel.insertRecordToBuffer(validRecord); - topicPartitionChannel.insertRecordToBuffer(failureRecord); - topicPartitionChannel.insertRecordToBuffer(evolveSchemaRecord); + records.forEach(topicPartitionChannel::insertRecordToBuffer); + + // In an ideal world, put API is going to invoke this to check if flush time threshold has + // reached. + // We are mimicking that call. + // Will wait for 10 seconds. + Thread.sleep(bufferFlushTimeSeconds * 1000 + 10); topicPartitionChannel.insertBufferedRecordsIfFlushTimeThresholdReached(); - expectedCallOpenChannelCount++; - // Verify that the buffer is cleaned up and one record is in the DLQ (one error reported) + // Verify that the buffer is cleaned up and one record is in the DLQ Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(this.mockKafkaRecordErrorReporter, Mockito.times(1)) - .reportError(Mockito.refEq(failureRecord), ArgumentMatchers.any(SFException.class)); - - Mockito.verify(this.mockStreamingChannel, Mockito.times(noOfRecords)) - .insertRow(ArgumentMatchers.any(), ArgumentMatchers.any(String.class)); - Mockito.verify(this.mockConn, Mockito.times(1)) - .hasSchemaEvolutionPermission(ArgumentMatchers.any(), ArgumentMatchers.any()); - Mockito.verify(this.mockConn, Mockito.times(1)) - .appendColumnsToTable(ArgumentMatchers.any(), ArgumentMatchers.any()); - Mockito.verify(mockStreamingBufferThreshold, Mockito.times(1)) - .isFlushTimeBased(ArgumentMatchers.any(Long.class)); - } else { - // not streaming means nothing is executed - this.expectedCallGetValidClientCount = 0; - this.expectedCallOpenChannelCount = 0; - this.expectedCallGetTaskIdCount = 0; + Assert.assertEquals(1, kafkaRecordErrorReporter.getReportedRecords().size()); } } - /* SFExceptions is thrown in first attempt of insert rows. It is also thrown while refetching - committed offset from snowflake after reopening the channel */ + /* SFExceptions is thrown in first attempt of insert rows. It is also thrown while refetching committed offset from snowflake after reopening the channel */ @Test(expected = SFException.class) public void testInsertRows_GetOffsetTokenFailureAfterReopenChannel() throws Exception { Mockito.when( @@ -622,28 +516,29 @@ public void testInsertRows_GetOffsetTokenFailureAfterReopenChannel() throws Exce TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); List records = TestUtils.createJsonStringSinkRecords(0, 1, TOPIC, PARTITION); try { - this.expectedCallOpenChannelCount++; // retry getting offset reopens channel TopicPartitionChannel.StreamingBuffer streamingBuffer = topicPartitionChannel.new StreamingBuffer(); streamingBuffer.insert(records.get(0)); topicPartitionChannel.insertBufferedRecords(streamingBuffer); } catch (SFException ex) { - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)) + Mockito.verify(mockStreamingClient, Mockito.times(2)).openChannel(ArgumentMatchers.any()); + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); // get offset token is called once after channel re-open - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) + .getLatestCommittedOffsetToken(); throw ex; } } @@ -659,14 +554,14 @@ public void testInsertRows_RuntimeException() throws Exception { TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); List records = TestUtils.createJsonStringSinkRecords(0, 1, TOPIC, PARTITION); @@ -675,7 +570,8 @@ public void testInsertRows_RuntimeException() throws Exception { try { topicPartitionChannel.insertBufferedRecords(topicPartitionChannel.getStreamingBuffer()); } catch (RuntimeException ex) { - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)) + Mockito.verify(mockStreamingClient, Mockito.times(1)).openChannel(ArgumentMatchers.any()); + Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(1)) .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); throw ex; } @@ -696,14 +592,14 @@ public void testInsertRows_ValidationResponseHasErrors_NoErrorTolerance() throws TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + streamingBufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); List records = TestUtils.createJsonStringSinkRecords(0, 1, TOPIC, PARTITION); @@ -712,8 +608,6 @@ public void testInsertRows_ValidationResponseHasErrors_NoErrorTolerance() throws try { topicPartitionChannel.insertBufferedRecords(topicPartitionChannel.getStreamingBuffer()); } catch (DataException ex) { - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); throw ex; } } @@ -727,23 +621,26 @@ public void testInsertRows_ValidationResponseHasErrors_ErrorTolerance_ALL() thro insertErrorWithException.setException(SF_EXCEPTION); validationResponse.addError(insertErrorWithException); Mockito.when( - this.mockStreamingChannel.insertRows( + mockStreamingChannel.insertRows( ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) .thenReturn(validationResponse); - this.sfConnectorConfig.put( + Map sfConnectorConfigWithErrors = new HashMap<>(sfConnectorConfig); + sfConnectorConfigWithErrors.put( ERRORS_TOLERANCE_CONFIG, SnowflakeSinkConnectorConfig.ErrorTolerance.ALL.toString()); - this.sfConnectorConfig.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + sfConnectorConfigWithErrors.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + InMemoryKafkaRecordErrorReporter kafkaRecordErrorReporter = + new InMemoryKafkaRecordErrorReporter(); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( - this.topicPartition, + mockStreamingClient, + topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - this.mockStreamingBufferThreshold, - this.sfConnectorConfig, - this.mockKafkaRecordErrorReporter, - this.mockSinkTaskContext, - this.mockConn); + new StreamingBufferThreshold(1000, 10_000_000, 10000), + sfConnectorConfigWithErrors, + kafkaRecordErrorReporter, + mockSinkTaskContext); List records = TestUtils.createJsonStringSinkRecords(0, 1, TOPIC, PARTITION); @@ -752,10 +649,8 @@ public void testInsertRows_ValidationResponseHasErrors_ErrorTolerance_ALL() thro streamingBuffer.insert(records.get(0)); assert topicPartitionChannel.insertBufferedRecords(streamingBuffer).hasErrors(); - Mockito.verify(this.mockKafkaRecordErrorReporter, Mockito.times(1)) - .reportError(Mockito.refEq(records.get(0)), ArgumentMatchers.any(SFException.class)); - Mockito.verify(this.mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); + + assert kafkaRecordErrorReporter.getReportedRecords().size() == 1; } /* Valid response but has errors, error tolerance is ALL. Meaning it will ignore the error. */ @@ -772,21 +667,24 @@ public void testInsertRows_ValidationResponseHasErrors_ErrorTolerance_ALL_LogEna ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) .thenReturn(validationResponse); - this.sfConnectorConfig.put( + Map sfConnectorConfigWithErrors = new HashMap<>(sfConnectorConfig); + sfConnectorConfigWithErrors.put( ERRORS_TOLERANCE_CONFIG, SnowflakeSinkConnectorConfig.ErrorTolerance.ALL.toString()); - this.sfConnectorConfig.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); - this.sfConnectorConfig.put(ERRORS_LOG_ENABLE_CONFIG, "true"); + sfConnectorConfigWithErrors.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + sfConnectorConfigWithErrors.put(ERRORS_LOG_ENABLE_CONFIG, "true"); + InMemoryKafkaRecordErrorReporter kafkaRecordErrorReporter = + new InMemoryKafkaRecordErrorReporter(); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( - this.topicPartition, + mockStreamingClient, + topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - this.mockStreamingBufferThreshold, - this.sfConnectorConfig, - this.mockKafkaRecordErrorReporter, - this.mockSinkTaskContext, - this.mockConn); + streamingBufferThreshold, + sfConnectorConfigWithErrors, + kafkaRecordErrorReporter, + mockSinkTaskContext); List records = TestUtils.createJsonStringSinkRecords(0, 1, TOPIC, PARTITION); @@ -795,126 +693,106 @@ public void testInsertRows_ValidationResponseHasErrors_ErrorTolerance_ALL_LogEna streamingBuffer.insert(records.get(0)); assert topicPartitionChannel.insertBufferedRecords(streamingBuffer).hasErrors(); - Mockito.verify(this.mockKafkaRecordErrorReporter, Mockito.times(1)) - .reportError(Mockito.refEq(records.get(0)), ArgumentMatchers.any(SFException.class)); - Mockito.verify(mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)); + + assert kafkaRecordErrorReporter.getReportedRecords().size() == 1; } // --------------- TEST THRESHOLDS --------------- - // insert 5 records, 4th will trigger the byte threshold, 5th will trigger time threshold @Test public void testBufferBytesThreshold() throws Exception { + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) + .thenReturn(null) + .thenReturn("0") + .thenReturn("1"); + Mockito.when( mockStreamingChannel.insertRows( ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) .thenReturn(new InsertValidationResponse()); - Mockito.when( - mockStreamingBufferThreshold.isFlushBufferedBytesBased( - ArgumentMatchers.any(Long.class))) - .thenReturn(false) - .thenReturn(false) - .thenReturn(false) - .thenReturn(true) - .thenReturn(false); - Mockito.when(mockStreamingBufferThreshold.isFlushTimeBased(ArgumentMatchers.any(Long.class))) - .thenReturn(true); + + final long bufferFlushTimeSeconds = 5L; + StreamingBufferThreshold bufferThreshold = + new StreamingBufferThreshold(bufferFlushTimeSeconds, 800 /* < 1KB */, 10000000L); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + bufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); + // Sending 5 records will trigger a buffer bytes based threshold after 4 records have been + // added. Size of each record after serialization to Json is 260 Bytes List records = createNativeJsonSinkRecords(0, 5, "test", 0); - // insert 3 records, verify rows in buffer - topicPartitionChannel.insertRecordToBuffer(records.get(0)); - topicPartitionChannel.insertRecordToBuffer(records.get(1)); - topicPartitionChannel.insertRecordToBuffer(records.get(2)); - Assert.assertTrue(!topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(mockStreamingChannel, Mockito.times(0)) - .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); - // insert 4th record, verify byte flush - no rows in buffer - topicPartitionChannel.insertRecordToBuffer(records.get(3)); - Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); + records.forEach(topicPartitionChannel::insertRecordToBuffer); - // insert 5th record - topicPartitionChannel.insertRecordToBuffer(records.get(4)); + Assert.assertEquals(0L, topicPartitionChannel.fetchOffsetTokenWithRetry()); + + // In an ideal world, put API is going to invoke this to check if flush time threshold has + // reached. + // We are mimicking that call. + // Will wait for 10 seconds. + Thread.sleep(bufferFlushTimeSeconds * 1000 + 10); - // flush on time buffer, verify time flush - no rows in buffer topicPartitionChannel.insertBufferedRecordsIfFlushTimeThresholdReached(); + Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); Mockito.verify(mockStreamingChannel, Mockito.times(2)) .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); - - Mockito.verify(mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); - Mockito.verify(mockStreamingBufferThreshold, Mockito.times(1)) - .isFlushTimeBased(ArgumentMatchers.any(Long.class)); - Mockito.verify(mockStreamingBufferThreshold, Mockito.times(5)) - .isFlushBufferedBytesBased(ArgumentMatchers.any(Long.class)); } - // insert 3 records, 2nd will trigger the byte threshold, 3rd will trigger time threshold @Test public void testBigAvroBufferBytesThreshold() throws Exception { + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()) + .thenReturn(null) + .thenReturn("1") + .thenReturn("2"); + Mockito.when( mockStreamingChannel.insertRows( ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class))) .thenReturn(new InsertValidationResponse()); - Mockito.when( - mockStreamingBufferThreshold.isFlushBufferedBytesBased( - ArgumentMatchers.any(Long.class))) - .thenReturn(false) - .thenReturn(true) - .thenReturn(false); - Mockito.when(mockStreamingBufferThreshold.isFlushTimeBased(ArgumentMatchers.any(Long.class))) - .thenReturn(true); + + final long bufferFlushTimeSeconds = 5L; + StreamingBufferThreshold bufferThreshold = + new StreamingBufferThreshold(bufferFlushTimeSeconds, 10_000 /* < 10 KB */, 10000000L); TopicPartitionChannel topicPartitionChannel = new TopicPartitionChannel( + mockStreamingClient, topicPartition, TEST_CHANNEL_NAME, TEST_TABLE_NAME, - mockStreamingBufferThreshold, + bufferThreshold, sfConnectorConfig, mockKafkaRecordErrorReporter, - mockSinkTaskContext, - mockConn); + mockSinkTaskContext); // Sending 3 records will trigger a buffer bytes based threshold after 2 records have been // added. Size of each record after serialization to Json is ~6 KBytes List records = createBigAvroRecords(0, 3, "test", 0); - // insert 1 record, verify row in buffer - topicPartitionChannel.insertRecordToBuffer(records.get(0)); - Assert.assertTrue(!topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(mockStreamingChannel, Mockito.times(0)) - .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); - // insert 2nd record, verify byte flush - no rows in buffer - topicPartitionChannel.insertRecordToBuffer(records.get(1)); - Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); - Mockito.verify(mockStreamingChannel, Mockito.times(1)) - .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); + records.forEach(topicPartitionChannel::insertRecordToBuffer); - // insert 3th record - topicPartitionChannel.insertRecordToBuffer(records.get(2)); + Assert.assertEquals(1L, topicPartitionChannel.fetchOffsetTokenWithRetry()); + + // In an ideal world, put API is going to invoke this to check if flush time threshold has + // reached. + // We are mimicking that call. + // Will wait for 10 seconds. + Thread.sleep(bufferFlushTimeSeconds * 1000 + 10); - // flush on time buffer, verify time flush - no rows in buffer topicPartitionChannel.insertBufferedRecordsIfFlushTimeThresholdReached(); + Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); + Mockito.verify(mockStreamingChannel, Mockito.times(2)) + .insertRows(ArgumentMatchers.any(), ArgumentMatchers.any()); - Mockito.verify(mockStreamingChannel, Mockito.times(1)).getLatestCommittedOffsetToken(); - Mockito.verify(mockStreamingBufferThreshold, Mockito.times(1)) - .isFlushTimeBased(ArgumentMatchers.any(Long.class)); - Mockito.verify(mockStreamingBufferThreshold, Mockito.times(3)) - .isFlushBufferedBytesBased(ArgumentMatchers.any(Long.class)); + Assert.assertEquals(2L, topicPartitionChannel.fetchOffsetTokenWithRetry()); } }