diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index c2f84537e..1f7267b08 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -335,6 +335,7 @@ public void restartFrom(RecordsRetrieved recordsRetrieved) { @Override public void subscribe(Subscriber s) { + throwOnIllegalState(); subscriber = s; subscriber.onSubscribe(new Subscription() { @Override diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index 8d88151b1..278c3fdf5 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -31,8 +31,10 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -64,6 +66,7 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; @@ -377,13 +380,27 @@ record = Record.builder().data(createByteBufferWithSize(1024)).build(); @Test(expected = IllegalStateException.class) public void testGetNextRecordsWithoutStarting() { verify(executorService, never()).execute(any()); - getRecordsCache.drainQueueForRequests(); + Subscriber mockSubscriber = mock(Subscriber.class); + getRecordsCache.subscribe(mockSubscriber); } @Test(expected = IllegalStateException.class) public void testCallAfterShutdown() { + GetRecordsResponse response = GetRecordsResponse.builder().records( + Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build()) + .nextShardIterator(NEXT_SHARD_ITERATOR).build(); + when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response); + + getRecordsCache.start(sequenceNumber, initialPosition); + + verify(getRecordsRetrievalStrategy, timeout(100).atLeastOnce()).getRecords(anyInt()); + when(executorService.isShutdown()).thenReturn(true); - getRecordsCache.drainQueueForRequests(); + Subscriber mockSubscriber = mock(Subscriber.class); + getRecordsCache.subscribe(mockSubscriber); + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + verify(mockSubscriber).onSubscribe(subscriptionCaptor.capture()); + subscriptionCaptor.getValue().request(1); } @Test