diff --git a/build.sbt b/build.sbt index 95324fc99..915aee126 100644 --- a/build.sbt +++ b/build.sbt @@ -63,6 +63,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind"), "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), + "software.amazon.awssdk" % "auth-crt" % "2.25.23", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java new file mode 100644 index 000000000..1bb7fe4ad --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java @@ -0,0 +1,228 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSSessionCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.services.glue.model.InvalidStateException; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHeader; +import org.apache.http.protocol.HttpContext; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.nvpToMapParams; +import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.skipHeader; + +/** + * Interceptor for signing AWS requests according to Signature Version 4A. + * This interceptor processes HTTP requests, signs them with AWS credentials, + * and updates the request headers to include the signature. + */ +public class AWSRequestSigV4ASigningApacheInterceptor implements HttpRequestInterceptor { + private static final String HTTPS_PROTOCOL = "https"; + private static final int HTTPS_PORT = 443; + + private final String service; + private final String region; + private final Signer signer; + private final AWSCredentialsProvider awsCredentialsProvider; + + /** + * Constructs an interceptor for AWS request signing with metadata access. + * + * @param service The AWS service name. + * @param region The AWS region for signing. + * @param signer The signer implementation. + * @param awsCredentialsProvider The credentials provider for metadata access. + */ + public AWSRequestSigV4ASigningApacheInterceptor(String service, String region, Signer signer, AWSCredentialsProvider awsCredentialsProvider) { + this.service = service; + this.region = region; + this.signer = signer; + this.awsCredentialsProvider = awsCredentialsProvider; + } + + /** + * Processes and signs an HTTP request, updating its headers with the signature. + * + * @param request the HTTP request to process and sign. + * @param context the HTTP context associated with the request. + * @throws IOException if an I/O error occurs during request processing. + */ + @Override + public void process(HttpRequest request, HttpContext context) throws IOException { + SdkHttpFullRequest requestToSign = buildSdkHttpRequest(request, context); + SdkHttpFullRequest signedRequest = signRequest(requestToSign); + updateRequestHeaders(request, signedRequest.headers()); + updateRequestEntity(request, signedRequest); + } + + /** + * Builds an {@link SdkHttpFullRequest} from the Apache {@link HttpRequest}. + * + * @param request the HTTP request to process and sign. + * @param context the HTTP context associated with the request. + * @return an SDK HTTP request ready to be signed. + * @throws IOException if an error occurs while building the request. + */ + private SdkHttpFullRequest buildSdkHttpRequest(HttpRequest request, HttpContext context) throws IOException { + URIBuilder uriBuilder = parseUri(request); + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod())) + .protocol(HTTPS_PROTOCOL) + .port(HTTPS_PORT) + .headers(headerArrayToMap(request.getAllHeaders())) + .rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams())); + + HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST); + if (host == null) { + throw new InvalidStateException("Host must not be null"); + } + builder.host(host.getHostName()); + try { + builder.encodedPath(uriBuilder.build().getRawPath()); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + setRequestEntity(request, builder); + return builder.build(); + } + + /** + * Sets the request entity for the {@link SdkHttpFullRequest.Builder} if the original request contains an entity. + * This is used for requests that have a body, such as POST or PUT requests. + * + * @param request the original HTTP request. + * @param builder the SDK HTTP request builder. + */ + private void setRequestEntity(HttpRequest request, SdkHttpFullRequest.Builder builder) { + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntity entity = ((HttpEntityEnclosingRequest) request).getEntity(); + if (entity != null) { + builder.contentStreamProvider(() -> { + try { + return entity.getContent(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } + } + } + + private URIBuilder parseUri(HttpRequest request) throws IOException { + try { + return new URIBuilder(request.getRequestLine().getUri()); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + } + + /** + * Signs the given SDK HTTP request using the provided AWS credentials and signer. + * + * @param request the SDK HTTP request to sign. + * @return a signed SDK HTTP request. + */ + private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) { + AWSSessionCredentials sessionCredentials = (AWSSessionCredentials) awsCredentialsProvider.getCredentials(); + AwsSessionCredentials awsCredentials = AwsSessionCredentials.create( + sessionCredentials.getAWSAccessKeyId(), + sessionCredentials.getAWSSecretKey(), + sessionCredentials.getSessionToken() + ); + + ExecutionAttributes executionAttributes = new ExecutionAttributes() + .putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, awsCredentials) + .putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, service) + .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, Region.of(region)); + + return signer.sign(request, executionAttributes); + } + + /** + * Updates the HTTP request headers with the signed headers. + * + * @param request the original HTTP request. + * @param signedHeaders the headers after signing. + */ + private void updateRequestHeaders(HttpRequest request, Map> signedHeaders) { + Header[] headers = convertHeaderMapToArray(signedHeaders); + request.setHeaders(headers); + } + + /** + * Updates the request entity based on the signed request. This is used to update the request body after signing. + * + * @param request the original HTTP request. + * @param signedRequest the signed SDK HTTP request. + */ + private void updateRequestEntity(HttpRequest request, SdkHttpFullRequest signedRequest) { + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request; + signedRequest.contentStreamProvider().ifPresent(provider -> { + InputStream contentStream = provider.newStream(); + BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); + basicHttpEntity.setContent(contentStream); + signedRequest.firstMatchingHeader("Content-Length").ifPresent(value -> + basicHttpEntity.setContentLength(Long.parseLong(value))); + signedRequest.firstMatchingHeader("Content-Type").ifPresent(basicHttpEntity::setContentType); + httpEntityEnclosingRequest.setEntity(basicHttpEntity); + }); + } + } + + /** + * Converts an array of {@link Header} objects into a map, consolidating multiple values for the same header name. + * + * @param headers the array of {@link Header} objects to convert. + * @return a map where each key is a header name and each value is a list of header values. + */ + private static Map> headerArrayToMap(final Header[] headers) { + Map> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (Header header : headers) { + if (!skipHeader(header)) { + headersMap.computeIfAbsent(header.getName(), k -> new ArrayList<>()).add(header.getValue()); + } + } + return headersMap; + } + + /** + * Converts a map of headers back into an array of {@link Header} objects. + * + * @param mapHeaders the map of headers to convert. + * @return an array of {@link Header} objects. + */ + private Header[] convertHeaderMapToArray(final Map> mapHeaders) { + return mapHeaders.entrySet().stream() + .map(entry -> new BasicHeader(entry.getKey(), String.join(",", entry.getValue()))) + .toArray(Header[]::new); + } +} \ No newline at end of file diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java index c11677c3f..dd5fd78bc 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java @@ -126,7 +126,7 @@ public void process(final HttpRequest request, final HttpContext context) * @param params list of HTTP query params as NameValuePairs * @return a multimap of HTTP query params */ - private static Map> nvpToMapParams(final List params) { + static Map> nvpToMapParams(final List params) { Map> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (NameValuePair nvp : params) { List argsList = @@ -154,10 +154,11 @@ private static Map headerArrayToMap(final Header[] headers) { * @param header header line to check * @return true if the given header should be excluded when signing */ - private static boolean skipHeader(final Header header) { + static boolean skipHeader(final Header header) { return ("content-length".equalsIgnoreCase(header.getName()) - && "0".equals(header.getValue())) // Strip Content-Length: 0 - || "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint + && "0".equals(header.getValue())) // Strip Content-Length: 0 + || "host".equalsIgnoreCase(header.getName()) // Host comes from endpoint + || "connection".equalsIgnoreCase(header.getName()); // Skip setting Connection manually } /** diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java index c3e65fef3..e71709ec2 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -1,12 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.auth; +import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.Signer; import org.apache.http.HttpException; import org.apache.http.HttpRequest; import org.apache.http.HttpRequestInterceptor; import org.apache.http.client.utils.URIBuilder; import org.apache.http.protocol.HttpContext; +import org.jetbrains.annotations.TestOnly; +import org.opensearch.common.Strings; +import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner; import java.io.IOException; import java.net.URISyntaxException; @@ -19,33 +27,45 @@ public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequ private final String service; private final String metadataAccessIdentifier; - final AWSRequestSigningApacheInterceptor primaryInterceptor; - final AWSRequestSigningApacheInterceptor metadataAccessInterceptor; + final HttpRequestInterceptor primaryInterceptor; + final HttpRequestInterceptor metadataAccessInterceptor; /** * Constructs an interceptor for AWS request signing with optional metadata access. * * @param service The AWS service name. - * @param signer The AWS request signer. + * @param region The AWS region for signing. * @param primaryCredentialsProvider The credentials provider for general access. * @param metadataAccessCredentialsProvider The credentials provider for metadata access. * @param metadataAccessIdentifier Identifier for operations requiring metadata access. */ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, - final Signer signer, + final String region, final AWSCredentialsProvider primaryCredentialsProvider, final AWSCredentialsProvider metadataAccessCredentialsProvider, final String metadataAccessIdentifier) { - this(service, - new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider), - new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider), - metadataAccessIdentifier); + if (Strings.isNullOrEmpty(service)) { + throw new IllegalArgumentException("Service name must not be null or empty."); + } + if (Strings.isNullOrEmpty(region)) { + throw new IllegalArgumentException("Region must not be null or empty."); + } + this.service = service; + this.metadataAccessIdentifier = metadataAccessIdentifier; + AWS4Signer signer = new AWS4Signer(); + signer.setServiceName(service); + signer.setRegionName(region); + this.primaryInterceptor = new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider); + this.metadataAccessInterceptor = primaryCredentialsProvider.equals(metadataAccessCredentialsProvider) + ? this.primaryInterceptor + : new AWSRequestSigV4ASigningApacheInterceptor(service, region, AwsCrtV4aSigner.builder().build(), metadataAccessCredentialsProvider); } // Test constructor allowing injection of mock interceptors + @TestOnly ResourceBasedAWSRequestSigningApacheInterceptor(final String service, - final AWSRequestSigningApacheInterceptor primaryInterceptor, - final AWSRequestSigningApacheInterceptor metadataAccessInterceptor, + final HttpRequestInterceptor primaryInterceptor, + final HttpRequestInterceptor metadataAccessInterceptor, final String metadataAccessIdentifier) { this.service = service == null ? "unknown" : service; this.primaryInterceptor = primaryInterceptor; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index b03ac0c6f..da90df8b5 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -7,7 +7,6 @@ import static org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; -import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import java.io.IOException; @@ -68,6 +67,9 @@ public class FlintOpenSearchClient implements FlintClient { private static final Logger LOG = Logger.getLogger(FlintOpenSearchClient.class.getName()); + private static final String SERVICE_NAME = "es"; + + /** * {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link QueryBuilder} from DSL query string. */ @@ -257,10 +259,6 @@ public IRestHighLevelClient createClient() { // SigV4 support if (options.getAuth().equals(FlintOptions.SIGV4_AUTH)) { - AWS4Signer signer = new AWS4Signer(); - signer.setServiceName("es"); - signer.setRegionName(options.getRegion()); - // Use DefaultAWSCredentialsProviderChain by default. final AtomicReference customAWSCredentialsProvider = new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); @@ -283,7 +281,7 @@ public IRestHighLevelClient createClient() { restClientBuilder.setHttpClientConfigCallback(builder -> { HttpAsyncClientBuilder delegate = builder.addInterceptorLast( new ResourceBasedAWSRequestSigningApacheInterceptor( - signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName())); + SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName())); return RetryableHttpAsyncClient.builder(delegate, options); } ); diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java new file mode 100644 index 000000000..9045fa9c1 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.auth; + +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.message.BasicHttpRequest; +import org.apache.http.protocol.HttpContext; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSSessionCredentials; + +import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class AWSRequestSigV4ASigningApacheInterceptorTest { + @Mock + private Signer mockSigner; + @Mock + private AWSCredentialsProvider mockCredentialsProvider; + @Mock + private HttpContext mockContext; + @Mock + private AWSSessionCredentials mockSessionCredentials; + @Mock + private SdkHttpFullRequest mockSdkHttpFullRequest; + @Captor + private ArgumentCaptor sdkHttpFullRequestCaptor; + + private AWSRequestSigV4ASigningApacheInterceptor interceptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + when(mockCredentialsProvider.getCredentials()).thenReturn(mockSessionCredentials); + when(mockSessionCredentials.getAWSAccessKeyId()).thenReturn("ACCESS_KEY_ID"); + when(mockSessionCredentials.getAWSSecretKey()).thenReturn("SECRET_ACCESS_KEY"); + when(mockSessionCredentials.getSessionToken()).thenReturn("SESSION_TOKEN"); + interceptor = new AWSRequestSigV4ASigningApacheInterceptor("s3", "us-west-2", mockSigner, mockCredentialsProvider); + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(new HttpHost("localhost", 443, "https")); + when(mockSigner.sign(any(), any())).thenReturn(mockSdkHttpFullRequest); + } + + @Test + public void testSigningProcess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + interceptor.process(request, mockContext); + + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + assertEquals(SdkHttpMethod.GET, signedRequest.method()); + assertEquals("/path/to/resource", signedRequest.encodedPath()); + } + + @Test(expected = IOException.class) + public void testInvalidUriHandling() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", ":///this/is/not/a/valid/uri"); + interceptor.process(request, mockContext); + } + + @Test + public void testHeaderUpdateAfterSigning() throws Exception { + // Setup mock signer to return a new SdkHttpFullRequest with an "Authorization" header + when(mockSigner.sign(any(SdkHttpFullRequest.class), any())).thenAnswer(invocation -> { + SdkHttpFullRequest originalRequest = invocation.getArgument(0); + Map> modifiedHeaders = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + modifiedHeaders.putAll(originalRequest.headers()); + modifiedHeaders.put("Authorization", List.of("AWS4-ECDSA-P256-SHA256 Credential=...")); + + // Build a new SdkHttpFullRequest with the modified headers + return SdkHttpFullRequest.builder() + .method(originalRequest.method()) + .uri(originalRequest.getUri()) + .headers(modifiedHeaders) + .build(); + }); + + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + interceptor.process(request, mockContext); + + // Now verify that the HttpRequest has been updated with the new headers from the signed request + assertTrue("The request does not contain the expected 'Authorization' header", + request.containsHeader("Authorization")); + assertEquals("AWS4-ECDSA-P256-SHA256 Credential=...", + request.getFirstHeader("Authorization").getValue()); + } + + @Test + public void testSigningProcessWithCorrectHostFormat() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + + // Setup the interceptor with a mock HTTP context to return an HttpHost with scheme and port + HttpHost expectedHost = new HttpHost("localhost", 443, "https"); + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(expectedHost); + + interceptor.process(request, mockContext); + + // Capture the SdkHttpFullRequest passed to the signer + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + // Assert method and path + assertEquals(SdkHttpMethod.GET, signedRequest.method()); + assertEquals("/path/to/resource", signedRequest.encodedPath()); + + // Verify the host format is correct (hostname only, without scheme and port) + String expectedHostName = "localhost"; // Expected hostname without scheme and port + assertEquals("The host in the signed request should contain only the hostname, without scheme and port.", + expectedHostName, signedRequest.host()); + } + + @Test + public void testConnectionHeaderIsSkippedDuringSigning() throws Exception { + // Create a new HttpRequest with a Connection header + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + request.addHeader("Connection", "keep-alive"); + + interceptor.process(request, mockContext); + + // Verify that the SdkHttpFullRequest passed to the signer does not contain the Connection header + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + // Assert that the signed request does not have the Connection header + assertFalse("The signed request should not contain a 'Connection' header.", + signedRequest.headers().containsKey("Connection")); + } +} \ No newline at end of file diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java index 0ef021b53..e3edf3f73 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java @@ -1,8 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.auth; import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSSessionCredentials; import com.amazonaws.auth.Signer; +import org.apache.http.HttpHost; import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; import org.apache.http.message.BasicHttpRequest; import org.apache.http.protocol.HttpContext; import org.junit.Before; @@ -11,19 +19,25 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; import static org.mockito.Mockito.*; public class ResourceBasedAWSRequestSigningApacheInterceptorTest { @Mock - private Signer mockSigner; + private Signer mockPrimarySigner; + @Mock + private software.amazon.awssdk.core.signer.Signer mockSigV4ASigner; @Mock private AWSCredentialsProvider mockPrimaryCredentialsProvider; @Mock private AWSCredentialsProvider mockMetadataAccessCredentialsProvider; @Mock private HttpContext mockContext; + @Mock + private SdkHttpFullRequest mockSdkHttpFullRequest; @Captor private ArgumentCaptor httpRequestCaptor; @@ -32,8 +46,20 @@ public class ResourceBasedAWSRequestSigningApacheInterceptorTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider)); - AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider)); + + // Configure the mockMetadataAccessCredentialsProvider to return the mock credentials + AWSSessionCredentials mockSessionCredentials = mock(AWSSessionCredentials.class); + when(mockSessionCredentials.getAWSAccessKeyId()).thenReturn("accessKey"); + when(mockSessionCredentials.getAWSSecretKey()).thenReturn("secretKey"); + when(mockSessionCredentials.getSessionToken()).thenReturn("sessionToken"); + when(mockMetadataAccessCredentialsProvider.getCredentials()).thenReturn(mockSessionCredentials); + + HttpRequestInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockPrimarySigner, mockPrimaryCredentialsProvider)); + HttpRequestInterceptor metadataInterceptorSpy = spy(new AWSRequestSigV4ASigningApacheInterceptor("es", "us-east-1", mockSigV4ASigner, mockMetadataAccessCredentialsProvider)); + + // Configure the mockMetadataAccessCredentialsProvider to avoid NPEs + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(new HttpHost("http://localhost")); + when(mockSigV4ASigner.sign(any(), any())).thenReturn(mockSdkHttpFullRequest); interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor( "es", @@ -45,6 +71,7 @@ public void setUp() { @Test public void testProcessWithMetadataAccess() throws Exception { HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource"); + request.addHeader("Content-Type", "application/json"); interceptor.process(request, mockContext);