diff --git a/src/main/java/net/snowflake/ingest/streaming/SnowflakeStreamingIngestClient.java b/src/main/java/net/snowflake/ingest/streaming/SnowflakeStreamingIngestClient.java index 3054d9564..664e43528 100644 --- a/src/main/java/net/snowflake/ingest/streaming/SnowflakeStreamingIngestClient.java +++ b/src/main/java/net/snowflake/ingest/streaming/SnowflakeStreamingIngestClient.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Map; +import net.snowflake.ingest.streaming.internal.DropChannelResponse; /** * A class that is the starting point for using the Streaming Ingest client APIs, a single client @@ -37,7 +38,7 @@ public interface SnowflakeStreamingIngestClient extends AutoCloseable { * * @param request the drop channel request */ - void dropChannel(DropChannelRequest request); + DropChannelResponse dropChannel(DropChannelRequest request); /** * Get the client name diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index e80c979fc..5b8ea2814 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -353,7 +353,7 @@ public SnowflakeStreamingIngestChannelInternal openChannel(OpenChannelRequest } @Override - public void dropChannel(DropChannelRequest request) { + public DropChannelResponse dropChannel(DropChannelRequest request) { if (isClosed) { throw new SFException(ErrorCode.CLOSED_CLIENT); } @@ -404,6 +404,7 @@ public void dropChannel(DropChannelRequest request) { request.getFullyQualifiedTableName(), request.getClientSequencer(), getName()); + return response; } catch (IOException | IngestResponseException e) { throw new SFException(e, ErrorCode.OPEN_CHANNEL_FAILURE, e.getMessage()); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index ca20fa0cc..ca36626ee 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -7,6 +7,7 @@ import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; import static net.snowflake.ingest.utils.Constants.ROLE; import static net.snowflake.ingest.utils.Constants.USER; +import static org.mockito.ArgumentMatchers.argThat; import java.security.KeyPair; import java.security.PrivateKey; @@ -32,6 +33,7 @@ import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; +import net.snowflake.ingest.streaming.DropChannelRequest; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; @@ -816,6 +818,46 @@ public void testClose() throws Exception { // Calling close again on closed channel shouldn't fail channel.close().get(); + Mockito.verify(client, Mockito.times(0)).dropChannel(Mockito.any()); + } + + @Test + public void testDropOnClose() throws Exception { + SnowflakeStreamingIngestClientInternal client = + Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client")); + SnowflakeStreamingIngestChannelInternal channel = + new SnowflakeStreamingIngestChannelInternal<>( + "channel", + "db", + "schema", + "table", + "0", + 1L, + 0L, + client, + "key", + 1234L, + OpenChannelRequest.OnErrorOption.CONTINUE, + UTC, + true); + ChannelsStatusResponse response = new ChannelsStatusResponse(); + response.setStatusCode(0L); + response.setMessage("Success"); + response.setChannels(new ArrayList<>()); + + Mockito.doReturn(response).when(client).getChannelsStatus(Mockito.any()); + + Assert.assertFalse(channel.isClosed()); + DropChannelResponse dropChannelResponse = new DropChannelResponse(); + Mockito.doReturn(dropChannelResponse).when(client).dropChannel(Mockito.any()); + channel.close().get(); + Assert.assertTrue(channel.isClosed()); + Mockito.verify(client, Mockito.times(1)) + .dropChannel( + argThat( + (DropChannelRequest req) -> + req.getChannelName().equals(channel.getName()) + && req.getClientSequencer().equals(channel.getChannelSequencer()))); } @Test diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index fc6421f0e..d280ec439 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.io.StringWriter; +import java.nio.charset.Charset; import java.security.KeyPair; import java.security.PrivateKey; import java.time.ZoneOffset; @@ -43,6 +44,7 @@ import net.snowflake.client.jdbc.internal.google.common.collect.Sets; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; +import net.snowflake.ingest.streaming.DropChannelRequest; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; @@ -373,6 +375,47 @@ public void testGetChannelsStatusWithRequest() throws Exception { objectMapper.writeValueAsString(request), CHANNEL_STATUS_ENDPOINT, "channel status"); } + @Test + public void testDropChannel() throws Exception { + DropChannelResponse response = new DropChannelResponse(); + response.setStatusCode(RESPONSE_SUCCESS); + response.setMessage("dropped"); + String responseString = objectMapper.writeValueAsString(response); + + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); + StatusLine statusLine = Mockito.mock(StatusLine.class); + HttpEntity httpEntity = Mockito.mock(HttpEntity.class); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()) + .thenReturn(IOUtils.toInputStream(responseString, Charset.defaultCharset())); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + + RequestBuilder requestBuilder = + Mockito.spy( + new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); + + DropChannelRequest request = + DropChannelRequest.builder("channel") + .setDBName("db") + .setTableName("table") + .setSchemaName("schema") + .build(); + DropChannelResponse result = client.dropChannel(request); + Assert.assertEquals(response.getMessage(), result.getMessage()); + } + @Test public void testGetChannelsStatusWithRequestError() throws Exception { ChannelsStatusResponse response = new ChannelsStatusResponse(); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java index 2318a7162..c77d6d455 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java @@ -1,8 +1,6 @@ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.utils.Constants.BLOB_NO_HEADER; -import static net.snowflake.ingest.utils.Constants.COMPRESS_BLOB_TWICE; -import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.*; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; @@ -189,6 +187,41 @@ public void testSimpleIngest() throws Exception { Assert.fail("Row sequencer not updated before timeout"); } + @Test + public void testDropChannel() throws Exception { + SnowflakeURL url = new SnowflakeURL(TestUtils.getAccountURL()); + RequestBuilder requestBuilder = + Mockito.spy( + new RequestBuilder( + url, + TestUtils.getUser(), + TestUtils.getKeyPair(), + HttpUtil.getHttpClient(url.getAccount()), + "testrequestbuilder")); + client.injectRequestBuilder(requestBuilder); + + OpenChannelRequest request1 = + OpenChannelRequest.builder("CHANNEL") + .setDBName(testDb) + .setSchemaName(TEST_SCHEMA) + .setTableName(TEST_TABLE) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .setDropOnClose(true) + .build(); + + // Open a streaming ingest channel from the given client + SnowflakeStreamingIngestChannel channel1 = client.openChannel(request1); + // Close the channel after insertion + channel1.close().get(); + + // verify expected request sent to server + Mockito.verify(requestBuilder) + .generateStreamingIngestPostRequest( + ArgumentMatchers.contains("channel"), + ArgumentMatchers.refEq(DROP_CHANNEL_ENDPOINT), + ArgumentMatchers.refEq("drop_channel")); + } + @Test public void testParameterOverrides() throws Exception { Map parameterMap = new HashMap<>();