diff --git a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java index a3f36c6d0d..d2e97ababb 100644 --- a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java +++ b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java @@ -6,6 +6,9 @@ package org.opensearch.dataprepper.aws.api; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Optional; /** * An interface available to plugins via the AWS Plugin Extension which supplies @@ -19,4 +22,10 @@ public interface AwsCredentialsSupplier { * @return An {@link AwsCredentialsProvider} to use. */ AwsCredentialsProvider getProvider(AwsCredentialsOptions options); + + /** + * Gets the default region if it is configured. Otherwise returns null + * @return Default {@link Region} + */ + Optional getDefaultRegion(); } diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java index 54a7b1c2c9..44d2d22931 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.plugins.aws; +import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; import org.opensearch.dataprepper.model.plugin.ExtensionPoints; @@ -13,12 +14,18 @@ * The {@link ExtensionPlugin} class which adds the AWS Plugin to * Data Prepper as an extension plugin. Everything starts from here. */ +@DataPrepperExtensionPlugin(modelType = AwsPluginConfig.class, rootKeyJsonPath = "/aws/configurations") public class AwsPlugin implements ExtensionPlugin { private final DefaultAwsCredentialsSupplier defaultAwsCredentialsSupplier; + private final AwsPluginConfig awsPluginConfig; + @DataPrepperPluginConstructor - public AwsPlugin() { - final CredentialsProviderFactory credentialsProviderFactory = new CredentialsProviderFactory(); + public AwsPlugin(final AwsPluginConfig awsPluginConfig) { + + this.awsPluginConfig = awsPluginConfig; + + final CredentialsProviderFactory credentialsProviderFactory = new CredentialsProviderFactory(awsPluginConfig != null ? awsPluginConfig.getDefaultStsConfiguration() : new AwsStsConfiguration()); final CredentialsCache credentialsCache = new CredentialsCache(); defaultAwsCredentialsSupplier = new DefaultAwsCredentialsSupplier(credentialsProviderFactory, credentialsCache); } diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfig.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfig.java new file mode 100644 index 0000000000..3ca5fce020 --- /dev/null +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfig.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.aws; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class AwsPluginConfig { + + @JsonProperty("default") + private AwsStsConfiguration defaultStsConfiguration = new AwsStsConfiguration(); + + public AwsStsConfiguration getDefaultStsConfiguration() { + return defaultStsConfiguration; + } +} diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsPluginConfigPublisher.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsPluginConfigPublisher.java index 1714dafd63..b19802a6d9 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsPluginConfigPublisher.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsPluginConfigPublisher.java @@ -5,8 +5,8 @@ package org.opensearch.dataprepper.plugins.aws; -import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher; import org.opensearch.dataprepper.model.plugin.PluginConfigObservable; +import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java new file mode 100644 index 0000000000..ee22244fa7 --- /dev/null +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.aws; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import software.amazon.awssdk.regions.Region; + +public class AwsStsConfiguration { + + @JsonProperty("region") + @Size(min = 1, message = "Region cannot be empty string") + private String awsRegion; + + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + public Region getAwsRegion() { + return awsRegion != null ? Region.of(awsRegion) : null; + } + + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } +} diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java index 26d0caf450..222051beab 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java @@ -36,10 +36,21 @@ class CredentialsProviderFactory { static final long STS_CLIENT_BASE_BACKOFF_MILLIS = 1000L; static final long STS_CLIENT_MAX_BACKOFF_MILLIS = 60000L; + private final AwsStsConfiguration defaultStsConfiguration; + + public CredentialsProviderFactory(final AwsStsConfiguration defaultStsConfiguration) { + Objects.requireNonNull(defaultStsConfiguration); + this.defaultStsConfiguration = defaultStsConfiguration; + } + + Region getDefaultRegion() { + return defaultStsConfiguration.getAwsRegion(); + } + AwsCredentialsProvider providerFromOptions(final AwsCredentialsOptions credentialsOptions) { Objects.requireNonNull(credentialsOptions); - if(credentialsOptions.getStsRoleArn() != null) { + if(credentialsOptions.getStsRoleArn() != null || defaultStsConfiguration.getAwsStsRoleArn() != null) { return createStsCredentials(credentialsOptions); } @@ -48,13 +59,15 @@ AwsCredentialsProvider providerFromOptions(final AwsCredentialsOptions credentia private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions credentialsOptions) { - final String stsRoleArn = credentialsOptions.getStsRoleArn(); + final String stsRoleArn = credentialsOptions.getStsRoleArn() == null ? defaultStsConfiguration.getAwsStsRoleArn() : credentialsOptions.getStsRoleArn(); validateStsRoleArn(stsRoleArn); LOG.debug("Creating new AwsCredentialsProvider with role {}.", stsRoleArn); - final StsClient stsClient = createStsClient(credentialsOptions.getRegion()); + final Region region = credentialsOptions.getRegion() == null ? defaultStsConfiguration.getAwsRegion() : credentialsOptions.getRegion(); + + final StsClient stsClient = createStsClient(region); AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder() .roleSessionName("Data-Prepper-" + UUID.randomUUID()) diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java index 3739446336..d6a647706a 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java @@ -8,6 +8,9 @@ import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Optional; class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier { private final CredentialsProviderFactory credentialsProviderFactory; @@ -22,4 +25,9 @@ class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier { public AwsCredentialsProvider getProvider(final AwsCredentialsOptions options) { return credentialsCache.getOrCreate(options, () -> credentialsProviderFactory.providerFromOptions(options)); } + + @Override + public Optional getDefaultRegion() { + return Optional.ofNullable(credentialsProviderFactory.getDefaultRegion()); + } } diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfigTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfigTest.java new file mode 100644 index 0000000000..5e9322d324 --- /dev/null +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginConfigTest.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.aws; + +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class AwsPluginConfigTest { + + @Test + void testDefault() { + final AwsPluginConfig objectUnderTest = new AwsPluginConfig(); + + assertThat(objectUnderTest, notNullValue()); + assertThat(objectUnderTest.getDefaultStsConfiguration(), notNullValue()); + assertThat(objectUnderTest.getDefaultStsConfiguration().getAwsRegion(), nullValue()); + assertThat(objectUnderTest.getDefaultStsConfiguration().getAwsStsRoleArn(), nullValue()); + } +} diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginIT.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginIT.java index 66bef11587..a1e81198c6 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginIT.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginIT.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.plugins.aws; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; @@ -28,17 +29,29 @@ import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class AwsPluginIT { + @Mock + private AwsPluginConfig awsPluginConfig; + @Mock private ExtensionPoints extensionPoints; @Mock private ExtensionProvider.Context context; + @Mock + private AwsStsConfiguration awsDefaultStsConfiguration; + + @BeforeEach + void setup() { + when(awsPluginConfig.getDefaultStsConfiguration()).thenReturn(awsDefaultStsConfiguration); + } + private AwsPlugin createObjectUnderTest() { - return new AwsPlugin(); + return new AwsPlugin(awsPluginConfig); } @Test @@ -78,6 +91,8 @@ void test_AwsPlugin_with_STS_role() { @Test void test_AwsPlugin_without_STS_role() { + when(awsDefaultStsConfiguration.getAwsStsRoleArn()).thenReturn(null); + createObjectUnderTest().apply(extensionPoints); final ArgumentCaptor> extensionProviderArgumentCaptor = ArgumentCaptor.forClass(ExtensionProvider.class); @@ -108,6 +123,40 @@ void test_AwsPlugin_without_STS_role() { assertThat(awsCredentialsProvider2, sameInstance(awsCredentialsProvider1)); } + @Test + void test_AwsPlugin_without_STS_role_and_with_default_role_uses_default_role() { + when(awsDefaultStsConfiguration.getAwsStsRoleArn()).thenReturn(createStsRole()); + + createObjectUnderTest().apply(extensionPoints); + + final ArgumentCaptor> extensionProviderArgumentCaptor = ArgumentCaptor.forClass(ExtensionProvider.class); + verify(extensionPoints).addExtensionProvider(extensionProviderArgumentCaptor.capture()); + + final ExtensionProvider extensionProvider = extensionProviderArgumentCaptor.getValue(); + + final Optional optionalSupplier = extensionProvider.provideInstance(context); + assertThat(optionalSupplier, notNullValue()); + assertThat(optionalSupplier.isPresent(), equalTo(true)); + + final AwsCredentialsSupplier awsCredentialsSupplier = optionalSupplier.get(); + + final AwsCredentialsOptions awsCredentialsOptions1 = AwsCredentialsOptions.builder() + .withRegion(Region.US_EAST_1) + .build(); + + final AwsCredentialsProvider awsCredentialsProvider1 = awsCredentialsSupplier.getProvider(awsCredentialsOptions1); + + assertThat(awsCredentialsProvider1, instanceOf(StsAssumeRoleCredentialsProvider.class)); + + final AwsCredentialsOptions awsCredentialsOptions2 = AwsCredentialsOptions.builder() + .withRegion(Region.US_EAST_1) + .build(); + + final AwsCredentialsProvider awsCredentialsProvider2 = awsCredentialsSupplier.getProvider(awsCredentialsOptions2); + + assertThat(awsCredentialsProvider2, sameInstance(awsCredentialsProvider1)); + } + private String createStsRole() { return String.format("arn:aws:iam::123456789012:role/%s", UUID.randomUUID()); } diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginTest.java index e643d32c7c..2bbaf3740f 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsPluginTest.java @@ -16,18 +16,25 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class AwsPluginTest { + + @Mock + private AwsPluginConfig awsPluginConfig; + @Mock private ExtensionPoints extensionPoints; private AwsPlugin createObjectUnderTest() { - return new AwsPlugin(); + return new AwsPlugin(awsPluginConfig); } @Test void apply_should_addExtensionProvider() { + when(awsPluginConfig.getDefaultStsConfiguration()).thenReturn(new AwsStsConfiguration()); + createObjectUnderTest().apply(extensionPoints); final ArgumentCaptor extensionProviderArgumentCaptor = @@ -39,4 +46,20 @@ void apply_should_addExtensionProvider() { assertThat(actualExtensionProvider, instanceOf(AwsExtensionProvider.class)); } + + @Test + void null_aws_plugin_config_applies_extensions_correctly() { + final AwsPlugin objectUnderTest = new AwsPlugin(null); + + objectUnderTest.apply(extensionPoints); + + final ArgumentCaptor extensionProviderArgumentCaptor = + ArgumentCaptor.forClass(ExtensionProvider.class); + + verify(extensionPoints).addExtensionProvider(extensionProviderArgumentCaptor.capture()); + + final ExtensionProvider actualExtensionProvider = extensionProviderArgumentCaptor.getValue(); + + assertThat(actualExtensionProvider, instanceOf(AwsExtensionProvider.class)); + } } \ No newline at end of file diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java new file mode 100644 index 0000000000..f44e4dd932 --- /dev/null +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.aws; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.regions.Region; + +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; + +public class AwsStsConfigurationTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @ParameterizedTest + @MethodSource("getRegions") + void testStsConfiguration(final Region region) throws JsonProcessingException { + + final String defaultConfigurationAsString = "{\"region\": \"" + region.toString() + "\", \"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\"}"; + + final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(defaultConfigurationAsString, AwsStsConfiguration.class); + + assertThat(objectUnderTest, notNullValue()); + assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role")); + assertThat(objectUnderTest.getAwsRegion(), equalTo(region)); + } + + private static List getRegions() { + return Region.regions(); + } +} diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java index 2189f20fd8..d8676cabb0 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -31,6 +32,7 @@ import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.UUID; @@ -58,8 +60,11 @@ class CredentialsProviderFactoryTest { @Mock private AwsCredentialsOptions awsCredentialsOptions; + @Mock + private AwsStsConfiguration defaultStsConfiguration; + private CredentialsProviderFactory createObjectUnderTest() { - return new CredentialsProviderFactory(); + return new CredentialsProviderFactory(defaultStsConfiguration); } @Test @@ -99,6 +104,22 @@ void providerFromOptions_with_StsRoleArn() { assertThat(awsCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class)); } + @ParameterizedTest + @MethodSource("getRegions") + void getDefaultRegion_returns_expected_region(final Region region) { + when(defaultStsConfiguration.getAwsRegion()).thenReturn(region); + + final CredentialsProviderFactory credentialsProviderFactory = createObjectUnderTest(); + + final Region actualRegion = credentialsProviderFactory.getDefaultRegion(); + + assertThat(actualRegion, equalTo(region)); + } + + private static List getRegions() { + return Region.regions(); + } + @Nested class WithSts { diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java index 37a2c98a6b..ca62dc24ac 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java @@ -7,12 +7,18 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import java.util.List; +import java.util.Optional; import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.equalTo; @@ -62,4 +68,27 @@ void getProvider_calls_getOrCreate_with_Supplier() { when(credentialsProviderFactory.providerFromOptions(options)).thenReturn(awsCredentialsProvider); assertThat(actualCredentialsSupplier.get(), equalTo(awsCredentialsProvider)); } + + @ParameterizedTest + @MethodSource("getRegions") + void getDefaultRegion_returns_default_region(final Region region) { + when(credentialsProviderFactory.getDefaultRegion()).thenReturn(region); + + final AwsCredentialsSupplier objectUnderTest = createObjectUnderTest(); + assertThat(objectUnderTest.getDefaultRegion().isPresent(), equalTo(true)); + assertThat(objectUnderTest.getDefaultRegion().get(), equalTo(credentialsProviderFactory.getDefaultRegion())); + } + + @Test + void no_default_region_returns_empty_optional() { + when(credentialsProviderFactory.getDefaultRegion()).thenReturn(null); + + final AwsCredentialsSupplier objectUnderTest = createObjectUnderTest(); + assertThat(objectUnderTest.getDefaultRegion(), equalTo(Optional.empty())); + + } + + private static List getRegions() { + return Region.regions(); + } } \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer_KmsIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer_KmsIT.java index 047e896c7f..464d7966e3 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer_KmsIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer_KmsIT.java @@ -15,6 +15,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.CheckpointState; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; @@ -72,6 +74,9 @@ public class KafkaBuffer_KmsIT { @Mock private AcknowledgementSet acknowledgementSet; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + private Random random; private BufferTopicConfig topicConfig; @@ -121,10 +126,12 @@ void setUp() { kmsClient = KmsClient.create(); byteDecoder = null; + + when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenAnswer(options -> DefaultCredentialsProvider.create()); } private KafkaBuffer createObjectUnderTest() { - return new KafkaBuffer(pluginSetting, kafkaBufferConfig, acknowledgementSetManager, null, ignored -> DefaultCredentialsProvider.create(), null); + return new KafkaBuffer(pluginSetting, kafkaBufferConfig, acknowledgementSetManager, null, awsCredentialsSupplier, null); } @Nested