Skip to content

Commit

Permalink
[FLINK-36427] Multiple shard IT cases for KDS source
Browse files Browse the repository at this point in the history
  • Loading branch information
elphastori authored and hlteoh37 committed Oct 7, 2024
1 parent b681cfc commit 91381cd
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ public RecordsWithSplitIds<Record> fetch() throws IOException {
.get(getRecordsResponse.records().size() - 1)
.sequenceNumber()));

assignedSplits.add(splitState);
if (!isComplete) {
assignedSplits.add(splitState);
}

return new KinesisRecordsWithSplitIds(
getRecordsResponse.records().iterator(), splitState.getSplitId(), isComplete);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.flink.connector.aws.testutils.AWSServicesTestUtils;
import org.apache.flink.connector.aws.testutils.LocalstackContainer;
import org.apache.flink.connector.aws.util.AWSGeneralUtil;
import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.test.junit5.MiniClusterExtension;

Expand Down Expand Up @@ -35,10 +36,13 @@
import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry;
import software.amazon.awssdk.services.kinesis.model.ScalingType;
import software.amazon.awssdk.services.kinesis.model.StreamStatus;
import software.amazon.awssdk.services.kinesis.model.UpdateShardCountRequest;

import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -49,8 +53,12 @@
import static org.apache.flink.connector.aws.config.AWSConfigConstants.AWS_SECRET_ACCESS_KEY;
import static org.apache.flink.connector.aws.config.AWSConfigConstants.HTTP_PROTOCOL_VERSION;
import static org.apache.flink.connector.aws.config.AWSConfigConstants.TRUST_ALL_CERTIFICATES;
import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_POSITION;

/** IT cases for using {@code KinesisStreamsSource} using a localstack container. */
/**
* IT cases for using {@code KinesisStreamsSource} using a localstack container. StopWithSavepoint
* IT cases have not been added because it is not exposed in the new source API.
*/
@Testcontainers
@ExtendWith(MiniClusterExtension.class)
public class KinesisStreamsSourceITCase {
Expand Down Expand Up @@ -100,11 +108,33 @@ void nonExistentStreamShouldResultInFailure() {
}

@Test
void validStreamIsConsumed() throws Exception {
void singleShardStreamIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("valid-stream")
.localstackStreamName("single-shard-stream")
.shardCount(1)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/valid-stream")
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/single-shard-stream")
.runScenario();
}

@Test
void multipleShardStreamIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("multiple-shard-stream")
.shardCount(4)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/multiple-shard-stream")
.runScenario();
}

@Test
void reshardedStreamIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("resharded-stream")
.shardCount(1)
.reshardStream(2)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/resharded-stream")
.runScenario();
}

Expand All @@ -116,12 +146,17 @@ private Configuration getDefaultConfiguration() {
configuration.setString(AWS_REGION, Region.AP_SOUTHEAST_1.toString());
configuration.setString(TRUST_ALL_CERTIFICATES, "true");
configuration.setString(HTTP_PROTOCOL_VERSION, "HTTP1_1");
configuration.set(
STREAM_INITIAL_POSITION, KinesisSourceConfigOptions.InitialPosition.TRIM_HORIZON);
return configuration;
}

private class Scenario {
private final int expectedElements = 1000;
private String localstackStreamName = null;
private int shardCount = 1;
private boolean shouldReshardStream = false;
private int targetReshardCount = -1;
private String sourceConnectionStreamArn;
private final Configuration configuration =
KinesisStreamsSourceITCase.this.getDefaultConfiguration();
Expand All @@ -131,7 +166,7 @@ public void runScenario() throws Exception {
prepareStream(localstackStreamName);
}

putRecords(localstackStreamName, expectedElements);
putRecordsWithReshard(localstackStreamName, expectedElements);

KinesisStreamsSource<String> kdsSource =
KinesisStreamsSource.<String>builder()
Expand All @@ -158,6 +193,17 @@ public Scenario localstackStreamName(String localstackStreamName) {
return this;
}

public Scenario shardCount(int shardCount) {
this.shardCount = shardCount;
return this;
}

public Scenario reshardStream(int targetShardCount) {
this.shouldReshardStream = true;
this.targetReshardCount = targetShardCount;
return this;
}

private void prepareStream(String streamName) throws Exception {
final RateLimiter rateLimiter =
RateLimiterBuilder.newBuilder()
Expand All @@ -166,7 +212,10 @@ private void prepareStream(String streamName) throws Exception {
.build();

kinesisClient.createStream(
CreateStreamRequest.builder().streamName(streamName).shardCount(1).build());
CreateStreamRequest.builder()
.streamName(streamName)
.shardCount(shardCount)
.build());

Deadline deadline = Deadline.fromNow(Duration.ofMinutes(1));
while (!rateLimiter.getWhenReady(() -> streamExists(streamName))) {
Expand All @@ -176,9 +225,18 @@ private void prepareStream(String streamName) throws Exception {
}
}

private void putRecords(String streamName, int numRecords) {
private void putRecordsWithReshard(String name, int numRecords) {
int midpoint = numRecords / 2;
putRecords(name, 0, midpoint);
if (shouldReshardStream) {
reshard(name);
}
putRecords(name, midpoint, numRecords);
}

private void putRecords(String streamName, int startInclusive, int endInclusive) {
List<byte[]> messages =
IntStream.range(0, numRecords)
IntStream.range(startInclusive, endInclusive)
.mapToObj(String::valueOf)
.map(String::getBytes)
.collect(Collectors.toList());
Expand All @@ -189,7 +247,7 @@ private void putRecords(String streamName, int numRecords) {
.map(
msg ->
PutRecordsRequestEntry.builder()
.partitionKey("fakePartitionKey")
.partitionKey(UUID.randomUUID().toString())
.data(SdkBytes.fromByteArray(msg))
.build())
.collect(Collectors.toList());
Expand All @@ -202,6 +260,15 @@ private void putRecords(String streamName, int numRecords) {
}
}

private void reshard(String streamName) {
kinesisClient.updateShardCount(
UpdateShardCountRequest.builder()
.streamName(streamName)
.targetShardCount(targetReshardCount)
.scalingType(ScalingType.UNIFORM_SCALING)
.build());
}

private boolean streamExists(final String streamName) {
try {
return kinesisClient
Expand Down

0 comments on commit 91381cd

Please sign in to comment.